Untitled

mail@pastecode.io avatar
unknown
plain_text
7 months ago
1.0 kB
2
Indexable
Never
import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM

GPU = torch.device("cuda")

tokenizer = LlamaTokenizer.from_pretrained("syzymon/long_llama_code_7b_instruct")
model = AutoModelForCausalLM.from_pretrained("syzymon/long_llama_code_7b_instruct",
                                             torch_dtype=torch.float32,
                                             trust_remote_code=True)
model.to(GPU)
while True:
    # Get prompt from user
    prompt = input("Enter your prompt (or 'exit' to quit): ")

    if prompt.lower() == 'exit':
        break

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    input_ids = input_ids.to(GPU)
    outputs = model(input_ids=input_ids)
    logits = outputs.logits

    # Get the most probable token IDs from the logits
    predicted_token_ids = torch.argmax(logits, dim=-1)

    # Decode the token IDs to get the text
    decoded_text = tokenizer.decode(predicted_token_ids[0].tolist())

    print(decoded_text)