Untitled

 avatar
unknown
plain_text
a year ago
2.4 kB
5
Indexable
import argparse
import os
import pandas as pd
from setfit import SetFitModel, SetFitTrainer
from sklearn.metrics import accuracy_score, f1_score

def load_data(train_path, validation_path, test_path):
    train_data = pd.read_csv(train_path)
    validation_data = pd.read_csv(validation_path)
    test_data = pd.read_csv(test_path)
    return train_data, validation_data, test_data

def train(train_data, validation_data, test_data, num_classes, multi_label, epochs, batch_size, learning_rate, model_dir):
    # Load a pretrained model
    model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

    # Create a SetFit trainer
    trainer = SetFitTrainer(
        model=model,
        train_dataset={"text": train_data['text'].tolist(), "label": train_data['label'].tolist()},
        eval_dataset={"text": validation_data['text'].tolist(), "label": validation_data['label'].tolist()},
        batch_size=batch_size,
        num_iterations=epochs,
        learning_rate=learning_rate
    )

    # Train the model
    trainer.train()

    # Evaluate the model
    y_pred = trainer.predict(test_data['text'].tolist())
    accuracy = accuracy_score(test_data['label'], y_pred)
    f1 = f1_score(test_data['label'], y_pred, average='weighted')
    print(f"Accuracy: {accuracy}")
    print(f"F1 Score: {f1}")

    # Save the model
    trainer.model.save_pretrained(model_dir)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_path', type=str, required=True)
    parser.add_argument('--validation_path', type=str, required=True)
    parser.add_argument('--test_path', type=str, required=True)
    parser.add_argument('--num_classes', type=int, required=True)
    parser.add_argument('--multi_label', type=bool, required=True)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--model_dir', type=str, default=os.getenv("SM_MODEL_DIR"))
    args = parser.parse_args()

    train_data, validation_data, test_data = load_data(args.train_path, args.validation_path, args.test_path)
    train(train_data, validation_data, test_data, args.num_classes, args.multi_label, args.epochs, args.batch_size, args.learning_rate, args.model_dir)
Editor is loading...
Leave a Comment