diff --git a/src/config/configuration.py b/src/config/configuration.py index 890e0b5..093b7f1 100644 --- a/src/config/configuration.py +++ b/src/config/configuration.py @@ -51,7 +51,9 @@ class Configuration: mcp_settings: dict = None # MCP settings, including dynamic loaded tools report_style: str = ReportStyle.ACADEMIC.value # Report style enable_deep_thinking: bool = False # Whether to enable deep thinking - enforce_web_search: bool = False # Enforce at least one web search step in every plan + enforce_web_search: bool = ( + False # Enforce at least one web search step in every plan + ) @classmethod def from_runnable_config( diff --git a/src/graph/nodes.py b/src/graph/nodes.py index e6eaa56..cb107f0 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -31,6 +31,12 @@ from src.utils.json_utils import repair_json_output from ..config import SELECTED_SEARCH_ENGINE, SearchEngine from .types import State +from .utils import ( + build_clarified_topic_from_history, + get_message_content, + is_user_message, + reconstruct_clarification_history, +) logger = logging.getLogger(__name__) @@ -49,6 +55,9 @@ def handoff_to_planner( @tool def handoff_after_clarification( locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."], + research_topic: Annotated[ + str, "The clarified research topic based on all clarification rounds." + ], ): """Handoff to planner after clarification rounds are complete. Pass all clarification history to planner for analysis.""" return @@ -78,23 +87,23 @@ def needs_clarification(state: dict) -> bool: def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict: """ Validate and fix a plan to ensure it meets requirements. - + Args: plan: The plan dict to validate enforce_web_search: If True, ensure at least one step has need_search=true - + Returns: The validated/fixed plan dict """ if not isinstance(plan, dict): return plan - + steps = plan.get("steps", []) - + if enforce_web_search: # Check if any step has need_search=true has_search_step = any(step.get("need_search", False) for step in steps) - + if not has_search_step and steps: # Ensure first research step has web search enabled for idx, step in enumerate(steps): @@ -107,7 +116,9 @@ def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict: # This ensures that at least one step will perform a web search as required. steps[0]["step_type"] = "research" steps[0]["need_search"] = True - logger.info("Converted first step to research with web search enforcement") + logger.info( + "Converted first step to research with web search enforcement" + ) elif not has_search_step and not steps: # Add a default research step if no steps exist logger.warning("Plan has no steps. Adding default research step.") @@ -119,14 +130,14 @@ def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False) -> dict: "step_type": "research", } ] - + return plan def background_investigation_node(state: State, config: RunnableConfig): logger.info("background investigation node is running.") configurable = Configuration.from_runnable_config(config) - query = state.get("research_topic") + query = state.get("clarified_research_topic") or state.get("research_topic") background_investigation_results = None if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value: searched_content = LoggedTavilySearch( @@ -230,11 +241,11 @@ def planner_node( return Command(goto="reporter") else: return Command(goto="__end__") - + # Validate and fix plan to ensure web search requirements are met if isinstance(curr_plan, dict): curr_plan = validate_and_fix_plan(curr_plan, configurable.enforce_web_search) - + if isinstance(curr_plan, dict) and curr_plan.get("has_enough_context"): logger.info("Planner response has enough context.") new_plan = Plan.model_validate(curr_plan) @@ -316,7 +327,8 @@ def coordinator_node( # Check if clarification is enabled enable_clarification = state.get("enable_clarification", False) - + initial_topic = state.get("research_topic", "") + clarified_topic = initial_topic # ============================================================ # BRANCH 1: Clarification DISABLED (Legacy Mode) # ============================================================ @@ -338,7 +350,6 @@ def coordinator_node( .invoke(messages) ) - # Process response - should directly handoff to planner goto = "__end__" locale = state.get("locale", "en-US") research_topic = state.get("research_topic", "") @@ -370,12 +381,28 @@ def coordinator_node( else: # Load clarification state clarification_rounds = state.get("clarification_rounds", 0) - clarification_history = state.get("clarification_history", []) + clarification_history = list(state.get("clarification_history", []) or []) + clarification_history = [item for item in clarification_history if item] max_clarification_rounds = state.get("max_clarification_rounds", 3) # Prepare the messages for the coordinator + state_messages = list(state.get("messages", [])) messages = apply_prompt_template("coordinator", state) + clarification_history = reconstruct_clarification_history( + state_messages, clarification_history, initial_topic + ) + clarified_topic, clarification_history = build_clarified_topic_from_history( + clarification_history + ) + logger.debug("Clarification history rebuilt: %s", clarification_history) + + if clarification_history: + initial_topic = clarification_history[0] + latest_user_content = clarification_history[-1] + else: + latest_user_content = "" + # Add clarification status for first round if clarification_rounds == 0: messages.append( @@ -385,91 +412,21 @@ def coordinator_node( } ) - # Add clarification context if continuing conversation (round > 0) - elif clarification_rounds > 0: - logger.info( - f"Clarification enabled (rounds: {clarification_rounds}/{max_clarification_rounds}): Continuing conversation" - ) + logger.info( + "Clarification round %s/%s | topic: %s | latest user content: %s", + clarification_rounds, + max_clarification_rounds, + clarified_topic or initial_topic, + latest_user_content or "N/A", + ) - # Add user's response to clarification history (only user messages) - last_message = None - if state.get("messages"): - last_message = state["messages"][-1] - # Extract content from last message for logging - if isinstance(last_message, dict): - content = last_message.get("content", "No content") - else: - content = getattr(last_message, "content", "No content") - logger.info(f"Last message content: {content}") - # Handle dict format - if isinstance(last_message, dict): - if last_message.get("role") == "user": - clarification_history.append(last_message["content"]) - logger.info( - f"Added user response to clarification history: {last_message['content']}" - ) - # Handle object format (like HumanMessage) - elif hasattr(last_message, "role") and last_message.role == "user": - clarification_history.append(last_message.content) - logger.info( - f"Added user response to clarification history: {last_message.content}" - ) - # Handle object format with content attribute (like the one in logs) - elif hasattr(last_message, "content"): - clarification_history.append(last_message.content) - logger.info( - f"Added user response to clarification history: {last_message.content}" - ) + current_response = latest_user_content or "No response" - # Build comprehensive clarification context with conversation history - current_response = "No response" - if last_message: - # Handle dict format - if isinstance(last_message, dict): - if last_message.get("role") == "user": - current_response = last_message.get("content", "No response") - else: - # If last message is not from user, try to get the latest user message - messages = state.get("messages", []) - for msg in reversed(messages): - if isinstance(msg, dict) and msg.get("role") == "user": - current_response = msg.get("content", "No response") - break - # Handle object format (like HumanMessage) - elif hasattr(last_message, "role") and last_message.role == "user": - current_response = last_message.content - # Handle object format with content attribute (like the one in logs) - elif hasattr(last_message, "content"): - current_response = last_message.content - else: - # If last message is not from user, try to get the latest user message - messages = state.get("messages", []) - for msg in reversed(messages): - if isinstance(msg, dict) and msg.get("role") == "user": - current_response = msg.get("content", "No response") - break - elif hasattr(msg, "role") and msg.role == "user": - current_response = msg.content - break - elif hasattr(msg, "content"): - current_response = msg.content - break - - # Create conversation history summary - conversation_summary = "" - if clarification_history: - conversation_summary = "Previous conversation:\n" - for i, response in enumerate(clarification_history, 1): - conversation_summary += f"- Round {i}: {response}\n" - - clarification_context = f"""Continuing clarification (round {clarification_rounds}/{max_clarification_rounds}): + clarification_context = f"""Continuing clarification (round {clarification_rounds}/{max_clarification_rounds}): User's latest response: {current_response} Ask for remaining missing dimensions. Do NOT repeat questions or start new topics.""" - # Log the clarification context for debugging - logger.info(f"Clarification context: {clarification_context}") - - messages.append({"role": "system", "content": clarification_context}) + messages.append({"role": "system", "content": clarification_context}) # Bind both clarification tools tools = [handoff_to_planner, handoff_after_clarification] @@ -483,7 +440,13 @@ def coordinator_node( # Initialize response processing variables goto = "__end__" locale = state.get("locale", "en-US") - research_topic = state.get("research_topic", "") + research_topic = ( + clarification_history[0] + if clarification_history + else state.get("research_topic", "") + ) + if not clarified_topic: + clarified_topic = research_topic # --- Process LLM response --- # No tool calls - LLM is asking a clarifying question @@ -497,20 +460,21 @@ def coordinator_node( ) # Append coordinator's question to messages - state_messages = state.get("messages", []) + updated_messages = list(state_messages) if response.content: - state_messages.append( + updated_messages.append( HumanMessage(content=response.content, name="coordinator") ) return Command( update={ - "messages": state_messages, + "messages": updated_messages, "locale": locale, "research_topic": research_topic, "resources": configurable.resources, "clarification_rounds": clarification_rounds, "clarification_history": clarification_history, + "clarified_research_topic": clarified_topic, "is_clarification_complete": False, "clarified_question": "", "goto": goto, @@ -521,7 +485,7 @@ def coordinator_node( else: # Max rounds reached - no more questions allowed logger.warning( - f"Max clarification rounds ({max_clarification_rounds}) reached. Handing off to planner." + f"Max clarification rounds ({max_clarification_rounds}) reached. Handing off to planner. Using prepared clarified topic: {clarified_topic}" ) goto = "planner" if state.get("enable_background_investigation"): @@ -539,7 +503,7 @@ def coordinator_node( # ============================================================ # Final: Build and return Command # ============================================================ - messages = state.get("messages", []) + messages = list(state.get("messages", []) or []) if response.content: messages.append(HumanMessage(content=response.content, name="coordinator")) @@ -554,10 +518,20 @@ def coordinator_node( logger.info("Handing off to planner") goto = "planner" - # Extract locale and research_topic if provided - if tool_args.get("locale") and tool_args.get("research_topic"): - locale = tool_args.get("locale") - research_topic = tool_args.get("research_topic") + # Extract locale if provided + locale = tool_args.get("locale", locale) + if not enable_clarification and tool_args.get("research_topic"): + research_topic = tool_args["research_topic"] + + if enable_clarification: + logger.info( + "Using prepared clarified topic: %s", + clarified_topic or research_topic, + ) + else: + logger.info( + "Using research topic for handoff: %s", research_topic + ) break except Exception as e: @@ -584,16 +558,24 @@ def coordinator_node( clarification_rounds = 0 clarification_history = [] + clarified_research_topic_value = clarified_topic or research_topic + + if enable_clarification: + handoff_topic = clarified_topic or research_topic + else: + handoff_topic = research_topic + return Command( update={ "messages": messages, "locale": locale, "research_topic": research_topic, + "clarified_research_topic": clarified_research_topic_value, "resources": configurable.resources, "clarification_rounds": clarification_rounds, "clarification_history": clarification_history, "is_clarification_complete": goto != "coordinator", - "clarified_question": research_topic if goto != "coordinator" else "", + "clarified_question": handoff_topic if goto != "coordinator" else "", "goto": goto, }, goto=goto, @@ -747,14 +729,15 @@ async def _execute_agent_step( ) except Exception as e: import traceback + error_traceback = traceback.format_exc() error_message = f"Error executing {agent_name} agent for step '{current_step.title}': {str(e)}" logger.exception(error_message) logger.error(f"Full traceback:\n{error_traceback}") - + detailed_error = f"[ERROR] {agent_name.capitalize()} Agent Error\n\nStep: {current_step.title}\n\nError Details:\n{str(e)}\n\nPlease check the logs for more information." current_step.execution_res = detailed_error - + return Command( update={ "messages": [ diff --git a/src/graph/types.py b/src/graph/types.py index f85ad57..8bac1c0 100644 --- a/src/graph/types.py +++ b/src/graph/types.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: MIT +from dataclasses import field + from langgraph.graph import MessagesState from src.prompts.planner_model import Plan @@ -14,6 +16,7 @@ class State(MessagesState): # Runtime Variables locale: str = "en-US" research_topic: str = "" + clarified_research_topic: str = "" observations: list[str] = [] resources: list[Resource] = [] plan_iterations: int = 0 @@ -28,7 +31,7 @@ class State(MessagesState): False # Enable/disable clarification feature (default: False) ) clarification_rounds: int = 0 - clarification_history: list[str] = [] + clarification_history: list[str] = field(default_factory=list) is_clarification_complete: bool = False clarified_question: str = "" max_clarification_rounds: int = ( diff --git a/src/graph/utils.py b/src/graph/utils.py new file mode 100644 index 0000000..2a2c0b4 --- /dev/null +++ b/src/graph/utils.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +from typing import Any + +ASSISTANT_SPEAKER_NAMES = { + "coordinator", + "planner", + "researcher", + "coder", + "reporter", + "background_investigator", +} + + +def get_message_content(message: Any) -> str: + """Extract message content from dict or LangChain message.""" + if isinstance(message, dict): + return message.get("content", "") + return getattr(message, "content", "") + + +def is_user_message(message: Any) -> bool: + """Return True if the message originated from the end user.""" + if isinstance(message, dict): + role = (message.get("role") or "").lower() + if role in {"user", "human"}: + return True + if role in {"assistant", "system"}: + return False + name = (message.get("name") or "").lower() + if name and name in ASSISTANT_SPEAKER_NAMES: + return False + return role == "" and name not in ASSISTANT_SPEAKER_NAMES + + message_type = (getattr(message, "type", "") or "").lower() + name = (getattr(message, "name", "") or "").lower() + if message_type == "human": + return not (name and name in ASSISTANT_SPEAKER_NAMES) + + role_attr = getattr(message, "role", None) + if isinstance(role_attr, str) and role_attr.lower() in {"user", "human"}: + return True + + additional_role = getattr(message, "additional_kwargs", {}).get("role") + if isinstance(additional_role, str) and additional_role.lower() in { + "user", + "human", + }: + return True + + return False + + +def get_latest_user_message(messages: list[Any]) -> tuple[Any, str]: + """Return the latest user-authored message and its content.""" + for message in reversed(messages or []): + if is_user_message(message): + content = get_message_content(message) + if content: + return message, content + return None, "" + + +def build_clarified_topic_from_history( + clarification_history: list[str], +) -> tuple[str, list[str]]: + """Construct clarified topic string from an ordered clarification history.""" + sequence = [item for item in clarification_history if item] + if not sequence: + return "", [] + if len(sequence) == 1: + return sequence[0], sequence + head, *tail = sequence + clarified_string = f"{head} - {', '.join(tail)}" + return clarified_string, sequence + + +def reconstruct_clarification_history( + messages: list[Any], + fallback_history: list[str] | None = None, + base_topic: str = "", +) -> list[str]: + """Rebuild clarification history from user-authored messages, with fallback. + + Args: + messages: Conversation messages in chronological order. + fallback_history: Optional existing history to use if no user messages found. + base_topic: Optional topic to use when no user messages are available. + + Returns: + A cleaned clarification history containing unique consecutive user contents. + """ + sequence: list[str] = [] + for message in messages or []: + if not is_user_message(message): + continue + content = get_message_content(message) + if not content: + continue + if sequence and sequence[-1] == content: + continue + sequence.append(content) + + if sequence: + return sequence + + fallback = [item for item in (fallback_history or []) if item] + if fallback: + return fallback + + base_topic = (base_topic or "").strip() + return [base_topic] if base_topic else [] diff --git a/src/server/app.py b/src/server/app.py index dc087a6..7e68b53 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -25,6 +25,10 @@ from src.config.report_style import ReportStyle from src.config.tools import SELECTED_RAG_PROVIDER from src.graph.builder import build_graph_with_memory from src.graph.checkpoint import chat_stream_message +from src.graph.utils import ( + build_clarified_topic_from_history, + reconstruct_clarification_history, +) from src.llms.llm import get_configured_llm_models from src.podcast.graph.builder import build_graph as build_podcast_graph from src.ppt.graph.builder import build_graph as build_ppt_graph @@ -160,7 +164,7 @@ def _create_event_stream_message( content = message_chunk.content if not isinstance(content, str): content = json.dumps(content, ensure_ascii=False) - + event_stream_message = { "thread_id": thread_id, "agent": agent_name, @@ -309,6 +313,14 @@ async def _astream_workflow_generator( if isinstance(message, dict) and "content" in message: _process_initial_messages(message, thread_id) + clarification_history = reconstruct_clarification_history(messages) + + clarified_topic, clarification_history = build_clarified_topic_from_history( + clarification_history + ) + latest_message_content = messages[-1]["content"] if messages else "" + clarified_research_topic = clarified_topic or latest_message_content + # Prepare workflow input workflow_input = { "messages": messages, @@ -318,7 +330,9 @@ async def _astream_workflow_generator( "observations": [], "auto_accepted_plan": auto_accepted_plan, "enable_background_investigation": enable_background_investigation, - "research_topic": messages[-1]["content"] if messages else "", + "research_topic": latest_message_content, + "clarification_history": clarification_history, + "clarified_research_topic": clarified_research_topic, "enable_clarification": enable_clarification, "max_clarification_rounds": max_clarification_rounds, } diff --git a/src/tools/search_postprocessor.py b/src/tools/search_postprocessor.py index 8f9813a..a250e7b 100644 --- a/src/tools/search_postprocessor.py +++ b/src/tools/search_postprocessor.py @@ -208,7 +208,7 @@ class SearchResultPostProcessor: url = image_url_val.get("url", "") else: url = image_url_val - + if url and url not in seen_urls: seen_urls.add(url) return result.copy() # Return a copy to avoid modifying original diff --git a/src/workflow.py b/src/workflow.py index 7687ce4..f8f1ee9 100644 --- a/src/workflow.py +++ b/src/workflow.py @@ -5,6 +5,7 @@ import logging from src.config.configuration import get_recursion_limit from src.graph import build_graph +from src.graph.utils import build_clarified_topic_from_history # Configure logging logging.basicConfig( @@ -65,6 +66,8 @@ async def run_agent_workflow_async( "auto_accepted_plan": True, "enable_background_investigation": enable_background_investigation, } + initial_state["research_topic"] = user_input + initial_state["clarified_research_topic"] = user_input # Only set clarification parameter if explicitly provided # If None, State class default will be used (enable_clarification=False) @@ -137,7 +140,18 @@ async def run_agent_workflow_async( current_state["messages"] = final_state["messages"] + [ {"role": "user", "content": user_response} ] - # Recursive call for clarification continuation + for key in ( + "clarification_history", + "clarification_rounds", + "clarified_research_topic", + "research_topic", + "locale", + "enable_clarification", + "max_clarification_rounds", + ): + if key in final_state: + current_state[key] = final_state[key] + return await run_agent_workflow_async( user_input=user_response, max_plan_iterations=max_plan_iterations, diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index 56ebe2e..ef1e49c 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -451,7 +451,9 @@ def test_human_feedback_node_accepted(monkeypatch, mock_state_base, mock_config) assert result.update["current_plan"]["has_enough_context"] is False -def test_human_feedback_node_invalid_interrupt(monkeypatch, mock_state_base, mock_config): +def test_human_feedback_node_invalid_interrupt( + monkeypatch, mock_state_base, mock_config +): # interrupt returns something else, should raise TypeError state = dict(mock_state_base) state["auto_accepted_plan"] = False @@ -490,7 +492,9 @@ def test_human_feedback_node_json_decode_error_second_iteration( assert result.goto == "reporter" -def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base, mock_config): +def test_human_feedback_node_not_enough_context( + monkeypatch, mock_state_base, mock_config +): # Plan does not have enough context, should goto research_team plan = { "has_enough_context": False, @@ -1446,7 +1450,9 @@ def test_handoff_tools(): assert result is None # Tool should return None (no-op) # Test handoff_after_clarification tool - use invoke() method - result = handoff_after_clarification.invoke({"locale": "en-US"}) + result = handoff_after_clarification.invoke( + {"locale": "en-US", "research_topic": "renewable energy research"} + ) assert result is None # Tool should return None (no-op) @@ -1468,9 +1474,13 @@ def test_coordinator_tools_with_clarification_enabled(mock_get_llm): "clarification_rounds": 2, "max_clarification_rounds": 3, "is_clarification_complete": False, - "clarification_history": ["response 1", "response 2"], + "clarification_history": [ + "Tell me about something", + "response 1", + "response 2", + ], "locale": "en-US", - "research_topic": "", + "research_topic": "Tell me about something", } # Mock config @@ -1567,3 +1577,289 @@ def test_coordinator_empty_llm_response_corner_case(mock_get_llm): # Should gracefully handle empty response by going to planner to ensure workflow continues assert result.goto == "planner" assert result.update["locale"] == "en-US" + + +# ============================================================================ +# Clarification flow tests +# ============================================================================ + + +def test_clarification_handoff_combines_history(): + """Coordinator should merge original topic with all clarification answers before handoff.""" + from langchain_core.messages import AIMessage + from langchain_core.runnables import RunnableConfig + + test_state = { + "messages": [ + {"role": "user", "content": "Research artificial intelligence"}, + {"role": "assistant", "content": "Which area of AI should we focus on?"}, + {"role": "user", "content": "Machine learning applications"}, + {"role": "assistant", "content": "What dimension of that should we cover?"}, + {"role": "user", "content": "Technical implementation details"}, + ], + "enable_clarification": True, + "clarification_rounds": 2, + "clarification_history": [ + "Research artificial intelligence", + "Machine learning applications", + "Technical implementation details", + ], + "max_clarification_rounds": 3, + "research_topic": "Research artificial intelligence", + "clarified_research_topic": "Research artificial intelligence - Machine learning applications, Technical implementation details", + "locale": "en-US", + } + + config = RunnableConfig(configurable={"thread_id": "clarification-test"}) + + mock_response = AIMessage( + content="Understood, handing off now.", + tool_calls=[ + { + "name": "handoff_after_clarification", + "args": {"locale": "en-US", "research_topic": "placeholder"}, + "id": "tool-call-handoff", + "type": "tool_call", + } + ], + ) + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm: + mock_llm = MagicMock() + mock_llm.bind_tools.return_value.invoke.return_value = mock_response + mock_get_llm.return_value = mock_llm + + result = coordinator_node(test_state, config) + + assert hasattr(result, "update") + update = result.update + assert update["clarification_history"] == [ + "Research artificial intelligence", + "Machine learning applications", + "Technical implementation details", + ] + expected_topic = ( + "Research artificial intelligence - " + "Machine learning applications, Technical implementation details" + ) + assert update["research_topic"] == "Research artificial intelligence" + assert update["clarified_research_topic"] == expected_topic + assert update["clarified_question"] == expected_topic + + +def test_clarification_history_reconstructed_from_messages(): + """Coordinator should rebuild clarification history from full message log when state is incomplete.""" + from langchain_core.messages import AIMessage + from langchain_core.runnables import RunnableConfig + + incomplete_state = { + "messages": [ + {"role": "user", "content": "Research on renewable energy"}, + { + "role": "assistant", + "content": "Which type of renewable energy interests you?", + }, + {"role": "user", "content": "Solar and wind energy"}, + {"role": "assistant", "content": "Which aspect should we focus on?"}, + {"role": "user", "content": "Technical implementation"}, + ], + "enable_clarification": True, + "clarification_rounds": 2, + "clarification_history": ["Technical implementation"], + "max_clarification_rounds": 3, + "research_topic": "Research on renewable energy", + "clarified_research_topic": "Research on renewable energy", + "locale": "en-US", + } + + config = RunnableConfig(configurable={"thread_id": "clarification-history-rebuild"}) + + mock_response = AIMessage( + content="Understood, handing over now.", + tool_calls=[ + { + "name": "handoff_after_clarification", + "args": {"locale": "en-US", "research_topic": "placeholder"}, + "id": "tool-call-handoff", + "type": "tool_call", + } + ], + ) + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm: + mock_llm = MagicMock() + mock_llm.bind_tools.return_value.invoke.return_value = mock_response + mock_get_llm.return_value = mock_llm + + result = coordinator_node(incomplete_state, config) + + update = result.update + assert update["clarification_history"] == [ + "Research on renewable energy", + "Solar and wind energy", + "Technical implementation", + ] + assert update["research_topic"] == "Research on renewable energy" + assert ( + update["clarified_research_topic"] + == "Research on renewable energy - Solar and wind energy, Technical implementation" + ) + + +def test_clarification_max_rounds_without_tool_call(): + """Coordinator should stop asking questions after max rounds and hand off with compiled topic.""" + from langchain_core.messages import AIMessage + from langchain_core.runnables import RunnableConfig + + test_state = { + "messages": [ + {"role": "user", "content": "Research artificial intelligence"}, + {"role": "assistant", "content": "Which area should we focus on?"}, + {"role": "user", "content": "Natural language processing"}, + {"role": "assistant", "content": "Which domain matters most?"}, + {"role": "user", "content": "Healthcare"}, + {"role": "assistant", "content": "Any specific scenario to study?"}, + {"role": "user", "content": "Clinical documentation"}, + ], + "enable_clarification": True, + "clarification_rounds": 3, + "clarification_history": [ + "Research artificial intelligence", + "Natural language processing", + "Healthcare", + "Clinical documentation", + ], + "max_clarification_rounds": 3, + "research_topic": "Research artificial intelligence", + "clarified_research_topic": "Research artificial intelligence - Natural language processing, Healthcare, Clinical documentation", + "locale": "en-US", + } + + config = RunnableConfig(configurable={"thread_id": "clarification-max"}) + + mock_response = AIMessage( + content="Got it, sending this to the planner.", + tool_calls=[], + ) + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm: + mock_llm = MagicMock() + mock_llm.bind_tools.return_value.invoke.return_value = mock_response + mock_get_llm.return_value = mock_llm + + result = coordinator_node(test_state, config) + + assert hasattr(result, "update") + update = result.update + expected_topic = ( + "Research artificial intelligence - " + "Natural language processing, Healthcare, Clinical documentation" + ) + assert update["research_topic"] == "Research artificial intelligence" + assert update["clarified_research_topic"] == expected_topic + assert result.goto == "planner" + + +def test_clarification_human_message_support(): + """Coordinator should treat HumanMessage instances from the user as user authored.""" + from langchain_core.messages import AIMessage, HumanMessage + from langchain_core.runnables import RunnableConfig + + test_state = { + "messages": [ + HumanMessage(content="Research artificial intelligence"), + HumanMessage(content="Which area should we focus on?", name="coordinator"), + HumanMessage(content="Machine learning"), + HumanMessage( + content="Which dimension should we explore?", name="coordinator" + ), + HumanMessage(content="Technical feasibility"), + ], + "enable_clarification": True, + "clarification_rounds": 2, + "clarification_history": [ + "Research artificial intelligence", + "Machine learning", + "Technical feasibility", + ], + "max_clarification_rounds": 3, + "research_topic": "Research artificial intelligence", + "clarified_research_topic": "Research artificial intelligence - Machine learning, Technical feasibility", + "locale": "en-US", + } + + config = RunnableConfig(configurable={"thread_id": "clarification-human"}) + + mock_response = AIMessage( + content="Moving to planner.", + tool_calls=[ + { + "name": "handoff_after_clarification", + "args": {"locale": "en-US", "research_topic": "placeholder"}, + "id": "human-message-handoff", + "type": "tool_call", + } + ], + ) + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm: + mock_llm = MagicMock() + mock_llm.bind_tools.return_value.invoke.return_value = mock_response + mock_get_llm.return_value = mock_llm + + result = coordinator_node(test_state, config) + + assert hasattr(result, "update") + update = result.update + expected_topic = ( + "Research artificial intelligence - Machine learning, Technical feasibility" + ) + assert update["clarification_history"] == [ + "Research artificial intelligence", + "Machine learning", + "Technical feasibility", + ] + assert update["research_topic"] == "Research artificial intelligence" + assert update["clarified_research_topic"] == expected_topic + + +def test_clarification_no_history_defaults_to_topic(): + """If clarification never started, coordinator should forward the original topic.""" + from langchain_core.messages import AIMessage + from langchain_core.runnables import RunnableConfig + + test_state = { + "messages": [{"role": "user", "content": "What is quantum computing?"}], + "enable_clarification": True, + "clarification_rounds": 0, + "clarification_history": ["What is quantum computing?"], + "max_clarification_rounds": 3, + "research_topic": "What is quantum computing?", + "clarified_research_topic": "What is quantum computing?", + "locale": "en-US", + } + + config = RunnableConfig(configurable={"thread_id": "clarification-none"}) + + mock_response = AIMessage( + content="Understood.", + tool_calls=[ + { + "name": "handoff_to_planner", + "args": {"locale": "en-US", "research_topic": "placeholder"}, + "id": "clarification-none", + "type": "tool_call", + } + ], + ) + + with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm: + mock_llm = MagicMock() + mock_llm.bind_tools.return_value.invoke.return_value = mock_response + mock_get_llm.return_value = mock_llm + + result = coordinator_node(test_state, config) + + assert hasattr(result, "update") + assert result.update["research_topic"] == "What is quantum computing?" + assert result.update["clarified_research_topic"] == "What is quantum computing?" diff --git a/tests/unit/server/test_app.py b/tests/unit/server/test_app.py index e73ea2f..314de0c 100644 --- a/tests/unit/server/test_app.py +++ b/tests/unit/server/test_app.py @@ -47,6 +47,79 @@ class TestMakeEvent: assert result == expected +@pytest.mark.asyncio +async def test_astream_workflow_generator_preserves_clarification_history(): + messages = [ + {"role": "user", "content": "Research on renewable energy"}, + { + "role": "assistant", + "content": "What type of renewable energy would you like to know about?", + }, + {"role": "user", "content": "Solar and wind energy"}, + { + "role": "assistant", + "content": "Please tell me the research dimensions you focus on, such as technological development or market applications.", + }, + {"role": "user", "content": "Technological development"}, + { + "role": "assistant", + "content": "Please specify the time range you want to focus on, such as current status or future trends.", + }, + {"role": "user", "content": "Current status and future trends"}, + ] + + captured_data = {} + + def empty_async_iterator(*args, **kwargs): + captured_data["workflow_input"] = args[1] + captured_data["workflow_config"] = args[2] + + class IteratorObject: + def __aiter__(self): + return self + + async def __anext__(self): + raise StopAsyncIteration + + return IteratorObject() + + with ( + patch("src.server.app._process_initial_messages"), + patch("src.server.app._stream_graph_events", side_effect=empty_async_iterator), + ): + generator = _astream_workflow_generator( + messages=messages, + thread_id="clarification-thread", + resources=[], + max_plan_iterations=1, + max_step_num=1, + max_search_results=5, + auto_accepted_plan=True, + interrupt_feedback="", + mcp_settings={}, + enable_background_investigation=True, + report_style=ReportStyle.ACADEMIC, + enable_deep_thinking=False, + enable_clarification=True, + max_clarification_rounds=3, + ) + + with pytest.raises(StopAsyncIteration): + await generator.__anext__() + + workflow_input = captured_data["workflow_input"] + assert workflow_input["clarification_history"] == [ + "Research on renewable energy", + "Solar and wind energy", + "Technological development", + "Current status and future trends", + ] + assert ( + workflow_input["clarified_research_topic"] + == "Research on renewable energy - Solar and wind energy, Technological development, Current status and future trends" + ) + + class TestTTSEndpoint: @patch.dict( os.environ,