Untitled

 avatar
unknown
plain_text
3 years ago
811 B
1
Indexable
with torch.no_grad():
        loop = tqdm(loader, ascii=True)
        loop.set_postfix(epoch=epoch)
        loop.set_postfix(loss='-')
        
        for batch_idx, (flair, t2, t1, mask) in enumerate(loop):
            flair = flair.to(DEVICE)
            t2 = t2.to(DEVICE)
            t1 = t1.to(DEVICE)
            mask = mask.to(DEVICE)
            
            flair = torch.squeeze(flair, dim=0)
            t2 = torch.squeeze(t2, dim=0)
            t1 = torch.squeeze(t1, dim=0)
            mask = torch.squeeze(mask, dim=0)            
            mask = mask.long()
            
            # forward pass
            with torch.cuda.amp.autocast():
                prediciton = model(flair, t2, t1)
                            
                loss = criterion(prediciton, mask)