mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 12:24:46 +08:00
Support custom channel assistant IDs via lead_agent (#1500)
* Support custom channel assistant IDs via lead agent * Normalize custom channel agent names
This commit is contained in:
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
@@ -17,6 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
DEFAULT_LANGGRAPH_URL = "http://localhost:2024"
|
||||
DEFAULT_GATEWAY_URL = "http://localhost:8001"
|
||||
DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||
CUSTOM_AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
|
||||
|
||||
DEFAULT_RUN_CONFIG: dict[str, Any] = {"recursion_limit": 100}
|
||||
DEFAULT_RUN_CONTEXT: dict[str, Any] = {
|
||||
@@ -33,6 +35,10 @@ CHANNEL_CAPABILITIES = {
|
||||
}
|
||||
|
||||
|
||||
class InvalidChannelSessionConfigError(ValueError):
|
||||
"""Raised when IM channel session overrides contain invalid agent config."""
|
||||
|
||||
|
||||
def _as_dict(value: Any) -> dict[str, Any]:
|
||||
return dict(value) if isinstance(value, Mapping) else {}
|
||||
|
||||
@@ -45,6 +51,21 @@ def _merge_dicts(*layers: Any) -> dict[str, Any]:
|
||||
return merged
|
||||
|
||||
|
||||
def _normalize_custom_agent_name(raw_value: str) -> str:
|
||||
"""Normalize legacy channel assistant IDs into valid custom agent names."""
|
||||
normalized = raw_value.strip().lower().replace("_", "-")
|
||||
if not normalized:
|
||||
raise InvalidChannelSessionConfigError(
|
||||
"Channel session assistant_id is empty. Use 'lead_agent' or a valid custom agent name."
|
||||
)
|
||||
if not CUSTOM_AGENT_NAME_PATTERN.fullmatch(normalized):
|
||||
raise InvalidChannelSessionConfigError(
|
||||
f"Invalid channel session assistant_id {raw_value!r}. "
|
||||
"Use 'lead_agent' or a custom agent name containing only letters, digits, and hyphens."
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def _extract_response_text(result: dict | list) -> str:
|
||||
"""Extract the last AI message text from a LangGraph runs.wait result.
|
||||
|
||||
@@ -379,6 +400,13 @@ class ChannelManager:
|
||||
{"thread_id": thread_id},
|
||||
)
|
||||
|
||||
# Custom agents are implemented as lead_agent + agent_name context.
|
||||
# Keep backward compatibility for channel configs that set
|
||||
# assistant_id: <custom-agent-name> by routing through lead_agent.
|
||||
if assistant_id != DEFAULT_ASSISTANT_ID:
|
||||
run_context.setdefault("agent_name", _normalize_custom_agent_name(assistant_id))
|
||||
assistant_id = DEFAULT_ASSISTANT_ID
|
||||
|
||||
return assistant_id, run_config, run_context
|
||||
|
||||
# -- LangGraph SDK client (lazy) ----------------------------------------
|
||||
@@ -452,6 +480,14 @@ class ChannelManager:
|
||||
await self._handle_command(msg)
|
||||
else:
|
||||
await self._handle_chat(msg)
|
||||
except InvalidChannelSessionConfigError as exc:
|
||||
logger.warning(
|
||||
"Invalid channel session config for %s (chat=%s): %s",
|
||||
msg.channel_name,
|
||||
msg.chat_id,
|
||||
exc,
|
||||
)
|
||||
await self._send_error(msg, str(exc))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error handling message from %s (chat=%s)",
|
||||
|
||||
Reference in New Issue
Block a user