fix: migrate from deprecated create_react_agent to langchain.agents.create_agent (#802)

* fix: migrate from deprecated create_react_agent to langchain.agents.create_agent

Fixes #799

- Replace deprecated langgraph.prebuilt.create_react_agent with
  langchain.agents.create_agent (LangGraph 1.0 migration)
- Add DynamicPromptMiddleware to handle dynamic prompt templates
  (replaces the 'prompt' callable parameter)
- Add PreModelHookMiddleware to handle pre-model hooks
  (replaces the 'pre_model_hook' parameter)
- Update AgentState import from langchain.agents in template.py
- Update tests to use the new API

* fix:update the code with review comments
This commit is contained in:
Willem Jiang
2026-01-07 09:06:16 +08:00
committed by GitHub
parent 1ced90b055
commit d4ab77de5c
5 changed files with 440 additions and 14 deletions

View File

@@ -1,10 +1,14 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import asyncio
import inspect
import logging
from typing import List, Optional
from typing import Any, Callable, List, Optional
from langgraph.prebuilt import create_react_agent
from langchain.agents import create_agent as langchain_create_agent
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from src.agents.tool_interceptor import wrap_tools_with_interceptor
from src.config.agents import AGENT_LLM_MAP
@@ -14,6 +18,88 @@ from src.prompts import apply_prompt_template
logger = logging.getLogger(__name__)
class DynamicPromptMiddleware(AgentMiddleware):
"""Middleware to apply dynamic prompt template before model invocation.
This middleware prepends a system message with the rendered prompt template
to the messages list before the model is called.
"""
def __init__(self, prompt_template: str, locale: str = "en-US"):
self.prompt_template = prompt_template
self.locale = locale
def before_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
"""Apply prompt template and prepend system message to messages."""
try:
# Get the rendered messages including system prompt from template
rendered_messages = apply_prompt_template(
self.prompt_template, state, locale=self.locale
)
# The first message is the system prompt, extract it
if rendered_messages and len(rendered_messages) > 0:
system_message = rendered_messages[0]
# Prepend system message to existing messages
return {"messages": [system_message]}
return None
except Exception as e:
logger.error(
f"Failed to apply prompt template in before_model: {e}",
exc_info=True
)
return None
async def abefore_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
"""Async version of before_model."""
return self.before_model(state, runtime)
class PreModelHookMiddleware(AgentMiddleware):
"""Middleware to execute a pre-model hook before model invocation.
This middleware wraps the legacy pre_model_hook callable and executes it
as part of the middleware chain.
"""
def __init__(self, pre_model_hook: Callable):
self._pre_model_hook = pre_model_hook
def before_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
"""Execute the pre-model hook."""
if not self._pre_model_hook:
return None
try:
result = self._pre_model_hook(state, runtime)
return result
except Exception as e:
logger.error(
f"Pre-model hook execution failed in before_model: {e}",
exc_info=True
)
return None
async def abefore_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
"""Async version of before_model."""
if not self._pre_model_hook:
return None
try:
# Check if the hook is async
if inspect.iscoroutinefunction(self._pre_model_hook):
result = await self._pre_model_hook(state, runtime)
else:
# Run synchronous hook in thread pool to avoid blocking event loop
result = await asyncio.to_thread(self._pre_model_hook, state, runtime)
return result
except Exception as e:
logger.error(
f"Pre-model hook execution failed in abefore_model: {e}",
exc_info=True
)
return None
# Create agents using configured LLM types
def create_agent(
agent_name: str,
@@ -64,18 +150,23 @@ def create_agent(
llm_type = AGENT_LLM_MAP.get(agent_type, "basic")
logger.debug(f"Agent '{agent_name}' using LLM type: {llm_type}")
logger.debug(f"Creating ReAct agent '{agent_name}' with locale: {locale}")
logger.debug(f"Creating agent '{agent_name}' with locale: {locale}")
# Build middleware list
# Use closure to capture locale from the workflow state instead of relying on
# agent state.get("locale"), which doesn't have the locale field
# See: https://github.com/bytedance/deer-flow/issues/743
agent = create_react_agent(
middleware = [DynamicPromptMiddleware(prompt_template, locale)]
# Add pre-model hook middleware if provided
if pre_model_hook:
middleware.append(PreModelHookMiddleware(pre_model_hook))
agent = langchain_create_agent(
name=agent_name,
model=get_llm_by_type(llm_type),
tools=processed_tools,
prompt=lambda state, captured_locale=locale: apply_prompt_template(
prompt_template, state, locale=captured_locale
),
pre_model_hook=pre_model_hook,
middleware=middleware,
)
logger.info(f"Agent '{agent_name}' created successfully")

View File

@@ -6,7 +6,7 @@ import os
from datetime import datetime
from jinja2 import Environment, FileSystemLoader, TemplateNotFound, select_autoescape
from langgraph.prebuilt.chat_agent_executor import AgentState
from langchain.agents import AgentState
from src.config.configuration import Configuration

View File

@@ -42,7 +42,7 @@ class TestToolInterceptorIntegration:
tools = [search_tool, db_tool]
# Create agent with interrupts on db_tool only
with patch("src.agents.agents.create_react_agent") as mock_create, \
with patch("src.agents.agents.langchain_create_agent") as mock_create, \
patch("src.agents.agents.get_llm_by_type") as mock_llm:
mock_create.return_value = MagicMock()
mock_llm.return_value = MagicMock()
@@ -55,7 +55,7 @@ class TestToolInterceptorIntegration:
interrupt_before_tools=["db_tool"],
)
# Verify create_react_agent was called with wrapped tools
# Verify langchain_create_agent was called with wrapped tools
assert mock_create.called
call_args = mock_create.call_args
wrapped_tools = call_args.kwargs["tools"]

View File

@@ -0,0 +1,335 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import asyncio
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from src.agents.agents import DynamicPromptMiddleware, PreModelHookMiddleware
@pytest.fixture
def mock_runtime():
"""Mock Runtime object."""
runtime = MagicMock()
runtime.config = {}
return runtime
@pytest.fixture
def mock_state():
"""Mock state object."""
return {
"messages": [HumanMessage(content="Test message")],
"context": "Test context",
}
@pytest.fixture
def mock_messages():
"""Mock messages returned by apply_prompt_template."""
return [
SystemMessage(content="Test system prompt"),
HumanMessage(content="Test human message"),
]
class TestDynamicPromptMiddleware:
"""Tests for DynamicPromptMiddleware class."""
def test_init(self):
"""Test middleware initialization."""
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
assert middleware.prompt_template == "test_template"
assert middleware.locale == "zh-CN"
def test_init_default_locale(self):
"""Test middleware initialization with default locale."""
middleware = DynamicPromptMiddleware("test_template")
assert middleware.prompt_template == "test_template"
assert middleware.locale == "en-US"
@patch("src.agents.agents.apply_prompt_template")
def test_before_model_success(
self, mock_apply_template, mock_state, mock_runtime, mock_messages
):
"""Test before_model successfully applies prompt template."""
mock_apply_template.return_value = mock_messages
middleware = DynamicPromptMiddleware("test_template", locale="en-US")
result = middleware.before_model(mock_state, mock_runtime)
# Verify apply_prompt_template was called with correct arguments
mock_apply_template.assert_called_once_with(
"test_template", mock_state, locale="en-US"
)
# Verify system message is returned
assert result == {"messages": [mock_messages[0]]}
assert result["messages"][0].content == "Test system prompt"
@patch("src.agents.agents.apply_prompt_template")
def test_before_model_empty_messages(
self, mock_apply_template, mock_state, mock_runtime
):
"""Test before_model with empty message list."""
mock_apply_template.return_value = []
middleware = DynamicPromptMiddleware("test_template")
result = middleware.before_model(mock_state, mock_runtime)
# Should return None when no messages are rendered
assert result is None
@patch("src.agents.agents.apply_prompt_template")
def test_before_model_none_messages(
self, mock_apply_template, mock_state, mock_runtime
):
"""Test before_model when apply_prompt_template returns None."""
mock_apply_template.return_value = None
middleware = DynamicPromptMiddleware("test_template")
result = middleware.before_model(mock_state, mock_runtime)
# Should return None when template returns None
assert result is None
@patch("src.agents.agents.apply_prompt_template")
@patch("src.agents.agents.logger")
def test_before_model_exception_handling(
self, mock_logger, mock_apply_template, mock_state, mock_runtime
):
"""Test before_model handles exceptions gracefully."""
mock_apply_template.side_effect = ValueError("Template rendering failed")
middleware = DynamicPromptMiddleware("test_template")
result = middleware.before_model(mock_state, mock_runtime)
# Should return None on exception
assert result is None
# Should log error with exc_info
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Failed to apply prompt template in before_model" in error_message
assert mock_logger.error.call_args[1]["exc_info"] is True
@patch("src.agents.agents.apply_prompt_template")
def test_before_model_with_different_locale(
self, mock_apply_template, mock_state, mock_runtime, mock_messages
):
"""Test before_model with different locale."""
mock_apply_template.return_value = mock_messages
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
result = middleware.before_model(mock_state, mock_runtime)
# Verify locale is passed correctly
mock_apply_template.assert_called_once_with(
"test_template", mock_state, locale="zh-CN"
)
assert result == {"messages": [mock_messages[0]]}
@pytest.mark.asyncio
@patch("src.agents.agents.apply_prompt_template")
async def test_abefore_model(
self, mock_apply_template, mock_state, mock_runtime, mock_messages
):
"""Test async version of before_model."""
mock_apply_template.return_value = mock_messages
middleware = DynamicPromptMiddleware("test_template")
result = await middleware.abefore_model(mock_state, mock_runtime)
# Should call the sync version and return same result
assert result == {"messages": [mock_messages[0]]}
mock_apply_template.assert_called_once_with(
"test_template", mock_state, locale="en-US"
)
class TestPreModelHookMiddleware:
"""Tests for PreModelHookMiddleware class."""
def test_init(self):
"""Test middleware initialization."""
hook = Mock()
middleware = PreModelHookMiddleware(hook)
assert middleware._pre_model_hook == hook
def test_before_model_with_sync_hook(self, mock_state, mock_runtime):
"""Test before_model with synchronous hook."""
hook = Mock(return_value={"custom_data": "test"})
middleware = PreModelHookMiddleware(hook)
result = middleware.before_model(mock_state, mock_runtime)
# Verify hook was called with correct arguments
hook.assert_called_once_with(mock_state, mock_runtime)
assert result == {"custom_data": "test"}
def test_before_model_with_none_hook(self, mock_state, mock_runtime):
"""Test before_model when hook is None."""
middleware = PreModelHookMiddleware(None)
result = middleware.before_model(mock_state, mock_runtime)
# Should return None when hook is None
assert result is None
def test_before_model_hook_returns_none(self, mock_state, mock_runtime):
"""Test before_model when hook returns None."""
hook = Mock(return_value=None)
middleware = PreModelHookMiddleware(hook)
result = middleware.before_model(mock_state, mock_runtime)
hook.assert_called_once_with(mock_state, mock_runtime)
assert result is None
@patch("src.agents.agents.logger")
def test_before_model_hook_exception(
self, mock_logger, mock_state, mock_runtime
):
"""Test before_model handles hook exceptions gracefully."""
hook = Mock(side_effect=RuntimeError("Hook execution failed"))
middleware = PreModelHookMiddleware(hook)
result = middleware.before_model(mock_state, mock_runtime)
# Should return None on exception
assert result is None
# Should log error with exc_info
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Pre-model hook execution failed in before_model" in error_message
assert mock_logger.error.call_args[1]["exc_info"] is True
@pytest.mark.asyncio
async def test_abefore_model_with_async_hook(self, mock_state, mock_runtime):
"""Test async before_model with async hook."""
async def async_hook(state, runtime):
await asyncio.sleep(0.001) # Simulate async work
return {"async_data": "test"}
middleware = PreModelHookMiddleware(async_hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
assert result == {"async_data": "test"}
@pytest.mark.asyncio
@patch("src.agents.agents.asyncio.to_thread")
async def test_abefore_model_with_sync_hook(
self, mock_to_thread, mock_state, mock_runtime
):
"""Test async before_model with synchronous hook uses asyncio.to_thread."""
hook = Mock(return_value={"sync_data": "test"})
mock_to_thread.return_value = {"sync_data": "test"}
middleware = PreModelHookMiddleware(hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Verify asyncio.to_thread was called with the sync hook
mock_to_thread.assert_called_once_with(hook, mock_state, mock_runtime)
assert result == {"sync_data": "test"}
@pytest.mark.asyncio
async def test_abefore_model_with_none_hook(self, mock_state, mock_runtime):
"""Test async before_model when hook is None."""
middleware = PreModelHookMiddleware(None)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Should return None when hook is None
assert result is None
@pytest.mark.asyncio
@patch("src.agents.agents.logger")
async def test_abefore_model_async_hook_exception(
self, mock_logger, mock_state, mock_runtime
):
"""Test async before_model handles async hook exceptions gracefully."""
async def failing_hook(state, runtime):
raise ValueError("Async hook failed")
middleware = PreModelHookMiddleware(failing_hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Should return None on exception
assert result is None
# Should log error with exc_info
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Pre-model hook execution failed in abefore_model" in error_message
assert mock_logger.error.call_args[1]["exc_info"] is True
@pytest.mark.asyncio
@patch("src.agents.agents.asyncio.to_thread")
@patch("src.agents.agents.logger")
async def test_abefore_model_sync_hook_exception(
self, mock_logger, mock_to_thread, mock_state, mock_runtime
):
"""Test async before_model handles sync hook exceptions gracefully."""
hook = Mock()
mock_to_thread.side_effect = RuntimeError("Thread execution failed")
middleware = PreModelHookMiddleware(hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Should return None on exception
assert result is None
# Should log error with exc_info
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Pre-model hook execution failed in abefore_model" in error_message
assert mock_logger.error.call_args[1]["exc_info"] is True
@pytest.mark.asyncio
async def test_abefore_model_sync_hook_actual_execution(
self, mock_state, mock_runtime
):
"""Test async before_model actually runs sync hook in thread pool."""
# Track if hook was called
hook_called = []
def sync_hook(state, runtime):
hook_called.append(True)
return {"data": "from_sync_hook"}
middleware = PreModelHookMiddleware(sync_hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Verify hook was called and result returned
assert len(hook_called) == 1
assert result == {"data": "from_sync_hook"}
@pytest.mark.asyncio
async def test_abefore_model_detects_coroutine_function(
self, mock_state, mock_runtime
):
"""Test that abefore_model correctly detects async vs sync functions."""
# Test with async function
async def async_hook(state, runtime):
return {"type": "async"}
# Test with sync function
def sync_hook(state, runtime):
return {"type": "sync"}
async_middleware = PreModelHookMiddleware(async_hook)
sync_middleware = PreModelHookMiddleware(sync_hook)
# Both should execute successfully
async_result = await async_middleware.abefore_model(mock_state, mock_runtime)
sync_result = await sync_middleware.abefore_model(mock_state, mock_runtime)
assert async_result == {"type": "async"}
assert sync_result == {"type": "sync"}

View File

@@ -2,10 +2,10 @@
# SPDX-License-Identifier: MIT
"""
Unit tests for agent locale restoration after create_react_agent execution.
Unit tests for agent locale restoration after agent execution.
Tests that meta fields (especially locale) are properly restored after
agent.ainvoke() returns, since create_react_agent creates a MessagesState
agent.ainvoke() returns, since the agent creates a MessagesState
subgraph that filters out custom fields.
"""
@@ -22,7 +22,7 @@ class TestAgentLocaleRestoration:
"""
Demonstrate the problem: agent subgraph filters out locale.
When create_react_agent creates a subgraph with MessagesState,
When the agent creates a subgraph with MessagesState,
it only returns messages, not custom fields.
"""
# Simulate agent behavior: only returns messages