EncoderDecoder-Finetuning
unknown
python
4 years ago
2.7 kB
9
Indexable
from transformers import EncoderDecoderModel
from seq2seq_trainer import Seq2SeqTrainer
from seq2seq_training_args import Seq2SeqTrainingArguments
import datasets
#first we can try bert-base-uncased, then with smaller bert models
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
#bert-config update for encoder-decoder architecture
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
#beam-search parameters set to bart-model (works well on MT)
bert2bert.config.max_length = 142
bert2bert.config.min_length = 56
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4
%%capture
!rm seq2seq_trainer.py
!rm seq2seq_training_args.py
!wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/seq2seq/seq2seq_trainer.py
!wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/seq2seq/seq2seq_training_args.py
%%capture
!pip install git-python==1.0.3
!pip install rouge_score
!pip install sacrebleu
#set the training args
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy="steps",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
fp16=True,
output_dir="./",
logging_steps=2,
save_steps=10,
eval_steps=4,
# logging_steps=1000,
# save_steps=500,
# eval_steps=7500,
# warmup_steps=2000,
# save_total_limit=3,
)
metric = datasets.load_metric("my-metric")
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
final_score = metric.compute(predictions = pred_str, references = label_str)
return final_score
#rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
#return {
# "rouge2_precision": round(rouge_output.precision, 4),
# "rouge2_recall": round(rouge_output.recall, 4),
# "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
#}
# instantiate trainer
trainer = Seq2SeqTrainer(
model=bert2bert,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_data,
eval_dataset=val_data,
)
trainer.train()
Editor is loading...