Untitled
unknown
python
4 years ago
7.3 kB
9
Indexable
import argparse
import copy
import os
import csv
import datasets
import datetime as dt
import joblib
import logging
import lrqa.tasks as tasks
import math
import nltk
import numpy as np
import os
import string
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, T5TokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers.data.data_collator import DataCollatorForSeq2Seq
import torch
import torch.nn as nn
def parse_args():
parser = argparse.ArgumentParser(description='Train long doc QA Model.')
parser.add_argument('--model_name', type=str, default="t5-base")
parser.add_argument('--out_fol', type=str, required=True)
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--num_steps', type=int, default=80000)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--max_input_length', type=int, default=512)
parser.add_argument('--grad_accum_steps', type=int, default=1)
parser.add_argument('--parallelize', action="store_true")
parser.add_argument('--gradient_checkpointing', action="store_true")
parser.add_argument('--answer_format', default='text', choices=['letter', 'number', 'text'])
return parser.parse_args()
def main(args):
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(message)s')
# Look for checkpoint files in the output folder.
checkpoint_nums = []
if os.path.exists(args.out_fol):
for path in os.listdir(args.out_fol):
if path.startswith("checkpoint-") and (not path.endswith("last")) and os.path.isdir(os.path.join(args.out_fol, path)):
checkpoint_nums.append(int(path[len("checkpoint-"):]))
if os.path.exists(os.path.join(args.out_fol, 'checkpoint-last')):
pretrained = os.path.join(args.out_fol, "checkpoint-last")
elif checkpoint_nums:
latest_ckpt_num = sorted(checkpoint_nums)[-1]
pretrained = os.path.join(args.out_fol, f"checkpoint-{latest_ckpt_num}")
else:
pretrained = args.model_name
def process_race_example(example):
answer_format = args.answer_format
if answer_format == 'letter' or answer_format == 'text':
answer_labels = string.ascii_uppercase
else:
answer_labels = list(range(4))
context = f"{example['question']} \n "
answer_labels_rev = {"A": 0, "B": 1, "C": 2, "D": 3}
for i in range(len(answer_labels_rev)):
context += f"({answer_labels[i]}): {example[f'options'][i]} "
context += f"\n {example['article']}"
context = context.lower()
if answer_format == 'text':
label = example['options'][answer_labels_rev[example['answer'].upper()]].lower()
elif answer_format == 'letter':
label = example['answer'].lower()
else:
label = str(answer_labels_rev[example['answer'].upper()])
return {"input_str": context, "label_str": label}
model = AutoModelForSeq2SeqLM.from_pretrained(
pretrained,
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if args.parallelize:
model.parallelize()
else:
model = model.cuda()
device = torch.device("cuda:0")
def tokenize_fn(examples):
tokenized_inputs = tokenizer(
examples["input_str"],
max_length=args.max_input_length,
truncation="only_first",
padding="max_length",
)
tokenized_labels = tokenizer(
examples["label_str"],
max_length=100,
truncation="only_first",
padding="max_length",
return_tensors="pt",
)
target_ids = tokenized_labels["input_ids"]
target_ids[target_ids[:, :] == tokenizer.pad_token_id] = -100
out_dict = {
"input_ids": tokenized_inputs["input_ids"],
"attention_mask": tokenized_inputs["attention_mask"],
"labels": target_ids.tolist(),
}
return out_dict
race = datasets.load_dataset("race", "all")
preprocessed_train = race["train"].map(process_race_example, remove_columns=[
'example_id', 'article', 'answer', 'question', 'options'
])
tokenized_train = preprocessed_train.map(tokenize_fn, batched=True)
tokenized_train.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
# Print a single example to be sure of correctness
print(tokenizer.decode(tokenized_train['input_ids'][0]))
l_copy = copy.deepcopy(tokenized_train['labels'][0])
l_copy[l_copy[:] == -100] = tokenizer.pad_token_id
print(tokenizer.decode(l_copy))
validation = race["validation"].map(process_race_example, remove_columns=[
'example_id', 'article', 'answer', 'question', 'options'
])
preprocessed_validation = validation.map(tokenize_fn, batched=True)
preprocessed_validation.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
training_args = Seq2SeqTrainingArguments(
output_dir=args.out_fol,
logging_dir=args.out_fol,
disable_tqdm=False,
do_train=True,
evaluation_strategy="steps",
eval_steps=5000,
gradient_accumulation_steps=args.grad_accum_steps,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=1,
learning_rate=args.learning_rate,
max_steps=args.num_steps,
save_steps=5000,
eval_accumulation_steps=1,
logging_steps=500,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=preprocessed_validation,
tokenizer=tokenizer,
)
if "checkpoint-" in pretrained:
print(f"Training from checkpoint: {pretrained}")
trainer.train(resume_from_checkpoint=pretrained)
else:
trainer.train()
trainer.save_model(output_dir=os.path.join(args.out_fol, "checkpoint-last"))
# Output predictions
val_batch_size = 100
outputs = []
for i in range(math.ceil(len(preprocessed_validation) / val_batch_size)):
data = datasets.Dataset.from_dict(preprocessed_validation[i*val_batch_size:(i+1)*val_batch_size])
preds, labels, eval_metrics = trainer.predict(data, num_beams=1)
for j, (p, l) in enumerate(zip(preds[0], labels)):
ex_preds = np.argmax(p, axis=-1)
pred_str = tokenizer.decode(ex_preds)
pred_str = pred_str[:pred_str.find("</s>")]
l[l[:] == -100] = 0
label_str = tokenizer.decode(l)
label_str = label_str[:label_str.find("</s>")]
original_idx = i*val_batch_size+j
outputs.append((validation['input_str'][original_idx], pred_str, label_str))
results_csv_path = os.path.join(args.out_fol, "validation_predictions.csv")
with open(results_csv_path, "w") as f:
csvwriter = csv.writer(f)
fieldnames = ["Input", "Predicted", "Label"]
csvwriter.writerow(fieldnames)
for row in outputs:
csvwriter.writerow(row)
if __name__ == "__main__":
main(parse_args())
Editor is loading...