validatino adapter script

 avatar
unknown
python
6 months ago
15 kB
5
Indexable
import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import wandb
from tqdm import tqdm
import copy
from typing import List
import wandb
import random

import torch
import torchvision
from torchvision.utils import make_grid
import torch.nn.functional as F
from torchvision import transforms

from diffusers import StableDiffusionPipeline
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, UNet2DConditionModel
from diffusers.pipelines.controlnet import MultiControlNetModel
from PIL import Image
from safetensors import safe_open
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ip_adapter.utils import is_torch2_available, get_generator
from ip_adapter.ip_adapter import ImageProjModel
if is_torch2_available():
    from ip_adapter.attention_processor import (
        AttnProcessor2_0 as AttnProcessor,
    )
    from ip_adapter.attention_processor import (
        CNAttnProcessor2_0 as CNAttnProcessor,
    )
    from ip_adapter.attention_processor import (
        IPAttnProcessor2_0 as IPAttnProcessor,
    )
else:
    from ip_adapter.attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
from ip_adapter.resampler import Resampler

import mae
import dinov2.models.vision_transformer as vits
from StaticCOCOFreeviewDatasetIP_adapter_xl import *

seed = 99
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Set deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def draw_fixations(image, fixations, radius=10, color=(255, 0, 0), thickness=-1):
    """
    Draw fixation points on the image.
    
    :param image: numpy array of shape (H, W, C)
    :param fixations: list of (array([x1, x2, ...]), array([y1, y2, ...])) pairs
    :param radius: radius of the fixation points
    :param color: color of the fixation points (B, G, R)
    :param thickness: thickness of the fixation points (-1 for filled circle)
    :return: image with fixations drawn
    """
    image = image.copy()
    
    # Ensure fixations is a list
    if not isinstance(fixations, list):
        fixations = [fixations]
    
    for fixation_pair in fixations:
        if not (isinstance(fixation_pair, tuple) and len(fixation_pair) == 2):
            print(f"Skipping invalid fixation pair: {fixation_pair}")
            continue
        
        x_coords, y_coords = fixation_pair
        
        if not (isinstance(x_coords, np.ndarray) and isinstance(y_coords, np.ndarray)):
            print(f"Skipping non-array fixation pair: {fixation_pair}")
            continue
        
        if len(x_coords) != len(y_coords):
            print(f"Mismatched coordinate lengths: {len(x_coords)} vs {len(y_coords)}")
            continue
        
        for x, y in zip(x_coords, y_coords):
            try:
                cv2.circle(image, (int(x), int(y)), radius, color, thickness)
            except Exception as e:
                print(f"Error drawing fixation ({x}, {y}): {str(e)}")
    
    return image

def process_scanpath(scanpath, scale):
    scanpath_array = np.array(scanpath)
    scaled_scanpath = np.round(scanpath_array * scale).astype(int)
    return list(map(tuple, scaled_scanpath))
    
class IPAdapter(torch.nn.Module):
    """IP-Adapter"""
    def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
        super().__init__()
        self.unet = unet
        self.image_proj_model = image_proj_model
        self.adapter_modules = adapter_modules

        if ckpt_path is not None:
            self.load_from_checkpoint(ckpt_path)

    def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
        device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
        image_embeds_cond, image_embeds_uncond = image_embeds.to(device), torch.zeros_like(image_embeds).to(device)
        ip_tokens_cond = self.image_proj_model(image_embeds_cond)
        ip_tokens_uncond = self.image_proj_model(image_embeds_uncond)
        encoder_hidden_states_cond = torch.cat([encoder_hidden_states, ip_tokens_cond], dim=1)
        encoder_hidden_states_uncond = torch.cat([encoder_hidden_states, ip_tokens_uncond], dim=1)
        # Predict the noise residual
        combined_encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_cond], dim=0)
        combined_noisy_latents = torch.cat([noisy_latents, noisy_latents], dim=0)
        # Single forward pass through U-Net
        combined_noise_pred = self.unet(combined_noisy_latents, timesteps, combined_encoder_hidden_states).sample
        # Split the predictions
        noise_pred_uncond, noise_pred_cond = combined_noise_pred.chunk(2)
        return noise_pred_uncond, noise_pred_cond

    def load_from_checkpoint(self, ckpt_path: str):
        # Calculate original checksums
        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        state_dict = torch.load(ckpt_path, map_location="cpu")

        # Load state dict for image_proj_model and adapter_modules
        self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
        self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)

        # Calculate new checksums
        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        # Verify if the weights have changed
        assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
        assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"

        print(f"Successfully loaded weights from checkpoint {ckpt_path}")

def main():
    random.seed(42)
    os.environ['WANDB_DIR'] = "/mnt/disk/data/add_disk0/rraina/ip_adapter_training_spGFj414_val/"
    wandb.init(project='ip_adapter_training')
    
    pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
    image_encoder_path = "/home/rraina/DiffPeriph/dinov2/dinov2_vitl14_reg4_pretrain.pth"
    
    ip_ckpt = "/mnt/disk/data/add_disk0/rraina/ip_adapter_training_spGFj414/checkpoint-50000/ip_adapter.bin"

    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    
    # Load scheduler, tokenizer and sd-models.
    noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder").to(device)
    vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae").to(device)
    unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet").to(device)

    image_encoder = vits.vit_large(img_size=518, patch_size=14, init_values=1.0,
                                   block_chunks=0, num_register_tokens=4)
    state_dict = torch.load(image_encoder_path, map_location='cpu')
    image_encoder.load_state_dict(state_dict, strict=True)
    image_encoder.eval().to(device)

    image_proj_model = ImageProjModel(
        cross_attention_dim=unet.config.cross_attention_dim,
        clip_embeddings_dim=1024,
        clip_extra_context_tokens=256,
    ).to(device)

    attn_procs = {}
    unet_sd = unet.state_dict()
    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = AttnProcessor()
        else:
            layer_name = name.split(".processor")[0]
            weights = {
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }
            attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=256)
            attn_procs[name].load_state_dict(weights)
    unet.set_attn_processor(attn_procs)
    adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
    
    ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, ip_ckpt).to(device)

    # freeze parameters of models to save more memory
    unet.requires_grad_(False)
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    image_encoder.requires_grad_(False)
    ip_adapter.requires_grad_(False)

    val_dataset = StaticCOCOFreeviewDatasetIP_adapter(root_dir='/mnt/disk/data/add_disk0/rraina/cocofreeview_fov_val/',
                                                      json_file='/home/rraina/coco-freeview/all_fixations_info_val_real.json',
                                                      annotation_file='/home/rraina/diffusion-peripheral/assets/annotations/captions_trainval2014_2017.json',
                                                      tokenizer=tokenizer, 
                                                      encoder_size=(224,224),
                                                      gen_size=(512,512),
                                                      t_drop_rate=0.00, i_drop_rate=0.00, ti_drop_rate=0.00)

    fixed_random_numbers = random.choices(range(0, 1000), k=50)
    print(fixed_random_numbers)

    val_dataloader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(val_dataset, fixed_random_numbers),
        shuffle=True,
        batch_size=1,
        num_workers=8,
    )

    num_inference_steps=50

    for batch in tqdm(val_dataloader):
        with torch.no_grad():
            foveated_img_encoder = batch["foveated_img_encoder"].to(device)
            foveated_img_ipadapter = batch["foveated_img_ipadapter"].to(device)
            original_img_encoder = batch["original_img_encoder"].to(device)
            original_img_ipadapter = batch["original_img_ipadapter"].to(device)
            mask_img_ipadapter = batch["mask_img_ipadapter"].to(device)
            text_input_ids = batch["text_input_ids"].to(device)
            current_fixation_idx = batch['current_fixation_idx'].to(device)
            fixations = batch['fixations']
            drop_image_embeds = batch['drop_image_embeds']

        scale = np.array([512 / 1680, 512 / 1050])
        scaled_fixations = [process_scanpath(scanpath, scale) for scanpath in fixations]

        fixations = np.array(fixations)
        scale = np.array([512 / 1680, 512 / 1050])
        fixations = np.round(fixations * scale).astype(int)
        fixations = list(map(tuple, fixations))

        noise_scheduler.set_timesteps(num_inference_steps)
        init_timestep = min(int(num_inference_steps*1.0), num_inference_steps)
        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = noise_scheduler.timesteps[t_start:].to(device)
        
        init_latents = vae.encode(original_img_ipadapter).latent_dist.sample()
        init_latents = init_latents * vae.config.scaling_factor
        init_latents_orig = init_latents.to(device)

        noise = torch.randn(init_latents.shape).to(device)
        init_latents = noise_scheduler.add_noise(init_latents, noise, timesteps[:1]).to(device)
        latents = init_latents

        mask = torch.nn.functional.interpolate(1-mask_img_ipadapter, size=latents.shape[-2:], mode='nearest').to(device)

        orig_image_embeds = image_encoder.forward_features(foveated_img_encoder)
        patch_tokens = orig_image_embeds['x_patchtokens']
        reg_tokens = orig_image_embeds['x_regtokens']
        cls_tokens = orig_image_embeds['x_clstoken']

        B, N, _ = patch_tokens.shape
        fixation_mask = F.interpolate(mask_img_ipadapter.float(), size=(int(N ** 0.5), int(N ** 0.5)),
                                        mode='area')
        fixation_mask = 1 - (fixation_mask == 1).float()
        fixation_mask = fixation_mask.view(B, N, 1)
        fixation_tokens = patch_tokens * fixation_mask
        # Concatenate fixation, register, and cls tokens
        image_embeds = fixation_tokens

        encoder_hidden_states = text_encoder(text_input_ids)[0].to(device)

        w = 6.5
        for t in tqdm(timesteps):
            noise_pred_uncond, noise_pred_cond = ip_adapter(latents, t, encoder_hidden_states, image_embeds)
            noise_pred = noise_pred_uncond + w * (noise_pred_cond - noise_pred_uncond)
            latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

        # decode the latents
        latents = 1 / vae.config.scaling_factor * latents
        with torch.autocast(device_type='cuda:2', dtype=torch.float16):
            image = vae.decode(latents).sample

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).detach().numpy()  # Convert to float32 for numpy operations
        image = (image * 255).round().astype("uint8")

        # denormalize images
        mean = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1)
        std = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1)
        original_img = original_img_ipadapter.cpu().detach() * std + mean
        foveated_img = foveated_img_ipadapter.cpu().detach() * std + mean
        mask_img = mask_img_ipadapter.cpu().detach() * std + mean
        pred_img = torch.from_numpy(image).permute(0, 3, 1, 2).float() / 255.0

        original_img_with_fixations = draw_fixations((original_img[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8), fixations)
        original_img_with_fixations = torch.from_numpy(original_img_with_fixations).permute(2, 0, 1).float() / 255.0

        img_grid = torchvision.utils.make_grid([original_img[0].float(), original_img_with_fixations, foveated_img[0].float(), pred_img[0]], nrow=4)

        wandb.log({
            "images": wandb.Image((np.transpose(img_grid.cpu().numpy(), (1, 2, 0)) * 255).astype(np.uint8),
                                        caption="Original | Original+Fixations | Foveated | Generated")
        })
        wandb.run.save() 



if __name__ == "__main__":
    main()  
Editor is loading...
Leave a Comment