EncoderDecoder-Finetuning

mail@pastecode.io avatar
unknown
python
2 years ago
2.7 kB
3
Indexable
Never
from transformers import EncoderDecoderModel
from seq2seq_trainer import Seq2SeqTrainer
from seq2seq_training_args import Seq2SeqTrainingArguments
import datasets

#first we can try bert-base-uncased, then with smaller bert models
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")


#bert-config update for encoder-decoder architecture
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size

#beam-search parameters set to bart-model (works well on MT)
bert2bert.config.max_length = 142
bert2bert.config.min_length = 56
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

%%capture
!rm seq2seq_trainer.py
!rm seq2seq_training_args.py
!wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/seq2seq/seq2seq_trainer.py
!wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/seq2seq/seq2seq_training_args.py

%%capture
!pip install git-python==1.0.3
!pip install rouge_score
!pip install sacrebleu

#set the training args
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=True, 
    output_dir="./",
    logging_steps=2,
    save_steps=10,
    eval_steps=4,
    # logging_steps=1000,
    # save_steps=500,
    # eval_steps=7500,
    # warmup_steps=2000,
    # save_total_limit=3,
)

metric = datasets.load_metric("my-metric")


def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    

    final_score = metric.compute(predictions = pred_str, references = label_str)
    
    return final_score

    #rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    #return {
    #    "rouge2_precision": round(rouge_output.precision, 4),
    #    "rouge2_recall": round(rouge_output.recall, 4),
    #    "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    #}
    
    # instantiate trainer
    
trainer = Seq2SeqTrainer(
    model=bert2bert,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=val_data,
)

trainer.train()