feat: support human in the loop

This commit is contained in:
He Tao
2025-04-14 18:01:50 +08:00
parent a759c168fa
commit a7ae47fc7a
12 changed files with 329 additions and 21 deletions

View File

@@ -8,6 +8,7 @@ from .nodes import (
research_team_node,
researcher_node,
coder_node,
human_feedback_node,
)
@@ -26,5 +27,6 @@ def build_graph():
builder.add_node("research_team", research_team_node)
builder.add_node("researcher", researcher_node)
builder.add_node("coder", coder_node)
builder.add_node("human_feedback", human_feedback_node)
builder.add_edge("reporter", END)
return builder.compile(checkpointer=memory)

View File

@@ -2,10 +2,10 @@ import logging
import json
from typing import Literal, Annotated
from langchain_core.messages import HumanMessage
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.tools import tool
from langchain_core.runnables import RunnableConfig
from langgraph.types import Command
from langgraph.types import Command, interrupt
from src.llms.llm import get_llm_by_type
from src.config.agents import AGENT_LLM_MAP
@@ -31,7 +31,7 @@ def handoff_to_planner(
def planner_node(
state: State, config: RunnableConfig
) -> Command[Literal["research_team", "reporter", "__end__"]]:
) -> Command[Literal["human_feedback", "reporter"]]:
"""Planner node that generate the full plan."""
logger.info("Planner generating full plan")
configurable = Configuration.from_runnable_config(config)
@@ -42,7 +42,6 @@ def planner_node(
)
else:
llm = get_llm_by_type(AGENT_LLM_MAP["planner"])
current_plan = state.get("current_plan", None)
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
# if the plan iterations is greater than the max plan iterations, return the reporter node
@@ -60,13 +59,48 @@ def planner_node(
logger.debug(f"Current state messages: {state['messages']}")
logger.debug(f"Planner response: {full_response}")
return Command(
update={
"messages": [AIMessage(content=full_response, name="planner")],
"current_plan": full_response,
},
goto="human_feedback",
)
def human_feedback_node(
state,
) -> Command[Literal["planner", "research_team", "reporter", "__end__"]]:
current_plan = state.get("current_plan", "")
# check if the plan is auto accepted
auto_accepted_plan = state.get("auto_accepted_plan", False)
if not auto_accepted_plan:
feedback = interrupt(current_plan)
# if the feedback is not accepted, return the planner node
if feedback and str(feedback).upper() != "[ACCEPTED]":
return Command(
update={
"messages": [
HumanMessage(content=feedback, name="feedback"),
],
},
goto="planner",
)
elif feedback and str(feedback).upper() == "[ACCEPTED]":
logger.info("Plan is accepted by user.")
else:
raise TypeError(f"Interrupt value of {feedback} is not supported.")
# if the plan is accepted, run the following node
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
goto = "research_team"
try:
full_response = repair_json_output(full_response)
current_plan = repair_json_output(current_plan)
# increment the plan iterations
plan_iterations += 1
# parse the plan
new_plan = json.loads(full_response)
new_plan = json.loads(current_plan)
if new_plan["has_enough_context"]:
goto = "reporter"
except json.JSONDecodeError:
@@ -78,8 +112,6 @@ def planner_node(
return Command(
update={
"messages": [HumanMessage(content=full_response, name="planner")],
"last_plan": current_plan,
"current_plan": Plan.model_validate(new_plan),
"plan_iterations": plan_iterations,
},

View File

@@ -11,6 +11,6 @@ class State(MessagesState):
# Runtime Variables
observations: Annotated[list[str], operator.add] = []
plan_iterations: int = 0
last_plan: Plan = None
current_plan: Plan = None
current_plan: Plan | str = None
final_report: str = ""
auto_accepted_plan: bool = False

View File

@@ -7,6 +7,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from langchain_core.messages import AIMessageChunk, ToolMessage
from langgraph.types import Command
from src.graph.builder import build_graph
from src.server.chat_request import ChatMessage, ChatRequest
@@ -42,6 +43,8 @@ async def chat_stream(request: ChatRequest):
thread_id,
request.max_plan_iterations,
request.max_step_num,
request.auto_accepted_plan,
request.feedback,
),
media_type="text/event-stream",
)
@@ -52,9 +55,14 @@ async def _astream_workflow_generator(
thread_id: str,
max_plan_iterations: int,
max_step_num: int,
auto_accepted_plan: bool,
feedback: str,
):
input_ = {"messages": messages, "auto_accepted_plan": auto_accepted_plan}
if not auto_accepted_plan and feedback:
input_ = Command(resume=feedback)
async for agent, _, event_data in graph.astream(
{"messages": messages},
input_,
config={
"thread_id": thread_id,
"max_plan_iterations": max_plan_iterations,

View File

@@ -35,3 +35,9 @@ class ChatRequest(BaseModel):
max_step_num: Optional[int] = Field(
3, description="The maximum number of steps in a plan"
)
auto_accepted_plan: Optional[bool] = Field(
False, description="Whether to automatically accept the plan"
)
feedback: Optional[str] = Field(
None, description="Feedback from the user on the plan"
)

View File

@@ -46,6 +46,7 @@ def run_agent_workflow(
initial_state = {
# Runtime Variables
"messages": [{"role": "user", "content": user_input}],
"auto_accepted_plan": True,
}
config = {
"configurable": {