mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
feat: support human in the loop
This commit is contained in:
@@ -2,10 +2,10 @@ import logging
|
||||
import json
|
||||
from typing import Literal, Annotated
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import Command
|
||||
from langgraph.types import Command, interrupt
|
||||
|
||||
from src.llms.llm import get_llm_by_type
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
@@ -31,7 +31,7 @@ def handoff_to_planner(
|
||||
|
||||
def planner_node(
|
||||
state: State, config: RunnableConfig
|
||||
) -> Command[Literal["research_team", "reporter", "__end__"]]:
|
||||
) -> Command[Literal["human_feedback", "reporter"]]:
|
||||
"""Planner node that generate the full plan."""
|
||||
logger.info("Planner generating full plan")
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
@@ -42,7 +42,6 @@ def planner_node(
|
||||
)
|
||||
else:
|
||||
llm = get_llm_by_type(AGENT_LLM_MAP["planner"])
|
||||
current_plan = state.get("current_plan", None)
|
||||
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
||||
|
||||
# if the plan iterations is greater than the max plan iterations, return the reporter node
|
||||
@@ -60,13 +59,48 @@ def planner_node(
|
||||
logger.debug(f"Current state messages: {state['messages']}")
|
||||
logger.debug(f"Planner response: {full_response}")
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"messages": [AIMessage(content=full_response, name="planner")],
|
||||
"current_plan": full_response,
|
||||
},
|
||||
goto="human_feedback",
|
||||
)
|
||||
|
||||
|
||||
def human_feedback_node(
|
||||
state,
|
||||
) -> Command[Literal["planner", "research_team", "reporter", "__end__"]]:
|
||||
current_plan = state.get("current_plan", "")
|
||||
# check if the plan is auto accepted
|
||||
auto_accepted_plan = state.get("auto_accepted_plan", False)
|
||||
if not auto_accepted_plan:
|
||||
feedback = interrupt(current_plan)
|
||||
|
||||
# if the feedback is not accepted, return the planner node
|
||||
if feedback and str(feedback).upper() != "[ACCEPTED]":
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
HumanMessage(content=feedback, name="feedback"),
|
||||
],
|
||||
},
|
||||
goto="planner",
|
||||
)
|
||||
elif feedback and str(feedback).upper() == "[ACCEPTED]":
|
||||
logger.info("Plan is accepted by user.")
|
||||
else:
|
||||
raise TypeError(f"Interrupt value of {feedback} is not supported.")
|
||||
|
||||
# if the plan is accepted, run the following node
|
||||
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
||||
goto = "research_team"
|
||||
try:
|
||||
full_response = repair_json_output(full_response)
|
||||
current_plan = repair_json_output(current_plan)
|
||||
# increment the plan iterations
|
||||
plan_iterations += 1
|
||||
# parse the plan
|
||||
new_plan = json.loads(full_response)
|
||||
new_plan = json.loads(current_plan)
|
||||
if new_plan["has_enough_context"]:
|
||||
goto = "reporter"
|
||||
except json.JSONDecodeError:
|
||||
@@ -78,8 +112,6 @@ def planner_node(
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"messages": [HumanMessage(content=full_response, name="planner")],
|
||||
"last_plan": current_plan,
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
"plan_iterations": plan_iterations,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user