import time
def train_epoch(dataloader, criterion, epoch):
start = time.time()
size = len(dataloader)
running_loss = 0
losses = []
lrs = []
val_losses = []
print({"train/epoch": epoch})
running_wer = 0
model.train()
for batch_idx, batch in enumerate(dataloader):
inputs, input_lengths, targets, target_lengths = batch
inputs, input_lengths = inputs.to(device), input_lengths.to(device)
targets, target_lengths = targets.to(device), target_lengths.to(device)
optimizer.zero_grad()
outputs, output_lengths = model(inputs, input_lengths)
loss = criterion(
outputs.permute(1, 0, 2).float(), targets.float(), output_lengths, target_lengths
)
predict_sequences = []
for encoder_output in outputs:
predict = decode(encoder_output)
if len(predict) == 0:
predict = "^"
predict_sequences.append(predict)
label_sequences = []
for target in targets.cpu().numpy():
s = [num_to_char[_] for _ in target]
s = [c for c in s if c != '^']
if len(s) == 0:
s = "^"
label_sequences.append("".join(s))
wer = torch.Tensor(
[
jiwer.wer(truth, hypot)
for truth, hypot in zip(label_sequences, predict_sequences)
]
)
wer = torch.mean(wer).item()
running_wer += wer
if torch.isnan(loss).item() == True:
break
running_loss += loss.item()
loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
if (batch_idx + 1) % 1 == 0:
print(f"Epoch: {epoch} - train_loss: {running_loss / (batch_idx + 1)}, train_wer: {running_wer / (batch_idx + 1)}, lr: {scheduler.get_last_lr()[0]}")
if (batch_idx + 1) % 300 == 0:
val_loss, val_wer = eval_epoch(
valid_dataloader, criterion, epoch, "val", eval_all = False
)
gc.collect()
for _ in range(32):
try:
print("=" * 25)
print("Predict:", predict_sequences[_], "len:", len(predict_sequences[_]))
print("Target:", label_sequences[_], "len:", len(label_sequences[_]))
print("=" * 25)
except:
pass
print(f"val_loss: {val_loss}, val_wer: {val_wer}")
print()
print()
gc.collect()
val_loss, val_wer = eval_epoch(valid_dataloader, criterion, epoch, "val", True)
print("VAL_LOSS:", val_loss, "VAL_WER:", val_wer)
return running_loss / size, val_loss, running_wer / size, val_wer