Untitled
unknown
plain_text
4 years ago
589 B
7
Indexable
for batch_idx, (flair, t2, t1, mask) in enumerate(loop):
flair = flair.to(DEVICE)
t2 = t2.to(DEVICE)
t1 = t1.to(DEVICE)
mask = mask.to(DEVICE)
flair = torch.squeeze(flair, dim=0)
t2 = torch.squeeze(t2, dim=0)
t1 = torch.squeeze(t1, dim=0)
mask = torch.squeeze(mask, dim=0)
mask = mask.long()
# forward pass
with torch.cuda.amp.autocast():
prediciton = model(flair, t2, t1)
loss = criterion(prediciton, mask)Editor is loading...