Untitled

 avatar
unknown
python
4 years ago
7.3 kB
6
Indexable
import argparse
import copy
import os

import csv
import datasets
import datetime as dt

import joblib
import logging
import lrqa.tasks as tasks
import math
import nltk
import numpy as np
import os
import string
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, T5TokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers.data.data_collator import DataCollatorForSeq2Seq
import torch
import torch.nn as nn


def parse_args():
    parser = argparse.ArgumentParser(description='Train long doc QA Model.')
    parser.add_argument('--model_name', type=str, default="t5-base")
    parser.add_argument('--out_fol', type=str, required=True)

    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--num_steps', type=int, default=80000)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--max_input_length', type=int, default=512)

    parser.add_argument('--grad_accum_steps', type=int, default=1)
    parser.add_argument('--parallelize', action="store_true")
    parser.add_argument('--gradient_checkpointing', action="store_true")
    parser.add_argument('--answer_format', default='text', choices=['letter', 'number', 'text'])
    return parser.parse_args()


def main(args):
    logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(message)s')
    
    # Look for checkpoint files in the output folder.
    checkpoint_nums = []
    if os.path.exists(args.out_fol):
        for path in os.listdir(args.out_fol):
            if path.startswith("checkpoint-") and (not path.endswith("last")) and os.path.isdir(os.path.join(args.out_fol, path)):
                checkpoint_nums.append(int(path[len("checkpoint-"):]))
    if os.path.exists(os.path.join(args.out_fol, 'checkpoint-last')):
        pretrained = os.path.join(args.out_fol, "checkpoint-last")
    elif checkpoint_nums:
        latest_ckpt_num = sorted(checkpoint_nums)[-1]
        pretrained = os.path.join(args.out_fol, f"checkpoint-{latest_ckpt_num}")
    else:
        pretrained = args.model_name
        
    def process_race_example(example):
        answer_format = args.answer_format
        if answer_format == 'letter' or answer_format == 'text':
            answer_labels = string.ascii_uppercase
        else:
            answer_labels = list(range(4))
        context = f"{example['question']} \n "
        answer_labels_rev = {"A": 0, "B": 1, "C": 2, "D": 3}
        for i in range(len(answer_labels_rev)):
            context += f"({answer_labels[i]}): {example[f'options'][i]} "
        context += f"\n {example['article']}"
        context = context.lower()
        if answer_format == 'text':
            label = example['options'][answer_labels_rev[example['answer'].upper()]].lower()
        elif answer_format == 'letter':
            label = example['answer'].lower()
        else:
            label = str(answer_labels_rev[example['answer'].upper()])
        return {"input_str": context, "label_str": label}

    model = AutoModelForSeq2SeqLM.from_pretrained(
        pretrained,
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    if args.parallelize:
        model.parallelize()
    else:
        model = model.cuda()
    device = torch.device("cuda:0")    

    def tokenize_fn(examples):
        tokenized_inputs = tokenizer(
            examples["input_str"],
            max_length=args.max_input_length,
            truncation="only_first",
            padding="max_length",
        )
        tokenized_labels = tokenizer(
            examples["label_str"],
            max_length=100,
            truncation="only_first",
            padding="max_length",
            return_tensors="pt",
        )
        target_ids = tokenized_labels["input_ids"]
        target_ids[target_ids[:, :] == tokenizer.pad_token_id] = -100
        out_dict = {
            "input_ids": tokenized_inputs["input_ids"],
            "attention_mask": tokenized_inputs["attention_mask"],
            "labels": target_ids.tolist(),
        }
        return out_dict
    
    race = datasets.load_dataset("race", "all")

    preprocessed_train = race["train"].map(process_race_example, remove_columns=[
        'example_id', 'article', 'answer', 'question', 'options'
    ])
    tokenized_train = preprocessed_train.map(tokenize_fn, batched=True)
    tokenized_train.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    
    # Print a single example to be sure of correctness
    print(tokenizer.decode(tokenized_train['input_ids'][0]))
    l_copy = copy.deepcopy(tokenized_train['labels'][0])
    l_copy[l_copy[:] == -100] = tokenizer.pad_token_id
    print(tokenizer.decode(l_copy))

    validation = race["validation"].map(process_race_example, remove_columns=[
        'example_id', 'article', 'answer', 'question', 'options'
    ])
    preprocessed_validation = validation.map(tokenize_fn, batched=True)
    preprocessed_validation.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

    training_args = Seq2SeqTrainingArguments(
        output_dir=args.out_fol,
        logging_dir=args.out_fol,
        disable_tqdm=False,
        do_train=True,
        evaluation_strategy="steps",
        eval_steps=5000,
        gradient_accumulation_steps=args.grad_accum_steps,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=1,
        learning_rate=args.learning_rate,
        max_steps=args.num_steps,
        save_steps=5000,
        eval_accumulation_steps=1,
        logging_steps=500,
    )
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=preprocessed_validation,
        tokenizer=tokenizer,
    )
    
    if "checkpoint-" in pretrained:
        print(f"Training from checkpoint: {pretrained}")
        trainer.train(resume_from_checkpoint=pretrained)
    else:
        trainer.train()
    trainer.save_model(output_dir=os.path.join(args.out_fol, "checkpoint-last"))
    
    # Output predictions
    val_batch_size = 100
    outputs = []
    for i in range(math.ceil(len(preprocessed_validation) / val_batch_size)):
        data = datasets.Dataset.from_dict(preprocessed_validation[i*val_batch_size:(i+1)*val_batch_size])
        preds, labels, eval_metrics = trainer.predict(data, num_beams=1)
        for j, (p, l) in enumerate(zip(preds[0], labels)):
            ex_preds = np.argmax(p, axis=-1)
            pred_str = tokenizer.decode(ex_preds)
            pred_str = pred_str[:pred_str.find("</s>")]
            l[l[:] == -100] = 0
            label_str = tokenizer.decode(l)
            label_str = label_str[:label_str.find("</s>")]
            original_idx = i*val_batch_size+j
            outputs.append((validation['input_str'][original_idx], pred_str, label_str))
            
    results_csv_path = os.path.join(args.out_fol, "validation_predictions.csv")
    with open(results_csv_path, "w") as f:
        csvwriter = csv.writer(f)
        fieldnames = ["Input", "Predicted", "Label"]
        csvwriter.writerow(fieldnames)
        for row in outputs:
            csvwriter.writerow(row)

if __name__ == "__main__":
    main(parse_args())
Editor is loading...