Untitled
unknown
python
20 days ago
4.0 kB
112
Indexable
Never
# New constants for instruct model tokens SYSTEM_START_TOKEN = "<system>" SYSTEM_END_TOKEN = "</system>" USER_START_TOKEN = "<user>" USER_END_TOKEN = "</user>" ASSISTANT_START_TOKEN = "<assistant>" ASSISTANT_END_TOKEN = "</assistant>" SYSTEM_PROMPT = "You are a helpful AI assistant that can analyze images and provide detailed descriptions." @torch.no_grad() def stream_chat(input_images: List[Image.Image], batch_size: int, pbar: tqdm, models: tuple) -> List[str]: clip_processor, clip_model, tokenizer, text_model, image_adapter = models torch.cuda.empty_cache() all_captions = [] for i in range(0, len(input_images), batch_size): batch = input_images[i:i+batch_size] try: images = clip_processor(images=batch, return_tensors='pt', padding=True).pixel_values.to('cuda') except ValueError as e: print(f"Error processing image batch: {e}") print("Skipping this batch and continuing...") continue with torch.amp.autocast_mode.autocast('cuda', enabled=True): vision_outputs = clip_model(pixel_values=images, output_hidden_states=True) image_features = vision_outputs.hidden_states[-2] embedded_images = image_adapter(image_features).to(dtype=torch.bfloat16) # Construct the new prompt structure full_prompt = f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}**<image>**{VLM_PROMPT}{USER_END_TOKEN}{ASSISTANT_START_TOKEN}" prompt = tokenizer.encode(full_prompt, return_tensors='pt') prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda')).to(dtype=torch.bfloat16) embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64)).to(dtype=torch.bfloat16) inputs_embeds = torch.cat([ embedded_bos.expand(embedded_images.shape[0], -1, -1), prompt_embeds[:, :tokenizer.encode(f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}", add_special_tokens=False, return_tensors='pt').shape[1], :].expand(embedded_images.shape[0], -1, -1), embedded_images, prompt_embeds[:, tokenizer.encode(f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}", add_special_tokens=False, return_tensors='pt').shape[1]:, :].expand(embedded_images.shape[0], -1, -1), ], dim=1).to(dtype=torch.bfloat16) input_ids = torch.cat([ torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).expand(embedded_images.shape[0], -1), prompt[:, :tokenizer.encode(f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}", add_special_tokens=False, return_tensors='pt').shape[1]].expand(embedded_images.shape[0], -1), torch.zeros((embedded_images.shape[0], embedded_images.shape[1]), dtype=torch.long), prompt[:, tokenizer.encode(f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}", add_special_tokens=False, return_tensors='pt').shape[1]:].expand(embedded_images.shape[0], -1), ], dim=1).to('cuda') attention_mask = torch.ones_like(input_ids) generate_ids = text_model.generate( input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, ) generate_ids = generate_ids[:, input_ids.shape[1]:] for ids in generate_ids: caption = tokenizer.decode(ids[:-1] if ids[-1] == tokenizer.eos_token_id else ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) caption = caption.replace('<|end_of_text|>', '').replace('<|finetune_right_pad_id|>', '').strip() all_captions.append(caption) if pbar: pbar.update(len(batch)) return all_captions
Leave a Comment