Untitled

mail@pastecode.io avatar
unknown
python
20 days ago
4.0 kB
112
Indexable
Never
# New constants for instruct model tokens
SYSTEM_START_TOKEN = "<system>"
SYSTEM_END_TOKEN = "</system>"
USER_START_TOKEN = "<user>"
USER_END_TOKEN = "</user>"
ASSISTANT_START_TOKEN = "<assistant>"
ASSISTANT_END_TOKEN = "</assistant>"

SYSTEM_PROMPT = "You are a helpful AI assistant that can analyze images and provide detailed descriptions."

@torch.no_grad()
def stream_chat(input_images: List[Image.Image], batch_size: int, pbar: tqdm, models: tuple) -> List[str]:
    clip_processor, clip_model, tokenizer, text_model, image_adapter = models
    torch.cuda.empty_cache()
    all_captions = []

    for i in range(0, len(input_images), batch_size):
        batch = input_images[i:i+batch_size]
        
        try:
            images = clip_processor(images=batch, return_tensors='pt', padding=True).pixel_values.to('cuda')
        except ValueError as e:
            print(f"Error processing image batch: {e}")
            print("Skipping this batch and continuing...")
            continue

        with torch.amp.autocast_mode.autocast('cuda', enabled=True):
            vision_outputs = clip_model(pixel_values=images, output_hidden_states=True)
            image_features = vision_outputs.hidden_states[-2]
            embedded_images = image_adapter(image_features).to(dtype=torch.bfloat16)

        # Construct the new prompt structure
        full_prompt = f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}**<image>**{VLM_PROMPT}{USER_END_TOKEN}{ASSISTANT_START_TOKEN}"
        prompt = tokenizer.encode(full_prompt, return_tensors='pt')
        prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda')).to(dtype=torch.bfloat16)
        embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64)).to(dtype=torch.bfloat16)

        inputs_embeds = torch.cat([
            embedded_bos.expand(embedded_images.shape[0], -1, -1),
            prompt_embeds[:, :tokenizer.encode(f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}", add_special_tokens=False, return_tensors='pt').shape[1], :].expand(embedded_images.shape[0], -1, -1),
            embedded_images,
            prompt_embeds[:, tokenizer.encode(f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}", add_special_tokens=False, return_tensors='pt').shape[1]:, :].expand(embedded_images.shape[0], -1, -1),
        ], dim=1).to(dtype=torch.bfloat16)

        input_ids = torch.cat([
            torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).expand(embedded_images.shape[0], -1),
            prompt[:, :tokenizer.encode(f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}", add_special_tokens=False, return_tensors='pt').shape[1]].expand(embedded_images.shape[0], -1),
            torch.zeros((embedded_images.shape[0], embedded_images.shape[1]), dtype=torch.long),
            prompt[:, tokenizer.encode(f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}", add_special_tokens=False, return_tensors='pt').shape[1]:].expand(embedded_images.shape[0], -1),
        ], dim=1).to('cuda')

        attention_mask = torch.ones_like(input_ids)

        generate_ids = text_model.generate(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            max_new_tokens=300,
            do_sample=True,
            top_k=10,
            temperature=0.5,
        )

        generate_ids = generate_ids[:, input_ids.shape[1]:]

        for ids in generate_ids:
            caption = tokenizer.decode(ids[:-1] if ids[-1] == tokenizer.eos_token_id else ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            caption = caption.replace('<|end_of_text|>', '').replace('<|finetune_right_pad_id|>', '').strip()
            all_captions.append(caption)

        if pbar:
            pbar.update(len(batch))

    return all_captions
Leave a Comment