Untitled
unknown
plain_text
3 years ago
811 B
2
Indexable
with torch.no_grad(): loop = tqdm(loader, ascii=True) loop.set_postfix(epoch=epoch) loop.set_postfix(loss='-') for batch_idx, (flair, t2, t1, mask) in enumerate(loop): flair = flair.to(DEVICE) t2 = t2.to(DEVICE) t1 = t1.to(DEVICE) mask = mask.to(DEVICE) flair = torch.squeeze(flair, dim=0) t2 = torch.squeeze(t2, dim=0) t1 = torch.squeeze(t1, dim=0) mask = torch.squeeze(mask, dim=0) mask = mask.long() # forward pass with torch.cuda.amp.autocast(): prediciton = model(flair, t2, t1) loss = criterion(prediciton, mask)
Editor is loading...