training code
import torch from transformers import LlavaForConditionalGeneration, AutoProcessor, Trainer, TrainingArguments, DataCollatorForSeq2Seq, BitsAndBytesConfig from datasets import load_dataset, concatenate_datasets from PIL import Image from peft import LoraConfig, get_peft_model, TaskType from accelerate import dispatch_model, infer_auto_device_map from transformers.utils import logging from main import get_image_embedding, generate_mask import numpy as np import os print("HF_HOME:", os.getenv("HF_HOME"))# 先執行 export HF_HOME=/media/robot/VR/huggingface_cache print(torch.cuda.get_device_properties(0)) print("torch version:", torch.__version__) # 1. Load Dataset # Replace with your Hugging Face dataset name dataset_name = "ntudlcv/dlcv_2024_final1" dataset = load_dataset(dataset_name, split="train") # If want use val also as train data ''' train_dataset = load_dataset(dataset_name, split="train") val_dataset = load_dataset(dataset_name, split="val") # Combine train and val datasets into a single training dataset dataset = concatenate_datasets([train_dataset, val_dataset]) ''' #dataset = dataset.train_test_split(test_size=0.99)["train"] # 2. Preprocess the Dataset processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") processor.tokenizer.padding_side = "right" processor.patch_size = 14 # Common default for LLaVA models processor.vision_feature_select_strategy = "default" def preprocess_function(examples): """ Preprocess function to tokenize text and encode images. """ # User prompt (e.g., instructions) text_prompt = examples["conversations"][0]["value"] # Assistant response (e.g., description, explanation, or driving advice) #print("text prompt", text_prompt) text_response = examples["conversations"][1]["value"] image = examples["image"] if not isinstance(image, Image.Image): image = Image.open(image).convert("RGB") # Ensure RGB format np_image= np.array(image) image_inputs = processor.image_processor(image, return_tensors="pt") #pixel_values = image_inputs["pixel_values"].to(torch.bfloat16) pixel_values = image_inputs["pixel_values"] combined_text = f"\nUSER: {text_prompt} \nASSISTANT:{text_response} {processor.tokenizer.eos_token}" #combined_text = f"{text_prompt} {text_response} {processor.tokenizer.eos_token}" # Process the image and text prompt text_inputs = processor.tokenizer( combined_text, padding="max_length", # Fixed-length padding truncation=True, # Truncate if longer than max length max_length=512, return_tensors="pt" ) input_ids = text_inputs["input_ids"].squeeze(0) labels = input_ids.clone() prompt_length = len(processor.tokenizer(text_prompt, add_special_tokens=False)["input_ids"]) + 1 # +1 for SEP token labels[:prompt_length] = -100 # Ignore prompt tokens #print("Labels when training", labels) # Combine the inputs manually inputs = { "input_ids": input_ids, "attention_mask": text_inputs["attention_mask"].squeeze(0), "pixel_values": pixel_values.squeeze(0), "labels": labels, "image": np_image } return inputs # Apply the preprocessing to the dataset processed_dataset = dataset.map(preprocess_function, remove_columns=dataset.column_names, num_proc=8) # num_proc=8 for faster processing # 4. Training Arguments training_args = TrainingArguments( output_dir="./llava_finetuned", # Directory to save the model per_device_train_batch_size=1, # Adjust batch size based on GPU memory gradient_accumulation_steps=4, # Simulates larger batch size learning_rate=2e-4, # Standard learning rate for fine-tuning num_train_epochs=2, # Number of fine-tuning epochs logging_dir="./logs", # Logs directory logging_steps=10, # Log every 10 steps save_strategy="epoch", # Save model at the end of each epoch report_to="none", # Disable reporting to external tools fp16=False, # Mixed precision for faster training bf16=True, dataloader_num_workers=4, # Parallelize data loading logging_first_step=True, # Log the first step logging_strategy="steps", ) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) model = LlavaForConditionalGeneration.from_pretrained( "llava-hf/llava-1.5-7b-hf", quantization_config=bnb_config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto", ) def force_fp16_hook(module, inputs, output): """ A forward hook to ensure outputs are cast to float16. """ if isinstance(output, torch.Tensor): return output.to(torch.bfloat16) elif isinstance(output, tuple): return tuple(o.to(torch.bfloat16) if isinstance(o, torch.Tensor) else o for o in output) else: return output # Register the hook for the vision tower outputs for submodule in model.vision_tower.modules(): if isinstance(submodule, torch.nn.Module): submodule.register_forward_hook(force_fp16_hook) lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, # Causal language model r=8, # LoRA rank (small rank for efficient fine-tuning) lora_alpha=32, # Scaling factor for LoRA lora_dropout=0.1, # Dropout to prevent overfitting target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"] # Apply LoRA to query/key/value projection layers ) model = get_peft_model(model, lora_config) # Force LoRA layers to float16 for name, param in model.named_parameters(): if "lora" in name: # LoRA layers are named with 'lora' param.data = param.data.to(torch.bfloat16) param.requires_grad = True model.vision_tower.eval() for param in model.vision_tower.parameters(): param.requires_grad = False model.print_trainable_parameters() # Show trainable parameters (should be much fewer now) # Define the data collator for dynamic padding data_collator = DataCollatorForSeq2Seq( tokenizer=processor.tokenizer, model=model, #padding="max_length", # Force padding to a fixed length #max_length=256, pad_to_multiple_of=8 # Optional optimization ) # 5. Initialize Trainer with data_collator trainer = Trainer( model=model, args=training_args, train_dataset=processed_dataset, tokenizer=processor.tokenizer, data_collator=data_collator ) # Step 6: Fine-Tuning the Model print("Starting fine-tuning with LoRA...") trainer.train() # Step 7: Save the Fine-Tuned Model print("Saving the model...") model.save_pretrained("./llava_finetuned_lora") processor.save_pretrained("./llava_finetuned_lora") print("Fine-tuning complete! Model saved to './llava_finetuned_lora'")
Leave a Comment