diff --git a/backend/src/agents/lead_agent/agent.py b/backend/src/agents/lead_agent/agent.py index 7775cd1..7b4028d 100644 --- a/backend/src/agents/lead_agent/agent.py +++ b/backend/src/agents/lead_agent/agent.py @@ -1,3 +1,5 @@ +import logging + from langchain.agents import create_agent from langchain.agents.middleware import SummarizationMiddleware, TodoListMiddleware from langchain_core.runnables import RunnableConfig @@ -17,6 +19,25 @@ from src.config.summarization_config import get_summarization_config from src.models import create_chat_model from src.sandbox.middleware import SandboxMiddleware +logger = logging.getLogger(__name__) + + +def _resolve_model_name(requested_model_name: str | None) -> str: + """Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured.""" + app_config = get_app_config() + default_model_name = app_config.models[0].name if app_config.models else None + if default_model_name is None: + raise ValueError( + "No chat models are configured. Please configure at least one model in config.yaml." + ) + + if requested_model_name and app_config.get_model_config(requested_model_name): + return requested_model_name + + if requested_model_name and requested_model_name != default_model_name: + logger.warning(f"Model '{requested_model_name}' not found in config; fallback to default model '{default_model_name}'.") + return default_model_name + def _create_summarization_middleware() -> SummarizationMiddleware | None: """Create and configure the summarization middleware from config.""" @@ -184,7 +205,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r # MemoryMiddleware queues conversation for memory update (after TitleMiddleware) # ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM # ClarificationMiddleware should be last to intercept clarification requests after model calls -def _build_middlewares(config: RunnableConfig): +def _build_middlewares(config: RunnableConfig, model_name: str | None): """Build middleware chain based on runtime configuration. Args: @@ -212,14 +233,9 @@ def _build_middlewares(config: RunnableConfig): # Add MemoryMiddleware (after TitleMiddleware) middlewares.append(MemoryMiddleware()) - # Add ViewImageMiddleware only if the current model supports vision - model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model") - + # Add ViewImageMiddleware only if the current model supports vision. + # Use the resolved runtime model_name from make_lead_agent to avoid stale config values. app_config = get_app_config() - # If no model_name specified, use the first model (default) - if model_name is None and app_config.models: - model_name = app_config.models[0].name - model_config = app_config.get_model_config(model_name) if model_name else None if model_config is not None and model_config.supports_vision: middlewares.append(ViewImageMiddleware()) @@ -240,11 +256,31 @@ def make_lead_agent(config: RunnableConfig): from src.tools import get_available_tools thinking_enabled = config.get("configurable", {}).get("thinking_enabled", True) - model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model") + requested_model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model") + model_name = _resolve_model_name(requested_model_name) + if model_name is None: + raise ValueError( + "No chat model could be resolved. Please configure at least one model in " + "config.yaml or provide a valid 'model_name'/'model' in the request." + ) is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False) subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False) max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3) - print(f"thinking_enabled: {thinking_enabled}, model_name: {model_name}, is_plan_mode: {is_plan_mode}, subagent_enabled: {subagent_enabled}, max_concurrent_subagents: {max_concurrent_subagents}") + + app_config = get_app_config() + model_config = app_config.get_model_config(model_name) if model_name else None + if thinking_enabled and model_config is not None and not model_config.supports_thinking: + logger.warning(f"Thinking mode is enabled but model '{model_name}' does not support it; fallback to non-thinking mode.") + thinking_enabled = False + + logger.info( + "thinking_enabled: %s, model_name: %s, is_plan_mode: %s, subagent_enabled: %s, max_concurrent_subagents: %s", + thinking_enabled, + model_name, + is_plan_mode, + subagent_enabled, + max_concurrent_subagents, + ) # Inject run metadata for LangSmith trace tagging if "metadata" not in config: @@ -261,7 +297,7 @@ def make_lead_agent(config: RunnableConfig): return create_agent( model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled), tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled), - middleware=_build_middlewares(config), + middleware=_build_middlewares(config, model_name=model_name), system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents), state_schema=ThreadState, ) diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/test_lead_agent_model_resolution.py new file mode 100644 index 0000000..b79829f --- /dev/null +++ b/backend/tests/test_lead_agent_model_resolution.py @@ -0,0 +1,136 @@ +"""Tests for lead agent runtime model resolution behavior.""" + +from __future__ import annotations + +import pytest + +from src.agents.lead_agent import agent as lead_agent_module +from src.config.app_config import AppConfig +from src.config.model_config import ModelConfig +from src.config.sandbox_config import SandboxConfig + + +def _make_app_config(models: list[ModelConfig]) -> AppConfig: + return AppConfig( + models=models, + sandbox=SandboxConfig(use="src.sandbox.local:LocalSandboxProvider"), + ) + + +def _make_model(name: str, *, supports_thinking: bool) -> ModelConfig: + return ModelConfig( + name=name, + display_name=name, + description=None, + use="langchain_openai:ChatOpenAI", + model=name, + supports_thinking=supports_thinking, + supports_vision=False, + ) + + +def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog): + app_config = _make_app_config( + [ + _make_model("default-model", supports_thinking=False), + _make_model("other-model", supports_thinking=True), + ] + ) + + monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config) + + with caplog.at_level("WARNING"): + resolved = lead_agent_module._resolve_model_name("missing-model") + + assert resolved == "default-model" + assert "fallback to default model 'default-model'" in caplog.text + + +def test_resolve_model_name_uses_default_when_none(monkeypatch): + app_config = _make_app_config( + [ + _make_model("default-model", supports_thinking=False), + _make_model("other-model", supports_thinking=True), + ] + ) + + monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config) + + resolved = lead_agent_module._resolve_model_name(None) + + assert resolved == "default-model" + + +def test_resolve_model_name_raises_when_no_models_configured(monkeypatch): + app_config = _make_app_config([]) + + monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config) + + with pytest.raises( + ValueError, + match="No chat models are configured", + ): + lead_agent_module._resolve_model_name("missing-model") + + +def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkeypatch): + app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)]) + + import src.tools as tools_module + + monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config) + monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: []) + monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name: []) + + captured: dict[str, object] = {} + + def _fake_create_chat_model(*, name, thinking_enabled): + captured["name"] = name + captured["thinking_enabled"] = thinking_enabled + return object() + + monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model) + monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs) + + result = lead_agent_module.make_lead_agent( + { + "configurable": { + "model_name": "safe-model", + "thinking_enabled": True, + "is_plan_mode": False, + "subagent_enabled": False, + } + } + ) + + assert captured["name"] == "safe-model" + assert captured["thinking_enabled"] is False + assert result["model"] is not None + + +def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch): + app_config = _make_app_config( + [ + _make_model("stale-model", supports_thinking=False), + ModelConfig( + name="vision-model", + display_name="vision-model", + description=None, + use="langchain_openai:ChatOpenAI", + model="vision-model", + supports_thinking=False, + supports_vision=True, + ), + ] + ) + + monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config) + monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None) + monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None) + + middlewares = lead_agent_module._build_middlewares( + {"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, + model_name="vision-model", + ) + + assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares) diff --git a/frontend/src/components/workspace/input-box.tsx b/frontend/src/components/workspace/input-box.tsx index b32febf..217f1f5 100644 --- a/frontend/src/components/workspace/input-box.tsx +++ b/frontend/src/components/workspace/input-box.tsx @@ -12,7 +12,13 @@ import { ZapIcon, } from "lucide-react"; import { useSearchParams } from "next/navigation"; -import { useCallback, useMemo, useState, type ComponentProps } from "react"; +import { + useCallback, + useEffect, + useMemo, + useState, + type ComponentProps, +} from "react"; import { PromptInput, @@ -63,6 +69,21 @@ import { import { ModeHoverGuide } from "./mode-hover-guide"; import { Tooltip } from "./tooltip"; +type InputMode = "flash" | "thinking" | "pro" | "ultra"; + +function getResolvedMode( + mode: InputMode | undefined, + supportsThinking: boolean, +): InputMode { + if (!supportsThinking && mode !== "flash") { + return "flash"; + } + if (mode) { + return mode; + } + return supportsThinking ? "pro" : "flash"; +} + export function InputBox({ className, disabled, @@ -104,42 +125,64 @@ export function InputBox({ const searchParams = useSearchParams(); const [modelDialogOpen, setModelDialogOpen] = useState(false); const { models } = useModels(); - const selectedModel = useMemo(() => { - if (!context.model_name && models.length > 0) { - const model = models[0]!; - setTimeout(() => { - onContextChange?.({ - ...context, - model_name: model.name, - mode: model.supports_thinking ? "pro" : "flash", - }); - }, 0); - return model; + + useEffect(() => { + if (models.length === 0) { + return; } - return models.find((m) => m.name === context.model_name); + const currentModel = models.find((m) => m.name === context.model_name); + const fallbackModel = currentModel ?? models[0]!; + const supportsThinking = fallbackModel.supports_thinking ?? false; + const nextModelName = fallbackModel.name; + const nextMode = getResolvedMode(context.mode, supportsThinking); + + if (context.model_name === nextModelName && context.mode === nextMode) { + return; + } + + onContextChange?.({ + ...context, + model_name: nextModelName, + mode: nextMode, + }); }, [context, models, onContextChange]); + + const selectedModel = useMemo(() => { + if (models.length === 0) { + return undefined; + } + return models.find((m) => m.name === context.model_name) ?? models[0]; + }, [context.model_name, models]); + const supportThinking = useMemo( () => selectedModel?.supports_thinking ?? false, [selectedModel], ); + const handleModelSelect = useCallback( (model_name: string) => { + const model = models.find((m) => m.name === model_name); + if (!model) { + return; + } onContextChange?.({ ...context, model_name, + mode: getResolvedMode(context.mode, model.supports_thinking ?? false), }); setModelDialogOpen(false); }, - [onContextChange, context], + [onContextChange, context, models], ); + const handleModeSelect = useCallback( - (mode: "flash" | "thinking" | "pro" | "ultra") => { + (mode: InputMode) => { onContextChange?.({ ...context, - mode, + mode: getResolvedMode(mode, supportThinking), }); }, - [onContextChange, context], + [onContextChange, context, supportThinking], ); const handleSubmit = useCallback( async (message: PromptInputMessage) => {