Untitled

mail@pastecode.io avatar
unknown
python
a year ago
1.5 kB
2
Indexable
Never
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

class LitMNIST(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(28 * 28, 128)
        self.layer2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.layer1(x)
        x = nn.functional.relu(x)
        x = self.layer2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

# prepare data
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST('data', train=True, download=True, transform=transform)
mnist_test = MNIST('data', train=False, download=True, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=64)
test_loader = DataLoader(mnist_test, batch_size=64)

# create an instance of LitMNIST and initialize PyTorch Lightning Trainer with multiple GPUs
model = LitMNIST()
trainer = pl.Trainer(accelerator="gpu", devices=2)

# train the model using multiple GPUs
trainer.fit(model, train_loader)

# test the model on the test set
trainer.test(test_dataloaders=test_loader)