Trainer

 avatar
unknown
python
2 years ago
2.2 kB
6
Indexable
from transformers import (
    DataCollatorWithPadding,
    MarianMTModel,
    MarianTokenizer,
    TrainingArguments,
    Trainer,
    logging,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    AutoModelForSeq2SeqLM,
    AutoTokenizer
    )
from datasets import load_dataset
import evaluate

# Setting the verbosity level
logging.set_verbosity_error()

# Model loading
model_name = "Helsinki-NLP/opus-mt-en-fr"
tokenizer  = AutoTokenizer.from_pretrained(model_name)
model      = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Evaluation metric loading and preparation
def compute_metrics(eval_preds):
    metric = evaluate.load("sacrebleu")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# Dataset preparation
def tokenize_function(examples):
    model_inputs = tokenizer([e["en"] for e in examples["translation"]], truncation=True)
    with tokenizer.as_target_tokenizer():
        model_target = tokenizer([e["cs"] for e in examples["translation"]], truncation=True)
    model_inputs["labels"] = model_target["input_ids"]

    return model_inputs

raw_datasets = load_dataset("wmt19", "cs-en", split=['train[:100000]', 'validation'],
    cache_dir="/home/javorsky/personal_work_troja/.cache/huggingface/datasets/")

tokenized_train = raw_datasets[0].map(tokenize_function, batched=True)
tokenized_valid = raw_datasets[1].map(tokenize_function, batched=True)
data_collator   = DataCollatorForSeq2Seq(tokenizer=tokenizer)

# Training arguments
args = Seq2SeqTrainingArguments(
    f"models/{model_name.split('/')[1]}.from-scratch",
    evaluation_strategy = "steps",
    eval_steps=4000,
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    # predict_with_generate=True
)

# Training
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_valid,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()
Editor is loading...