Untitled
unknown
python
2 years ago
2.2 kB
6
Indexable
def compute_logprob_and_length(self, prompts, completions):
completions_num_tokens = []
completions_logprobs = []
for prompt, completion in zip(prompts, completions):
prompt_tokens = self.tokenizer(prompt, return_tensors="pt").to(
self.model.device
) # <s> SPIECE_UNDERLINE [tokens]
# Actual number of tokens in completion (without `<s>`)
prompt_num_tokens = prompt_tokens.input_ids.shape[1] - 1
completion_tokens = self.tokenizer(
f"{completion} {self.tokenizer.eos_token}", return_tensors="pt"
).to(
self.model.device
) # <s> SPIECE_UNDERLINE [tokens] SPIECE_UNDERLINE </s>
# Actual number of tokens in completion (without `<s> SPIECE_UNDERLINE`)
completion_num_tokens = completion_tokens.input_ids.shape[1] - 1
if completion_tokens.input_ids[0, 1] == 29871:
completion_num_tokens = completion_num_tokens - 1
completions_num_tokens.append(completion_num_tokens)
inputs = torch.concatenate(
(
prompt_tokens.input_ids,
completion_tokens.input_ids[:, -completion_num_tokens:],
),
dim=-1,
)
outputs = self.model(inputs)
# [input_tokens] [next_token]
# Include probabilities of 'SPIECE_UNDERLINE </s>' tokens
logits = outputs.logits[
:, prompt_num_tokens: prompt_num_tokens + completion_num_tokens
]
logprobs = logits.log_softmax(dim=-1)
# >>> batch_size, sequence_length, vocab_size
logprobs = logprobs.gather(
dim=-1,
index=completion_tokens.input_ids[:, -completion_num_tokens:].unsqueeze(
-1
),
).squeeze(-1)
# >>> batch_size, sequence_length
completions_logprobs.append(logprobs.cpu().numpy())
return completions_logprobs, completions_num_tokens
Editor is loading...
Leave a Comment