with torch.no_grad():
loop = tqdm(loader, ascii=True)
loop.set_postfix(epoch=epoch)
loop.set_postfix(loss='-')
for batch_idx, (flair, t2, t1, mask) in enumerate(loop):
flair = flair.to(DEVICE)
t2 = t2.to(DEVICE)
t1 = t1.to(DEVICE)
mask = mask.to(DEVICE)
flair = torch.squeeze(flair, dim=0)
t2 = torch.squeeze(t2, dim=0)
t1 = torch.squeeze(t1, dim=0)
mask = torch.squeeze(mask, dim=0)
mask = mask.long()
# forward pass
with torch.cuda.amp.autocast():
prediciton = model(flair, t2, t1)
loss = criterion(prediciton, mask)