mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-28 00:04:47 +08:00
fix: migrate from deprecated create_react_agent to langchain.agents.create_agent (#802)
* fix: migrate from deprecated create_react_agent to langchain.agents.create_agent Fixes #799 - Replace deprecated langgraph.prebuilt.create_react_agent with langchain.agents.create_agent (LangGraph 1.0 migration) - Add DynamicPromptMiddleware to handle dynamic prompt templates (replaces the 'prompt' callable parameter) - Add PreModelHookMiddleware to handle pre-model hooks (replaces the 'pre_model_hook' parameter) - Update AgentState import from langchain.agents in template.py - Update tests to use the new API * fix:update the code with review comments
This commit is contained in:
@@ -1,10 +1,14 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langchain.agents import create_agent as langchain_create_agent
|
||||||
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from src.agents.tool_interceptor import wrap_tools_with_interceptor
|
from src.agents.tool_interceptor import wrap_tools_with_interceptor
|
||||||
from src.config.agents import AGENT_LLM_MAP
|
from src.config.agents import AGENT_LLM_MAP
|
||||||
@@ -14,6 +18,88 @@ from src.prompts import apply_prompt_template
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicPromptMiddleware(AgentMiddleware):
|
||||||
|
"""Middleware to apply dynamic prompt template before model invocation.
|
||||||
|
|
||||||
|
This middleware prepends a system message with the rendered prompt template
|
||||||
|
to the messages list before the model is called.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prompt_template: str, locale: str = "en-US"):
|
||||||
|
self.prompt_template = prompt_template
|
||||||
|
self.locale = locale
|
||||||
|
|
||||||
|
def before_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
|
"""Apply prompt template and prepend system message to messages."""
|
||||||
|
try:
|
||||||
|
# Get the rendered messages including system prompt from template
|
||||||
|
rendered_messages = apply_prompt_template(
|
||||||
|
self.prompt_template, state, locale=self.locale
|
||||||
|
)
|
||||||
|
# The first message is the system prompt, extract it
|
||||||
|
if rendered_messages and len(rendered_messages) > 0:
|
||||||
|
system_message = rendered_messages[0]
|
||||||
|
# Prepend system message to existing messages
|
||||||
|
return {"messages": [system_message]}
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to apply prompt template in before_model: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def abefore_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
|
"""Async version of before_model."""
|
||||||
|
return self.before_model(state, runtime)
|
||||||
|
|
||||||
|
|
||||||
|
class PreModelHookMiddleware(AgentMiddleware):
|
||||||
|
"""Middleware to execute a pre-model hook before model invocation.
|
||||||
|
|
||||||
|
This middleware wraps the legacy pre_model_hook callable and executes it
|
||||||
|
as part of the middleware chain.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pre_model_hook: Callable):
|
||||||
|
self._pre_model_hook = pre_model_hook
|
||||||
|
|
||||||
|
def before_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
|
"""Execute the pre-model hook."""
|
||||||
|
if not self._pre_model_hook:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self._pre_model_hook(state, runtime)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Pre-model hook execution failed in before_model: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def abefore_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
|
"""Async version of before_model."""
|
||||||
|
if not self._pre_model_hook:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if the hook is async
|
||||||
|
if inspect.iscoroutinefunction(self._pre_model_hook):
|
||||||
|
result = await self._pre_model_hook(state, runtime)
|
||||||
|
else:
|
||||||
|
# Run synchronous hook in thread pool to avoid blocking event loop
|
||||||
|
result = await asyncio.to_thread(self._pre_model_hook, state, runtime)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Pre-model hook execution failed in abefore_model: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Create agents using configured LLM types
|
# Create agents using configured LLM types
|
||||||
def create_agent(
|
def create_agent(
|
||||||
agent_name: str,
|
agent_name: str,
|
||||||
@@ -64,18 +150,23 @@ def create_agent(
|
|||||||
llm_type = AGENT_LLM_MAP.get(agent_type, "basic")
|
llm_type = AGENT_LLM_MAP.get(agent_type, "basic")
|
||||||
logger.debug(f"Agent '{agent_name}' using LLM type: {llm_type}")
|
logger.debug(f"Agent '{agent_name}' using LLM type: {llm_type}")
|
||||||
|
|
||||||
logger.debug(f"Creating ReAct agent '{agent_name}' with locale: {locale}")
|
logger.debug(f"Creating agent '{agent_name}' with locale: {locale}")
|
||||||
|
|
||||||
|
# Build middleware list
|
||||||
# Use closure to capture locale from the workflow state instead of relying on
|
# Use closure to capture locale from the workflow state instead of relying on
|
||||||
# agent state.get("locale"), which doesn't have the locale field
|
# agent state.get("locale"), which doesn't have the locale field
|
||||||
# See: https://github.com/bytedance/deer-flow/issues/743
|
# See: https://github.com/bytedance/deer-flow/issues/743
|
||||||
agent = create_react_agent(
|
middleware = [DynamicPromptMiddleware(prompt_template, locale)]
|
||||||
|
|
||||||
|
# Add pre-model hook middleware if provided
|
||||||
|
if pre_model_hook:
|
||||||
|
middleware.append(PreModelHookMiddleware(pre_model_hook))
|
||||||
|
|
||||||
|
agent = langchain_create_agent(
|
||||||
name=agent_name,
|
name=agent_name,
|
||||||
model=get_llm_by_type(llm_type),
|
model=get_llm_by_type(llm_type),
|
||||||
tools=processed_tools,
|
tools=processed_tools,
|
||||||
prompt=lambda state, captured_locale=locale: apply_prompt_template(
|
middleware=middleware,
|
||||||
prompt_template, state, locale=captured_locale
|
|
||||||
),
|
|
||||||
pre_model_hook=pre_model_hook,
|
|
||||||
)
|
)
|
||||||
logger.info(f"Agent '{agent_name}' created successfully")
|
logger.info(f"Agent '{agent_name}' created successfully")
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import os
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from jinja2 import Environment, FileSystemLoader, TemplateNotFound, select_autoescape
|
from jinja2 import Environment, FileSystemLoader, TemplateNotFound, select_autoescape
|
||||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
from langchain.agents import AgentState
|
||||||
|
|
||||||
from src.config.configuration import Configuration
|
from src.config.configuration import Configuration
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class TestToolInterceptorIntegration:
|
|||||||
tools = [search_tool, db_tool]
|
tools = [search_tool, db_tool]
|
||||||
|
|
||||||
# Create agent with interrupts on db_tool only
|
# Create agent with interrupts on db_tool only
|
||||||
with patch("src.agents.agents.create_react_agent") as mock_create, \
|
with patch("src.agents.agents.langchain_create_agent") as mock_create, \
|
||||||
patch("src.agents.agents.get_llm_by_type") as mock_llm:
|
patch("src.agents.agents.get_llm_by_type") as mock_llm:
|
||||||
mock_create.return_value = MagicMock()
|
mock_create.return_value = MagicMock()
|
||||||
mock_llm.return_value = MagicMock()
|
mock_llm.return_value = MagicMock()
|
||||||
@@ -55,7 +55,7 @@ class TestToolInterceptorIntegration:
|
|||||||
interrupt_before_tools=["db_tool"],
|
interrupt_before_tools=["db_tool"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify create_react_agent was called with wrapped tools
|
# Verify langchain_create_agent was called with wrapped tools
|
||||||
assert mock_create.called
|
assert mock_create.called
|
||||||
call_args = mock_create.call_args
|
call_args = mock_create.call_args
|
||||||
wrapped_tools = call_args.kwargs["tools"]
|
wrapped_tools = call_args.kwargs["tools"]
|
||||||
|
|||||||
335
tests/unit/agents/test_middleware.py
Normal file
335
tests/unit/agents/test_middleware.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from src.agents.agents import DynamicPromptMiddleware, PreModelHookMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_runtime():
|
||||||
|
"""Mock Runtime object."""
|
||||||
|
runtime = MagicMock()
|
||||||
|
runtime.config = {}
|
||||||
|
return runtime
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_state():
|
||||||
|
"""Mock state object."""
|
||||||
|
return {
|
||||||
|
"messages": [HumanMessage(content="Test message")],
|
||||||
|
"context": "Test context",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_messages():
|
||||||
|
"""Mock messages returned by apply_prompt_template."""
|
||||||
|
return [
|
||||||
|
SystemMessage(content="Test system prompt"),
|
||||||
|
HumanMessage(content="Test human message"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestDynamicPromptMiddleware:
|
||||||
|
"""Tests for DynamicPromptMiddleware class."""
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
"""Test middleware initialization."""
|
||||||
|
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
|
||||||
|
assert middleware.prompt_template == "test_template"
|
||||||
|
assert middleware.locale == "zh-CN"
|
||||||
|
|
||||||
|
def test_init_default_locale(self):
|
||||||
|
"""Test middleware initialization with default locale."""
|
||||||
|
middleware = DynamicPromptMiddleware("test_template")
|
||||||
|
assert middleware.prompt_template == "test_template"
|
||||||
|
assert middleware.locale == "en-US"
|
||||||
|
|
||||||
|
@patch("src.agents.agents.apply_prompt_template")
|
||||||
|
def test_before_model_success(
|
||||||
|
self, mock_apply_template, mock_state, mock_runtime, mock_messages
|
||||||
|
):
|
||||||
|
"""Test before_model successfully applies prompt template."""
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
middleware = DynamicPromptMiddleware("test_template", locale="en-US")
|
||||||
|
|
||||||
|
result = middleware.before_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Verify apply_prompt_template was called with correct arguments
|
||||||
|
mock_apply_template.assert_called_once_with(
|
||||||
|
"test_template", mock_state, locale="en-US"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify system message is returned
|
||||||
|
assert result == {"messages": [mock_messages[0]]}
|
||||||
|
assert result["messages"][0].content == "Test system prompt"
|
||||||
|
|
||||||
|
@patch("src.agents.agents.apply_prompt_template")
|
||||||
|
def test_before_model_empty_messages(
|
||||||
|
self, mock_apply_template, mock_state, mock_runtime
|
||||||
|
):
|
||||||
|
"""Test before_model with empty message list."""
|
||||||
|
mock_apply_template.return_value = []
|
||||||
|
middleware = DynamicPromptMiddleware("test_template")
|
||||||
|
|
||||||
|
result = middleware.before_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Should return None when no messages are rendered
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@patch("src.agents.agents.apply_prompt_template")
|
||||||
|
def test_before_model_none_messages(
|
||||||
|
self, mock_apply_template, mock_state, mock_runtime
|
||||||
|
):
|
||||||
|
"""Test before_model when apply_prompt_template returns None."""
|
||||||
|
mock_apply_template.return_value = None
|
||||||
|
middleware = DynamicPromptMiddleware("test_template")
|
||||||
|
|
||||||
|
result = middleware.before_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Should return None when template returns None
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@patch("src.agents.agents.apply_prompt_template")
|
||||||
|
@patch("src.agents.agents.logger")
|
||||||
|
def test_before_model_exception_handling(
|
||||||
|
self, mock_logger, mock_apply_template, mock_state, mock_runtime
|
||||||
|
):
|
||||||
|
"""Test before_model handles exceptions gracefully."""
|
||||||
|
mock_apply_template.side_effect = ValueError("Template rendering failed")
|
||||||
|
middleware = DynamicPromptMiddleware("test_template")
|
||||||
|
|
||||||
|
result = middleware.before_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Should return None on exception
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Should log error with exc_info
|
||||||
|
mock_logger.error.assert_called_once()
|
||||||
|
error_message = mock_logger.error.call_args[0][0]
|
||||||
|
assert "Failed to apply prompt template in before_model" in error_message
|
||||||
|
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||||
|
|
||||||
|
@patch("src.agents.agents.apply_prompt_template")
|
||||||
|
def test_before_model_with_different_locale(
|
||||||
|
self, mock_apply_template, mock_state, mock_runtime, mock_messages
|
||||||
|
):
|
||||||
|
"""Test before_model with different locale."""
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
|
||||||
|
|
||||||
|
result = middleware.before_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Verify locale is passed correctly
|
||||||
|
mock_apply_template.assert_called_once_with(
|
||||||
|
"test_template", mock_state, locale="zh-CN"
|
||||||
|
)
|
||||||
|
assert result == {"messages": [mock_messages[0]]}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("src.agents.agents.apply_prompt_template")
|
||||||
|
async def test_abefore_model(
|
||||||
|
self, mock_apply_template, mock_state, mock_runtime, mock_messages
|
||||||
|
):
|
||||||
|
"""Test async version of before_model."""
|
||||||
|
mock_apply_template.return_value = mock_messages
|
||||||
|
middleware = DynamicPromptMiddleware("test_template")
|
||||||
|
|
||||||
|
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Should call the sync version and return same result
|
||||||
|
assert result == {"messages": [mock_messages[0]]}
|
||||||
|
mock_apply_template.assert_called_once_with(
|
||||||
|
"test_template", mock_state, locale="en-US"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreModelHookMiddleware:
|
||||||
|
"""Tests for PreModelHookMiddleware class."""
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
"""Test middleware initialization."""
|
||||||
|
hook = Mock()
|
||||||
|
middleware = PreModelHookMiddleware(hook)
|
||||||
|
assert middleware._pre_model_hook == hook
|
||||||
|
|
||||||
|
def test_before_model_with_sync_hook(self, mock_state, mock_runtime):
|
||||||
|
"""Test before_model with synchronous hook."""
|
||||||
|
hook = Mock(return_value={"custom_data": "test"})
|
||||||
|
middleware = PreModelHookMiddleware(hook)
|
||||||
|
|
||||||
|
result = middleware.before_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Verify hook was called with correct arguments
|
||||||
|
hook.assert_called_once_with(mock_state, mock_runtime)
|
||||||
|
assert result == {"custom_data": "test"}
|
||||||
|
|
||||||
|
def test_before_model_with_none_hook(self, mock_state, mock_runtime):
|
||||||
|
"""Test before_model when hook is None."""
|
||||||
|
middleware = PreModelHookMiddleware(None)
|
||||||
|
|
||||||
|
result = middleware.before_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Should return None when hook is None
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_before_model_hook_returns_none(self, mock_state, mock_runtime):
|
||||||
|
"""Test before_model when hook returns None."""
|
||||||
|
hook = Mock(return_value=None)
|
||||||
|
middleware = PreModelHookMiddleware(hook)
|
||||||
|
|
||||||
|
result = middleware.before_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
hook.assert_called_once_with(mock_state, mock_runtime)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@patch("src.agents.agents.logger")
|
||||||
|
def test_before_model_hook_exception(
|
||||||
|
self, mock_logger, mock_state, mock_runtime
|
||||||
|
):
|
||||||
|
"""Test before_model handles hook exceptions gracefully."""
|
||||||
|
hook = Mock(side_effect=RuntimeError("Hook execution failed"))
|
||||||
|
middleware = PreModelHookMiddleware(hook)
|
||||||
|
|
||||||
|
result = middleware.before_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Should return None on exception
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Should log error with exc_info
|
||||||
|
mock_logger.error.assert_called_once()
|
||||||
|
error_message = mock_logger.error.call_args[0][0]
|
||||||
|
assert "Pre-model hook execution failed in before_model" in error_message
|
||||||
|
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abefore_model_with_async_hook(self, mock_state, mock_runtime):
|
||||||
|
"""Test async before_model with async hook."""
|
||||||
|
async def async_hook(state, runtime):
|
||||||
|
await asyncio.sleep(0.001) # Simulate async work
|
||||||
|
return {"async_data": "test"}
|
||||||
|
|
||||||
|
middleware = PreModelHookMiddleware(async_hook)
|
||||||
|
|
||||||
|
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
assert result == {"async_data": "test"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("src.agents.agents.asyncio.to_thread")
|
||||||
|
async def test_abefore_model_with_sync_hook(
|
||||||
|
self, mock_to_thread, mock_state, mock_runtime
|
||||||
|
):
|
||||||
|
"""Test async before_model with synchronous hook uses asyncio.to_thread."""
|
||||||
|
hook = Mock(return_value={"sync_data": "test"})
|
||||||
|
mock_to_thread.return_value = {"sync_data": "test"}
|
||||||
|
middleware = PreModelHookMiddleware(hook)
|
||||||
|
|
||||||
|
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Verify asyncio.to_thread was called with the sync hook
|
||||||
|
mock_to_thread.assert_called_once_with(hook, mock_state, mock_runtime)
|
||||||
|
assert result == {"sync_data": "test"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abefore_model_with_none_hook(self, mock_state, mock_runtime):
|
||||||
|
"""Test async before_model when hook is None."""
|
||||||
|
middleware = PreModelHookMiddleware(None)
|
||||||
|
|
||||||
|
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Should return None when hook is None
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("src.agents.agents.logger")
|
||||||
|
async def test_abefore_model_async_hook_exception(
|
||||||
|
self, mock_logger, mock_state, mock_runtime
|
||||||
|
):
|
||||||
|
"""Test async before_model handles async hook exceptions gracefully."""
|
||||||
|
async def failing_hook(state, runtime):
|
||||||
|
raise ValueError("Async hook failed")
|
||||||
|
|
||||||
|
middleware = PreModelHookMiddleware(failing_hook)
|
||||||
|
|
||||||
|
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Should return None on exception
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Should log error with exc_info
|
||||||
|
mock_logger.error.assert_called_once()
|
||||||
|
error_message = mock_logger.error.call_args[0][0]
|
||||||
|
assert "Pre-model hook execution failed in abefore_model" in error_message
|
||||||
|
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("src.agents.agents.asyncio.to_thread")
|
||||||
|
@patch("src.agents.agents.logger")
|
||||||
|
async def test_abefore_model_sync_hook_exception(
|
||||||
|
self, mock_logger, mock_to_thread, mock_state, mock_runtime
|
||||||
|
):
|
||||||
|
"""Test async before_model handles sync hook exceptions gracefully."""
|
||||||
|
hook = Mock()
|
||||||
|
mock_to_thread.side_effect = RuntimeError("Thread execution failed")
|
||||||
|
middleware = PreModelHookMiddleware(hook)
|
||||||
|
|
||||||
|
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Should return None on exception
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Should log error with exc_info
|
||||||
|
mock_logger.error.assert_called_once()
|
||||||
|
error_message = mock_logger.error.call_args[0][0]
|
||||||
|
assert "Pre-model hook execution failed in abefore_model" in error_message
|
||||||
|
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abefore_model_sync_hook_actual_execution(
|
||||||
|
self, mock_state, mock_runtime
|
||||||
|
):
|
||||||
|
"""Test async before_model actually runs sync hook in thread pool."""
|
||||||
|
# Track if hook was called
|
||||||
|
hook_called = []
|
||||||
|
|
||||||
|
def sync_hook(state, runtime):
|
||||||
|
hook_called.append(True)
|
||||||
|
return {"data": "from_sync_hook"}
|
||||||
|
|
||||||
|
middleware = PreModelHookMiddleware(sync_hook)
|
||||||
|
|
||||||
|
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
# Verify hook was called and result returned
|
||||||
|
assert len(hook_called) == 1
|
||||||
|
assert result == {"data": "from_sync_hook"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abefore_model_detects_coroutine_function(
|
||||||
|
self, mock_state, mock_runtime
|
||||||
|
):
|
||||||
|
"""Test that abefore_model correctly detects async vs sync functions."""
|
||||||
|
# Test with async function
|
||||||
|
async def async_hook(state, runtime):
|
||||||
|
return {"type": "async"}
|
||||||
|
|
||||||
|
# Test with sync function
|
||||||
|
def sync_hook(state, runtime):
|
||||||
|
return {"type": "sync"}
|
||||||
|
|
||||||
|
async_middleware = PreModelHookMiddleware(async_hook)
|
||||||
|
sync_middleware = PreModelHookMiddleware(sync_hook)
|
||||||
|
|
||||||
|
# Both should execute successfully
|
||||||
|
async_result = await async_middleware.abefore_model(mock_state, mock_runtime)
|
||||||
|
sync_result = await sync_middleware.abefore_model(mock_state, mock_runtime)
|
||||||
|
|
||||||
|
assert async_result == {"type": "async"}
|
||||||
|
assert sync_result == {"type": "sync"}
|
||||||
@@ -2,10 +2,10 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Unit tests for agent locale restoration after create_react_agent execution.
|
Unit tests for agent locale restoration after agent execution.
|
||||||
|
|
||||||
Tests that meta fields (especially locale) are properly restored after
|
Tests that meta fields (especially locale) are properly restored after
|
||||||
agent.ainvoke() returns, since create_react_agent creates a MessagesState
|
agent.ainvoke() returns, since the agent creates a MessagesState
|
||||||
subgraph that filters out custom fields.
|
subgraph that filters out custom fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ class TestAgentLocaleRestoration:
|
|||||||
"""
|
"""
|
||||||
Demonstrate the problem: agent subgraph filters out locale.
|
Demonstrate the problem: agent subgraph filters out locale.
|
||||||
|
|
||||||
When create_react_agent creates a subgraph with MessagesState,
|
When the agent creates a subgraph with MessagesState,
|
||||||
it only returns messages, not custom fields.
|
it only returns messages, not custom fields.
|
||||||
"""
|
"""
|
||||||
# Simulate agent behavior: only returns messages
|
# Simulate agent behavior: only returns messages
|
||||||
|
|||||||
Reference in New Issue
Block a user