validatino adapter script
unknown
python
a year ago
15 kB
6
Indexable
import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import wandb
from tqdm import tqdm
import copy
from typing import List
import wandb
import random
import torch
import torchvision
from torchvision.utils import make_grid
import torch.nn.functional as F
from torchvision import transforms
from diffusers import StableDiffusionPipeline
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, UNet2DConditionModel
from diffusers.pipelines.controlnet import MultiControlNetModel
from PIL import Image
from safetensors import safe_open
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ip_adapter.utils import is_torch2_available, get_generator
from ip_adapter.ip_adapter import ImageProjModel
if is_torch2_available():
from ip_adapter.attention_processor import (
AttnProcessor2_0 as AttnProcessor,
)
from ip_adapter.attention_processor import (
CNAttnProcessor2_0 as CNAttnProcessor,
)
from ip_adapter.attention_processor import (
IPAttnProcessor2_0 as IPAttnProcessor,
)
else:
from ip_adapter.attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
from ip_adapter.resampler import Resampler
import mae
import dinov2.models.vision_transformer as vits
from StaticCOCOFreeviewDatasetIP_adapter_xl import *
seed = 99
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Set deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def draw_fixations(image, fixations, radius=10, color=(255, 0, 0), thickness=-1):
"""
Draw fixation points on the image.
:param image: numpy array of shape (H, W, C)
:param fixations: list of (array([x1, x2, ...]), array([y1, y2, ...])) pairs
:param radius: radius of the fixation points
:param color: color of the fixation points (B, G, R)
:param thickness: thickness of the fixation points (-1 for filled circle)
:return: image with fixations drawn
"""
image = image.copy()
# Ensure fixations is a list
if not isinstance(fixations, list):
fixations = [fixations]
for fixation_pair in fixations:
if not (isinstance(fixation_pair, tuple) and len(fixation_pair) == 2):
print(f"Skipping invalid fixation pair: {fixation_pair}")
continue
x_coords, y_coords = fixation_pair
if not (isinstance(x_coords, np.ndarray) and isinstance(y_coords, np.ndarray)):
print(f"Skipping non-array fixation pair: {fixation_pair}")
continue
if len(x_coords) != len(y_coords):
print(f"Mismatched coordinate lengths: {len(x_coords)} vs {len(y_coords)}")
continue
for x, y in zip(x_coords, y_coords):
try:
cv2.circle(image, (int(x), int(y)), radius, color, thickness)
except Exception as e:
print(f"Error drawing fixation ({x}, {y}): {str(e)}")
return image
def process_scanpath(scanpath, scale):
scanpath_array = np.array(scanpath)
scaled_scanpath = np.round(scanpath_array * scale).astype(int)
return list(map(tuple, scaled_scanpath))
class IPAdapter(torch.nn.Module):
"""IP-Adapter"""
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
super().__init__()
self.unet = unet
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
image_embeds_cond, image_embeds_uncond = image_embeds.to(device), torch.zeros_like(image_embeds).to(device)
ip_tokens_cond = self.image_proj_model(image_embeds_cond)
ip_tokens_uncond = self.image_proj_model(image_embeds_uncond)
encoder_hidden_states_cond = torch.cat([encoder_hidden_states, ip_tokens_cond], dim=1)
encoder_hidden_states_uncond = torch.cat([encoder_hidden_states, ip_tokens_uncond], dim=1)
# Predict the noise residual
combined_encoder_hidden_states = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_cond], dim=0)
combined_noisy_latents = torch.cat([noisy_latents, noisy_latents], dim=0)
# Single forward pass through U-Net
combined_noise_pred = self.unet(combined_noisy_latents, timesteps, combined_encoder_hidden_states).sample
# Split the predictions
noise_pred_uncond, noise_pred_cond = combined_noise_pred.chunk(2)
return noise_pred_uncond, noise_pred_cond
def load_from_checkpoint(self, ckpt_path: str):
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
def main():
random.seed(42)
os.environ['WANDB_DIR'] = "/mnt/disk/data/add_disk0/rraina/ip_adapter_training_spGFj414_val/"
wandb.init(project='ip_adapter_training')
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
image_encoder_path = "/home/rraina/DiffPeriph/dinov2/dinov2_vitl14_reg4_pretrain.pth"
ip_ckpt = "/mnt/disk/data/add_disk0/rraina/ip_adapter_training_spGFj414/checkpoint-50000/ip_adapter.bin"
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
# Load scheduler, tokenizer and sd-models.
noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder").to(device)
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae").to(device)
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet").to(device)
image_encoder = vits.vit_large(img_size=518, patch_size=14, init_values=1.0,
block_chunks=0, num_register_tokens=4)
state_dict = torch.load(image_encoder_path, map_location='cpu')
image_encoder.load_state_dict(state_dict, strict=True)
image_encoder.eval().to(device)
image_proj_model = ImageProjModel(
cross_attention_dim=unet.config.cross_attention_dim,
clip_embeddings_dim=1024,
clip_extra_context_tokens=256,
).to(device)
attn_procs = {}
unet_sd = unet.state_dict()
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=256)
attn_procs[name].load_state_dict(weights)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, ip_ckpt).to(device)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
ip_adapter.requires_grad_(False)
val_dataset = StaticCOCOFreeviewDatasetIP_adapter(root_dir='/mnt/disk/data/add_disk0/rraina/cocofreeview_fov_val/',
json_file='/home/rraina/coco-freeview/all_fixations_info_val_real.json',
annotation_file='/home/rraina/diffusion-peripheral/assets/annotations/captions_trainval2014_2017.json',
tokenizer=tokenizer,
encoder_size=(224,224),
gen_size=(512,512),
t_drop_rate=0.00, i_drop_rate=0.00, ti_drop_rate=0.00)
fixed_random_numbers = random.choices(range(0, 1000), k=50)
print(fixed_random_numbers)
val_dataloader = torch.utils.data.DataLoader(
torch.utils.data.Subset(val_dataset, fixed_random_numbers),
shuffle=True,
batch_size=1,
num_workers=8,
)
num_inference_steps=50
for batch in tqdm(val_dataloader):
with torch.no_grad():
foveated_img_encoder = batch["foveated_img_encoder"].to(device)
foveated_img_ipadapter = batch["foveated_img_ipadapter"].to(device)
original_img_encoder = batch["original_img_encoder"].to(device)
original_img_ipadapter = batch["original_img_ipadapter"].to(device)
mask_img_ipadapter = batch["mask_img_ipadapter"].to(device)
text_input_ids = batch["text_input_ids"].to(device)
current_fixation_idx = batch['current_fixation_idx'].to(device)
fixations = batch['fixations']
drop_image_embeds = batch['drop_image_embeds']
scale = np.array([512 / 1680, 512 / 1050])
scaled_fixations = [process_scanpath(scanpath, scale) for scanpath in fixations]
fixations = np.array(fixations)
scale = np.array([512 / 1680, 512 / 1050])
fixations = np.round(fixations * scale).astype(int)
fixations = list(map(tuple, fixations))
noise_scheduler.set_timesteps(num_inference_steps)
init_timestep = min(int(num_inference_steps*1.0), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = noise_scheduler.timesteps[t_start:].to(device)
init_latents = vae.encode(original_img_ipadapter).latent_dist.sample()
init_latents = init_latents * vae.config.scaling_factor
init_latents_orig = init_latents.to(device)
noise = torch.randn(init_latents.shape).to(device)
init_latents = noise_scheduler.add_noise(init_latents, noise, timesteps[:1]).to(device)
latents = init_latents
mask = torch.nn.functional.interpolate(1-mask_img_ipadapter, size=latents.shape[-2:], mode='nearest').to(device)
orig_image_embeds = image_encoder.forward_features(foveated_img_encoder)
patch_tokens = orig_image_embeds['x_patchtokens']
reg_tokens = orig_image_embeds['x_regtokens']
cls_tokens = orig_image_embeds['x_clstoken']
B, N, _ = patch_tokens.shape
fixation_mask = F.interpolate(mask_img_ipadapter.float(), size=(int(N ** 0.5), int(N ** 0.5)),
mode='area')
fixation_mask = 1 - (fixation_mask == 1).float()
fixation_mask = fixation_mask.view(B, N, 1)
fixation_tokens = patch_tokens * fixation_mask
# Concatenate fixation, register, and cls tokens
image_embeds = fixation_tokens
encoder_hidden_states = text_encoder(text_input_ids)[0].to(device)
w = 6.5
for t in tqdm(timesteps):
noise_pred_uncond, noise_pred_cond = ip_adapter(latents, t, encoder_hidden_states, image_embeds)
noise_pred = noise_pred_uncond + w * (noise_pred_cond - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
# decode the latents
latents = 1 / vae.config.scaling_factor * latents
with torch.autocast(device_type='cuda:2', dtype=torch.float16):
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).detach().numpy() # Convert to float32 for numpy operations
image = (image * 255).round().astype("uint8")
# denormalize images
mean = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1)
std = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1)
original_img = original_img_ipadapter.cpu().detach() * std + mean
foveated_img = foveated_img_ipadapter.cpu().detach() * std + mean
mask_img = mask_img_ipadapter.cpu().detach() * std + mean
pred_img = torch.from_numpy(image).permute(0, 3, 1, 2).float() / 255.0
original_img_with_fixations = draw_fixations((original_img[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8), fixations)
original_img_with_fixations = torch.from_numpy(original_img_with_fixations).permute(2, 0, 1).float() / 255.0
img_grid = torchvision.utils.make_grid([original_img[0].float(), original_img_with_fixations, foveated_img[0].float(), pred_img[0]], nrow=4)
wandb.log({
"images": wandb.Image((np.transpose(img_grid.cpu().numpy(), (1, 2, 0)) * 255).astype(np.uint8),
caption="Original | Original+Fixations | Foveated | Generated")
})
wandb.run.save()
if __name__ == "__main__":
main() Editor is loading...
Leave a Comment