Untitled

 avatar
unknown
plain_text
a year ago
3.3 kB
7
Indexable
import argparse
import csv
import time
from typing import List, Dict

import torch
from torch.utils.data import DataLoader, Dataset
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

def setup(rank: int, world_size: int):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class CustomDataset(Dataset):
    def __init__(self, texts: List[str], tokenizer, max_length: int):
        self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length, return_tensors="pt")

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        return {key: val[idx] for key, val in self.encodings.items()}

    def __len__(self) -> int:
        return len(self.encodings.input_ids)

def run_inference(rank: int, world_size: int, args):
    setup(rank, world_size)
    
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)

    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(args.model_path)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)

    # Mixed precision
    model = model.half().to(device)
    model = DDP(model, device_ids=[rank])

    # Load dataset
    dataset = load_dataset(args.data_path)
    texts = [text[:args.max_length] for text in dataset['train']['text']]  # Clip text to max_length

    # Prepare dataset and dataloader
    dataset = CustomDataset(texts, tokenizer, max_length=args.max_length)
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4, pin_memory=True)

    # Inference
    model.eval()
    total_tokens = 0
    start_time = time.time()
    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            total_tokens += batch['input_ids'].numel()

    end_time = time.time()
    inference_time = end_time - start_time
    tokens_per_second = total_tokens / inference_time

    if rank == 0:
        with open('inference_times.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([args.num_gpus, args.model_path, args.data_path, args.max_length, args.batch_size, inference_time, tokens_per_second])

    cleanup()
    return tokens_per_second

def main():
    parser = argparse.ArgumentParser(description="Batch inference on a HuggingFace Causal LM")
    parser.add_argument("--num_gpus", type=int, default=4, help="Number of GPUs to use")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model")
    parser.add_argument("--data_path", type=str, required=True, help="Path to the HuggingFace dataset")
    parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")

    args = parser.parse_args()

    import torch.multiprocessing as mp
    mp.spawn(run_inference, args=(args.num_gpus, args), nprocs=args.num_gpus)

if __name__ == "__main__":
    main()
Editor is loading...
Leave a Comment