Untitled

mail@pastecode.io avatar
unknown
plain_text
7 months ago
516 B
3
Indexable
Never
model1.eval()

index = 0
data = X_test[index:index+1]

fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot(1, 1, 1)
ax.imshow(data[0][0], cmap='viridis')
ax.set_title(f'Original Label: {y_test[index].argmax()}')
ax.axis('off')

X = torch.tensor(data, dtype=torch.float32, device=device)
with torch.no_grad():
    fig = plt.figure(figsize=(10, 5))
    for i in range(32):
        ax = fig.add_subplot(4, 8, i+1)
        ax.imshow(model1.layer1(X)[0][i].cpu().detach().numpy(), cmap='viridis')
        ax.axis('off')