def forward(self, x):
grid = self.get_grid(x.shape, x.device)
x = torch.cat((x, grid), dim=1)
# x = x.permute(0, 3, 1, 2)
x = self.conv_begin_0(x)
x = self.prelu_begin(x)
x = self.conv_begin_1(x)
x_0 = x
x_s = x_0
x_t = x_0
pointer = 0
for i in range(len(self.scales_per_block)):
# print(i)
if self.share_block[i]:
for _ in range(self.num_per_block[i]):
# print("\t",pointer)
x = self.SConv2d_list[pointer](x) + self.w_list[pointer](x)
x = self.prelu_list[pointer](x)
x = x + x_t
x_t = x
pointer += 1
else:
for _ in range(self.num_per_block[i]):
# print("\t",pointer)
x = self.SConv2d_list[pointer](x) + self.w_list[pointer](x)
x = self.prelu_list[pointer](x)
x = x + x_t
x_t = x
pointer += 1
x = self.ssc_list[i](x)
x = x + x_s
x_s = x
x_t = x_s
x = self.lsc(x)
x = x + x_0
x = self.conv_end1(x)
x = self.prelu_end(x)
x = self.conv_end2(x)
# x = x.permute(0, 2, 3, 1)
return x
Editor is loading...