Untitled
unknown
plain_text
a year ago
3.3 kB
10
Indexable
import argparse
import csv
import time
from typing import List, Dict
import torch
from torch.utils.data import DataLoader, Dataset
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
def setup(rank: int, world_size: int):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class CustomDataset(Dataset):
def __init__(self, texts: List[str], tokenizer, max_length: int):
self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length, return_tensors="pt")
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
return {key: val[idx] for key, val in self.encodings.items()}
def __len__(self) -> int:
return len(self.encodings.input_ids)
def run_inference(rank: int, world_size: int, args):
setup(rank, world_size)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(args.model_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
# Mixed precision
model = model.half().to(device)
model = DDP(model, device_ids=[rank])
# Load dataset
dataset = load_dataset(args.data_path)
texts = [text[:args.max_length] for text in dataset['train']['text']] # Clip text to max_length
# Prepare dataset and dataloader
dataset = CustomDataset(texts, tokenizer, max_length=args.max_length)
sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4, pin_memory=True)
# Inference
model.eval()
total_tokens = 0
start_time = time.time()
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
total_tokens += batch['input_ids'].numel()
end_time = time.time()
inference_time = end_time - start_time
tokens_per_second = total_tokens / inference_time
if rank == 0:
with open('inference_times.csv', 'a', newline='') as f:
writer = csv.writer(f)
writer.writerow([args.num_gpus, args.model_path, args.data_path, args.max_length, args.batch_size, inference_time, tokens_per_second])
cleanup()
return tokens_per_second
def main():
parser = argparse.ArgumentParser(description="Batch inference on a HuggingFace Causal LM")
parser.add_argument("--num_gpus", type=int, default=4, help="Number of GPUs to use")
parser.add_argument("--model_path", type=str, required=True, help="Path to the model")
parser.add_argument("--data_path", type=str, required=True, help="Path to the HuggingFace dataset")
parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
args = parser.parse_args()
import torch.multiprocessing as mp
mp.spawn(run_inference, args=(args.num_gpus, args), nprocs=args.num_gpus)
if __name__ == "__main__":
main()Editor is loading...
Leave a Comment