Llama BatchError finetuning

mail@pastecode.io avatarunknown
python
a month ago
2.0 kB
1
Indexable
Never
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset
import transformers
import torch
# from torch import nn
import numpy as np
import evaluate

PAD_TOKEN = '<pad>'

print('Loading model and tokenizer...')
model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-chat-hf', num_labels=1)
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf')

# Add padding token
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
model.config.pad_token_id = tokenizer(PAD_TOKEN)
model.resize_token_embeddings(len(tokenizer))
# model.embed_tokens = nn.Embedding(model.config.vocab_size, model.config.hidden_size, padding_idx=model.config.pad_token_id)  # ???

training_args = TrainingArguments(output_dir="/nfs/hpc/share/wildma/ml_model_stuff/llama_training/training_session_one", evaluation_strategy="epoch")

# Making and tokenizing the set
# training_set = load_dataset('json', data_files='/nfs/hpc/share/wildma/ml_model_stuff/llama_training/verbose_labeled/train.json')
training_set = load_dataset('json', data_files='/nfs/hpc/share/wildma/ml_model_stuff/llama_training/deprecated/testset-02.json', field='data')

def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=2048)

tokenized_training_set = training_set.map(tokenize_function, batched=True)

# Metrics
metric = evaluate.load('accuracy')
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_training_set['train'],
    compute_metrics=compute_metrics,
)

print('Training the model...')
trainer.train()

print('Training complete. Saving the model...')
trainer.save_model('/nfs/hpc/share/wildma/ml_model_stuff/llama_training/trained_model_one')