fix: Refine clarification workflow state handling (#641)

* fix: support local models by making thought field optional in Plan model

- Make thought field optional in Plan model to fix Pydantic validation errors with local models
- Add Ollama configuration example to conf.yaml.example
- Update documentation to include local model support
- Improve planner prompt with better JSON format requirements

Fixes local model integration issues where models like qwen3:14b would fail
due to missing thought field in JSON output.

* feat: Add intelligent clarification feature for research queries

- Add multi-turn clarification process to refine vague research questions
- Implement three-dimension clarification standard (Tech/App, Focus, Scope)
- Add clarification state management in coordinator node
- Update coordinator prompt with detailed clarification guidelines
- Add UI settings to enable/disable clarification feature (disabled by default)
- Update workflow to handle clarification rounds recursively
- Add comprehensive test coverage for clarification functionality
- Update documentation with clarification feature usage guide

Key components:
- src/graph/nodes.py: Core clarification logic and state management
- src/prompts/coordinator.md: Detailed clarification guidelines
- src/workflow.py: Recursive clarification handling
- web/: UI settings integration
- tests/: Comprehensive test coverage
- docs/: Updated configuration guide

* fix: Improve clarification conversation continuity

- Add comprehensive conversation history to clarification context
- Include previous exchanges summary in system messages
- Add explicit guidelines for continuing rounds in coordinator prompt
- Prevent LLM from starting new topics during clarification
- Ensure topic continuity across clarification rounds

Fixes issue where LLM would restart clarification instead of building upon previous exchanges.

* fix: Add conversation history to clarification context

* fix: resolve clarification feature message to planer, prompt, test issues

- Optimize coordinator.md prompt template for better clarification flow
- Simplify final message sent to planner after clarification
- Fix API key assertion issues in test_search.py

* fix: Add configurable max_clarification_rounds and comprehensive tests

- Add max_clarification_rounds parameter for external configuration
- Add comprehensive test cases for clarification feature in test_app.py
- Fixes issues found during interactive mode testing where:
  - Recursive call failed due to missing initial_state parameter
  - Clarification exited prematurely at max rounds
  - Incorrect logging of max rounds reached

* Move clarification tests to test_nodes.py and add max_clarification_rounds to zh.json

* fix: add max_clarification_rounds parameter passing from frontend to backend

- Add max_clarification_rounds parameter in store.ts sendMessage function
- Add max_clarification_rounds type definition in chat.ts
- Ensure frontend settings page clarification rounds are correctly passed to backend

* fix: refine clarification workflow state handling and coverage

- Add clarification history reconstruction
- Fix clarified topic accumulation
- Add clarified_research_topic state field
- Preserve clarification state in recursive calls
- Add comprehensive test coverage

* refactor: optimize coordinator logic and type annotations

- Simplify handoff topic logic in coordinator_node
- Update type annotations from Tuple to tuple
- Improve code readability and maintainability

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
jimmyuconn1982
2025-10-22 22:49:07 +08:00
committed by GitHub
parent 9371ad23ee
commit 003f081a7b
9 changed files with 615 additions and 117 deletions

View File

@@ -51,7 +51,9 @@ class Configuration:
mcp_settings: dict = None # MCP settings, including dynamic loaded tools
report_style: str = ReportStyle.ACADEMIC.value # Report style
enable_deep_thinking: bool = False # Whether to enable deep thinking
enforce_web_search: bool = False # Enforce at least one web search step in every plan
enforce_web_search: bool = (
False # Enforce at least one web search step in every plan
)
@classmethod
def from_runnable_config(

View File

@@ -31,6 +31,12 @@ from src.utils.json_utils import repair_json_output
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
from .types import State
from .utils import (
build_clarified_topic_from_history,
get_message_content,
is_user_message,
reconstruct_clarification_history,
)
logger = logging.getLogger(__name__)
@@ -49,6 +55,9 @@ def handoff_to_planner(
@tool
def handoff_after_clarification(
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
research_topic: Annotated[
str, "The clarified research topic based on all clarification rounds."
],
):
"""Handoff to planner after clarification rounds are complete. Pass all clarification history to planner for analysis."""
return
@@ -78,23 +87,23 @@ def needs_clarification(state: dict) -> bool:
def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict:
"""
Validate and fix a plan to ensure it meets requirements.
Args:
plan: The plan dict to validate
enforce_web_search: If True, ensure at least one step has need_search=true
Returns:
The validated/fixed plan dict
"""
if not isinstance(plan, dict):
return plan
steps = plan.get("steps", [])
if enforce_web_search:
# Check if any step has need_search=true
has_search_step = any(step.get("need_search", False) for step in steps)
if not has_search_step and steps:
# Ensure first research step has web search enabled
for idx, step in enumerate(steps):
@@ -107,7 +116,9 @@ def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict:
# This ensures that at least one step will perform a web search as required.
steps[0]["step_type"] = "research"
steps[0]["need_search"] = True
logger.info("Converted first step to research with web search enforcement")
logger.info(
"Converted first step to research with web search enforcement"
)
elif not has_search_step and not steps:
# Add a default research step if no steps exist
logger.warning("Plan has no steps. Adding default research step.")
@@ -119,14 +130,14 @@ def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict:
"step_type": "research",
}
]
return plan
def background_investigation_node(state: State, config: RunnableConfig):
logger.info("background investigation node is running.")
configurable = Configuration.from_runnable_config(config)
query = state.get("research_topic")
query = state.get("clarified_research_topic") or state.get("research_topic")
background_investigation_results = None
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
searched_content = LoggedTavilySearch(
@@ -230,11 +241,11 @@ def planner_node(
return Command(goto="reporter")
else:
return Command(goto="__end__")
# Validate and fix plan to ensure web search requirements are met
if isinstance(curr_plan, dict):
curr_plan = validate_and_fix_plan(curr_plan, configurable.enforce_web_search)
if isinstance(curr_plan, dict) and curr_plan.get("has_enough_context"):
logger.info("Planner response has enough context.")
new_plan = Plan.model_validate(curr_plan)
@@ -316,7 +327,8 @@ def coordinator_node(
# Check if clarification is enabled
enable_clarification = state.get("enable_clarification", False)
initial_topic = state.get("research_topic", "")
clarified_topic = initial_topic
# ============================================================
# BRANCH 1: Clarification DISABLED (Legacy Mode)
# ============================================================
@@ -338,7 +350,6 @@ def coordinator_node(
.invoke(messages)
)
# Process response - should directly handoff to planner
goto = "__end__"
locale = state.get("locale", "en-US")
research_topic = state.get("research_topic", "")
@@ -370,12 +381,28 @@ def coordinator_node(
else:
# Load clarification state
clarification_rounds = state.get("clarification_rounds", 0)
clarification_history = state.get("clarification_history", [])
clarification_history = list(state.get("clarification_history", []) or [])
clarification_history = [item for item in clarification_history if item]
max_clarification_rounds = state.get("max_clarification_rounds", 3)
# Prepare the messages for the coordinator
state_messages = list(state.get("messages", []))
messages = apply_prompt_template("coordinator", state)
clarification_history = reconstruct_clarification_history(
state_messages, clarification_history, initial_topic
)
clarified_topic, clarification_history = build_clarified_topic_from_history(
clarification_history
)
logger.debug("Clarification history rebuilt: %s", clarification_history)
if clarification_history:
initial_topic = clarification_history[0]
latest_user_content = clarification_history[-1]
else:
latest_user_content = ""
# Add clarification status for first round
if clarification_rounds == 0:
messages.append(
@@ -385,91 +412,21 @@ def coordinator_node(
}
)
# Add clarification context if continuing conversation (round > 0)
elif clarification_rounds > 0:
logger.info(
f"Clarification enabled (rounds: {clarification_rounds}/{max_clarification_rounds}): Continuing conversation"
)
logger.info(
"Clarification round %s/%s | topic: %s | latest user content: %s",
clarification_rounds,
max_clarification_rounds,
clarified_topic or initial_topic,
latest_user_content or "N/A",
)
# Add user's response to clarification history (only user messages)
last_message = None
if state.get("messages"):
last_message = state["messages"][-1]
# Extract content from last message for logging
if isinstance(last_message, dict):
content = last_message.get("content", "No content")
else:
content = getattr(last_message, "content", "No content")
logger.info(f"Last message content: {content}")
# Handle dict format
if isinstance(last_message, dict):
if last_message.get("role") == "user":
clarification_history.append(last_message["content"])
logger.info(
f"Added user response to clarification history: {last_message['content']}"
)
# Handle object format (like HumanMessage)
elif hasattr(last_message, "role") and last_message.role == "user":
clarification_history.append(last_message.content)
logger.info(
f"Added user response to clarification history: {last_message.content}"
)
# Handle object format with content attribute (like the one in logs)
elif hasattr(last_message, "content"):
clarification_history.append(last_message.content)
logger.info(
f"Added user response to clarification history: {last_message.content}"
)
current_response = latest_user_content or "No response"
# Build comprehensive clarification context with conversation history
current_response = "No response"
if last_message:
# Handle dict format
if isinstance(last_message, dict):
if last_message.get("role") == "user":
current_response = last_message.get("content", "No response")
else:
# If last message is not from user, try to get the latest user message
messages = state.get("messages", [])
for msg in reversed(messages):
if isinstance(msg, dict) and msg.get("role") == "user":
current_response = msg.get("content", "No response")
break
# Handle object format (like HumanMessage)
elif hasattr(last_message, "role") and last_message.role == "user":
current_response = last_message.content
# Handle object format with content attribute (like the one in logs)
elif hasattr(last_message, "content"):
current_response = last_message.content
else:
# If last message is not from user, try to get the latest user message
messages = state.get("messages", [])
for msg in reversed(messages):
if isinstance(msg, dict) and msg.get("role") == "user":
current_response = msg.get("content", "No response")
break
elif hasattr(msg, "role") and msg.role == "user":
current_response = msg.content
break
elif hasattr(msg, "content"):
current_response = msg.content
break
# Create conversation history summary
conversation_summary = ""
if clarification_history:
conversation_summary = "Previous conversation:\n"
for i, response in enumerate(clarification_history, 1):
conversation_summary += f"- Round {i}: {response}\n"
clarification_context = f"""Continuing clarification (round {clarification_rounds}/{max_clarification_rounds}):
clarification_context = f"""Continuing clarification (round {clarification_rounds}/{max_clarification_rounds}):
User's latest response: {current_response}
Ask for remaining missing dimensions. Do NOT repeat questions or start new topics."""
# Log the clarification context for debugging
logger.info(f"Clarification context: {clarification_context}")
messages.append({"role": "system", "content": clarification_context})
messages.append({"role": "system", "content": clarification_context})
# Bind both clarification tools
tools = [handoff_to_planner, handoff_after_clarification]
@@ -483,7 +440,13 @@ def coordinator_node(
# Initialize response processing variables
goto = "__end__"
locale = state.get("locale", "en-US")
research_topic = state.get("research_topic", "")
research_topic = (
clarification_history[0]
if clarification_history
else state.get("research_topic", "")
)
if not clarified_topic:
clarified_topic = research_topic
# --- Process LLM response ---
# No tool calls - LLM is asking a clarifying question
@@ -497,20 +460,21 @@ def coordinator_node(
)
# Append coordinator's question to messages
state_messages = state.get("messages", [])
updated_messages = list(state_messages)
if response.content:
state_messages.append(
updated_messages.append(
HumanMessage(content=response.content, name="coordinator")
)
return Command(
update={
"messages": state_messages,
"messages": updated_messages,
"locale": locale,
"research_topic": research_topic,
"resources": configurable.resources,
"clarification_rounds": clarification_rounds,
"clarification_history": clarification_history,
"clarified_research_topic": clarified_topic,
"is_clarification_complete": False,
"clarified_question": "",
"goto": goto,
@@ -521,7 +485,7 @@ def coordinator_node(
else:
# Max rounds reached - no more questions allowed
logger.warning(
f"Max clarification rounds ({max_clarification_rounds}) reached. Handing off to planner."
f"Max clarification rounds ({max_clarification_rounds}) reached. Handing off to planner. Using prepared clarified topic: {clarified_topic}"
)
goto = "planner"
if state.get("enable_background_investigation"):
@@ -539,7 +503,7 @@ def coordinator_node(
# ============================================================
# Final: Build and return Command
# ============================================================
messages = state.get("messages", [])
messages = list(state.get("messages", []) or [])
if response.content:
messages.append(HumanMessage(content=response.content, name="coordinator"))
@@ -554,10 +518,20 @@ def coordinator_node(
logger.info("Handing off to planner")
goto = "planner"
# Extract locale and research_topic if provided
if tool_args.get("locale") and tool_args.get("research_topic"):
locale = tool_args.get("locale")
research_topic = tool_args.get("research_topic")
# Extract locale if provided
locale = tool_args.get("locale", locale)
if not enable_clarification and tool_args.get("research_topic"):
research_topic = tool_args["research_topic"]
if enable_clarification:
logger.info(
"Using prepared clarified topic: %s",
clarified_topic or research_topic,
)
else:
logger.info(
"Using research topic for handoff: %s", research_topic
)
break
except Exception as e:
@@ -584,16 +558,24 @@ def coordinator_node(
clarification_rounds = 0
clarification_history = []
clarified_research_topic_value = clarified_topic or research_topic
if enable_clarification:
handoff_topic = clarified_topic or research_topic
else:
handoff_topic = research_topic
return Command(
update={
"messages": messages,
"locale": locale,
"research_topic": research_topic,
"clarified_research_topic": clarified_research_topic_value,
"resources": configurable.resources,
"clarification_rounds": clarification_rounds,
"clarification_history": clarification_history,
"is_clarification_complete": goto != "coordinator",
"clarified_question": research_topic if goto != "coordinator" else "",
"clarified_question": handoff_topic if goto != "coordinator" else "",
"goto": goto,
},
goto=goto,
@@ -747,14 +729,15 @@ async def _execute_agent_step(
)
except Exception as e:
import traceback
error_traceback = traceback.format_exc()
error_message = f"Error executing {agent_name} agent for step '{current_step.title}': {str(e)}"
logger.exception(error_message)
logger.error(f"Full traceback:\n{error_traceback}")
detailed_error = f"[ERROR] {agent_name.capitalize()} Agent Error\n\nStep: {current_step.title}\n\nError Details:\n{str(e)}\n\nPlease check the logs for more information."
current_step.execution_res = detailed_error
return Command(
update={
"messages": [

View File

@@ -2,6 +2,8 @@
# SPDX-License-Identifier: MIT
from dataclasses import field
from langgraph.graph import MessagesState
from src.prompts.planner_model import Plan
@@ -14,6 +16,7 @@ class State(MessagesState):
# Runtime Variables
locale: str = "en-US"
research_topic: str = ""
clarified_research_topic: str = ""
observations: list[str] = []
resources: list[Resource] = []
plan_iterations: int = 0
@@ -28,7 +31,7 @@ class State(MessagesState):
False # Enable/disable clarification feature (default: False)
)
clarification_rounds: int = 0
clarification_history: list[str] = []
clarification_history: list[str] = field(default_factory=list)
is_clarification_complete: bool = False
clarified_question: str = ""
max_clarification_rounds: int = (

113
src/graph/utils.py Normal file
View File

@@ -0,0 +1,113 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from typing import Any
ASSISTANT_SPEAKER_NAMES = {
"coordinator",
"planner",
"researcher",
"coder",
"reporter",
"background_investigator",
}
def get_message_content(message: Any) -> str:
"""Extract message content from dict or LangChain message."""
if isinstance(message, dict):
return message.get("content", "")
return getattr(message, "content", "")
def is_user_message(message: Any) -> bool:
"""Return True if the message originated from the end user."""
if isinstance(message, dict):
role = (message.get("role") or "").lower()
if role in {"user", "human"}:
return True
if role in {"assistant", "system"}:
return False
name = (message.get("name") or "").lower()
if name and name in ASSISTANT_SPEAKER_NAMES:
return False
return role == "" and name not in ASSISTANT_SPEAKER_NAMES
message_type = (getattr(message, "type", "") or "").lower()
name = (getattr(message, "name", "") or "").lower()
if message_type == "human":
return not (name and name in ASSISTANT_SPEAKER_NAMES)
role_attr = getattr(message, "role", None)
if isinstance(role_attr, str) and role_attr.lower() in {"user", "human"}:
return True
additional_role = getattr(message, "additional_kwargs", {}).get("role")
if isinstance(additional_role, str) and additional_role.lower() in {
"user",
"human",
}:
return True
return False
def get_latest_user_message(messages: list[Any]) -> tuple[Any, str]:
"""Return the latest user-authored message and its content."""
for message in reversed(messages or []):
if is_user_message(message):
content = get_message_content(message)
if content:
return message, content
return None, ""
def build_clarified_topic_from_history(
clarification_history: list[str],
) -> tuple[str, list[str]]:
"""Construct clarified topic string from an ordered clarification history."""
sequence = [item for item in clarification_history if item]
if not sequence:
return "", []
if len(sequence) == 1:
return sequence[0], sequence
head, *tail = sequence
clarified_string = f"{head} - {', '.join(tail)}"
return clarified_string, sequence
def reconstruct_clarification_history(
messages: list[Any],
fallback_history: list[str] | None = None,
base_topic: str = "",
) -> list[str]:
"""Rebuild clarification history from user-authored messages, with fallback.
Args:
messages: Conversation messages in chronological order.
fallback_history: Optional existing history to use if no user messages found.
base_topic: Optional topic to use when no user messages are available.
Returns:
A cleaned clarification history containing unique consecutive user contents.
"""
sequence: list[str] = []
for message in messages or []:
if not is_user_message(message):
continue
content = get_message_content(message)
if not content:
continue
if sequence and sequence[-1] == content:
continue
sequence.append(content)
if sequence:
return sequence
fallback = [item for item in (fallback_history or []) if item]
if fallback:
return fallback
base_topic = (base_topic or "").strip()
return [base_topic] if base_topic else []

View File

@@ -25,6 +25,10 @@ from src.config.report_style import ReportStyle
from src.config.tools import SELECTED_RAG_PROVIDER
from src.graph.builder import build_graph_with_memory
from src.graph.checkpoint import chat_stream_message
from src.graph.utils import (
build_clarified_topic_from_history,
reconstruct_clarification_history,
)
from src.llms.llm import get_configured_llm_models
from src.podcast.graph.builder import build_graph as build_podcast_graph
from src.ppt.graph.builder import build_graph as build_ppt_graph
@@ -160,7 +164,7 @@ def _create_event_stream_message(
content = message_chunk.content
if not isinstance(content, str):
content = json.dumps(content, ensure_ascii=False)
event_stream_message = {
"thread_id": thread_id,
"agent": agent_name,
@@ -309,6 +313,14 @@ async def _astream_workflow_generator(
if isinstance(message, dict) and "content" in message:
_process_initial_messages(message, thread_id)
clarification_history = reconstruct_clarification_history(messages)
clarified_topic, clarification_history = build_clarified_topic_from_history(
clarification_history
)
latest_message_content = messages[-1]["content"] if messages else ""
clarified_research_topic = clarified_topic or latest_message_content
# Prepare workflow input
workflow_input = {
"messages": messages,
@@ -318,7 +330,9 @@ async def _astream_workflow_generator(
"observations": [],
"auto_accepted_plan": auto_accepted_plan,
"enable_background_investigation": enable_background_investigation,
"research_topic": messages[-1]["content"] if messages else "",
"research_topic": latest_message_content,
"clarification_history": clarification_history,
"clarified_research_topic": clarified_research_topic,
"enable_clarification": enable_clarification,
"max_clarification_rounds": max_clarification_rounds,
}

View File

@@ -208,7 +208,7 @@ class SearchResultPostProcessor:
url = image_url_val.get("url", "")
else:
url = image_url_val
if url and url not in seen_urls:
seen_urls.add(url)
return result.copy() # Return a copy to avoid modifying original

View File

@@ -5,6 +5,7 @@ import logging
from src.config.configuration import get_recursion_limit
from src.graph import build_graph
from src.graph.utils import build_clarified_topic_from_history
# Configure logging
logging.basicConfig(
@@ -65,6 +66,8 @@ async def run_agent_workflow_async(
"auto_accepted_plan": True,
"enable_background_investigation": enable_background_investigation,
}
initial_state["research_topic"] = user_input
initial_state["clarified_research_topic"] = user_input
# Only set clarification parameter if explicitly provided
# If None, State class default will be used (enable_clarification=False)
@@ -137,7 +140,18 @@ async def run_agent_workflow_async(
current_state["messages"] = final_state["messages"] + [
{"role": "user", "content": user_response}
]
# Recursive call for clarification continuation
for key in (
"clarification_history",
"clarification_rounds",
"clarified_research_topic",
"research_topic",
"locale",
"enable_clarification",
"max_clarification_rounds",
):
if key in final_state:
current_state[key] = final_state[key]
return await run_agent_workflow_async(
user_input=user_response,
max_plan_iterations=max_plan_iterations,

View File

@@ -451,7 +451,9 @@ def test_human_feedback_node_accepted(monkeypatch, mock_state_base, mock_config)
assert result.update["current_plan"]["has_enough_context"] is False
def test_human_feedback_node_invalid_interrupt(monkeypatch, mock_state_base, mock_config):
def test_human_feedback_node_invalid_interrupt(
monkeypatch, mock_state_base, mock_config
):
# interrupt returns something else, should raise TypeError
state = dict(mock_state_base)
state["auto_accepted_plan"] = False
@@ -490,7 +492,9 @@ def test_human_feedback_node_json_decode_error_second_iteration(
assert result.goto == "reporter"
def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base, mock_config):
def test_human_feedback_node_not_enough_context(
monkeypatch, mock_state_base, mock_config
):
# Plan does not have enough context, should goto research_team
plan = {
"has_enough_context": False,
@@ -1446,7 +1450,9 @@ def test_handoff_tools():
assert result is None # Tool should return None (no-op)
# Test handoff_after_clarification tool - use invoke() method
result = handoff_after_clarification.invoke({"locale": "en-US"})
result = handoff_after_clarification.invoke(
{"locale": "en-US", "research_topic": "renewable energy research"}
)
assert result is None # Tool should return None (no-op)
@@ -1468,9 +1474,13 @@ def test_coordinator_tools_with_clarification_enabled(mock_get_llm):
"clarification_rounds": 2,
"max_clarification_rounds": 3,
"is_clarification_complete": False,
"clarification_history": ["response 1", "response 2"],
"clarification_history": [
"Tell me about something",
"response 1",
"response 2",
],
"locale": "en-US",
"research_topic": "",
"research_topic": "Tell me about something",
}
# Mock config
@@ -1567,3 +1577,289 @@ def test_coordinator_empty_llm_response_corner_case(mock_get_llm):
# Should gracefully handle empty response by going to planner to ensure workflow continues
assert result.goto == "planner"
assert result.update["locale"] == "en-US"
# ============================================================================
# Clarification flow tests
# ============================================================================
def test_clarification_handoff_combines_history():
"""Coordinator should merge original topic with all clarification answers before handoff."""
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
test_state = {
"messages": [
{"role": "user", "content": "Research artificial intelligence"},
{"role": "assistant", "content": "Which area of AI should we focus on?"},
{"role": "user", "content": "Machine learning applications"},
{"role": "assistant", "content": "What dimension of that should we cover?"},
{"role": "user", "content": "Technical implementation details"},
],
"enable_clarification": True,
"clarification_rounds": 2,
"clarification_history": [
"Research artificial intelligence",
"Machine learning applications",
"Technical implementation details",
],
"max_clarification_rounds": 3,
"research_topic": "Research artificial intelligence",
"clarified_research_topic": "Research artificial intelligence - Machine learning applications, Technical implementation details",
"locale": "en-US",
}
config = RunnableConfig(configurable={"thread_id": "clarification-test"})
mock_response = AIMessage(
content="Understood, handing off now.",
tool_calls=[
{
"name": "handoff_after_clarification",
"args": {"locale": "en-US", "research_topic": "placeholder"},
"id": "tool-call-handoff",
"type": "tool_call",
}
],
)
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
mock_llm = MagicMock()
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = coordinator_node(test_state, config)
assert hasattr(result, "update")
update = result.update
assert update["clarification_history"] == [
"Research artificial intelligence",
"Machine learning applications",
"Technical implementation details",
]
expected_topic = (
"Research artificial intelligence - "
"Machine learning applications, Technical implementation details"
)
assert update["research_topic"] == "Research artificial intelligence"
assert update["clarified_research_topic"] == expected_topic
assert update["clarified_question"] == expected_topic
def test_clarification_history_reconstructed_from_messages():
"""Coordinator should rebuild clarification history from full message log when state is incomplete."""
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
incomplete_state = {
"messages": [
{"role": "user", "content": "Research on renewable energy"},
{
"role": "assistant",
"content": "Which type of renewable energy interests you?",
},
{"role": "user", "content": "Solar and wind energy"},
{"role": "assistant", "content": "Which aspect should we focus on?"},
{"role": "user", "content": "Technical implementation"},
],
"enable_clarification": True,
"clarification_rounds": 2,
"clarification_history": ["Technical implementation"],
"max_clarification_rounds": 3,
"research_topic": "Research on renewable energy",
"clarified_research_topic": "Research on renewable energy",
"locale": "en-US",
}
config = RunnableConfig(configurable={"thread_id": "clarification-history-rebuild"})
mock_response = AIMessage(
content="Understood, handing over now.",
tool_calls=[
{
"name": "handoff_after_clarification",
"args": {"locale": "en-US", "research_topic": "placeholder"},
"id": "tool-call-handoff",
"type": "tool_call",
}
],
)
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
mock_llm = MagicMock()
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = coordinator_node(incomplete_state, config)
update = result.update
assert update["clarification_history"] == [
"Research on renewable energy",
"Solar and wind energy",
"Technical implementation",
]
assert update["research_topic"] == "Research on renewable energy"
assert (
update["clarified_research_topic"]
== "Research on renewable energy - Solar and wind energy, Technical implementation"
)
def test_clarification_max_rounds_without_tool_call():
"""Coordinator should stop asking questions after max rounds and hand off with compiled topic."""
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
test_state = {
"messages": [
{"role": "user", "content": "Research artificial intelligence"},
{"role": "assistant", "content": "Which area should we focus on?"},
{"role": "user", "content": "Natural language processing"},
{"role": "assistant", "content": "Which domain matters most?"},
{"role": "user", "content": "Healthcare"},
{"role": "assistant", "content": "Any specific scenario to study?"},
{"role": "user", "content": "Clinical documentation"},
],
"enable_clarification": True,
"clarification_rounds": 3,
"clarification_history": [
"Research artificial intelligence",
"Natural language processing",
"Healthcare",
"Clinical documentation",
],
"max_clarification_rounds": 3,
"research_topic": "Research artificial intelligence",
"clarified_research_topic": "Research artificial intelligence - Natural language processing, Healthcare, Clinical documentation",
"locale": "en-US",
}
config = RunnableConfig(configurable={"thread_id": "clarification-max"})
mock_response = AIMessage(
content="Got it, sending this to the planner.",
tool_calls=[],
)
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
mock_llm = MagicMock()
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = coordinator_node(test_state, config)
assert hasattr(result, "update")
update = result.update
expected_topic = (
"Research artificial intelligence - "
"Natural language processing, Healthcare, Clinical documentation"
)
assert update["research_topic"] == "Research artificial intelligence"
assert update["clarified_research_topic"] == expected_topic
assert result.goto == "planner"
def test_clarification_human_message_support():
"""Coordinator should treat HumanMessage instances from the user as user authored."""
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
test_state = {
"messages": [
HumanMessage(content="Research artificial intelligence"),
HumanMessage(content="Which area should we focus on?", name="coordinator"),
HumanMessage(content="Machine learning"),
HumanMessage(
content="Which dimension should we explore?", name="coordinator"
),
HumanMessage(content="Technical feasibility"),
],
"enable_clarification": True,
"clarification_rounds": 2,
"clarification_history": [
"Research artificial intelligence",
"Machine learning",
"Technical feasibility",
],
"max_clarification_rounds": 3,
"research_topic": "Research artificial intelligence",
"clarified_research_topic": "Research artificial intelligence - Machine learning, Technical feasibility",
"locale": "en-US",
}
config = RunnableConfig(configurable={"thread_id": "clarification-human"})
mock_response = AIMessage(
content="Moving to planner.",
tool_calls=[
{
"name": "handoff_after_clarification",
"args": {"locale": "en-US", "research_topic": "placeholder"},
"id": "human-message-handoff",
"type": "tool_call",
}
],
)
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
mock_llm = MagicMock()
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = coordinator_node(test_state, config)
assert hasattr(result, "update")
update = result.update
expected_topic = (
"Research artificial intelligence - Machine learning, Technical feasibility"
)
assert update["clarification_history"] == [
"Research artificial intelligence",
"Machine learning",
"Technical feasibility",
]
assert update["research_topic"] == "Research artificial intelligence"
assert update["clarified_research_topic"] == expected_topic
def test_clarification_no_history_defaults_to_topic():
"""If clarification never started, coordinator should forward the original topic."""
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
test_state = {
"messages": [{"role": "user", "content": "What is quantum computing?"}],
"enable_clarification": True,
"clarification_rounds": 0,
"clarification_history": ["What is quantum computing?"],
"max_clarification_rounds": 3,
"research_topic": "What is quantum computing?",
"clarified_research_topic": "What is quantum computing?",
"locale": "en-US",
}
config = RunnableConfig(configurable={"thread_id": "clarification-none"})
mock_response = AIMessage(
content="Understood.",
tool_calls=[
{
"name": "handoff_to_planner",
"args": {"locale": "en-US", "research_topic": "placeholder"},
"id": "clarification-none",
"type": "tool_call",
}
],
)
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
mock_llm = MagicMock()
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = coordinator_node(test_state, config)
assert hasattr(result, "update")
assert result.update["research_topic"] == "What is quantum computing?"
assert result.update["clarified_research_topic"] == "What is quantum computing?"

View File

@@ -47,6 +47,79 @@ class TestMakeEvent:
assert result == expected
@pytest.mark.asyncio
async def test_astream_workflow_generator_preserves_clarification_history():
messages = [
{"role": "user", "content": "Research on renewable energy"},
{
"role": "assistant",
"content": "What type of renewable energy would you like to know about?",
},
{"role": "user", "content": "Solar and wind energy"},
{
"role": "assistant",
"content": "Please tell me the research dimensions you focus on, such as technological development or market applications.",
},
{"role": "user", "content": "Technological development"},
{
"role": "assistant",
"content": "Please specify the time range you want to focus on, such as current status or future trends.",
},
{"role": "user", "content": "Current status and future trends"},
]
captured_data = {}
def empty_async_iterator(*args, **kwargs):
captured_data["workflow_input"] = args[1]
captured_data["workflow_config"] = args[2]
class IteratorObject:
def __aiter__(self):
return self
async def __anext__(self):
raise StopAsyncIteration
return IteratorObject()
with (
patch("src.server.app._process_initial_messages"),
patch("src.server.app._stream_graph_events", side_effect=empty_async_iterator),
):
generator = _astream_workflow_generator(
messages=messages,
thread_id="clarification-thread",
resources=[],
max_plan_iterations=1,
max_step_num=1,
max_search_results=5,
auto_accepted_plan=True,
interrupt_feedback="",
mcp_settings={},
enable_background_investigation=True,
report_style=ReportStyle.ACADEMIC,
enable_deep_thinking=False,
enable_clarification=True,
max_clarification_rounds=3,
)
with pytest.raises(StopAsyncIteration):
await generator.__anext__()
workflow_input = captured_data["workflow_input"]
assert workflow_input["clarification_history"] == [
"Research on renewable energy",
"Solar and wind energy",
"Technological development",
"Current status and future trends",
]
assert (
workflow_input["clarified_research_topic"]
== "Research on renewable energy - Solar and wind energy, Technological development, Current status and future trends"
)
class TestTTSEndpoint:
@patch.dict(
os.environ,