Untitled

mail@pastecode.io avatar
unknown
python
14 days ago
22 kB
7
Indexable
Never
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import pandas as pd
import ast
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt
import math


token_lists = [torch.tensor([50364,  7956,   766,   264,  1442,   294,   264,  8687,    13, 50514]), torch.tensor([50364,  3087,   472,    11,  1767,    13, 50464]), torch.tensor([50364,  1468,   380,   718,   385,  2870,  4153,   311,  8410,   365,
        48832,    13, 50589]), torch.tensor([50364, 16346,   264,  1442,    13, 50464]), torch.tensor([50364,  3013,   786,   341,  1243,   360,   286,   362,  5482,    30,
        50541]), torch.tensor([50364, 13638,  6782,    13, 50464]), torch.tensor([50364,  8928,   257, 13548,   337,   805,   277,     6,  9023,    13,
        50564]), torch.tensor([50364,  7956,   309,   281,   589,   484,    13, 50464]), torch.tensor([50364,  2555,  5623,   385,   365,   264,  1329,   295,  3931,  2737,
          294,   452,  1859,    13, 50564]), torch.tensor([50364,  4387,   264,  1808, 19764,    13, 50600]), torch.tensor([50364, 15074,   493,  3905,  2583,   965,   490, 24859,    13, 50464]), torch.tensor([50364,  5506,   428,  7367,  8182,  3324, 50464]), torch.tensor([50364,  9476,  1737,  3098,    12,    83,  8400,  3847, 12628,   337,
         7776,   412,  1614, 14395,    11,  2736, 10550,   295,   958,  1243,
         7776,   412,  1649,   335,    13, 50828]), torch.tensor([50364, 17180,  2654,  4601,   294,  1176,  9139,  3089,  2272,    23,
         4436,    13, 50564]), torch.tensor([50364,  6767,  2037, 11590,   293,  1439,   271,  1966, 11590,  8252,
          544,  2049,   813,   661, 11590,   570, 50602]), torch.tensor([50364,  2053,   337,   257,  6477,  5214,  1219,   497, 34494,   322,
          264,  7893,  2319,    13,    16,    13, 50664]), torch.tensor([50364,   708,   307,  8589,   337,   965,   294,   452,  2654,  1859,
           30, 50484]), torch.tensor([50364,   708,   366,   264, 10486,   295,  8632,  3899,  6246,   666,
          257, 42852,    30, 50662]), torch.tensor([50364,  1119,   309, 18441,   294, 21247,    30, 50564]), torch.tensor([50364,  4552,    72,    11,   652,   385,   512,  4982,    13, 50564]), torch.tensor([50364,  2692,   286,   658, 17203,    13, 50444]), torch.tensor([50364,   708,   366,   264,  2583, 12881,    30, 50464]), torch.tensor([50364,  3560,   257,  3440,   965,    13, 50464]), torch.tensor([50364,  3013,  2007,   264,  1318,  1487,   490,    30,   708,   307,
          264,  1315,   295,   264,  1318,    30, 50564]), torch.tensor([50364, 15060,   452,  1808,  5811,   281,   512, 13590,  2017,    13,
        50564]), torch.tensor([50364,  5115,   385,   439,   452, 14511,    13, 50464]), torch.tensor([50364,  1144,   286,   362,   604,   777, 12524,    30, 50464]), torch.tensor([50364,  2205,    13, 50441]), torch.tensor([50364,   708,   307,   341,  2280,   307,  1219,   597,  2737,   294,
          452,  1859,    30, 50740]), torch.tensor([50364,   286,  2103,  1318,    13, 50464]), torch.tensor([50364,  5506,   452, 15066,  5924,    11, 23795,    13, 50464]), torch.tensor([50364,   708,   307,   264, 14330,    30, 50464]), torch.tensor([50364,  6895,   385,  2583,   466,  3899,    13, 50456]), torch.tensor([50364,   708,   307,   264,  4669,  4238,   295,  6309,    30, 50504]), torch.tensor([50364,  7497,   291,  1767,  4159,   452,  2920,  4365,   490, 30512,
           30, 50514]), torch.tensor([50364,  5010,    13, 50400]), torch.tensor([50364,   509,  1491,   300,   286,   362,   257,  3440,   365,  5041,
          412,   805,  4153,    13, 50508]), torch.tensor([50364, 22595,    11,   437,   366, 34808,  4127,  7901,    30, 50544]), torch.tensor([50364, 35089, 15920,   385,   295,   452,  1225,   311,  6154,   257,
         1243,   949,    13, 50528]), torch.tensor([50364,  1711,  1614,   257,    13,    76,    13,  4153,  2446,    11,
         4875,   264, 14183,    13, 50539]), torch.tensor([50364,  2555,   980,   385,   577,   938,   360,   291,   519,  3899,
          486,  1036,    30, 50564]), torch.tensor([50364,  1012,   264,  5503,   965,    13, 50564]), torch.tensor([50364, 49452,   264,   958,  2280,   322,   452, 12183,   370,   300,
          286,   500,   380,   536,   309,   797,    13, 50594]), torch.tensor([50364,  6895,   385,   452,  6792,  2093,  3021,  5191,    13, 50514]), torch.tensor([50364,  3240,   385,   484,    11,   420,   291,   434,   406,   294,
          264,  1379,   551,  1562,    13, 50504]), torch.tensor([50364,   440,  7742,  3314,  1296,   264,  2546,  7241,   293, 12641,
         7241,    13, 50564]), torch.tensor([50364,   708,   311,   264,  1315,   295,   264,  2280,    30, 50464]), torch.tensor([50364,  5115,   385,   257,  7647,    13, 50464]), torch.tensor([50364,  1468,   380,   718,   796,  2870,  4153,   311,  3440,   365,
        48832,    13, 50564]), torch.tensor([50364, 13548,   337,  4153,   412,   568,   280,    13,    76,    13,
        50564]), torch.tensor([50364, 40546,   257,  5081,   365,  7938,   322, 10383,  6499,    13,
        50614]), torch.tensor([50364,   639,   307,  1203,  1411,   322,   452,  7605,    13, 50514]), torch.tensor([50364,  1012,   281,   483,   281,   591, 11272,   305,   949, 24040,
          538,  3847,    13, 50564]), torch.tensor([50364, 44926,  2271,   364, 21839,   294,   945,  2077,   281, 12262,
         1291,   322,  1025,   392,  7638,    13, 50564]), torch.tensor([50389,   294,   879,   347,   325,   341,  2446,    30, 50714]), torch.tensor([50364,  6454,   286,  1565,   364, 21925,  4153,    30, 50464]), torch.tensor([50364,   286,   643,   281,  2845,   257, 16972,   281, 19059,    13,
        50464]), torch.tensor([50364,  1664,  4933,   264,  3440,   958,  1243,    11, 10383,   412,
          568, 14395,    13, 50575]), torch.tensor([50364,  6895,   385,   428, 44223,    13, 50464]), torch.tensor([50389,   295,   376,  2641, 27924,    30, 50714]), torch.tensor([50364,   708,   311,   926,   430,  7509,    30, 50464]), torch.tensor([50364,  5115,   385,   257,  7647,    13, 50464]), torch.tensor([50364,  5506,   264,  3894,  2497,   496,   372,    13, 50514]), torch.tensor([50364,  5735, 10371, 27274,    13, 50464]), torch.tensor([50364, 34441,   294,  9525,    13, 50464]), torch.tensor([50364,   422, 13020,    11,   286,   478,  1940,   786,   766,  4153,
           13, 50514]), torch.tensor([50364,  2102,   307,  6726,   341,  2153,    30, 50464]), torch.tensor([50364,   286,  1116,   411,   257,  4982,   586,    13, 50464]), torch.tensor([50364,  5506,   787,   439,  1318,  4736,  1296,   264,  1064, 50482]), torch.tensor([50364,   639,   307,   257,   665, 10864,    11,  1767,   500,   380,
          747, 23491,   295,   264,  6613,    13, 50494]), torch.tensor([50364,   639,   307,   452,  1036,  3233,  1454, 50464]), torch.tensor([50364,   708,   311,   257,  1266,    12,   810, 14330,    30, 50504]), torch.tensor([50364,  2555,  1261,   766,   452, 15642,  6765,    13, 50464]), torch.tensor([50364,   708,   390,   264, 11500,  3931,    30, 50564]), torch.tensor([50364,   437, 20774,   307,   322, 39426,    13, 50464]), torch.tensor([50364,  2555,  4160,   385,   295,   452,  5027,  3440,   322, 10383,
          412,   805, 14395,    13, 50580]), torch.tensor([50364,   708,   311,   322,   264,  6477,    30, 50464]), torch.tensor([50364,  1119,   456,   604,  2280,   926,    30, 50464]), torch.tensor([50364,  1144,   286,   362,   281,   360,   746, 19419,    30, 50464]), torch.tensor([50364,   708,   307,   264,   881,  2190,  7742,  3314,   294,  3533,
           30, 50548]), torch.tensor([50364,  5349,   777,  2280,  4926, 12899,   281, 12183,    13, 50564]), torch.tensor([50364,   286,   486,  3651,  1577,  2060, 11781,   498,   309,   307,
          886,  7679,    88,   294,  5634,    13, 50664]), torch.tensor([50364, 14895,   287,   910,  1373,   294,   264,  6525,    13, 50514]), torch.tensor([50364,  2555,   718,   385,   458,   300, 14183,  4305,   337, 10017,
          311,  3440,    13, 50544]), torch.tensor([50364, 11211,   281,  3679,   766,   264,  5811, 50614]), torch.tensor([50364,   407,   452,   958,  1520,   307,   412,  1386,   277,     6,
         9023,    13, 50514]), torch.tensor([50364, 13637,  1668,   281, 24469,  3664,   346,   337,  1614, 14395,
           13, 50504]), torch.tensor([50364,  5506, 10294, 20159,   295,  5116, 10647, 50604]), torch.tensor([50364,  3696,   356,   281,   300,  3796,    13, 50464]), torch.tensor([50364,   708,   311,  6782,   337, 10425,  4662,    30, 50514]), torch.tensor([50364,   708,   307,   264,  6343,  4292,   337,   341,  1243,    30,
        50539]), torch.tensor([50364,  5115,   385,   567,   575,   264,   881,  1230,   322,   472,
         8664,   322,   264, 40351,  8840,   945,    13, 50744]), torch.tensor([50364,   821,   366,   867,  1021,  3931,   294,   341,  1618,    11,
          406,  1096,   300, 50566]), torch.tensor([50364,  7956,   760,   264, 21367,    13, 50464]), torch.tensor([50364,  6895,   385,  5162,  2583,   322,  1729,  3983,    13, 50514]), torch.tensor([50364,  5349,   364, 14183,   337,  4153,  2446,   412, 11849,   669,
           13, 50614]), torch.tensor([50364,  1318, 50514]), torch.tensor([50364, 27868, 22717,     0, 50464]), torch.tensor([50364,  5115,   385,  2190,  2590,   295,  6419,  2651,  9701,  4964,
        11507,    13, 50589]), torch.tensor([50364,  2555,   976,   385,   257,   732,    12, 18048,  9164,   949,
          958,  8803,   311,  3440,    13, 50564])]

one_hot_labels = [torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0.]), torch.tensor([0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0.,
        1., 1., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 1., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 1., 0., 1., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 1., 1., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 1., 1., 0.]), torch.tensor([0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 1., 1., 1., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 1., 1., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0.])]
class EntityDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

def pad_collate(batch):
    (xx, yy) = zip(*batch)
    x_padded = pad_sequence(xx, batch_first=True, padding_value=0)
    y_padded = pad_sequence(yy, batch_first=True, padding_value=0)
    return x_padded, y_padded

def load_model(model_path, vocab_size, embed_dim, model_dim, num_heads, num_classes):
    model = TinyTransformer(vocab_size=vocab_size, embed_dim=embed_dim, model_dim=model_dim, num_heads=num_heads, num_classes=num_classes)
    model.load_state_dict(torch.load(model_path))
    # out = sum(p.numel() for p in model.parameters())
    model.eval()
    return model

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=100):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, model_dim, num_heads, num_classes, dropout_rate=0.2, num_encoder_layers=1):
        super(TinyTransformer, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = PositionalEncoding(embed_dim, dropout_rate)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_encoder_layers)
        self.output_layer = nn.Linear(model_dim, num_classes)
        self.dropout = nn.Dropout(dropout_rate)
        # self.sigmoid = nn.Sigmoid()  # Sigmoid activation to output probabilities

    def forward(self, src):
        # src shape: [seq_length, batch_size]
        # embedded = self.embedding(src).permute(1, 0, 2)  # [batch_size, seq_length, embed_dim]
        # encoded = self.transformer_encoder(embedded)
        # output = self.output_layer(encoded)
        # return self.sigmoid(output)
        src = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)  # Scale embedding
        src = self.positional_encoding(src)
        encoded_src = self.transformer_encoder(src)
        output = self.dropout(encoded_src)
        output = self.output_layer(output)
        return output


def train(model, data_loader, epochs=1):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    visual_loss = []
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for inputs, labels in tqdm(data_loader, desc="Processing batches", leave=True):
            optimizer.zero_grad()
            outputs = model(inputs)  # Forward pass
            outputs = outputs.squeeze(-1)  # Remove the last dimension to match labels shape
            loss = criterion(outputs.view(-1), labels.view(-1))
            loss.backward()  # Backward pass
            optimizer.step()
            total_loss += loss.item()
            visual_loss.append(loss.item())
            if len(visual_loss) % 100 == 0:
                print(f"Batch {len(visual_loss)} Loss: {loss.item()}")
        avg_loss = total_loss / len(data_loader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss}")

    # Plot the training loss for each batch
    plt.figure(figsize=(10, 6))
    plt.plot(visual_loss, label='Batch Loss')
    plt.xlabel('input')
    plt.ylabel('Loss')
    plt.title('Training Loss Per input')
    plt.legend()
    plt.grid(True)
    plt.show()

# Example training data and labels, with multiple inputs
data = token_lists
labels = one_hot_labels

dataset = EntityDataset(data, labels)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=pad_collate)

vocab_size = 51865  # Adjust based on your actual vocabulary size
embed_dim = 128
model_dim = 128
num_heads = 4
num_classes = 1  # Output is binary for each token

def eval_model(model, data_loader, threshold=0.5):
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():  # No need to track gradients during evaluation
        for inputs, labels in data_loader:
            if not labels.any():
                continue
            if len(inputs[0]) != len(labels[0]):
                continue
            outputs = model(inputs)
            outputs = outputs.squeeze(-1)  # Adjust dimensions if necessary

            # Apply threshold to obtain binary predictions
            predicted_labels = (outputs > threshold).float()

            # Flatten both labels and predictions to ensure shape alignment for masking
            labels_flat = labels.view(-1)
            predicted_labels_flat = predicted_labels.view(-1)

            # Apply the mask for non-padded values (-100)
            valid_indices = labels_flat != -100
            predictions_valid = predicted_labels_flat[valid_indices]
            print(labels[0], predictions_valid)

    accuracy = accuracy_score(labels[0], predictions_valid)

    return {
        "Accuracy": accuracy
    }

model_path = '/Users/afsarabenazir/Downloads/speech_projects/whisper-timestamped-master/models/transformer-NER.pth'
if os.path.isfile(model_path):
    model = load_model(model_path, vocab_size, embed_dim, model_dim, num_heads, num_classes)
    results = eval_model(model, data_loader)
    print(f"Accuracy: {results['Accuracy']:.4f}")
else:
    model = TinyTransformer(vocab_size=vocab_size, embed_dim=embed_dim, model_dim=model_dim, num_heads=num_heads, num_classes=num_classes)
    train(model, data_loader)
    torch.save(model.state_dict(), model_path)



Leave a Comment