class ConcatBiFPN(nn.Module):
# Concatenate a list of tensors along dimension
def __init__(self, c1, c2):
super(ConcatBiFPN, self).__init__()
# self.relu = nn.ReLU()
self.w1 = nn.Parameter(torch.ones(2, dtype = torch.float32), requires_grad = True)
self.w2 = nn.Parameter(torch.ones(3, dtype = torch.float32), requires_grad = True)
self.epsilon = 0.0001
self.dw_conv = DWConv(c1, c1, 3, 1, act = Swish)
self.pw_conv = Conv(c1, c2, 1, 1, act = Swish)
self.bn = nn.BatchNorm2d(num_features = c2, momentum = 0.01, eps = 1e-3)
self.swish = Swish()
def forward(self, x):
outs = self._forward(x)
return outs
def _forward(self, x):
if len(x) == 2:
# w = self.relu(self.w1)
w = self.w1
weight = w / (torch.sum(w, dim=0) + self.epsilon)
x = self.swish(weight[0] * x[0] + weight[1] * x[1])
elif len(x) == 3:
# w = self.relu(self.w2)
w = self.w2
weight = w / (torch.sum(w, dim=0) + self.epsilon)
# print("\n X.shape: ", x[0].shape, x[1].shape, x[2].shape)
x = self.swish(weight[0] * x[0] + weight[1] * x[1] + weight[2] * x[2])
x = self.dw_conv(x)
x = self.pw_conv(x)
x = self.bn(x)
return x