GedankenNet
unknown
python
13 days ago
1.3 kB
2
Indexable
Never
def forward(self, x): grid = self.get_grid(x.shape, x.device) x = torch.cat((x, grid), dim=1) # x = x.permute(0, 3, 1, 2) x = self.conv_begin_0(x) x = self.prelu_begin(x) x = self.conv_begin_1(x) x_0 = x x_s = x_0 x_t = x_0 pointer = 0 for i in range(len(self.scales_per_block)): # print(i) if self.share_block[i]: for _ in range(self.num_per_block[i]): # print("\t",pointer) x = self.SConv2d_list[pointer](x) + self.w_list[pointer](x) x = self.prelu_list[pointer](x) x = x + x_t x_t = x pointer += 1 else: for _ in range(self.num_per_block[i]): # print("\t",pointer) x = self.SConv2d_list[pointer](x) + self.w_list[pointer](x) x = self.prelu_list[pointer](x) x = x + x_t x_t = x pointer += 1 x = self.ssc_list[i](x) x = x + x_s x_s = x x_t = x_s x = self.lsc(x) x = x + x_0 x = self.conv_end1(x) x = self.prelu_end(x) x = self.conv_end2(x) # x = x.permute(0, 2, 3, 1) return x
Leave a Comment