mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-25 07:04:44 +08:00
fix: Refine clarification workflow state handling (#641)
* fix: support local models by making thought field optional in Plan model - Make thought field optional in Plan model to fix Pydantic validation errors with local models - Add Ollama configuration example to conf.yaml.example - Update documentation to include local model support - Improve planner prompt with better JSON format requirements Fixes local model integration issues where models like qwen3:14b would fail due to missing thought field in JSON output. * feat: Add intelligent clarification feature for research queries - Add multi-turn clarification process to refine vague research questions - Implement three-dimension clarification standard (Tech/App, Focus, Scope) - Add clarification state management in coordinator node - Update coordinator prompt with detailed clarification guidelines - Add UI settings to enable/disable clarification feature (disabled by default) - Update workflow to handle clarification rounds recursively - Add comprehensive test coverage for clarification functionality - Update documentation with clarification feature usage guide Key components: - src/graph/nodes.py: Core clarification logic and state management - src/prompts/coordinator.md: Detailed clarification guidelines - src/workflow.py: Recursive clarification handling - web/: UI settings integration - tests/: Comprehensive test coverage - docs/: Updated configuration guide * fix: Improve clarification conversation continuity - Add comprehensive conversation history to clarification context - Include previous exchanges summary in system messages - Add explicit guidelines for continuing rounds in coordinator prompt - Prevent LLM from starting new topics during clarification - Ensure topic continuity across clarification rounds Fixes issue where LLM would restart clarification instead of building upon previous exchanges. * fix: Add conversation history to clarification context * fix: resolve clarification feature message to planer, prompt, test issues - Optimize coordinator.md prompt template for better clarification flow - Simplify final message sent to planner after clarification - Fix API key assertion issues in test_search.py * fix: Add configurable max_clarification_rounds and comprehensive tests - Add max_clarification_rounds parameter for external configuration - Add comprehensive test cases for clarification feature in test_app.py - Fixes issues found during interactive mode testing where: - Recursive call failed due to missing initial_state parameter - Clarification exited prematurely at max rounds - Incorrect logging of max rounds reached * Move clarification tests to test_nodes.py and add max_clarification_rounds to zh.json * fix: add max_clarification_rounds parameter passing from frontend to backend - Add max_clarification_rounds parameter in store.ts sendMessage function - Add max_clarification_rounds type definition in chat.ts - Ensure frontend settings page clarification rounds are correctly passed to backend * fix: refine clarification workflow state handling and coverage - Add clarification history reconstruction - Fix clarified topic accumulation - Add clarified_research_topic state field - Preserve clarification state in recursive calls - Add comprehensive test coverage * refactor: optimize coordinator logic and type annotations - Simplify handoff topic logic in coordinator_node - Update type annotations from Tuple to tuple - Improve code readability and maintainability --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -51,7 +51,9 @@ class Configuration:
|
|||||||
mcp_settings: dict = None # MCP settings, including dynamic loaded tools
|
mcp_settings: dict = None # MCP settings, including dynamic loaded tools
|
||||||
report_style: str = ReportStyle.ACADEMIC.value # Report style
|
report_style: str = ReportStyle.ACADEMIC.value # Report style
|
||||||
enable_deep_thinking: bool = False # Whether to enable deep thinking
|
enable_deep_thinking: bool = False # Whether to enable deep thinking
|
||||||
enforce_web_search: bool = False # Enforce at least one web search step in every plan
|
enforce_web_search: bool = (
|
||||||
|
False # Enforce at least one web search step in every plan
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_runnable_config(
|
def from_runnable_config(
|
||||||
|
|||||||
@@ -31,6 +31,12 @@ from src.utils.json_utils import repair_json_output
|
|||||||
|
|
||||||
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
|
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
|
||||||
from .types import State
|
from .types import State
|
||||||
|
from .utils import (
|
||||||
|
build_clarified_topic_from_history,
|
||||||
|
get_message_content,
|
||||||
|
is_user_message,
|
||||||
|
reconstruct_clarification_history,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -49,6 +55,9 @@ def handoff_to_planner(
|
|||||||
@tool
|
@tool
|
||||||
def handoff_after_clarification(
|
def handoff_after_clarification(
|
||||||
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
|
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
|
||||||
|
research_topic: Annotated[
|
||||||
|
str, "The clarified research topic based on all clarification rounds."
|
||||||
|
],
|
||||||
):
|
):
|
||||||
"""Handoff to planner after clarification rounds are complete. Pass all clarification history to planner for analysis."""
|
"""Handoff to planner after clarification rounds are complete. Pass all clarification history to planner for analysis."""
|
||||||
return
|
return
|
||||||
@@ -107,7 +116,9 @@ def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict:
|
|||||||
# This ensures that at least one step will perform a web search as required.
|
# This ensures that at least one step will perform a web search as required.
|
||||||
steps[0]["step_type"] = "research"
|
steps[0]["step_type"] = "research"
|
||||||
steps[0]["need_search"] = True
|
steps[0]["need_search"] = True
|
||||||
logger.info("Converted first step to research with web search enforcement")
|
logger.info(
|
||||||
|
"Converted first step to research with web search enforcement"
|
||||||
|
)
|
||||||
elif not has_search_step and not steps:
|
elif not has_search_step and not steps:
|
||||||
# Add a default research step if no steps exist
|
# Add a default research step if no steps exist
|
||||||
logger.warning("Plan has no steps. Adding default research step.")
|
logger.warning("Plan has no steps. Adding default research step.")
|
||||||
@@ -126,7 +137,7 @@ def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict:
|
|||||||
def background_investigation_node(state: State, config: RunnableConfig):
|
def background_investigation_node(state: State, config: RunnableConfig):
|
||||||
logger.info("background investigation node is running.")
|
logger.info("background investigation node is running.")
|
||||||
configurable = Configuration.from_runnable_config(config)
|
configurable = Configuration.from_runnable_config(config)
|
||||||
query = state.get("research_topic")
|
query = state.get("clarified_research_topic") or state.get("research_topic")
|
||||||
background_investigation_results = None
|
background_investigation_results = None
|
||||||
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
|
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
|
||||||
searched_content = LoggedTavilySearch(
|
searched_content = LoggedTavilySearch(
|
||||||
@@ -316,7 +327,8 @@ def coordinator_node(
|
|||||||
|
|
||||||
# Check if clarification is enabled
|
# Check if clarification is enabled
|
||||||
enable_clarification = state.get("enable_clarification", False)
|
enable_clarification = state.get("enable_clarification", False)
|
||||||
|
initial_topic = state.get("research_topic", "")
|
||||||
|
clarified_topic = initial_topic
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# BRANCH 1: Clarification DISABLED (Legacy Mode)
|
# BRANCH 1: Clarification DISABLED (Legacy Mode)
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -338,7 +350,6 @@ def coordinator_node(
|
|||||||
.invoke(messages)
|
.invoke(messages)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process response - should directly handoff to planner
|
|
||||||
goto = "__end__"
|
goto = "__end__"
|
||||||
locale = state.get("locale", "en-US")
|
locale = state.get("locale", "en-US")
|
||||||
research_topic = state.get("research_topic", "")
|
research_topic = state.get("research_topic", "")
|
||||||
@@ -370,12 +381,28 @@ def coordinator_node(
|
|||||||
else:
|
else:
|
||||||
# Load clarification state
|
# Load clarification state
|
||||||
clarification_rounds = state.get("clarification_rounds", 0)
|
clarification_rounds = state.get("clarification_rounds", 0)
|
||||||
clarification_history = state.get("clarification_history", [])
|
clarification_history = list(state.get("clarification_history", []) or [])
|
||||||
|
clarification_history = [item for item in clarification_history if item]
|
||||||
max_clarification_rounds = state.get("max_clarification_rounds", 3)
|
max_clarification_rounds = state.get("max_clarification_rounds", 3)
|
||||||
|
|
||||||
# Prepare the messages for the coordinator
|
# Prepare the messages for the coordinator
|
||||||
|
state_messages = list(state.get("messages", []))
|
||||||
messages = apply_prompt_template("coordinator", state)
|
messages = apply_prompt_template("coordinator", state)
|
||||||
|
|
||||||
|
clarification_history = reconstruct_clarification_history(
|
||||||
|
state_messages, clarification_history, initial_topic
|
||||||
|
)
|
||||||
|
clarified_topic, clarification_history = build_clarified_topic_from_history(
|
||||||
|
clarification_history
|
||||||
|
)
|
||||||
|
logger.debug("Clarification history rebuilt: %s", clarification_history)
|
||||||
|
|
||||||
|
if clarification_history:
|
||||||
|
initial_topic = clarification_history[0]
|
||||||
|
latest_user_content = clarification_history[-1]
|
||||||
|
else:
|
||||||
|
latest_user_content = ""
|
||||||
|
|
||||||
# Add clarification status for first round
|
# Add clarification status for first round
|
||||||
if clarification_rounds == 0:
|
if clarification_rounds == 0:
|
||||||
messages.append(
|
messages.append(
|
||||||
@@ -385,91 +412,21 @@ def coordinator_node(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add clarification context if continuing conversation (round > 0)
|
logger.info(
|
||||||
elif clarification_rounds > 0:
|
"Clarification round %s/%s | topic: %s | latest user content: %s",
|
||||||
logger.info(
|
clarification_rounds,
|
||||||
f"Clarification enabled (rounds: {clarification_rounds}/{max_clarification_rounds}): Continuing conversation"
|
max_clarification_rounds,
|
||||||
)
|
clarified_topic or initial_topic,
|
||||||
|
latest_user_content or "N/A",
|
||||||
|
)
|
||||||
|
|
||||||
# Add user's response to clarification history (only user messages)
|
current_response = latest_user_content or "No response"
|
||||||
last_message = None
|
|
||||||
if state.get("messages"):
|
|
||||||
last_message = state["messages"][-1]
|
|
||||||
# Extract content from last message for logging
|
|
||||||
if isinstance(last_message, dict):
|
|
||||||
content = last_message.get("content", "No content")
|
|
||||||
else:
|
|
||||||
content = getattr(last_message, "content", "No content")
|
|
||||||
logger.info(f"Last message content: {content}")
|
|
||||||
# Handle dict format
|
|
||||||
if isinstance(last_message, dict):
|
|
||||||
if last_message.get("role") == "user":
|
|
||||||
clarification_history.append(last_message["content"])
|
|
||||||
logger.info(
|
|
||||||
f"Added user response to clarification history: {last_message['content']}"
|
|
||||||
)
|
|
||||||
# Handle object format (like HumanMessage)
|
|
||||||
elif hasattr(last_message, "role") and last_message.role == "user":
|
|
||||||
clarification_history.append(last_message.content)
|
|
||||||
logger.info(
|
|
||||||
f"Added user response to clarification history: {last_message.content}"
|
|
||||||
)
|
|
||||||
# Handle object format with content attribute (like the one in logs)
|
|
||||||
elif hasattr(last_message, "content"):
|
|
||||||
clarification_history.append(last_message.content)
|
|
||||||
logger.info(
|
|
||||||
f"Added user response to clarification history: {last_message.content}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build comprehensive clarification context with conversation history
|
clarification_context = f"""Continuing clarification (round {clarification_rounds}/{max_clarification_rounds}):
|
||||||
current_response = "No response"
|
|
||||||
if last_message:
|
|
||||||
# Handle dict format
|
|
||||||
if isinstance(last_message, dict):
|
|
||||||
if last_message.get("role") == "user":
|
|
||||||
current_response = last_message.get("content", "No response")
|
|
||||||
else:
|
|
||||||
# If last message is not from user, try to get the latest user message
|
|
||||||
messages = state.get("messages", [])
|
|
||||||
for msg in reversed(messages):
|
|
||||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
|
||||||
current_response = msg.get("content", "No response")
|
|
||||||
break
|
|
||||||
# Handle object format (like HumanMessage)
|
|
||||||
elif hasattr(last_message, "role") and last_message.role == "user":
|
|
||||||
current_response = last_message.content
|
|
||||||
# Handle object format with content attribute (like the one in logs)
|
|
||||||
elif hasattr(last_message, "content"):
|
|
||||||
current_response = last_message.content
|
|
||||||
else:
|
|
||||||
# If last message is not from user, try to get the latest user message
|
|
||||||
messages = state.get("messages", [])
|
|
||||||
for msg in reversed(messages):
|
|
||||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
|
||||||
current_response = msg.get("content", "No response")
|
|
||||||
break
|
|
||||||
elif hasattr(msg, "role") and msg.role == "user":
|
|
||||||
current_response = msg.content
|
|
||||||
break
|
|
||||||
elif hasattr(msg, "content"):
|
|
||||||
current_response = msg.content
|
|
||||||
break
|
|
||||||
|
|
||||||
# Create conversation history summary
|
|
||||||
conversation_summary = ""
|
|
||||||
if clarification_history:
|
|
||||||
conversation_summary = "Previous conversation:\n"
|
|
||||||
for i, response in enumerate(clarification_history, 1):
|
|
||||||
conversation_summary += f"- Round {i}: {response}\n"
|
|
||||||
|
|
||||||
clarification_context = f"""Continuing clarification (round {clarification_rounds}/{max_clarification_rounds}):
|
|
||||||
User's latest response: {current_response}
|
User's latest response: {current_response}
|
||||||
Ask for remaining missing dimensions. Do NOT repeat questions or start new topics."""
|
Ask for remaining missing dimensions. Do NOT repeat questions or start new topics."""
|
||||||
|
|
||||||
# Log the clarification context for debugging
|
messages.append({"role": "system", "content": clarification_context})
|
||||||
logger.info(f"Clarification context: {clarification_context}")
|
|
||||||
|
|
||||||
messages.append({"role": "system", "content": clarification_context})
|
|
||||||
|
|
||||||
# Bind both clarification tools
|
# Bind both clarification tools
|
||||||
tools = [handoff_to_planner, handoff_after_clarification]
|
tools = [handoff_to_planner, handoff_after_clarification]
|
||||||
@@ -483,7 +440,13 @@ def coordinator_node(
|
|||||||
# Initialize response processing variables
|
# Initialize response processing variables
|
||||||
goto = "__end__"
|
goto = "__end__"
|
||||||
locale = state.get("locale", "en-US")
|
locale = state.get("locale", "en-US")
|
||||||
research_topic = state.get("research_topic", "")
|
research_topic = (
|
||||||
|
clarification_history[0]
|
||||||
|
if clarification_history
|
||||||
|
else state.get("research_topic", "")
|
||||||
|
)
|
||||||
|
if not clarified_topic:
|
||||||
|
clarified_topic = research_topic
|
||||||
|
|
||||||
# --- Process LLM response ---
|
# --- Process LLM response ---
|
||||||
# No tool calls - LLM is asking a clarifying question
|
# No tool calls - LLM is asking a clarifying question
|
||||||
@@ -497,20 +460,21 @@ def coordinator_node(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Append coordinator's question to messages
|
# Append coordinator's question to messages
|
||||||
state_messages = state.get("messages", [])
|
updated_messages = list(state_messages)
|
||||||
if response.content:
|
if response.content:
|
||||||
state_messages.append(
|
updated_messages.append(
|
||||||
HumanMessage(content=response.content, name="coordinator")
|
HumanMessage(content=response.content, name="coordinator")
|
||||||
)
|
)
|
||||||
|
|
||||||
return Command(
|
return Command(
|
||||||
update={
|
update={
|
||||||
"messages": state_messages,
|
"messages": updated_messages,
|
||||||
"locale": locale,
|
"locale": locale,
|
||||||
"research_topic": research_topic,
|
"research_topic": research_topic,
|
||||||
"resources": configurable.resources,
|
"resources": configurable.resources,
|
||||||
"clarification_rounds": clarification_rounds,
|
"clarification_rounds": clarification_rounds,
|
||||||
"clarification_history": clarification_history,
|
"clarification_history": clarification_history,
|
||||||
|
"clarified_research_topic": clarified_topic,
|
||||||
"is_clarification_complete": False,
|
"is_clarification_complete": False,
|
||||||
"clarified_question": "",
|
"clarified_question": "",
|
||||||
"goto": goto,
|
"goto": goto,
|
||||||
@@ -521,7 +485,7 @@ def coordinator_node(
|
|||||||
else:
|
else:
|
||||||
# Max rounds reached - no more questions allowed
|
# Max rounds reached - no more questions allowed
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Max clarification rounds ({max_clarification_rounds}) reached. Handing off to planner."
|
f"Max clarification rounds ({max_clarification_rounds}) reached. Handing off to planner. Using prepared clarified topic: {clarified_topic}"
|
||||||
)
|
)
|
||||||
goto = "planner"
|
goto = "planner"
|
||||||
if state.get("enable_background_investigation"):
|
if state.get("enable_background_investigation"):
|
||||||
@@ -539,7 +503,7 @@ def coordinator_node(
|
|||||||
# ============================================================
|
# ============================================================
|
||||||
# Final: Build and return Command
|
# Final: Build and return Command
|
||||||
# ============================================================
|
# ============================================================
|
||||||
messages = state.get("messages", [])
|
messages = list(state.get("messages", []) or [])
|
||||||
if response.content:
|
if response.content:
|
||||||
messages.append(HumanMessage(content=response.content, name="coordinator"))
|
messages.append(HumanMessage(content=response.content, name="coordinator"))
|
||||||
|
|
||||||
@@ -554,10 +518,20 @@ def coordinator_node(
|
|||||||
logger.info("Handing off to planner")
|
logger.info("Handing off to planner")
|
||||||
goto = "planner"
|
goto = "planner"
|
||||||
|
|
||||||
# Extract locale and research_topic if provided
|
# Extract locale if provided
|
||||||
if tool_args.get("locale") and tool_args.get("research_topic"):
|
locale = tool_args.get("locale", locale)
|
||||||
locale = tool_args.get("locale")
|
if not enable_clarification and tool_args.get("research_topic"):
|
||||||
research_topic = tool_args.get("research_topic")
|
research_topic = tool_args["research_topic"]
|
||||||
|
|
||||||
|
if enable_clarification:
|
||||||
|
logger.info(
|
||||||
|
"Using prepared clarified topic: %s",
|
||||||
|
clarified_topic or research_topic,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Using research topic for handoff: %s", research_topic
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -584,16 +558,24 @@ def coordinator_node(
|
|||||||
clarification_rounds = 0
|
clarification_rounds = 0
|
||||||
clarification_history = []
|
clarification_history = []
|
||||||
|
|
||||||
|
clarified_research_topic_value = clarified_topic or research_topic
|
||||||
|
|
||||||
|
if enable_clarification:
|
||||||
|
handoff_topic = clarified_topic or research_topic
|
||||||
|
else:
|
||||||
|
handoff_topic = research_topic
|
||||||
|
|
||||||
return Command(
|
return Command(
|
||||||
update={
|
update={
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"locale": locale,
|
"locale": locale,
|
||||||
"research_topic": research_topic,
|
"research_topic": research_topic,
|
||||||
|
"clarified_research_topic": clarified_research_topic_value,
|
||||||
"resources": configurable.resources,
|
"resources": configurable.resources,
|
||||||
"clarification_rounds": clarification_rounds,
|
"clarification_rounds": clarification_rounds,
|
||||||
"clarification_history": clarification_history,
|
"clarification_history": clarification_history,
|
||||||
"is_clarification_complete": goto != "coordinator",
|
"is_clarification_complete": goto != "coordinator",
|
||||||
"clarified_question": research_topic if goto != "coordinator" else "",
|
"clarified_question": handoff_topic if goto != "coordinator" else "",
|
||||||
"goto": goto,
|
"goto": goto,
|
||||||
},
|
},
|
||||||
goto=goto,
|
goto=goto,
|
||||||
@@ -747,6 +729,7 @@ async def _execute_agent_step(
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
error_traceback = traceback.format_exc()
|
error_traceback = traceback.format_exc()
|
||||||
error_message = f"Error executing {agent_name} agent for step '{current_step.title}': {str(e)}"
|
error_message = f"Error executing {agent_name} agent for step '{current_step.title}': {str(e)}"
|
||||||
logger.exception(error_message)
|
logger.exception(error_message)
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
|
from dataclasses import field
|
||||||
|
|
||||||
from langgraph.graph import MessagesState
|
from langgraph.graph import MessagesState
|
||||||
|
|
||||||
from src.prompts.planner_model import Plan
|
from src.prompts.planner_model import Plan
|
||||||
@@ -14,6 +16,7 @@ class State(MessagesState):
|
|||||||
# Runtime Variables
|
# Runtime Variables
|
||||||
locale: str = "en-US"
|
locale: str = "en-US"
|
||||||
research_topic: str = ""
|
research_topic: str = ""
|
||||||
|
clarified_research_topic: str = ""
|
||||||
observations: list[str] = []
|
observations: list[str] = []
|
||||||
resources: list[Resource] = []
|
resources: list[Resource] = []
|
||||||
plan_iterations: int = 0
|
plan_iterations: int = 0
|
||||||
@@ -28,7 +31,7 @@ class State(MessagesState):
|
|||||||
False # Enable/disable clarification feature (default: False)
|
False # Enable/disable clarification feature (default: False)
|
||||||
)
|
)
|
||||||
clarification_rounds: int = 0
|
clarification_rounds: int = 0
|
||||||
clarification_history: list[str] = []
|
clarification_history: list[str] = field(default_factory=list)
|
||||||
is_clarification_complete: bool = False
|
is_clarification_complete: bool = False
|
||||||
clarified_question: str = ""
|
clarified_question: str = ""
|
||||||
max_clarification_rounds: int = (
|
max_clarification_rounds: int = (
|
||||||
|
|||||||
113
src/graph/utils.py
Normal file
113
src/graph/utils.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
ASSISTANT_SPEAKER_NAMES = {
|
||||||
|
"coordinator",
|
||||||
|
"planner",
|
||||||
|
"researcher",
|
||||||
|
"coder",
|
||||||
|
"reporter",
|
||||||
|
"background_investigator",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_message_content(message: Any) -> str:
|
||||||
|
"""Extract message content from dict or LangChain message."""
|
||||||
|
if isinstance(message, dict):
|
||||||
|
return message.get("content", "")
|
||||||
|
return getattr(message, "content", "")
|
||||||
|
|
||||||
|
|
||||||
|
def is_user_message(message: Any) -> bool:
|
||||||
|
"""Return True if the message originated from the end user."""
|
||||||
|
if isinstance(message, dict):
|
||||||
|
role = (message.get("role") or "").lower()
|
||||||
|
if role in {"user", "human"}:
|
||||||
|
return True
|
||||||
|
if role in {"assistant", "system"}:
|
||||||
|
return False
|
||||||
|
name = (message.get("name") or "").lower()
|
||||||
|
if name and name in ASSISTANT_SPEAKER_NAMES:
|
||||||
|
return False
|
||||||
|
return role == "" and name not in ASSISTANT_SPEAKER_NAMES
|
||||||
|
|
||||||
|
message_type = (getattr(message, "type", "") or "").lower()
|
||||||
|
name = (getattr(message, "name", "") or "").lower()
|
||||||
|
if message_type == "human":
|
||||||
|
return not (name and name in ASSISTANT_SPEAKER_NAMES)
|
||||||
|
|
||||||
|
role_attr = getattr(message, "role", None)
|
||||||
|
if isinstance(role_attr, str) and role_attr.lower() in {"user", "human"}:
|
||||||
|
return True
|
||||||
|
|
||||||
|
additional_role = getattr(message, "additional_kwargs", {}).get("role")
|
||||||
|
if isinstance(additional_role, str) and additional_role.lower() in {
|
||||||
|
"user",
|
||||||
|
"human",
|
||||||
|
}:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_latest_user_message(messages: list[Any]) -> tuple[Any, str]:
|
||||||
|
"""Return the latest user-authored message and its content."""
|
||||||
|
for message in reversed(messages or []):
|
||||||
|
if is_user_message(message):
|
||||||
|
content = get_message_content(message)
|
||||||
|
if content:
|
||||||
|
return message, content
|
||||||
|
return None, ""
|
||||||
|
|
||||||
|
|
||||||
|
def build_clarified_topic_from_history(
|
||||||
|
clarification_history: list[str],
|
||||||
|
) -> tuple[str, list[str]]:
|
||||||
|
"""Construct clarified topic string from an ordered clarification history."""
|
||||||
|
sequence = [item for item in clarification_history if item]
|
||||||
|
if not sequence:
|
||||||
|
return "", []
|
||||||
|
if len(sequence) == 1:
|
||||||
|
return sequence[0], sequence
|
||||||
|
head, *tail = sequence
|
||||||
|
clarified_string = f"{head} - {', '.join(tail)}"
|
||||||
|
return clarified_string, sequence
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruct_clarification_history(
|
||||||
|
messages: list[Any],
|
||||||
|
fallback_history: list[str] | None = None,
|
||||||
|
base_topic: str = "",
|
||||||
|
) -> list[str]:
|
||||||
|
"""Rebuild clarification history from user-authored messages, with fallback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation messages in chronological order.
|
||||||
|
fallback_history: Optional existing history to use if no user messages found.
|
||||||
|
base_topic: Optional topic to use when no user messages are available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A cleaned clarification history containing unique consecutive user contents.
|
||||||
|
"""
|
||||||
|
sequence: list[str] = []
|
||||||
|
for message in messages or []:
|
||||||
|
if not is_user_message(message):
|
||||||
|
continue
|
||||||
|
content = get_message_content(message)
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
if sequence and sequence[-1] == content:
|
||||||
|
continue
|
||||||
|
sequence.append(content)
|
||||||
|
|
||||||
|
if sequence:
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
fallback = [item for item in (fallback_history or []) if item]
|
||||||
|
if fallback:
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
base_topic = (base_topic or "").strip()
|
||||||
|
return [base_topic] if base_topic else []
|
||||||
@@ -25,6 +25,10 @@ from src.config.report_style import ReportStyle
|
|||||||
from src.config.tools import SELECTED_RAG_PROVIDER
|
from src.config.tools import SELECTED_RAG_PROVIDER
|
||||||
from src.graph.builder import build_graph_with_memory
|
from src.graph.builder import build_graph_with_memory
|
||||||
from src.graph.checkpoint import chat_stream_message
|
from src.graph.checkpoint import chat_stream_message
|
||||||
|
from src.graph.utils import (
|
||||||
|
build_clarified_topic_from_history,
|
||||||
|
reconstruct_clarification_history,
|
||||||
|
)
|
||||||
from src.llms.llm import get_configured_llm_models
|
from src.llms.llm import get_configured_llm_models
|
||||||
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
||||||
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
||||||
@@ -309,6 +313,14 @@ async def _astream_workflow_generator(
|
|||||||
if isinstance(message, dict) and "content" in message:
|
if isinstance(message, dict) and "content" in message:
|
||||||
_process_initial_messages(message, thread_id)
|
_process_initial_messages(message, thread_id)
|
||||||
|
|
||||||
|
clarification_history = reconstruct_clarification_history(messages)
|
||||||
|
|
||||||
|
clarified_topic, clarification_history = build_clarified_topic_from_history(
|
||||||
|
clarification_history
|
||||||
|
)
|
||||||
|
latest_message_content = messages[-1]["content"] if messages else ""
|
||||||
|
clarified_research_topic = clarified_topic or latest_message_content
|
||||||
|
|
||||||
# Prepare workflow input
|
# Prepare workflow input
|
||||||
workflow_input = {
|
workflow_input = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@@ -318,7 +330,9 @@ async def _astream_workflow_generator(
|
|||||||
"observations": [],
|
"observations": [],
|
||||||
"auto_accepted_plan": auto_accepted_plan,
|
"auto_accepted_plan": auto_accepted_plan,
|
||||||
"enable_background_investigation": enable_background_investigation,
|
"enable_background_investigation": enable_background_investigation,
|
||||||
"research_topic": messages[-1]["content"] if messages else "",
|
"research_topic": latest_message_content,
|
||||||
|
"clarification_history": clarification_history,
|
||||||
|
"clarified_research_topic": clarified_research_topic,
|
||||||
"enable_clarification": enable_clarification,
|
"enable_clarification": enable_clarification,
|
||||||
"max_clarification_rounds": max_clarification_rounds,
|
"max_clarification_rounds": max_clarification_rounds,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import logging
|
|||||||
|
|
||||||
from src.config.configuration import get_recursion_limit
|
from src.config.configuration import get_recursion_limit
|
||||||
from src.graph import build_graph
|
from src.graph import build_graph
|
||||||
|
from src.graph.utils import build_clarified_topic_from_history
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -65,6 +66,8 @@ async def run_agent_workflow_async(
|
|||||||
"auto_accepted_plan": True,
|
"auto_accepted_plan": True,
|
||||||
"enable_background_investigation": enable_background_investigation,
|
"enable_background_investigation": enable_background_investigation,
|
||||||
}
|
}
|
||||||
|
initial_state["research_topic"] = user_input
|
||||||
|
initial_state["clarified_research_topic"] = user_input
|
||||||
|
|
||||||
# Only set clarification parameter if explicitly provided
|
# Only set clarification parameter if explicitly provided
|
||||||
# If None, State class default will be used (enable_clarification=False)
|
# If None, State class default will be used (enable_clarification=False)
|
||||||
@@ -137,7 +140,18 @@ async def run_agent_workflow_async(
|
|||||||
current_state["messages"] = final_state["messages"] + [
|
current_state["messages"] = final_state["messages"] + [
|
||||||
{"role": "user", "content": user_response}
|
{"role": "user", "content": user_response}
|
||||||
]
|
]
|
||||||
# Recursive call for clarification continuation
|
for key in (
|
||||||
|
"clarification_history",
|
||||||
|
"clarification_rounds",
|
||||||
|
"clarified_research_topic",
|
||||||
|
"research_topic",
|
||||||
|
"locale",
|
||||||
|
"enable_clarification",
|
||||||
|
"max_clarification_rounds",
|
||||||
|
):
|
||||||
|
if key in final_state:
|
||||||
|
current_state[key] = final_state[key]
|
||||||
|
|
||||||
return await run_agent_workflow_async(
|
return await run_agent_workflow_async(
|
||||||
user_input=user_response,
|
user_input=user_response,
|
||||||
max_plan_iterations=max_plan_iterations,
|
max_plan_iterations=max_plan_iterations,
|
||||||
|
|||||||
@@ -451,7 +451,9 @@ def test_human_feedback_node_accepted(monkeypatch, mock_state_base, mock_config)
|
|||||||
assert result.update["current_plan"]["has_enough_context"] is False
|
assert result.update["current_plan"]["has_enough_context"] is False
|
||||||
|
|
||||||
|
|
||||||
def test_human_feedback_node_invalid_interrupt(monkeypatch, mock_state_base, mock_config):
|
def test_human_feedback_node_invalid_interrupt(
|
||||||
|
monkeypatch, mock_state_base, mock_config
|
||||||
|
):
|
||||||
# interrupt returns something else, should raise TypeError
|
# interrupt returns something else, should raise TypeError
|
||||||
state = dict(mock_state_base)
|
state = dict(mock_state_base)
|
||||||
state["auto_accepted_plan"] = False
|
state["auto_accepted_plan"] = False
|
||||||
@@ -490,7 +492,9 @@ def test_human_feedback_node_json_decode_error_second_iteration(
|
|||||||
assert result.goto == "reporter"
|
assert result.goto == "reporter"
|
||||||
|
|
||||||
|
|
||||||
def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base, mock_config):
|
def test_human_feedback_node_not_enough_context(
|
||||||
|
monkeypatch, mock_state_base, mock_config
|
||||||
|
):
|
||||||
# Plan does not have enough context, should goto research_team
|
# Plan does not have enough context, should goto research_team
|
||||||
plan = {
|
plan = {
|
||||||
"has_enough_context": False,
|
"has_enough_context": False,
|
||||||
@@ -1446,7 +1450,9 @@ def test_handoff_tools():
|
|||||||
assert result is None # Tool should return None (no-op)
|
assert result is None # Tool should return None (no-op)
|
||||||
|
|
||||||
# Test handoff_after_clarification tool - use invoke() method
|
# Test handoff_after_clarification tool - use invoke() method
|
||||||
result = handoff_after_clarification.invoke({"locale": "en-US"})
|
result = handoff_after_clarification.invoke(
|
||||||
|
{"locale": "en-US", "research_topic": "renewable energy research"}
|
||||||
|
)
|
||||||
assert result is None # Tool should return None (no-op)
|
assert result is None # Tool should return None (no-op)
|
||||||
|
|
||||||
|
|
||||||
@@ -1468,9 +1474,13 @@ def test_coordinator_tools_with_clarification_enabled(mock_get_llm):
|
|||||||
"clarification_rounds": 2,
|
"clarification_rounds": 2,
|
||||||
"max_clarification_rounds": 3,
|
"max_clarification_rounds": 3,
|
||||||
"is_clarification_complete": False,
|
"is_clarification_complete": False,
|
||||||
"clarification_history": ["response 1", "response 2"],
|
"clarification_history": [
|
||||||
|
"Tell me about something",
|
||||||
|
"response 1",
|
||||||
|
"response 2",
|
||||||
|
],
|
||||||
"locale": "en-US",
|
"locale": "en-US",
|
||||||
"research_topic": "",
|
"research_topic": "Tell me about something",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock config
|
# Mock config
|
||||||
@@ -1567,3 +1577,289 @@ def test_coordinator_empty_llm_response_corner_case(mock_get_llm):
|
|||||||
# Should gracefully handle empty response by going to planner to ensure workflow continues
|
# Should gracefully handle empty response by going to planner to ensure workflow continues
|
||||||
assert result.goto == "planner"
|
assert result.goto == "planner"
|
||||||
assert result.update["locale"] == "en-US"
|
assert result.update["locale"] == "en-US"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Clarification flow tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_clarification_handoff_combines_history():
|
||||||
|
"""Coordinator should merge original topic with all clarification answers before handoff."""
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
test_state = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Research artificial intelligence"},
|
||||||
|
{"role": "assistant", "content": "Which area of AI should we focus on?"},
|
||||||
|
{"role": "user", "content": "Machine learning applications"},
|
||||||
|
{"role": "assistant", "content": "What dimension of that should we cover?"},
|
||||||
|
{"role": "user", "content": "Technical implementation details"},
|
||||||
|
],
|
||||||
|
"enable_clarification": True,
|
||||||
|
"clarification_rounds": 2,
|
||||||
|
"clarification_history": [
|
||||||
|
"Research artificial intelligence",
|
||||||
|
"Machine learning applications",
|
||||||
|
"Technical implementation details",
|
||||||
|
],
|
||||||
|
"max_clarification_rounds": 3,
|
||||||
|
"research_topic": "Research artificial intelligence",
|
||||||
|
"clarified_research_topic": "Research artificial intelligence - Machine learning applications, Technical implementation details",
|
||||||
|
"locale": "en-US",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = RunnableConfig(configurable={"thread_id": "clarification-test"})
|
||||||
|
|
||||||
|
mock_response = AIMessage(
|
||||||
|
content="Understood, handing off now.",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "handoff_after_clarification",
|
||||||
|
"args": {"locale": "en-US", "research_topic": "placeholder"},
|
||||||
|
"id": "tool-call-handoff",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
result = coordinator_node(test_state, config)
|
||||||
|
|
||||||
|
assert hasattr(result, "update")
|
||||||
|
update = result.update
|
||||||
|
assert update["clarification_history"] == [
|
||||||
|
"Research artificial intelligence",
|
||||||
|
"Machine learning applications",
|
||||||
|
"Technical implementation details",
|
||||||
|
]
|
||||||
|
expected_topic = (
|
||||||
|
"Research artificial intelligence - "
|
||||||
|
"Machine learning applications, Technical implementation details"
|
||||||
|
)
|
||||||
|
assert update["research_topic"] == "Research artificial intelligence"
|
||||||
|
assert update["clarified_research_topic"] == expected_topic
|
||||||
|
assert update["clarified_question"] == expected_topic
|
||||||
|
|
||||||
|
|
||||||
|
def test_clarification_history_reconstructed_from_messages():
|
||||||
|
"""Coordinator should rebuild clarification history from full message log when state is incomplete."""
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
incomplete_state = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Research on renewable energy"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Which type of renewable energy interests you?",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Solar and wind energy"},
|
||||||
|
{"role": "assistant", "content": "Which aspect should we focus on?"},
|
||||||
|
{"role": "user", "content": "Technical implementation"},
|
||||||
|
],
|
||||||
|
"enable_clarification": True,
|
||||||
|
"clarification_rounds": 2,
|
||||||
|
"clarification_history": ["Technical implementation"],
|
||||||
|
"max_clarification_rounds": 3,
|
||||||
|
"research_topic": "Research on renewable energy",
|
||||||
|
"clarified_research_topic": "Research on renewable energy",
|
||||||
|
"locale": "en-US",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = RunnableConfig(configurable={"thread_id": "clarification-history-rebuild"})
|
||||||
|
|
||||||
|
mock_response = AIMessage(
|
||||||
|
content="Understood, handing over now.",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "handoff_after_clarification",
|
||||||
|
"args": {"locale": "en-US", "research_topic": "placeholder"},
|
||||||
|
"id": "tool-call-handoff",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
result = coordinator_node(incomplete_state, config)
|
||||||
|
|
||||||
|
update = result.update
|
||||||
|
assert update["clarification_history"] == [
|
||||||
|
"Research on renewable energy",
|
||||||
|
"Solar and wind energy",
|
||||||
|
"Technical implementation",
|
||||||
|
]
|
||||||
|
assert update["research_topic"] == "Research on renewable energy"
|
||||||
|
assert (
|
||||||
|
update["clarified_research_topic"]
|
||||||
|
== "Research on renewable energy - Solar and wind energy, Technical implementation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clarification_max_rounds_without_tool_call():
|
||||||
|
"""Coordinator should stop asking questions after max rounds and hand off with compiled topic."""
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
test_state = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Research artificial intelligence"},
|
||||||
|
{"role": "assistant", "content": "Which area should we focus on?"},
|
||||||
|
{"role": "user", "content": "Natural language processing"},
|
||||||
|
{"role": "assistant", "content": "Which domain matters most?"},
|
||||||
|
{"role": "user", "content": "Healthcare"},
|
||||||
|
{"role": "assistant", "content": "Any specific scenario to study?"},
|
||||||
|
{"role": "user", "content": "Clinical documentation"},
|
||||||
|
],
|
||||||
|
"enable_clarification": True,
|
||||||
|
"clarification_rounds": 3,
|
||||||
|
"clarification_history": [
|
||||||
|
"Research artificial intelligence",
|
||||||
|
"Natural language processing",
|
||||||
|
"Healthcare",
|
||||||
|
"Clinical documentation",
|
||||||
|
],
|
||||||
|
"max_clarification_rounds": 3,
|
||||||
|
"research_topic": "Research artificial intelligence",
|
||||||
|
"clarified_research_topic": "Research artificial intelligence - Natural language processing, Healthcare, Clinical documentation",
|
||||||
|
"locale": "en-US",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = RunnableConfig(configurable={"thread_id": "clarification-max"})
|
||||||
|
|
||||||
|
mock_response = AIMessage(
|
||||||
|
content="Got it, sending this to the planner.",
|
||||||
|
tool_calls=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
result = coordinator_node(test_state, config)
|
||||||
|
|
||||||
|
assert hasattr(result, "update")
|
||||||
|
update = result.update
|
||||||
|
expected_topic = (
|
||||||
|
"Research artificial intelligence - "
|
||||||
|
"Natural language processing, Healthcare, Clinical documentation"
|
||||||
|
)
|
||||||
|
assert update["research_topic"] == "Research artificial intelligence"
|
||||||
|
assert update["clarified_research_topic"] == expected_topic
|
||||||
|
assert result.goto == "planner"
|
||||||
|
|
||||||
|
|
||||||
|
def test_clarification_human_message_support():
|
||||||
|
"""Coordinator should treat HumanMessage instances from the user as user authored."""
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
test_state = {
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(content="Research artificial intelligence"),
|
||||||
|
HumanMessage(content="Which area should we focus on?", name="coordinator"),
|
||||||
|
HumanMessage(content="Machine learning"),
|
||||||
|
HumanMessage(
|
||||||
|
content="Which dimension should we explore?", name="coordinator"
|
||||||
|
),
|
||||||
|
HumanMessage(content="Technical feasibility"),
|
||||||
|
],
|
||||||
|
"enable_clarification": True,
|
||||||
|
"clarification_rounds": 2,
|
||||||
|
"clarification_history": [
|
||||||
|
"Research artificial intelligence",
|
||||||
|
"Machine learning",
|
||||||
|
"Technical feasibility",
|
||||||
|
],
|
||||||
|
"max_clarification_rounds": 3,
|
||||||
|
"research_topic": "Research artificial intelligence",
|
||||||
|
"clarified_research_topic": "Research artificial intelligence - Machine learning, Technical feasibility",
|
||||||
|
"locale": "en-US",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = RunnableConfig(configurable={"thread_id": "clarification-human"})
|
||||||
|
|
||||||
|
mock_response = AIMessage(
|
||||||
|
content="Moving to planner.",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "handoff_after_clarification",
|
||||||
|
"args": {"locale": "en-US", "research_topic": "placeholder"},
|
||||||
|
"id": "human-message-handoff",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
result = coordinator_node(test_state, config)
|
||||||
|
|
||||||
|
assert hasattr(result, "update")
|
||||||
|
update = result.update
|
||||||
|
expected_topic = (
|
||||||
|
"Research artificial intelligence - Machine learning, Technical feasibility"
|
||||||
|
)
|
||||||
|
assert update["clarification_history"] == [
|
||||||
|
"Research artificial intelligence",
|
||||||
|
"Machine learning",
|
||||||
|
"Technical feasibility",
|
||||||
|
]
|
||||||
|
assert update["research_topic"] == "Research artificial intelligence"
|
||||||
|
assert update["clarified_research_topic"] == expected_topic
|
||||||
|
|
||||||
|
|
||||||
|
def test_clarification_no_history_defaults_to_topic():
|
||||||
|
"""If clarification never started, coordinator should forward the original topic."""
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
test_state = {
|
||||||
|
"messages": [{"role": "user", "content": "What is quantum computing?"}],
|
||||||
|
"enable_clarification": True,
|
||||||
|
"clarification_rounds": 0,
|
||||||
|
"clarification_history": ["What is quantum computing?"],
|
||||||
|
"max_clarification_rounds": 3,
|
||||||
|
"research_topic": "What is quantum computing?",
|
||||||
|
"clarified_research_topic": "What is quantum computing?",
|
||||||
|
"locale": "en-US",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = RunnableConfig(configurable={"thread_id": "clarification-none"})
|
||||||
|
|
||||||
|
mock_response = AIMessage(
|
||||||
|
content="Understood.",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "handoff_to_planner",
|
||||||
|
"args": {"locale": "en-US", "research_topic": "placeholder"},
|
||||||
|
"id": "clarification-none",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
|
||||||
|
mock_get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
result = coordinator_node(test_state, config)
|
||||||
|
|
||||||
|
assert hasattr(result, "update")
|
||||||
|
assert result.update["research_topic"] == "What is quantum computing?"
|
||||||
|
assert result.update["clarified_research_topic"] == "What is quantum computing?"
|
||||||
|
|||||||
@@ -47,6 +47,79 @@ class TestMakeEvent:
|
|||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_astream_workflow_generator_preserves_clarification_history():
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Research on renewable energy"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "What type of renewable energy would you like to know about?",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Solar and wind energy"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Please tell me the research dimensions you focus on, such as technological development or market applications.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Technological development"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Please specify the time range you want to focus on, such as current status or future trends.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Current status and future trends"},
|
||||||
|
]
|
||||||
|
|
||||||
|
captured_data = {}
|
||||||
|
|
||||||
|
def empty_async_iterator(*args, **kwargs):
|
||||||
|
captured_data["workflow_input"] = args[1]
|
||||||
|
captured_data["workflow_config"] = args[2]
|
||||||
|
|
||||||
|
class IteratorObject:
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
return IteratorObject()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("src.server.app._process_initial_messages"),
|
||||||
|
patch("src.server.app._stream_graph_events", side_effect=empty_async_iterator),
|
||||||
|
):
|
||||||
|
generator = _astream_workflow_generator(
|
||||||
|
messages=messages,
|
||||||
|
thread_id="clarification-thread",
|
||||||
|
resources=[],
|
||||||
|
max_plan_iterations=1,
|
||||||
|
max_step_num=1,
|
||||||
|
max_search_results=5,
|
||||||
|
auto_accepted_plan=True,
|
||||||
|
interrupt_feedback="",
|
||||||
|
mcp_settings={},
|
||||||
|
enable_background_investigation=True,
|
||||||
|
report_style=ReportStyle.ACADEMIC,
|
||||||
|
enable_deep_thinking=False,
|
||||||
|
enable_clarification=True,
|
||||||
|
max_clarification_rounds=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(StopAsyncIteration):
|
||||||
|
await generator.__anext__()
|
||||||
|
|
||||||
|
workflow_input = captured_data["workflow_input"]
|
||||||
|
assert workflow_input["clarification_history"] == [
|
||||||
|
"Research on renewable energy",
|
||||||
|
"Solar and wind energy",
|
||||||
|
"Technological development",
|
||||||
|
"Current status and future trends",
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
workflow_input["clarified_research_topic"]
|
||||||
|
== "Research on renewable energy - Solar and wind energy, Technological development, Current status and future trends"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestTTSEndpoint:
|
class TestTTSEndpoint:
|
||||||
@patch.dict(
|
@patch.dict(
|
||||||
os.environ,
|
os.environ,
|
||||||
|
|||||||
Reference in New Issue
Block a user