Untitled

 avatar
unknown
plain_text
a year ago
4.4 kB
5
Indexable
import openai
from typing import List, Tuple, Union, Literal, Annotated
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, state
import asyncio
import operator

# Configure OpenAI API
openai.api_key = "YOUR_OPENAI_API_KEY"

class PlanExecute(BaseModel):
    input: str
    plan: List[str]
    past_steps: Annotated[List[Tuple], operator.add]
    response: str

class HistoricalContext:
    def __init__(self):
        self.history = []

    def add_interaction(self, interaction):
        self.history.append(interaction)

    def get_context(self):
        return " ".join(self.history[-5:])  # Limit context to the last 5 interactions

history_manager = HistoricalContext()

def classify_input(input_text):
    # Dummy classification logic
    if "context" in input_text.lower():
        return "needs_context"
    return "no_context_needed"

def rewrite_prompt_with_context(prompt, context):
    combined_prompt = f"Given the context: {context}\n\nRewrite the following prompt to include the necessary context:\n\nPrompt: {prompt}\n\nRewritten Prompt:"
    response = openai.Completion.create(
        engine="gpt-35-turbo",
        prompt=combined_prompt,
        max_tokens=100,
        temperature=0.7
    )
    return response.choices[0].text.strip()

async def forward_to_agent_async(classification, prompt):
    await asyncio.sleep(0.5)  # Mock delay for processing
    return f"Response from {classification} agent for prompt: {prompt}"

# Define state nodes
@state
async def input_parsing_step(state: PlanExecute):
    prompts = state.input.split(".")
    return {"prompts": [p.strip() for p in prompts if p.strip()]}

@state
async def classifier_step(state: PlanExecute):
    classification = classify_input(state.input)
    if classification == "needs_context":
        return {"context_needed": True, "input": state.input}
    return {"classification": classification, "input": state.input}

@state
async def context_retrieval_step(state: PlanExecute):
    if state.get("context_needed"):
        context = history_manager.get_context()
        return {"context": context, "input": state["input"]}
    return state

@state
async def contextual_rewriting_step(state: PlanExecute):
    if "context" in state:
        rewritten_input = rewrite_prompt_with_context(state["input"], state["context"])
        return {"input": rewritten_input}
    return state

@state
async def reclassification_step(state: PlanExecute):
    classification = classify_input(state["input"])
    return {"classification": classification, "input": state["input"]}

@state
async def async_agent_step(state: PlanExecute):
    tasks = [forward_to_agent_async(classify_input(prompt), prompt) for prompt in state["prompts"]]
    results = await asyncio.gather(*tasks)
    return {"responses": results}

@state
async def aggregation_step(state: PlanExecute):
    aggregated_response = " ".join(state["responses"])
    return {"aggregated_response": aggregated_response}

def should_continue(state: PlanExecute) -> Literal["context_retrieval", "async_agent"]:
    if state.get("context_needed"):
        return "context_retrieval"
    return "async_agent"

# Create the state graph
workflow = StateGraph(PlanExecute)

workflow.add_node("input_parsing", input_parsing_step)
workflow.add_node("classifier", classifier_step)
workflow.add_node("context_retrieval", context_retrieval_step)
workflow.add_node("contextual_rewriting", contextual_rewriting_step)
workflow.add_node("reclassification", reclassification_step)
workflow.add_node("async_agent", async_agent_step)
workflow.add_node("aggregation", aggregation_step)

workflow.set_entry_point("input_parsing")

workflow.add_edge("input_parsing", "classifier")

workflow.add_conditional_edges(
    "classifier",
    {
        "context_retrieval": lambda state: state.get("context_needed"),
        "async_agent": lambda state: not state.get("context_needed"),
    }
)

workflow.add_conditional_edges(
    "context_retrieval",
    {
        "contextual_rewriting": lambda state: "context" in state,
    }
)

workflow.add_edge("contextual_rewriting", "reclassification")
workflow.add_edge("reclassification", "async_agent")
workflow.add_edge("async_agent", "aggregation")

app = workflow.compile()

# Example execution
input_data = "Check the balance of my savings account. Forecast my expenses for next month."
result = asyncio.run(app.run({"input": input_data}))
print(result)
Editor is loading...
Leave a Comment