def train_fn(loader, model, optimizer, criterion, epoch, scaler, patience=5):
loop = tqdm(loader, ascii=True)
loop.set_postfix(epoch=epoch)
loop.set_postfix(loss='-')
epoch_loss = 0
for batch_idx, (flair, mask) in enumerate(loop):
flair = flair.to(DEVICE)
mask = mask.to(DEVICE)
flair = torch.squeeze(flair, dim=0)
mask = torch.squeeze(mask, dim=0)
mask = mask.long()
# forward pass
with torch.cuda.amp.autocast():
prediction = model(flair)
loss = criterion(prediction, mask)
# backward
optimizer.zero_grad()
if scaler is not None:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
loop.set_postfix(loss=loss.item())
epoch_loss += loss.item()
return epoch_loss / len(loop)