Untitled
unknown
python
3 years ago
1.0 kB
12
Indexable
Never
def train_fn(loader, model, optimizer, criterion, epoch, scaler, patience=5): loop = tqdm(loader, ascii=True) loop.set_postfix(epoch=epoch) loop.set_postfix(loss='-') epoch_loss = 0 for batch_idx, (flair, mask) in enumerate(loop): flair = flair.to(DEVICE) mask = mask.to(DEVICE) flair = torch.squeeze(flair, dim=0) mask = torch.squeeze(mask, dim=0) mask = mask.long() # forward pass with torch.cuda.amp.autocast(): prediction = model(flair) loss = criterion(prediction, mask) # backward optimizer.zero_grad() if scaler is not None: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: loss.backward() optimizer.step() loop.set_postfix(loss=loss.item()) epoch_loss += loss.item() return epoch_loss / len(loop)