Untitled
unknown
plain_text
a year ago
2.4 kB
5
Indexable
import argparse import os import pandas as pd from setfit import SetFitModel, SetFitTrainer from sklearn.metrics import accuracy_score, f1_score def load_data(train_path, validation_path, test_path): train_data = pd.read_csv(train_path) validation_data = pd.read_csv(validation_path) test_data = pd.read_csv(test_path) return train_data, validation_data, test_data def train(train_data, validation_data, test_data, num_classes, multi_label, epochs, batch_size, learning_rate, model_dir): # Load a pretrained model model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2") # Create a SetFit trainer trainer = SetFitTrainer( model=model, train_dataset={"text": train_data['text'].tolist(), "label": train_data['label'].tolist()}, eval_dataset={"text": validation_data['text'].tolist(), "label": validation_data['label'].tolist()}, batch_size=batch_size, num_iterations=epochs, learning_rate=learning_rate ) # Train the model trainer.train() # Evaluate the model y_pred = trainer.predict(test_data['text'].tolist()) accuracy = accuracy_score(test_data['label'], y_pred) f1 = f1_score(test_data['label'], y_pred, average='weighted') print(f"Accuracy: {accuracy}") print(f"F1 Score: {f1}") # Save the model trainer.model.save_pretrained(model_dir) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--train_path', type=str, required=True) parser.add_argument('--validation_path', type=str, required=True) parser.add_argument('--test_path', type=str, required=True) parser.add_argument('--num_classes', type=int, required=True) parser.add_argument('--multi_label', type=bool, required=True) parser.add_argument('--epochs', type=int, default=10) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--learning_rate', type=float, default=0.001) parser.add_argument('--model_dir', type=str, default=os.getenv("SM_MODEL_DIR")) args = parser.parse_args() train_data, validation_data, test_data = load_data(args.train_path, args.validation_path, args.test_path) train(train_data, validation_data, test_data, args.num_classes, args.multi_label, args.epochs, args.batch_size, args.learning_rate, args.model_dir)
Editor is loading...
Leave a Comment