fix: recover from stale model context when configured models change (#898)

* fix: recover from stale model context after config model changes

* fix: fail fast on missing model config and expand model resolution tests

* fix: remove duplicate get_app_config imports

* fix: align model resolution tests with runtime imports

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: remove duplicate model resolution test case

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Xinmin Zeng
2026-02-26 13:54:29 +08:00
committed by GitHub
parent 3e6e4b0b5f
commit 6a55860a15
3 changed files with 243 additions and 28 deletions

View File

@@ -1,3 +1,5 @@
import logging
from langchain.agents import create_agent
from langchain.agents.middleware import SummarizationMiddleware, TodoListMiddleware
from langchain_core.runnables import RunnableConfig
@@ -17,6 +19,25 @@ from src.config.summarization_config import get_summarization_config
from src.models import create_chat_model
from src.sandbox.middleware import SandboxMiddleware
logger = logging.getLogger(__name__)
def _resolve_model_name(requested_model_name: str | None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config()
default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None:
raise ValueError(
"No chat models are configured. Please configure at least one model in config.yaml."
)
if requested_model_name and app_config.get_model_config(requested_model_name):
return requested_model_name
if requested_model_name and requested_model_name != default_model_name:
logger.warning(f"Model '{requested_model_name}' not found in config; fallback to default model '{default_model_name}'.")
return default_model_name
def _create_summarization_middleware() -> SummarizationMiddleware | None:
"""Create and configure the summarization middleware from config."""
@@ -184,7 +205,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r
# MemoryMiddleware queues conversation for memory update (after TitleMiddleware)
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
# ClarificationMiddleware should be last to intercept clarification requests after model calls
def _build_middlewares(config: RunnableConfig):
def _build_middlewares(config: RunnableConfig, model_name: str | None):
"""Build middleware chain based on runtime configuration.
Args:
@@ -212,14 +233,9 @@ def _build_middlewares(config: RunnableConfig):
# Add MemoryMiddleware (after TitleMiddleware)
middlewares.append(MemoryMiddleware())
# Add ViewImageMiddleware only if the current model supports vision
model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model")
# Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = get_app_config()
# If no model_name specified, use the first model (default)
if model_name is None and app_config.models:
model_name = app_config.models[0].name
model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware())
@@ -240,11 +256,31 @@ def make_lead_agent(config: RunnableConfig):
from src.tools import get_available_tools
thinking_enabled = config.get("configurable", {}).get("thinking_enabled", True)
model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model")
requested_model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model")
model_name = _resolve_model_name(requested_model_name)
if model_name is None:
raise ValueError(
"No chat model could be resolved. Please configure at least one model in "
"config.yaml or provide a valid 'model_name'/'model' in the request."
)
is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False)
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False)
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
print(f"thinking_enabled: {thinking_enabled}, model_name: {model_name}, is_plan_mode: {is_plan_mode}, subagent_enabled: {subagent_enabled}, max_concurrent_subagents: {max_concurrent_subagents}")
app_config = get_app_config()
model_config = app_config.get_model_config(model_name) if model_name else None
if thinking_enabled and model_config is not None and not model_config.supports_thinking:
logger.warning(f"Thinking mode is enabled but model '{model_name}' does not support it; fallback to non-thinking mode.")
thinking_enabled = False
logger.info(
"thinking_enabled: %s, model_name: %s, is_plan_mode: %s, subagent_enabled: %s, max_concurrent_subagents: %s",
thinking_enabled,
model_name,
is_plan_mode,
subagent_enabled,
max_concurrent_subagents,
)
# Inject run metadata for LangSmith trace tagging
if "metadata" not in config:
@@ -261,7 +297,7 @@ def make_lead_agent(config: RunnableConfig):
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled),
middleware=_build_middlewares(config),
middleware=_build_middlewares(config, model_name=model_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents),
state_schema=ThreadState,
)

View File

@@ -0,0 +1,136 @@
"""Tests for lead agent runtime model resolution behavior."""
from __future__ import annotations
import pytest
from src.agents.lead_agent import agent as lead_agent_module
from src.config.app_config import AppConfig
from src.config.model_config import ModelConfig
from src.config.sandbox_config import SandboxConfig
def _make_app_config(models: list[ModelConfig]) -> AppConfig:
return AppConfig(
models=models,
sandbox=SandboxConfig(use="src.sandbox.local:LocalSandboxProvider"),
)
def _make_model(name: str, *, supports_thinking: bool) -> ModelConfig:
return ModelConfig(
name=name,
display_name=name,
description=None,
use="langchain_openai:ChatOpenAI",
model=name,
supports_thinking=supports_thinking,
supports_vision=False,
)
def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
app_config = _make_app_config(
[
_make_model("default-model", supports_thinking=False),
_make_model("other-model", supports_thinking=True),
]
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
with caplog.at_level("WARNING"):
resolved = lead_agent_module._resolve_model_name("missing-model")
assert resolved == "default-model"
assert "fallback to default model 'default-model'" in caplog.text
def test_resolve_model_name_uses_default_when_none(monkeypatch):
app_config = _make_app_config(
[
_make_model("default-model", supports_thinking=False),
_make_model("other-model", supports_thinking=True),
]
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
resolved = lead_agent_module._resolve_model_name(None)
assert resolved == "default-model"
def test_resolve_model_name_raises_when_no_models_configured(monkeypatch):
app_config = _make_app_config([])
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
with pytest.raises(
ValueError,
match="No chat models are configured",
):
lead_agent_module._resolve_model_name("missing-model")
def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkeypatch):
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
import src.tools as tools_module
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name: [])
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
result = lead_agent_module.make_lead_agent(
{
"configurable": {
"model_name": "safe-model",
"thinking_enabled": True,
"is_plan_mode": False,
"subagent_enabled": False,
}
}
)
assert captured["name"] == "safe-model"
assert captured["thinking_enabled"] is False
assert result["model"] is not None
def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
app_config = _make_app_config(
[
_make_model("stale-model", supports_thinking=False),
ModelConfig(
name="vision-model",
display_name="vision-model",
description=None,
use="langchain_openai:ChatOpenAI",
model="vision-model",
supports_thinking=False,
supports_vision=True,
),
]
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}},
model_name="vision-model",
)
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)

View File

@@ -12,7 +12,13 @@ import {
ZapIcon,
} from "lucide-react";
import { useSearchParams } from "next/navigation";
import { useCallback, useMemo, useState, type ComponentProps } from "react";
import {
useCallback,
useEffect,
useMemo,
useState,
type ComponentProps,
} from "react";
import {
PromptInput,
@@ -63,6 +69,21 @@ import {
import { ModeHoverGuide } from "./mode-hover-guide";
import { Tooltip } from "./tooltip";
type InputMode = "flash" | "thinking" | "pro" | "ultra";
function getResolvedMode(
mode: InputMode | undefined,
supportsThinking: boolean,
): InputMode {
if (!supportsThinking && mode !== "flash") {
return "flash";
}
if (mode) {
return mode;
}
return supportsThinking ? "pro" : "flash";
}
export function InputBox({
className,
disabled,
@@ -104,42 +125,64 @@ export function InputBox({
const searchParams = useSearchParams();
const [modelDialogOpen, setModelDialogOpen] = useState(false);
const { models } = useModels();
const selectedModel = useMemo(() => {
if (!context.model_name && models.length > 0) {
const model = models[0]!;
setTimeout(() => {
onContextChange?.({
...context,
model_name: model.name,
mode: model.supports_thinking ? "pro" : "flash",
});
}, 0);
return model;
useEffect(() => {
if (models.length === 0) {
return;
}
return models.find((m) => m.name === context.model_name);
const currentModel = models.find((m) => m.name === context.model_name);
const fallbackModel = currentModel ?? models[0]!;
const supportsThinking = fallbackModel.supports_thinking ?? false;
const nextModelName = fallbackModel.name;
const nextMode = getResolvedMode(context.mode, supportsThinking);
if (context.model_name === nextModelName && context.mode === nextMode) {
return;
}
onContextChange?.({
...context,
model_name: nextModelName,
mode: nextMode,
});
}, [context, models, onContextChange]);
const selectedModel = useMemo(() => {
if (models.length === 0) {
return undefined;
}
return models.find((m) => m.name === context.model_name) ?? models[0];
}, [context.model_name, models]);
const supportThinking = useMemo(
() => selectedModel?.supports_thinking ?? false,
[selectedModel],
);
const handleModelSelect = useCallback(
(model_name: string) => {
const model = models.find((m) => m.name === model_name);
if (!model) {
return;
}
onContextChange?.({
...context,
model_name,
mode: getResolvedMode(context.mode, model.supports_thinking ?? false),
});
setModelDialogOpen(false);
},
[onContextChange, context],
[onContextChange, context, models],
);
const handleModeSelect = useCallback(
(mode: "flash" | "thinking" | "pro" | "ultra") => {
(mode: InputMode) => {
onContextChange?.({
...context,
mode,
mode: getResolvedMode(mode, supportThinking),
});
},
[onContextChange, context],
[onContextChange, context, supportThinking],
);
const handleSubmit = useCallback(
async (message: PromptInputMessage) => {