validatino adapter script
unknown
python
6 months ago
15 kB
5
Indexable
import os import random import argparse from pathlib import Path import json import itertools import time import wandb from tqdm import tqdm import copy from typing import List import wandb import random import torch import torchvision from torchvision.utils import make_grid import torch.nn.functional as F from torchvision import transforms from diffusers import StableDiffusionPipeline from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, UNet2DConditionModel from diffusers.pipelines.controlnet import MultiControlNetModel from PIL import Image from safetensors import safe_open from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ip_adapter.utils import is_torch2_available, get_generator from ip_adapter.ip_adapter import ImageProjModel if is_torch2_available(): from ip_adapter.attention_processor import ( AttnProcessor2_0 as AttnProcessor, ) from ip_adapter.attention_processor import ( CNAttnProcessor2_0 as CNAttnProcessor, ) from ip_adapter.attention_processor import ( IPAttnProcessor2_0 as IPAttnProcessor, ) else: from ip_adapter.attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor from ip_adapter.resampler import Resampler import mae import dinov2.models.vision_transformer as vits from StaticCOCOFreeviewDatasetIP_adapter_xl import * seed = 99 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # Set deterministic behavior torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def draw_fixations(image, fixations, radius=10, color=(255, 0, 0), thickness=-1): """ Draw fixation points on the image. :param image: numpy array of shape (H, W, C) :param fixations: list of (array([x1, x2, ...]), array([y1, y2, ...])) pairs :param radius: radius of the fixation points :param color: color of the fixation points (B, G, R) :param thickness: thickness of the fixation points (-1 for filled circle) :return: image with fixations drawn """ image = image.copy() # Ensure fixations is a list if not isinstance(fixations, list): fixations = [fixations] for fixation_pair in fixations: if not (isinstance(fixation_pair, tuple) and len(fixation_pair) == 2): print(f"Skipping invalid fixation pair: {fixation_pair}") continue x_coords, y_coords = fixation_pair if not (isinstance(x_coords, np.ndarray) and isinstance(y_coords, np.ndarray)): print(f"Skipping non-array fixation pair: {fixation_pair}") continue if len(x_coords) != len(y_coords): print(f"Mismatched coordinate lengths: {len(x_coords)} vs {len(y_coords)}") continue for x, y in zip(x_coords, y_coords): try: cv2.circle(image, (int(x), int(y)), radius, color, thickness) except Exception as e: print(f"Error drawing fixation ({x}, {y}): {str(e)}") return image def process_scanpath(scanpath, scale): scanpath_array = np.array(scanpath) scaled_scanpath = np.round(scanpath_array * scale).astype(int) return list(map(tuple, scaled_scanpath)) class IPAdapter(torch.nn.Module): """IP-Adapter""" def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): super().__init__() self.unet = unet self.image_proj_model = image_proj_model self.adapter_modules = adapter_modules if ckpt_path is not None: self.load_from_checkpoint(ckpt_path) def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") image_embeds_cond, image_embeds_uncond = image_embeds.to(device), torch.zeros_like(image_embeds).to(device) ip_tokens_cond = self.image_proj_model(image_embeds_cond) ip_tokens_uncond = self.image_proj_model(image_embeds_uncond) encoder_hidden_states_cond = torch.cat([encoder_hidden_states, ip_tokens_cond], dim=1) encoder_hidden_states_uncond = torch.cat([encoder_hidden_states, ip_tokens_uncond], dim=1) # Predict the noise residual combined_encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_cond], dim=0) combined_noisy_latents = torch.cat([noisy_latents, noisy_latents], dim=0) # Single forward pass through U-Net combined_noise_pred = self.unet(combined_noisy_latents, timesteps, combined_encoder_hidden_states).sample # Split the predictions noise_pred_uncond, noise_pred_cond = combined_noise_pred.chunk(2) return noise_pred_uncond, noise_pred_cond def load_from_checkpoint(self, ckpt_path: str): # Calculate original checksums orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) state_dict = torch.load(ckpt_path, map_location="cpu") # Load state dict for image_proj_model and adapter_modules self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) # Calculate new checksums new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) # Verify if the weights have changed assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" print(f"Successfully loaded weights from checkpoint {ckpt_path}") def main(): random.seed(42) os.environ['WANDB_DIR'] = "/mnt/disk/data/add_disk0/rraina/ip_adapter_training_spGFj414_val/" wandb.init(project='ip_adapter_training') pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5" image_encoder_path = "/home/rraina/DiffPeriph/dinov2/dinov2_vitl14_reg4_pretrain.pth" ip_ckpt = "/mnt/disk/data/add_disk0/rraina/ip_adapter_training_spGFj414/checkpoint-50000/ip_adapter.bin" device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") # Load scheduler, tokenizer and sd-models. noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder").to(device) vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae").to(device) unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet").to(device) image_encoder = vits.vit_large(img_size=518, patch_size=14, init_values=1.0, block_chunks=0, num_register_tokens=4) state_dict = torch.load(image_encoder_path, map_location='cpu') image_encoder.load_state_dict(state_dict, strict=True) image_encoder.eval().to(device) image_proj_model = ImageProjModel( cross_attention_dim=unet.config.cross_attention_dim, clip_embeddings_dim=1024, clip_extra_context_tokens=256, ).to(device) attn_procs = {} unet_sd = unet.state_dict() for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: layer_name = name.split(".processor")[0] weights = { "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], } attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=256) attn_procs[name].load_state_dict(weights) unet.set_attn_processor(attn_procs) adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, ip_ckpt).to(device) # freeze parameters of models to save more memory unet.requires_grad_(False) vae.requires_grad_(False) text_encoder.requires_grad_(False) image_encoder.requires_grad_(False) ip_adapter.requires_grad_(False) val_dataset = StaticCOCOFreeviewDatasetIP_adapter(root_dir='/mnt/disk/data/add_disk0/rraina/cocofreeview_fov_val/', json_file='/home/rraina/coco-freeview/all_fixations_info_val_real.json', annotation_file='/home/rraina/diffusion-peripheral/assets/annotations/captions_trainval2014_2017.json', tokenizer=tokenizer, encoder_size=(224,224), gen_size=(512,512), t_drop_rate=0.00, i_drop_rate=0.00, ti_drop_rate=0.00) fixed_random_numbers = random.choices(range(0, 1000), k=50) print(fixed_random_numbers) val_dataloader = torch.utils.data.DataLoader( torch.utils.data.Subset(val_dataset, fixed_random_numbers), shuffle=True, batch_size=1, num_workers=8, ) num_inference_steps=50 for batch in tqdm(val_dataloader): with torch.no_grad(): foveated_img_encoder = batch["foveated_img_encoder"].to(device) foveated_img_ipadapter = batch["foveated_img_ipadapter"].to(device) original_img_encoder = batch["original_img_encoder"].to(device) original_img_ipadapter = batch["original_img_ipadapter"].to(device) mask_img_ipadapter = batch["mask_img_ipadapter"].to(device) text_input_ids = batch["text_input_ids"].to(device) current_fixation_idx = batch['current_fixation_idx'].to(device) fixations = batch['fixations'] drop_image_embeds = batch['drop_image_embeds'] scale = np.array([512 / 1680, 512 / 1050]) scaled_fixations = [process_scanpath(scanpath, scale) for scanpath in fixations] fixations = np.array(fixations) scale = np.array([512 / 1680, 512 / 1050]) fixations = np.round(fixations * scale).astype(int) fixations = list(map(tuple, fixations)) noise_scheduler.set_timesteps(num_inference_steps) init_timestep = min(int(num_inference_steps*1.0), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = noise_scheduler.timesteps[t_start:].to(device) init_latents = vae.encode(original_img_ipadapter).latent_dist.sample() init_latents = init_latents * vae.config.scaling_factor init_latents_orig = init_latents.to(device) noise = torch.randn(init_latents.shape).to(device) init_latents = noise_scheduler.add_noise(init_latents, noise, timesteps[:1]).to(device) latents = init_latents mask = torch.nn.functional.interpolate(1-mask_img_ipadapter, size=latents.shape[-2:], mode='nearest').to(device) orig_image_embeds = image_encoder.forward_features(foveated_img_encoder) patch_tokens = orig_image_embeds['x_patchtokens'] reg_tokens = orig_image_embeds['x_regtokens'] cls_tokens = orig_image_embeds['x_clstoken'] B, N, _ = patch_tokens.shape fixation_mask = F.interpolate(mask_img_ipadapter.float(), size=(int(N ** 0.5), int(N ** 0.5)), mode='area') fixation_mask = 1 - (fixation_mask == 1).float() fixation_mask = fixation_mask.view(B, N, 1) fixation_tokens = patch_tokens * fixation_mask # Concatenate fixation, register, and cls tokens image_embeds = fixation_tokens encoder_hidden_states = text_encoder(text_input_ids)[0].to(device) w = 6.5 for t in tqdm(timesteps): noise_pred_uncond, noise_pred_cond = ip_adapter(latents, t, encoder_hidden_states, image_embeds) noise_pred = noise_pred_uncond + w * (noise_pred_cond - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample # decode the latents latents = 1 / vae.config.scaling_factor * latents with torch.autocast(device_type='cuda:2', dtype=torch.float16): image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).detach().numpy() # Convert to float32 for numpy operations image = (image * 255).round().astype("uint8") # denormalize images mean = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1) std = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1) original_img = original_img_ipadapter.cpu().detach() * std + mean foveated_img = foveated_img_ipadapter.cpu().detach() * std + mean mask_img = mask_img_ipadapter.cpu().detach() * std + mean pred_img = torch.from_numpy(image).permute(0, 3, 1, 2).float() / 255.0 original_img_with_fixations = draw_fixations((original_img[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8), fixations) original_img_with_fixations = torch.from_numpy(original_img_with_fixations).permute(2, 0, 1).float() / 255.0 img_grid = torchvision.utils.make_grid([original_img[0].float(), original_img_with_fixations, foveated_img[0].float(), pred_img[0]], nrow=4) wandb.log({ "images": wandb.Image((np.transpose(img_grid.cpu().numpy(), (1, 2, 0)) * 255).astype(np.uint8), caption="Original | Original+Fixations | Foveated | Generated") }) wandb.run.save() if __name__ == "__main__": main()
Editor is loading...
Leave a Comment