mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-12 10:04:45 +08:00
fix: Refine clarification workflow state handling (#641)
* fix: support local models by making thought field optional in Plan model - Make thought field optional in Plan model to fix Pydantic validation errors with local models - Add Ollama configuration example to conf.yaml.example - Update documentation to include local model support - Improve planner prompt with better JSON format requirements Fixes local model integration issues where models like qwen3:14b would fail due to missing thought field in JSON output. * feat: Add intelligent clarification feature for research queries - Add multi-turn clarification process to refine vague research questions - Implement three-dimension clarification standard (Tech/App, Focus, Scope) - Add clarification state management in coordinator node - Update coordinator prompt with detailed clarification guidelines - Add UI settings to enable/disable clarification feature (disabled by default) - Update workflow to handle clarification rounds recursively - Add comprehensive test coverage for clarification functionality - Update documentation with clarification feature usage guide Key components: - src/graph/nodes.py: Core clarification logic and state management - src/prompts/coordinator.md: Detailed clarification guidelines - src/workflow.py: Recursive clarification handling - web/: UI settings integration - tests/: Comprehensive test coverage - docs/: Updated configuration guide * fix: Improve clarification conversation continuity - Add comprehensive conversation history to clarification context - Include previous exchanges summary in system messages - Add explicit guidelines for continuing rounds in coordinator prompt - Prevent LLM from starting new topics during clarification - Ensure topic continuity across clarification rounds Fixes issue where LLM would restart clarification instead of building upon previous exchanges. * fix: Add conversation history to clarification context * fix: resolve clarification feature message to planer, prompt, test issues - Optimize coordinator.md prompt template for better clarification flow - Simplify final message sent to planner after clarification - Fix API key assertion issues in test_search.py * fix: Add configurable max_clarification_rounds and comprehensive tests - Add max_clarification_rounds parameter for external configuration - Add comprehensive test cases for clarification feature in test_app.py - Fixes issues found during interactive mode testing where: - Recursive call failed due to missing initial_state parameter - Clarification exited prematurely at max rounds - Incorrect logging of max rounds reached * Move clarification tests to test_nodes.py and add max_clarification_rounds to zh.json * fix: add max_clarification_rounds parameter passing from frontend to backend - Add max_clarification_rounds parameter in store.ts sendMessage function - Add max_clarification_rounds type definition in chat.ts - Ensure frontend settings page clarification rounds are correctly passed to backend * fix: refine clarification workflow state handling and coverage - Add clarification history reconstruction - Fix clarified topic accumulation - Add clarified_research_topic state field - Preserve clarification state in recursive calls - Add comprehensive test coverage * refactor: optimize coordinator logic and type annotations - Simplify handoff topic logic in coordinator_node - Update type annotations from Tuple to tuple - Improve code readability and maintainability --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -51,7 +51,9 @@ class Configuration:
|
||||
mcp_settings: dict = None # MCP settings, including dynamic loaded tools
|
||||
report_style: str = ReportStyle.ACADEMIC.value # Report style
|
||||
enable_deep_thinking: bool = False # Whether to enable deep thinking
|
||||
enforce_web_search: bool = False # Enforce at least one web search step in every plan
|
||||
enforce_web_search: bool = (
|
||||
False # Enforce at least one web search step in every plan
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
|
||||
@@ -31,6 +31,12 @@ from src.utils.json_utils import repair_json_output
|
||||
|
||||
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
|
||||
from .types import State
|
||||
from .utils import (
|
||||
build_clarified_topic_from_history,
|
||||
get_message_content,
|
||||
is_user_message,
|
||||
reconstruct_clarification_history,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,6 +55,9 @@ def handoff_to_planner(
|
||||
@tool
|
||||
def handoff_after_clarification(
|
||||
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
|
||||
research_topic: Annotated[
|
||||
str, "The clarified research topic based on all clarification rounds."
|
||||
],
|
||||
):
|
||||
"""Handoff to planner after clarification rounds are complete. Pass all clarification history to planner for analysis."""
|
||||
return
|
||||
@@ -78,23 +87,23 @@ def needs_clarification(state: dict) -> bool:
|
||||
def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict:
|
||||
"""
|
||||
Validate and fix a plan to ensure it meets requirements.
|
||||
|
||||
|
||||
Args:
|
||||
plan: The plan dict to validate
|
||||
enforce_web_search: If True, ensure at least one step has need_search=true
|
||||
|
||||
|
||||
Returns:
|
||||
The validated/fixed plan dict
|
||||
"""
|
||||
if not isinstance(plan, dict):
|
||||
return plan
|
||||
|
||||
|
||||
steps = plan.get("steps", [])
|
||||
|
||||
|
||||
if enforce_web_search:
|
||||
# Check if any step has need_search=true
|
||||
has_search_step = any(step.get("need_search", False) for step in steps)
|
||||
|
||||
|
||||
if not has_search_step and steps:
|
||||
# Ensure first research step has web search enabled
|
||||
for idx, step in enumerate(steps):
|
||||
@@ -107,7 +116,9 @@ def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict:
|
||||
# This ensures that at least one step will perform a web search as required.
|
||||
steps[0]["step_type"] = "research"
|
||||
steps[0]["need_search"] = True
|
||||
logger.info("Converted first step to research with web search enforcement")
|
||||
logger.info(
|
||||
"Converted first step to research with web search enforcement"
|
||||
)
|
||||
elif not has_search_step and not steps:
|
||||
# Add a default research step if no steps exist
|
||||
logger.warning("Plan has no steps. Adding default research step.")
|
||||
@@ -119,14 +130,14 @@ def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict:
|
||||
"step_type": "research",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
return plan
|
||||
|
||||
|
||||
def background_investigation_node(state: State, config: RunnableConfig):
|
||||
logger.info("background investigation node is running.")
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
query = state.get("research_topic")
|
||||
query = state.get("clarified_research_topic") or state.get("research_topic")
|
||||
background_investigation_results = None
|
||||
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
|
||||
searched_content = LoggedTavilySearch(
|
||||
@@ -230,11 +241,11 @@ def planner_node(
|
||||
return Command(goto="reporter")
|
||||
else:
|
||||
return Command(goto="__end__")
|
||||
|
||||
|
||||
# Validate and fix plan to ensure web search requirements are met
|
||||
if isinstance(curr_plan, dict):
|
||||
curr_plan = validate_and_fix_plan(curr_plan, configurable.enforce_web_search)
|
||||
|
||||
|
||||
if isinstance(curr_plan, dict) and curr_plan.get("has_enough_context"):
|
||||
logger.info("Planner response has enough context.")
|
||||
new_plan = Plan.model_validate(curr_plan)
|
||||
@@ -316,7 +327,8 @@ def coordinator_node(
|
||||
|
||||
# Check if clarification is enabled
|
||||
enable_clarification = state.get("enable_clarification", False)
|
||||
|
||||
initial_topic = state.get("research_topic", "")
|
||||
clarified_topic = initial_topic
|
||||
# ============================================================
|
||||
# BRANCH 1: Clarification DISABLED (Legacy Mode)
|
||||
# ============================================================
|
||||
@@ -338,7 +350,6 @@ def coordinator_node(
|
||||
.invoke(messages)
|
||||
)
|
||||
|
||||
# Process response - should directly handoff to planner
|
||||
goto = "__end__"
|
||||
locale = state.get("locale", "en-US")
|
||||
research_topic = state.get("research_topic", "")
|
||||
@@ -370,12 +381,28 @@ def coordinator_node(
|
||||
else:
|
||||
# Load clarification state
|
||||
clarification_rounds = state.get("clarification_rounds", 0)
|
||||
clarification_history = state.get("clarification_history", [])
|
||||
clarification_history = list(state.get("clarification_history", []) or [])
|
||||
clarification_history = [item for item in clarification_history if item]
|
||||
max_clarification_rounds = state.get("max_clarification_rounds", 3)
|
||||
|
||||
# Prepare the messages for the coordinator
|
||||
state_messages = list(state.get("messages", []))
|
||||
messages = apply_prompt_template("coordinator", state)
|
||||
|
||||
clarification_history = reconstruct_clarification_history(
|
||||
state_messages, clarification_history, initial_topic
|
||||
)
|
||||
clarified_topic, clarification_history = build_clarified_topic_from_history(
|
||||
clarification_history
|
||||
)
|
||||
logger.debug("Clarification history rebuilt: %s", clarification_history)
|
||||
|
||||
if clarification_history:
|
||||
initial_topic = clarification_history[0]
|
||||
latest_user_content = clarification_history[-1]
|
||||
else:
|
||||
latest_user_content = ""
|
||||
|
||||
# Add clarification status for first round
|
||||
if clarification_rounds == 0:
|
||||
messages.append(
|
||||
@@ -385,91 +412,21 @@ def coordinator_node(
|
||||
}
|
||||
)
|
||||
|
||||
# Add clarification context if continuing conversation (round > 0)
|
||||
elif clarification_rounds > 0:
|
||||
logger.info(
|
||||
f"Clarification enabled (rounds: {clarification_rounds}/{max_clarification_rounds}): Continuing conversation"
|
||||
)
|
||||
logger.info(
|
||||
"Clarification round %s/%s | topic: %s | latest user content: %s",
|
||||
clarification_rounds,
|
||||
max_clarification_rounds,
|
||||
clarified_topic or initial_topic,
|
||||
latest_user_content or "N/A",
|
||||
)
|
||||
|
||||
# Add user's response to clarification history (only user messages)
|
||||
last_message = None
|
||||
if state.get("messages"):
|
||||
last_message = state["messages"][-1]
|
||||
# Extract content from last message for logging
|
||||
if isinstance(last_message, dict):
|
||||
content = last_message.get("content", "No content")
|
||||
else:
|
||||
content = getattr(last_message, "content", "No content")
|
||||
logger.info(f"Last message content: {content}")
|
||||
# Handle dict format
|
||||
if isinstance(last_message, dict):
|
||||
if last_message.get("role") == "user":
|
||||
clarification_history.append(last_message["content"])
|
||||
logger.info(
|
||||
f"Added user response to clarification history: {last_message['content']}"
|
||||
)
|
||||
# Handle object format (like HumanMessage)
|
||||
elif hasattr(last_message, "role") and last_message.role == "user":
|
||||
clarification_history.append(last_message.content)
|
||||
logger.info(
|
||||
f"Added user response to clarification history: {last_message.content}"
|
||||
)
|
||||
# Handle object format with content attribute (like the one in logs)
|
||||
elif hasattr(last_message, "content"):
|
||||
clarification_history.append(last_message.content)
|
||||
logger.info(
|
||||
f"Added user response to clarification history: {last_message.content}"
|
||||
)
|
||||
current_response = latest_user_content or "No response"
|
||||
|
||||
# Build comprehensive clarification context with conversation history
|
||||
current_response = "No response"
|
||||
if last_message:
|
||||
# Handle dict format
|
||||
if isinstance(last_message, dict):
|
||||
if last_message.get("role") == "user":
|
||||
current_response = last_message.get("content", "No response")
|
||||
else:
|
||||
# If last message is not from user, try to get the latest user message
|
||||
messages = state.get("messages", [])
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||||
current_response = msg.get("content", "No response")
|
||||
break
|
||||
# Handle object format (like HumanMessage)
|
||||
elif hasattr(last_message, "role") and last_message.role == "user":
|
||||
current_response = last_message.content
|
||||
# Handle object format with content attribute (like the one in logs)
|
||||
elif hasattr(last_message, "content"):
|
||||
current_response = last_message.content
|
||||
else:
|
||||
# If last message is not from user, try to get the latest user message
|
||||
messages = state.get("messages", [])
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||||
current_response = msg.get("content", "No response")
|
||||
break
|
||||
elif hasattr(msg, "role") and msg.role == "user":
|
||||
current_response = msg.content
|
||||
break
|
||||
elif hasattr(msg, "content"):
|
||||
current_response = msg.content
|
||||
break
|
||||
|
||||
# Create conversation history summary
|
||||
conversation_summary = ""
|
||||
if clarification_history:
|
||||
conversation_summary = "Previous conversation:\n"
|
||||
for i, response in enumerate(clarification_history, 1):
|
||||
conversation_summary += f"- Round {i}: {response}\n"
|
||||
|
||||
clarification_context = f"""Continuing clarification (round {clarification_rounds}/{max_clarification_rounds}):
|
||||
clarification_context = f"""Continuing clarification (round {clarification_rounds}/{max_clarification_rounds}):
|
||||
User's latest response: {current_response}
|
||||
Ask for remaining missing dimensions. Do NOT repeat questions or start new topics."""
|
||||
|
||||
# Log the clarification context for debugging
|
||||
logger.info(f"Clarification context: {clarification_context}")
|
||||
|
||||
messages.append({"role": "system", "content": clarification_context})
|
||||
messages.append({"role": "system", "content": clarification_context})
|
||||
|
||||
# Bind both clarification tools
|
||||
tools = [handoff_to_planner, handoff_after_clarification]
|
||||
@@ -483,7 +440,13 @@ def coordinator_node(
|
||||
# Initialize response processing variables
|
||||
goto = "__end__"
|
||||
locale = state.get("locale", "en-US")
|
||||
research_topic = state.get("research_topic", "")
|
||||
research_topic = (
|
||||
clarification_history[0]
|
||||
if clarification_history
|
||||
else state.get("research_topic", "")
|
||||
)
|
||||
if not clarified_topic:
|
||||
clarified_topic = research_topic
|
||||
|
||||
# --- Process LLM response ---
|
||||
# No tool calls - LLM is asking a clarifying question
|
||||
@@ -497,20 +460,21 @@ def coordinator_node(
|
||||
)
|
||||
|
||||
# Append coordinator's question to messages
|
||||
state_messages = state.get("messages", [])
|
||||
updated_messages = list(state_messages)
|
||||
if response.content:
|
||||
state_messages.append(
|
||||
updated_messages.append(
|
||||
HumanMessage(content=response.content, name="coordinator")
|
||||
)
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"messages": state_messages,
|
||||
"messages": updated_messages,
|
||||
"locale": locale,
|
||||
"research_topic": research_topic,
|
||||
"resources": configurable.resources,
|
||||
"clarification_rounds": clarification_rounds,
|
||||
"clarification_history": clarification_history,
|
||||
"clarified_research_topic": clarified_topic,
|
||||
"is_clarification_complete": False,
|
||||
"clarified_question": "",
|
||||
"goto": goto,
|
||||
@@ -521,7 +485,7 @@ def coordinator_node(
|
||||
else:
|
||||
# Max rounds reached - no more questions allowed
|
||||
logger.warning(
|
||||
f"Max clarification rounds ({max_clarification_rounds}) reached. Handing off to planner."
|
||||
f"Max clarification rounds ({max_clarification_rounds}) reached. Handing off to planner. Using prepared clarified topic: {clarified_topic}"
|
||||
)
|
||||
goto = "planner"
|
||||
if state.get("enable_background_investigation"):
|
||||
@@ -539,7 +503,7 @@ def coordinator_node(
|
||||
# ============================================================
|
||||
# Final: Build and return Command
|
||||
# ============================================================
|
||||
messages = state.get("messages", [])
|
||||
messages = list(state.get("messages", []) or [])
|
||||
if response.content:
|
||||
messages.append(HumanMessage(content=response.content, name="coordinator"))
|
||||
|
||||
@@ -554,10 +518,20 @@ def coordinator_node(
|
||||
logger.info("Handing off to planner")
|
||||
goto = "planner"
|
||||
|
||||
# Extract locale and research_topic if provided
|
||||
if tool_args.get("locale") and tool_args.get("research_topic"):
|
||||
locale = tool_args.get("locale")
|
||||
research_topic = tool_args.get("research_topic")
|
||||
# Extract locale if provided
|
||||
locale = tool_args.get("locale", locale)
|
||||
if not enable_clarification and tool_args.get("research_topic"):
|
||||
research_topic = tool_args["research_topic"]
|
||||
|
||||
if enable_clarification:
|
||||
logger.info(
|
||||
"Using prepared clarified topic: %s",
|
||||
clarified_topic or research_topic,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Using research topic for handoff: %s", research_topic
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
@@ -584,16 +558,24 @@ def coordinator_node(
|
||||
clarification_rounds = 0
|
||||
clarification_history = []
|
||||
|
||||
clarified_research_topic_value = clarified_topic or research_topic
|
||||
|
||||
if enable_clarification:
|
||||
handoff_topic = clarified_topic or research_topic
|
||||
else:
|
||||
handoff_topic = research_topic
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"messages": messages,
|
||||
"locale": locale,
|
||||
"research_topic": research_topic,
|
||||
"clarified_research_topic": clarified_research_topic_value,
|
||||
"resources": configurable.resources,
|
||||
"clarification_rounds": clarification_rounds,
|
||||
"clarification_history": clarification_history,
|
||||
"is_clarification_complete": goto != "coordinator",
|
||||
"clarified_question": research_topic if goto != "coordinator" else "",
|
||||
"clarified_question": handoff_topic if goto != "coordinator" else "",
|
||||
"goto": goto,
|
||||
},
|
||||
goto=goto,
|
||||
@@ -747,14 +729,15 @@ async def _execute_agent_step(
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_traceback = traceback.format_exc()
|
||||
error_message = f"Error executing {agent_name} agent for step '{current_step.title}': {str(e)}"
|
||||
logger.exception(error_message)
|
||||
logger.error(f"Full traceback:\n{error_traceback}")
|
||||
|
||||
|
||||
detailed_error = f"[ERROR] {agent_name.capitalize()} Agent Error\n\nStep: {current_step.title}\n\nError Details:\n{str(e)}\n\nPlease check the logs for more information."
|
||||
current_step.execution_res = detailed_error
|
||||
|
||||
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
from dataclasses import field
|
||||
|
||||
from langgraph.graph import MessagesState
|
||||
|
||||
from src.prompts.planner_model import Plan
|
||||
@@ -14,6 +16,7 @@ class State(MessagesState):
|
||||
# Runtime Variables
|
||||
locale: str = "en-US"
|
||||
research_topic: str = ""
|
||||
clarified_research_topic: str = ""
|
||||
observations: list[str] = []
|
||||
resources: list[Resource] = []
|
||||
plan_iterations: int = 0
|
||||
@@ -28,7 +31,7 @@ class State(MessagesState):
|
||||
False # Enable/disable clarification feature (default: False)
|
||||
)
|
||||
clarification_rounds: int = 0
|
||||
clarification_history: list[str] = []
|
||||
clarification_history: list[str] = field(default_factory=list)
|
||||
is_clarification_complete: bool = False
|
||||
clarified_question: str = ""
|
||||
max_clarification_rounds: int = (
|
||||
|
||||
113
src/graph/utils.py
Normal file
113
src/graph/utils.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import Any
|
||||
|
||||
ASSISTANT_SPEAKER_NAMES = {
|
||||
"coordinator",
|
||||
"planner",
|
||||
"researcher",
|
||||
"coder",
|
||||
"reporter",
|
||||
"background_investigator",
|
||||
}
|
||||
|
||||
|
||||
def get_message_content(message: Any) -> str:
|
||||
"""Extract message content from dict or LangChain message."""
|
||||
if isinstance(message, dict):
|
||||
return message.get("content", "")
|
||||
return getattr(message, "content", "")
|
||||
|
||||
|
||||
def is_user_message(message: Any) -> bool:
|
||||
"""Return True if the message originated from the end user."""
|
||||
if isinstance(message, dict):
|
||||
role = (message.get("role") or "").lower()
|
||||
if role in {"user", "human"}:
|
||||
return True
|
||||
if role in {"assistant", "system"}:
|
||||
return False
|
||||
name = (message.get("name") or "").lower()
|
||||
if name and name in ASSISTANT_SPEAKER_NAMES:
|
||||
return False
|
||||
return role == "" and name not in ASSISTANT_SPEAKER_NAMES
|
||||
|
||||
message_type = (getattr(message, "type", "") or "").lower()
|
||||
name = (getattr(message, "name", "") or "").lower()
|
||||
if message_type == "human":
|
||||
return not (name and name in ASSISTANT_SPEAKER_NAMES)
|
||||
|
||||
role_attr = getattr(message, "role", None)
|
||||
if isinstance(role_attr, str) and role_attr.lower() in {"user", "human"}:
|
||||
return True
|
||||
|
||||
additional_role = getattr(message, "additional_kwargs", {}).get("role")
|
||||
if isinstance(additional_role, str) and additional_role.lower() in {
|
||||
"user",
|
||||
"human",
|
||||
}:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_latest_user_message(messages: list[Any]) -> tuple[Any, str]:
|
||||
"""Return the latest user-authored message and its content."""
|
||||
for message in reversed(messages or []):
|
||||
if is_user_message(message):
|
||||
content = get_message_content(message)
|
||||
if content:
|
||||
return message, content
|
||||
return None, ""
|
||||
|
||||
|
||||
def build_clarified_topic_from_history(
|
||||
clarification_history: list[str],
|
||||
) -> tuple[str, list[str]]:
|
||||
"""Construct clarified topic string from an ordered clarification history."""
|
||||
sequence = [item for item in clarification_history if item]
|
||||
if not sequence:
|
||||
return "", []
|
||||
if len(sequence) == 1:
|
||||
return sequence[0], sequence
|
||||
head, *tail = sequence
|
||||
clarified_string = f"{head} - {', '.join(tail)}"
|
||||
return clarified_string, sequence
|
||||
|
||||
|
||||
def reconstruct_clarification_history(
|
||||
messages: list[Any],
|
||||
fallback_history: list[str] | None = None,
|
||||
base_topic: str = "",
|
||||
) -> list[str]:
|
||||
"""Rebuild clarification history from user-authored messages, with fallback.
|
||||
|
||||
Args:
|
||||
messages: Conversation messages in chronological order.
|
||||
fallback_history: Optional existing history to use if no user messages found.
|
||||
base_topic: Optional topic to use when no user messages are available.
|
||||
|
||||
Returns:
|
||||
A cleaned clarification history containing unique consecutive user contents.
|
||||
"""
|
||||
sequence: list[str] = []
|
||||
for message in messages or []:
|
||||
if not is_user_message(message):
|
||||
continue
|
||||
content = get_message_content(message)
|
||||
if not content:
|
||||
continue
|
||||
if sequence and sequence[-1] == content:
|
||||
continue
|
||||
sequence.append(content)
|
||||
|
||||
if sequence:
|
||||
return sequence
|
||||
|
||||
fallback = [item for item in (fallback_history or []) if item]
|
||||
if fallback:
|
||||
return fallback
|
||||
|
||||
base_topic = (base_topic or "").strip()
|
||||
return [base_topic] if base_topic else []
|
||||
@@ -25,6 +25,10 @@ from src.config.report_style import ReportStyle
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER
|
||||
from src.graph.builder import build_graph_with_memory
|
||||
from src.graph.checkpoint import chat_stream_message
|
||||
from src.graph.utils import (
|
||||
build_clarified_topic_from_history,
|
||||
reconstruct_clarification_history,
|
||||
)
|
||||
from src.llms.llm import get_configured_llm_models
|
||||
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
||||
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
||||
@@ -160,7 +164,7 @@ def _create_event_stream_message(
|
||||
content = message_chunk.content
|
||||
if not isinstance(content, str):
|
||||
content = json.dumps(content, ensure_ascii=False)
|
||||
|
||||
|
||||
event_stream_message = {
|
||||
"thread_id": thread_id,
|
||||
"agent": agent_name,
|
||||
@@ -309,6 +313,14 @@ async def _astream_workflow_generator(
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
_process_initial_messages(message, thread_id)
|
||||
|
||||
clarification_history = reconstruct_clarification_history(messages)
|
||||
|
||||
clarified_topic, clarification_history = build_clarified_topic_from_history(
|
||||
clarification_history
|
||||
)
|
||||
latest_message_content = messages[-1]["content"] if messages else ""
|
||||
clarified_research_topic = clarified_topic or latest_message_content
|
||||
|
||||
# Prepare workflow input
|
||||
workflow_input = {
|
||||
"messages": messages,
|
||||
@@ -318,7 +330,9 @@ async def _astream_workflow_generator(
|
||||
"observations": [],
|
||||
"auto_accepted_plan": auto_accepted_plan,
|
||||
"enable_background_investigation": enable_background_investigation,
|
||||
"research_topic": messages[-1]["content"] if messages else "",
|
||||
"research_topic": latest_message_content,
|
||||
"clarification_history": clarification_history,
|
||||
"clarified_research_topic": clarified_research_topic,
|
||||
"enable_clarification": enable_clarification,
|
||||
"max_clarification_rounds": max_clarification_rounds,
|
||||
}
|
||||
|
||||
@@ -208,7 +208,7 @@ class SearchResultPostProcessor:
|
||||
url = image_url_val.get("url", "")
|
||||
else:
|
||||
url = image_url_val
|
||||
|
||||
|
||||
if url and url not in seen_urls:
|
||||
seen_urls.add(url)
|
||||
return result.copy() # Return a copy to avoid modifying original
|
||||
|
||||
@@ -5,6 +5,7 @@ import logging
|
||||
|
||||
from src.config.configuration import get_recursion_limit
|
||||
from src.graph import build_graph
|
||||
from src.graph.utils import build_clarified_topic_from_history
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -65,6 +66,8 @@ async def run_agent_workflow_async(
|
||||
"auto_accepted_plan": True,
|
||||
"enable_background_investigation": enable_background_investigation,
|
||||
}
|
||||
initial_state["research_topic"] = user_input
|
||||
initial_state["clarified_research_topic"] = user_input
|
||||
|
||||
# Only set clarification parameter if explicitly provided
|
||||
# If None, State class default will be used (enable_clarification=False)
|
||||
@@ -137,7 +140,18 @@ async def run_agent_workflow_async(
|
||||
current_state["messages"] = final_state["messages"] + [
|
||||
{"role": "user", "content": user_response}
|
||||
]
|
||||
# Recursive call for clarification continuation
|
||||
for key in (
|
||||
"clarification_history",
|
||||
"clarification_rounds",
|
||||
"clarified_research_topic",
|
||||
"research_topic",
|
||||
"locale",
|
||||
"enable_clarification",
|
||||
"max_clarification_rounds",
|
||||
):
|
||||
if key in final_state:
|
||||
current_state[key] = final_state[key]
|
||||
|
||||
return await run_agent_workflow_async(
|
||||
user_input=user_response,
|
||||
max_plan_iterations=max_plan_iterations,
|
||||
|
||||
@@ -451,7 +451,9 @@ def test_human_feedback_node_accepted(monkeypatch, mock_state_base, mock_config)
|
||||
assert result.update["current_plan"]["has_enough_context"] is False
|
||||
|
||||
|
||||
def test_human_feedback_node_invalid_interrupt(monkeypatch, mock_state_base, mock_config):
|
||||
def test_human_feedback_node_invalid_interrupt(
|
||||
monkeypatch, mock_state_base, mock_config
|
||||
):
|
||||
# interrupt returns something else, should raise TypeError
|
||||
state = dict(mock_state_base)
|
||||
state["auto_accepted_plan"] = False
|
||||
@@ -490,7 +492,9 @@ def test_human_feedback_node_json_decode_error_second_iteration(
|
||||
assert result.goto == "reporter"
|
||||
|
||||
|
||||
def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base, mock_config):
|
||||
def test_human_feedback_node_not_enough_context(
|
||||
monkeypatch, mock_state_base, mock_config
|
||||
):
|
||||
# Plan does not have enough context, should goto research_team
|
||||
plan = {
|
||||
"has_enough_context": False,
|
||||
@@ -1446,7 +1450,9 @@ def test_handoff_tools():
|
||||
assert result is None # Tool should return None (no-op)
|
||||
|
||||
# Test handoff_after_clarification tool - use invoke() method
|
||||
result = handoff_after_clarification.invoke({"locale": "en-US"})
|
||||
result = handoff_after_clarification.invoke(
|
||||
{"locale": "en-US", "research_topic": "renewable energy research"}
|
||||
)
|
||||
assert result is None # Tool should return None (no-op)
|
||||
|
||||
|
||||
@@ -1468,9 +1474,13 @@ def test_coordinator_tools_with_clarification_enabled(mock_get_llm):
|
||||
"clarification_rounds": 2,
|
||||
"max_clarification_rounds": 3,
|
||||
"is_clarification_complete": False,
|
||||
"clarification_history": ["response 1", "response 2"],
|
||||
"clarification_history": [
|
||||
"Tell me about something",
|
||||
"response 1",
|
||||
"response 2",
|
||||
],
|
||||
"locale": "en-US",
|
||||
"research_topic": "",
|
||||
"research_topic": "Tell me about something",
|
||||
}
|
||||
|
||||
# Mock config
|
||||
@@ -1567,3 +1577,289 @@ def test_coordinator_empty_llm_response_corner_case(mock_get_llm):
|
||||
# Should gracefully handle empty response by going to planner to ensure workflow continues
|
||||
assert result.goto == "planner"
|
||||
assert result.update["locale"] == "en-US"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Clarification flow tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_clarification_handoff_combines_history():
|
||||
"""Coordinator should merge original topic with all clarification answers before handoff."""
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
test_state = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Research artificial intelligence"},
|
||||
{"role": "assistant", "content": "Which area of AI should we focus on?"},
|
||||
{"role": "user", "content": "Machine learning applications"},
|
||||
{"role": "assistant", "content": "What dimension of that should we cover?"},
|
||||
{"role": "user", "content": "Technical implementation details"},
|
||||
],
|
||||
"enable_clarification": True,
|
||||
"clarification_rounds": 2,
|
||||
"clarification_history": [
|
||||
"Research artificial intelligence",
|
||||
"Machine learning applications",
|
||||
"Technical implementation details",
|
||||
],
|
||||
"max_clarification_rounds": 3,
|
||||
"research_topic": "Research artificial intelligence",
|
||||
"clarified_research_topic": "Research artificial intelligence - Machine learning applications, Technical implementation details",
|
||||
"locale": "en-US",
|
||||
}
|
||||
|
||||
config = RunnableConfig(configurable={"thread_id": "clarification-test"})
|
||||
|
||||
mock_response = AIMessage(
|
||||
content="Understood, handing off now.",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "handoff_after_clarification",
|
||||
"args": {"locale": "en-US", "research_topic": "placeholder"},
|
||||
"id": "tool-call-handoff",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = coordinator_node(test_state, config)
|
||||
|
||||
assert hasattr(result, "update")
|
||||
update = result.update
|
||||
assert update["clarification_history"] == [
|
||||
"Research artificial intelligence",
|
||||
"Machine learning applications",
|
||||
"Technical implementation details",
|
||||
]
|
||||
expected_topic = (
|
||||
"Research artificial intelligence - "
|
||||
"Machine learning applications, Technical implementation details"
|
||||
)
|
||||
assert update["research_topic"] == "Research artificial intelligence"
|
||||
assert update["clarified_research_topic"] == expected_topic
|
||||
assert update["clarified_question"] == expected_topic
|
||||
|
||||
|
||||
def test_clarification_history_reconstructed_from_messages():
|
||||
"""Coordinator should rebuild clarification history from full message log when state is incomplete."""
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
incomplete_state = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Research on renewable energy"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Which type of renewable energy interests you?",
|
||||
},
|
||||
{"role": "user", "content": "Solar and wind energy"},
|
||||
{"role": "assistant", "content": "Which aspect should we focus on?"},
|
||||
{"role": "user", "content": "Technical implementation"},
|
||||
],
|
||||
"enable_clarification": True,
|
||||
"clarification_rounds": 2,
|
||||
"clarification_history": ["Technical implementation"],
|
||||
"max_clarification_rounds": 3,
|
||||
"research_topic": "Research on renewable energy",
|
||||
"clarified_research_topic": "Research on renewable energy",
|
||||
"locale": "en-US",
|
||||
}
|
||||
|
||||
config = RunnableConfig(configurable={"thread_id": "clarification-history-rebuild"})
|
||||
|
||||
mock_response = AIMessage(
|
||||
content="Understood, handing over now.",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "handoff_after_clarification",
|
||||
"args": {"locale": "en-US", "research_topic": "placeholder"},
|
||||
"id": "tool-call-handoff",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = coordinator_node(incomplete_state, config)
|
||||
|
||||
update = result.update
|
||||
assert update["clarification_history"] == [
|
||||
"Research on renewable energy",
|
||||
"Solar and wind energy",
|
||||
"Technical implementation",
|
||||
]
|
||||
assert update["research_topic"] == "Research on renewable energy"
|
||||
assert (
|
||||
update["clarified_research_topic"]
|
||||
== "Research on renewable energy - Solar and wind energy, Technical implementation"
|
||||
)
|
||||
|
||||
|
||||
def test_clarification_max_rounds_without_tool_call():
|
||||
"""Coordinator should stop asking questions after max rounds and hand off with compiled topic."""
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
test_state = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Research artificial intelligence"},
|
||||
{"role": "assistant", "content": "Which area should we focus on?"},
|
||||
{"role": "user", "content": "Natural language processing"},
|
||||
{"role": "assistant", "content": "Which domain matters most?"},
|
||||
{"role": "user", "content": "Healthcare"},
|
||||
{"role": "assistant", "content": "Any specific scenario to study?"},
|
||||
{"role": "user", "content": "Clinical documentation"},
|
||||
],
|
||||
"enable_clarification": True,
|
||||
"clarification_rounds": 3,
|
||||
"clarification_history": [
|
||||
"Research artificial intelligence",
|
||||
"Natural language processing",
|
||||
"Healthcare",
|
||||
"Clinical documentation",
|
||||
],
|
||||
"max_clarification_rounds": 3,
|
||||
"research_topic": "Research artificial intelligence",
|
||||
"clarified_research_topic": "Research artificial intelligence - Natural language processing, Healthcare, Clinical documentation",
|
||||
"locale": "en-US",
|
||||
}
|
||||
|
||||
config = RunnableConfig(configurable={"thread_id": "clarification-max"})
|
||||
|
||||
mock_response = AIMessage(
|
||||
content="Got it, sending this to the planner.",
|
||||
tool_calls=[],
|
||||
)
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = coordinator_node(test_state, config)
|
||||
|
||||
assert hasattr(result, "update")
|
||||
update = result.update
|
||||
expected_topic = (
|
||||
"Research artificial intelligence - "
|
||||
"Natural language processing, Healthcare, Clinical documentation"
|
||||
)
|
||||
assert update["research_topic"] == "Research artificial intelligence"
|
||||
assert update["clarified_research_topic"] == expected_topic
|
||||
assert result.goto == "planner"
|
||||
|
||||
|
||||
def test_clarification_human_message_support():
|
||||
"""Coordinator should treat HumanMessage instances from the user as user authored."""
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
test_state = {
|
||||
"messages": [
|
||||
HumanMessage(content="Research artificial intelligence"),
|
||||
HumanMessage(content="Which area should we focus on?", name="coordinator"),
|
||||
HumanMessage(content="Machine learning"),
|
||||
HumanMessage(
|
||||
content="Which dimension should we explore?", name="coordinator"
|
||||
),
|
||||
HumanMessage(content="Technical feasibility"),
|
||||
],
|
||||
"enable_clarification": True,
|
||||
"clarification_rounds": 2,
|
||||
"clarification_history": [
|
||||
"Research artificial intelligence",
|
||||
"Machine learning",
|
||||
"Technical feasibility",
|
||||
],
|
||||
"max_clarification_rounds": 3,
|
||||
"research_topic": "Research artificial intelligence",
|
||||
"clarified_research_topic": "Research artificial intelligence - Machine learning, Technical feasibility",
|
||||
"locale": "en-US",
|
||||
}
|
||||
|
||||
config = RunnableConfig(configurable={"thread_id": "clarification-human"})
|
||||
|
||||
mock_response = AIMessage(
|
||||
content="Moving to planner.",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "handoff_after_clarification",
|
||||
"args": {"locale": "en-US", "research_topic": "placeholder"},
|
||||
"id": "human-message-handoff",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = coordinator_node(test_state, config)
|
||||
|
||||
assert hasattr(result, "update")
|
||||
update = result.update
|
||||
expected_topic = (
|
||||
"Research artificial intelligence - Machine learning, Technical feasibility"
|
||||
)
|
||||
assert update["clarification_history"] == [
|
||||
"Research artificial intelligence",
|
||||
"Machine learning",
|
||||
"Technical feasibility",
|
||||
]
|
||||
assert update["research_topic"] == "Research artificial intelligence"
|
||||
assert update["clarified_research_topic"] == expected_topic
|
||||
|
||||
|
||||
def test_clarification_no_history_defaults_to_topic():
|
||||
"""If clarification never started, coordinator should forward the original topic."""
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
test_state = {
|
||||
"messages": [{"role": "user", "content": "What is quantum computing?"}],
|
||||
"enable_clarification": True,
|
||||
"clarification_rounds": 0,
|
||||
"clarification_history": ["What is quantum computing?"],
|
||||
"max_clarification_rounds": 3,
|
||||
"research_topic": "What is quantum computing?",
|
||||
"clarified_research_topic": "What is quantum computing?",
|
||||
"locale": "en-US",
|
||||
}
|
||||
|
||||
config = RunnableConfig(configurable={"thread_id": "clarification-none"})
|
||||
|
||||
mock_response = AIMessage(
|
||||
content="Understood.",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "handoff_to_planner",
|
||||
"args": {"locale": "en-US", "research_topic": "placeholder"},
|
||||
"id": "clarification-none",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = coordinator_node(test_state, config)
|
||||
|
||||
assert hasattr(result, "update")
|
||||
assert result.update["research_topic"] == "What is quantum computing?"
|
||||
assert result.update["clarified_research_topic"] == "What is quantum computing?"
|
||||
|
||||
@@ -47,6 +47,79 @@ class TestMakeEvent:
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_astream_workflow_generator_preserves_clarification_history():
|
||||
messages = [
|
||||
{"role": "user", "content": "Research on renewable energy"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "What type of renewable energy would you like to know about?",
|
||||
},
|
||||
{"role": "user", "content": "Solar and wind energy"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Please tell me the research dimensions you focus on, such as technological development or market applications.",
|
||||
},
|
||||
{"role": "user", "content": "Technological development"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Please specify the time range you want to focus on, such as current status or future trends.",
|
||||
},
|
||||
{"role": "user", "content": "Current status and future trends"},
|
||||
]
|
||||
|
||||
captured_data = {}
|
||||
|
||||
def empty_async_iterator(*args, **kwargs):
|
||||
captured_data["workflow_input"] = args[1]
|
||||
captured_data["workflow_config"] = args[2]
|
||||
|
||||
class IteratorObject:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
raise StopAsyncIteration
|
||||
|
||||
return IteratorObject()
|
||||
|
||||
with (
|
||||
patch("src.server.app._process_initial_messages"),
|
||||
patch("src.server.app._stream_graph_events", side_effect=empty_async_iterator),
|
||||
):
|
||||
generator = _astream_workflow_generator(
|
||||
messages=messages,
|
||||
thread_id="clarification-thread",
|
||||
resources=[],
|
||||
max_plan_iterations=1,
|
||||
max_step_num=1,
|
||||
max_search_results=5,
|
||||
auto_accepted_plan=True,
|
||||
interrupt_feedback="",
|
||||
mcp_settings={},
|
||||
enable_background_investigation=True,
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
enable_deep_thinking=False,
|
||||
enable_clarification=True,
|
||||
max_clarification_rounds=3,
|
||||
)
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await generator.__anext__()
|
||||
|
||||
workflow_input = captured_data["workflow_input"]
|
||||
assert workflow_input["clarification_history"] == [
|
||||
"Research on renewable energy",
|
||||
"Solar and wind energy",
|
||||
"Technological development",
|
||||
"Current status and future trends",
|
||||
]
|
||||
assert (
|
||||
workflow_input["clarified_research_topic"]
|
||||
== "Research on renewable energy - Solar and wind energy, Technological development, Current status and future trends"
|
||||
)
|
||||
|
||||
|
||||
class TestTTSEndpoint:
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
|
||||
Reference in New Issue
Block a user