refactor: refine teh background check logic (#306)

This commit is contained in:
DanielWalnut
2025-06-11 11:10:02 +08:00
committed by GitHub
parent eeff1ebf80
commit 447e427fd3
4 changed files with 19 additions and 10 deletions

View File

@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
@tool
def handoff_to_planner(
task_title: Annotated[str, "The title of the task to be handed off."],
research_topic: Annotated[str, "The topic of the research task to be handed off."],
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
):
"""Handoff to planner agent to do plan."""
@@ -48,7 +48,7 @@ def handoff_to_planner(
def background_investigation_node(state: State, config: RunnableConfig):
logger.info("background investigation node is running.")
configurable = Configuration.from_runnable_config(config)
query = state["messages"][-1].content
query = state.get("research_topic")
background_investigation_results = None
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
searched_content = LoggedTavilySearch(
@@ -87,10 +87,8 @@ def planner_node(
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
messages = apply_prompt_template("planner", state, configurable)
if (
plan_iterations == 0
and state.get("enable_background_investigation")
and state.get("background_investigation_results")
if state.get("enable_background_investigation") and state.get(
"background_investigation_results"
):
messages += [
{
@@ -221,6 +219,7 @@ def coordinator_node(
goto = "__end__"
locale = state.get("locale", "en-US") # Default locale if not specified
research_topic = state.get("research_topic", "")
if len(response.tool_calls) > 0:
goto = "planner"
@@ -231,8 +230,11 @@ def coordinator_node(
for tool_call in response.tool_calls:
if tool_call.get("name", "") != "handoff_to_planner":
continue
if tool_locale := tool_call.get("args", {}).get("locale"):
locale = tool_locale
if tool_call.get("args", {}).get("locale") and tool_call.get(
"args", {}
).get("research_topic"):
locale = tool_call.get("args", {}).get("locale")
research_topic = tool_call.get("args", {}).get("research_topic")
break
except Exception as e:
logger.error(f"Error processing tool calls: {e}")
@@ -243,7 +245,11 @@ def coordinator_node(
logger.debug(f"Coordinator response: {response}")
return Command(
update={"locale": locale, "resources": configurable.resources},
update={
"locale": locale,
"research_topic": research_topic,
"resources": configurable.resources,
},
goto=goto,
)

View File

@@ -12,6 +12,7 @@ class State(MessagesState):
# Runtime Variables
locale: str = "en-US"
research_topic: str = ""
observations: list[str] = []
resources: list[Resource] = []
plan_iterations: int = 0

View File

@@ -87,7 +87,7 @@ async def chat_stream(request: ChatRequest):
async def _astream_workflow_generator(
messages: List[ChatMessage],
messages: List[dict],
thread_id: str,
resources: List[Resource],
max_plan_iterations: int,
@@ -107,6 +107,7 @@ async def _astream_workflow_generator(
"observations": [],
"auto_accepted_plan": auto_accepted_plan,
"enable_background_investigation": enable_background_investigation,
"research_topic": messages[-1]["content"] if messages else "",
}
if not auto_accepted_plan and interrupt_feedback:
resume_msg = f"[{interrupt_feedback}]"

View File

@@ -20,6 +20,7 @@ MOCK_SEARCH_RESULTS = [
def mock_state():
return {
"messages": [HumanMessage(content="test query")],
"research_topic": "test query",
"background_investigation_results": None,
}