conservation
unknown
python
a year ago
3.3 kB
7
Indexable
import nltk
# from langchain.llms import OpenAI
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
# from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from src.utils.logger import get_logger
import tiktoken
logger = get_logger(logname="Conservation", logfile="logs/openapi.log")
class ConversationalAI:
def __init__(self, config):
self.config = config
LIST_MODELS = config['LIST_MODELS']
LIST_EMB_MODELS = config['LIST_EMB_MODELS']
MODEL = config['MODEL']
EMB_MODEL = config['EMB_MODEL']
DOCS_VER = config['DOCS_VER']
with open(config['PROMT'], 'r', encoding='utf-8') as file:
self.prompt_template = file.read()
self.llm = ChatOpenAI(model_name=LIST_MODELS[MODEL], temperature=0.7)
self.embeddings = OpenAIEmbeddings(model=LIST_EMB_MODELS[EMB_MODEL])
self.tokenizer = tiktoken.encoding_for_model(LIST_EMB_MODELS[EMB_MODEL])
self.persist_directory = f"resources/vectors/{DOCS_VER}/{EMB_MODEL}"
self.vectorstore = Chroma(
persist_directory=self.persist_directory, embedding_function=self.embeddings)
self.callback_handler = OpenAICallbackHandler()
self.qa_chain = ConversationalRetrievalChain.from_llm(
llm=self.llm,
retriever=self.vectorstore.as_retriever(),
return_source_documents=False)
# Download necessary NLTK data
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
def count_tokens(self, text):
return len(self.tokenizer.encode(text))
def reset_callback_handler(self):
self.callback_handler = OpenAICallbackHandler()
def chat(self, query):
query = self.prompt_template.format(question=query)
self.reset_callback_handler()
result = self.qa_chain({"question": query, "chat_history": []},
callbacks=[self.callback_handler])
embedding_tokens = self.count_tokens(query)
payload = {'embedding': embedding_tokens,
'promt': self.callback_handler.prompt_tokens,
"completion": self.callback_handler.completion_tokens,
'total': self.callback_handler.total_tokens,
'cost': self.callback_handler.total_cost}
logger.info(str(payload))
ans = "Question:" + result['answer'].split("\n\nQuestion:")[-1]
return result['answer']
# return ans
if __name__ == "__main__":
import yaml
with open("config/config.yaml", "r") as f:
cfg = yaml.safe_load(f)
import os
os.environ['OPENAI_API_KEY'] = cfg['OPENAI_API_KEY']
conversational_ai = ConversationalAI(cfg)
# Example usage
question = "What is the capital of France?"
response = conversational_ai.chat(question)
print(response)Editor is loading...
Leave a Comment