mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-28 00:04:47 +08:00
feat: Add intelligent clarification feature in coordinate step for research queries (#613)
* fix: support local models by making thought field optional in Plan model - Make thought field optional in Plan model to fix Pydantic validation errors with local models - Add Ollama configuration example to conf.yaml.example - Update documentation to include local model support - Improve planner prompt with better JSON format requirements Fixes local model integration issues where models like qwen3:14b would fail due to missing thought field in JSON output. * feat: Add intelligent clarification feature for research queries - Add multi-turn clarification process to refine vague research questions - Implement three-dimension clarification standard (Tech/App, Focus, Scope) - Add clarification state management in coordinator node - Update coordinator prompt with detailed clarification guidelines - Add UI settings to enable/disable clarification feature (disabled by default) - Update workflow to handle clarification rounds recursively - Add comprehensive test coverage for clarification functionality - Update documentation with clarification feature usage guide Key components: - src/graph/nodes.py: Core clarification logic and state management - src/prompts/coordinator.md: Detailed clarification guidelines - src/workflow.py: Recursive clarification handling - web/: UI settings integration - tests/: Comprehensive test coverage - docs/: Updated configuration guide * fix: Improve clarification conversation continuity - Add comprehensive conversation history to clarification context - Include previous exchanges summary in system messages - Add explicit guidelines for continuing rounds in coordinator prompt - Prevent LLM from starting new topics during clarification - Ensure topic continuity across clarification rounds Fixes issue where LLM would restart clarification instead of building upon previous exchanges. * fix: Add conversation history to clarification context * fix: resolve clarification feature message to planer, prompt, test issues - Optimize coordinator.md prompt template for better clarification flow - Simplify final message sent to planner after clarification - Fix API key assertion issues in test_search.py * fix: Add configurable max_clarification_rounds and comprehensive tests - Add max_clarification_rounds parameter for external configuration - Add comprehensive test cases for clarification feature in test_app.py - Fixes issues found during interactive mode testing where: - Recursive call failed due to missing initial_state parameter - Clarification exited prematurely at max rounds - Incorrect logging of max rounds reached * Move clarification tests to test_nodes.py and add max_clarification_rounds to zh.json
This commit is contained in:
@@ -514,6 +514,7 @@ def mock_state_coordinator():
|
||||
return {
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"locale": "en-US",
|
||||
"enable_clarification": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -1385,3 +1386,183 @@ async def test_researcher_node_without_resources(
|
||||
tools = args[3]
|
||||
assert patch_get_web_search_tool.return_value in tools
|
||||
assert result == "RESEARCHER_RESULT"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Clarification Feature Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clarification_workflow_integration():
|
||||
"""Test the complete clarification workflow integration."""
|
||||
import inspect
|
||||
|
||||
from src.workflow import run_agent_workflow_async
|
||||
|
||||
# Verify that the function accepts clarification parameters
|
||||
sig = inspect.signature(run_agent_workflow_async)
|
||||
assert "max_clarification_rounds" in sig.parameters
|
||||
assert "enable_clarification" in sig.parameters
|
||||
assert "initial_state" in sig.parameters
|
||||
|
||||
|
||||
def test_clarification_parameters_combinations():
|
||||
"""Test various combinations of clarification parameters."""
|
||||
from src.graph.nodes import needs_clarification
|
||||
|
||||
test_cases = [
|
||||
# (enable_clarification, clarification_rounds, max_rounds, is_complete, expected)
|
||||
(True, 0, 3, False, False), # No rounds started
|
||||
(True, 1, 3, False, True), # In progress
|
||||
(True, 2, 3, False, True), # In progress
|
||||
(True, 3, 3, False, True), # At max - still waiting for last answer
|
||||
(True, 4, 3, False, False), # Exceeded max
|
||||
(True, 1, 3, True, False), # Completed
|
||||
(False, 1, 3, False, False), # Disabled
|
||||
]
|
||||
|
||||
for enable, rounds, max_rounds, complete, expected in test_cases:
|
||||
state = {
|
||||
"enable_clarification": enable,
|
||||
"clarification_rounds": rounds,
|
||||
"max_clarification_rounds": max_rounds,
|
||||
"is_clarification_complete": complete,
|
||||
}
|
||||
|
||||
result = needs_clarification(state)
|
||||
assert result == expected, f"Failed for case: {state}"
|
||||
|
||||
|
||||
def test_handoff_tools():
|
||||
"""Test that handoff tools are properly defined."""
|
||||
from src.graph.nodes import handoff_after_clarification, handoff_to_planner
|
||||
|
||||
# Test handoff_to_planner tool - use invoke() method
|
||||
result = handoff_to_planner.invoke(
|
||||
{"research_topic": "renewable energy", "locale": "en-US"}
|
||||
)
|
||||
assert result is None # Tool should return None (no-op)
|
||||
|
||||
# Test handoff_after_clarification tool - use invoke() method
|
||||
result = handoff_after_clarification.invoke({"locale": "en-US"})
|
||||
assert result is None # Tool should return None (no-op)
|
||||
|
||||
|
||||
@patch("src.graph.nodes.get_llm_by_type")
|
||||
def test_coordinator_tools_with_clarification_enabled(mock_get_llm):
|
||||
"""Test that coordinator binds correct tools when clarification is enabled."""
|
||||
# Mock LLM response
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "Let me clarify..."
|
||||
mock_response.tool_calls = []
|
||||
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# State with clarification enabled (in progress)
|
||||
state = {
|
||||
"messages": [{"role": "user", "content": "Tell me about something"}],
|
||||
"enable_clarification": True,
|
||||
"clarification_rounds": 2,
|
||||
"max_clarification_rounds": 3,
|
||||
"is_clarification_complete": False,
|
||||
"clarification_history": ["response 1", "response 2"],
|
||||
"locale": "en-US",
|
||||
"research_topic": "",
|
||||
}
|
||||
|
||||
# Mock config
|
||||
config = {"configurable": {"resources": []}}
|
||||
|
||||
# Call coordinator_node
|
||||
coordinator_node(state, config)
|
||||
|
||||
# Verify that LLM was called with bind_tools
|
||||
assert mock_llm.bind_tools.called
|
||||
bound_tools = mock_llm.bind_tools.call_args[0][0]
|
||||
|
||||
# Should bind 2 tools when clarification is enabled
|
||||
assert len(bound_tools) == 2
|
||||
tool_names = [tool.name for tool in bound_tools]
|
||||
assert "handoff_to_planner" in tool_names
|
||||
assert "handoff_after_clarification" in tool_names
|
||||
|
||||
|
||||
@patch("src.graph.nodes.get_llm_by_type")
|
||||
def test_coordinator_tools_with_clarification_disabled(mock_get_llm):
|
||||
"""Test that coordinator binds only one tool when clarification is disabled."""
|
||||
# Mock LLM response with tool call
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = ""
|
||||
mock_response.tool_calls = [
|
||||
{
|
||||
"name": "handoff_to_planner",
|
||||
"args": {"research_topic": "test", "locale": "en-US"},
|
||||
}
|
||||
]
|
||||
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# State with clarification disabled
|
||||
state = {
|
||||
"messages": [{"role": "user", "content": "Tell me about something"}],
|
||||
"enable_clarification": False,
|
||||
"locale": "en-US",
|
||||
"research_topic": "",
|
||||
}
|
||||
|
||||
# Mock config
|
||||
config = {"configurable": {"resources": []}}
|
||||
|
||||
# Call coordinator_node
|
||||
coordinator_node(state, config)
|
||||
|
||||
# Verify that LLM was called with bind_tools
|
||||
assert mock_llm.bind_tools.called
|
||||
bound_tools = mock_llm.bind_tools.call_args[0][0]
|
||||
|
||||
# Should bind only 1 tool when clarification is disabled
|
||||
assert len(bound_tools) == 1
|
||||
assert bound_tools[0].name == "handoff_to_planner"
|
||||
|
||||
|
||||
@patch("src.graph.nodes.get_llm_by_type")
|
||||
def test_coordinator_empty_llm_response_corner_case(mock_get_llm):
|
||||
"""
|
||||
Corner case test: LLM returns empty response when clarification is enabled.
|
||||
|
||||
This tests error handling when LLM fails to return any content or tool calls
|
||||
in the initial state (clarification_rounds=0). The system should gracefully
|
||||
handle this by going to __end__ instead of crashing.
|
||||
|
||||
Note: This is NOT a typical clarification workflow test, but rather tests
|
||||
fault tolerance when LLM misbehaves.
|
||||
"""
|
||||
# Mock LLM response - empty response (failure scenario)
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = ""
|
||||
mock_response.tool_calls = []
|
||||
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# State with clarification enabled but initial round
|
||||
state = {
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"enable_clarification": True,
|
||||
# clarification_rounds: 0 (default, not started)
|
||||
"locale": "en-US",
|
||||
"research_topic": "",
|
||||
}
|
||||
|
||||
# Mock config
|
||||
config = {"configurable": {"resources": []}}
|
||||
|
||||
# Call coordinator_node - should not crash
|
||||
result = coordinator_node(state, config)
|
||||
|
||||
# Should gracefully handle empty response by going to __end__
|
||||
assert result.goto == "__end__"
|
||||
assert result.update["locale"] == "en-US"
|
||||
|
||||
@@ -96,7 +96,8 @@ def test_build_base_graph_adds_nodes_and_edges(MockStateGraph):
|
||||
# Check that all nodes and edges are added
|
||||
assert mock_builder.add_edge.call_count >= 2
|
||||
assert mock_builder.add_node.call_count >= 8
|
||||
mock_builder.add_conditional_edges.assert_called_once()
|
||||
# Now we have 2 conditional edges: research_team and coordinator
|
||||
assert mock_builder.add_conditional_edges.call_count == 2
|
||||
|
||||
|
||||
@patch("src.graph.builder._build_base_graph")
|
||||
|
||||
@@ -157,7 +157,9 @@ def test_list_local_markdown_resources_populated(temp_examples_dir):
|
||||
# File without heading -> fallback title
|
||||
(temp_examples_dir / "file_two.md").write_text("No heading here.", encoding="utf-8")
|
||||
# Non-markdown file should be ignored
|
||||
(temp_examples_dir / "ignore.txt").write_text("Should not be picked up.", encoding="utf-8")
|
||||
(temp_examples_dir / "ignore.txt").write_text(
|
||||
"Should not be picked up.", encoding="utf-8"
|
||||
)
|
||||
|
||||
resources = retriever._list_local_markdown_resources()
|
||||
# Order not guaranteed; sort by uri for assertions
|
||||
@@ -815,7 +817,9 @@ def test_load_example_files_directory_missing(monkeypatch):
|
||||
assert called["insert"] == 0 # sanity (no insertion attempted)
|
||||
|
||||
|
||||
def test_load_example_files_loads_and_skips_existing(monkeypatch, temp_load_skip_examples_dir):
|
||||
def test_load_example_files_loads_and_skips_existing(
|
||||
monkeypatch, temp_load_skip_examples_dir
|
||||
):
|
||||
_patch_init(monkeypatch)
|
||||
examples_dir_name = temp_load_skip_examples_dir.name
|
||||
|
||||
@@ -863,7 +867,9 @@ def test_load_example_files_loads_and_skips_existing(monkeypatch, temp_load_skip
|
||||
assert all(c["title"] == "Title Two" for c in calls)
|
||||
|
||||
|
||||
def test_load_example_files_single_chunk_no_suffix(monkeypatch, temp_single_chunk_examples_dir):
|
||||
def test_load_example_files_single_chunk_no_suffix(
|
||||
monkeypatch, temp_single_chunk_examples_dir
|
||||
):
|
||||
_patch_init(monkeypatch)
|
||||
examples_dir_name = temp_single_chunk_examples_dir.name
|
||||
|
||||
@@ -901,6 +907,7 @@ def test_load_example_files_single_chunk_no_suffix(monkeypatch, temp_single_chun
|
||||
# Clean up test database file after tests
|
||||
import atexit
|
||||
|
||||
|
||||
def cleanup_test_database():
|
||||
"""Clean up milvus_demo.db file created during testing."""
|
||||
import os
|
||||
|
||||
@@ -532,6 +532,8 @@ class TestAstreamWorkflowGenerator:
|
||||
enable_background_investigation=False,
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
enable_deep_thinking=False,
|
||||
enable_clarification=False,
|
||||
max_clarification_rounds=3,
|
||||
)
|
||||
|
||||
events = []
|
||||
@@ -571,6 +573,8 @@ class TestAstreamWorkflowGenerator:
|
||||
enable_background_investigation=False,
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
enable_deep_thinking=False,
|
||||
enable_clarification=False,
|
||||
max_clarification_rounds=3,
|
||||
)
|
||||
|
||||
events = []
|
||||
@@ -605,6 +609,8 @@ class TestAstreamWorkflowGenerator:
|
||||
enable_background_investigation=False,
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
enable_deep_thinking=False,
|
||||
enable_clarification=False,
|
||||
max_clarification_rounds=3,
|
||||
)
|
||||
|
||||
events = []
|
||||
@@ -641,6 +647,8 @@ class TestAstreamWorkflowGenerator:
|
||||
enable_background_investigation=False,
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
enable_deep_thinking=False,
|
||||
enable_clarification=False,
|
||||
max_clarification_rounds=3,
|
||||
)
|
||||
|
||||
events = []
|
||||
@@ -682,6 +690,8 @@ class TestAstreamWorkflowGenerator:
|
||||
enable_background_investigation=False,
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
enable_deep_thinking=False,
|
||||
enable_clarification=False,
|
||||
max_clarification_rounds=3,
|
||||
)
|
||||
|
||||
events = []
|
||||
@@ -723,6 +733,8 @@ class TestAstreamWorkflowGenerator:
|
||||
enable_background_investigation=False,
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
enable_deep_thinking=False,
|
||||
enable_clarification=False,
|
||||
max_clarification_rounds=3,
|
||||
)
|
||||
|
||||
events = []
|
||||
@@ -761,6 +773,8 @@ class TestAstreamWorkflowGenerator:
|
||||
enable_background_investigation=False,
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
enable_deep_thinking=False,
|
||||
enable_clarification=False,
|
||||
max_clarification_rounds=3,
|
||||
)
|
||||
|
||||
events = []
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from src.tools.search_postprocessor import SearchResultPostProcessor
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
|
||||
from src.utils.context_manager import ContextManager
|
||||
|
||||
|
||||
@@ -140,7 +141,6 @@ class TestContextManager:
|
||||
# return the original messages
|
||||
assert len(compressed["messages"]) == 4
|
||||
|
||||
|
||||
def test_count_message_tokens_with_additional_kwargs(self):
|
||||
"""Test counting tokens for messages with additional kwargs"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
|
||||
Reference in New Issue
Block a user