From 756421c3ac30fd9b8e7ce1bad3f63d5181de3e1e Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Wed, 28 Jan 2026 21:25:16 +0800 Subject: [PATCH] fix(mcp-tool): using the async invocation for MCP tools (#840) --- src/graph/nodes.py | 5 +++-- tests/integration/test_nodes.py | 30 +++++++++++++++--------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/graph/nodes.py b/src/graph/nodes.py index f8c4591..f48a847 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -1125,10 +1125,11 @@ async def _execute_agent_step( ) try: - # Use stream from the start to capture messages in real-time + # Use astream (async) from the start to capture messages in real-time # This allows us to retrieve accumulated messages even if recursion limit is hit + # NOTE: astream is required for MCP tools which only support async invocation accumulated_messages = [] - for chunk in agent.stream( + async for chunk in agent.astream( input=agent_input, config={"recursion_limit": recursion_limit}, stream_mode="values", diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index 46c6d90..5ff23f5 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -1107,12 +1107,12 @@ def mock_agent(): # Simulate agent returning a message list return {"messages": [MagicMock(content="result content")]} - def stream(input, config, stream_mode): - # Simulate agent.stream() yielding messages + async def astream(input, config, stream_mode): + # Simulate agent.astream() yielding messages (async generator) yield {"messages": [MagicMock(content="result content")]} agent.ainvoke = ainvoke - agent.stream = stream + agent.astream = astream return agent @@ -1177,12 +1177,12 @@ async def test_execute_agent_step_with_resources_and_researcher(mock_step): assert any("DO NOT include inline citations" in m.content for m in messages) return {"messages": [MagicMock(content="resource result")]} - def stream(input, config, stream_mode): - # Simulate agent.stream() yielding messages + async def astream(input, config, stream_mode): + # Simulate agent.astream() yielding messages (async generator) yield {"messages": [MagicMock(content="resource result")]} agent.ainvoke = ainvoke - agent.stream = stream + agent.astream = astream with patch( "src.graph.nodes.HumanMessage", side_effect=lambda content, name=None: MagicMock(content=content, name=name), @@ -2424,8 +2424,8 @@ async def test_execute_agent_step_preserves_multiple_tool_messages(): ] return {"messages": messages} - def stream(input, config, stream_mode): - # Simulate agent.stream() yielding the final messages + async def astream(input, config, stream_mode): + # Simulate agent.astream() yielding the final messages (async generator) messages = [ AIMessage( content="I'll search for information about this topic.", @@ -2460,7 +2460,7 @@ async def test_execute_agent_step_preserves_multiple_tool_messages(): yield {"messages": messages} agent.ainvoke = mock_ainvoke - agent.stream = stream + agent.astream = astream # Execute the agent step with patch( @@ -2556,8 +2556,8 @@ async def test_execute_agent_step_single_tool_call_still_works(): ] return {"messages": messages} - def stream(input, config, stream_mode): - # Simulate agent.stream() yielding the messages + async def astream(input, config, stream_mode): + # Simulate agent.astream() yielding the messages (async generator) messages = [ AIMessage( content="I'll search for information.", @@ -2579,7 +2579,7 @@ async def test_execute_agent_step_single_tool_call_still_works(): yield {"messages": messages} agent.ainvoke = mock_ainvoke - agent.stream = stream + agent.astream = astream with patch( "src.graph.nodes.HumanMessage", @@ -2639,8 +2639,8 @@ async def test_execute_agent_step_no_tool_calls_still_works(): ] return {"messages": messages} - def stream(input, config, stream_mode): - # Simulate agent.stream() yielding messages without tool calls + async def astream(input, config, stream_mode): + # Simulate agent.astream() yielding messages without tool calls (async generator) messages = [ AIMessage( content="Based on my knowledge, here is the answer without needing to search." @@ -2649,7 +2649,7 @@ async def test_execute_agent_step_no_tool_calls_still_works(): yield {"messages": messages} agent.ainvoke = mock_ainvoke - agent.stream = stream + agent.astream = astream with patch( "src.graph.nodes.HumanMessage",