Untitled

mail@pastecode.io avatar
unknown
python
14 days ago
2.3 kB
3
Indexable
Never
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import login
import torch
from torch.cuda.amp import autocast
from contexts import data

model_name = "/Projects/Deep_Tracking/yonasab/PycharmProjects/rgpt_model_check/models/CohereAya"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_message_format(prompt):
    message = [{"role": "user", "content": prompt}]
    return message


def generate_aya_23(
        prompts,
        model,
        temperature=0.3,
        top_p=0.75,
        top_k=0,
        max_new_tokens=1024
):
    messages = get_message_format(prompts)

    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        padding=True,
        return_tensors="pt",
    )
    input_ids = input_ids.to(model.device)
    prompt_padded_len = len(input_ids[0])

    gen_tokens = model.generate(
        input_ids,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        max_new_tokens=max_new_tokens,
        do_sample=True,
    )

    # get only generated tokens
    gen_tokens = [
        gt[prompt_padded_len:] for gt in gen_tokens
    ]

    gen_text = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
    return gen_text

def main():
    login('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')
    for context_key, context_value in data.items():
        context = context_value["text"]
        print("\n")
        print(f"Processing {context_key}...")
        print("\n")
        for question in context_value["questions"]:
            print("\n----- שאלה -----")
            print(question)
            prompt = context + "\n" + question

            answer = generate_aya_23(prompt, model)
            print("\n----- תשובה -----")
            print(answer)
            # print_interaction(question, answer)
            # print("PROMPT", prompt, "RESPONSE", answer, "\n", sep="\n")

if __name__ == "__main__":
    main()
Leave a Comment