Untitled

 avatar
unknown
plain_text
a month ago
6.2 kB
6
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import pickle
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm.notebook import tqdm

class SentimentClassifier(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=256, output_dim=3, dropout_prob=0.3):
        super(SentimentClassifier, self).__init__()

        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(p=dropout_prob),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.mlp(x)

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    train_epoch_loss = 0.0
    train_correct = 0
    train_total = 0

    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        # Update metrics
        train_epoch_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

    train_epoch_loss /= train_total
    train_accuracy = train_correct / train_total

    return train_epoch_loss, train_accuracy

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    val_epoch_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_epoch_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    val_epoch_loss /= val_total
    val_accuracy = val_correct / val_total

    return val_epoch_loss, val_accuracy

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, device='cpu', save_loss_plot=True):
    model.to(device)
    train_loss_history = []
    val_loss_history = []

    for epoch in range(num_epochs):
        train_epoch_loss, train_accuracy = train_epoch(model, train_loader, criterion, optimizer, device)
        val_epoch_loss, val_accuracy = validate_epoch(model, val_loader, criterion, device)

        train_loss_history.append(train_epoch_loss)
        val_loss_history.append(val_epoch_loss)

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_epoch_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, "
              f"Val Loss: {val_epoch_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

    if save_loss_plot:
        plt.figure()
        plt.plot(range(1, num_epochs + 1), train_loss_history, label='Train Loss', marker='o')
        plt.plot(range(1, num_epochs + 1), val_loss_history, label='Val Loss', marker='o')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss Over Epochs')
        plt.legend()
        plt.grid(True)
        plt.savefig('training_val_loss_plot.png')
        print("Training and validation loss plot saved as 'training_val_loss_plot.png'")

    return model

def generate_and_save_dataset(csv_file, tokenizer, pickle_file, max_length=128):
    data = pd.read_csv(csv_file)
    processed_data = []

    for _, row in data.iterrows():
        text = row['text']
        label = row['label']
        embedding = tokenizer(text)[:max_length]
        processed_data.append((embedding, label))

    with open(pickle_file, 'wb') as f:
        pickle.dump(processed_data, f)

    print(f"Dataset saved to {pickle_file}")

def load_dataset_from_pickle(pickle_file):
    with open(pickle_file, 'rb') as f:
        data = pickle.load(f)

    print(f"Dataset loaded from {pickle_file}")
    return data

class SentimentDataset(Dataset):
    def __init__(self, csv_file=None, tokenizer=None, pickle_file=None, max_length=128):

        if pickle_file:
            self.data = load_dataset_from_pickle(pickle_file)
        elif csv_file and tokenizer:
            self.data = []
            raw_data = pd.read_csv(csv_file)
            for _, row in raw_data.iterrows():
                text = row['text']
                label = row['label']
                embedding = tokenizer(text)[:max_length]
                self.data.append((embedding, label))
        else:
            raise ValueError("Either pickle_file or both csv_file and tokenizer must be provided.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        embedding, label = data['emb'], data['label']
        return embedding.clone().detach(), torch.tensor(label, dtype=torch.long)

if __name__ == "__main__":
  full_dataset = SentimentDataset(pickle_file='drive/MyDrive/sub_phobert2/pkl/dataset_all.pkl')
  train_size = int(0.8 * len(full_dataset))
  val_size = len(full_dataset) - train_size
  train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

  train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
  val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

  model = SentimentClassifier()

  # Define loss function and optimizer
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.AdamW(model.parameters(), lr=5e-4)

  trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device='cpu')
  torch.save(trained_model.state_dict(), 'drive/MyDrive/sub_phobert2/sentiment_classifier.pth')
  def load_model(model_path):
    model = SentimentClassifier()
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model
  trained_model = load_model('drive/MyDrive/sub_phobert2/sentiment_classifier.pth')
Leave a Comment