Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
589 B
1
Indexable
Never
for batch_idx, (flair, t2, t1, mask) in enumerate(loop):
        flair = flair.to(DEVICE)
        t2 = t2.to(DEVICE)
        t1 = t1.to(DEVICE)
        mask = mask.to(DEVICE)
        
        flair = torch.squeeze(flair, dim=0)
        t2 = torch.squeeze(t2, dim=0)
        t1 = torch.squeeze(t1, dim=0)
        mask = torch.squeeze(mask, dim=0)
        
        mask = mask.long()
         
        # forward pass
        with torch.cuda.amp.autocast():
            prediciton = model(flair, t2, t1)
                        
            loss = criterion(prediciton, mask)