Untitled
unknown
python
a year ago
1.5 kB
2
Indexable
Never
import torch from torch import nn from torch.utils.data import DataLoader, random_split from torchvision.datasets import MNIST from torchvision import transforms import pytorch_lightning as pl class LitMNIST(pl.LightningModule): def __init__(self): super().__init__() self.layer1 = nn.Linear(28 * 28, 128) self.layer2 = nn.Linear(128, 10) def forward(self, x): x = x.view(x.size(0), -1) x = self.layer1(x) x = nn.functional.relu(x) x = self.layer2(x) return x def training_step(self, batch, batch_idx): x, y = batch y_hat = self.forward(x) loss = nn.functional.cross_entropy(y_hat, y) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=0.001) return optimizer # prepare data transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) mnist_train = MNIST('data', train=True, download=True, transform=transform) mnist_test = MNIST('data', train=False, download=True, transform=transform) train_loader = DataLoader(mnist_train, batch_size=64) test_loader = DataLoader(mnist_test, batch_size=64) # create an instance of LitMNIST and initialize PyTorch Lightning Trainer with multiple GPUs model = LitMNIST() trainer = pl.Trainer(accelerator="gpu", devices=2) # train the model using multiple GPUs trainer.fit(model, train_loader) # test the model on the test set trainer.test(test_dataloaders=test_loader)