Untitled
user_3839718
python
2 years ago
8.8 kB
17
Indexable
import argparse import pandas as pd import torch from loguru import logger from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup import wandb from elasticdb import ElasticDB import sys from torch.utils.data import Dataset, DataLoader from tqdm.auto import tqdm from rouge_score import rouge_scorer from accelerate import Accelerator # Configure Loguru logger logger.add(sys.stdout, format="{time} {level} {message}", level="DEBUG") logger.add("train.log", format="{time} {level} {message}", level="DEBUG", rotation="1 week", retention="10 days") # Preprocessing function def format_recipe(data_row): ingredients = '\n'.join(data_row['ingredients']).strip().lower() instructions = '\n'.join(data_row['directions']).strip().lower() input_text = f"generate recipe instructions: {ingredients} </s>" target_text = f"{instructions} </s>" return input_text, target_text def tokenize_data(examples, tokenizer, max_length=512): input_ids = [] attention_masks = [] labels = [] for input_text, target_text in examples: input_encoding = tokenizer(input_text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt') target_encoding = tokenizer(target_text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt') input_ids.append(input_encoding['input_ids'][0]) attention_masks.append(input_encoding['attention_mask'][0]) labels.append(target_encoding['input_ids'][0]) return {'input_ids': torch.stack(input_ids), 'attention_mask': torch.stack(attention_masks), 'labels': torch.stack(labels)} class RecipeDataset(Dataset): def __init__(self, encodings): self.encodings = encodings def __len__(self): return len(self.encodings['input_ids']) def __getitem__(self, idx): return { 'input_ids': self.encodings['input_ids'][idx].clone().detach(), 'attention_mask': self.encodings['attention_mask'][idx].clone().detach(), 'labels': self.encodings['labels'][idx].clone().detach() } def generate_predictions(model, dataloader, tokenizer, device): model.eval() predictions = [] references = [] with torch.no_grad(): for batch in dataloader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask) predictions.extend( [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in outputs]) references.extend( [tokenizer.decode(l, skip_special_tokens=True, clean_up_tokenization_spaces=True) for l in labels]) return predictions, references def compute_rouge(predictions, references): scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) scores = [] for pred, ref in zip(predictions, references): scores.append(scorer.score(ref, pred)) # Averaging the scores avg_scores = {} for key in scores[0].keys(): avg_scores[f'rouge_{key}'] = sum([score[key].fmeasure for score in scores]) / len(scores) * 100 return avg_scores if __name__ == "__main__": pretrained_model = "t5-base" wandb_project_name = "t5-base-recipe-model" es = ElasticDB() parser = argparse.ArgumentParser() # Command-line arguments parser.add_argument("--num_epochs", type=int, default=3) parser.add_argument("--learning_rate", type=float, default=5e-5) parser.add_argument("--upper_limit", type=int, default=10) parser.add_argument("--batch_per_gpu", type=int, default=8) multi_gpu = False args = parser.parse_args() wandb.login(key="7ee655b7d2b7349a2a188f1b2b4a4aeda0bf7460") wandb.init(project=wandb_project_name) wandb.config.update(args) if torch.cuda.is_available(): accelerator = Accelerator() device = accelerator.device num_epochs = 3 gradient_accumulation_steps = 1 logger.info("CUDA is available: {}".format(torch.cuda.is_available())) logger.info("GPU Count: {}".format(torch.cuda.device_count())) logger.info("Current device: {}".format(device)) # Elasticsearch's data fetching es_client = es.search(index_name="recipenlp", todos=True, query={"size": 10, "query": {"match_all": {}} }, upper_limit=args.upper_limit) df = pd.DataFrame(es_client).sample(frac=1).reset_index(drop=True) logger.info("Dataframe created") formatted_data = [format_recipe(row) for _, row in df.iterrows()] logger.info("Data formatted") tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True, model_max_length=128) logger.info("Tokenizer created") tokenized_data = tokenize_data(formatted_data, tokenizer) logger.info("Data tokenized") train_size = int(0.8 * len(tokenized_data['input_ids'])) train_dataset = RecipeDataset({k: v[:train_size] for k, v in tokenized_data.items()}) val_dataset = RecipeDataset({k: v[train_size:] for k, v in tokenized_data.items()}) logger.info('Train dataset size: {}'.format(len(train_dataset))) logger.info('Validation dataset size: {}'.format(len(val_dataset))) logger.info("DataLoader created") # Initialize the model and move it to the device model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model) model = model.to(accelerator.device) logger.info("Model created") # DataLoader train_loader = DataLoader(train_dataset, batch_size=args.batch_per_gpu, shuffle=True, num_workers=4, drop_last=True, ) val_loader = DataLoader(val_dataset, batch_size=args.batch_per_gpu, shuffle=False, num_workers=4, drop_last=(accelerator.mixed_precision == "fp8"), ) # Initialize the optimizer and scheduler optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=100, num_training_steps=(len(train_loader) * num_epochs) ) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_loader, val_loader, lr_scheduler ) logger.info("Device: {}".format(device)) for epoch in range(num_epochs): model.train() total_loss = 0 progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}") for step, batch in progress_bar: batch = {k: v.to(device) for k, v in batch.items()} outputs = model(**batch) # Forward pass loss = outputs.loss total_loss += loss.item() # Backward and optimize accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Update progress bar progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'}) if step % 10 == 0: wandb.log({"train_loss": loss.item(), "epoch": epoch, "step": step}) # Evaluation after each epoch model.eval() predictions, references = generate_predictions(model, val_loader, tokenizer, device) scores = compute_rouge(predictions, references) wandb.log(scores) logger.info(scores) if accelerator.is_main_process: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained(f"./{wandb_project_name}") # Save tokenizer tokenizer.save_pretrained(f"./{wandb_project_name}") # Finish WandB logging wandb.finish()
Editor is loading...
Leave a Comment