Untitled
results = [] for i in tqdm(range(0, len(dataset), batch_size)): # Prepare batch inputs batch_inputs_raw = dataset[i : i + batch_size] batch_texts = [ self.processor.apply_chat_template( input_raw["messages"], add_generation_prompt=True ) for input_raw in batch_inputs_raw ] batch_images = [ [load_image(input_raw["image_path"], model.device)] for input_raw in batch_inputs_raw ] start_time = time.time() # Process batch inputs = self.processor( batch_images, batch_texts, add_special_tokens=False, return_tensors="pt", padding=True, ).to(model.device) end_time = time.time() print(f"Time taken for processing batch: {end_time - start_time} seconds") start_time = time.time() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=256, do_sample=False, top_p=None ) end_time = time.time() print(f"Time taken for generating batch: {end_time - start_time} seconds") start_time = time.time() output_strs = self.processor.batch_decode(outputs, skip_special_tokens=True) end_time = time.time() print(f"Time taken for decoding batch: {end_time - start_time} seconds")
Leave a Comment