Untitled
unknown
plain_text
a year ago
2.4 kB
10
Indexable
import argparse
import os
import pandas as pd
from setfit import SetFitModel, SetFitTrainer
from sklearn.metrics import accuracy_score, f1_score
def load_data(train_path, validation_path, test_path):
train_data = pd.read_csv(train_path)
validation_data = pd.read_csv(validation_path)
test_data = pd.read_csv(test_path)
return train_data, validation_data, test_data
def train(train_data, validation_data, test_data, num_classes, multi_label, epochs, batch_size, learning_rate, model_dir):
# Load a pretrained model
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
# Create a SetFit trainer
trainer = SetFitTrainer(
model=model,
train_dataset={"text": train_data['text'].tolist(), "label": train_data['label'].tolist()},
eval_dataset={"text": validation_data['text'].tolist(), "label": validation_data['label'].tolist()},
batch_size=batch_size,
num_iterations=epochs,
learning_rate=learning_rate
)
# Train the model
trainer.train()
# Evaluate the model
y_pred = trainer.predict(test_data['text'].tolist())
accuracy = accuracy_score(test_data['label'], y_pred)
f1 = f1_score(test_data['label'], y_pred, average='weighted')
print(f"Accuracy: {accuracy}")
print(f"F1 Score: {f1}")
# Save the model
trainer.model.save_pretrained(model_dir)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train_path', type=str, required=True)
parser.add_argument('--validation_path', type=str, required=True)
parser.add_argument('--test_path', type=str, required=True)
parser.add_argument('--num_classes', type=int, required=True)
parser.add_argument('--multi_label', type=bool, required=True)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--model_dir', type=str, default=os.getenv("SM_MODEL_DIR"))
args = parser.parse_args()
train_data, validation_data, test_data = load_data(args.train_path, args.validation_path, args.test_path)
train(train_data, validation_data, test_data, args.num_classes, args.multi_label, args.epochs, args.batch_size, args.learning_rate, args.model_dir)Editor is loading...
Leave a Comment