Untitled
unknown
python
2 years ago
22 kB
12
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import pandas as pd
import ast
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt
import math
token_lists = [torch.tensor([50364, 7956, 766, 264, 1442, 294, 264, 8687, 13, 50514]), torch.tensor([50364, 3087, 472, 11, 1767, 13, 50464]), torch.tensor([50364, 1468, 380, 718, 385, 2870, 4153, 311, 8410, 365,
48832, 13, 50589]), torch.tensor([50364, 16346, 264, 1442, 13, 50464]), torch.tensor([50364, 3013, 786, 341, 1243, 360, 286, 362, 5482, 30,
50541]), torch.tensor([50364, 13638, 6782, 13, 50464]), torch.tensor([50364, 8928, 257, 13548, 337, 805, 277, 6, 9023, 13,
50564]), torch.tensor([50364, 7956, 309, 281, 589, 484, 13, 50464]), torch.tensor([50364, 2555, 5623, 385, 365, 264, 1329, 295, 3931, 2737,
294, 452, 1859, 13, 50564]), torch.tensor([50364, 4387, 264, 1808, 19764, 13, 50600]), torch.tensor([50364, 15074, 493, 3905, 2583, 965, 490, 24859, 13, 50464]), torch.tensor([50364, 5506, 428, 7367, 8182, 3324, 50464]), torch.tensor([50364, 9476, 1737, 3098, 12, 83, 8400, 3847, 12628, 337,
7776, 412, 1614, 14395, 11, 2736, 10550, 295, 958, 1243,
7776, 412, 1649, 335, 13, 50828]), torch.tensor([50364, 17180, 2654, 4601, 294, 1176, 9139, 3089, 2272, 23,
4436, 13, 50564]), torch.tensor([50364, 6767, 2037, 11590, 293, 1439, 271, 1966, 11590, 8252,
544, 2049, 813, 661, 11590, 570, 50602]), torch.tensor([50364, 2053, 337, 257, 6477, 5214, 1219, 497, 34494, 322,
264, 7893, 2319, 13, 16, 13, 50664]), torch.tensor([50364, 708, 307, 8589, 337, 965, 294, 452, 2654, 1859,
30, 50484]), torch.tensor([50364, 708, 366, 264, 10486, 295, 8632, 3899, 6246, 666,
257, 42852, 30, 50662]), torch.tensor([50364, 1119, 309, 18441, 294, 21247, 30, 50564]), torch.tensor([50364, 4552, 72, 11, 652, 385, 512, 4982, 13, 50564]), torch.tensor([50364, 2692, 286, 658, 17203, 13, 50444]), torch.tensor([50364, 708, 366, 264, 2583, 12881, 30, 50464]), torch.tensor([50364, 3560, 257, 3440, 965, 13, 50464]), torch.tensor([50364, 3013, 2007, 264, 1318, 1487, 490, 30, 708, 307,
264, 1315, 295, 264, 1318, 30, 50564]), torch.tensor([50364, 15060, 452, 1808, 5811, 281, 512, 13590, 2017, 13,
50564]), torch.tensor([50364, 5115, 385, 439, 452, 14511, 13, 50464]), torch.tensor([50364, 1144, 286, 362, 604, 777, 12524, 30, 50464]), torch.tensor([50364, 2205, 13, 50441]), torch.tensor([50364, 708, 307, 341, 2280, 307, 1219, 597, 2737, 294,
452, 1859, 30, 50740]), torch.tensor([50364, 286, 2103, 1318, 13, 50464]), torch.tensor([50364, 5506, 452, 15066, 5924, 11, 23795, 13, 50464]), torch.tensor([50364, 708, 307, 264, 14330, 30, 50464]), torch.tensor([50364, 6895, 385, 2583, 466, 3899, 13, 50456]), torch.tensor([50364, 708, 307, 264, 4669, 4238, 295, 6309, 30, 50504]), torch.tensor([50364, 7497, 291, 1767, 4159, 452, 2920, 4365, 490, 30512,
30, 50514]), torch.tensor([50364, 5010, 13, 50400]), torch.tensor([50364, 509, 1491, 300, 286, 362, 257, 3440, 365, 5041,
412, 805, 4153, 13, 50508]), torch.tensor([50364, 22595, 11, 437, 366, 34808, 4127, 7901, 30, 50544]), torch.tensor([50364, 35089, 15920, 385, 295, 452, 1225, 311, 6154, 257,
1243, 949, 13, 50528]), torch.tensor([50364, 1711, 1614, 257, 13, 76, 13, 4153, 2446, 11,
4875, 264, 14183, 13, 50539]), torch.tensor([50364, 2555, 980, 385, 577, 938, 360, 291, 519, 3899,
486, 1036, 30, 50564]), torch.tensor([50364, 1012, 264, 5503, 965, 13, 50564]), torch.tensor([50364, 49452, 264, 958, 2280, 322, 452, 12183, 370, 300,
286, 500, 380, 536, 309, 797, 13, 50594]), torch.tensor([50364, 6895, 385, 452, 6792, 2093, 3021, 5191, 13, 50514]), torch.tensor([50364, 3240, 385, 484, 11, 420, 291, 434, 406, 294,
264, 1379, 551, 1562, 13, 50504]), torch.tensor([50364, 440, 7742, 3314, 1296, 264, 2546, 7241, 293, 12641,
7241, 13, 50564]), torch.tensor([50364, 708, 311, 264, 1315, 295, 264, 2280, 30, 50464]), torch.tensor([50364, 5115, 385, 257, 7647, 13, 50464]), torch.tensor([50364, 1468, 380, 718, 796, 2870, 4153, 311, 3440, 365,
48832, 13, 50564]), torch.tensor([50364, 13548, 337, 4153, 412, 568, 280, 13, 76, 13,
50564]), torch.tensor([50364, 40546, 257, 5081, 365, 7938, 322, 10383, 6499, 13,
50614]), torch.tensor([50364, 639, 307, 1203, 1411, 322, 452, 7605, 13, 50514]), torch.tensor([50364, 1012, 281, 483, 281, 591, 11272, 305, 949, 24040,
538, 3847, 13, 50564]), torch.tensor([50364, 44926, 2271, 364, 21839, 294, 945, 2077, 281, 12262,
1291, 322, 1025, 392, 7638, 13, 50564]), torch.tensor([50389, 294, 879, 347, 325, 341, 2446, 30, 50714]), torch.tensor([50364, 6454, 286, 1565, 364, 21925, 4153, 30, 50464]), torch.tensor([50364, 286, 643, 281, 2845, 257, 16972, 281, 19059, 13,
50464]), torch.tensor([50364, 1664, 4933, 264, 3440, 958, 1243, 11, 10383, 412,
568, 14395, 13, 50575]), torch.tensor([50364, 6895, 385, 428, 44223, 13, 50464]), torch.tensor([50389, 295, 376, 2641, 27924, 30, 50714]), torch.tensor([50364, 708, 311, 926, 430, 7509, 30, 50464]), torch.tensor([50364, 5115, 385, 257, 7647, 13, 50464]), torch.tensor([50364, 5506, 264, 3894, 2497, 496, 372, 13, 50514]), torch.tensor([50364, 5735, 10371, 27274, 13, 50464]), torch.tensor([50364, 34441, 294, 9525, 13, 50464]), torch.tensor([50364, 422, 13020, 11, 286, 478, 1940, 786, 766, 4153,
13, 50514]), torch.tensor([50364, 2102, 307, 6726, 341, 2153, 30, 50464]), torch.tensor([50364, 286, 1116, 411, 257, 4982, 586, 13, 50464]), torch.tensor([50364, 5506, 787, 439, 1318, 4736, 1296, 264, 1064, 50482]), torch.tensor([50364, 639, 307, 257, 665, 10864, 11, 1767, 500, 380,
747, 23491, 295, 264, 6613, 13, 50494]), torch.tensor([50364, 639, 307, 452, 1036, 3233, 1454, 50464]), torch.tensor([50364, 708, 311, 257, 1266, 12, 810, 14330, 30, 50504]), torch.tensor([50364, 2555, 1261, 766, 452, 15642, 6765, 13, 50464]), torch.tensor([50364, 708, 390, 264, 11500, 3931, 30, 50564]), torch.tensor([50364, 437, 20774, 307, 322, 39426, 13, 50464]), torch.tensor([50364, 2555, 4160, 385, 295, 452, 5027, 3440, 322, 10383,
412, 805, 14395, 13, 50580]), torch.tensor([50364, 708, 311, 322, 264, 6477, 30, 50464]), torch.tensor([50364, 1119, 456, 604, 2280, 926, 30, 50464]), torch.tensor([50364, 1144, 286, 362, 281, 360, 746, 19419, 30, 50464]), torch.tensor([50364, 708, 307, 264, 881, 2190, 7742, 3314, 294, 3533,
30, 50548]), torch.tensor([50364, 5349, 777, 2280, 4926, 12899, 281, 12183, 13, 50564]), torch.tensor([50364, 286, 486, 3651, 1577, 2060, 11781, 498, 309, 307,
886, 7679, 88, 294, 5634, 13, 50664]), torch.tensor([50364, 14895, 287, 910, 1373, 294, 264, 6525, 13, 50514]), torch.tensor([50364, 2555, 718, 385, 458, 300, 14183, 4305, 337, 10017,
311, 3440, 13, 50544]), torch.tensor([50364, 11211, 281, 3679, 766, 264, 5811, 50614]), torch.tensor([50364, 407, 452, 958, 1520, 307, 412, 1386, 277, 6,
9023, 13, 50514]), torch.tensor([50364, 13637, 1668, 281, 24469, 3664, 346, 337, 1614, 14395,
13, 50504]), torch.tensor([50364, 5506, 10294, 20159, 295, 5116, 10647, 50604]), torch.tensor([50364, 3696, 356, 281, 300, 3796, 13, 50464]), torch.tensor([50364, 708, 311, 6782, 337, 10425, 4662, 30, 50514]), torch.tensor([50364, 708, 307, 264, 6343, 4292, 337, 341, 1243, 30,
50539]), torch.tensor([50364, 5115, 385, 567, 575, 264, 881, 1230, 322, 472,
8664, 322, 264, 40351, 8840, 945, 13, 50744]), torch.tensor([50364, 821, 366, 867, 1021, 3931, 294, 341, 1618, 11,
406, 1096, 300, 50566]), torch.tensor([50364, 7956, 760, 264, 21367, 13, 50464]), torch.tensor([50364, 6895, 385, 5162, 2583, 322, 1729, 3983, 13, 50514]), torch.tensor([50364, 5349, 364, 14183, 337, 4153, 2446, 412, 11849, 669,
13, 50614]), torch.tensor([50364, 1318, 50514]), torch.tensor([50364, 27868, 22717, 0, 50464]), torch.tensor([50364, 5115, 385, 2190, 2590, 295, 6419, 2651, 9701, 4964,
11507, 13, 50589]), torch.tensor([50364, 2555, 976, 385, 257, 732, 12, 18048, 9164, 949,
958, 8803, 311, 3440, 13, 50564])]
one_hot_labels = [torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0.]), torch.tensor([0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0.,
1., 1., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 1., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 1., 0., 1., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 1., 1., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 1., 1., 0.]), torch.tensor([0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 1., 1., 1., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 1., 1., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0.])]
class EntityDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
def pad_collate(batch):
(xx, yy) = zip(*batch)
x_padded = pad_sequence(xx, batch_first=True, padding_value=0)
y_padded = pad_sequence(yy, batch_first=True, padding_value=0)
return x_padded, y_padded
def load_model(model_path, vocab_size, embed_dim, model_dim, num_heads, num_classes):
model = TinyTransformer(vocab_size=vocab_size, embed_dim=embed_dim, model_dim=model_dim, num_heads=num_heads, num_classes=num_classes)
model.load_state_dict(torch.load(model_path))
# out = sum(p.numel() for p in model.parameters())
model.eval()
return model
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=100):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0)]
return self.dropout(x)
class TinyTransformer(nn.Module):
def __init__(self, vocab_size, embed_dim, model_dim, num_heads, num_classes, dropout_rate=0.2, num_encoder_layers=1):
super(TinyTransformer, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.positional_encoding = PositionalEncoding(embed_dim, dropout_rate)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_encoder_layers)
self.output_layer = nn.Linear(model_dim, num_classes)
self.dropout = nn.Dropout(dropout_rate)
# self.sigmoid = nn.Sigmoid() # Sigmoid activation to output probabilities
def forward(self, src):
# src shape: [seq_length, batch_size]
# embedded = self.embedding(src).permute(1, 0, 2) # [batch_size, seq_length, embed_dim]
# encoded = self.transformer_encoder(embedded)
# output = self.output_layer(encoded)
# return self.sigmoid(output)
src = self.embedding(src) * math.sqrt(self.embedding.embedding_dim) # Scale embedding
src = self.positional_encoding(src)
encoded_src = self.transformer_encoder(src)
output = self.dropout(encoded_src)
output = self.output_layer(output)
return output
def train(model, data_loader, epochs=1):
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
visual_loss = []
model.train()
for epoch in range(epochs):
total_loss = 0
for inputs, labels in tqdm(data_loader, desc="Processing batches", leave=True):
optimizer.zero_grad()
outputs = model(inputs) # Forward pass
outputs = outputs.squeeze(-1) # Remove the last dimension to match labels shape
loss = criterion(outputs.view(-1), labels.view(-1))
loss.backward() # Backward pass
optimizer.step()
total_loss += loss.item()
visual_loss.append(loss.item())
if len(visual_loss) % 100 == 0:
print(f"Batch {len(visual_loss)} Loss: {loss.item()}")
avg_loss = total_loss / len(data_loader)
print(f"Epoch {epoch+1}, Average Loss: {avg_loss}")
# Plot the training loss for each batch
plt.figure(figsize=(10, 6))
plt.plot(visual_loss, label='Batch Loss')
plt.xlabel('input')
plt.ylabel('Loss')
plt.title('Training Loss Per input')
plt.legend()
plt.grid(True)
plt.show()
# Example training data and labels, with multiple inputs
data = token_lists
labels = one_hot_labels
dataset = EntityDataset(data, labels)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=pad_collate)
vocab_size = 51865 # Adjust based on your actual vocabulary size
embed_dim = 128
model_dim = 128
num_heads = 4
num_classes = 1 # Output is binary for each token
def eval_model(model, data_loader, threshold=0.5):
model.eval() # Set the model to evaluation mode
with torch.no_grad(): # No need to track gradients during evaluation
for inputs, labels in data_loader:
if not labels.any():
continue
if len(inputs[0]) != len(labels[0]):
continue
outputs = model(inputs)
outputs = outputs.squeeze(-1) # Adjust dimensions if necessary
# Apply threshold to obtain binary predictions
predicted_labels = (outputs > threshold).float()
# Flatten both labels and predictions to ensure shape alignment for masking
labels_flat = labels.view(-1)
predicted_labels_flat = predicted_labels.view(-1)
# Apply the mask for non-padded values (-100)
valid_indices = labels_flat != -100
predictions_valid = predicted_labels_flat[valid_indices]
print(labels[0], predictions_valid)
accuracy = accuracy_score(labels[0], predictions_valid)
return {
"Accuracy": accuracy
}
model_path = '/Users/afsarabenazir/Downloads/speech_projects/whisper-timestamped-master/models/transformer-NER.pth'
if os.path.isfile(model_path):
model = load_model(model_path, vocab_size, embed_dim, model_dim, num_heads, num_classes)
results = eval_model(model, data_loader)
print(f"Accuracy: {results['Accuracy']:.4f}")
else:
model = TinyTransformer(vocab_size=vocab_size, embed_dim=embed_dim, model_dim=model_dim, num_heads=num_heads, num_classes=num_classes)
train(model, data_loader)
torch.save(model.state_dict(), model_path)
Editor is loading...
Leave a Comment