Untitled
unknown
python
a year ago
2.3 kB
11
Indexable
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import login
import torch
from torch.cuda.amp import autocast
from contexts import data
model_name = "/Projects/Deep_Tracking/yonasab/PycharmProjects/rgpt_model_check/models/CohereAya"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_message_format(prompt):
message = [{"role": "user", "content": prompt}]
return message
def generate_aya_23(
prompts,
model,
temperature=0.3,
top_p=0.75,
top_k=0,
max_new_tokens=1024
):
messages = get_message_format(prompts)
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
padding=True,
return_tensors="pt",
)
input_ids = input_ids.to(model.device)
prompt_padded_len = len(input_ids[0])
gen_tokens = model.generate(
input_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens,
do_sample=True,
)
# get only generated tokens
gen_tokens = [
gt[prompt_padded_len:] for gt in gen_tokens
]
gen_text = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
return gen_text
def main():
login('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')
for context_key, context_value in data.items():
context = context_value["text"]
print("\n")
print(f"Processing {context_key}...")
print("\n")
for question in context_value["questions"]:
print("\n----- שאלה -----")
print(question)
prompt = context + "\n" + question
answer = generate_aya_23(prompt, model)
print("\n----- תשובה -----")
print(answer)
# print_interaction(question, answer)
# print("PROMPT", prompt, "RESPONSE", answer, "\n", sep="\n")
if __name__ == "__main__":
main()Editor is loading...
Leave a Comment