Untitled
unknown
plain_text
a year ago
3.3 kB
7
Indexable
import argparse import csv import time from typing import List, Dict import torch from torch.utils.data import DataLoader, Dataset import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset def setup(rank: int, world_size: int): dist.init_process_group("nccl", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() class CustomDataset(Dataset): def __init__(self, texts: List[str], tokenizer, max_length: int): self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length, return_tensors="pt") def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: return {key: val[idx] for key, val in self.encodings.items()} def __len__(self) -> int: return len(self.encodings.input_ids) def run_inference(rank: int, world_size: int, args): setup(rank, world_size) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) # Load model and tokenizer model = AutoModelForCausalLM.from_pretrained(args.model_path) tokenizer = AutoTokenizer.from_pretrained(args.model_path) # Mixed precision model = model.half().to(device) model = DDP(model, device_ids=[rank]) # Load dataset dataset = load_dataset(args.data_path) texts = [text[:args.max_length] for text in dataset['train']['text']] # Clip text to max_length # Prepare dataset and dataloader dataset = CustomDataset(texts, tokenizer, max_length=args.max_length) sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) dataloader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4, pin_memory=True) # Inference model.eval() total_tokens = 0 start_time = time.time() with torch.no_grad(): for batch in dataloader: batch = {k: v.to(device) for k, v in batch.items()} outputs = model(**batch) total_tokens += batch['input_ids'].numel() end_time = time.time() inference_time = end_time - start_time tokens_per_second = total_tokens / inference_time if rank == 0: with open('inference_times.csv', 'a', newline='') as f: writer = csv.writer(f) writer.writerow([args.num_gpus, args.model_path, args.data_path, args.max_length, args.batch_size, inference_time, tokens_per_second]) cleanup() return tokens_per_second def main(): parser = argparse.ArgumentParser(description="Batch inference on a HuggingFace Causal LM") parser.add_argument("--num_gpus", type=int, default=4, help="Number of GPUs to use") parser.add_argument("--model_path", type=str, required=True, help="Path to the model") parser.add_argument("--data_path", type=str, required=True, help="Path to the HuggingFace dataset") parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length") parser.add_argument("--batch_size", type=int, default=32, help="Batch size") args = parser.parse_args() import torch.multiprocessing as mp mp.spawn(run_inference, args=(args.num_gpus, args), nprocs=args.num_gpus) if __name__ == "__main__": main()
Editor is loading...
Leave a Comment