conservation
unknown
python
6 months ago
3.3 kB
4
Indexable
import nltk # from langchain.llms import OpenAI from langchain_openai import OpenAIEmbeddings, ChatOpenAI # from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings from langchain_community.vectorstores import Chroma from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain import PromptTemplate from langchain.chains.question_answering import load_qa_chain from langchain_community.callbacks.openai_info import OpenAICallbackHandler from src.utils.logger import get_logger import tiktoken logger = get_logger(logname="Conservation", logfile="logs/openapi.log") class ConversationalAI: def __init__(self, config): self.config = config LIST_MODELS = config['LIST_MODELS'] LIST_EMB_MODELS = config['LIST_EMB_MODELS'] MODEL = config['MODEL'] EMB_MODEL = config['EMB_MODEL'] DOCS_VER = config['DOCS_VER'] with open(config['PROMT'], 'r', encoding='utf-8') as file: self.prompt_template = file.read() self.llm = ChatOpenAI(model_name=LIST_MODELS[MODEL], temperature=0.7) self.embeddings = OpenAIEmbeddings(model=LIST_EMB_MODELS[EMB_MODEL]) self.tokenizer = tiktoken.encoding_for_model(LIST_EMB_MODELS[EMB_MODEL]) self.persist_directory = f"resources/vectors/{DOCS_VER}/{EMB_MODEL}" self.vectorstore = Chroma( persist_directory=self.persist_directory, embedding_function=self.embeddings) self.callback_handler = OpenAICallbackHandler() self.qa_chain = ConversationalRetrievalChain.from_llm( llm=self.llm, retriever=self.vectorstore.as_retriever(), return_source_documents=False) # Download necessary NLTK data nltk.download('punkt') nltk.download('averaged_perceptron_tagger') def count_tokens(self, text): return len(self.tokenizer.encode(text)) def reset_callback_handler(self): self.callback_handler = OpenAICallbackHandler() def chat(self, query): query = self.prompt_template.format(question=query) self.reset_callback_handler() result = self.qa_chain({"question": query, "chat_history": []}, callbacks=[self.callback_handler]) embedding_tokens = self.count_tokens(query) payload = {'embedding': embedding_tokens, 'promt': self.callback_handler.prompt_tokens, "completion": self.callback_handler.completion_tokens, 'total': self.callback_handler.total_tokens, 'cost': self.callback_handler.total_cost} logger.info(str(payload)) ans = "Question:" + result['answer'].split("\n\nQuestion:")[-1] return result['answer'] # return ans if __name__ == "__main__": import yaml with open("config/config.yaml", "r") as f: cfg = yaml.safe_load(f) import os os.environ['OPENAI_API_KEY'] = cfg['OPENAI_API_KEY'] conversational_ai = ConversationalAI(cfg) # Example usage question = "What is the capital of France?" response = conversational_ai.chat(question) print(response)
Editor is loading...
Leave a Comment