mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-27 15:54:48 +08:00
fix: migrate from deprecated create_react_agent to langchain.agents.create_agent (#802)
* fix: migrate from deprecated create_react_agent to langchain.agents.create_agent Fixes #799 - Replace deprecated langgraph.prebuilt.create_react_agent with langchain.agents.create_agent (LangGraph 1.0 migration) - Add DynamicPromptMiddleware to handle dynamic prompt templates (replaces the 'prompt' callable parameter) - Add PreModelHookMiddleware to handle pre-model hooks (replaces the 'pre_model_hook' parameter) - Update AgentState import from langchain.agents in template.py - Update tests to use the new API * fix:update the code with review comments
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain.agents import create_agent as langchain_create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from src.agents.tool_interceptor import wrap_tools_with_interceptor
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
@@ -14,6 +18,88 @@ from src.prompts import apply_prompt_template
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamicPromptMiddleware(AgentMiddleware):
|
||||
"""Middleware to apply dynamic prompt template before model invocation.
|
||||
|
||||
This middleware prepends a system message with the rendered prompt template
|
||||
to the messages list before the model is called.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt_template: str, locale: str = "en-US"):
|
||||
self.prompt_template = prompt_template
|
||||
self.locale = locale
|
||||
|
||||
def before_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Apply prompt template and prepend system message to messages."""
|
||||
try:
|
||||
# Get the rendered messages including system prompt from template
|
||||
rendered_messages = apply_prompt_template(
|
||||
self.prompt_template, state, locale=self.locale
|
||||
)
|
||||
# The first message is the system prompt, extract it
|
||||
if rendered_messages and len(rendered_messages) > 0:
|
||||
system_message = rendered_messages[0]
|
||||
# Prepend system message to existing messages
|
||||
return {"messages": [system_message]}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to apply prompt template in before_model: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
async def abefore_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Async version of before_model."""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
|
||||
class PreModelHookMiddleware(AgentMiddleware):
|
||||
"""Middleware to execute a pre-model hook before model invocation.
|
||||
|
||||
This middleware wraps the legacy pre_model_hook callable and executes it
|
||||
as part of the middleware chain.
|
||||
"""
|
||||
|
||||
def __init__(self, pre_model_hook: Callable):
|
||||
self._pre_model_hook = pre_model_hook
|
||||
|
||||
def before_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Execute the pre-model hook."""
|
||||
if not self._pre_model_hook:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = self._pre_model_hook(state, runtime)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Pre-model hook execution failed in before_model: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
async def abefore_model(self, state: Any, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Async version of before_model."""
|
||||
if not self._pre_model_hook:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Check if the hook is async
|
||||
if inspect.iscoroutinefunction(self._pre_model_hook):
|
||||
result = await self._pre_model_hook(state, runtime)
|
||||
else:
|
||||
# Run synchronous hook in thread pool to avoid blocking event loop
|
||||
result = await asyncio.to_thread(self._pre_model_hook, state, runtime)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Pre-model hook execution failed in abefore_model: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Create agents using configured LLM types
|
||||
def create_agent(
|
||||
agent_name: str,
|
||||
@@ -64,18 +150,23 @@ def create_agent(
|
||||
llm_type = AGENT_LLM_MAP.get(agent_type, "basic")
|
||||
logger.debug(f"Agent '{agent_name}' using LLM type: {llm_type}")
|
||||
|
||||
logger.debug(f"Creating ReAct agent '{agent_name}' with locale: {locale}")
|
||||
logger.debug(f"Creating agent '{agent_name}' with locale: {locale}")
|
||||
|
||||
# Build middleware list
|
||||
# Use closure to capture locale from the workflow state instead of relying on
|
||||
# agent state.get("locale"), which doesn't have the locale field
|
||||
# See: https://github.com/bytedance/deer-flow/issues/743
|
||||
agent = create_react_agent(
|
||||
middleware = [DynamicPromptMiddleware(prompt_template, locale)]
|
||||
|
||||
# Add pre-model hook middleware if provided
|
||||
if pre_model_hook:
|
||||
middleware.append(PreModelHookMiddleware(pre_model_hook))
|
||||
|
||||
agent = langchain_create_agent(
|
||||
name=agent_name,
|
||||
model=get_llm_by_type(llm_type),
|
||||
tools=processed_tools,
|
||||
prompt=lambda state, captured_locale=locale: apply_prompt_template(
|
||||
prompt_template, state, locale=captured_locale
|
||||
),
|
||||
pre_model_hook=pre_model_hook,
|
||||
middleware=middleware,
|
||||
)
|
||||
logger.info(f"Agent '{agent_name}' created successfully")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user