Untitled
unknown
plain_text
a year ago
4.4 kB
5
Indexable
import openai from typing import List, Tuple, Union, Literal, Annotated from pydantic import BaseModel, Field from langgraph.graph import StateGraph, state import asyncio import operator # Configure OpenAI API openai.api_key = "YOUR_OPENAI_API_KEY" class PlanExecute(BaseModel): input: str plan: List[str] past_steps: Annotated[List[Tuple], operator.add] response: str class HistoricalContext: def __init__(self): self.history = [] def add_interaction(self, interaction): self.history.append(interaction) def get_context(self): return " ".join(self.history[-5:]) # Limit context to the last 5 interactions history_manager = HistoricalContext() def classify_input(input_text): # Dummy classification logic if "context" in input_text.lower(): return "needs_context" return "no_context_needed" def rewrite_prompt_with_context(prompt, context): combined_prompt = f"Given the context: {context}\n\nRewrite the following prompt to include the necessary context:\n\nPrompt: {prompt}\n\nRewritten Prompt:" response = openai.Completion.create( engine="gpt-35-turbo", prompt=combined_prompt, max_tokens=100, temperature=0.7 ) return response.choices[0].text.strip() async def forward_to_agent_async(classification, prompt): await asyncio.sleep(0.5) # Mock delay for processing return f"Response from {classification} agent for prompt: {prompt}" # Define state nodes @state async def input_parsing_step(state: PlanExecute): prompts = state.input.split(".") return {"prompts": [p.strip() for p in prompts if p.strip()]} @state async def classifier_step(state: PlanExecute): classification = classify_input(state.input) if classification == "needs_context": return {"context_needed": True, "input": state.input} return {"classification": classification, "input": state.input} @state async def context_retrieval_step(state: PlanExecute): if state.get("context_needed"): context = history_manager.get_context() return {"context": context, "input": state["input"]} return state @state async def contextual_rewriting_step(state: PlanExecute): if "context" in state: rewritten_input = rewrite_prompt_with_context(state["input"], state["context"]) return {"input": rewritten_input} return state @state async def reclassification_step(state: PlanExecute): classification = classify_input(state["input"]) return {"classification": classification, "input": state["input"]} @state async def async_agent_step(state: PlanExecute): tasks = [forward_to_agent_async(classify_input(prompt), prompt) for prompt in state["prompts"]] results = await asyncio.gather(*tasks) return {"responses": results} @state async def aggregation_step(state: PlanExecute): aggregated_response = " ".join(state["responses"]) return {"aggregated_response": aggregated_response} def should_continue(state: PlanExecute) -> Literal["context_retrieval", "async_agent"]: if state.get("context_needed"): return "context_retrieval" return "async_agent" # Create the state graph workflow = StateGraph(PlanExecute) workflow.add_node("input_parsing", input_parsing_step) workflow.add_node("classifier", classifier_step) workflow.add_node("context_retrieval", context_retrieval_step) workflow.add_node("contextual_rewriting", contextual_rewriting_step) workflow.add_node("reclassification", reclassification_step) workflow.add_node("async_agent", async_agent_step) workflow.add_node("aggregation", aggregation_step) workflow.set_entry_point("input_parsing") workflow.add_edge("input_parsing", "classifier") workflow.add_conditional_edges( "classifier", { "context_retrieval": lambda state: state.get("context_needed"), "async_agent": lambda state: not state.get("context_needed"), } ) workflow.add_conditional_edges( "context_retrieval", { "contextual_rewriting": lambda state: "context" in state, } ) workflow.add_edge("contextual_rewriting", "reclassification") workflow.add_edge("reclassification", "async_agent") workflow.add_edge("async_agent", "aggregation") app = workflow.compile() # Example execution input_data = "Check the balance of my savings account. Forecast my expenses for next month." result = asyncio.run(app.run({"input": input_data})) print(result)
Editor is loading...
Leave a Comment