Untitled

mail@pastecode.io avatarunknown
python
2 months ago
871 B
5
Indexable
Never
from transformers import GPT2Tokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer
from typing import Any
from functools import partial

def collate(tokeniser: PreTrainedTokenizer, max_length: int, batch: list) -> Any:
    batch = [sample['text'] for sample in batch]
    return tokeniser.batch_encode_plus(batch, padding=True, max_length=max_length, truncation=True, return_tensors='pt')

tokeniser: GPT2Tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokeniser.pad_token = tokeniser.eos_token_id
dataset = load_dataset('c4', 'en.noblocklist', split='train', streaming=True)
dataloader = DataLoader(dataset, batch_size=1024, num_workers=12, drop_last=True, collate_fn=partial(collate, tokeniser, 100), prefetch_factor=4)

i = 0
for batch in dataloader:
    i += 1
    if i > 100:
        break