From d4ab77de5c630855b13c735828c61dcc076294cd Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Wed, 7 Jan 2026 09:06:16 +0800 Subject: [PATCH] fix: migrate from deprecated create_react_agent to langchain.agents.create_agent (#802) * fix: migrate from deprecated create_react_agent to langchain.agents.create_agent Fixes #799 - Replace deprecated langgraph.prebuilt.create_react_agent with langchain.agents.create_agent (LangGraph 1.0 migration) - Add DynamicPromptMiddleware to handle dynamic prompt templates (replaces the 'prompt' callable parameter) - Add PreModelHookMiddleware to handle pre-model hooks (replaces the 'pre_model_hook' parameter) - Update AgentState import from langchain.agents in template.py - Update tests to use the new API * fix:update the code with review comments --- src/agents/agents.py | 107 +++++- src/prompts/template.py | 2 +- .../test_tool_interceptor_integration.py | 4 +- tests/unit/agents/test_middleware.py | 335 ++++++++++++++++++ .../graph/test_agent_locale_restoration.py | 6 +- 5 files changed, 440 insertions(+), 14 deletions(-) create mode 100644 tests/unit/agents/test_middleware.py diff --git a/src/agents/agents.py b/src/agents/agents.py index 7310cc1..ee3b999 100644 --- a/src/agents/agents.py +++ b/src/agents/agents.py @@ -1,10 +1,14 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import asyncio +import inspect import logging -from typing import List, Optional +from typing import Any, Callable, List, Optional -from langgraph.prebuilt import create_react_agent +from langchain.agents import create_agent as langchain_create_agent +from langchain.agents.middleware import AgentMiddleware +from langgraph.runtime import Runtime from src.agents.tool_interceptor import wrap_tools_with_interceptor from src.config.agents import AGENT_LLM_MAP @@ -14,6 +18,88 @@ from src.prompts import apply_prompt_template logger = logging.getLogger(__name__) +class DynamicPromptMiddleware(AgentMiddleware): + """Middleware to apply dynamic prompt template before model invocation. + + This middleware prepends a system message with the rendered prompt template + to the messages list before the model is called. + """ + + def __init__(self, prompt_template: str, locale: str = "en-US"): + self.prompt_template = prompt_template + self.locale = locale + + def before_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None: + """Apply prompt template and prepend system message to messages.""" + try: + # Get the rendered messages including system prompt from template + rendered_messages = apply_prompt_template( + self.prompt_template, state, locale=self.locale + ) + # The first message is the system prompt, extract it + if rendered_messages and len(rendered_messages) > 0: + system_message = rendered_messages[0] + # Prepend system message to existing messages + return {"messages": [system_message]} + return None + except Exception as e: + logger.error( + f"Failed to apply prompt template in before_model: {e}", + exc_info=True + ) + return None + + async def abefore_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None: + """Async version of before_model.""" + return self.before_model(state, runtime) + + +class PreModelHookMiddleware(AgentMiddleware): + """Middleware to execute a pre-model hook before model invocation. + + This middleware wraps the legacy pre_model_hook callable and executes it + as part of the middleware chain. + """ + + def __init__(self, pre_model_hook: Callable): + self._pre_model_hook = pre_model_hook + + def before_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None: + """Execute the pre-model hook.""" + if not self._pre_model_hook: + return None + + try: + result = self._pre_model_hook(state, runtime) + return result + except Exception as e: + logger.error( + f"Pre-model hook execution failed in before_model: {e}", + exc_info=True + ) + return None + + async def abefore_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None: + """Async version of before_model.""" + if not self._pre_model_hook: + return None + + try: + # Check if the hook is async + if inspect.iscoroutinefunction(self._pre_model_hook): + result = await self._pre_model_hook(state, runtime) + else: + # Run synchronous hook in thread pool to avoid blocking event loop + result = await asyncio.to_thread(self._pre_model_hook, state, runtime) + return result + except Exception as e: + logger.error( + f"Pre-model hook execution failed in abefore_model: {e}", + exc_info=True + ) + return None + + # Create agents using configured LLM types def create_agent( agent_name: str, @@ -64,18 +150,23 @@ def create_agent( llm_type = AGENT_LLM_MAP.get(agent_type, "basic") logger.debug(f"Agent '{agent_name}' using LLM type: {llm_type}") - logger.debug(f"Creating ReAct agent '{agent_name}' with locale: {locale}") + logger.debug(f"Creating agent '{agent_name}' with locale: {locale}") + + # Build middleware list # Use closure to capture locale from the workflow state instead of relying on # agent state.get("locale"), which doesn't have the locale field # See: https://github.com/bytedance/deer-flow/issues/743 - agent = create_react_agent( + middleware = [DynamicPromptMiddleware(prompt_template, locale)] + + # Add pre-model hook middleware if provided + if pre_model_hook: + middleware.append(PreModelHookMiddleware(pre_model_hook)) + + agent = langchain_create_agent( name=agent_name, model=get_llm_by_type(llm_type), tools=processed_tools, - prompt=lambda state, captured_locale=locale: apply_prompt_template( - prompt_template, state, locale=captured_locale - ), - pre_model_hook=pre_model_hook, + middleware=middleware, ) logger.info(f"Agent '{agent_name}' created successfully") diff --git a/src/prompts/template.py b/src/prompts/template.py index 842e7c8..cba167f 100644 --- a/src/prompts/template.py +++ b/src/prompts/template.py @@ -6,7 +6,7 @@ import os from datetime import datetime from jinja2 import Environment, FileSystemLoader, TemplateNotFound, select_autoescape -from langgraph.prebuilt.chat_agent_executor import AgentState +from langchain.agents import AgentState from src.config.configuration import Configuration diff --git a/tests/integration/test_tool_interceptor_integration.py b/tests/integration/test_tool_interceptor_integration.py index 20b017e..73e37d6 100644 --- a/tests/integration/test_tool_interceptor_integration.py +++ b/tests/integration/test_tool_interceptor_integration.py @@ -42,7 +42,7 @@ class TestToolInterceptorIntegration: tools = [search_tool, db_tool] # Create agent with interrupts on db_tool only - with patch("src.agents.agents.create_react_agent") as mock_create, \ + with patch("src.agents.agents.langchain_create_agent") as mock_create, \ patch("src.agents.agents.get_llm_by_type") as mock_llm: mock_create.return_value = MagicMock() mock_llm.return_value = MagicMock() @@ -55,7 +55,7 @@ class TestToolInterceptorIntegration: interrupt_before_tools=["db_tool"], ) - # Verify create_react_agent was called with wrapped tools + # Verify langchain_create_agent was called with wrapped tools assert mock_create.called call_args = mock_create.call_args wrapped_tools = call_args.kwargs["tools"] diff --git a/tests/unit/agents/test_middleware.py b/tests/unit/agents/test_middleware.py new file mode 100644 index 0000000..8099462 --- /dev/null +++ b/tests/unit/agents/test_middleware.py @@ -0,0 +1,335 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import asyncio +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +from langchain_core.messages import HumanMessage, SystemMessage + +from src.agents.agents import DynamicPromptMiddleware, PreModelHookMiddleware + + +@pytest.fixture +def mock_runtime(): + """Mock Runtime object.""" + runtime = MagicMock() + runtime.config = {} + return runtime + + +@pytest.fixture +def mock_state(): + """Mock state object.""" + return { + "messages": [HumanMessage(content="Test message")], + "context": "Test context", + } + + +@pytest.fixture +def mock_messages(): + """Mock messages returned by apply_prompt_template.""" + return [ + SystemMessage(content="Test system prompt"), + HumanMessage(content="Test human message"), + ] + + +class TestDynamicPromptMiddleware: + """Tests for DynamicPromptMiddleware class.""" + + def test_init(self): + """Test middleware initialization.""" + middleware = DynamicPromptMiddleware("test_template", locale="zh-CN") + assert middleware.prompt_template == "test_template" + assert middleware.locale == "zh-CN" + + def test_init_default_locale(self): + """Test middleware initialization with default locale.""" + middleware = DynamicPromptMiddleware("test_template") + assert middleware.prompt_template == "test_template" + assert middleware.locale == "en-US" + + @patch("src.agents.agents.apply_prompt_template") + def test_before_model_success( + self, mock_apply_template, mock_state, mock_runtime, mock_messages + ): + """Test before_model successfully applies prompt template.""" + mock_apply_template.return_value = mock_messages + middleware = DynamicPromptMiddleware("test_template", locale="en-US") + + result = middleware.before_model(mock_state, mock_runtime) + + # Verify apply_prompt_template was called with correct arguments + mock_apply_template.assert_called_once_with( + "test_template", mock_state, locale="en-US" + ) + + # Verify system message is returned + assert result == {"messages": [mock_messages[0]]} + assert result["messages"][0].content == "Test system prompt" + + @patch("src.agents.agents.apply_prompt_template") + def test_before_model_empty_messages( + self, mock_apply_template, mock_state, mock_runtime + ): + """Test before_model with empty message list.""" + mock_apply_template.return_value = [] + middleware = DynamicPromptMiddleware("test_template") + + result = middleware.before_model(mock_state, mock_runtime) + + # Should return None when no messages are rendered + assert result is None + + @patch("src.agents.agents.apply_prompt_template") + def test_before_model_none_messages( + self, mock_apply_template, mock_state, mock_runtime + ): + """Test before_model when apply_prompt_template returns None.""" + mock_apply_template.return_value = None + middleware = DynamicPromptMiddleware("test_template") + + result = middleware.before_model(mock_state, mock_runtime) + + # Should return None when template returns None + assert result is None + + @patch("src.agents.agents.apply_prompt_template") + @patch("src.agents.agents.logger") + def test_before_model_exception_handling( + self, mock_logger, mock_apply_template, mock_state, mock_runtime + ): + """Test before_model handles exceptions gracefully.""" + mock_apply_template.side_effect = ValueError("Template rendering failed") + middleware = DynamicPromptMiddleware("test_template") + + result = middleware.before_model(mock_state, mock_runtime) + + # Should return None on exception + assert result is None + + # Should log error with exc_info + mock_logger.error.assert_called_once() + error_message = mock_logger.error.call_args[0][0] + assert "Failed to apply prompt template in before_model" in error_message + assert mock_logger.error.call_args[1]["exc_info"] is True + + @patch("src.agents.agents.apply_prompt_template") + def test_before_model_with_different_locale( + self, mock_apply_template, mock_state, mock_runtime, mock_messages + ): + """Test before_model with different locale.""" + mock_apply_template.return_value = mock_messages + middleware = DynamicPromptMiddleware("test_template", locale="zh-CN") + + result = middleware.before_model(mock_state, mock_runtime) + + # Verify locale is passed correctly + mock_apply_template.assert_called_once_with( + "test_template", mock_state, locale="zh-CN" + ) + assert result == {"messages": [mock_messages[0]]} + + @pytest.mark.asyncio + @patch("src.agents.agents.apply_prompt_template") + async def test_abefore_model( + self, mock_apply_template, mock_state, mock_runtime, mock_messages + ): + """Test async version of before_model.""" + mock_apply_template.return_value = mock_messages + middleware = DynamicPromptMiddleware("test_template") + + result = await middleware.abefore_model(mock_state, mock_runtime) + + # Should call the sync version and return same result + assert result == {"messages": [mock_messages[0]]} + mock_apply_template.assert_called_once_with( + "test_template", mock_state, locale="en-US" + ) + + +class TestPreModelHookMiddleware: + """Tests for PreModelHookMiddleware class.""" + + def test_init(self): + """Test middleware initialization.""" + hook = Mock() + middleware = PreModelHookMiddleware(hook) + assert middleware._pre_model_hook == hook + + def test_before_model_with_sync_hook(self, mock_state, mock_runtime): + """Test before_model with synchronous hook.""" + hook = Mock(return_value={"custom_data": "test"}) + middleware = PreModelHookMiddleware(hook) + + result = middleware.before_model(mock_state, mock_runtime) + + # Verify hook was called with correct arguments + hook.assert_called_once_with(mock_state, mock_runtime) + assert result == {"custom_data": "test"} + + def test_before_model_with_none_hook(self, mock_state, mock_runtime): + """Test before_model when hook is None.""" + middleware = PreModelHookMiddleware(None) + + result = middleware.before_model(mock_state, mock_runtime) + + # Should return None when hook is None + assert result is None + + def test_before_model_hook_returns_none(self, mock_state, mock_runtime): + """Test before_model when hook returns None.""" + hook = Mock(return_value=None) + middleware = PreModelHookMiddleware(hook) + + result = middleware.before_model(mock_state, mock_runtime) + + hook.assert_called_once_with(mock_state, mock_runtime) + assert result is None + + @patch("src.agents.agents.logger") + def test_before_model_hook_exception( + self, mock_logger, mock_state, mock_runtime + ): + """Test before_model handles hook exceptions gracefully.""" + hook = Mock(side_effect=RuntimeError("Hook execution failed")) + middleware = PreModelHookMiddleware(hook) + + result = middleware.before_model(mock_state, mock_runtime) + + # Should return None on exception + assert result is None + + # Should log error with exc_info + mock_logger.error.assert_called_once() + error_message = mock_logger.error.call_args[0][0] + assert "Pre-model hook execution failed in before_model" in error_message + assert mock_logger.error.call_args[1]["exc_info"] is True + + @pytest.mark.asyncio + async def test_abefore_model_with_async_hook(self, mock_state, mock_runtime): + """Test async before_model with async hook.""" + async def async_hook(state, runtime): + await asyncio.sleep(0.001) # Simulate async work + return {"async_data": "test"} + + middleware = PreModelHookMiddleware(async_hook) + + result = await middleware.abefore_model(mock_state, mock_runtime) + + assert result == {"async_data": "test"} + + @pytest.mark.asyncio + @patch("src.agents.agents.asyncio.to_thread") + async def test_abefore_model_with_sync_hook( + self, mock_to_thread, mock_state, mock_runtime + ): + """Test async before_model with synchronous hook uses asyncio.to_thread.""" + hook = Mock(return_value={"sync_data": "test"}) + mock_to_thread.return_value = {"sync_data": "test"} + middleware = PreModelHookMiddleware(hook) + + result = await middleware.abefore_model(mock_state, mock_runtime) + + # Verify asyncio.to_thread was called with the sync hook + mock_to_thread.assert_called_once_with(hook, mock_state, mock_runtime) + assert result == {"sync_data": "test"} + + @pytest.mark.asyncio + async def test_abefore_model_with_none_hook(self, mock_state, mock_runtime): + """Test async before_model when hook is None.""" + middleware = PreModelHookMiddleware(None) + + result = await middleware.abefore_model(mock_state, mock_runtime) + + # Should return None when hook is None + assert result is None + + @pytest.mark.asyncio + @patch("src.agents.agents.logger") + async def test_abefore_model_async_hook_exception( + self, mock_logger, mock_state, mock_runtime + ): + """Test async before_model handles async hook exceptions gracefully.""" + async def failing_hook(state, runtime): + raise ValueError("Async hook failed") + + middleware = PreModelHookMiddleware(failing_hook) + + result = await middleware.abefore_model(mock_state, mock_runtime) + + # Should return None on exception + assert result is None + + # Should log error with exc_info + mock_logger.error.assert_called_once() + error_message = mock_logger.error.call_args[0][0] + assert "Pre-model hook execution failed in abefore_model" in error_message + assert mock_logger.error.call_args[1]["exc_info"] is True + + @pytest.mark.asyncio + @patch("src.agents.agents.asyncio.to_thread") + @patch("src.agents.agents.logger") + async def test_abefore_model_sync_hook_exception( + self, mock_logger, mock_to_thread, mock_state, mock_runtime + ): + """Test async before_model handles sync hook exceptions gracefully.""" + hook = Mock() + mock_to_thread.side_effect = RuntimeError("Thread execution failed") + middleware = PreModelHookMiddleware(hook) + + result = await middleware.abefore_model(mock_state, mock_runtime) + + # Should return None on exception + assert result is None + + # Should log error with exc_info + mock_logger.error.assert_called_once() + error_message = mock_logger.error.call_args[0][0] + assert "Pre-model hook execution failed in abefore_model" in error_message + assert mock_logger.error.call_args[1]["exc_info"] is True + + @pytest.mark.asyncio + async def test_abefore_model_sync_hook_actual_execution( + self, mock_state, mock_runtime + ): + """Test async before_model actually runs sync hook in thread pool.""" + # Track if hook was called + hook_called = [] + + def sync_hook(state, runtime): + hook_called.append(True) + return {"data": "from_sync_hook"} + + middleware = PreModelHookMiddleware(sync_hook) + + result = await middleware.abefore_model(mock_state, mock_runtime) + + # Verify hook was called and result returned + assert len(hook_called) == 1 + assert result == {"data": "from_sync_hook"} + + @pytest.mark.asyncio + async def test_abefore_model_detects_coroutine_function( + self, mock_state, mock_runtime + ): + """Test that abefore_model correctly detects async vs sync functions.""" + # Test with async function + async def async_hook(state, runtime): + return {"type": "async"} + + # Test with sync function + def sync_hook(state, runtime): + return {"type": "sync"} + + async_middleware = PreModelHookMiddleware(async_hook) + sync_middleware = PreModelHookMiddleware(sync_hook) + + # Both should execute successfully + async_result = await async_middleware.abefore_model(mock_state, mock_runtime) + sync_result = await sync_middleware.abefore_model(mock_state, mock_runtime) + + assert async_result == {"type": "async"} + assert sync_result == {"type": "sync"} diff --git a/tests/unit/graph/test_agent_locale_restoration.py b/tests/unit/graph/test_agent_locale_restoration.py index 34ade1e..570b3fb 100644 --- a/tests/unit/graph/test_agent_locale_restoration.py +++ b/tests/unit/graph/test_agent_locale_restoration.py @@ -2,10 +2,10 @@ # SPDX-License-Identifier: MIT """ -Unit tests for agent locale restoration after create_react_agent execution. +Unit tests for agent locale restoration after agent execution. Tests that meta fields (especially locale) are properly restored after -agent.ainvoke() returns, since create_react_agent creates a MessagesState +agent.ainvoke() returns, since the agent creates a MessagesState subgraph that filters out custom fields. """ @@ -22,7 +22,7 @@ class TestAgentLocaleRestoration: """ Demonstrate the problem: agent subgraph filters out locale. - When create_react_agent creates a subgraph with MessagesState, + When the agent creates a subgraph with MessagesState, it only returns messages, not custom fields. """ # Simulate agent behavior: only returns messages