fix: Refine clarification workflow state handling (#641)

* fix: support local models by making thought field optional in Plan model

- Make thought field optional in Plan model to fix Pydantic validation errors with local models
- Add Ollama configuration example to conf.yaml.example
- Update documentation to include local model support
- Improve planner prompt with better JSON format requirements

Fixes local model integration issues where models like qwen3:14b would fail
due to missing thought field in JSON output.

* feat: Add intelligent clarification feature for research queries

- Add multi-turn clarification process to refine vague research questions
- Implement three-dimension clarification standard (Tech/App, Focus, Scope)
- Add clarification state management in coordinator node
- Update coordinator prompt with detailed clarification guidelines
- Add UI settings to enable/disable clarification feature (disabled by default)
- Update workflow to handle clarification rounds recursively
- Add comprehensive test coverage for clarification functionality
- Update documentation with clarification feature usage guide

Key components:
- src/graph/nodes.py: Core clarification logic and state management
- src/prompts/coordinator.md: Detailed clarification guidelines
- src/workflow.py: Recursive clarification handling
- web/: UI settings integration
- tests/: Comprehensive test coverage
- docs/: Updated configuration guide

* fix: Improve clarification conversation continuity

- Add comprehensive conversation history to clarification context
- Include previous exchanges summary in system messages
- Add explicit guidelines for continuing rounds in coordinator prompt
- Prevent LLM from starting new topics during clarification
- Ensure topic continuity across clarification rounds

Fixes issue where LLM would restart clarification instead of building upon previous exchanges.

* fix: Add conversation history to clarification context

* fix: resolve clarification feature message to planer, prompt, test issues

- Optimize coordinator.md prompt template for better clarification flow
- Simplify final message sent to planner after clarification
- Fix API key assertion issues in test_search.py

* fix: Add configurable max_clarification_rounds and comprehensive tests

- Add max_clarification_rounds parameter for external configuration
- Add comprehensive test cases for clarification feature in test_app.py
- Fixes issues found during interactive mode testing where:
  - Recursive call failed due to missing initial_state parameter
  - Clarification exited prematurely at max rounds
  - Incorrect logging of max rounds reached

* Move clarification tests to test_nodes.py and add max_clarification_rounds to zh.json

* fix: add max_clarification_rounds parameter passing from frontend to backend

- Add max_clarification_rounds parameter in store.ts sendMessage function
- Add max_clarification_rounds type definition in chat.ts
- Ensure frontend settings page clarification rounds are correctly passed to backend

* fix: refine clarification workflow state handling and coverage

- Add clarification history reconstruction
- Fix clarified topic accumulation
- Add clarified_research_topic state field
- Preserve clarification state in recursive calls
- Add comprehensive test coverage

* refactor: optimize coordinator logic and type annotations

- Simplify handoff topic logic in coordinator_node
- Update type annotations from Tuple to tuple
- Improve code readability and maintainability

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
jimmyuconn1982
2025-10-22 22:49:07 +08:00
committed by GitHub
parent 9371ad23ee
commit 003f081a7b
9 changed files with 615 additions and 117 deletions

View File

@@ -51,7 +51,9 @@ class Configuration:
mcp_settings: dict = None # MCP settings, including dynamic loaded tools mcp_settings: dict = None # MCP settings, including dynamic loaded tools
report_style: str = ReportStyle.ACADEMIC.value # Report style report_style: str = ReportStyle.ACADEMIC.value # Report style
enable_deep_thinking: bool = False # Whether to enable deep thinking enable_deep_thinking: bool = False # Whether to enable deep thinking
enforce_web_search: bool = False # Enforce at least one web search step in every plan enforce_web_search: bool = (
False # Enforce at least one web search step in every plan
)
@classmethod @classmethod
def from_runnable_config( def from_runnable_config(

View File

@@ -31,6 +31,12 @@ from src.utils.json_utils import repair_json_output
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
from .types import State from .types import State
from .utils import (
build_clarified_topic_from_history,
get_message_content,
is_user_message,
reconstruct_clarification_history,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -49,6 +55,9 @@ def handoff_to_planner(
@tool @tool
def handoff_after_clarification( def handoff_after_clarification(
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."], locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
research_topic: Annotated[
str, "The clarified research topic based on all clarification rounds."
],
): ):
"""Handoff to planner after clarification rounds are complete. Pass all clarification history to planner for analysis.""" """Handoff to planner after clarification rounds are complete. Pass all clarification history to planner for analysis."""
return return
@@ -107,7 +116,9 @@ def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict:
# This ensures that at least one step will perform a web search as required. # This ensures that at least one step will perform a web search as required.
steps[0]["step_type"] = "research" steps[0]["step_type"] = "research"
steps[0]["need_search"] = True steps[0]["need_search"] = True
logger.info("Converted first step to research with web search enforcement") logger.info(
"Converted first step to research with web search enforcement"
)
elif not has_search_step and not steps: elif not has_search_step and not steps:
# Add a default research step if no steps exist # Add a default research step if no steps exist
logger.warning("Plan has no steps. Adding default research step.") logger.warning("Plan has no steps. Adding default research step.")
@@ -126,7 +137,7 @@ def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict:
def background_investigation_node(state: State, config: RunnableConfig): def background_investigation_node(state: State, config: RunnableConfig):
logger.info("background investigation node is running.") logger.info("background investigation node is running.")
configurable = Configuration.from_runnable_config(config) configurable = Configuration.from_runnable_config(config)
query = state.get("research_topic") query = state.get("clarified_research_topic") or state.get("research_topic")
background_investigation_results = None background_investigation_results = None
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value: if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
searched_content = LoggedTavilySearch( searched_content = LoggedTavilySearch(
@@ -316,7 +327,8 @@ def coordinator_node(
# Check if clarification is enabled # Check if clarification is enabled
enable_clarification = state.get("enable_clarification", False) enable_clarification = state.get("enable_clarification", False)
initial_topic = state.get("research_topic", "")
clarified_topic = initial_topic
# ============================================================ # ============================================================
# BRANCH 1: Clarification DISABLED (Legacy Mode) # BRANCH 1: Clarification DISABLED (Legacy Mode)
# ============================================================ # ============================================================
@@ -338,7 +350,6 @@ def coordinator_node(
.invoke(messages) .invoke(messages)
) )
# Process response - should directly handoff to planner
goto = "__end__" goto = "__end__"
locale = state.get("locale", "en-US") locale = state.get("locale", "en-US")
research_topic = state.get("research_topic", "") research_topic = state.get("research_topic", "")
@@ -370,12 +381,28 @@ def coordinator_node(
else: else:
# Load clarification state # Load clarification state
clarification_rounds = state.get("clarification_rounds", 0) clarification_rounds = state.get("clarification_rounds", 0)
clarification_history = state.get("clarification_history", []) clarification_history = list(state.get("clarification_history", []) or [])
clarification_history = [item for item in clarification_history if item]
max_clarification_rounds = state.get("max_clarification_rounds", 3) max_clarification_rounds = state.get("max_clarification_rounds", 3)
# Prepare the messages for the coordinator # Prepare the messages for the coordinator
state_messages = list(state.get("messages", []))
messages = apply_prompt_template("coordinator", state) messages = apply_prompt_template("coordinator", state)
clarification_history = reconstruct_clarification_history(
state_messages, clarification_history, initial_topic
)
clarified_topic, clarification_history = build_clarified_topic_from_history(
clarification_history
)
logger.debug("Clarification history rebuilt: %s", clarification_history)
if clarification_history:
initial_topic = clarification_history[0]
latest_user_content = clarification_history[-1]
else:
latest_user_content = ""
# Add clarification status for first round # Add clarification status for first round
if clarification_rounds == 0: if clarification_rounds == 0:
messages.append( messages.append(
@@ -385,91 +412,21 @@ def coordinator_node(
} }
) )
# Add clarification context if continuing conversation (round > 0) logger.info(
elif clarification_rounds > 0: "Clarification round %s/%s | topic: %s | latest user content: %s",
logger.info( clarification_rounds,
f"Clarification enabled (rounds: {clarification_rounds}/{max_clarification_rounds}): Continuing conversation" max_clarification_rounds,
) clarified_topic or initial_topic,
latest_user_content or "N/A",
)
# Add user's response to clarification history (only user messages) current_response = latest_user_content or "No response"
last_message = None
if state.get("messages"):
last_message = state["messages"][-1]
# Extract content from last message for logging
if isinstance(last_message, dict):
content = last_message.get("content", "No content")
else:
content = getattr(last_message, "content", "No content")
logger.info(f"Last message content: {content}")
# Handle dict format
if isinstance(last_message, dict):
if last_message.get("role") == "user":
clarification_history.append(last_message["content"])
logger.info(
f"Added user response to clarification history: {last_message['content']}"
)
# Handle object format (like HumanMessage)
elif hasattr(last_message, "role") and last_message.role == "user":
clarification_history.append(last_message.content)
logger.info(
f"Added user response to clarification history: {last_message.content}"
)
# Handle object format with content attribute (like the one in logs)
elif hasattr(last_message, "content"):
clarification_history.append(last_message.content)
logger.info(
f"Added user response to clarification history: {last_message.content}"
)
# Build comprehensive clarification context with conversation history clarification_context = f"""Continuing clarification (round {clarification_rounds}/{max_clarification_rounds}):
current_response = "No response"
if last_message:
# Handle dict format
if isinstance(last_message, dict):
if last_message.get("role") == "user":
current_response = last_message.get("content", "No response")
else:
# If last message is not from user, try to get the latest user message
messages = state.get("messages", [])
for msg in reversed(messages):
if isinstance(msg, dict) and msg.get("role") == "user":
current_response = msg.get("content", "No response")
break
# Handle object format (like HumanMessage)
elif hasattr(last_message, "role") and last_message.role == "user":
current_response = last_message.content
# Handle object format with content attribute (like the one in logs)
elif hasattr(last_message, "content"):
current_response = last_message.content
else:
# If last message is not from user, try to get the latest user message
messages = state.get("messages", [])
for msg in reversed(messages):
if isinstance(msg, dict) and msg.get("role") == "user":
current_response = msg.get("content", "No response")
break
elif hasattr(msg, "role") and msg.role == "user":
current_response = msg.content
break
elif hasattr(msg, "content"):
current_response = msg.content
break
# Create conversation history summary
conversation_summary = ""
if clarification_history:
conversation_summary = "Previous conversation:\n"
for i, response in enumerate(clarification_history, 1):
conversation_summary += f"- Round {i}: {response}\n"
clarification_context = f"""Continuing clarification (round {clarification_rounds}/{max_clarification_rounds}):
User's latest response: {current_response} User's latest response: {current_response}
Ask for remaining missing dimensions. Do NOT repeat questions or start new topics.""" Ask for remaining missing dimensions. Do NOT repeat questions or start new topics."""
# Log the clarification context for debugging messages.append({"role": "system", "content": clarification_context})
logger.info(f"Clarification context: {clarification_context}")
messages.append({"role": "system", "content": clarification_context})
# Bind both clarification tools # Bind both clarification tools
tools = [handoff_to_planner, handoff_after_clarification] tools = [handoff_to_planner, handoff_after_clarification]
@@ -483,7 +440,13 @@ def coordinator_node(
# Initialize response processing variables # Initialize response processing variables
goto = "__end__" goto = "__end__"
locale = state.get("locale", "en-US") locale = state.get("locale", "en-US")
research_topic = state.get("research_topic", "") research_topic = (
clarification_history[0]
if clarification_history
else state.get("research_topic", "")
)
if not clarified_topic:
clarified_topic = research_topic
# --- Process LLM response --- # --- Process LLM response ---
# No tool calls - LLM is asking a clarifying question # No tool calls - LLM is asking a clarifying question
@@ -497,20 +460,21 @@ def coordinator_node(
) )
# Append coordinator's question to messages # Append coordinator's question to messages
state_messages = state.get("messages", []) updated_messages = list(state_messages)
if response.content: if response.content:
state_messages.append( updated_messages.append(
HumanMessage(content=response.content, name="coordinator") HumanMessage(content=response.content, name="coordinator")
) )
return Command( return Command(
update={ update={
"messages": state_messages, "messages": updated_messages,
"locale": locale, "locale": locale,
"research_topic": research_topic, "research_topic": research_topic,
"resources": configurable.resources, "resources": configurable.resources,
"clarification_rounds": clarification_rounds, "clarification_rounds": clarification_rounds,
"clarification_history": clarification_history, "clarification_history": clarification_history,
"clarified_research_topic": clarified_topic,
"is_clarification_complete": False, "is_clarification_complete": False,
"clarified_question": "", "clarified_question": "",
"goto": goto, "goto": goto,
@@ -521,7 +485,7 @@ def coordinator_node(
else: else:
# Max rounds reached - no more questions allowed # Max rounds reached - no more questions allowed
logger.warning( logger.warning(
f"Max clarification rounds ({max_clarification_rounds}) reached. Handing off to planner." f"Max clarification rounds ({max_clarification_rounds}) reached. Handing off to planner. Using prepared clarified topic: {clarified_topic}"
) )
goto = "planner" goto = "planner"
if state.get("enable_background_investigation"): if state.get("enable_background_investigation"):
@@ -539,7 +503,7 @@ def coordinator_node(
# ============================================================ # ============================================================
# Final: Build and return Command # Final: Build and return Command
# ============================================================ # ============================================================
messages = state.get("messages", []) messages = list(state.get("messages", []) or [])
if response.content: if response.content:
messages.append(HumanMessage(content=response.content, name="coordinator")) messages.append(HumanMessage(content=response.content, name="coordinator"))
@@ -554,10 +518,20 @@ def coordinator_node(
logger.info("Handing off to planner") logger.info("Handing off to planner")
goto = "planner" goto = "planner"
# Extract locale and research_topic if provided # Extract locale if provided
if tool_args.get("locale") and tool_args.get("research_topic"): locale = tool_args.get("locale", locale)
locale = tool_args.get("locale") if not enable_clarification and tool_args.get("research_topic"):
research_topic = tool_args.get("research_topic") research_topic = tool_args["research_topic"]
if enable_clarification:
logger.info(
"Using prepared clarified topic: %s",
clarified_topic or research_topic,
)
else:
logger.info(
"Using research topic for handoff: %s", research_topic
)
break break
except Exception as e: except Exception as e:
@@ -584,16 +558,24 @@ def coordinator_node(
clarification_rounds = 0 clarification_rounds = 0
clarification_history = [] clarification_history = []
clarified_research_topic_value = clarified_topic or research_topic
if enable_clarification:
handoff_topic = clarified_topic or research_topic
else:
handoff_topic = research_topic
return Command( return Command(
update={ update={
"messages": messages, "messages": messages,
"locale": locale, "locale": locale,
"research_topic": research_topic, "research_topic": research_topic,
"clarified_research_topic": clarified_research_topic_value,
"resources": configurable.resources, "resources": configurable.resources,
"clarification_rounds": clarification_rounds, "clarification_rounds": clarification_rounds,
"clarification_history": clarification_history, "clarification_history": clarification_history,
"is_clarification_complete": goto != "coordinator", "is_clarification_complete": goto != "coordinator",
"clarified_question": research_topic if goto != "coordinator" else "", "clarified_question": handoff_topic if goto != "coordinator" else "",
"goto": goto, "goto": goto,
}, },
goto=goto, goto=goto,
@@ -747,6 +729,7 @@ async def _execute_agent_step(
) )
except Exception as e: except Exception as e:
import traceback import traceback
error_traceback = traceback.format_exc() error_traceback = traceback.format_exc()
error_message = f"Error executing {agent_name} agent for step '{current_step.title}': {str(e)}" error_message = f"Error executing {agent_name} agent for step '{current_step.title}': {str(e)}"
logger.exception(error_message) logger.exception(error_message)

View File

@@ -2,6 +2,8 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from dataclasses import field
from langgraph.graph import MessagesState from langgraph.graph import MessagesState
from src.prompts.planner_model import Plan from src.prompts.planner_model import Plan
@@ -14,6 +16,7 @@ class State(MessagesState):
# Runtime Variables # Runtime Variables
locale: str = "en-US" locale: str = "en-US"
research_topic: str = "" research_topic: str = ""
clarified_research_topic: str = ""
observations: list[str] = [] observations: list[str] = []
resources: list[Resource] = [] resources: list[Resource] = []
plan_iterations: int = 0 plan_iterations: int = 0
@@ -28,7 +31,7 @@ class State(MessagesState):
False # Enable/disable clarification feature (default: False) False # Enable/disable clarification feature (default: False)
) )
clarification_rounds: int = 0 clarification_rounds: int = 0
clarification_history: list[str] = [] clarification_history: list[str] = field(default_factory=list)
is_clarification_complete: bool = False is_clarification_complete: bool = False
clarified_question: str = "" clarified_question: str = ""
max_clarification_rounds: int = ( max_clarification_rounds: int = (

113
src/graph/utils.py Normal file
View File

@@ -0,0 +1,113 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from typing import Any
ASSISTANT_SPEAKER_NAMES = {
"coordinator",
"planner",
"researcher",
"coder",
"reporter",
"background_investigator",
}
def get_message_content(message: Any) -> str:
"""Extract message content from dict or LangChain message."""
if isinstance(message, dict):
return message.get("content", "")
return getattr(message, "content", "")
def is_user_message(message: Any) -> bool:
"""Return True if the message originated from the end user."""
if isinstance(message, dict):
role = (message.get("role") or "").lower()
if role in {"user", "human"}:
return True
if role in {"assistant", "system"}:
return False
name = (message.get("name") or "").lower()
if name and name in ASSISTANT_SPEAKER_NAMES:
return False
return role == "" and name not in ASSISTANT_SPEAKER_NAMES
message_type = (getattr(message, "type", "") or "").lower()
name = (getattr(message, "name", "") or "").lower()
if message_type == "human":
return not (name and name in ASSISTANT_SPEAKER_NAMES)
role_attr = getattr(message, "role", None)
if isinstance(role_attr, str) and role_attr.lower() in {"user", "human"}:
return True
additional_role = getattr(message, "additional_kwargs", {}).get("role")
if isinstance(additional_role, str) and additional_role.lower() in {
"user",
"human",
}:
return True
return False
def get_latest_user_message(messages: list[Any]) -> tuple[Any, str]:
"""Return the latest user-authored message and its content."""
for message in reversed(messages or []):
if is_user_message(message):
content = get_message_content(message)
if content:
return message, content
return None, ""
def build_clarified_topic_from_history(
clarification_history: list[str],
) -> tuple[str, list[str]]:
"""Construct clarified topic string from an ordered clarification history."""
sequence = [item for item in clarification_history if item]
if not sequence:
return "", []
if len(sequence) == 1:
return sequence[0], sequence
head, *tail = sequence
clarified_string = f"{head} - {', '.join(tail)}"
return clarified_string, sequence
def reconstruct_clarification_history(
messages: list[Any],
fallback_history: list[str] | None = None,
base_topic: str = "",
) -> list[str]:
"""Rebuild clarification history from user-authored messages, with fallback.
Args:
messages: Conversation messages in chronological order.
fallback_history: Optional existing history to use if no user messages found.
base_topic: Optional topic to use when no user messages are available.
Returns:
A cleaned clarification history containing unique consecutive user contents.
"""
sequence: list[str] = []
for message in messages or []:
if not is_user_message(message):
continue
content = get_message_content(message)
if not content:
continue
if sequence and sequence[-1] == content:
continue
sequence.append(content)
if sequence:
return sequence
fallback = [item for item in (fallback_history or []) if item]
if fallback:
return fallback
base_topic = (base_topic or "").strip()
return [base_topic] if base_topic else []

View File

@@ -25,6 +25,10 @@ from src.config.report_style import ReportStyle
from src.config.tools import SELECTED_RAG_PROVIDER from src.config.tools import SELECTED_RAG_PROVIDER
from src.graph.builder import build_graph_with_memory from src.graph.builder import build_graph_with_memory
from src.graph.checkpoint import chat_stream_message from src.graph.checkpoint import chat_stream_message
from src.graph.utils import (
build_clarified_topic_from_history,
reconstruct_clarification_history,
)
from src.llms.llm import get_configured_llm_models from src.llms.llm import get_configured_llm_models
from src.podcast.graph.builder import build_graph as build_podcast_graph from src.podcast.graph.builder import build_graph as build_podcast_graph
from src.ppt.graph.builder import build_graph as build_ppt_graph from src.ppt.graph.builder import build_graph as build_ppt_graph
@@ -309,6 +313,14 @@ async def _astream_workflow_generator(
if isinstance(message, dict) and "content" in message: if isinstance(message, dict) and "content" in message:
_process_initial_messages(message, thread_id) _process_initial_messages(message, thread_id)
clarification_history = reconstruct_clarification_history(messages)
clarified_topic, clarification_history = build_clarified_topic_from_history(
clarification_history
)
latest_message_content = messages[-1]["content"] if messages else ""
clarified_research_topic = clarified_topic or latest_message_content
# Prepare workflow input # Prepare workflow input
workflow_input = { workflow_input = {
"messages": messages, "messages": messages,
@@ -318,7 +330,9 @@ async def _astream_workflow_generator(
"observations": [], "observations": [],
"auto_accepted_plan": auto_accepted_plan, "auto_accepted_plan": auto_accepted_plan,
"enable_background_investigation": enable_background_investigation, "enable_background_investigation": enable_background_investigation,
"research_topic": messages[-1]["content"] if messages else "", "research_topic": latest_message_content,
"clarification_history": clarification_history,
"clarified_research_topic": clarified_research_topic,
"enable_clarification": enable_clarification, "enable_clarification": enable_clarification,
"max_clarification_rounds": max_clarification_rounds, "max_clarification_rounds": max_clarification_rounds,
} }

View File

@@ -5,6 +5,7 @@ import logging
from src.config.configuration import get_recursion_limit from src.config.configuration import get_recursion_limit
from src.graph import build_graph from src.graph import build_graph
from src.graph.utils import build_clarified_topic_from_history
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
@@ -65,6 +66,8 @@ async def run_agent_workflow_async(
"auto_accepted_plan": True, "auto_accepted_plan": True,
"enable_background_investigation": enable_background_investigation, "enable_background_investigation": enable_background_investigation,
} }
initial_state["research_topic"] = user_input
initial_state["clarified_research_topic"] = user_input
# Only set clarification parameter if explicitly provided # Only set clarification parameter if explicitly provided
# If None, State class default will be used (enable_clarification=False) # If None, State class default will be used (enable_clarification=False)
@@ -137,7 +140,18 @@ async def run_agent_workflow_async(
current_state["messages"] = final_state["messages"] + [ current_state["messages"] = final_state["messages"] + [
{"role": "user", "content": user_response} {"role": "user", "content": user_response}
] ]
# Recursive call for clarification continuation for key in (
"clarification_history",
"clarification_rounds",
"clarified_research_topic",
"research_topic",
"locale",
"enable_clarification",
"max_clarification_rounds",
):
if key in final_state:
current_state[key] = final_state[key]
return await run_agent_workflow_async( return await run_agent_workflow_async(
user_input=user_response, user_input=user_response,
max_plan_iterations=max_plan_iterations, max_plan_iterations=max_plan_iterations,

View File

@@ -451,7 +451,9 @@ def test_human_feedback_node_accepted(monkeypatch, mock_state_base, mock_config)
assert result.update["current_plan"]["has_enough_context"] is False assert result.update["current_plan"]["has_enough_context"] is False
def test_human_feedback_node_invalid_interrupt(monkeypatch, mock_state_base, mock_config): def test_human_feedback_node_invalid_interrupt(
monkeypatch, mock_state_base, mock_config
):
# interrupt returns something else, should raise TypeError # interrupt returns something else, should raise TypeError
state = dict(mock_state_base) state = dict(mock_state_base)
state["auto_accepted_plan"] = False state["auto_accepted_plan"] = False
@@ -490,7 +492,9 @@ def test_human_feedback_node_json_decode_error_second_iteration(
assert result.goto == "reporter" assert result.goto == "reporter"
def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base, mock_config): def test_human_feedback_node_not_enough_context(
monkeypatch, mock_state_base, mock_config
):
# Plan does not have enough context, should goto research_team # Plan does not have enough context, should goto research_team
plan = { plan = {
"has_enough_context": False, "has_enough_context": False,
@@ -1446,7 +1450,9 @@ def test_handoff_tools():
assert result is None # Tool should return None (no-op) assert result is None # Tool should return None (no-op)
# Test handoff_after_clarification tool - use invoke() method # Test handoff_after_clarification tool - use invoke() method
result = handoff_after_clarification.invoke({"locale": "en-US"}) result = handoff_after_clarification.invoke(
{"locale": "en-US", "research_topic": "renewable energy research"}
)
assert result is None # Tool should return None (no-op) assert result is None # Tool should return None (no-op)
@@ -1468,9 +1474,13 @@ def test_coordinator_tools_with_clarification_enabled(mock_get_llm):
"clarification_rounds": 2, "clarification_rounds": 2,
"max_clarification_rounds": 3, "max_clarification_rounds": 3,
"is_clarification_complete": False, "is_clarification_complete": False,
"clarification_history": ["response 1", "response 2"], "clarification_history": [
"Tell me about something",
"response 1",
"response 2",
],
"locale": "en-US", "locale": "en-US",
"research_topic": "", "research_topic": "Tell me about something",
} }
# Mock config # Mock config
@@ -1567,3 +1577,289 @@ def test_coordinator_empty_llm_response_corner_case(mock_get_llm):
# Should gracefully handle empty response by going to planner to ensure workflow continues # Should gracefully handle empty response by going to planner to ensure workflow continues
assert result.goto == "planner" assert result.goto == "planner"
assert result.update["locale"] == "en-US" assert result.update["locale"] == "en-US"
# ============================================================================
# Clarification flow tests
# ============================================================================
def test_clarification_handoff_combines_history():
"""Coordinator should merge original topic with all clarification answers before handoff."""
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
test_state = {
"messages": [
{"role": "user", "content": "Research artificial intelligence"},
{"role": "assistant", "content": "Which area of AI should we focus on?"},
{"role": "user", "content": "Machine learning applications"},
{"role": "assistant", "content": "What dimension of that should we cover?"},
{"role": "user", "content": "Technical implementation details"},
],
"enable_clarification": True,
"clarification_rounds": 2,
"clarification_history": [
"Research artificial intelligence",
"Machine learning applications",
"Technical implementation details",
],
"max_clarification_rounds": 3,
"research_topic": "Research artificial intelligence",
"clarified_research_topic": "Research artificial intelligence - Machine learning applications, Technical implementation details",
"locale": "en-US",
}
config = RunnableConfig(configurable={"thread_id": "clarification-test"})
mock_response = AIMessage(
content="Understood, handing off now.",
tool_calls=[
{
"name": "handoff_after_clarification",
"args": {"locale": "en-US", "research_topic": "placeholder"},
"id": "tool-call-handoff",
"type": "tool_call",
}
],
)
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
mock_llm = MagicMock()
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = coordinator_node(test_state, config)
assert hasattr(result, "update")
update = result.update
assert update["clarification_history"] == [
"Research artificial intelligence",
"Machine learning applications",
"Technical implementation details",
]
expected_topic = (
"Research artificial intelligence - "
"Machine learning applications, Technical implementation details"
)
assert update["research_topic"] == "Research artificial intelligence"
assert update["clarified_research_topic"] == expected_topic
assert update["clarified_question"] == expected_topic
def test_clarification_history_reconstructed_from_messages():
"""Coordinator should rebuild clarification history from full message log when state is incomplete."""
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
incomplete_state = {
"messages": [
{"role": "user", "content": "Research on renewable energy"},
{
"role": "assistant",
"content": "Which type of renewable energy interests you?",
},
{"role": "user", "content": "Solar and wind energy"},
{"role": "assistant", "content": "Which aspect should we focus on?"},
{"role": "user", "content": "Technical implementation"},
],
"enable_clarification": True,
"clarification_rounds": 2,
"clarification_history": ["Technical implementation"],
"max_clarification_rounds": 3,
"research_topic": "Research on renewable energy",
"clarified_research_topic": "Research on renewable energy",
"locale": "en-US",
}
config = RunnableConfig(configurable={"thread_id": "clarification-history-rebuild"})
mock_response = AIMessage(
content="Understood, handing over now.",
tool_calls=[
{
"name": "handoff_after_clarification",
"args": {"locale": "en-US", "research_topic": "placeholder"},
"id": "tool-call-handoff",
"type": "tool_call",
}
],
)
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
mock_llm = MagicMock()
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = coordinator_node(incomplete_state, config)
update = result.update
assert update["clarification_history"] == [
"Research on renewable energy",
"Solar and wind energy",
"Technical implementation",
]
assert update["research_topic"] == "Research on renewable energy"
assert (
update["clarified_research_topic"]
== "Research on renewable energy - Solar and wind energy, Technical implementation"
)
def test_clarification_max_rounds_without_tool_call():
"""Coordinator should stop asking questions after max rounds and hand off with compiled topic."""
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
test_state = {
"messages": [
{"role": "user", "content": "Research artificial intelligence"},
{"role": "assistant", "content": "Which area should we focus on?"},
{"role": "user", "content": "Natural language processing"},
{"role": "assistant", "content": "Which domain matters most?"},
{"role": "user", "content": "Healthcare"},
{"role": "assistant", "content": "Any specific scenario to study?"},
{"role": "user", "content": "Clinical documentation"},
],
"enable_clarification": True,
"clarification_rounds": 3,
"clarification_history": [
"Research artificial intelligence",
"Natural language processing",
"Healthcare",
"Clinical documentation",
],
"max_clarification_rounds": 3,
"research_topic": "Research artificial intelligence",
"clarified_research_topic": "Research artificial intelligence - Natural language processing, Healthcare, Clinical documentation",
"locale": "en-US",
}
config = RunnableConfig(configurable={"thread_id": "clarification-max"})
mock_response = AIMessage(
content="Got it, sending this to the planner.",
tool_calls=[],
)
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
mock_llm = MagicMock()
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = coordinator_node(test_state, config)
assert hasattr(result, "update")
update = result.update
expected_topic = (
"Research artificial intelligence - "
"Natural language processing, Healthcare, Clinical documentation"
)
assert update["research_topic"] == "Research artificial intelligence"
assert update["clarified_research_topic"] == expected_topic
assert result.goto == "planner"
def test_clarification_human_message_support():
"""Coordinator should treat HumanMessage instances from the user as user authored."""
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
test_state = {
"messages": [
HumanMessage(content="Research artificial intelligence"),
HumanMessage(content="Which area should we focus on?", name="coordinator"),
HumanMessage(content="Machine learning"),
HumanMessage(
content="Which dimension should we explore?", name="coordinator"
),
HumanMessage(content="Technical feasibility"),
],
"enable_clarification": True,
"clarification_rounds": 2,
"clarification_history": [
"Research artificial intelligence",
"Machine learning",
"Technical feasibility",
],
"max_clarification_rounds": 3,
"research_topic": "Research artificial intelligence",
"clarified_research_topic": "Research artificial intelligence - Machine learning, Technical feasibility",
"locale": "en-US",
}
config = RunnableConfig(configurable={"thread_id": "clarification-human"})
mock_response = AIMessage(
content="Moving to planner.",
tool_calls=[
{
"name": "handoff_after_clarification",
"args": {"locale": "en-US", "research_topic": "placeholder"},
"id": "human-message-handoff",
"type": "tool_call",
}
],
)
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
mock_llm = MagicMock()
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = coordinator_node(test_state, config)
assert hasattr(result, "update")
update = result.update
expected_topic = (
"Research artificial intelligence - Machine learning, Technical feasibility"
)
assert update["clarification_history"] == [
"Research artificial intelligence",
"Machine learning",
"Technical feasibility",
]
assert update["research_topic"] == "Research artificial intelligence"
assert update["clarified_research_topic"] == expected_topic
def test_clarification_no_history_defaults_to_topic():
"""If clarification never started, coordinator should forward the original topic."""
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
test_state = {
"messages": [{"role": "user", "content": "What is quantum computing?"}],
"enable_clarification": True,
"clarification_rounds": 0,
"clarification_history": ["What is quantum computing?"],
"max_clarification_rounds": 3,
"research_topic": "What is quantum computing?",
"clarified_research_topic": "What is quantum computing?",
"locale": "en-US",
}
config = RunnableConfig(configurable={"thread_id": "clarification-none"})
mock_response = AIMessage(
content="Understood.",
tool_calls=[
{
"name": "handoff_to_planner",
"args": {"locale": "en-US", "research_topic": "placeholder"},
"id": "clarification-none",
"type": "tool_call",
}
],
)
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
mock_llm = MagicMock()
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = coordinator_node(test_state, config)
assert hasattr(result, "update")
assert result.update["research_topic"] == "What is quantum computing?"
assert result.update["clarified_research_topic"] == "What is quantum computing?"

View File

@@ -47,6 +47,79 @@ class TestMakeEvent:
assert result == expected assert result == expected
@pytest.mark.asyncio
async def test_astream_workflow_generator_preserves_clarification_history():
messages = [
{"role": "user", "content": "Research on renewable energy"},
{
"role": "assistant",
"content": "What type of renewable energy would you like to know about?",
},
{"role": "user", "content": "Solar and wind energy"},
{
"role": "assistant",
"content": "Please tell me the research dimensions you focus on, such as technological development or market applications.",
},
{"role": "user", "content": "Technological development"},
{
"role": "assistant",
"content": "Please specify the time range you want to focus on, such as current status or future trends.",
},
{"role": "user", "content": "Current status and future trends"},
]
captured_data = {}
def empty_async_iterator(*args, **kwargs):
captured_data["workflow_input"] = args[1]
captured_data["workflow_config"] = args[2]
class IteratorObject:
def __aiter__(self):
return self
async def __anext__(self):
raise StopAsyncIteration
return IteratorObject()
with (
patch("src.server.app._process_initial_messages"),
patch("src.server.app._stream_graph_events", side_effect=empty_async_iterator),
):
generator = _astream_workflow_generator(
messages=messages,
thread_id="clarification-thread",
resources=[],
max_plan_iterations=1,
max_step_num=1,
max_search_results=5,
auto_accepted_plan=True,
interrupt_feedback="",
mcp_settings={},
enable_background_investigation=True,
report_style=ReportStyle.ACADEMIC,
enable_deep_thinking=False,
enable_clarification=True,
max_clarification_rounds=3,
)
with pytest.raises(StopAsyncIteration):
await generator.__anext__()
workflow_input = captured_data["workflow_input"]
assert workflow_input["clarification_history"] == [
"Research on renewable energy",
"Solar and wind energy",
"Technological development",
"Current status and future trends",
]
assert (
workflow_input["clarified_research_topic"]
== "Research on renewable energy - Solar and wind energy, Technological development, Current status and future trends"
)
class TestTTSEndpoint: class TestTTSEndpoint:
@patch.dict( @patch.dict(
os.environ, os.environ,