Untitled

mail@pastecode.io avatar
unknown
python
15 days ago
1.2 kB
2
Indexable
Never
mlp1_dim = [10, 32]
pooling = 'mean_std'
mlp2_dim = [64, 128]

pixels_tmp = pixel_dataset[1901]['pixels'].unsqueeze(0)
print("input", pixels_tmp.shape)

layers = []
for i in range(len(mlp1_dim) - 1):
    layers.append(linlayer(mlp1_dim[i], mlp1_dim[i + 1]))
mlp1 = nn.Sequential(*layers)
print(mlp1)

layers = []
for i in range(len(mlp2_dim) - 1):
    layers.append(nn.Linear(mlp2_dim[i], mlp2_dim[i + 1]))
    layers.append(nn.BatchNorm1d(mlp2_dim[i + 1]))
    if i < len(mlp2_dim) - 2:
        layers.append(nn.ReLU())
mlp2 = nn.Sequential(*layers)
print(mlp2)


out = pixels_tmp
if len(out.shape) == 4:
    # Combine batch and temporal dimensions in case of sequential input
    reshape_needed = True
    batch, temp = out.shape[:2]

    out = out.view(batch * temp, *out.shape[2:])
else:
    reshape_needed = False

out = mlp1(out)
# out = torch.cat([pooling_methods[n](out, mask) for n in self.pooling.split('_')], dim=1)

# if self.with_extra:
#     out = torch.cat([out, extra], dim=1)

print(out.shape)

# out = out.permute((0, 2, 1))

out = mlp2(out)

if reshape_needed:
    out = out.view(batch, temp, -1)
print(out)
Leave a Comment