Untitled
unknown
plain_text
3 years ago
2.8 kB
2
Indexable
class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class ConcatBiFPN(nn.Module): # Concatenate a list of tensors along dimension def __init__(self, c1, c2): super(ConcatBiFPN, self).__init__() # self.relu = nn.ReLU() self.w1 = nn.Parameter(torch.ones(2, dtype = torch.float32), requires_grad = True) self.w2 = nn.Parameter(torch.ones(3, dtype = torch.float32), requires_grad = True) self.epsilon = 0.0001 self.conv = nn.Conv2d(c1, c2, kernel_size = 1, stride = 1, padding = 0) self.swish = Swish() def forward(self, x): outs = self._forward(x) return outs def _forward(self, x): if len(x) == 2: # w = self.relu(self.w1) w = self.w1 weight = w / (torch.sum(w, dim=0) + self.epsilon) # Connections for P6_0 and P7_0 to P6_1 respectively x = self.conv(self.swish(weight[0] * x[0] + weight[1] * x[1])) elif len(x) == 3: # w = self.relu(self.w2) w = self.w2 weight = w / (torch.sum(w, dim=0) + self.epsilon) x = self.conv(self.swish(weight[0] * x[0] + weight[1] * x[1] + weight[2] * x[2])) return x elif m is ConcatBiFPN: c2 = max([ch[x] for x in f]) # YOLOv5 🚀 by Ultralytics, GPL-3.0 license # Parameters nc: 80 # number of classes depth_multiple: 0.33 # model depth multiple width_multiple: 0.50 # layer channel multiple anchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 - [116,90, 156,198, 373,326] # P5/32 # YOLOv5 v6.0 backbone backbone: # [from, number, module, args] [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 [-1, 3, C3, [128]], [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 [-1, 6, C3, [256]], [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 [-1, 9, C3, [512]], [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 [-1, 3, C3, [1024]], [-1, 1, SPPFTR2, [1024, 5]], # 9 ] # YOLOv5 v6.0 head head: [[-1, 1, Conv, [512, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 6], 1, Concat, [1]], # cat backbone P4 [-1, 3, C3, [512, False]], # 13 [-1, 1, Conv, [256, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 4], 1, Concat, [1]], # cat backbone P3 [-1, 3, C3, [256, False]], # 17 (P3/8-small) [-1, 1, Conv, [512, 3, 2]], [[-1, 6, 13], 1, ConcatBiFPN, [256, 256]], # cat head P4 [-1, 3, C3, [512, False]], # 20 (P4/16-medium) [-1, 1, Conv, [512, 3, 2]], [[-1, 10], 1, Concat, [1]], # cat head P5 [-1, 3, C3, [1024, False]], # 23 (P5/32-large) [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ]
Editor is loading...