Untitled
unknown
python
4 years ago
7.3 kB
6
Indexable
import argparse import copy import os import csv import datasets import datetime as dt import joblib import logging import lrqa.tasks as tasks import math import nltk import numpy as np import os import string from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, T5TokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments from transformers.data.data_collator import DataCollatorForSeq2Seq import torch import torch.nn as nn def parse_args(): parser = argparse.ArgumentParser(description='Train long doc QA Model.') parser.add_argument('--model_name', type=str, default="t5-base") parser.add_argument('--out_fol', type=str, required=True) parser.add_argument('--learning_rate', type=float, default=1e-3) parser.add_argument('--num_steps', type=int, default=80000) parser.add_argument('--batch_size', type=int, default=8) parser.add_argument('--max_input_length', type=int, default=512) parser.add_argument('--grad_accum_steps', type=int, default=1) parser.add_argument('--parallelize', action="store_true") parser.add_argument('--gradient_checkpointing', action="store_true") parser.add_argument('--answer_format', default='text', choices=['letter', 'number', 'text']) return parser.parse_args() def main(args): logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(message)s') # Look for checkpoint files in the output folder. checkpoint_nums = [] if os.path.exists(args.out_fol): for path in os.listdir(args.out_fol): if path.startswith("checkpoint-") and (not path.endswith("last")) and os.path.isdir(os.path.join(args.out_fol, path)): checkpoint_nums.append(int(path[len("checkpoint-"):])) if os.path.exists(os.path.join(args.out_fol, 'checkpoint-last')): pretrained = os.path.join(args.out_fol, "checkpoint-last") elif checkpoint_nums: latest_ckpt_num = sorted(checkpoint_nums)[-1] pretrained = os.path.join(args.out_fol, f"checkpoint-{latest_ckpt_num}") else: pretrained = args.model_name def process_race_example(example): answer_format = args.answer_format if answer_format == 'letter' or answer_format == 'text': answer_labels = string.ascii_uppercase else: answer_labels = list(range(4)) context = f"{example['question']} \n " answer_labels_rev = {"A": 0, "B": 1, "C": 2, "D": 3} for i in range(len(answer_labels_rev)): context += f"({answer_labels[i]}): {example[f'options'][i]} " context += f"\n {example['article']}" context = context.lower() if answer_format == 'text': label = example['options'][answer_labels_rev[example['answer'].upper()]].lower() elif answer_format == 'letter': label = example['answer'].lower() else: label = str(answer_labels_rev[example['answer'].upper()]) return {"input_str": context, "label_str": label} model = AutoModelForSeq2SeqLM.from_pretrained( pretrained, ) tokenizer = AutoTokenizer.from_pretrained(args.model_name) if args.parallelize: model.parallelize() else: model = model.cuda() device = torch.device("cuda:0") def tokenize_fn(examples): tokenized_inputs = tokenizer( examples["input_str"], max_length=args.max_input_length, truncation="only_first", padding="max_length", ) tokenized_labels = tokenizer( examples["label_str"], max_length=100, truncation="only_first", padding="max_length", return_tensors="pt", ) target_ids = tokenized_labels["input_ids"] target_ids[target_ids[:, :] == tokenizer.pad_token_id] = -100 out_dict = { "input_ids": tokenized_inputs["input_ids"], "attention_mask": tokenized_inputs["attention_mask"], "labels": target_ids.tolist(), } return out_dict race = datasets.load_dataset("race", "all") preprocessed_train = race["train"].map(process_race_example, remove_columns=[ 'example_id', 'article', 'answer', 'question', 'options' ]) tokenized_train = preprocessed_train.map(tokenize_fn, batched=True) tokenized_train.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) # Print a single example to be sure of correctness print(tokenizer.decode(tokenized_train['input_ids'][0])) l_copy = copy.deepcopy(tokenized_train['labels'][0]) l_copy[l_copy[:] == -100] = tokenizer.pad_token_id print(tokenizer.decode(l_copy)) validation = race["validation"].map(process_race_example, remove_columns=[ 'example_id', 'article', 'answer', 'question', 'options' ]) preprocessed_validation = validation.map(tokenize_fn, batched=True) preprocessed_validation.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) training_args = Seq2SeqTrainingArguments( output_dir=args.out_fol, logging_dir=args.out_fol, disable_tqdm=False, do_train=True, evaluation_strategy="steps", eval_steps=5000, gradient_accumulation_steps=args.grad_accum_steps, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=1, learning_rate=args.learning_rate, max_steps=args.num_steps, save_steps=5000, eval_accumulation_steps=1, logging_steps=500, ) trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=tokenized_train, eval_dataset=preprocessed_validation, tokenizer=tokenizer, ) if "checkpoint-" in pretrained: print(f"Training from checkpoint: {pretrained}") trainer.train(resume_from_checkpoint=pretrained) else: trainer.train() trainer.save_model(output_dir=os.path.join(args.out_fol, "checkpoint-last")) # Output predictions val_batch_size = 100 outputs = [] for i in range(math.ceil(len(preprocessed_validation) / val_batch_size)): data = datasets.Dataset.from_dict(preprocessed_validation[i*val_batch_size:(i+1)*val_batch_size]) preds, labels, eval_metrics = trainer.predict(data, num_beams=1) for j, (p, l) in enumerate(zip(preds[0], labels)): ex_preds = np.argmax(p, axis=-1) pred_str = tokenizer.decode(ex_preds) pred_str = pred_str[:pred_str.find("</s>")] l[l[:] == -100] = 0 label_str = tokenizer.decode(l) label_str = label_str[:label_str.find("</s>")] original_idx = i*val_batch_size+j outputs.append((validation['input_str'][original_idx], pred_str, label_str)) results_csv_path = os.path.join(args.out_fol, "validation_predictions.csv") with open(results_csv_path, "w") as f: csvwriter = csv.writer(f) fieldnames = ["Input", "Predicted", "Label"] csvwriter.writerow(fieldnames) for row in outputs: csvwriter.writerow(row) if __name__ == "__main__": main(parse_args())
Editor is loading...