feat(channels): make mobile session settings configurable by channel and user (#1021)

This commit is contained in:
aworki
2026-03-08 22:19:40 +08:00
committed by GitHub
parent 8871fca5cb
commit ac1e1915ef
5 changed files with 252 additions and 8 deletions

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
import asyncio
import logging
from collections.abc import Mapping
from typing import Any
from src.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage
from src.channels.store import ChannelStore
@@ -14,6 +16,25 @@ DEFAULT_LANGGRAPH_URL = "http://localhost:2024"
DEFAULT_GATEWAY_URL = "http://localhost:8001"
DEFAULT_ASSISTANT_ID = "lead_agent"
DEFAULT_RUN_CONFIG: dict[str, Any] = {"recursion_limit": 100}
DEFAULT_RUN_CONTEXT: dict[str, Any] = {
"thinking_enabled": True,
"is_plan_mode": False,
"subagent_enabled": False,
}
def _as_dict(value: Any) -> dict[str, Any]:
return dict(value) if isinstance(value, Mapping) else {}
def _merge_dicts(*layers: Any) -> dict[str, Any]:
merged: dict[str, Any] = {}
for layer in layers:
if isinstance(layer, Mapping):
merged.update(layer)
return merged
def _extract_response_text(result: dict | list) -> str:
"""Extract the last AI message text from a LangGraph runs.wait result.
@@ -125,6 +146,8 @@ class ChannelManager:
langgraph_url: str = DEFAULT_LANGGRAPH_URL,
gateway_url: str = DEFAULT_GATEWAY_URL,
assistant_id: str = DEFAULT_ASSISTANT_ID,
default_session: dict[str, Any] | None = None,
channel_sessions: dict[str, Any] | None = None,
) -> None:
self.bus = bus
self.store = store
@@ -132,11 +155,48 @@ class ChannelManager:
self._langgraph_url = langgraph_url
self._gateway_url = gateway_url
self._assistant_id = assistant_id
self._default_session = _as_dict(default_session)
self._channel_sessions = dict(channel_sessions or {})
self._client = None # lazy init — langgraph_sdk async client
self._semaphore: asyncio.Semaphore | None = None
self._running = False
self._task: asyncio.Task | None = None
def _resolve_session_layer(self, msg: InboundMessage) -> tuple[dict[str, Any], dict[str, Any]]:
channel_layer = _as_dict(self._channel_sessions.get(msg.channel_name))
users_layer = _as_dict(channel_layer.get("users"))
user_layer = _as_dict(users_layer.get(msg.user_id))
return channel_layer, user_layer
def _resolve_run_params(self, msg: InboundMessage, thread_id: str) -> tuple[str, dict[str, Any], dict[str, Any]]:
channel_layer, user_layer = self._resolve_session_layer(msg)
assistant_id = (
user_layer.get("assistant_id")
or channel_layer.get("assistant_id")
or self._default_session.get("assistant_id")
or self._assistant_id
)
if not isinstance(assistant_id, str) or not assistant_id.strip():
assistant_id = self._assistant_id
run_config = _merge_dicts(
DEFAULT_RUN_CONFIG,
self._default_session.get("config"),
channel_layer.get("config"),
user_layer.get("config"),
)
run_context = _merge_dicts(
DEFAULT_RUN_CONTEXT,
self._default_session.get("context"),
channel_layer.get("context"),
user_layer.get("context"),
{"thread_id": thread_id},
)
return assistant_id, run_config, run_context
# -- LangGraph SDK client (lazy) ----------------------------------------
def _get_client(self):
@@ -246,18 +306,14 @@ class ChannelManager:
if thread_id is None:
thread_id = await self._create_thread(client, msg)
assistant_id, run_config, run_context = self._resolve_run_params(msg, thread_id)
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
result = await client.runs.wait(
thread_id,
self._assistant_id,
assistant_id,
input={"messages": [{"role": "human", "content": msg.text}]},
config={"recursion_limit": 100},
context={
"thread_id": thread_id,
"thinking_enabled": True,
"is_plan_mode": False,
"subagent_enabled": False,
},
config=run_config,
context=run_context,
)
response_text = _extract_response_text(result)

View File

@@ -32,11 +32,19 @@ class ChannelService:
config = dict(channels_config or {})
langgraph_url = config.pop("langgraph_url", None) or "http://localhost:2024"
gateway_url = config.pop("gateway_url", None) or "http://localhost:8001"
default_session = config.pop("session", None)
channel_sessions = {
name: channel_config.get("session")
for name, channel_config in config.items()
if isinstance(channel_config, dict)
}
self.manager = ChannelManager(
bus=self.bus,
store=self.store,
langgraph_url=langgraph_url,
gateway_url=gateway_url,
default_session=default_session if isinstance(default_session, dict) else None,
channel_sessions=channel_sessions,
)
self._channels: dict[str, Any] = {} # name -> Channel instance
self._config = config

View File

@@ -424,6 +424,112 @@ class TestChannelManager:
_run(go())
def test_handle_chat_uses_channel_session_overrides(self):
from src.channels.manager import ChannelManager
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(
bus=bus,
store=store,
channel_sessions={
"telegram": {
"assistant_id": "mobile_agent",
"config": {"recursion_limit": 55},
"context": {
"thinking_enabled": False,
"subagent_enabled": True,
},
}
},
)
outbound_received = []
async def capture_outbound(msg):
outbound_received.append(msg)
bus.subscribe_outbound(capture_outbound)
mock_client = _make_mock_langgraph_client()
manager._client = mock_client
await manager.start()
inbound = InboundMessage(channel_name="telegram", chat_id="chat1", user_id="user1", text="hi")
await bus.publish_inbound(inbound)
await _wait_for(lambda: len(outbound_received) >= 1)
await manager.stop()
mock_client.runs.wait.assert_called_once()
call_args = mock_client.runs.wait.call_args
assert call_args[0][1] == "mobile_agent"
assert call_args[1]["config"]["recursion_limit"] == 55
assert call_args[1]["context"]["thinking_enabled"] is False
assert call_args[1]["context"]["subagent_enabled"] is True
_run(go())
def test_handle_chat_uses_user_session_overrides(self):
from src.channels.manager import ChannelManager
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(
bus=bus,
store=store,
default_session={"context": {"is_plan_mode": True}},
channel_sessions={
"telegram": {
"assistant_id": "mobile_agent",
"config": {"recursion_limit": 55},
"context": {
"thinking_enabled": False,
"subagent_enabled": False,
},
"users": {
"vip-user": {
"assistant_id": "vip_agent",
"config": {"recursion_limit": 77},
"context": {
"thinking_enabled": True,
"subagent_enabled": True,
},
}
},
}
},
)
outbound_received = []
async def capture_outbound(msg):
outbound_received.append(msg)
bus.subscribe_outbound(capture_outbound)
mock_client = _make_mock_langgraph_client()
manager._client = mock_client
await manager.start()
inbound = InboundMessage(channel_name="telegram", chat_id="chat1", user_id="vip-user", text="hi")
await bus.publish_inbound(inbound)
await _wait_for(lambda: len(outbound_received) >= 1)
await manager.stop()
mock_client.runs.wait.assert_called_once()
call_args = mock_client.runs.wait.call_args
assert call_args[0][1] == "vip_agent"
assert call_args[1]["config"]["recursion_limit"] == 77
assert call_args[1]["context"]["thinking_enabled"] is True
assert call_args[1]["context"]["subagent_enabled"] is True
assert call_args[1]["context"]["is_plan_mode"] is True
_run(go())
def test_handle_command_help(self):
from src.channels.manager import ChannelManager
@@ -954,6 +1060,30 @@ class TestChannelService:
_run(go())
def test_session_config_is_forwarded_to_manager(self):
from src.channels.service import ChannelService
service = ChannelService(
channels_config={
"session": {"context": {"thinking_enabled": False}},
"telegram": {
"enabled": False,
"session": {
"assistant_id": "mobile_agent",
"users": {
"vip": {
"assistant_id": "vip_agent",
}
},
},
},
}
)
assert service.manager._default_session["context"]["thinking_enabled"] is False
assert service.manager._channel_sessions["telegram"]["assistant_id"] == "mobile_agent"
assert service.manager._channel_sessions["telegram"]["users"]["vip"]["assistant_id"] == "vip_agent"
# ---------------------------------------------------------------------------
# Slack send retry tests