Untitled

mail@pastecode.io avatar
unknown
python
2 months ago
1.9 kB
3
Indexable
Never
def compute_logprob_and_length(self, prompt, completion):
    completions_logprobs = []

    prompt_tokens = self.tokenizer(prompt, return_tensors="pt").to(
        self.model.device
        )  # <s> SPIECE_UNDERLINE [tokens]
    len_init = len(prompt)
    prompt_tokens_yes = self.tokenizer(prompt + " Yes", return_tensors="pt").to( \
        self.model.device \
        )  # <s> SPIECE_UNDERLINE [tokens]
    prompt_tokens_no = self.tokenizer(prompt + " No", return_tensors="pt").to( \
               self.model.device 
               )
           # Actual number of tokens in completion (without `<s>`)
    for prompt_tokens in [prompt_tokens_yes, prompt_tokens_no]:
        prompt_num_tokens = prompt_tokens.input_ids.shape[1] - 1

        completion_tokens = self.tokenizer(f"{completion} {self.tokenizer.eos_token}", return_tensors="pt").to(self.model.device)  # <s> SPIECE_UNDERLINE [tokens] SPIECE_UNDERLINE </s>
                # Actual number of tokens in completion (without `<s> SPIECE_UNDERLINE`)
        completion_num_tokens = completion_tokens.input_ids.shape[1] - 1
        inputs = torch.concatenate(
            (
                prompt_tokens.input_ids,
                completion_tokens.input_ids[:, -completion_num_tokens:],
            ),
            dim=-1,
        )
        outputs = self.model(inputs)
        # [input_tokens] [next_token
        # Include probabilities of 'SPIECE_UNDERLINE </s>' tokens
        logits = outputs.logits[:, len_init: len_init + completion_num_tokens
        ]
        logprobs = logits.log_softmax(dim=-1)
        # >>> batch_size, sequence_length, vocab_siz
        logprobs = logprobs.gather(dim=-1, index=completion_tokens.input_ids[:, -completion_num_tokens:].unsqueeze( -1 ),
        ).squeeze(-1)
        # >>> batch_size, sequence_length
        completions_logprobs.append(logprobs.cpu().numpy())

    return max(completions_logprobs)
Leave a Comment