Untitled

mail@pastecode.io avatar
unknown
python
2 months ago
2.2 kB
2
Indexable
Never
    def compute_logprob_and_length(self, prompts, completions):
        completions_num_tokens = []
        completions_logprobs = []

        for prompt, completion in zip(prompts, completions):
            prompt_tokens = self.tokenizer(prompt, return_tensors="pt").to(
                self.model.device
            )  # <s> SPIECE_UNDERLINE [tokens]
            # Actual number of tokens in completion (without `<s>`)
            prompt_num_tokens = prompt_tokens.input_ids.shape[1] - 1

            completion_tokens = self.tokenizer(
                f"{completion} {self.tokenizer.eos_token}", return_tensors="pt"
            ).to(
                self.model.device
            )  # <s> SPIECE_UNDERLINE [tokens] SPIECE_UNDERLINE </s>
            # Actual number of tokens in completion (without `<s> SPIECE_UNDERLINE`)
            completion_num_tokens = completion_tokens.input_ids.shape[1] - 1
            if completion_tokens.input_ids[0, 1] == 29871:
                completion_num_tokens = completion_num_tokens - 1
            completions_num_tokens.append(completion_num_tokens)

            inputs = torch.concatenate(
                (
                    prompt_tokens.input_ids,
                    completion_tokens.input_ids[:, -completion_num_tokens:],
                ),
                dim=-1,
            )
            outputs = self.model(inputs)
            # [input_tokens] [next_token]

            # Include probabilities of 'SPIECE_UNDERLINE </s>' tokens
            logits = outputs.logits[
                :, prompt_num_tokens: prompt_num_tokens + completion_num_tokens
            ]
            logprobs = logits.log_softmax(dim=-1)
            # >>> batch_size, sequence_length, vocab_size

            logprobs = logprobs.gather(
                dim=-1,
                index=completion_tokens.input_ids[:, -completion_num_tokens:].unsqueeze(
                    -1
                ),
            ).squeeze(-1)
            # >>> batch_size, sequence_length
            completions_logprobs.append(logprobs.cpu().numpy())
        return completions_logprobs, completions_num_tokens
Leave a Comment