Trainer
unknown
python
2 years ago
2.2 kB
6
Indexable
from transformers import ( DataCollatorWithPadding, MarianMTModel, MarianTokenizer, TrainingArguments, Trainer, logging, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, AutoTokenizer ) from datasets import load_dataset import evaluate # Setting the verbosity level logging.set_verbosity_error() # Model loading model_name = "Helsinki-NLP/opus-mt-en-fr" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Evaluation metric loading and preparation def compute_metrics(eval_preds): metric = evaluate.load("sacrebleu") logits, labels = eval_preds predictions = np.argmax(logits, axis=-1) return metric.compute(predictions=predictions, references=labels) # Dataset preparation def tokenize_function(examples): model_inputs = tokenizer([e["en"] for e in examples["translation"]], truncation=True) with tokenizer.as_target_tokenizer(): model_target = tokenizer([e["cs"] for e in examples["translation"]], truncation=True) model_inputs["labels"] = model_target["input_ids"] return model_inputs raw_datasets = load_dataset("wmt19", "cs-en", split=['train[:100000]', 'validation'], cache_dir="/home/javorsky/personal_work_troja/.cache/huggingface/datasets/") tokenized_train = raw_datasets[0].map(tokenize_function, batched=True) tokenized_valid = raw_datasets[1].map(tokenize_function, batched=True) data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer) # Training arguments args = Seq2SeqTrainingArguments( f"models/{model_name.split('/')[1]}.from-scratch", evaluation_strategy = "steps", eval_steps=4000, learning_rate=2e-5, per_device_train_batch_size=32, per_device_eval_batch_size=32, weight_decay=0.01, save_total_limit=3, num_train_epochs=1, # predict_with_generate=True ) # Training trainer = Seq2SeqTrainer( model, args, train_dataset=tokenized_train, eval_dataset=tokenized_valid, data_collator=data_collator, tokenizer=tokenizer, compute_metrics=compute_metrics ) trainer.train()
Editor is loading...