Untitled
unknown
python
a year ago
9.6 kB
6
Indexable
import argparse import librosa import torch from omegaconf import OmegaConf from transformers import LlamaTokenizer, AutoTokenizer, HubertForCTC from model.audio_encoder import AudioEncoder from model.audio_llama import AudioLlamaForCausalLM from utils import merge_prompt_tokens, PROMPT_PREFIX, PROMPT_SUFFIX class LLMSpeechTextInference(): def __init__(self, config, audio_encoder_checkpoint, device): self.config = config self.device = device # Audio encoder. checkpoint = torch.load(audio_encoder_checkpoint, map_location="cpu") # print("Checkpoint keys:", checkpoint.keys()) # Debugging: Print keys self.audio_encoder = AudioEncoder(self.config) if 'audio_encoder' in checkpoint: self.audio_encoder.load_state_dict(checkpoint['audio_encoder']) else: self.audio_encoder.load_state_dict(checkpoint) self.audio_encoder.eval().to(self.device) print("Loaded audio encoder.\n") # LLM tokenizer. self.llm_tokenizer = LlamaTokenizer.from_pretrained( "GeneZC/MiniChat-2-3B", use_fast=False, ) # Load and freeze LLM model weights. self.llm = AudioLlamaForCausalLM.from_pretrained( "GeneZC/MiniChat-2-3B", use_cache=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ).eval() self.llm.to(self.device) print("Loaded LLM.\n") # Load HuBERT ASR model for getting CTC offsets. self.hubert_tokenizer = AutoTokenizer.from_pretrained("facebook/hubert-large-ls960-ft") self.hubert = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device) self.hubert.to(self.device) print("Loaded HuBERT.\n") def perform_hubert_asr(self, audio): # Feed audio through model to get greedily predicted transcription IDs. logits = self.hubert(audio).logits[0] pred_ids = torch.argmax(logits, axis=-1) # Decode transcription IDs to get text transcript. # NOTE: Always converts to lower case. transcript = self.hubert_tokenizer.decode(pred_ids).lower() return transcript def get_ctc_pool_ranges(self, audio, pool_range=4): # Feed audio through model to get greedily predicted transcription IDs. logits = self.hubert(audio).logits[0] pred_ids = torch.argmax(logits, axis=-1) # Perform decoding to get CTC offsets for each predicted word. outputs = self.hubert_tokenizer.decode(pred_ids, output_word_offsets=True) word_offsets = outputs.word_offsets ctc_word_offsets = [ (word['start_offset'], word['end_offset']) for word in word_offsets ] # Add offset ranges for silence in between words. The first element of # each tuple is a flag denoting whether the offset corresponds to # a word (1) or silence (0). all_word_offsets = [(0, 0, ctc_word_offsets[0][0])] for i in range(len(ctc_word_offsets)-1): all_word_offsets.append((1, ctc_word_offsets[i][0], ctc_word_offsets[i][1])) all_word_offsets.append((0, ctc_word_offsets[i][1], ctc_word_offsets[i+1][0])) all_word_offsets.append((1, ctc_word_offsets[-1][0], ctc_word_offsets[-1][1])) all_word_offsets.append( (0, ctc_word_offsets[-1][1], ctc_word_offsets[-1][1] + (pool_range * 2)) ) # Aggregate the offsets into pooling ranges for the audio encoder. ctc_pool_ranges = [] for is_word, start_offset, end_offset in all_word_offsets: if is_word == 1: startpoint = start_offset endpoint = start_offset + pool_range while startpoint < end_offset: ctc_pool_ranges.append((startpoint, endpoint)) startpoint += pool_range endpoint += pool_range else: ctc_pool_ranges.append((start_offset, end_offset)) return ctc_pool_ranges def generate_llm_response(self, inputs_embeds, max_new_tokens=256): with torch.no_grad(): with torch.autocast(device_type='cuda', dtype=torch.float16 if torch.cuda.is_available() else torch.float32): # Ensure inputs_embeds is in the correct dtype inputs_embeds = inputs_embeds.to(dtype=self.llm.dtype) # Debugging shapes # print(f"inputs_embeds shape: {inputs_embeds.shape}") generate_ids = self.llm.generate( input_ids=None, inputs_embeds=inputs_embeds, max_new_tokens=max_new_tokens, ) # Debugging shapes # print(f"generate_ids shape: {generate_ids.shape}") response_text = self.llm_tokenizer.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True, ) return response_text def generate_text_response(self, input_text, max_new_tokens=256): # Create full prompt for instruction-tuned LLM. full_text_prompt = f"{PROMPT_PREFIX} {input_text}{PROMPT_SUFFIX} " with torch.no_grad(): # Tokenize and get embeddings for the full text prompt. prompt_input_ids = self.llm_tokenizer( full_text_prompt, return_tensors='pt' ).input_ids.to(self.device) prompt_embeds = self.llm.model.embed_tokens(prompt_input_ids) # Generate the LLM response. llm_response = self.generate_llm_response( inputs_embeds=prompt_embeds, max_new_tokens=max_new_tokens, )[0] return llm_response def generate_asr_cascade_response(self, audio, additional_text_prompt="Summarize the following:", max_new_tokens=256): with torch.no_grad(): # Perform ASR using HuBERT. audio_tensor = torch.tensor(audio).float().unsqueeze(0).to(self.device) asr_transcript = self.perform_hubert_asr(audio_tensor) # Combine the transcript with any additional text prompt. # NOTE: Assumes that the text prompt always comes before the # transcribed text. full_text = additional_text_prompt + asr_transcript llm_response = self.generate_text_response(full_text, max_new_tokens) return llm_response def generate_audio_response(self, audio, additional_text_prompt="Summarize the following:", max_new_tokens=256): with torch.no_grad(): audio_tensor = torch.tensor(audio).float().unsqueeze(0).to(self.device) if self.audio_encoder.downsample_method == "ctc_pool": # Get the CTC pooling ranges for the audio. ctc_pool_ranges = self.get_ctc_pool_ranges(audio_tensor) # Get embeddings from the audio encoder. audio_embeds = self.audio_encoder(audio_tensor, [ctc_pool_ranges]) else: audio_embeds = self.audio_encoder(audio_tensor, ctc_pool_ranges=None) # Debugging shapes # print(f"audio_embeds shape: {audio_embeds.shape}") # Combine the audio embeddings with any additional text prompt. if len(additional_text_prompt) > 0: additional_text_input_ids = self.llm_tokenizer( additional_text_prompt, return_tensors='pt' ).input_ids[:, 1:].to(self.device) text_embeds = self.llm.model.embed_tokens(additional_text_input_ids) combined_embeds = torch.cat([text_embeds, audio_embeds], dim=1) else: combined_embeds = audio_embeds # Debugging shapes # print(f"combined_embeds shape: {combined_embeds.shape}") # Get the full embedding sequence and generate the LLM response prompt_emb_sequence = merge_prompt_tokens( inputs_embeds=combined_embeds, tokenizer=self.llm_tokenizer, embed_tokens=self.llm.model.embed_tokens, device=self.device, ) # Debugging shapes # print(f"prompt_emb_sequence shape: {prompt_emb_sequence.shape}") llm_response = self.generate_llm_response(prompt_emb_sequence, max_new_tokens)[0] return llm_response if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-c', '--config', type=str, help="yaml file for configuration") parser.add_argument('-g', '--gpu_idx', type=int, default=0, help="index of home GPU device") parser.add_argument('-p', '--audio_encoder_checkpoint', type=str, help="path to audio encoder checkpoint") parser.add_argument('-a', '--audio_file', type=str, required=True, help="audio file containing speech utterance to be used in prompt") args = parser.parse_args() device = torch.device(f"cuda:{args.gpu_idx}" if torch.cuda.is_available() else "cpu") # Set up inferencer. config = OmegaConf.load(args.config) llm_inferencer = LLMSpeechTextInference( config=config, audio_encoder_checkpoint=args.audio_encoder_checkpoint, device=device, ) # Load audio file. audio, sr = librosa.load(args.audio_file, sr=16000) # Generate LLM response. llm_response = llm_inferencer.generate_audio_response( audio, max_new_tokens=512, ) print(f'LLM Response: {llm_response}')
Editor is loading...
Leave a Comment