Untitled
unknown
plain_text
a year ago
2.9 kB
6
Indexable
import time def train_epoch(dataloader, criterion, epoch): start = time.time() size = len(dataloader) running_loss = 0 losses = [] lrs = [] val_losses = [] print({"train/epoch": epoch}) running_wer = 0 model.train() for batch_idx, batch in enumerate(dataloader): inputs, input_lengths, targets, target_lengths = batch inputs, input_lengths = inputs.to(device), input_lengths.to(device) targets, target_lengths = targets.to(device), target_lengths.to(device) optimizer.zero_grad() outputs, output_lengths = model(inputs, input_lengths) loss = criterion( outputs.permute(1, 0, 2).float(), targets.float(), output_lengths, target_lengths ) predict_sequences = [] for encoder_output in outputs: predict = decode(encoder_output) if len(predict) == 0: predict = "^" predict_sequences.append(predict) label_sequences = [] for target in targets.cpu().numpy(): s = [num_to_char[_] for _ in target] s = [c for c in s if c != '^'] if len(s) == 0: s = "^" label_sequences.append("".join(s)) wer = torch.Tensor( [ jiwer.wer(truth, hypot) for truth, hypot in zip(label_sequences, predict_sequences) ] ) wer = torch.mean(wer).item() running_wer += wer if torch.isnan(loss).item() == True: break running_loss += loss.item() loss.backward() optimizer.step() if scheduler: scheduler.step() if (batch_idx + 1) % 1 == 0: print(f"Epoch: {epoch} - train_loss: {running_loss / (batch_idx + 1)}, train_wer: {running_wer / (batch_idx + 1)}, lr: {scheduler.get_last_lr()[0]}") if (batch_idx + 1) % 300 == 0: val_loss, val_wer = eval_epoch( valid_dataloader, criterion, epoch, "val", eval_all = False ) gc.collect() for _ in range(32): try: print("=" * 25) print("Predict:", predict_sequences[_], "len:", len(predict_sequences[_])) print("Target:", label_sequences[_], "len:", len(label_sequences[_])) print("=" * 25) except: pass print(f"val_loss: {val_loss}, val_wer: {val_wer}") print() print() gc.collect() val_loss, val_wer = eval_epoch(valid_dataloader, criterion, epoch, "val", True) print("VAL_LOSS:", val_loss, "VAL_WER:", val_wer) return running_loss / size, val_loss, running_wer / size, val_wer
Editor is loading...