Untitled
unknown
python
a year ago
9.6 kB
15
Indexable
import argparse
import librosa
import torch
from omegaconf import OmegaConf
from transformers import LlamaTokenizer, AutoTokenizer, HubertForCTC
from model.audio_encoder import AudioEncoder
from model.audio_llama import AudioLlamaForCausalLM
from utils import merge_prompt_tokens, PROMPT_PREFIX, PROMPT_SUFFIX
class LLMSpeechTextInference():
def __init__(self, config, audio_encoder_checkpoint, device):
self.config = config
self.device = device
# Audio encoder.
checkpoint = torch.load(audio_encoder_checkpoint, map_location="cpu")
# print("Checkpoint keys:", checkpoint.keys()) # Debugging: Print keys
self.audio_encoder = AudioEncoder(self.config)
if 'audio_encoder' in checkpoint:
self.audio_encoder.load_state_dict(checkpoint['audio_encoder'])
else:
self.audio_encoder.load_state_dict(checkpoint)
self.audio_encoder.eval().to(self.device)
print("Loaded audio encoder.\n")
# LLM tokenizer.
self.llm_tokenizer = LlamaTokenizer.from_pretrained(
"GeneZC/MiniChat-2-3B",
use_fast=False,
)
# Load and freeze LLM model weights.
self.llm = AudioLlamaForCausalLM.from_pretrained(
"GeneZC/MiniChat-2-3B",
use_cache=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).eval()
self.llm.to(self.device)
print("Loaded LLM.\n")
# Load HuBERT ASR model for getting CTC offsets.
self.hubert_tokenizer = AutoTokenizer.from_pretrained("facebook/hubert-large-ls960-ft")
self.hubert = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device)
self.hubert.to(self.device)
print("Loaded HuBERT.\n")
def perform_hubert_asr(self, audio):
# Feed audio through model to get greedily predicted transcription IDs.
logits = self.hubert(audio).logits[0]
pred_ids = torch.argmax(logits, axis=-1)
# Decode transcription IDs to get text transcript.
# NOTE: Always converts to lower case.
transcript = self.hubert_tokenizer.decode(pred_ids).lower()
return transcript
def get_ctc_pool_ranges(self, audio, pool_range=4):
# Feed audio through model to get greedily predicted transcription IDs.
logits = self.hubert(audio).logits[0]
pred_ids = torch.argmax(logits, axis=-1)
# Perform decoding to get CTC offsets for each predicted word.
outputs = self.hubert_tokenizer.decode(pred_ids, output_word_offsets=True)
word_offsets = outputs.word_offsets
ctc_word_offsets = [
(word['start_offset'], word['end_offset']) for word in word_offsets
]
# Add offset ranges for silence in between words. The first element of
# each tuple is a flag denoting whether the offset corresponds to
# a word (1) or silence (0).
all_word_offsets = [(0, 0, ctc_word_offsets[0][0])]
for i in range(len(ctc_word_offsets)-1):
all_word_offsets.append((1, ctc_word_offsets[i][0], ctc_word_offsets[i][1]))
all_word_offsets.append((0, ctc_word_offsets[i][1], ctc_word_offsets[i+1][0]))
all_word_offsets.append((1, ctc_word_offsets[-1][0], ctc_word_offsets[-1][1]))
all_word_offsets.append(
(0, ctc_word_offsets[-1][1], ctc_word_offsets[-1][1] + (pool_range * 2))
)
# Aggregate the offsets into pooling ranges for the audio encoder.
ctc_pool_ranges = []
for is_word, start_offset, end_offset in all_word_offsets:
if is_word == 1:
startpoint = start_offset
endpoint = start_offset + pool_range
while startpoint < end_offset:
ctc_pool_ranges.append((startpoint, endpoint))
startpoint += pool_range
endpoint += pool_range
else:
ctc_pool_ranges.append((start_offset, end_offset))
return ctc_pool_ranges
def generate_llm_response(self, inputs_embeds, max_new_tokens=256):
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.float16 if torch.cuda.is_available() else torch.float32):
# Ensure inputs_embeds is in the correct dtype
inputs_embeds = inputs_embeds.to(dtype=self.llm.dtype)
# Debugging shapes
# print(f"inputs_embeds shape: {inputs_embeds.shape}")
generate_ids = self.llm.generate(
input_ids=None,
inputs_embeds=inputs_embeds,
max_new_tokens=max_new_tokens,
)
# Debugging shapes
# print(f"generate_ids shape: {generate_ids.shape}")
response_text = self.llm_tokenizer.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
return response_text
def generate_text_response(self, input_text, max_new_tokens=256):
# Create full prompt for instruction-tuned LLM.
full_text_prompt = f"{PROMPT_PREFIX} {input_text}{PROMPT_SUFFIX} "
with torch.no_grad():
# Tokenize and get embeddings for the full text prompt.
prompt_input_ids = self.llm_tokenizer(
full_text_prompt, return_tensors='pt'
).input_ids.to(self.device)
prompt_embeds = self.llm.model.embed_tokens(prompt_input_ids)
# Generate the LLM response.
llm_response = self.generate_llm_response(
inputs_embeds=prompt_embeds,
max_new_tokens=max_new_tokens,
)[0]
return llm_response
def generate_asr_cascade_response(self, audio, additional_text_prompt="Summarize the following:", max_new_tokens=256):
with torch.no_grad():
# Perform ASR using HuBERT.
audio_tensor = torch.tensor(audio).float().unsqueeze(0).to(self.device)
asr_transcript = self.perform_hubert_asr(audio_tensor)
# Combine the transcript with any additional text prompt.
# NOTE: Assumes that the text prompt always comes before the
# transcribed text.
full_text = additional_text_prompt + asr_transcript
llm_response = self.generate_text_response(full_text, max_new_tokens)
return llm_response
def generate_audio_response(self, audio, additional_text_prompt="Summarize the following:", max_new_tokens=256):
with torch.no_grad():
audio_tensor = torch.tensor(audio).float().unsqueeze(0).to(self.device)
if self.audio_encoder.downsample_method == "ctc_pool":
# Get the CTC pooling ranges for the audio.
ctc_pool_ranges = self.get_ctc_pool_ranges(audio_tensor)
# Get embeddings from the audio encoder.
audio_embeds = self.audio_encoder(audio_tensor, [ctc_pool_ranges])
else:
audio_embeds = self.audio_encoder(audio_tensor, ctc_pool_ranges=None)
# Debugging shapes
# print(f"audio_embeds shape: {audio_embeds.shape}")
# Combine the audio embeddings with any additional text prompt.
if len(additional_text_prompt) > 0:
additional_text_input_ids = self.llm_tokenizer(
additional_text_prompt, return_tensors='pt'
).input_ids[:, 1:].to(self.device)
text_embeds = self.llm.model.embed_tokens(additional_text_input_ids)
combined_embeds = torch.cat([text_embeds, audio_embeds], dim=1)
else:
combined_embeds = audio_embeds
# Debugging shapes
# print(f"combined_embeds shape: {combined_embeds.shape}")
# Get the full embedding sequence and generate the LLM response
prompt_emb_sequence = merge_prompt_tokens(
inputs_embeds=combined_embeds,
tokenizer=self.llm_tokenizer,
embed_tokens=self.llm.model.embed_tokens,
device=self.device,
)
# Debugging shapes
# print(f"prompt_emb_sequence shape: {prompt_emb_sequence.shape}")
llm_response = self.generate_llm_response(prompt_emb_sequence, max_new_tokens)[0]
return llm_response
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, help="yaml file for configuration")
parser.add_argument('-g', '--gpu_idx', type=int, default=0, help="index of home GPU device")
parser.add_argument('-p', '--audio_encoder_checkpoint', type=str, help="path to audio encoder checkpoint")
parser.add_argument('-a', '--audio_file', type=str, required=True, help="audio file containing speech utterance to be used in prompt")
args = parser.parse_args()
device = torch.device(f"cuda:{args.gpu_idx}" if torch.cuda.is_available() else "cpu")
# Set up inferencer.
config = OmegaConf.load(args.config)
llm_inferencer = LLMSpeechTextInference(
config=config,
audio_encoder_checkpoint=args.audio_encoder_checkpoint,
device=device,
)
# Load audio file.
audio, sr = librosa.load(args.audio_file, sr=16000)
# Generate LLM response.
llm_response = llm_inferencer.generate_audio_response(
audio,
max_new_tokens=512,
)
print(f'LLM Response: {llm_response}')Editor is loading...
Leave a Comment