Untitled

 avatar
unknown
plain_text
a month ago
10 kB
2
Indexable
from typing import Annotated, List
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
import functools
import tiktoken
from langchain_core.prompts import PromptTemplate
from services.llm.agents.utils import handle_tool_error, create_tool_node_with_fallback, _print_event
import random
from services.llm.agents.utils import trim_tokens
from services.conversation.conversation_summary_rubric import get_default_summary_template, get_task_template, \
    get_thread_template

class State(TypedDict):
    messages: Annotated[list, add_messages]
    token_usage: dict
    function_token_usage: list
    entity_ids: list
    rag_flag: dict


class UnifiedSummary:
    MAX_TOKENS = int(.15 * 128000)
    MAX_MESSAGES = 50
    tokenizer = tiktoken.get_encoding("o200k_base")

    def __init__(self, llm, system_message='', tools=[]):
        self.llm = llm
        self.tools = tools
        # self.tools = [rag_tool]
        self.system_message = system_message
        # self.llm_with_tools = self.llm.bind_tools(self.tools)
        self.build_graph()

    def create_agent(self, llms, state, config):
        """Create an agent."""

        # TODO: explain what are threads and what are tasks
        self.agent_prompt = """
You are a helpful bot tasked with generating conversation summaries in Markdown format. Use the conversation history as context to generate summary.
This summary will be used to generate the product spec/brd. The summary is a culmination of keypoints/topics discussed in the conversation.
Each keypoint is called a thread, which can be used to track individual. Each thread again will have a list of action-items, where each action item
is a individual sub task such that all the action items should be successfully completed to complete the parent thread.   

Follow these instructions one after the other:
- Based on the conversation history and the information gathered from the user, 
    give a detailed summary by filling the information in this template: {summary_template} without altering its format. 
- Summary should include all the unique topics/key-points discussed in the conversation.
- Output the content in exact same structure and in structured Markdown format without using any Markdown delimiters."""

        if config['metadata']['thread_gen']:
            self.agent_prompt = self.agent_prompt + """\n\n
Instructions= Follow the below mentioned rules to generate threads in the next section and strictly do not return extra or less items, fetch all the necessary details.                     
Role : Content Analysts
Task : Analyze dialogue, such as conversations to identify key takeaways about names, dates, time, stakeholder, every detail from the discussion if any do not return extra items.
        (Get all the necessary key points from the conversation)
Output : 
    [List of Takeaways(
**Relevance**: Ensure each point directly relates to the main topic or objective of the discussion/document.
**Clarity**: Use clear and straightforward language to ensure understanding.
**Brevity**: Keep each point concise, ideally one or two sentences.
**Comprehensiveness**: Cover all significant aspects without omitting critical information.
**Logical Order**: Present points in a logical sequence to facilitate understanding.)]"""
            
            # - Generate subtopics called threads by filling this template:{thread_template} where each thread is a topics/key-points from the generated summary.
            # double check that you generate summaries for ALL the threads mentioned in the initial summary."""
        
        if config['metadata']['task_gen']:
            self.agent_prompt = self.agent_prompt + """\n\nBased on the conversation history and the information gathered from the user.
generate a list of action items from the conversation for each thread. Do not change the format of the template : {task_template}."""
            
            
            # """- Generate a list of tasks that should be performed to successfully complete a thread. Do this for each and every thread."""

        self.agent_prompt = self.agent_prompt + """ \n\nuser conversation: {conv_hist}"""


        prompt = PromptTemplate.from_template(self.agent_prompt)

        # TODO: call transcript is not being saved in the checkpointer

        conv_hist = config['configurable']['conv_hist']

        prompt = prompt.partial(summary_template=get_default_summary_template([]))
        prompt = prompt.partial(task_template=get_task_template())
        prompt = prompt.partial(conv_hist= conv_hist)

        # prompt = prompt.partial(system_message=self.system_message)
        # prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
        return [prompt | llm.bind_tools(self.tools) for llm in llms]

    def trim_tokens_tool(messages: List[HumanMessage], max_tokens: int = MAX_TOKENS) -> List[HumanMessage]:

        unique_messages = []
        seen = set()
        for message in messages:
            message_key = (message.content, type(message).__name__)
            if message_key not in seen:
                seen.add(message_key)
                unique_messages.append(message)

        trimmed_messages = []
        total_tokens = 0
        for message in reversed(unique_messages):
            message_tokens = len(UnifiedSummary.tokenizer.encode(message.content))
            if total_tokens + message_tokens <= max_tokens:
                trimmed_messages.insert(0, message)
                total_tokens += message_tokens
            else:
                break
        print("total_tokens from trimmed messages", total_tokens)
        return trimmed_messages

    # Helper function to create a node for a given agent
    # @staticmethod
    def agent_node(self, state, name, config):

        chat_bot_agent = self.create_agent(self.llm, state, config)
        agent_chosen = random.choice(chat_bot_agent)
        # state["messages"] = NaviBot.trim_tokens_tool(state["messages"], NaviBot.MAX_TOKENS)

        # unformatted_messages = state["messages"]
        # messages= NaviBot.format_messages(unformatted_messages)

        result = agent_chosen.invoke(
            {
                "entity_ids": state["entity_ids"],
                # TODO: define the token_limit as some weightage * model_limit
                "messages": trim_tokens(state['messages'], 64000)
            })
        # We convert the agent output into a format that is suitable to append to the global state
        if isinstance(result, ToolMessage):
            # print("tool message....", result)
            pass
        else:
            # calculate tokens
            tokens_used = UnifiedSummary.token_calculator(text=result.content)
            state["token_usage"]["completion_tokens"] = tokens_used
            state["token_usage"]["total_tokens"] = tokens_used
            state["token_usage"]["prompt_tokens"] = 0

            # state["function_token_usage"].append(tokens_used)
            # state["messages"].append(result)

            result = AIMessage(**result.dict(exclude={"type", "name"}), name=name)
        return {
            "messages": [result],
            # Since we have a strict workflow, we can
            # track the sender so we know who to pass to next.
            "sender": name,
        }

    def build_graph(self):
        self.graph_builder = StateGraph(State)
        # Do not generate images for existing entities; instead, return useful links to resources (e.g., images, online products, maps).

        # agent_prompt = \
        #     """You are a helpful bot use the conversation for the context. Provide answers in a general context and
        #     ensure a clear chain of thoughts for logical questions. use today date to search recent events.
        #     Use web search whenever required but strictly call web search at most 3 times .
        #     When the conversation spans multiple topics, prioritize the latest one for your response.
        #     strictly Use markdown for generating answers and mermaid.js diagrams. whenever tool call
        #     fails or answer is not satisfactory retry maximum of 3 times. when the task is complex break it down to small
        #     manageable task and return the response after executing all. generate answers for queries and avoid returning
        #     links as primary response""" + f"{image_diagram_gen_prompt}\n{rag_prompt}" + """
        #     You have access to the following tools: {tool_names}
        #     {system_message}
        #     user prompt: {messages}
        #     """
        # chat_bot_agent = self.create_agent(self.llm, self.tools, agent_prompt)
        chatbot_node = functools.partial(self.agent_node, name="chatbot")

        self.graph_builder.add_node("chatbot", chatbot_node)
        tool_node = create_tool_node_with_fallback(tools=self.tools)
        self.graph_builder.add_node("tools", tool_node)
        self.graph_builder.add_conditional_edges(
            "chatbot",
            tools_condition,
        )
        self.graph_builder.add_edge("tools", "chatbot")
        self.graph_builder.add_edge(START, "chatbot")

    def get_graph(self):
        with open('graph.png', 'wb') as png:
            graph = self.graph_builder.compile()
            png.write(graph.get_graph().draw_mermaid_png())

    def get_graph_builder(self):
        return self.graph_builder

    def get_compiled_graph(self):
        return self.graph_builder.compile()

    def token_calculator(text):
        return len(UnifiedSummary.tokenizer.encode(text))

    def format_messages(messages):
        formatted_messages = []
        for message in messages:
            if isinstance(message, AIMessage):
                formatted_messages.append(f"System: {message.content}")
            elif isinstance(message, HumanMessage):
                formatted_messages.append(f"User: {message.content}")
            elif isinstance(message, ToolMessage):
                formatted_messages.append(f"Tool: {message.name}, Output: {message.content}")
        return formatted_messages
Leave a Comment