GedankenNet

mail@pastecode.io avatar
unknown
python
13 days ago
1.3 kB
2
Indexable
Never
def forward(self, x):
    grid = self.get_grid(x.shape, x.device)
    x = torch.cat((x, grid), dim=1)
    # x = x.permute(0, 3, 1, 2)
    x = self.conv_begin_0(x)
    x = self.prelu_begin(x)
    x = self.conv_begin_1(x)
    x_0 = x
    x_s = x_0
    x_t = x_0
        

    pointer = 0
    for i in range(len(self.scales_per_block)):
        # print(i)
        if self.share_block[i]:
            for _ in range(self.num_per_block[i]):
                # print("\t",pointer)
                x = self.SConv2d_list[pointer](x) + self.w_list[pointer](x)
                x = self.prelu_list[pointer](x)
                x = x + x_t
                x_t = x
            pointer += 1
        else:
            for _ in range(self.num_per_block[i]):
                # print("\t",pointer)
                x = self.SConv2d_list[pointer](x) + self.w_list[pointer](x)
                x = self.prelu_list[pointer](x)
                x = x + x_t
                x_t = x
                pointer += 1
        x = self.ssc_list[i](x)
        x = x + x_s
        x_s = x
        x_t = x_s


    x = self.lsc(x)
    x = x + x_0

    x = self.conv_end1(x)
    x = self.prelu_end(x)
    x = self.conv_end2(x)
    # x = x.permute(0, 2, 3, 1)
    return x
Leave a Comment