Untitled

 avatar
unknown
plain_text
a year ago
2.9 kB
5
Indexable
import time

def train_epoch(dataloader, criterion, epoch):
    start = time.time()
    size = len(dataloader)
    running_loss = 0
    losses = []
    lrs = []
    val_losses = []
    print({"train/epoch": epoch})
    running_wer = 0
    
    model.train()
    for batch_idx, batch in enumerate(dataloader):
        inputs, input_lengths, targets, target_lengths = batch
        inputs, input_lengths = inputs.to(device), input_lengths.to(device)
        targets, target_lengths = targets.to(device), target_lengths.to(device)
        optimizer.zero_grad()
        
        outputs, output_lengths = model(inputs, input_lengths)

        loss = criterion(
                outputs.permute(1, 0, 2).float(), targets.float(), output_lengths, target_lengths
        )

        predict_sequences = []
        for encoder_output in outputs:
            predict = decode(encoder_output)
            if len(predict) == 0:
                predict = "^"
            predict_sequences.append(predict)

        label_sequences = []
        for target in targets.cpu().numpy():
            s = [num_to_char[_] for _ in target]
            s = [c for c in s if c != '^']
            if len(s) == 0:
                s = "^"
            label_sequences.append("".join(s))

        wer = torch.Tensor(
            [
                jiwer.wer(truth, hypot)
                for truth, hypot in zip(label_sequences, predict_sequences)
            ]
        )
        wer = torch.mean(wer).item()
        running_wer += wer
        if torch.isnan(loss).item() == True:
            break

        running_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        if scheduler:
            scheduler.step()

        if (batch_idx + 1) % 1 == 0:
            print(f"Epoch: {epoch} - train_loss: {running_loss / (batch_idx + 1)}, train_wer: {running_wer / (batch_idx + 1)}, lr: {scheduler.get_last_lr()[0]}")
    
        if (batch_idx + 1) % 300 == 0:
            val_loss, val_wer = eval_epoch(
                valid_dataloader, criterion, epoch, "val", eval_all = False
            )
            gc.collect()
            for _ in range(32):
                try:
                    print("=" * 25)
                    print("Predict:", predict_sequences[_], "len:", len(predict_sequences[_]))
                    print("Target:", label_sequences[_], "len:", len(label_sequences[_]))
                    print("=" * 25)

                except:
                    pass
            print(f"val_loss: {val_loss}, val_wer: {val_wer}")
            print()
            print()


        gc.collect()
    val_loss, val_wer = eval_epoch(valid_dataloader, criterion, epoch, "val", True)
    print("VAL_LOSS:", val_loss, "VAL_WER:", val_wer)
    return running_loss / size, val_loss, running_wer / size, val_wer