diff --git a/README.md b/README.md index 84063fd..300991c 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,16 @@ channels: # Gateway API URL (default: http://localhost:8001) gateway_url: http://localhost:8001 + # Optional: global session defaults for all mobile channels + session: + assistant_id: lead_agent + config: + recursion_limit: 100 + context: + thinking_enabled: true + is_plan_mode: false + subagent_enabled: false + feishu: enabled: true app_id: $FEISHU_APP_ID @@ -224,6 +234,20 @@ channels: enabled: true bot_token: $TELEGRAM_BOT_TOKEN allowed_users: [] # empty = allow all + + # Optional: per-channel / per-user session settings + session: + assistant_id: mobile_agent + context: + thinking_enabled: false + users: + "123456789": + assistant_id: vip_agent + config: + recursion_limit: 150 + context: + thinking_enabled: true + subagent_enabled: true ``` Set the corresponding API keys in your `.env` file: diff --git a/backend/src/channels/manager.py b/backend/src/channels/manager.py index bee4eba..93e64af 100644 --- a/backend/src/channels/manager.py +++ b/backend/src/channels/manager.py @@ -4,6 +4,8 @@ from __future__ import annotations import asyncio import logging +from collections.abc import Mapping +from typing import Any from src.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage from src.channels.store import ChannelStore @@ -14,6 +16,25 @@ DEFAULT_LANGGRAPH_URL = "http://localhost:2024" DEFAULT_GATEWAY_URL = "http://localhost:8001" DEFAULT_ASSISTANT_ID = "lead_agent" +DEFAULT_RUN_CONFIG: dict[str, Any] = {"recursion_limit": 100} +DEFAULT_RUN_CONTEXT: dict[str, Any] = { + "thinking_enabled": True, + "is_plan_mode": False, + "subagent_enabled": False, +} + + +def _as_dict(value: Any) -> dict[str, Any]: + return dict(value) if isinstance(value, Mapping) else {} + + +def _merge_dicts(*layers: Any) -> dict[str, Any]: + merged: dict[str, Any] = {} + for layer in layers: + if isinstance(layer, Mapping): + merged.update(layer) + return merged + def _extract_response_text(result: dict | list) -> str: """Extract the last AI message text from a LangGraph runs.wait result. @@ -125,6 +146,8 @@ class ChannelManager: langgraph_url: str = DEFAULT_LANGGRAPH_URL, gateway_url: str = DEFAULT_GATEWAY_URL, assistant_id: str = DEFAULT_ASSISTANT_ID, + default_session: dict[str, Any] | None = None, + channel_sessions: dict[str, Any] | None = None, ) -> None: self.bus = bus self.store = store @@ -132,11 +155,48 @@ class ChannelManager: self._langgraph_url = langgraph_url self._gateway_url = gateway_url self._assistant_id = assistant_id + self._default_session = _as_dict(default_session) + self._channel_sessions = dict(channel_sessions or {}) self._client = None # lazy init — langgraph_sdk async client self._semaphore: asyncio.Semaphore | None = None self._running = False self._task: asyncio.Task | None = None + def _resolve_session_layer(self, msg: InboundMessage) -> tuple[dict[str, Any], dict[str, Any]]: + channel_layer = _as_dict(self._channel_sessions.get(msg.channel_name)) + users_layer = _as_dict(channel_layer.get("users")) + user_layer = _as_dict(users_layer.get(msg.user_id)) + return channel_layer, user_layer + + def _resolve_run_params(self, msg: InboundMessage, thread_id: str) -> tuple[str, dict[str, Any], dict[str, Any]]: + channel_layer, user_layer = self._resolve_session_layer(msg) + + assistant_id = ( + user_layer.get("assistant_id") + or channel_layer.get("assistant_id") + or self._default_session.get("assistant_id") + or self._assistant_id + ) + if not isinstance(assistant_id, str) or not assistant_id.strip(): + assistant_id = self._assistant_id + + run_config = _merge_dicts( + DEFAULT_RUN_CONFIG, + self._default_session.get("config"), + channel_layer.get("config"), + user_layer.get("config"), + ) + + run_context = _merge_dicts( + DEFAULT_RUN_CONTEXT, + self._default_session.get("context"), + channel_layer.get("context"), + user_layer.get("context"), + {"thread_id": thread_id}, + ) + + return assistant_id, run_config, run_context + # -- LangGraph SDK client (lazy) ---------------------------------------- def _get_client(self): @@ -246,18 +306,14 @@ class ChannelManager: if thread_id is None: thread_id = await self._create_thread(client, msg) + assistant_id, run_config, run_context = self._resolve_run_params(msg, thread_id) logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100]) result = await client.runs.wait( thread_id, - self._assistant_id, + assistant_id, input={"messages": [{"role": "human", "content": msg.text}]}, - config={"recursion_limit": 100}, - context={ - "thread_id": thread_id, - "thinking_enabled": True, - "is_plan_mode": False, - "subagent_enabled": False, - }, + config=run_config, + context=run_context, ) response_text = _extract_response_text(result) diff --git a/backend/src/channels/service.py b/backend/src/channels/service.py index b28e494..72fa3ec 100644 --- a/backend/src/channels/service.py +++ b/backend/src/channels/service.py @@ -32,11 +32,19 @@ class ChannelService: config = dict(channels_config or {}) langgraph_url = config.pop("langgraph_url", None) or "http://localhost:2024" gateway_url = config.pop("gateway_url", None) or "http://localhost:8001" + default_session = config.pop("session", None) + channel_sessions = { + name: channel_config.get("session") + for name, channel_config in config.items() + if isinstance(channel_config, dict) + } self.manager = ChannelManager( bus=self.bus, store=self.store, langgraph_url=langgraph_url, gateway_url=gateway_url, + default_session=default_session if isinstance(default_session, dict) else None, + channel_sessions=channel_sessions, ) self._channels: dict[str, Any] = {} # name -> Channel instance self._config = config diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 38695de..131476d 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -424,6 +424,112 @@ class TestChannelManager: _run(go()) + def test_handle_chat_uses_channel_session_overrides(self): + from src.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager( + bus=bus, + store=store, + channel_sessions={ + "telegram": { + "assistant_id": "mobile_agent", + "config": {"recursion_limit": 55}, + "context": { + "thinking_enabled": False, + "subagent_enabled": True, + }, + } + }, + ) + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + + mock_client = _make_mock_langgraph_client() + manager._client = mock_client + + await manager.start() + + inbound = InboundMessage(channel_name="telegram", chat_id="chat1", user_id="user1", text="hi") + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + mock_client.runs.wait.assert_called_once() + call_args = mock_client.runs.wait.call_args + assert call_args[0][1] == "mobile_agent" + assert call_args[1]["config"]["recursion_limit"] == 55 + assert call_args[1]["context"]["thinking_enabled"] is False + assert call_args[1]["context"]["subagent_enabled"] is True + + _run(go()) + + def test_handle_chat_uses_user_session_overrides(self): + from src.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager( + bus=bus, + store=store, + default_session={"context": {"is_plan_mode": True}}, + channel_sessions={ + "telegram": { + "assistant_id": "mobile_agent", + "config": {"recursion_limit": 55}, + "context": { + "thinking_enabled": False, + "subagent_enabled": False, + }, + "users": { + "vip-user": { + "assistant_id": "vip_agent", + "config": {"recursion_limit": 77}, + "context": { + "thinking_enabled": True, + "subagent_enabled": True, + }, + } + }, + } + }, + ) + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + + mock_client = _make_mock_langgraph_client() + manager._client = mock_client + + await manager.start() + + inbound = InboundMessage(channel_name="telegram", chat_id="chat1", user_id="vip-user", text="hi") + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + mock_client.runs.wait.assert_called_once() + call_args = mock_client.runs.wait.call_args + assert call_args[0][1] == "vip_agent" + assert call_args[1]["config"]["recursion_limit"] == 77 + assert call_args[1]["context"]["thinking_enabled"] is True + assert call_args[1]["context"]["subagent_enabled"] is True + assert call_args[1]["context"]["is_plan_mode"] is True + + _run(go()) + def test_handle_command_help(self): from src.channels.manager import ChannelManager @@ -954,6 +1060,30 @@ class TestChannelService: _run(go()) + def test_session_config_is_forwarded_to_manager(self): + from src.channels.service import ChannelService + + service = ChannelService( + channels_config={ + "session": {"context": {"thinking_enabled": False}}, + "telegram": { + "enabled": False, + "session": { + "assistant_id": "mobile_agent", + "users": { + "vip": { + "assistant_id": "vip_agent", + } + }, + }, + }, + } + ) + + assert service.manager._default_session["context"]["thinking_enabled"] is False + assert service.manager._channel_sessions["telegram"]["assistant_id"] == "mobile_agent" + assert service.manager._channel_sessions["telegram"]["users"]["vip"]["assistant_id"] == "vip_agent" + # --------------------------------------------------------------------------- # Slack send retry tests diff --git a/config.example.yaml b/config.example.yaml index 4429f1f..5f7ad88 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -394,6 +394,16 @@ memory: # # Gateway API URL for auxiliary queries like /models, /memory (default: http://localhost:8001) # gateway_url: http://localhost:8001 # +# # Optional: default mobile/session settings for all IM channels +# session: +# assistant_id: lead_agent +# config: +# recursion_limit: 100 +# context: +# thinking_enabled: true +# is_plan_mode: false +# subagent_enabled: false +# # feishu: # enabled: false # app_id: $FEISHU_APP_ID @@ -409,3 +419,19 @@ memory: # enabled: false # bot_token: $TELEGRAM_BOT_TOKEN # allowed_users: [] # empty = allow all +# +# # Optional: channel-level session overrides +# session: +# assistant_id: mobile_agent +# context: +# thinking_enabled: false +# +# # Optional: per-user overrides by user_id +# users: +# "123456789": +# assistant_id: vip_agent +# config: +# recursion_limit: 150 +# context: +# thinking_enabled: true +# subagent_enabled: true