Untitled
python
2 months ago
871 B
5
Indexable
Never
from transformers import GPT2Tokenizer from datasets import load_dataset from torch.utils.data import DataLoader from transformers import PreTrainedTokenizer from typing import Any from functools import partial def collate(tokeniser: PreTrainedTokenizer, max_length: int, batch: list) -> Any: batch = [sample['text'] for sample in batch] return tokeniser.batch_encode_plus(batch, padding=True, max_length=max_length, truncation=True, return_tensors='pt') tokeniser: GPT2Tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokeniser.pad_token = tokeniser.eos_token_id dataset = load_dataset('c4', 'en.noblocklist', split='train', streaming=True) dataloader = DataLoader(dataset, batch_size=1024, num_workers=12, drop_last=True, collate_fn=partial(collate, tokeniser, 100), prefetch_factor=4) i = 0 for batch in dataloader: i += 1 if i > 100: break