Untitled
unknown
python
2 years ago
1.9 kB
13
Indexable
def compute_logprob_and_length(self, prompt, completion):
completions_logprobs = []
prompt_tokens = self.tokenizer(prompt, return_tensors="pt").to(
self.model.device
) # <s> SPIECE_UNDERLINE [tokens]
len_init = len(prompt)
prompt_tokens_yes = self.tokenizer(prompt + " Yes", return_tensors="pt").to( \
self.model.device \
) # <s> SPIECE_UNDERLINE [tokens]
prompt_tokens_no = self.tokenizer(prompt + " No", return_tensors="pt").to( \
self.model.device
)
# Actual number of tokens in completion (without `<s>`)
for prompt_tokens in [prompt_tokens_yes, prompt_tokens_no]:
prompt_num_tokens = prompt_tokens.input_ids.shape[1] - 1
completion_tokens = self.tokenizer(f"{completion} {self.tokenizer.eos_token}", return_tensors="pt").to(self.model.device) # <s> SPIECE_UNDERLINE [tokens] SPIECE_UNDERLINE </s>
# Actual number of tokens in completion (without `<s> SPIECE_UNDERLINE`)
completion_num_tokens = completion_tokens.input_ids.shape[1] - 1
inputs = torch.concatenate(
(
prompt_tokens.input_ids,
completion_tokens.input_ids[:, -completion_num_tokens:],
),
dim=-1,
)
outputs = self.model(inputs)
# [input_tokens] [next_token
# Include probabilities of 'SPIECE_UNDERLINE </s>' tokens
logits = outputs.logits[:, len_init: len_init + completion_num_tokens
]
logprobs = logits.log_softmax(dim=-1)
# >>> batch_size, sequence_length, vocab_siz
logprobs = logprobs.gather(dim=-1, index=completion_tokens.input_ids[:, -completion_num_tokens:].unsqueeze( -1 ),
).squeeze(-1)
# >>> batch_size, sequence_length
completions_logprobs.append(logprobs.cpu().numpy())
return max(completions_logprobs)Editor is loading...
Leave a Comment