conservation

 avatar
unknown
python
6 months ago
3.3 kB
4
Indexable
import nltk
# from langchain.llms import OpenAI
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
# from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from src.utils.logger import get_logger
import tiktoken

logger = get_logger(logname="Conservation", logfile="logs/openapi.log")


class ConversationalAI:
    def __init__(self, config):
        self.config = config

        LIST_MODELS = config['LIST_MODELS']
        LIST_EMB_MODELS = config['LIST_EMB_MODELS']

        MODEL = config['MODEL']
        EMB_MODEL = config['EMB_MODEL']
        DOCS_VER = config['DOCS_VER']

        with open(config['PROMT'], 'r', encoding='utf-8') as file:
            self.prompt_template = file.read()

        self.llm = ChatOpenAI(model_name=LIST_MODELS[MODEL], temperature=0.7)
        self.embeddings = OpenAIEmbeddings(model=LIST_EMB_MODELS[EMB_MODEL])
        self.tokenizer = tiktoken.encoding_for_model(LIST_EMB_MODELS[EMB_MODEL])

        self.persist_directory = f"resources/vectors/{DOCS_VER}/{EMB_MODEL}"
        self.vectorstore = Chroma(
            persist_directory=self.persist_directory, embedding_function=self.embeddings)

        self.callback_handler = OpenAICallbackHandler()
        self.qa_chain = ConversationalRetrievalChain.from_llm(
            llm=self.llm,
            retriever=self.vectorstore.as_retriever(),
            return_source_documents=False)

        # Download necessary NLTK data
        nltk.download('punkt')
        nltk.download('averaged_perceptron_tagger')

    def count_tokens(self, text):
        return len(self.tokenizer.encode(text))
    def reset_callback_handler(self):
        self.callback_handler = OpenAICallbackHandler()

    def chat(self, query):
        query = self.prompt_template.format(question=query)
        self.reset_callback_handler()
        result = self.qa_chain({"question": query, "chat_history": []},
                               callbacks=[self.callback_handler])
        
        embedding_tokens = self.count_tokens(query)

        payload = {'embedding': embedding_tokens,
                   'promt': self.callback_handler.prompt_tokens,
                   "completion": self.callback_handler.completion_tokens,
                   'total': self.callback_handler.total_tokens,
                   'cost': self.callback_handler.total_cost}

        logger.info(str(payload))

        ans = "Question:" + result['answer'].split("\n\nQuestion:")[-1]
        return result['answer']
        # return ans


if __name__ == "__main__":
    import yaml
    with open("config/config.yaml", "r") as f:
        cfg = yaml.safe_load(f)
    import os
    os.environ['OPENAI_API_KEY'] = cfg['OPENAI_API_KEY']
    conversational_ai = ConversationalAI(cfg)
    # Example usage
    question = "What is the capital of France?"
    response = conversational_ai.chat(question)
    print(response)
Editor is loading...
Leave a Comment