Untitled

 avatar
user_3839718
python
a year ago
5.5 kB
6
Indexable
import argparse
import pandas as pd
from accelerate import Accelerator
from datasets import Dataset
from elasticdb import ElasticDB
import evaluate
from transformers import T5Tokenizer, DataCollatorForSeq2Seq
from transformers import T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
import nltk
import numpy as np

if __name__ == "__main__":
    es = ElasticDB()

    parser = argparse.ArgumentParser()
    # Command-line arguments
    parser.add_argument("--model", type=str, default="google-t5/t5-base")
    parser.add_argument("--num_epochs", type=int, default=1)
    parser.add_argument("--learning_rate", type=float, default=1e-5)
    parser.add_argument("--upper_limit", type=int, default=100000)
    parser.add_argument("--batch_per_gpu", type=int, default=32)
    parser.add_argument("--token_max_length", type=int, default=256)
    parser.add_argument("--test_size", type=float, default=0.3)
    parser.add_argument("--wandb_project_name", type=str, default="t5-v_epoch_1-recipe-model-base")
    parser.add_argument("--accelerator_batch_device", type=bool, default=False)
    parser.add_argument("--saved_model_name", type=str, default="t5-v_epoch_1-small-base")

    args = parser.parse_args()
    ber = Accelerator()
    tokenizer = T5Tokenizer.from_pretrained(args.model)
    model = T5ForConditionalGeneration.from_pretrained(args.model)

    def tokenize_function(recipes):
        input_text = []
        labels = []

        for title, ingredients in zip(recipes['title'], recipes['ingredients']):
            input_text.append(ingredients)
        for directions in recipes['directions']:
            labels.append(directions)

        # Tokenize inputs and labels
        model_inputs = tokenizer(input_text,
                                 max_length=args.token_max_length,
                                 pad_to_max_length=True,
                                 truncation=True,
                                 padding="max_length",
                                 return_tensors="pt",
                                 )
        labels = tokenizer(labels,
                           max_length=args.token_max_length,
                           pad_to_max_length=True,
                           truncation=True,
                           padding="max_length",
                           return_tensors="pt",)
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs


    es_client = es.search(index_name="recipenlp",
                          todos=True,
                          query={"size": args.upper_limit,
                                 "query": {"match_all": {}}
                                 },
                          upper_limit=args.upper_limit)
    df = pd.DataFrame(es_client).sample(frac=1).reset_index(drop=True)
    data = df[['title', 'ingredients', 'directions']]

    data['directions'] = data['directions'].apply(lambda x: '|'.join(x))
    data['ingredients'] = data['ingredients'].apply(lambda x: '|'.join(x))

    dataset = Dataset.from_pandas(data)

    tokenized_datasets = dataset.map(tokenize_function, batched=True, batch_size=100)

    train_temp_dataset = tokenized_datasets.train_test_split(test_size=args.test_size)

    validation_test_split = train_temp_dataset['test'].train_test_split(test_size=0.5)

    train_dataset = train_temp_dataset['train']
    validation_dataset = validation_test_split['train']
    test_dataset = validation_test_split['test']

    metric = evaluate.load("rouge")

    # Global Parameters
    L_RATE = 1e-5
    BATCH_SIZE = 8
    PER_DEVICE_EVAL_BATCH = 8
    WEIGHT_DECAY = 0.01
    SAVE_TOTAL_LIM = 3
    NUM_EPOCHS = 1

    # Set up training arguments
    model_args = Seq2SeqTrainingArguments(
        output_dir="./results",
        evaluation_strategy="steps",
        eval_steps=500,
        logging_strategy="steps",
        logging_steps=500,
        learning_rate=L_RATE,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH,
        weight_decay=WEIGHT_DECAY,
        save_total_limit=SAVE_TOTAL_LIM,
        num_train_epochs=NUM_EPOCHS,
        predict_with_generate=True,
        push_to_hub=False,
    )

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
    rouge = evaluate.load("rouge")

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Rouge expects a newline after each sentence
        decoded_preds = ["|".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
        decoded_labels = ["|".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

        result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
        # Extract a few results
        return {k: round(v, 4) for k, v in result.items()}


    trainer = Seq2SeqTrainer(
        model,
        model_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )
    trainer.train()
    evaluation = trainer.evaluate()
    print(evaluation)
    trainer.save_model(args.saved_model_name)
    print("Model saved")
    print("Training complete")
    # save tokenizer
    tokenizer.save_pretrained(args.saved_model_name)
Editor is loading...
Leave a Comment