Untitled

mail@pastecode.io avatar
unknown
python
3 years ago
1.0 kB
12
Indexable
Never
def train_fn(loader, model, optimizer, criterion, epoch, scaler, patience=5):
    loop = tqdm(loader, ascii=True)
    
    loop.set_postfix(epoch=epoch)
    loop.set_postfix(loss='-')
    
    epoch_loss = 0
    
    for batch_idx, (flair, mask) in enumerate(loop):
        flair = flair.to(DEVICE)
        mask = mask.to(DEVICE)
        
        flair = torch.squeeze(flair, dim=0)
        mask = torch.squeeze(mask, dim=0)
        
        mask = mask.long()
         
        # forward pass
        with torch.cuda.amp.autocast():
            prediction = model(flair)
            loss = criterion(prediction, mask)
        
        # backward
        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()    
        
        loop.set_postfix(loss=loss.item())

        epoch_loss += loss.item()

    return epoch_loss / len(loop)