Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
1.4 kB
1
Indexable
Never
class ConcatBiFPN(nn.Module):
    # Concatenate a list of tensors along dimension
    def __init__(self, c1, c2):
        super(ConcatBiFPN, self).__init__()
        # self.relu = nn.ReLU()
        self.w1 = nn.Parameter(torch.ones(2, dtype = torch.float32), requires_grad = True)
        self.w2 = nn.Parameter(torch.ones(3, dtype = torch.float32), requires_grad = True)
        self.epsilon = 0.0001
        self.dw_conv = DWConv(c1, c1, 3, 1, act = Swish)
        self.pw_conv = Conv(c1, c2, 1, 1, act = Swish)
        self.bn = nn.BatchNorm2d(num_features = c2, momentum = 0.01, eps = 1e-3)
        self.swish = Swish()

    def forward(self, x):
        outs = self._forward(x)
        return outs

    def _forward(self, x):
        if len(x) == 2:
            # w = self.relu(self.w1)
            w = self.w1
            weight = w / (torch.sum(w, dim=0) + self.epsilon)
            x = self.swish(weight[0] * x[0] + weight[1] * x[1])
        elif len(x) == 3:
            # w = self.relu(self.w2)
            w = self.w2
            weight = w / (torch.sum(w, dim=0) + self.epsilon)
            # print("\n X.shape: ", x[0].shape, x[1].shape, x[2].shape)
            x = self.swish(weight[0] * x[0] + weight[1] * x[1] + weight[2] * x[2])
        x = self.dw_conv(x)
        x = self.pw_conv(x)
        x = self.bn(x)
        return x