mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
fix: migrate from deprecated create_react_agent to langchain.agents.create_agent (#802)
* fix: migrate from deprecated create_react_agent to langchain.agents.create_agent Fixes #799 - Replace deprecated langgraph.prebuilt.create_react_agent with langchain.agents.create_agent (LangGraph 1.0 migration) - Add DynamicPromptMiddleware to handle dynamic prompt templates (replaces the 'prompt' callable parameter) - Add PreModelHookMiddleware to handle pre-model hooks (replaces the 'pre_model_hook' parameter) - Update AgentState import from langchain.agents in template.py - Update tests to use the new API * fix:update the code with review comments
This commit is contained in:
@@ -42,7 +42,7 @@ class TestToolInterceptorIntegration:
|
||||
tools = [search_tool, db_tool]
|
||||
|
||||
# Create agent with interrupts on db_tool only
|
||||
with patch("src.agents.agents.create_react_agent") as mock_create, \
|
||||
with patch("src.agents.agents.langchain_create_agent") as mock_create, \
|
||||
patch("src.agents.agents.get_llm_by_type") as mock_llm:
|
||||
mock_create.return_value = MagicMock()
|
||||
mock_llm.return_value = MagicMock()
|
||||
@@ -55,7 +55,7 @@ class TestToolInterceptorIntegration:
|
||||
interrupt_before_tools=["db_tool"],
|
||||
)
|
||||
|
||||
# Verify create_react_agent was called with wrapped tools
|
||||
# Verify langchain_create_agent was called with wrapped tools
|
||||
assert mock_create.called
|
||||
call_args = mock_create.call_args
|
||||
wrapped_tools = call_args.kwargs["tools"]
|
||||
|
||||
335
tests/unit/agents/test_middleware.py
Normal file
335
tests/unit/agents/test_middleware.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from src.agents.agents import DynamicPromptMiddleware, PreModelHookMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime():
|
||||
"""Mock Runtime object."""
|
||||
runtime = MagicMock()
|
||||
runtime.config = {}
|
||||
return runtime
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
"""Mock state object."""
|
||||
return {
|
||||
"messages": [HumanMessage(content="Test message")],
|
||||
"context": "Test context",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messages():
|
||||
"""Mock messages returned by apply_prompt_template."""
|
||||
return [
|
||||
SystemMessage(content="Test system prompt"),
|
||||
HumanMessage(content="Test human message"),
|
||||
]
|
||||
|
||||
|
||||
class TestDynamicPromptMiddleware:
|
||||
"""Tests for DynamicPromptMiddleware class."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test middleware initialization."""
|
||||
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
|
||||
assert middleware.prompt_template == "test_template"
|
||||
assert middleware.locale == "zh-CN"
|
||||
|
||||
def test_init_default_locale(self):
|
||||
"""Test middleware initialization with default locale."""
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
assert middleware.prompt_template == "test_template"
|
||||
assert middleware.locale == "en-US"
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
def test_before_model_success(
|
||||
self, mock_apply_template, mock_state, mock_runtime, mock_messages
|
||||
):
|
||||
"""Test before_model successfully applies prompt template."""
|
||||
mock_apply_template.return_value = mock_messages
|
||||
middleware = DynamicPromptMiddleware("test_template", locale="en-US")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify apply_prompt_template was called with correct arguments
|
||||
mock_apply_template.assert_called_once_with(
|
||||
"test_template", mock_state, locale="en-US"
|
||||
)
|
||||
|
||||
# Verify system message is returned
|
||||
assert result == {"messages": [mock_messages[0]]}
|
||||
assert result["messages"][0].content == "Test system prompt"
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
def test_before_model_empty_messages(
|
||||
self, mock_apply_template, mock_state, mock_runtime
|
||||
):
|
||||
"""Test before_model with empty message list."""
|
||||
mock_apply_template.return_value = []
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None when no messages are rendered
|
||||
assert result is None
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
def test_before_model_none_messages(
|
||||
self, mock_apply_template, mock_state, mock_runtime
|
||||
):
|
||||
"""Test before_model when apply_prompt_template returns None."""
|
||||
mock_apply_template.return_value = None
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None when template returns None
|
||||
assert result is None
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
@patch("src.agents.agents.logger")
|
||||
def test_before_model_exception_handling(
|
||||
self, mock_logger, mock_apply_template, mock_state, mock_runtime
|
||||
):
|
||||
"""Test before_model handles exceptions gracefully."""
|
||||
mock_apply_template.side_effect = ValueError("Template rendering failed")
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None on exception
|
||||
assert result is None
|
||||
|
||||
# Should log error with exc_info
|
||||
mock_logger.error.assert_called_once()
|
||||
error_message = mock_logger.error.call_args[0][0]
|
||||
assert "Failed to apply prompt template in before_model" in error_message
|
||||
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
def test_before_model_with_different_locale(
|
||||
self, mock_apply_template, mock_state, mock_runtime, mock_messages
|
||||
):
|
||||
"""Test before_model with different locale."""
|
||||
mock_apply_template.return_value = mock_messages
|
||||
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify locale is passed correctly
|
||||
mock_apply_template.assert_called_once_with(
|
||||
"test_template", mock_state, locale="zh-CN"
|
||||
)
|
||||
assert result == {"messages": [mock_messages[0]]}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
async def test_abefore_model(
|
||||
self, mock_apply_template, mock_state, mock_runtime, mock_messages
|
||||
):
|
||||
"""Test async version of before_model."""
|
||||
mock_apply_template.return_value = mock_messages
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Should call the sync version and return same result
|
||||
assert result == {"messages": [mock_messages[0]]}
|
||||
mock_apply_template.assert_called_once_with(
|
||||
"test_template", mock_state, locale="en-US"
|
||||
)
|
||||
|
||||
|
||||
class TestPreModelHookMiddleware:
|
||||
"""Tests for PreModelHookMiddleware class."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test middleware initialization."""
|
||||
hook = Mock()
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
assert middleware._pre_model_hook == hook
|
||||
|
||||
def test_before_model_with_sync_hook(self, mock_state, mock_runtime):
|
||||
"""Test before_model with synchronous hook."""
|
||||
hook = Mock(return_value={"custom_data": "test"})
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify hook was called with correct arguments
|
||||
hook.assert_called_once_with(mock_state, mock_runtime)
|
||||
assert result == {"custom_data": "test"}
|
||||
|
||||
def test_before_model_with_none_hook(self, mock_state, mock_runtime):
|
||||
"""Test before_model when hook is None."""
|
||||
middleware = PreModelHookMiddleware(None)
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None when hook is None
|
||||
assert result is None
|
||||
|
||||
def test_before_model_hook_returns_none(self, mock_state, mock_runtime):
|
||||
"""Test before_model when hook returns None."""
|
||||
hook = Mock(return_value=None)
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
hook.assert_called_once_with(mock_state, mock_runtime)
|
||||
assert result is None
|
||||
|
||||
@patch("src.agents.agents.logger")
|
||||
def test_before_model_hook_exception(
|
||||
self, mock_logger, mock_state, mock_runtime
|
||||
):
|
||||
"""Test before_model handles hook exceptions gracefully."""
|
||||
hook = Mock(side_effect=RuntimeError("Hook execution failed"))
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None on exception
|
||||
assert result is None
|
||||
|
||||
# Should log error with exc_info
|
||||
mock_logger.error.assert_called_once()
|
||||
error_message = mock_logger.error.call_args[0][0]
|
||||
assert "Pre-model hook execution failed in before_model" in error_message
|
||||
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abefore_model_with_async_hook(self, mock_state, mock_runtime):
|
||||
"""Test async before_model with async hook."""
|
||||
async def async_hook(state, runtime):
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
return {"async_data": "test"}
|
||||
|
||||
middleware = PreModelHookMiddleware(async_hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
assert result == {"async_data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.agents.agents.asyncio.to_thread")
|
||||
async def test_abefore_model_with_sync_hook(
|
||||
self, mock_to_thread, mock_state, mock_runtime
|
||||
):
|
||||
"""Test async before_model with synchronous hook uses asyncio.to_thread."""
|
||||
hook = Mock(return_value={"sync_data": "test"})
|
||||
mock_to_thread.return_value = {"sync_data": "test"}
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify asyncio.to_thread was called with the sync hook
|
||||
mock_to_thread.assert_called_once_with(hook, mock_state, mock_runtime)
|
||||
assert result == {"sync_data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abefore_model_with_none_hook(self, mock_state, mock_runtime):
|
||||
"""Test async before_model when hook is None."""
|
||||
middleware = PreModelHookMiddleware(None)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None when hook is None
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.agents.agents.logger")
|
||||
async def test_abefore_model_async_hook_exception(
|
||||
self, mock_logger, mock_state, mock_runtime
|
||||
):
|
||||
"""Test async before_model handles async hook exceptions gracefully."""
|
||||
async def failing_hook(state, runtime):
|
||||
raise ValueError("Async hook failed")
|
||||
|
||||
middleware = PreModelHookMiddleware(failing_hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None on exception
|
||||
assert result is None
|
||||
|
||||
# Should log error with exc_info
|
||||
mock_logger.error.assert_called_once()
|
||||
error_message = mock_logger.error.call_args[0][0]
|
||||
assert "Pre-model hook execution failed in abefore_model" in error_message
|
||||
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.agents.agents.asyncio.to_thread")
|
||||
@patch("src.agents.agents.logger")
|
||||
async def test_abefore_model_sync_hook_exception(
|
||||
self, mock_logger, mock_to_thread, mock_state, mock_runtime
|
||||
):
|
||||
"""Test async before_model handles sync hook exceptions gracefully."""
|
||||
hook = Mock()
|
||||
mock_to_thread.side_effect = RuntimeError("Thread execution failed")
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None on exception
|
||||
assert result is None
|
||||
|
||||
# Should log error with exc_info
|
||||
mock_logger.error.assert_called_once()
|
||||
error_message = mock_logger.error.call_args[0][0]
|
||||
assert "Pre-model hook execution failed in abefore_model" in error_message
|
||||
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abefore_model_sync_hook_actual_execution(
|
||||
self, mock_state, mock_runtime
|
||||
):
|
||||
"""Test async before_model actually runs sync hook in thread pool."""
|
||||
# Track if hook was called
|
||||
hook_called = []
|
||||
|
||||
def sync_hook(state, runtime):
|
||||
hook_called.append(True)
|
||||
return {"data": "from_sync_hook"}
|
||||
|
||||
middleware = PreModelHookMiddleware(sync_hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify hook was called and result returned
|
||||
assert len(hook_called) == 1
|
||||
assert result == {"data": "from_sync_hook"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abefore_model_detects_coroutine_function(
|
||||
self, mock_state, mock_runtime
|
||||
):
|
||||
"""Test that abefore_model correctly detects async vs sync functions."""
|
||||
# Test with async function
|
||||
async def async_hook(state, runtime):
|
||||
return {"type": "async"}
|
||||
|
||||
# Test with sync function
|
||||
def sync_hook(state, runtime):
|
||||
return {"type": "sync"}
|
||||
|
||||
async_middleware = PreModelHookMiddleware(async_hook)
|
||||
sync_middleware = PreModelHookMiddleware(sync_hook)
|
||||
|
||||
# Both should execute successfully
|
||||
async_result = await async_middleware.abefore_model(mock_state, mock_runtime)
|
||||
sync_result = await sync_middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
assert async_result == {"type": "async"}
|
||||
assert sync_result == {"type": "sync"}
|
||||
@@ -2,10 +2,10 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for agent locale restoration after create_react_agent execution.
|
||||
Unit tests for agent locale restoration after agent execution.
|
||||
|
||||
Tests that meta fields (especially locale) are properly restored after
|
||||
agent.ainvoke() returns, since create_react_agent creates a MessagesState
|
||||
agent.ainvoke() returns, since the agent creates a MessagesState
|
||||
subgraph that filters out custom fields.
|
||||
"""
|
||||
|
||||
@@ -22,7 +22,7 @@ class TestAgentLocaleRestoration:
|
||||
"""
|
||||
Demonstrate the problem: agent subgraph filters out locale.
|
||||
|
||||
When create_react_agent creates a subgraph with MessagesState,
|
||||
When the agent creates a subgraph with MessagesState,
|
||||
it only returns messages, not custom fields.
|
||||
"""
|
||||
# Simulate agent behavior: only returns messages
|
||||
|
||||
Reference in New Issue
Block a user