training code

 avatar
unknown
plain_text
2 months ago
7.1 kB
4
Indexable
import torch
from transformers import LlavaForConditionalGeneration, AutoProcessor, Trainer, TrainingArguments, DataCollatorForSeq2Seq, BitsAndBytesConfig
from datasets import load_dataset, concatenate_datasets
from PIL import Image
from peft import LoraConfig, get_peft_model, TaskType

from accelerate import dispatch_model, infer_auto_device_map
from transformers.utils import logging
from main import get_image_embedding, generate_mask
import numpy as np
import os
print("HF_HOME:", os.getenv("HF_HOME"))# 先執行 export HF_HOME=/media/robot/VR/huggingface_cache

print(torch.cuda.get_device_properties(0))


print("torch version:", torch.__version__)
# 1. Load Dataset
# Replace with your Hugging Face dataset name
dataset_name = "ntudlcv/dlcv_2024_final1"
dataset = load_dataset(dataset_name, split="train")


# If want use val also as train data
'''
train_dataset = load_dataset(dataset_name, split="train")
val_dataset = load_dataset(dataset_name, split="val")

# Combine train and val datasets into a single training dataset
dataset = concatenate_datasets([train_dataset, val_dataset])
'''


#dataset = dataset.train_test_split(test_size=0.99)["train"]
# 2. Preprocess the Dataset
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
processor.tokenizer.padding_side = "right"
processor.patch_size = 14  # Common default for LLaVA models
processor.vision_feature_select_strategy = "default"




def preprocess_function(examples):
    """
    Preprocess function to tokenize text and encode images.
    """
    # User prompt (e.g., instructions)

    text_prompt = examples["conversations"][0]["value"]
    # Assistant response (e.g., description, explanation, or driving advice)
    #print("text prompt", text_prompt)
    text_response = examples["conversations"][1]["value"]

    image = examples["image"]
    if not isinstance(image, Image.Image):
        image = Image.open(image).convert("RGB")  # Ensure RGB format
    
    np_image= np.array(image)
    
    
    image_inputs = processor.image_processor(image, return_tensors="pt")
    #pixel_values = image_inputs["pixel_values"].to(torch.bfloat16)
    pixel_values = image_inputs["pixel_values"]

    combined_text = f"\nUSER: {text_prompt} \nASSISTANT:{text_response} {processor.tokenizer.eos_token}"

    #combined_text = f"{text_prompt} {text_response} {processor.tokenizer.eos_token}"    
    # Process the image and text prompt
    
    text_inputs = processor.tokenizer(
        combined_text,
        padding="max_length",       # Fixed-length padding
        truncation=True,            # Truncate if longer than max length
        max_length=512,
        return_tensors="pt"
    )

    input_ids = text_inputs["input_ids"].squeeze(0)
    labels = input_ids.clone()


    prompt_length = len(processor.tokenizer(text_prompt, add_special_tokens=False)["input_ids"]) + 1  # +1 for SEP token
    labels[:prompt_length] = -100  # Ignore prompt tokens
    #print("Labels when training", labels)

     # Combine the inputs manually
    inputs = {
        "input_ids": input_ids,
        "attention_mask": text_inputs["attention_mask"].squeeze(0),
        "pixel_values": pixel_values.squeeze(0),
        "labels": labels,
        "image": np_image
    }
    return inputs
   
    


# Apply the preprocessing to the dataset
processed_dataset = dataset.map(preprocess_function, remove_columns=dataset.column_names, num_proc=8) # num_proc=8 for faster processing


# 4. Training Arguments
training_args = TrainingArguments(
    output_dir="./llava_finetuned",     # Directory to save the model
    per_device_train_batch_size=1,      # Adjust batch size based on GPU memory
    gradient_accumulation_steps=4,      # Simulates larger batch size
    learning_rate=2e-4,                 # Standard learning rate for fine-tuning
    num_train_epochs=2,                 # Number of fine-tuning epochs
    logging_dir="./logs",               # Logs directory
    logging_steps=10,                   # Log every 10 steps
    save_strategy="epoch",              # Save model at the end of each epoch
    report_to="none",                   # Disable reporting to external tools
    fp16=False,              # Mixed precision for faster training
    bf16=True, 
    dataloader_num_workers=4,             # Parallelize data loading
    logging_first_step=True,            # Log the first step
    logging_strategy="steps", 
)


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)



model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    quantization_config=bnb_config, 
    torch_dtype=torch.bfloat16, 
    low_cpu_mem_usage=True,
    device_map="auto",
)





def force_fp16_hook(module, inputs, output):
    """
    A forward hook to ensure outputs are cast to float16.
    """
    if isinstance(output, torch.Tensor):
        return output.to(torch.bfloat16)
    elif isinstance(output, tuple):
        return tuple(o.to(torch.bfloat16) if isinstance(o, torch.Tensor) else o for o in output)
    else:
        return output

# Register the hook for the vision tower outputs
for submodule in model.vision_tower.modules():
    if isinstance(submodule, torch.nn.Module):
        submodule.register_forward_hook(force_fp16_hook)


lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,  # Causal language model
    r=8,                          # LoRA rank (small rank for efficient fine-tuning)
    lora_alpha=32,                 # Scaling factor for LoRA
    lora_dropout=0.1,              # Dropout to prevent overfitting
    target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"]  # Apply LoRA to query/key/value projection layers
)


model = get_peft_model(model, lora_config)



# Force LoRA layers to float16
for name, param in model.named_parameters():
    if "lora" in name:  # LoRA layers are named with 'lora'
        param.data = param.data.to(torch.bfloat16)
        param.requires_grad = True

model.vision_tower.eval()
for param in model.vision_tower.parameters():
    param.requires_grad = False

model.print_trainable_parameters()  # Show trainable parameters (should be much fewer now)


# Define the data collator for dynamic padding
data_collator = DataCollatorForSeq2Seq(
    tokenizer=processor.tokenizer,
    model=model,
    #padding="max_length",  # Force padding to a fixed length
    #max_length=256, 
    pad_to_multiple_of=8  # Optional optimization
)



# 5. Initialize Trainer with data_collator
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    tokenizer=processor.tokenizer,
    data_collator=data_collator
)


# Step 6: Fine-Tuning the Model
print("Starting fine-tuning with LoRA...")

trainer.train()

# Step 7: Save the Fine-Tuned Model
print("Saving the model...")
model.save_pretrained("./llava_finetuned_lora")
processor.save_pretrained("./llava_finetuned_lora")

print("Fine-tuning complete! Model saved to './llava_finetuned_lora'")
Leave a Comment