Trainer
unknown
python
2 years ago
2.2 kB
7
Indexable
from transformers import (
DataCollatorWithPadding,
MarianMTModel,
MarianTokenizer,
TrainingArguments,
Trainer,
logging,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
DataCollatorForSeq2Seq,
AutoModelForSeq2SeqLM,
AutoTokenizer
)
from datasets import load_dataset
import evaluate
# Setting the verbosity level
logging.set_verbosity_error()
# Model loading
model_name = "Helsinki-NLP/opus-mt-en-fr"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Evaluation metric loading and preparation
def compute_metrics(eval_preds):
metric = evaluate.load("sacrebleu")
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
# Dataset preparation
def tokenize_function(examples):
model_inputs = tokenizer([e["en"] for e in examples["translation"]], truncation=True)
with tokenizer.as_target_tokenizer():
model_target = tokenizer([e["cs"] for e in examples["translation"]], truncation=True)
model_inputs["labels"] = model_target["input_ids"]
return model_inputs
raw_datasets = load_dataset("wmt19", "cs-en", split=['train[:100000]', 'validation'],
cache_dir="/home/javorsky/personal_work_troja/.cache/huggingface/datasets/")
tokenized_train = raw_datasets[0].map(tokenize_function, batched=True)
tokenized_valid = raw_datasets[1].map(tokenize_function, batched=True)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)
# Training arguments
args = Seq2SeqTrainingArguments(
f"models/{model_name.split('/')[1]}.from-scratch",
evaluation_strategy = "steps",
eval_steps=4000,
learning_rate=2e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=1,
# predict_with_generate=True
)
# Training
trainer = Seq2SeqTrainer(
model,
args,
train_dataset=tokenized_train,
eval_dataset=tokenized_valid,
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()Editor is loading...