Untitled

 avatar
user_3839718
python
2 years ago
8.8 kB
17
Indexable
import argparse

import pandas as pd
import torch
from loguru import logger
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup
import wandb
from elasticdb import ElasticDB
import sys
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from rouge_score import rouge_scorer
from accelerate import Accelerator


# Configure Loguru logger
logger.add(sys.stdout, format="{time} {level} {message}", level="DEBUG")
logger.add("train.log", format="{time} {level} {message}", level="DEBUG", rotation="1 week", retention="10 days")


# Preprocessing function
def format_recipe(data_row):
    ingredients = '\n'.join(data_row['ingredients']).strip().lower()
    instructions = '\n'.join(data_row['directions']).strip().lower()
    input_text = f"generate recipe instructions: {ingredients} </s>"
    target_text = f"{instructions} </s>"
    return input_text, target_text


def tokenize_data(examples, tokenizer, max_length=512):
    input_ids = []
    attention_masks = []
    labels = []
    for input_text, target_text in examples:
        input_encoding = tokenizer(input_text, max_length=max_length, padding='max_length', truncation=True,
                                   return_tensors='pt')
        target_encoding = tokenizer(target_text, max_length=max_length, padding='max_length', truncation=True,
                                    return_tensors='pt')

        input_ids.append(input_encoding['input_ids'][0])
        attention_masks.append(input_encoding['attention_mask'][0])
        labels.append(target_encoding['input_ids'][0])

    return {'input_ids': torch.stack(input_ids),
            'attention_mask': torch.stack(attention_masks),
            'labels': torch.stack(labels)}


class RecipeDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return len(self.encodings['input_ids'])

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx].clone().detach(),
            'attention_mask': self.encodings['attention_mask'][idx].clone().detach(),
            'labels': self.encodings['labels'][idx].clone().detach()
        }


def generate_predictions(model, dataloader, tokenizer, device):
    model.eval()
    predictions = []
    references = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask)
            predictions.extend(
                [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in outputs])
            references.extend(
                [tokenizer.decode(l, skip_special_tokens=True, clean_up_tokenization_spaces=True) for l in labels])

    return predictions, references


def compute_rouge(predictions, references):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = []
    for pred, ref in zip(predictions, references):
        scores.append(scorer.score(ref, pred))

    # Averaging the scores
    avg_scores = {}
    for key in scores[0].keys():
        avg_scores[f'rouge_{key}'] = sum([score[key].fmeasure for score in scores]) / len(scores) * 100

    return avg_scores


if __name__ == "__main__":
    pretrained_model = "t5-base"
    wandb_project_name = "t5-base-recipe-model"
    es = ElasticDB()

    parser = argparse.ArgumentParser()
    # Command-line arguments
    parser.add_argument("--num_epochs", type=int, default=3)
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--upper_limit", type=int, default=10)
    parser.add_argument("--batch_per_gpu", type=int, default=8)
    multi_gpu = False
    args = parser.parse_args()

    wandb.login(key="7ee655b7d2b7349a2a188f1b2b4a4aeda0bf7460")
    wandb.init(project=wandb_project_name)
    wandb.config.update(args)

    if torch.cuda.is_available():
        accelerator = Accelerator()
        device = accelerator.device
        num_epochs = 3
        gradient_accumulation_steps = 1
        logger.info("CUDA is available: {}".format(torch.cuda.is_available()))
        logger.info("GPU Count: {}".format(torch.cuda.device_count()))
        logger.info("Current device: {}".format(device))

        # Elasticsearch's data fetching
        es_client = es.search(index_name="recipenlp",
                              todos=True,
                              query={"size": 10,
                                     "query": {"match_all": {}}
                                     },
                              upper_limit=args.upper_limit)
        df = pd.DataFrame(es_client).sample(frac=1).reset_index(drop=True)
        logger.info("Dataframe created")
        formatted_data = [format_recipe(row) for _, row in df.iterrows()]
        logger.info("Data formatted")
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True, model_max_length=128)
        logger.info("Tokenizer created")
        tokenized_data = tokenize_data(formatted_data, tokenizer)
        logger.info("Data tokenized")
        train_size = int(0.8 * len(tokenized_data['input_ids']))
        train_dataset = RecipeDataset({k: v[:train_size] for k, v in tokenized_data.items()})
        val_dataset = RecipeDataset({k: v[train_size:] for k, v in tokenized_data.items()})

        logger.info('Train dataset size: {}'.format(len(train_dataset)))
        logger.info('Validation dataset size: {}'.format(len(val_dataset)))

        logger.info("DataLoader created")
        # Initialize the model and move it to the device
        model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model)
        model = model.to(accelerator.device)
        logger.info("Model created")

        # DataLoader
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_per_gpu,
                                  shuffle=True,
                                  num_workers=4,
                                  drop_last=True,
                                  )
        val_loader = DataLoader(val_dataset,
                                batch_size=args.batch_per_gpu,
                                shuffle=False,
                                num_workers=4,
                                drop_last=(accelerator.mixed_precision == "fp8"),
                                )

        # Initialize the optimizer and scheduler
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
        lr_scheduler = get_linear_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=100,
            num_training_steps=(len(train_loader) * num_epochs)
        )

        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
            model, optimizer, train_loader, val_loader, lr_scheduler
        )
        logger.info("Device: {}".format(device))
        for epoch in range(num_epochs):
            model.train()
            total_loss = 0
            progress_bar = tqdm(enumerate(train_loader),
                                total=len(train_loader),
                                desc=f"Epoch {epoch + 1}/{num_epochs}")

            for step, batch in progress_bar:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**batch)
                # Forward pass
                loss = outputs.loss
                total_loss += loss.item()

                # Backward and optimize
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                # Update progress bar
                progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})

                if step % 10 == 0:
                    wandb.log({"train_loss": loss.item(), "epoch": epoch, "step": step})

            # Evaluation after each epoch
            model.eval()
            predictions, references = generate_predictions(model, val_loader, tokenizer, device)
            scores = compute_rouge(predictions, references)
            wandb.log(scores)
            logger.info(scores)

        if accelerator.is_main_process:
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(f"./{wandb_project_name}")

        # Save tokenizer
        tokenizer.save_pretrained(f"./{wandb_project_name}")

        # Finish WandB logging
        wandb.finish()
Editor is loading...
Leave a Comment