Untitled
unknown
plain_text
4 years ago
1.4 kB
4
Indexable
class ConcatBiFPN(nn.Module): # Concatenate a list of tensors along dimension def __init__(self, c1, c2): super(ConcatBiFPN, self).__init__() # self.relu = nn.ReLU() self.w1 = nn.Parameter(torch.ones(2, dtype = torch.float32), requires_grad = True) self.w2 = nn.Parameter(torch.ones(3, dtype = torch.float32), requires_grad = True) self.epsilon = 0.0001 self.dw_conv = DWConv(c1, c1, 3, 1, act = Swish) self.pw_conv = Conv(c1, c2, 1, 1, act = Swish) self.bn = nn.BatchNorm2d(num_features = c2, momentum = 0.01, eps = 1e-3) self.swish = Swish() def forward(self, x): outs = self._forward(x) return outs def _forward(self, x): if len(x) == 2: # w = self.relu(self.w1) w = self.w1 weight = w / (torch.sum(w, dim=0) + self.epsilon) x = self.swish(weight[0] * x[0] + weight[1] * x[1]) elif len(x) == 3: # w = self.relu(self.w2) w = self.w2 weight = w / (torch.sum(w, dim=0) + self.epsilon) # print("\n X.shape: ", x[0].shape, x[1].shape, x[2].shape) x = self.swish(weight[0] * x[0] + weight[1] * x[1] + weight[2] * x[2]) x = self.dw_conv(x) x = self.pw_conv(x) x = self.bn(x) return x
Editor is loading...