Untitled
unknown
python
a year ago
22 kB
11
Indexable
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from torch.nn.utils.rnn import pad_sequence from sklearn.metrics import accuracy_score from sklearn.model_selection import train_test_split import pandas as pd import ast from tqdm import tqdm import numpy as np import os import matplotlib.pyplot as plt import math token_lists = [torch.tensor([50364, 7956, 766, 264, 1442, 294, 264, 8687, 13, 50514]), torch.tensor([50364, 3087, 472, 11, 1767, 13, 50464]), torch.tensor([50364, 1468, 380, 718, 385, 2870, 4153, 311, 8410, 365, 48832, 13, 50589]), torch.tensor([50364, 16346, 264, 1442, 13, 50464]), torch.tensor([50364, 3013, 786, 341, 1243, 360, 286, 362, 5482, 30, 50541]), torch.tensor([50364, 13638, 6782, 13, 50464]), torch.tensor([50364, 8928, 257, 13548, 337, 805, 277, 6, 9023, 13, 50564]), torch.tensor([50364, 7956, 309, 281, 589, 484, 13, 50464]), torch.tensor([50364, 2555, 5623, 385, 365, 264, 1329, 295, 3931, 2737, 294, 452, 1859, 13, 50564]), torch.tensor([50364, 4387, 264, 1808, 19764, 13, 50600]), torch.tensor([50364, 15074, 493, 3905, 2583, 965, 490, 24859, 13, 50464]), torch.tensor([50364, 5506, 428, 7367, 8182, 3324, 50464]), torch.tensor([50364, 9476, 1737, 3098, 12, 83, 8400, 3847, 12628, 337, 7776, 412, 1614, 14395, 11, 2736, 10550, 295, 958, 1243, 7776, 412, 1649, 335, 13, 50828]), torch.tensor([50364, 17180, 2654, 4601, 294, 1176, 9139, 3089, 2272, 23, 4436, 13, 50564]), torch.tensor([50364, 6767, 2037, 11590, 293, 1439, 271, 1966, 11590, 8252, 544, 2049, 813, 661, 11590, 570, 50602]), torch.tensor([50364, 2053, 337, 257, 6477, 5214, 1219, 497, 34494, 322, 264, 7893, 2319, 13, 16, 13, 50664]), torch.tensor([50364, 708, 307, 8589, 337, 965, 294, 452, 2654, 1859, 30, 50484]), torch.tensor([50364, 708, 366, 264, 10486, 295, 8632, 3899, 6246, 666, 257, 42852, 30, 50662]), torch.tensor([50364, 1119, 309, 18441, 294, 21247, 30, 50564]), torch.tensor([50364, 4552, 72, 11, 652, 385, 512, 4982, 13, 50564]), torch.tensor([50364, 2692, 286, 658, 17203, 13, 50444]), torch.tensor([50364, 708, 366, 264, 2583, 12881, 30, 50464]), torch.tensor([50364, 3560, 257, 3440, 965, 13, 50464]), torch.tensor([50364, 3013, 2007, 264, 1318, 1487, 490, 30, 708, 307, 264, 1315, 295, 264, 1318, 30, 50564]), torch.tensor([50364, 15060, 452, 1808, 5811, 281, 512, 13590, 2017, 13, 50564]), torch.tensor([50364, 5115, 385, 439, 452, 14511, 13, 50464]), torch.tensor([50364, 1144, 286, 362, 604, 777, 12524, 30, 50464]), torch.tensor([50364, 2205, 13, 50441]), torch.tensor([50364, 708, 307, 341, 2280, 307, 1219, 597, 2737, 294, 452, 1859, 30, 50740]), torch.tensor([50364, 286, 2103, 1318, 13, 50464]), torch.tensor([50364, 5506, 452, 15066, 5924, 11, 23795, 13, 50464]), torch.tensor([50364, 708, 307, 264, 14330, 30, 50464]), torch.tensor([50364, 6895, 385, 2583, 466, 3899, 13, 50456]), torch.tensor([50364, 708, 307, 264, 4669, 4238, 295, 6309, 30, 50504]), torch.tensor([50364, 7497, 291, 1767, 4159, 452, 2920, 4365, 490, 30512, 30, 50514]), torch.tensor([50364, 5010, 13, 50400]), torch.tensor([50364, 509, 1491, 300, 286, 362, 257, 3440, 365, 5041, 412, 805, 4153, 13, 50508]), torch.tensor([50364, 22595, 11, 437, 366, 34808, 4127, 7901, 30, 50544]), torch.tensor([50364, 35089, 15920, 385, 295, 452, 1225, 311, 6154, 257, 1243, 949, 13, 50528]), torch.tensor([50364, 1711, 1614, 257, 13, 76, 13, 4153, 2446, 11, 4875, 264, 14183, 13, 50539]), torch.tensor([50364, 2555, 980, 385, 577, 938, 360, 291, 519, 3899, 486, 1036, 30, 50564]), torch.tensor([50364, 1012, 264, 5503, 965, 13, 50564]), torch.tensor([50364, 49452, 264, 958, 2280, 322, 452, 12183, 370, 300, 286, 500, 380, 536, 309, 797, 13, 50594]), torch.tensor([50364, 6895, 385, 452, 6792, 2093, 3021, 5191, 13, 50514]), torch.tensor([50364, 3240, 385, 484, 11, 420, 291, 434, 406, 294, 264, 1379, 551, 1562, 13, 50504]), torch.tensor([50364, 440, 7742, 3314, 1296, 264, 2546, 7241, 293, 12641, 7241, 13, 50564]), torch.tensor([50364, 708, 311, 264, 1315, 295, 264, 2280, 30, 50464]), torch.tensor([50364, 5115, 385, 257, 7647, 13, 50464]), torch.tensor([50364, 1468, 380, 718, 796, 2870, 4153, 311, 3440, 365, 48832, 13, 50564]), torch.tensor([50364, 13548, 337, 4153, 412, 568, 280, 13, 76, 13, 50564]), torch.tensor([50364, 40546, 257, 5081, 365, 7938, 322, 10383, 6499, 13, 50614]), torch.tensor([50364, 639, 307, 1203, 1411, 322, 452, 7605, 13, 50514]), torch.tensor([50364, 1012, 281, 483, 281, 591, 11272, 305, 949, 24040, 538, 3847, 13, 50564]), torch.tensor([50364, 44926, 2271, 364, 21839, 294, 945, 2077, 281, 12262, 1291, 322, 1025, 392, 7638, 13, 50564]), torch.tensor([50389, 294, 879, 347, 325, 341, 2446, 30, 50714]), torch.tensor([50364, 6454, 286, 1565, 364, 21925, 4153, 30, 50464]), torch.tensor([50364, 286, 643, 281, 2845, 257, 16972, 281, 19059, 13, 50464]), torch.tensor([50364, 1664, 4933, 264, 3440, 958, 1243, 11, 10383, 412, 568, 14395, 13, 50575]), torch.tensor([50364, 6895, 385, 428, 44223, 13, 50464]), torch.tensor([50389, 295, 376, 2641, 27924, 30, 50714]), torch.tensor([50364, 708, 311, 926, 430, 7509, 30, 50464]), torch.tensor([50364, 5115, 385, 257, 7647, 13, 50464]), torch.tensor([50364, 5506, 264, 3894, 2497, 496, 372, 13, 50514]), torch.tensor([50364, 5735, 10371, 27274, 13, 50464]), torch.tensor([50364, 34441, 294, 9525, 13, 50464]), torch.tensor([50364, 422, 13020, 11, 286, 478, 1940, 786, 766, 4153, 13, 50514]), torch.tensor([50364, 2102, 307, 6726, 341, 2153, 30, 50464]), torch.tensor([50364, 286, 1116, 411, 257, 4982, 586, 13, 50464]), torch.tensor([50364, 5506, 787, 439, 1318, 4736, 1296, 264, 1064, 50482]), torch.tensor([50364, 639, 307, 257, 665, 10864, 11, 1767, 500, 380, 747, 23491, 295, 264, 6613, 13, 50494]), torch.tensor([50364, 639, 307, 452, 1036, 3233, 1454, 50464]), torch.tensor([50364, 708, 311, 257, 1266, 12, 810, 14330, 30, 50504]), torch.tensor([50364, 2555, 1261, 766, 452, 15642, 6765, 13, 50464]), torch.tensor([50364, 708, 390, 264, 11500, 3931, 30, 50564]), torch.tensor([50364, 437, 20774, 307, 322, 39426, 13, 50464]), torch.tensor([50364, 2555, 4160, 385, 295, 452, 5027, 3440, 322, 10383, 412, 805, 14395, 13, 50580]), torch.tensor([50364, 708, 311, 322, 264, 6477, 30, 50464]), torch.tensor([50364, 1119, 456, 604, 2280, 926, 30, 50464]), torch.tensor([50364, 1144, 286, 362, 281, 360, 746, 19419, 30, 50464]), torch.tensor([50364, 708, 307, 264, 881, 2190, 7742, 3314, 294, 3533, 30, 50548]), torch.tensor([50364, 5349, 777, 2280, 4926, 12899, 281, 12183, 13, 50564]), torch.tensor([50364, 286, 486, 3651, 1577, 2060, 11781, 498, 309, 307, 886, 7679, 88, 294, 5634, 13, 50664]), torch.tensor([50364, 14895, 287, 910, 1373, 294, 264, 6525, 13, 50514]), torch.tensor([50364, 2555, 718, 385, 458, 300, 14183, 4305, 337, 10017, 311, 3440, 13, 50544]), torch.tensor([50364, 11211, 281, 3679, 766, 264, 5811, 50614]), torch.tensor([50364, 407, 452, 958, 1520, 307, 412, 1386, 277, 6, 9023, 13, 50514]), torch.tensor([50364, 13637, 1668, 281, 24469, 3664, 346, 337, 1614, 14395, 13, 50504]), torch.tensor([50364, 5506, 10294, 20159, 295, 5116, 10647, 50604]), torch.tensor([50364, 3696, 356, 281, 300, 3796, 13, 50464]), torch.tensor([50364, 708, 311, 6782, 337, 10425, 4662, 30, 50514]), torch.tensor([50364, 708, 307, 264, 6343, 4292, 337, 341, 1243, 30, 50539]), torch.tensor([50364, 5115, 385, 567, 575, 264, 881, 1230, 322, 472, 8664, 322, 264, 40351, 8840, 945, 13, 50744]), torch.tensor([50364, 821, 366, 867, 1021, 3931, 294, 341, 1618, 11, 406, 1096, 300, 50566]), torch.tensor([50364, 7956, 760, 264, 21367, 13, 50464]), torch.tensor([50364, 6895, 385, 5162, 2583, 322, 1729, 3983, 13, 50514]), torch.tensor([50364, 5349, 364, 14183, 337, 4153, 2446, 412, 11849, 669, 13, 50614]), torch.tensor([50364, 1318, 50514]), torch.tensor([50364, 27868, 22717, 0, 50464]), torch.tensor([50364, 5115, 385, 2190, 2590, 295, 6419, 2651, 9701, 4964, 11507, 13, 50589]), torch.tensor([50364, 2555, 976, 385, 257, 732, 12, 18048, 9164, 949, 958, 8803, 311, 3440, 13, 50564])] one_hot_labels = [torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0.]), torch.tensor([0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 1., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 1., 0., 1., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 1., 1., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 1., 1., 0.]), torch.tensor([0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 1., 1., 1., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 1., 1., 1., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 1., 0.]), torch.tensor([0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0.]), torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]), torch.tensor([0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0.])] class EntityDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.labels[idx] def pad_collate(batch): (xx, yy) = zip(*batch) x_padded = pad_sequence(xx, batch_first=True, padding_value=0) y_padded = pad_sequence(yy, batch_first=True, padding_value=0) return x_padded, y_padded def load_model(model_path, vocab_size, embed_dim, model_dim, num_heads, num_classes): model = TinyTransformer(vocab_size=vocab_size, embed_dim=embed_dim, model_dim=model_dim, num_heads=num_heads, num_classes=num_classes) model.load_state_dict(torch.load(model_path)) # out = sum(p.numel() for p in model.parameters()) model.eval() return model class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=100): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0)] return self.dropout(x) class TinyTransformer(nn.Module): def __init__(self, vocab_size, embed_dim, model_dim, num_heads, num_classes, dropout_rate=0.2, num_encoder_layers=1): super(TinyTransformer, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.positional_encoding = PositionalEncoding(embed_dim, dropout_rate) self.encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, batch_first=True) self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_encoder_layers) self.output_layer = nn.Linear(model_dim, num_classes) self.dropout = nn.Dropout(dropout_rate) # self.sigmoid = nn.Sigmoid() # Sigmoid activation to output probabilities def forward(self, src): # src shape: [seq_length, batch_size] # embedded = self.embedding(src).permute(1, 0, 2) # [batch_size, seq_length, embed_dim] # encoded = self.transformer_encoder(embedded) # output = self.output_layer(encoded) # return self.sigmoid(output) src = self.embedding(src) * math.sqrt(self.embedding.embedding_dim) # Scale embedding src = self.positional_encoding(src) encoded_src = self.transformer_encoder(src) output = self.dropout(encoded_src) output = self.output_layer(output) return output def train(model, data_loader, epochs=1): criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) visual_loss = [] model.train() for epoch in range(epochs): total_loss = 0 for inputs, labels in tqdm(data_loader, desc="Processing batches", leave=True): optimizer.zero_grad() outputs = model(inputs) # Forward pass outputs = outputs.squeeze(-1) # Remove the last dimension to match labels shape loss = criterion(outputs.view(-1), labels.view(-1)) loss.backward() # Backward pass optimizer.step() total_loss += loss.item() visual_loss.append(loss.item()) if len(visual_loss) % 100 == 0: print(f"Batch {len(visual_loss)} Loss: {loss.item()}") avg_loss = total_loss / len(data_loader) print(f"Epoch {epoch+1}, Average Loss: {avg_loss}") # Plot the training loss for each batch plt.figure(figsize=(10, 6)) plt.plot(visual_loss, label='Batch Loss') plt.xlabel('input') plt.ylabel('Loss') plt.title('Training Loss Per input') plt.legend() plt.grid(True) plt.show() # Example training data and labels, with multiple inputs data = token_lists labels = one_hot_labels dataset = EntityDataset(data, labels) data_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=pad_collate) vocab_size = 51865 # Adjust based on your actual vocabulary size embed_dim = 128 model_dim = 128 num_heads = 4 num_classes = 1 # Output is binary for each token def eval_model(model, data_loader, threshold=0.5): model.eval() # Set the model to evaluation mode with torch.no_grad(): # No need to track gradients during evaluation for inputs, labels in data_loader: if not labels.any(): continue if len(inputs[0]) != len(labels[0]): continue outputs = model(inputs) outputs = outputs.squeeze(-1) # Adjust dimensions if necessary # Apply threshold to obtain binary predictions predicted_labels = (outputs > threshold).float() # Flatten both labels and predictions to ensure shape alignment for masking labels_flat = labels.view(-1) predicted_labels_flat = predicted_labels.view(-1) # Apply the mask for non-padded values (-100) valid_indices = labels_flat != -100 predictions_valid = predicted_labels_flat[valid_indices] print(labels[0], predictions_valid) accuracy = accuracy_score(labels[0], predictions_valid) return { "Accuracy": accuracy } model_path = '/Users/afsarabenazir/Downloads/speech_projects/whisper-timestamped-master/models/transformer-NER.pth' if os.path.isfile(model_path): model = load_model(model_path, vocab_size, embed_dim, model_dim, num_heads, num_classes) results = eval_model(model, data_loader) print(f"Accuracy: {results['Accuracy']:.4f}") else: model = TinyTransformer(vocab_size=vocab_size, embed_dim=embed_dim, model_dim=model_dim, num_heads=num_heads, num_classes=num_classes) train(model, data_loader) torch.save(model.state_dict(), model_path)
Editor is loading...
Leave a Comment