fix: migrate from deprecated create_react_agent to langchain.agents.create_agent (#802)

* fix: migrate from deprecated create_react_agent to langchain.agents.create_agent

Fixes #799

- Replace deprecated langgraph.prebuilt.create_react_agent with
  langchain.agents.create_agent (LangGraph 1.0 migration)
- Add DynamicPromptMiddleware to handle dynamic prompt templates
  (replaces the 'prompt' callable parameter)
- Add PreModelHookMiddleware to handle pre-model hooks
  (replaces the 'pre_model_hook' parameter)
- Update AgentState import from langchain.agents in template.py
- Update tests to use the new API

* fix:update the code with review comments
This commit is contained in:
Willem Jiang
2026-01-07 09:06:16 +08:00
committed by GitHub
parent 1ced90b055
commit d4ab77de5c
5 changed files with 440 additions and 14 deletions

View File

@@ -1,10 +1,14 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import asyncio
import inspect
import logging import logging
from typing import List, Optional from typing import Any, Callable, List, Optional
from langgraph.prebuilt import create_react_agent from langchain.agents import create_agent as langchain_create_agent
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from src.agents.tool_interceptor import wrap_tools_with_interceptor from src.agents.tool_interceptor import wrap_tools_with_interceptor
from src.config.agents import AGENT_LLM_MAP from src.config.agents import AGENT_LLM_MAP
@@ -14,6 +18,88 @@ from src.prompts import apply_prompt_template
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DynamicPromptMiddleware(AgentMiddleware):
"""Middleware to apply dynamic prompt template before model invocation.
This middleware prepends a system message with the rendered prompt template
to the messages list before the model is called.
"""
def __init__(self, prompt_template: str, locale: str = "en-US"):
self.prompt_template = prompt_template
self.locale = locale
def before_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
"""Apply prompt template and prepend system message to messages."""
try:
# Get the rendered messages including system prompt from template
rendered_messages = apply_prompt_template(
self.prompt_template, state, locale=self.locale
)
# The first message is the system prompt, extract it
if rendered_messages and len(rendered_messages) > 0:
system_message = rendered_messages[0]
# Prepend system message to existing messages
return {"messages": [system_message]}
return None
except Exception as e:
logger.error(
f"Failed to apply prompt template in before_model: {e}",
exc_info=True
)
return None
async def abefore_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
"""Async version of before_model."""
return self.before_model(state, runtime)
class PreModelHookMiddleware(AgentMiddleware):
"""Middleware to execute a pre-model hook before model invocation.
This middleware wraps the legacy pre_model_hook callable and executes it
as part of the middleware chain.
"""
def __init__(self, pre_model_hook: Callable):
self._pre_model_hook = pre_model_hook
def before_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
"""Execute the pre-model hook."""
if not self._pre_model_hook:
return None
try:
result = self._pre_model_hook(state, runtime)
return result
except Exception as e:
logger.error(
f"Pre-model hook execution failed in before_model: {e}",
exc_info=True
)
return None
async def abefore_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
"""Async version of before_model."""
if not self._pre_model_hook:
return None
try:
# Check if the hook is async
if inspect.iscoroutinefunction(self._pre_model_hook):
result = await self._pre_model_hook(state, runtime)
else:
# Run synchronous hook in thread pool to avoid blocking event loop
result = await asyncio.to_thread(self._pre_model_hook, state, runtime)
return result
except Exception as e:
logger.error(
f"Pre-model hook execution failed in abefore_model: {e}",
exc_info=True
)
return None
# Create agents using configured LLM types # Create agents using configured LLM types
def create_agent( def create_agent(
agent_name: str, agent_name: str,
@@ -64,18 +150,23 @@ def create_agent(
llm_type = AGENT_LLM_MAP.get(agent_type, "basic") llm_type = AGENT_LLM_MAP.get(agent_type, "basic")
logger.debug(f"Agent '{agent_name}' using LLM type: {llm_type}") logger.debug(f"Agent '{agent_name}' using LLM type: {llm_type}")
logger.debug(f"Creating ReAct agent '{agent_name}' with locale: {locale}") logger.debug(f"Creating agent '{agent_name}' with locale: {locale}")
# Build middleware list
# Use closure to capture locale from the workflow state instead of relying on # Use closure to capture locale from the workflow state instead of relying on
# agent state.get("locale"), which doesn't have the locale field # agent state.get("locale"), which doesn't have the locale field
# See: https://github.com/bytedance/deer-flow/issues/743 # See: https://github.com/bytedance/deer-flow/issues/743
agent = create_react_agent( middleware = [DynamicPromptMiddleware(prompt_template, locale)]
# Add pre-model hook middleware if provided
if pre_model_hook:
middleware.append(PreModelHookMiddleware(pre_model_hook))
agent = langchain_create_agent(
name=agent_name, name=agent_name,
model=get_llm_by_type(llm_type), model=get_llm_by_type(llm_type),
tools=processed_tools, tools=processed_tools,
prompt=lambda state, captured_locale=locale: apply_prompt_template( middleware=middleware,
prompt_template, state, locale=captured_locale
),
pre_model_hook=pre_model_hook,
) )
logger.info(f"Agent '{agent_name}' created successfully") logger.info(f"Agent '{agent_name}' created successfully")

View File

@@ -6,7 +6,7 @@ import os
from datetime import datetime from datetime import datetime
from jinja2 import Environment, FileSystemLoader, TemplateNotFound, select_autoescape from jinja2 import Environment, FileSystemLoader, TemplateNotFound, select_autoescape
from langgraph.prebuilt.chat_agent_executor import AgentState from langchain.agents import AgentState
from src.config.configuration import Configuration from src.config.configuration import Configuration

View File

@@ -42,7 +42,7 @@ class TestToolInterceptorIntegration:
tools = [search_tool, db_tool] tools = [search_tool, db_tool]
# Create agent with interrupts on db_tool only # Create agent with interrupts on db_tool only
with patch("src.agents.agents.create_react_agent") as mock_create, \ with patch("src.agents.agents.langchain_create_agent") as mock_create, \
patch("src.agents.agents.get_llm_by_type") as mock_llm: patch("src.agents.agents.get_llm_by_type") as mock_llm:
mock_create.return_value = MagicMock() mock_create.return_value = MagicMock()
mock_llm.return_value = MagicMock() mock_llm.return_value = MagicMock()
@@ -55,7 +55,7 @@ class TestToolInterceptorIntegration:
interrupt_before_tools=["db_tool"], interrupt_before_tools=["db_tool"],
) )
# Verify create_react_agent was called with wrapped tools # Verify langchain_create_agent was called with wrapped tools
assert mock_create.called assert mock_create.called
call_args = mock_create.call_args call_args = mock_create.call_args
wrapped_tools = call_args.kwargs["tools"] wrapped_tools = call_args.kwargs["tools"]

View File

@@ -0,0 +1,335 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import asyncio
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from src.agents.agents import DynamicPromptMiddleware, PreModelHookMiddleware
@pytest.fixture
def mock_runtime():
"""Mock Runtime object."""
runtime = MagicMock()
runtime.config = {}
return runtime
@pytest.fixture
def mock_state():
"""Mock state object."""
return {
"messages": [HumanMessage(content="Test message")],
"context": "Test context",
}
@pytest.fixture
def mock_messages():
"""Mock messages returned by apply_prompt_template."""
return [
SystemMessage(content="Test system prompt"),
HumanMessage(content="Test human message"),
]
class TestDynamicPromptMiddleware:
"""Tests for DynamicPromptMiddleware class."""
def test_init(self):
"""Test middleware initialization."""
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
assert middleware.prompt_template == "test_template"
assert middleware.locale == "zh-CN"
def test_init_default_locale(self):
"""Test middleware initialization with default locale."""
middleware = DynamicPromptMiddleware("test_template")
assert middleware.prompt_template == "test_template"
assert middleware.locale == "en-US"
@patch("src.agents.agents.apply_prompt_template")
def test_before_model_success(
self, mock_apply_template, mock_state, mock_runtime, mock_messages
):
"""Test before_model successfully applies prompt template."""
mock_apply_template.return_value = mock_messages
middleware = DynamicPromptMiddleware("test_template", locale="en-US")
result = middleware.before_model(mock_state, mock_runtime)
# Verify apply_prompt_template was called with correct arguments
mock_apply_template.assert_called_once_with(
"test_template", mock_state, locale="en-US"
)
# Verify system message is returned
assert result == {"messages": [mock_messages[0]]}
assert result["messages"][0].content == "Test system prompt"
@patch("src.agents.agents.apply_prompt_template")
def test_before_model_empty_messages(
self, mock_apply_template, mock_state, mock_runtime
):
"""Test before_model with empty message list."""
mock_apply_template.return_value = []
middleware = DynamicPromptMiddleware("test_template")
result = middleware.before_model(mock_state, mock_runtime)
# Should return None when no messages are rendered
assert result is None
@patch("src.agents.agents.apply_prompt_template")
def test_before_model_none_messages(
self, mock_apply_template, mock_state, mock_runtime
):
"""Test before_model when apply_prompt_template returns None."""
mock_apply_template.return_value = None
middleware = DynamicPromptMiddleware("test_template")
result = middleware.before_model(mock_state, mock_runtime)
# Should return None when template returns None
assert result is None
@patch("src.agents.agents.apply_prompt_template")
@patch("src.agents.agents.logger")
def test_before_model_exception_handling(
self, mock_logger, mock_apply_template, mock_state, mock_runtime
):
"""Test before_model handles exceptions gracefully."""
mock_apply_template.side_effect = ValueError("Template rendering failed")
middleware = DynamicPromptMiddleware("test_template")
result = middleware.before_model(mock_state, mock_runtime)
# Should return None on exception
assert result is None
# Should log error with exc_info
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Failed to apply prompt template in before_model" in error_message
assert mock_logger.error.call_args[1]["exc_info"] is True
@patch("src.agents.agents.apply_prompt_template")
def test_before_model_with_different_locale(
self, mock_apply_template, mock_state, mock_runtime, mock_messages
):
"""Test before_model with different locale."""
mock_apply_template.return_value = mock_messages
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
result = middleware.before_model(mock_state, mock_runtime)
# Verify locale is passed correctly
mock_apply_template.assert_called_once_with(
"test_template", mock_state, locale="zh-CN"
)
assert result == {"messages": [mock_messages[0]]}
@pytest.mark.asyncio
@patch("src.agents.agents.apply_prompt_template")
async def test_abefore_model(
self, mock_apply_template, mock_state, mock_runtime, mock_messages
):
"""Test async version of before_model."""
mock_apply_template.return_value = mock_messages
middleware = DynamicPromptMiddleware("test_template")
result = await middleware.abefore_model(mock_state, mock_runtime)
# Should call the sync version and return same result
assert result == {"messages": [mock_messages[0]]}
mock_apply_template.assert_called_once_with(
"test_template", mock_state, locale="en-US"
)
class TestPreModelHookMiddleware:
"""Tests for PreModelHookMiddleware class."""
def test_init(self):
"""Test middleware initialization."""
hook = Mock()
middleware = PreModelHookMiddleware(hook)
assert middleware._pre_model_hook == hook
def test_before_model_with_sync_hook(self, mock_state, mock_runtime):
"""Test before_model with synchronous hook."""
hook = Mock(return_value={"custom_data": "test"})
middleware = PreModelHookMiddleware(hook)
result = middleware.before_model(mock_state, mock_runtime)
# Verify hook was called with correct arguments
hook.assert_called_once_with(mock_state, mock_runtime)
assert result == {"custom_data": "test"}
def test_before_model_with_none_hook(self, mock_state, mock_runtime):
"""Test before_model when hook is None."""
middleware = PreModelHookMiddleware(None)
result = middleware.before_model(mock_state, mock_runtime)
# Should return None when hook is None
assert result is None
def test_before_model_hook_returns_none(self, mock_state, mock_runtime):
"""Test before_model when hook returns None."""
hook = Mock(return_value=None)
middleware = PreModelHookMiddleware(hook)
result = middleware.before_model(mock_state, mock_runtime)
hook.assert_called_once_with(mock_state, mock_runtime)
assert result is None
@patch("src.agents.agents.logger")
def test_before_model_hook_exception(
self, mock_logger, mock_state, mock_runtime
):
"""Test before_model handles hook exceptions gracefully."""
hook = Mock(side_effect=RuntimeError("Hook execution failed"))
middleware = PreModelHookMiddleware(hook)
result = middleware.before_model(mock_state, mock_runtime)
# Should return None on exception
assert result is None
# Should log error with exc_info
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Pre-model hook execution failed in before_model" in error_message
assert mock_logger.error.call_args[1]["exc_info"] is True
@pytest.mark.asyncio
async def test_abefore_model_with_async_hook(self, mock_state, mock_runtime):
"""Test async before_model with async hook."""
async def async_hook(state, runtime):
await asyncio.sleep(0.001) # Simulate async work
return {"async_data": "test"}
middleware = PreModelHookMiddleware(async_hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
assert result == {"async_data": "test"}
@pytest.mark.asyncio
@patch("src.agents.agents.asyncio.to_thread")
async def test_abefore_model_with_sync_hook(
self, mock_to_thread, mock_state, mock_runtime
):
"""Test async before_model with synchronous hook uses asyncio.to_thread."""
hook = Mock(return_value={"sync_data": "test"})
mock_to_thread.return_value = {"sync_data": "test"}
middleware = PreModelHookMiddleware(hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Verify asyncio.to_thread was called with the sync hook
mock_to_thread.assert_called_once_with(hook, mock_state, mock_runtime)
assert result == {"sync_data": "test"}
@pytest.mark.asyncio
async def test_abefore_model_with_none_hook(self, mock_state, mock_runtime):
"""Test async before_model when hook is None."""
middleware = PreModelHookMiddleware(None)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Should return None when hook is None
assert result is None
@pytest.mark.asyncio
@patch("src.agents.agents.logger")
async def test_abefore_model_async_hook_exception(
self, mock_logger, mock_state, mock_runtime
):
"""Test async before_model handles async hook exceptions gracefully."""
async def failing_hook(state, runtime):
raise ValueError("Async hook failed")
middleware = PreModelHookMiddleware(failing_hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Should return None on exception
assert result is None
# Should log error with exc_info
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Pre-model hook execution failed in abefore_model" in error_message
assert mock_logger.error.call_args[1]["exc_info"] is True
@pytest.mark.asyncio
@patch("src.agents.agents.asyncio.to_thread")
@patch("src.agents.agents.logger")
async def test_abefore_model_sync_hook_exception(
self, mock_logger, mock_to_thread, mock_state, mock_runtime
):
"""Test async before_model handles sync hook exceptions gracefully."""
hook = Mock()
mock_to_thread.side_effect = RuntimeError("Thread execution failed")
middleware = PreModelHookMiddleware(hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Should return None on exception
assert result is None
# Should log error with exc_info
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Pre-model hook execution failed in abefore_model" in error_message
assert mock_logger.error.call_args[1]["exc_info"] is True
@pytest.mark.asyncio
async def test_abefore_model_sync_hook_actual_execution(
self, mock_state, mock_runtime
):
"""Test async before_model actually runs sync hook in thread pool."""
# Track if hook was called
hook_called = []
def sync_hook(state, runtime):
hook_called.append(True)
return {"data": "from_sync_hook"}
middleware = PreModelHookMiddleware(sync_hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Verify hook was called and result returned
assert len(hook_called) == 1
assert result == {"data": "from_sync_hook"}
@pytest.mark.asyncio
async def test_abefore_model_detects_coroutine_function(
self, mock_state, mock_runtime
):
"""Test that abefore_model correctly detects async vs sync functions."""
# Test with async function
async def async_hook(state, runtime):
return {"type": "async"}
# Test with sync function
def sync_hook(state, runtime):
return {"type": "sync"}
async_middleware = PreModelHookMiddleware(async_hook)
sync_middleware = PreModelHookMiddleware(sync_hook)
# Both should execute successfully
async_result = await async_middleware.abefore_model(mock_state, mock_runtime)
sync_result = await sync_middleware.abefore_model(mock_state, mock_runtime)
assert async_result == {"type": "async"}
assert sync_result == {"type": "sync"}

View File

@@ -2,10 +2,10 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
""" """
Unit tests for agent locale restoration after create_react_agent execution. Unit tests for agent locale restoration after agent execution.
Tests that meta fields (especially locale) are properly restored after Tests that meta fields (especially locale) are properly restored after
agent.ainvoke() returns, since create_react_agent creates a MessagesState agent.ainvoke() returns, since the agent creates a MessagesState
subgraph that filters out custom fields. subgraph that filters out custom fields.
""" """
@@ -22,7 +22,7 @@ class TestAgentLocaleRestoration:
""" """
Demonstrate the problem: agent subgraph filters out locale. Demonstrate the problem: agent subgraph filters out locale.
When create_react_agent creates a subgraph with MessagesState, When the agent creates a subgraph with MessagesState,
it only returns messages, not custom fields. it only returns messages, not custom fields.
""" """
# Simulate agent behavior: only returns messages # Simulate agent behavior: only returns messages