Untitled
unknown
python
3 years ago
1.5 kB
12
Indexable
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
class LitMNIST(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(28 * 28, 128)
self.layer2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer1(x)
x = nn.functional.relu(x)
x = self.layer2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = nn.functional.cross_entropy(y_hat, y)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
return optimizer
# prepare data
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST('data', train=True, download=True, transform=transform)
mnist_test = MNIST('data', train=False, download=True, transform=transform)
train_loader = DataLoader(mnist_train, batch_size=64)
test_loader = DataLoader(mnist_test, batch_size=64)
# create an instance of LitMNIST and initialize PyTorch Lightning Trainer with multiple GPUs
model = LitMNIST()
trainer = pl.Trainer(accelerator="gpu", devices=2)
# train the model using multiple GPUs
trainer.fit(model, train_loader)
# test the model on the test set
trainer.test(test_dataloaders=test_loader)Editor is loading...