Untitled
unknown
python
a year ago
4.0 kB
136
Indexable
# New constants for instruct model tokens
SYSTEM_START_TOKEN = "<system>"
SYSTEM_END_TOKEN = "</system>"
USER_START_TOKEN = "<user>"
USER_END_TOKEN = "</user>"
ASSISTANT_START_TOKEN = "<assistant>"
ASSISTANT_END_TOKEN = "</assistant>"
SYSTEM_PROMPT = "You are a helpful AI assistant that can analyze images and provide detailed descriptions."
@torch.no_grad()
def stream_chat(input_images: List[Image.Image], batch_size: int, pbar: tqdm, models: tuple) -> List[str]:
clip_processor, clip_model, tokenizer, text_model, image_adapter = models
torch.cuda.empty_cache()
all_captions = []
for i in range(0, len(input_images), batch_size):
batch = input_images[i:i+batch_size]
try:
images = clip_processor(images=batch, return_tensors='pt', padding=True).pixel_values.to('cuda')
except ValueError as e:
print(f"Error processing image batch: {e}")
print("Skipping this batch and continuing...")
continue
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
vision_outputs = clip_model(pixel_values=images, output_hidden_states=True)
image_features = vision_outputs.hidden_states[-2]
embedded_images = image_adapter(image_features).to(dtype=torch.bfloat16)
# Construct the new prompt structure
full_prompt = f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}**<image>**{VLM_PROMPT}{USER_END_TOKEN}{ASSISTANT_START_TOKEN}"
prompt = tokenizer.encode(full_prompt, return_tensors='pt')
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda')).to(dtype=torch.bfloat16)
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64)).to(dtype=torch.bfloat16)
inputs_embeds = torch.cat([
embedded_bos.expand(embedded_images.shape[0], -1, -1),
prompt_embeds[:, :tokenizer.encode(f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}", add_special_tokens=False, return_tensors='pt').shape[1], :].expand(embedded_images.shape[0], -1, -1),
embedded_images,
prompt_embeds[:, tokenizer.encode(f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}", add_special_tokens=False, return_tensors='pt').shape[1]:, :].expand(embedded_images.shape[0], -1, -1),
], dim=1).to(dtype=torch.bfloat16)
input_ids = torch.cat([
torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).expand(embedded_images.shape[0], -1),
prompt[:, :tokenizer.encode(f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}", add_special_tokens=False, return_tensors='pt').shape[1]].expand(embedded_images.shape[0], -1),
torch.zeros((embedded_images.shape[0], embedded_images.shape[1]), dtype=torch.long),
prompt[:, tokenizer.encode(f"{SYSTEM_START_TOKEN}{SYSTEM_PROMPT}{SYSTEM_END_TOKEN}{USER_START_TOKEN}", add_special_tokens=False, return_tensors='pt').shape[1]:].expand(embedded_images.shape[0], -1),
], dim=1).to('cuda')
attention_mask = torch.ones_like(input_ids)
generate_ids = text_model.generate(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=300,
do_sample=True,
top_k=10,
temperature=0.5,
)
generate_ids = generate_ids[:, input_ids.shape[1]:]
for ids in generate_ids:
caption = tokenizer.decode(ids[:-1] if ids[-1] == tokenizer.eos_token_id else ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
caption = caption.replace('<|end_of_text|>', '').replace('<|finetune_right_pad_id|>', '').strip()
all_captions.append(caption)
if pbar:
pbar.update(len(batch))
return all_captionsEditor is loading...
Leave a Comment