Untitled
unknown
plain_text
2 years ago
5.5 kB
9
Indexable
import numpy as np
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from bert_dataset import CustomDataset
class BertClassifier:
def __init__(self, model_path, tokenizer_path, n_classes=2, epochs=1, model_save_path='/content/bert.pt'):
self.model = BertForSequenceClassification.from_pretrained(model_path, local_files_only=True)
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model_save_path=model_save_path
self.max_len = 512
self.epochs = epochs
self.out_features = self.model.bert.encoder.layer[1].output.dense.out_features
self.model.classifier = torch.nn.Linear(self.out_features, n_classes)
self.model.to(self.device)
def goodSave(self):
self.model.save_pretrained("/content/model/")
self.tokenizer.save_pretrained("/content/tokenizer/")
def preparation(self, X_train, y_train, X_valid, y_valid):
# create datasets
self.train_set = CustomDataset(X_train, y_train, self.tokenizer)
self.valid_set = CustomDataset(X_valid, y_valid, self.tokenizer)
# create data loaders
self.train_loader = DataLoader(self.train_set, batch_size=2, shuffle=True)
self.valid_loader = DataLoader(self.valid_set, batch_size=2, shuffle=True)
# helpers initialization
self.optimizer = AdamW(self.model.parameters(), lr=2e-5, correct_bias=False)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=0,
num_training_steps=len(self.train_loader) * self.epochs
)
self.loss_fn = torch.nn.CrossEntropyLoss().to(self.device)
def fit(self):
self.model = self.model.train()
losses = []
correct_predictions = 0
for data in self.train_loader:
input_ids = data["input_ids"].to(self.device)
attention_mask = data["attention_mask"].to(self.device)
targets = data["targets"].to(self.device)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask
)
preds = torch.argmax(outputs.logits, dim=1)
loss = self.loss_fn(outputs.logits, targets)
correct_predictions += torch.sum(preds == targets)
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
train_acc = correct_predictions.double() / len(self.train_set)
train_loss = np.mean(losses)
return train_acc, train_loss
def eval(self):
self.model = self.model.eval()
losses = []
correct_predictions = 0
with torch.no_grad():
for data in self.valid_loader:
input_ids = data["input_ids"].to(self.device)
attention_mask = data["attention_mask"].to(self.device)
targets = data["targets"].to(self.device)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask
)
preds = torch.argmax(outputs.logits, dim=1)
loss = self.loss_fn(outputs.logits, targets)
correct_predictions += torch.sum(preds == targets)
losses.append(loss.item())
val_acc = correct_predictions.double() / len(self.valid_set)
val_loss = np.mean(losses)
return val_acc, val_loss
def train(self):
best_accuracy = 0
for epoch in range(self.epochs):
print(f'Epoch {epoch + 1}/{self.epochs}')
train_acc, train_loss = self.fit()
print(f'Train loss {train_loss} accuracy {train_acc}')
val_acc, val_loss = self.eval()
print(f'Val loss {val_loss} accuracy {val_acc}')
print('-' * 10)
if val_acc > best_accuracy:
torch.save(self.model, self.model_save_path)
best_accuracy = val_acc
self.model = torch.load(self.model_save_path)
def predict(self, text):
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt',
)
out = {
'text': text,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten()
}
input_ids = out["input_ids"].to(self.device)
attention_mask = out["attention_mask"].to(self.device)
outputs = self.model(
input_ids=input_ids.unsqueeze(0),
attention_mask=attention_mask.unsqueeze(0)
)
prediction = torch.argmax(outputs.logits, dim=1).cpu().numpy()[0]
return prediction
Editor is loading...
Leave a Comment