diff --git a/src/graph/nodes.py b/src/graph/nodes.py index cb107f0..af313d5 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -26,7 +26,7 @@ from src.tools import ( python_repl_tool, ) from src.tools.search import LoggedTavilySearch -from src.utils.context_manager import ContextManager +from src.utils.context_manager import ContextManager, validate_message_content from src.utils.json_utils import repair_json_output from ..config import SELECTED_SEARCH_ENGINE, SearchEngine @@ -138,7 +138,8 @@ def background_investigation_node(state: State, config: RunnableConfig): logger.info("background investigation node is running.") configurable = Configuration.from_runnable_config(config) query = state.get("clarified_research_topic") or state.get("research_topic") - background_investigation_results = None + background_investigation_results = [] + if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value: searched_content = LoggedTavilySearch( max_results=configurable.max_search_results @@ -146,7 +147,27 @@ def background_investigation_node(state: State, config: RunnableConfig): # check if the searched_content is a tuple, then we need to unpack it if isinstance(searched_content, tuple): searched_content = searched_content[0] - if isinstance(searched_content, list): + + # Handle string JSON response (new format from fixed Tavily tool) + if isinstance(searched_content, str): + try: + parsed = json.loads(searched_content) + if isinstance(parsed, dict) and "error" in parsed: + logger.error(f"Tavily search error: {parsed['error']}") + background_investigation_results = [] + elif isinstance(parsed, list): + background_investigation_results = [ + f"## {elem.get('title', 'Untitled')}\n\n{elem.get('content', 'No content')}" + for elem in parsed + ] + else: + logger.error(f"Unexpected Tavily response format: {searched_content}") + background_investigation_results = [] + except json.JSONDecodeError: + logger.error(f"Failed to parse Tavily response as JSON: {searched_content}") + background_investigation_results = [] + # Handle legacy list format + elif isinstance(searched_content, list): background_investigation_results = [ f"## {elem['title']}\n\n{elem['content']}" for elem in searched_content ] @@ -159,10 +180,12 @@ def background_investigation_node(state: State, config: RunnableConfig): logger.error( f"Tavily search returned malformed response: {searched_content}" ) + background_investigation_results = [] else: background_investigation_results = get_web_search_tool( configurable.max_search_results ).invoke(query) + return { "background_investigation_results": json.dumps( background_investigation_results, ensure_ascii=False @@ -723,6 +746,14 @@ async def _execute_agent_step( recursion_limit = default_recursion_limit logger.info(f"Agent input: {agent_input}") + + # Validate message content before invoking agent + try: + validated_messages = validate_message_content(agent_input["messages"]) + agent_input["messages"] = validated_messages + except Exception as validation_error: + logger.error(f"Error validating agent input messages: {validation_error}") + try: result = await agent.ainvoke( input=agent_input, config={"recursion_limit": recursion_limit} @@ -734,6 +765,15 @@ async def _execute_agent_step( error_message = f"Error executing {agent_name} agent for step '{current_step.title}': {str(e)}" logger.exception(error_message) logger.error(f"Full traceback:\n{error_traceback}") + + # Enhanced error diagnostics for content-related errors + if "Field required" in str(e) and "content" in str(e): + logger.error(f"Message content validation error detected") + for i, msg in enumerate(agent_input.get('messages', [])): + logger.error(f"Message {i}: type={type(msg).__name__}, " + f"has_content={hasattr(msg, 'content')}, " + f"content_type={type(msg.content).__name__ if hasattr(msg, 'content') else 'N/A'}, " + f"content_len={len(str(msg.content)) if hasattr(msg, 'content') and msg.content else 0}") detailed_error = f"[ERROR] {agent_name.capitalize()} Agent Error\n\nStep: {current_step.title}\n\nError Details:\n{str(e)}\n\nPlease check the logs for more information." current_step.execution_res = detailed_error diff --git a/src/tools/tavily_search/tavily_search_results_with_images.py b/src/tools/tavily_search/tavily_search_results_with_images.py index 7ecde9e..0fd83e6 100644 --- a/src/tools/tavily_search/tavily_search_results_with_images.py +++ b/src/tools/tavily_search/tavily_search_results_with_images.py @@ -129,12 +129,14 @@ class TavilySearchWithImages(TavilySearchResults): # type: ignore[override, ove ) except Exception as e: logger.error("Tavily search returned error: {}".format(e)) - return repr(e), {} + error_result = json.dumps({"error": repr(e)}, ensure_ascii=False) + return error_result, {} cleaned_results = self.api_wrapper.clean_results_with_images(raw_results) logger.debug( "sync: %s", json.dumps(cleaned_results, indent=2, ensure_ascii=False) ) - return cleaned_results, raw_results + result_json = json.dumps(cleaned_results, ensure_ascii=False) + return result_json, raw_results async def _arun( self, @@ -156,9 +158,11 @@ class TavilySearchWithImages(TavilySearchResults): # type: ignore[override, ove ) except Exception as e: logger.error("Tavily search returned error: {}".format(e)) - return repr(e), {} + error_result = json.dumps({"error": repr(e)}, ensure_ascii=False) + return error_result, {} cleaned_results = self.api_wrapper.clean_results_with_images(raw_results) logger.debug( "async: %s", json.dumps(cleaned_results, indent=2, ensure_ascii=False) ) - return cleaned_results, raw_results + result_json = json.dumps(cleaned_results, ensure_ascii=False) + return result_json, raw_results diff --git a/src/utils/context_manager.py b/src/utils/context_manager.py index b56d6a6..d551cda 100644 --- a/src/utils/context_manager.py +++ b/src/utils/context_manager.py @@ -264,3 +264,54 @@ class ContextManager: """ # TODO: summary implementation pass + + +def validate_message_content(messages: List[BaseMessage]) -> List[BaseMessage]: + """ + Validate and fix all messages to ensure they have valid content before sending to LLM. + + This function ensures: + 1. All messages have a content field + 2. No message has None or empty string content (except for legitimate empty responses) + 3. Complex objects (lists, dicts) are converted to JSON strings + + Args: + messages: List of messages to validate + + Returns: + List of validated messages with fixed content + """ + validated = [] + for i, msg in enumerate(messages): + try: + # Check if message has content attribute + if not hasattr(msg, 'content'): + logger.warning(f"Message {i} ({type(msg).__name__}) has no content attribute") + msg.content = "" + + # Handle None content + elif msg.content is None: + logger.warning(f"Message {i} ({type(msg).__name__}) has None content, setting to empty string") + msg.content = "" + + # Handle complex content types (convert to JSON) + elif isinstance(msg.content, (list, dict)): + logger.debug(f"Message {i} ({type(msg).__name__}) has complex content type {type(msg.content).__name__}, converting to JSON") + msg.content = json.dumps(msg.content, ensure_ascii=False) + + # Handle other non-string types + elif not isinstance(msg.content, str): + logger.debug(f"Message {i} ({type(msg).__name__}) has non-string content type {type(msg.content).__name__}, converting to string") + msg.content = str(msg.content) + + validated.append(msg) + except Exception as e: + logger.error(f"Error validating message {i}: {e}") + # Create a safe fallback message + if isinstance(msg, ToolMessage): + msg.content = json.dumps({"error": str(e)}, ensure_ascii=False) + else: + msg.content = f"[Error processing message: {str(e)}]" + validated.append(msg) + + return validated diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index ef1e49c..7496019 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -136,7 +136,7 @@ def test_background_investigation_node_malformed_response( # Parse and verify the JSON content results = result["background_investigation_results"] - assert json.loads(results) is None + assert json.loads(results) == [] @pytest.fixture diff --git a/tests/unit/tools/test_tavily_search_results_with_images.py b/tests/unit/tools/test_tavily_search_results_with_images.py index 1e24f04..e0f5301 100644 --- a/tests/unit/tools/test_tavily_search_results_with_images.py +++ b/tests/unit/tools/test_tavily_search_results_with_images.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import json from unittest.mock import AsyncMock, Mock, patch import pytest @@ -88,7 +89,7 @@ class TestTavilySearchWithImages: result, raw = search_tool._run("test query") - assert result == sample_cleaned_results + assert result == json.dumps(sample_cleaned_results, ensure_ascii=False) assert raw == sample_raw_results mock_api_wrapper.raw_results.assert_called_once_with( @@ -113,7 +114,9 @@ class TestTavilySearchWithImages: result, raw = search_tool._run("test query") - assert "API Error" in result + result_dict = json.loads(result) + assert "error" in result_dict + assert "API Error" in result_dict["error"] assert raw == {} mock_api_wrapper.clean_results_with_images.assert_not_called() @@ -131,7 +134,7 @@ class TestTavilySearchWithImages: result, raw = await search_tool._arun("test query") - assert result == sample_cleaned_results + assert result == json.dumps(sample_cleaned_results, ensure_ascii=False) assert raw == sample_raw_results mock_api_wrapper.raw_results_async.assert_called_once_with( @@ -159,7 +162,9 @@ class TestTavilySearchWithImages: result, raw = await search_tool._arun("test query") - assert "Async API Error" in result + result_dict = json.loads(result) + assert "error" in result_dict + assert "Async API Error" in result_dict["error"] assert raw == {} mock_api_wrapper.clean_results_with_images.assert_not_called() @@ -177,7 +182,7 @@ class TestTavilySearchWithImages: result, raw = search_tool._run("test query", run_manager=mock_run_manager) - assert result == sample_cleaned_results + assert result == json.dumps(sample_cleaned_results, ensure_ascii=False) assert raw == sample_raw_results @pytest.mark.asyncio @@ -197,5 +202,5 @@ class TestTavilySearchWithImages: "test query", run_manager=mock_run_manager ) - assert result == sample_cleaned_results + assert result == json.dumps(sample_cleaned_results, ensure_ascii=False) assert raw == sample_raw_results