Untitled
unknown
python
a year ago
1.9 kB
6
Indexable
def compute_logprob_and_length(self, prompt, completion): completions_logprobs = [] prompt_tokens = self.tokenizer(prompt, return_tensors="pt").to( self.model.device ) # <s> SPIECE_UNDERLINE [tokens] len_init = len(prompt) prompt_tokens_yes = self.tokenizer(prompt + " Yes", return_tensors="pt").to( \ self.model.device \ ) # <s> SPIECE_UNDERLINE [tokens] prompt_tokens_no = self.tokenizer(prompt + " No", return_tensors="pt").to( \ self.model.device ) # Actual number of tokens in completion (without `<s>`) for prompt_tokens in [prompt_tokens_yes, prompt_tokens_no]: prompt_num_tokens = prompt_tokens.input_ids.shape[1] - 1 completion_tokens = self.tokenizer(f"{completion} {self.tokenizer.eos_token}", return_tensors="pt").to(self.model.device) # <s> SPIECE_UNDERLINE [tokens] SPIECE_UNDERLINE </s> # Actual number of tokens in completion (without `<s> SPIECE_UNDERLINE`) completion_num_tokens = completion_tokens.input_ids.shape[1] - 1 inputs = torch.concatenate( ( prompt_tokens.input_ids, completion_tokens.input_ids[:, -completion_num_tokens:], ), dim=-1, ) outputs = self.model(inputs) # [input_tokens] [next_token # Include probabilities of 'SPIECE_UNDERLINE </s>' tokens logits = outputs.logits[:, len_init: len_init + completion_num_tokens ] logprobs = logits.log_softmax(dim=-1) # >>> batch_size, sequence_length, vocab_siz logprobs = logprobs.gather(dim=-1, index=completion_tokens.input_ids[:, -completion_num_tokens:].unsqueeze( -1 ), ).squeeze(-1) # >>> batch_size, sequence_length completions_logprobs.append(logprobs.cpu().numpy()) return max(completions_logprobs)
Editor is loading...
Leave a Comment