mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 12:24:46 +08:00
fix: recover from stale model context when configured models change (#898)
* fix: recover from stale model context after config model changes * fix: fail fast on missing model config and expand model resolution tests * fix: remove duplicate get_app_config imports * fix: align model resolution tests with runtime imports * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: remove duplicate model resolution test case --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.agents.middleware import SummarizationMiddleware, TodoListMiddleware
|
from langchain.agents.middleware import SummarizationMiddleware, TodoListMiddleware
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
@@ -17,6 +19,25 @@ from src.config.summarization_config import get_summarization_config
|
|||||||
from src.models import create_chat_model
|
from src.models import create_chat_model
|
||||||
from src.sandbox.middleware import SandboxMiddleware
|
from src.sandbox.middleware import SandboxMiddleware
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_model_name(requested_model_name: str | None) -> str:
|
||||||
|
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
|
||||||
|
app_config = get_app_config()
|
||||||
|
default_model_name = app_config.models[0].name if app_config.models else None
|
||||||
|
if default_model_name is None:
|
||||||
|
raise ValueError(
|
||||||
|
"No chat models are configured. Please configure at least one model in config.yaml."
|
||||||
|
)
|
||||||
|
|
||||||
|
if requested_model_name and app_config.get_model_config(requested_model_name):
|
||||||
|
return requested_model_name
|
||||||
|
|
||||||
|
if requested_model_name and requested_model_name != default_model_name:
|
||||||
|
logger.warning(f"Model '{requested_model_name}' not found in config; fallback to default model '{default_model_name}'.")
|
||||||
|
return default_model_name
|
||||||
|
|
||||||
|
|
||||||
def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
||||||
"""Create and configure the summarization middleware from config."""
|
"""Create and configure the summarization middleware from config."""
|
||||||
@@ -184,7 +205,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
|||||||
# MemoryMiddleware queues conversation for memory update (after TitleMiddleware)
|
# MemoryMiddleware queues conversation for memory update (after TitleMiddleware)
|
||||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||||
def _build_middlewares(config: RunnableConfig):
|
def _build_middlewares(config: RunnableConfig, model_name: str | None):
|
||||||
"""Build middleware chain based on runtime configuration.
|
"""Build middleware chain based on runtime configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -212,14 +233,9 @@ def _build_middlewares(config: RunnableConfig):
|
|||||||
# Add MemoryMiddleware (after TitleMiddleware)
|
# Add MemoryMiddleware (after TitleMiddleware)
|
||||||
middlewares.append(MemoryMiddleware())
|
middlewares.append(MemoryMiddleware())
|
||||||
|
|
||||||
# Add ViewImageMiddleware only if the current model supports vision
|
# Add ViewImageMiddleware only if the current model supports vision.
|
||||||
model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model")
|
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
|
||||||
|
|
||||||
app_config = get_app_config()
|
app_config = get_app_config()
|
||||||
# If no model_name specified, use the first model (default)
|
|
||||||
if model_name is None and app_config.models:
|
|
||||||
model_name = app_config.models[0].name
|
|
||||||
|
|
||||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||||
if model_config is not None and model_config.supports_vision:
|
if model_config is not None and model_config.supports_vision:
|
||||||
middlewares.append(ViewImageMiddleware())
|
middlewares.append(ViewImageMiddleware())
|
||||||
@@ -240,11 +256,31 @@ def make_lead_agent(config: RunnableConfig):
|
|||||||
from src.tools import get_available_tools
|
from src.tools import get_available_tools
|
||||||
|
|
||||||
thinking_enabled = config.get("configurable", {}).get("thinking_enabled", True)
|
thinking_enabled = config.get("configurable", {}).get("thinking_enabled", True)
|
||||||
model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model")
|
requested_model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model")
|
||||||
|
model_name = _resolve_model_name(requested_model_name)
|
||||||
|
if model_name is None:
|
||||||
|
raise ValueError(
|
||||||
|
"No chat model could be resolved. Please configure at least one model in "
|
||||||
|
"config.yaml or provide a valid 'model_name'/'model' in the request."
|
||||||
|
)
|
||||||
is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False)
|
is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False)
|
||||||
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False)
|
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False)
|
||||||
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
|
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
|
||||||
print(f"thinking_enabled: {thinking_enabled}, model_name: {model_name}, is_plan_mode: {is_plan_mode}, subagent_enabled: {subagent_enabled}, max_concurrent_subagents: {max_concurrent_subagents}")
|
|
||||||
|
app_config = get_app_config()
|
||||||
|
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||||
|
if thinking_enabled and model_config is not None and not model_config.supports_thinking:
|
||||||
|
logger.warning(f"Thinking mode is enabled but model '{model_name}' does not support it; fallback to non-thinking mode.")
|
||||||
|
thinking_enabled = False
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"thinking_enabled: %s, model_name: %s, is_plan_mode: %s, subagent_enabled: %s, max_concurrent_subagents: %s",
|
||||||
|
thinking_enabled,
|
||||||
|
model_name,
|
||||||
|
is_plan_mode,
|
||||||
|
subagent_enabled,
|
||||||
|
max_concurrent_subagents,
|
||||||
|
)
|
||||||
|
|
||||||
# Inject run metadata for LangSmith trace tagging
|
# Inject run metadata for LangSmith trace tagging
|
||||||
if "metadata" not in config:
|
if "metadata" not in config:
|
||||||
@@ -261,7 +297,7 @@ def make_lead_agent(config: RunnableConfig):
|
|||||||
return create_agent(
|
return create_agent(
|
||||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled),
|
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled),
|
||||||
middleware=_build_middlewares(config),
|
middleware=_build_middlewares(config, model_name=model_name),
|
||||||
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents),
|
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents),
|
||||||
state_schema=ThreadState,
|
state_schema=ThreadState,
|
||||||
)
|
)
|
||||||
|
|||||||
136
backend/tests/test_lead_agent_model_resolution.py
Normal file
136
backend/tests/test_lead_agent_model_resolution.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""Tests for lead agent runtime model resolution behavior."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.agents.lead_agent import agent as lead_agent_module
|
||||||
|
from src.config.app_config import AppConfig
|
||||||
|
from src.config.model_config import ModelConfig
|
||||||
|
from src.config.sandbox_config import SandboxConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app_config(models: list[ModelConfig]) -> AppConfig:
|
||||||
|
return AppConfig(
|
||||||
|
models=models,
|
||||||
|
sandbox=SandboxConfig(use="src.sandbox.local:LocalSandboxProvider"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_model(name: str, *, supports_thinking: bool) -> ModelConfig:
|
||||||
|
return ModelConfig(
|
||||||
|
name=name,
|
||||||
|
display_name=name,
|
||||||
|
description=None,
|
||||||
|
use="langchain_openai:ChatOpenAI",
|
||||||
|
model=name,
|
||||||
|
supports_thinking=supports_thinking,
|
||||||
|
supports_vision=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
|
||||||
|
app_config = _make_app_config(
|
||||||
|
[
|
||||||
|
_make_model("default-model", supports_thinking=False),
|
||||||
|
_make_model("other-model", supports_thinking=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||||
|
|
||||||
|
with caplog.at_level("WARNING"):
|
||||||
|
resolved = lead_agent_module._resolve_model_name("missing-model")
|
||||||
|
|
||||||
|
assert resolved == "default-model"
|
||||||
|
assert "fallback to default model 'default-model'" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_model_name_uses_default_when_none(monkeypatch):
|
||||||
|
app_config = _make_app_config(
|
||||||
|
[
|
||||||
|
_make_model("default-model", supports_thinking=False),
|
||||||
|
_make_model("other-model", supports_thinking=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||||
|
|
||||||
|
resolved = lead_agent_module._resolve_model_name(None)
|
||||||
|
|
||||||
|
assert resolved == "default-model"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_model_name_raises_when_no_models_configured(monkeypatch):
|
||||||
|
app_config = _make_app_config([])
|
||||||
|
|
||||||
|
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="No chat models are configured",
|
||||||
|
):
|
||||||
|
lead_agent_module._resolve_model_name("missing-model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkeypatch):
|
||||||
|
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
|
||||||
|
|
||||||
|
import src.tools as tools_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||||
|
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
|
||||||
|
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name: [])
|
||||||
|
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
def _fake_create_chat_model(*, name, thinking_enabled):
|
||||||
|
captured["name"] = name
|
||||||
|
captured["thinking_enabled"] = thinking_enabled
|
||||||
|
return object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||||
|
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||||
|
|
||||||
|
result = lead_agent_module.make_lead_agent(
|
||||||
|
{
|
||||||
|
"configurable": {
|
||||||
|
"model_name": "safe-model",
|
||||||
|
"thinking_enabled": True,
|
||||||
|
"is_plan_mode": False,
|
||||||
|
"subagent_enabled": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured["name"] == "safe-model"
|
||||||
|
assert captured["thinking_enabled"] is False
|
||||||
|
assert result["model"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||||
|
app_config = _make_app_config(
|
||||||
|
[
|
||||||
|
_make_model("stale-model", supports_thinking=False),
|
||||||
|
ModelConfig(
|
||||||
|
name="vision-model",
|
||||||
|
display_name="vision-model",
|
||||||
|
description=None,
|
||||||
|
use="langchain_openai:ChatOpenAI",
|
||||||
|
model="vision-model",
|
||||||
|
supports_thinking=False,
|
||||||
|
supports_vision=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||||
|
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
|
||||||
|
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||||
|
|
||||||
|
middlewares = lead_agent_module._build_middlewares(
|
||||||
|
{"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}},
|
||||||
|
model_name="vision-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
|
||||||
@@ -12,7 +12,13 @@ import {
|
|||||||
ZapIcon,
|
ZapIcon,
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import { useSearchParams } from "next/navigation";
|
import { useSearchParams } from "next/navigation";
|
||||||
import { useCallback, useMemo, useState, type ComponentProps } from "react";
|
import {
|
||||||
|
useCallback,
|
||||||
|
useEffect,
|
||||||
|
useMemo,
|
||||||
|
useState,
|
||||||
|
type ComponentProps,
|
||||||
|
} from "react";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
PromptInput,
|
PromptInput,
|
||||||
@@ -63,6 +69,21 @@ import {
|
|||||||
import { ModeHoverGuide } from "./mode-hover-guide";
|
import { ModeHoverGuide } from "./mode-hover-guide";
|
||||||
import { Tooltip } from "./tooltip";
|
import { Tooltip } from "./tooltip";
|
||||||
|
|
||||||
|
type InputMode = "flash" | "thinking" | "pro" | "ultra";
|
||||||
|
|
||||||
|
function getResolvedMode(
|
||||||
|
mode: InputMode | undefined,
|
||||||
|
supportsThinking: boolean,
|
||||||
|
): InputMode {
|
||||||
|
if (!supportsThinking && mode !== "flash") {
|
||||||
|
return "flash";
|
||||||
|
}
|
||||||
|
if (mode) {
|
||||||
|
return mode;
|
||||||
|
}
|
||||||
|
return supportsThinking ? "pro" : "flash";
|
||||||
|
}
|
||||||
|
|
||||||
export function InputBox({
|
export function InputBox({
|
||||||
className,
|
className,
|
||||||
disabled,
|
disabled,
|
||||||
@@ -104,42 +125,64 @@ export function InputBox({
|
|||||||
const searchParams = useSearchParams();
|
const searchParams = useSearchParams();
|
||||||
const [modelDialogOpen, setModelDialogOpen] = useState(false);
|
const [modelDialogOpen, setModelDialogOpen] = useState(false);
|
||||||
const { models } = useModels();
|
const { models } = useModels();
|
||||||
const selectedModel = useMemo(() => {
|
|
||||||
if (!context.model_name && models.length > 0) {
|
useEffect(() => {
|
||||||
const model = models[0]!;
|
if (models.length === 0) {
|
||||||
setTimeout(() => {
|
return;
|
||||||
onContextChange?.({
|
|
||||||
...context,
|
|
||||||
model_name: model.name,
|
|
||||||
mode: model.supports_thinking ? "pro" : "flash",
|
|
||||||
});
|
|
||||||
}, 0);
|
|
||||||
return model;
|
|
||||||
}
|
}
|
||||||
return models.find((m) => m.name === context.model_name);
|
const currentModel = models.find((m) => m.name === context.model_name);
|
||||||
|
const fallbackModel = currentModel ?? models[0]!;
|
||||||
|
const supportsThinking = fallbackModel.supports_thinking ?? false;
|
||||||
|
const nextModelName = fallbackModel.name;
|
||||||
|
const nextMode = getResolvedMode(context.mode, supportsThinking);
|
||||||
|
|
||||||
|
if (context.model_name === nextModelName && context.mode === nextMode) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
onContextChange?.({
|
||||||
|
...context,
|
||||||
|
model_name: nextModelName,
|
||||||
|
mode: nextMode,
|
||||||
|
});
|
||||||
}, [context, models, onContextChange]);
|
}, [context, models, onContextChange]);
|
||||||
|
|
||||||
|
const selectedModel = useMemo(() => {
|
||||||
|
if (models.length === 0) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
return models.find((m) => m.name === context.model_name) ?? models[0];
|
||||||
|
}, [context.model_name, models]);
|
||||||
|
|
||||||
const supportThinking = useMemo(
|
const supportThinking = useMemo(
|
||||||
() => selectedModel?.supports_thinking ?? false,
|
() => selectedModel?.supports_thinking ?? false,
|
||||||
[selectedModel],
|
[selectedModel],
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleModelSelect = useCallback(
|
const handleModelSelect = useCallback(
|
||||||
(model_name: string) => {
|
(model_name: string) => {
|
||||||
|
const model = models.find((m) => m.name === model_name);
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
onContextChange?.({
|
onContextChange?.({
|
||||||
...context,
|
...context,
|
||||||
model_name,
|
model_name,
|
||||||
|
mode: getResolvedMode(context.mode, model.supports_thinking ?? false),
|
||||||
});
|
});
|
||||||
setModelDialogOpen(false);
|
setModelDialogOpen(false);
|
||||||
},
|
},
|
||||||
[onContextChange, context],
|
[onContextChange, context, models],
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleModeSelect = useCallback(
|
const handleModeSelect = useCallback(
|
||||||
(mode: "flash" | "thinking" | "pro" | "ultra") => {
|
(mode: InputMode) => {
|
||||||
onContextChange?.({
|
onContextChange?.({
|
||||||
...context,
|
...context,
|
||||||
mode,
|
mode: getResolvedMode(mode, supportThinking),
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[onContextChange, context],
|
[onContextChange, context, supportThinking],
|
||||||
);
|
);
|
||||||
const handleSubmit = useCallback(
|
const handleSubmit = useCallback(
|
||||||
async (message: PromptInputMessage) => {
|
async (message: PromptInputMessage) => {
|
||||||
|
|||||||
Reference in New Issue
Block a user