Untitled
unknown
plain_text
2 years ago
4.4 kB
16
Indexable
import openai
from typing import List, Tuple, Union, Literal, Annotated
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, state
import asyncio
import operator
# Configure OpenAI API
openai.api_key = "YOUR_OPENAI_API_KEY"
class PlanExecute(BaseModel):
input: str
plan: List[str]
past_steps: Annotated[List[Tuple], operator.add]
response: str
class HistoricalContext:
def __init__(self):
self.history = []
def add_interaction(self, interaction):
self.history.append(interaction)
def get_context(self):
return " ".join(self.history[-5:]) # Limit context to the last 5 interactions
history_manager = HistoricalContext()
def classify_input(input_text):
# Dummy classification logic
if "context" in input_text.lower():
return "needs_context"
return "no_context_needed"
def rewrite_prompt_with_context(prompt, context):
combined_prompt = f"Given the context: {context}\n\nRewrite the following prompt to include the necessary context:\n\nPrompt: {prompt}\n\nRewritten Prompt:"
response = openai.Completion.create(
engine="gpt-35-turbo",
prompt=combined_prompt,
max_tokens=100,
temperature=0.7
)
return response.choices[0].text.strip()
async def forward_to_agent_async(classification, prompt):
await asyncio.sleep(0.5) # Mock delay for processing
return f"Response from {classification} agent for prompt: {prompt}"
# Define state nodes
@state
async def input_parsing_step(state: PlanExecute):
prompts = state.input.split(".")
return {"prompts": [p.strip() for p in prompts if p.strip()]}
@state
async def classifier_step(state: PlanExecute):
classification = classify_input(state.input)
if classification == "needs_context":
return {"context_needed": True, "input": state.input}
return {"classification": classification, "input": state.input}
@state
async def context_retrieval_step(state: PlanExecute):
if state.get("context_needed"):
context = history_manager.get_context()
return {"context": context, "input": state["input"]}
return state
@state
async def contextual_rewriting_step(state: PlanExecute):
if "context" in state:
rewritten_input = rewrite_prompt_with_context(state["input"], state["context"])
return {"input": rewritten_input}
return state
@state
async def reclassification_step(state: PlanExecute):
classification = classify_input(state["input"])
return {"classification": classification, "input": state["input"]}
@state
async def async_agent_step(state: PlanExecute):
tasks = [forward_to_agent_async(classify_input(prompt), prompt) for prompt in state["prompts"]]
results = await asyncio.gather(*tasks)
return {"responses": results}
@state
async def aggregation_step(state: PlanExecute):
aggregated_response = " ".join(state["responses"])
return {"aggregated_response": aggregated_response}
def should_continue(state: PlanExecute) -> Literal["context_retrieval", "async_agent"]:
if state.get("context_needed"):
return "context_retrieval"
return "async_agent"
# Create the state graph
workflow = StateGraph(PlanExecute)
workflow.add_node("input_parsing", input_parsing_step)
workflow.add_node("classifier", classifier_step)
workflow.add_node("context_retrieval", context_retrieval_step)
workflow.add_node("contextual_rewriting", contextual_rewriting_step)
workflow.add_node("reclassification", reclassification_step)
workflow.add_node("async_agent", async_agent_step)
workflow.add_node("aggregation", aggregation_step)
workflow.set_entry_point("input_parsing")
workflow.add_edge("input_parsing", "classifier")
workflow.add_conditional_edges(
"classifier",
{
"context_retrieval": lambda state: state.get("context_needed"),
"async_agent": lambda state: not state.get("context_needed"),
}
)
workflow.add_conditional_edges(
"context_retrieval",
{
"contextual_rewriting": lambda state: "context" in state,
}
)
workflow.add_edge("contextual_rewriting", "reclassification")
workflow.add_edge("reclassification", "async_agent")
workflow.add_edge("async_agent", "aggregation")
app = workflow.compile()
# Example execution
input_data = "Check the balance of my savings account. Forecast my expenses for next month."
result = asyncio.run(app.run({"input": input_data}))
print(result)Editor is loading...
Leave a Comment