mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
feat(channels): make mobile session settings configurable by channel and user (#1021)
This commit is contained in:
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from src.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage
|
||||
from src.channels.store import ChannelStore
|
||||
@@ -14,6 +16,25 @@ DEFAULT_LANGGRAPH_URL = "http://localhost:2024"
|
||||
DEFAULT_GATEWAY_URL = "http://localhost:8001"
|
||||
DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||
|
||||
DEFAULT_RUN_CONFIG: dict[str, Any] = {"recursion_limit": 100}
|
||||
DEFAULT_RUN_CONTEXT: dict[str, Any] = {
|
||||
"thinking_enabled": True,
|
||||
"is_plan_mode": False,
|
||||
"subagent_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
def _as_dict(value: Any) -> dict[str, Any]:
|
||||
return dict(value) if isinstance(value, Mapping) else {}
|
||||
|
||||
|
||||
def _merge_dicts(*layers: Any) -> dict[str, Any]:
|
||||
merged: dict[str, Any] = {}
|
||||
for layer in layers:
|
||||
if isinstance(layer, Mapping):
|
||||
merged.update(layer)
|
||||
return merged
|
||||
|
||||
|
||||
def _extract_response_text(result: dict | list) -> str:
|
||||
"""Extract the last AI message text from a LangGraph runs.wait result.
|
||||
@@ -125,6 +146,8 @@ class ChannelManager:
|
||||
langgraph_url: str = DEFAULT_LANGGRAPH_URL,
|
||||
gateway_url: str = DEFAULT_GATEWAY_URL,
|
||||
assistant_id: str = DEFAULT_ASSISTANT_ID,
|
||||
default_session: dict[str, Any] | None = None,
|
||||
channel_sessions: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.bus = bus
|
||||
self.store = store
|
||||
@@ -132,11 +155,48 @@ class ChannelManager:
|
||||
self._langgraph_url = langgraph_url
|
||||
self._gateway_url = gateway_url
|
||||
self._assistant_id = assistant_id
|
||||
self._default_session = _as_dict(default_session)
|
||||
self._channel_sessions = dict(channel_sessions or {})
|
||||
self._client = None # lazy init — langgraph_sdk async client
|
||||
self._semaphore: asyncio.Semaphore | None = None
|
||||
self._running = False
|
||||
self._task: asyncio.Task | None = None
|
||||
|
||||
def _resolve_session_layer(self, msg: InboundMessage) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
channel_layer = _as_dict(self._channel_sessions.get(msg.channel_name))
|
||||
users_layer = _as_dict(channel_layer.get("users"))
|
||||
user_layer = _as_dict(users_layer.get(msg.user_id))
|
||||
return channel_layer, user_layer
|
||||
|
||||
def _resolve_run_params(self, msg: InboundMessage, thread_id: str) -> tuple[str, dict[str, Any], dict[str, Any]]:
|
||||
channel_layer, user_layer = self._resolve_session_layer(msg)
|
||||
|
||||
assistant_id = (
|
||||
user_layer.get("assistant_id")
|
||||
or channel_layer.get("assistant_id")
|
||||
or self._default_session.get("assistant_id")
|
||||
or self._assistant_id
|
||||
)
|
||||
if not isinstance(assistant_id, str) or not assistant_id.strip():
|
||||
assistant_id = self._assistant_id
|
||||
|
||||
run_config = _merge_dicts(
|
||||
DEFAULT_RUN_CONFIG,
|
||||
self._default_session.get("config"),
|
||||
channel_layer.get("config"),
|
||||
user_layer.get("config"),
|
||||
)
|
||||
|
||||
run_context = _merge_dicts(
|
||||
DEFAULT_RUN_CONTEXT,
|
||||
self._default_session.get("context"),
|
||||
channel_layer.get("context"),
|
||||
user_layer.get("context"),
|
||||
{"thread_id": thread_id},
|
||||
)
|
||||
|
||||
return assistant_id, run_config, run_context
|
||||
|
||||
# -- LangGraph SDK client (lazy) ----------------------------------------
|
||||
|
||||
def _get_client(self):
|
||||
@@ -246,18 +306,14 @@ class ChannelManager:
|
||||
if thread_id is None:
|
||||
thread_id = await self._create_thread(client, msg)
|
||||
|
||||
assistant_id, run_config, run_context = self._resolve_run_params(msg, thread_id)
|
||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||
result = await client.runs.wait(
|
||||
thread_id,
|
||||
self._assistant_id,
|
||||
assistant_id,
|
||||
input={"messages": [{"role": "human", "content": msg.text}]},
|
||||
config={"recursion_limit": 100},
|
||||
context={
|
||||
"thread_id": thread_id,
|
||||
"thinking_enabled": True,
|
||||
"is_plan_mode": False,
|
||||
"subagent_enabled": False,
|
||||
},
|
||||
config=run_config,
|
||||
context=run_context,
|
||||
)
|
||||
|
||||
response_text = _extract_response_text(result)
|
||||
|
||||
@@ -32,11 +32,19 @@ class ChannelService:
|
||||
config = dict(channels_config or {})
|
||||
langgraph_url = config.pop("langgraph_url", None) or "http://localhost:2024"
|
||||
gateway_url = config.pop("gateway_url", None) or "http://localhost:8001"
|
||||
default_session = config.pop("session", None)
|
||||
channel_sessions = {
|
||||
name: channel_config.get("session")
|
||||
for name, channel_config in config.items()
|
||||
if isinstance(channel_config, dict)
|
||||
}
|
||||
self.manager = ChannelManager(
|
||||
bus=self.bus,
|
||||
store=self.store,
|
||||
langgraph_url=langgraph_url,
|
||||
gateway_url=gateway_url,
|
||||
default_session=default_session if isinstance(default_session, dict) else None,
|
||||
channel_sessions=channel_sessions,
|
||||
)
|
||||
self._channels: dict[str, Any] = {} # name -> Channel instance
|
||||
self._config = config
|
||||
|
||||
@@ -424,6 +424,112 @@ class TestChannelManager:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_chat_uses_channel_session_overrides(self):
|
||||
from src.channels.manager import ChannelManager
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(
|
||||
bus=bus,
|
||||
store=store,
|
||||
channel_sessions={
|
||||
"telegram": {
|
||||
"assistant_id": "mobile_agent",
|
||||
"config": {"recursion_limit": 55},
|
||||
"context": {
|
||||
"thinking_enabled": False,
|
||||
"subagent_enabled": True,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
outbound_received = []
|
||||
|
||||
async def capture_outbound(msg):
|
||||
outbound_received.append(msg)
|
||||
|
||||
bus.subscribe_outbound(capture_outbound)
|
||||
|
||||
mock_client = _make_mock_langgraph_client()
|
||||
manager._client = mock_client
|
||||
|
||||
await manager.start()
|
||||
|
||||
inbound = InboundMessage(channel_name="telegram", chat_id="chat1", user_id="user1", text="hi")
|
||||
await bus.publish_inbound(inbound)
|
||||
await _wait_for(lambda: len(outbound_received) >= 1)
|
||||
await manager.stop()
|
||||
|
||||
mock_client.runs.wait.assert_called_once()
|
||||
call_args = mock_client.runs.wait.call_args
|
||||
assert call_args[0][1] == "mobile_agent"
|
||||
assert call_args[1]["config"]["recursion_limit"] == 55
|
||||
assert call_args[1]["context"]["thinking_enabled"] is False
|
||||
assert call_args[1]["context"]["subagent_enabled"] is True
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_chat_uses_user_session_overrides(self):
|
||||
from src.channels.manager import ChannelManager
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(
|
||||
bus=bus,
|
||||
store=store,
|
||||
default_session={"context": {"is_plan_mode": True}},
|
||||
channel_sessions={
|
||||
"telegram": {
|
||||
"assistant_id": "mobile_agent",
|
||||
"config": {"recursion_limit": 55},
|
||||
"context": {
|
||||
"thinking_enabled": False,
|
||||
"subagent_enabled": False,
|
||||
},
|
||||
"users": {
|
||||
"vip-user": {
|
||||
"assistant_id": "vip_agent",
|
||||
"config": {"recursion_limit": 77},
|
||||
"context": {
|
||||
"thinking_enabled": True,
|
||||
"subagent_enabled": True,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
outbound_received = []
|
||||
|
||||
async def capture_outbound(msg):
|
||||
outbound_received.append(msg)
|
||||
|
||||
bus.subscribe_outbound(capture_outbound)
|
||||
|
||||
mock_client = _make_mock_langgraph_client()
|
||||
manager._client = mock_client
|
||||
|
||||
await manager.start()
|
||||
|
||||
inbound = InboundMessage(channel_name="telegram", chat_id="chat1", user_id="vip-user", text="hi")
|
||||
await bus.publish_inbound(inbound)
|
||||
await _wait_for(lambda: len(outbound_received) >= 1)
|
||||
await manager.stop()
|
||||
|
||||
mock_client.runs.wait.assert_called_once()
|
||||
call_args = mock_client.runs.wait.call_args
|
||||
assert call_args[0][1] == "vip_agent"
|
||||
assert call_args[1]["config"]["recursion_limit"] == 77
|
||||
assert call_args[1]["context"]["thinking_enabled"] is True
|
||||
assert call_args[1]["context"]["subagent_enabled"] is True
|
||||
assert call_args[1]["context"]["is_plan_mode"] is True
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_command_help(self):
|
||||
from src.channels.manager import ChannelManager
|
||||
|
||||
@@ -954,6 +1060,30 @@ class TestChannelService:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_session_config_is_forwarded_to_manager(self):
|
||||
from src.channels.service import ChannelService
|
||||
|
||||
service = ChannelService(
|
||||
channels_config={
|
||||
"session": {"context": {"thinking_enabled": False}},
|
||||
"telegram": {
|
||||
"enabled": False,
|
||||
"session": {
|
||||
"assistant_id": "mobile_agent",
|
||||
"users": {
|
||||
"vip": {
|
||||
"assistant_id": "vip_agent",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert service.manager._default_session["context"]["thinking_enabled"] is False
|
||||
assert service.manager._channel_sessions["telegram"]["assistant_id"] == "mobile_agent"
|
||||
assert service.manager._channel_sessions["telegram"]["users"]["vip"]["assistant_id"] == "vip_agent"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slack send retry tests
|
||||
|
||||
Reference in New Issue
Block a user