mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-10 01:04:46 +08:00
fix: migrate from deprecated create_react_agent to langchain.agents.create_agent (#802)
* fix: migrate from deprecated create_react_agent to langchain.agents.create_agent Fixes #799 - Replace deprecated langgraph.prebuilt.create_react_agent with langchain.agents.create_agent (LangGraph 1.0 migration) - Add DynamicPromptMiddleware to handle dynamic prompt templates (replaces the 'prompt' callable parameter) - Add PreModelHookMiddleware to handle pre-model hooks (replaces the 'pre_model_hook' parameter) - Update AgentState import from langchain.agents in template.py - Update tests to use the new API * fix:update the code with review comments
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain.agents import create_agent as langchain_create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from src.agents.tool_interceptor import wrap_tools_with_interceptor
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
@@ -14,6 +18,88 @@ from src.prompts import apply_prompt_template
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamicPromptMiddleware(AgentMiddleware):
|
||||
"""Middleware to apply dynamic prompt template before model invocation.
|
||||
|
||||
This middleware prepends a system message with the rendered prompt template
|
||||
to the messages list before the model is called.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt_template: str, locale: str = "en-US"):
|
||||
self.prompt_template = prompt_template
|
||||
self.locale = locale
|
||||
|
||||
def before_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Apply prompt template and prepend system message to messages."""
|
||||
try:
|
||||
# Get the rendered messages including system prompt from template
|
||||
rendered_messages = apply_prompt_template(
|
||||
self.prompt_template, state, locale=self.locale
|
||||
)
|
||||
# The first message is the system prompt, extract it
|
||||
if rendered_messages and len(rendered_messages) > 0:
|
||||
system_message = rendered_messages[0]
|
||||
# Prepend system message to existing messages
|
||||
return {"messages": [system_message]}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to apply prompt template in before_model: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
async def abefore_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Async version of before_model."""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
|
||||
class PreModelHookMiddleware(AgentMiddleware):
|
||||
"""Middleware to execute a pre-model hook before model invocation.
|
||||
|
||||
This middleware wraps the legacy pre_model_hook callable and executes it
|
||||
as part of the middleware chain.
|
||||
"""
|
||||
|
||||
def __init__(self, pre_model_hook: Callable):
|
||||
self._pre_model_hook = pre_model_hook
|
||||
|
||||
def before_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Execute the pre-model hook."""
|
||||
if not self._pre_model_hook:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = self._pre_model_hook(state, runtime)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Pre-model hook execution failed in before_model: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
async def abefore_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Async version of before_model."""
|
||||
if not self._pre_model_hook:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Check if the hook is async
|
||||
if inspect.iscoroutinefunction(self._pre_model_hook):
|
||||
result = await self._pre_model_hook(state, runtime)
|
||||
else:
|
||||
# Run synchronous hook in thread pool to avoid blocking event loop
|
||||
result = await asyncio.to_thread(self._pre_model_hook, state, runtime)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Pre-model hook execution failed in abefore_model: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Create agents using configured LLM types
|
||||
def create_agent(
|
||||
agent_name: str,
|
||||
@@ -64,18 +150,23 @@ def create_agent(
|
||||
llm_type = AGENT_LLM_MAP.get(agent_type, "basic")
|
||||
logger.debug(f"Agent '{agent_name}' using LLM type: {llm_type}")
|
||||
|
||||
logger.debug(f"Creating ReAct agent '{agent_name}' with locale: {locale}")
|
||||
logger.debug(f"Creating agent '{agent_name}' with locale: {locale}")
|
||||
|
||||
# Build middleware list
|
||||
# Use closure to capture locale from the workflow state instead of relying on
|
||||
# agent state.get("locale"), which doesn't have the locale field
|
||||
# See: https://github.com/bytedance/deer-flow/issues/743
|
||||
agent = create_react_agent(
|
||||
middleware = [DynamicPromptMiddleware(prompt_template, locale)]
|
||||
|
||||
# Add pre-model hook middleware if provided
|
||||
if pre_model_hook:
|
||||
middleware.append(PreModelHookMiddleware(pre_model_hook))
|
||||
|
||||
agent = langchain_create_agent(
|
||||
name=agent_name,
|
||||
model=get_llm_by_type(llm_type),
|
||||
tools=processed_tools,
|
||||
prompt=lambda state, captured_locale=locale: apply_prompt_template(
|
||||
prompt_template, state, locale=captured_locale
|
||||
),
|
||||
pre_model_hook=pre_model_hook,
|
||||
middleware=middleware,
|
||||
)
|
||||
logger.info(f"Agent '{agent_name}' created successfully")
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
from datetime import datetime
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, TemplateNotFound, select_autoescape
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from langchain.agents import AgentState
|
||||
|
||||
from src.config.configuration import Configuration
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class TestToolInterceptorIntegration:
|
||||
tools = [search_tool, db_tool]
|
||||
|
||||
# Create agent with interrupts on db_tool only
|
||||
with patch("src.agents.agents.create_react_agent") as mock_create, \
|
||||
with patch("src.agents.agents.langchain_create_agent") as mock_create, \
|
||||
patch("src.agents.agents.get_llm_by_type") as mock_llm:
|
||||
mock_create.return_value = MagicMock()
|
||||
mock_llm.return_value = MagicMock()
|
||||
@@ -55,7 +55,7 @@ class TestToolInterceptorIntegration:
|
||||
interrupt_before_tools=["db_tool"],
|
||||
)
|
||||
|
||||
# Verify create_react_agent was called with wrapped tools
|
||||
# Verify langchain_create_agent was called with wrapped tools
|
||||
assert mock_create.called
|
||||
call_args = mock_create.call_args
|
||||
wrapped_tools = call_args.kwargs["tools"]
|
||||
|
||||
335
tests/unit/agents/test_middleware.py
Normal file
335
tests/unit/agents/test_middleware.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from src.agents.agents import DynamicPromptMiddleware, PreModelHookMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime():
|
||||
"""Mock Runtime object."""
|
||||
runtime = MagicMock()
|
||||
runtime.config = {}
|
||||
return runtime
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
"""Mock state object."""
|
||||
return {
|
||||
"messages": [HumanMessage(content="Test message")],
|
||||
"context": "Test context",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messages():
|
||||
"""Mock messages returned by apply_prompt_template."""
|
||||
return [
|
||||
SystemMessage(content="Test system prompt"),
|
||||
HumanMessage(content="Test human message"),
|
||||
]
|
||||
|
||||
|
||||
class TestDynamicPromptMiddleware:
|
||||
"""Tests for DynamicPromptMiddleware class."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test middleware initialization."""
|
||||
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
|
||||
assert middleware.prompt_template == "test_template"
|
||||
assert middleware.locale == "zh-CN"
|
||||
|
||||
def test_init_default_locale(self):
|
||||
"""Test middleware initialization with default locale."""
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
assert middleware.prompt_template == "test_template"
|
||||
assert middleware.locale == "en-US"
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
def test_before_model_success(
|
||||
self, mock_apply_template, mock_state, mock_runtime, mock_messages
|
||||
):
|
||||
"""Test before_model successfully applies prompt template."""
|
||||
mock_apply_template.return_value = mock_messages
|
||||
middleware = DynamicPromptMiddleware("test_template", locale="en-US")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify apply_prompt_template was called with correct arguments
|
||||
mock_apply_template.assert_called_once_with(
|
||||
"test_template", mock_state, locale="en-US"
|
||||
)
|
||||
|
||||
# Verify system message is returned
|
||||
assert result == {"messages": [mock_messages[0]]}
|
||||
assert result["messages"][0].content == "Test system prompt"
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
def test_before_model_empty_messages(
|
||||
self, mock_apply_template, mock_state, mock_runtime
|
||||
):
|
||||
"""Test before_model with empty message list."""
|
||||
mock_apply_template.return_value = []
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None when no messages are rendered
|
||||
assert result is None
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
def test_before_model_none_messages(
|
||||
self, mock_apply_template, mock_state, mock_runtime
|
||||
):
|
||||
"""Test before_model when apply_prompt_template returns None."""
|
||||
mock_apply_template.return_value = None
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None when template returns None
|
||||
assert result is None
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
@patch("src.agents.agents.logger")
|
||||
def test_before_model_exception_handling(
|
||||
self, mock_logger, mock_apply_template, mock_state, mock_runtime
|
||||
):
|
||||
"""Test before_model handles exceptions gracefully."""
|
||||
mock_apply_template.side_effect = ValueError("Template rendering failed")
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None on exception
|
||||
assert result is None
|
||||
|
||||
# Should log error with exc_info
|
||||
mock_logger.error.assert_called_once()
|
||||
error_message = mock_logger.error.call_args[0][0]
|
||||
assert "Failed to apply prompt template in before_model" in error_message
|
||||
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
def test_before_model_with_different_locale(
|
||||
self, mock_apply_template, mock_state, mock_runtime, mock_messages
|
||||
):
|
||||
"""Test before_model with different locale."""
|
||||
mock_apply_template.return_value = mock_messages
|
||||
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify locale is passed correctly
|
||||
mock_apply_template.assert_called_once_with(
|
||||
"test_template", mock_state, locale="zh-CN"
|
||||
)
|
||||
assert result == {"messages": [mock_messages[0]]}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
async def test_abefore_model(
|
||||
self, mock_apply_template, mock_state, mock_runtime, mock_messages
|
||||
):
|
||||
"""Test async version of before_model."""
|
||||
mock_apply_template.return_value = mock_messages
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Should call the sync version and return same result
|
||||
assert result == {"messages": [mock_messages[0]]}
|
||||
mock_apply_template.assert_called_once_with(
|
||||
"test_template", mock_state, locale="en-US"
|
||||
)
|
||||
|
||||
|
||||
class TestPreModelHookMiddleware:
|
||||
"""Tests for PreModelHookMiddleware class."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test middleware initialization."""
|
||||
hook = Mock()
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
assert middleware._pre_model_hook == hook
|
||||
|
||||
def test_before_model_with_sync_hook(self, mock_state, mock_runtime):
|
||||
"""Test before_model with synchronous hook."""
|
||||
hook = Mock(return_value={"custom_data": "test"})
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify hook was called with correct arguments
|
||||
hook.assert_called_once_with(mock_state, mock_runtime)
|
||||
assert result == {"custom_data": "test"}
|
||||
|
||||
def test_before_model_with_none_hook(self, mock_state, mock_runtime):
|
||||
"""Test before_model when hook is None."""
|
||||
middleware = PreModelHookMiddleware(None)
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None when hook is None
|
||||
assert result is None
|
||||
|
||||
def test_before_model_hook_returns_none(self, mock_state, mock_runtime):
|
||||
"""Test before_model when hook returns None."""
|
||||
hook = Mock(return_value=None)
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
hook.assert_called_once_with(mock_state, mock_runtime)
|
||||
assert result is None
|
||||
|
||||
@patch("src.agents.agents.logger")
|
||||
def test_before_model_hook_exception(
|
||||
self, mock_logger, mock_state, mock_runtime
|
||||
):
|
||||
"""Test before_model handles hook exceptions gracefully."""
|
||||
hook = Mock(side_effect=RuntimeError("Hook execution failed"))
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None on exception
|
||||
assert result is None
|
||||
|
||||
# Should log error with exc_info
|
||||
mock_logger.error.assert_called_once()
|
||||
error_message = mock_logger.error.call_args[0][0]
|
||||
assert "Pre-model hook execution failed in before_model" in error_message
|
||||
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abefore_model_with_async_hook(self, mock_state, mock_runtime):
|
||||
"""Test async before_model with async hook."""
|
||||
async def async_hook(state, runtime):
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
return {"async_data": "test"}
|
||||
|
||||
middleware = PreModelHookMiddleware(async_hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
assert result == {"async_data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.agents.agents.asyncio.to_thread")
|
||||
async def test_abefore_model_with_sync_hook(
|
||||
self, mock_to_thread, mock_state, mock_runtime
|
||||
):
|
||||
"""Test async before_model with synchronous hook uses asyncio.to_thread."""
|
||||
hook = Mock(return_value={"sync_data": "test"})
|
||||
mock_to_thread.return_value = {"sync_data": "test"}
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify asyncio.to_thread was called with the sync hook
|
||||
mock_to_thread.assert_called_once_with(hook, mock_state, mock_runtime)
|
||||
assert result == {"sync_data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abefore_model_with_none_hook(self, mock_state, mock_runtime):
|
||||
"""Test async before_model when hook is None."""
|
||||
middleware = PreModelHookMiddleware(None)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None when hook is None
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.agents.agents.logger")
|
||||
async def test_abefore_model_async_hook_exception(
|
||||
self, mock_logger, mock_state, mock_runtime
|
||||
):
|
||||
"""Test async before_model handles async hook exceptions gracefully."""
|
||||
async def failing_hook(state, runtime):
|
||||
raise ValueError("Async hook failed")
|
||||
|
||||
middleware = PreModelHookMiddleware(failing_hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None on exception
|
||||
assert result is None
|
||||
|
||||
# Should log error with exc_info
|
||||
mock_logger.error.assert_called_once()
|
||||
error_message = mock_logger.error.call_args[0][0]
|
||||
assert "Pre-model hook execution failed in abefore_model" in error_message
|
||||
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.agents.agents.asyncio.to_thread")
|
||||
@patch("src.agents.agents.logger")
|
||||
async def test_abefore_model_sync_hook_exception(
|
||||
self, mock_logger, mock_to_thread, mock_state, mock_runtime
|
||||
):
|
||||
"""Test async before_model handles sync hook exceptions gracefully."""
|
||||
hook = Mock()
|
||||
mock_to_thread.side_effect = RuntimeError("Thread execution failed")
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None on exception
|
||||
assert result is None
|
||||
|
||||
# Should log error with exc_info
|
||||
mock_logger.error.assert_called_once()
|
||||
error_message = mock_logger.error.call_args[0][0]
|
||||
assert "Pre-model hook execution failed in abefore_model" in error_message
|
||||
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abefore_model_sync_hook_actual_execution(
|
||||
self, mock_state, mock_runtime
|
||||
):
|
||||
"""Test async before_model actually runs sync hook in thread pool."""
|
||||
# Track if hook was called
|
||||
hook_called = []
|
||||
|
||||
def sync_hook(state, runtime):
|
||||
hook_called.append(True)
|
||||
return {"data": "from_sync_hook"}
|
||||
|
||||
middleware = PreModelHookMiddleware(sync_hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify hook was called and result returned
|
||||
assert len(hook_called) == 1
|
||||
assert result == {"data": "from_sync_hook"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abefore_model_detects_coroutine_function(
|
||||
self, mock_state, mock_runtime
|
||||
):
|
||||
"""Test that abefore_model correctly detects async vs sync functions."""
|
||||
# Test with async function
|
||||
async def async_hook(state, runtime):
|
||||
return {"type": "async"}
|
||||
|
||||
# Test with sync function
|
||||
def sync_hook(state, runtime):
|
||||
return {"type": "sync"}
|
||||
|
||||
async_middleware = PreModelHookMiddleware(async_hook)
|
||||
sync_middleware = PreModelHookMiddleware(sync_hook)
|
||||
|
||||
# Both should execute successfully
|
||||
async_result = await async_middleware.abefore_model(mock_state, mock_runtime)
|
||||
sync_result = await sync_middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
assert async_result == {"type": "async"}
|
||||
assert sync_result == {"type": "sync"}
|
||||
@@ -2,10 +2,10 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for agent locale restoration after create_react_agent execution.
|
||||
Unit tests for agent locale restoration after agent execution.
|
||||
|
||||
Tests that meta fields (especially locale) are properly restored after
|
||||
agent.ainvoke() returns, since create_react_agent creates a MessagesState
|
||||
agent.ainvoke() returns, since the agent creates a MessagesState
|
||||
subgraph that filters out custom fields.
|
||||
"""
|
||||
|
||||
@@ -22,7 +22,7 @@ class TestAgentLocaleRestoration:
|
||||
"""
|
||||
Demonstrate the problem: agent subgraph filters out locale.
|
||||
|
||||
When create_react_agent creates a subgraph with MessagesState,
|
||||
When the agent creates a subgraph with MessagesState,
|
||||
it only returns messages, not custom fields.
|
||||
"""
|
||||
# Simulate agent behavior: only returns messages
|
||||
|
||||
Reference in New Issue
Block a user