Untitled
user_3839718
python
2 years ago
8.8 kB
19
Indexable
import argparse
import pandas as pd
import torch
from loguru import logger
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup
import wandb
from elasticdb import ElasticDB
import sys
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from rouge_score import rouge_scorer
from accelerate import Accelerator
# Configure Loguru logger
logger.add(sys.stdout, format="{time} {level} {message}", level="DEBUG")
logger.add("train.log", format="{time} {level} {message}", level="DEBUG", rotation="1 week", retention="10 days")
# Preprocessing function
def format_recipe(data_row):
ingredients = '\n'.join(data_row['ingredients']).strip().lower()
instructions = '\n'.join(data_row['directions']).strip().lower()
input_text = f"generate recipe instructions: {ingredients} </s>"
target_text = f"{instructions} </s>"
return input_text, target_text
def tokenize_data(examples, tokenizer, max_length=512):
input_ids = []
attention_masks = []
labels = []
for input_text, target_text in examples:
input_encoding = tokenizer(input_text, max_length=max_length, padding='max_length', truncation=True,
return_tensors='pt')
target_encoding = tokenizer(target_text, max_length=max_length, padding='max_length', truncation=True,
return_tensors='pt')
input_ids.append(input_encoding['input_ids'][0])
attention_masks.append(input_encoding['attention_mask'][0])
labels.append(target_encoding['input_ids'][0])
return {'input_ids': torch.stack(input_ids),
'attention_mask': torch.stack(attention_masks),
'labels': torch.stack(labels)}
class RecipeDataset(Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __len__(self):
return len(self.encodings['input_ids'])
def __getitem__(self, idx):
return {
'input_ids': self.encodings['input_ids'][idx].clone().detach(),
'attention_mask': self.encodings['attention_mask'][idx].clone().detach(),
'labels': self.encodings['labels'][idx].clone().detach()
}
def generate_predictions(model, dataloader, tokenizer, device):
model.eval()
predictions = []
references = []
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask)
predictions.extend(
[tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in outputs])
references.extend(
[tokenizer.decode(l, skip_special_tokens=True, clean_up_tokenization_spaces=True) for l in labels])
return predictions, references
def compute_rouge(predictions, references):
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
scores = []
for pred, ref in zip(predictions, references):
scores.append(scorer.score(ref, pred))
# Averaging the scores
avg_scores = {}
for key in scores[0].keys():
avg_scores[f'rouge_{key}'] = sum([score[key].fmeasure for score in scores]) / len(scores) * 100
return avg_scores
if __name__ == "__main__":
pretrained_model = "t5-base"
wandb_project_name = "t5-base-recipe-model"
es = ElasticDB()
parser = argparse.ArgumentParser()
# Command-line arguments
parser.add_argument("--num_epochs", type=int, default=3)
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--upper_limit", type=int, default=10)
parser.add_argument("--batch_per_gpu", type=int, default=8)
multi_gpu = False
args = parser.parse_args()
wandb.login(key="7ee655b7d2b7349a2a188f1b2b4a4aeda0bf7460")
wandb.init(project=wandb_project_name)
wandb.config.update(args)
if torch.cuda.is_available():
accelerator = Accelerator()
device = accelerator.device
num_epochs = 3
gradient_accumulation_steps = 1
logger.info("CUDA is available: {}".format(torch.cuda.is_available()))
logger.info("GPU Count: {}".format(torch.cuda.device_count()))
logger.info("Current device: {}".format(device))
# Elasticsearch's data fetching
es_client = es.search(index_name="recipenlp",
todos=True,
query={"size": 10,
"query": {"match_all": {}}
},
upper_limit=args.upper_limit)
df = pd.DataFrame(es_client).sample(frac=1).reset_index(drop=True)
logger.info("Dataframe created")
formatted_data = [format_recipe(row) for _, row in df.iterrows()]
logger.info("Data formatted")
tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True, model_max_length=128)
logger.info("Tokenizer created")
tokenized_data = tokenize_data(formatted_data, tokenizer)
logger.info("Data tokenized")
train_size = int(0.8 * len(tokenized_data['input_ids']))
train_dataset = RecipeDataset({k: v[:train_size] for k, v in tokenized_data.items()})
val_dataset = RecipeDataset({k: v[train_size:] for k, v in tokenized_data.items()})
logger.info('Train dataset size: {}'.format(len(train_dataset)))
logger.info('Validation dataset size: {}'.format(len(val_dataset)))
logger.info("DataLoader created")
# Initialize the model and move it to the device
model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model)
model = model.to(accelerator.device)
logger.info("Model created")
# DataLoader
train_loader = DataLoader(train_dataset,
batch_size=args.batch_per_gpu,
shuffle=True,
num_workers=4,
drop_last=True,
)
val_loader = DataLoader(val_dataset,
batch_size=args.batch_per_gpu,
shuffle=False,
num_workers=4,
drop_last=(accelerator.mixed_precision == "fp8"),
)
# Initialize the optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=(len(train_loader) * num_epochs)
)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_loader, val_loader, lr_scheduler
)
logger.info("Device: {}".format(device))
for epoch in range(num_epochs):
model.train()
total_loss = 0
progress_bar = tqdm(enumerate(train_loader),
total=len(train_loader),
desc=f"Epoch {epoch + 1}/{num_epochs}")
for step, batch in progress_bar:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
# Forward pass
loss = outputs.loss
total_loss += loss.item()
# Backward and optimize
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Update progress bar
progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})
if step % 10 == 0:
wandb.log({"train_loss": loss.item(), "epoch": epoch, "step": step})
# Evaluation after each epoch
model.eval()
predictions, references = generate_predictions(model, val_loader, tokenizer, device)
scores = compute_rouge(predictions, references)
wandb.log(scores)
logger.info(scores)
if accelerator.is_main_process:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(f"./{wandb_project_name}")
# Save tokenizer
tokenizer.save_pretrained(f"./{wandb_project_name}")
# Finish WandB logging
wandb.finish()
Editor is loading...
Leave a Comment