Untitled

mail@pastecode.io avatar
unknown
python
a year ago
849 B
2
Indexable
Never
def run_model(model, tokenizer, batch):
    # Tokenize and encode the input data
    input_ids = tokenizer.batch_encode_plus(batch, padding=True, truncation=True, return_tensors='pt')['input_ids']
    decoder_input_ids = torch.ones_like(input_ids)
    # Perform inference on each batch and store the probability values
    probabilities = []
    with torch.no_grad():
        
        logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, return_dict=True).logits

        print(logits.shape, tokenizer.encode("no"), tokenizer.encode("yes"))

        probabilities_batch = torch.softmax(logits[:, :, [tokenizer.encode("no")[0], tokenizer.encode("yes")[0]]], dim=-1)

        yes_probs = 1 - probabilities_batch[:, -1, 1]

        probabilities.extend(yes_probs.tolist())
    # Print the probability values
    return probabilities