mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-20 12:54:45 +08:00
Support langgraph checkpointer (#1005)
* Add checkpointer configuration to config.example.yaml - Introduced a new section for checkpointer configuration to enable state persistence for the embedded DeerFlowClient. - Documented supported types: memory, sqlite, and postgres, along with examples for each. - Clarified that the LangGraph Server manages its own state persistence separately. * refactor(checkpointer): streamline checkpointer initialization and logging * fix(uv.lock): update revision and add new wheel URLs for brotlicffi package * feat: add langchain-anthropic dependency and update related configurations * Fix checkpointer lifecycle, docstring, and path resolution bugs from PR #1005 review (#4) * Initial plan * Address all review suggestions from PR #1005 Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com> * Fix resolve_path to always return real Path; move SQLite special-string handling to callers Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com> --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com>
This commit is contained in:
255
backend/tests/test_checkpointer.py
Normal file
255
backend/tests/test_checkpointer.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Unit tests for checkpointer config and singleton factory."""
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.agents.checkpointer import get_checkpointer, reset_checkpointer
|
||||
from src.config.checkpointer_config import (
|
||||
CheckpointerConfig,
|
||||
get_checkpointer_config,
|
||||
load_checkpointer_config_from_dict,
|
||||
set_checkpointer_config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset singleton state before each test."""
|
||||
set_checkpointer_config(None)
|
||||
reset_checkpointer()
|
||||
yield
|
||||
set_checkpointer_config(None)
|
||||
reset_checkpointer()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckpointerConfig:
|
||||
def test_load_memory_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
assert config.type == "memory"
|
||||
assert config.connection_string is None
|
||||
|
||||
def test_load_sqlite_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
assert config.type == "sqlite"
|
||||
assert config.connection_string == "/tmp/test.db"
|
||||
|
||||
def test_load_postgres_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
assert config.type == "postgres"
|
||||
assert config.connection_string == "postgresql://localhost/db"
|
||||
|
||||
def test_default_connection_string_is_none(self):
|
||||
config = CheckpointerConfig(type="memory")
|
||||
assert config.connection_string is None
|
||||
|
||||
def test_set_config_to_none(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
set_checkpointer_config(None)
|
||||
assert get_checkpointer_config() is None
|
||||
|
||||
def test_invalid_type_raises(self):
|
||||
with pytest.raises(Exception):
|
||||
load_checkpointer_config_from_dict({"type": "unknown"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetCheckpointer:
|
||||
def test_returns_none_when_not_configured(self):
|
||||
assert get_checkpointer() is None
|
||||
|
||||
def test_memory_returns_in_memory_saver(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
cp = get_checkpointer()
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
|
||||
def test_memory_singleton(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
cp1 = get_checkpointer()
|
||||
cp2 = get_checkpointer()
|
||||
assert cp1 is cp2
|
||||
|
||||
def test_reset_clears_singleton(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
cp1 = get_checkpointer()
|
||||
reset_checkpointer()
|
||||
cp2 = get_checkpointer()
|
||||
assert cp1 is not cp2
|
||||
|
||||
def test_sqlite_raises_when_package_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
|
||||
get_checkpointer()
|
||||
|
||||
def test_postgres_raises_when_package_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
|
||||
get_checkpointer()
|
||||
|
||||
def test_postgres_raises_when_connection_string_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "postgres"})
|
||||
mock_saver = MagicMock()
|
||||
mock_module = MagicMock()
|
||||
mock_module.PostgresSaver = mock_saver
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ValueError, match="connection_string is required"):
|
||||
get_checkpointer()
|
||||
|
||||
def test_sqlite_creates_saver(self):
|
||||
"""SQLite checkpointer is created when package is available."""
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
|
||||
mock_saver_instance = MagicMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
|
||||
mock_cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_saver_cls = MagicMock()
|
||||
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
|
||||
|
||||
mock_module = MagicMock()
|
||||
mock_module.SqliteSaver = mock_saver_cls
|
||||
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}):
|
||||
reset_checkpointer()
|
||||
cp = get_checkpointer()
|
||||
|
||||
assert cp is mock_saver_instance
|
||||
mock_saver_cls.from_conn_string.assert_called_once()
|
||||
mock_saver_instance.setup.assert_called_once()
|
||||
|
||||
def test_postgres_creates_saver(self):
|
||||
"""Postgres checkpointer is created when packages are available."""
|
||||
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
|
||||
|
||||
mock_saver_instance = MagicMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
|
||||
mock_cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_saver_cls = MagicMock()
|
||||
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
|
||||
|
||||
mock_pg_module = MagicMock()
|
||||
mock_pg_module.PostgresSaver = mock_saver_cls
|
||||
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}):
|
||||
reset_checkpointer()
|
||||
cp = get_checkpointer()
|
||||
|
||||
assert cp is mock_saver_instance
|
||||
mock_saver_cls.from_conn_string.assert_called_once_with("postgresql://localhost/db")
|
||||
mock_saver_instance.setup.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# app_config.py integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAppConfigLoadsCheckpointer:
|
||||
def test_load_checkpointer_section(self):
|
||||
"""load_checkpointer_config_from_dict populates the global config."""
|
||||
set_checkpointer_config(None)
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
cfg = get_checkpointer_config()
|
||||
assert cfg is not None
|
||||
assert cfg.type == "memory"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DeerFlowClient falls back to config checkpointer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClientCheckpointerFallback:
|
||||
def test_client_uses_config_checkpointer_when_none_provided(self):
|
||||
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from src.client import DeerFlowClient
|
||||
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_create_agent(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return MagicMock()
|
||||
|
||||
model_mock = MagicMock()
|
||||
config_mock = MagicMock()
|
||||
config_mock.models = [model_mock]
|
||||
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
|
||||
config_mock.checkpointer = None
|
||||
|
||||
with (
|
||||
patch("src.client.get_app_config", return_value=config_mock),
|
||||
patch("src.client.create_agent", side_effect=fake_create_agent),
|
||||
patch("src.client.create_chat_model", return_value=MagicMock()),
|
||||
patch("src.client._build_middlewares", return_value=[]),
|
||||
patch("src.client.apply_prompt_template", return_value=""),
|
||||
patch("src.client.DeerFlowClient._get_tools", return_value=[]),
|
||||
):
|
||||
client = DeerFlowClient(checkpointer=None)
|
||||
config = client._get_runnable_config("test-thread")
|
||||
client._ensure_agent(config)
|
||||
|
||||
assert "checkpointer" in captured_kwargs
|
||||
assert isinstance(captured_kwargs["checkpointer"], InMemorySaver)
|
||||
|
||||
def test_client_explicit_checkpointer_takes_precedence(self):
|
||||
"""An explicitly provided checkpointer is used even when config checkpointer is set."""
|
||||
from src.client import DeerFlowClient
|
||||
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
explicit_cp = MagicMock()
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_create_agent(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return MagicMock()
|
||||
|
||||
model_mock = MagicMock()
|
||||
config_mock = MagicMock()
|
||||
config_mock.models = [model_mock]
|
||||
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
|
||||
config_mock.checkpointer = None
|
||||
|
||||
with (
|
||||
patch("src.client.get_app_config", return_value=config_mock),
|
||||
patch("src.client.create_agent", side_effect=fake_create_agent),
|
||||
patch("src.client.create_chat_model", return_value=MagicMock()),
|
||||
patch("src.client._build_middlewares", return_value=[]),
|
||||
patch("src.client.apply_prompt_template", return_value=""),
|
||||
patch("src.client.DeerFlowClient._get_tools", return_value=[]),
|
||||
):
|
||||
client = DeerFlowClient(checkpointer=explicit_cp)
|
||||
config = client._get_runnable_config("test-thread")
|
||||
client._ensure_agent(config)
|
||||
|
||||
assert captured_kwargs["checkpointer"] is explicit_cp
|
||||
@@ -28,6 +28,7 @@ if _skip_reason:
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
"""Create a real DeerFlowClient (no mocks)."""
|
||||
@@ -38,6 +39,7 @@ def client():
|
||||
def thread_tmp(tmp_path):
|
||||
"""Provide a unique thread_id + tmp directory for file operations."""
|
||||
import uuid
|
||||
|
||||
tid = f"live-test-{uuid.uuid4().hex[:8]}"
|
||||
return tid, tmp_path
|
||||
|
||||
@@ -46,6 +48,7 @@ def thread_tmp(tmp_path):
|
||||
# Scenario 1: Basic chat — model responds coherently
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLiveBasicChat:
|
||||
def test_chat_returns_nonempty_string(self, client):
|
||||
"""chat() returns a non-empty response from the real model."""
|
||||
@@ -65,6 +68,7 @@ class TestLiveBasicChat:
|
||||
# Scenario 2: Streaming — events arrive in correct order
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLiveStreaming:
|
||||
def test_stream_yields_messages_tuple_and_end(self, client):
|
||||
"""stream() produces at least one messages-tuple event and ends with end."""
|
||||
@@ -81,10 +85,7 @@ class TestLiveStreaming:
|
||||
|
||||
def test_stream_ai_content_nonempty(self, client):
|
||||
"""Streamed messages-tuple AI events contain non-empty content."""
|
||||
ai_messages = [
|
||||
e for e in client.stream("What color is the sky? One word.")
|
||||
if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content")
|
||||
]
|
||||
ai_messages = [e for e in client.stream("What color is the sky? One word.") if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content")]
|
||||
assert len(ai_messages) >= 1
|
||||
for m in ai_messages:
|
||||
assert len(m.data.get("content", "")) > 0
|
||||
@@ -94,13 +95,11 @@ class TestLiveStreaming:
|
||||
# Scenario 3: Tool use — agent calls a tool and returns result
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLiveToolUse:
|
||||
def test_agent_uses_bash_tool(self, client):
|
||||
"""Agent uses bash tool when asked to run a command."""
|
||||
events = list(client.stream(
|
||||
"Use the bash tool to run: echo 'LIVE_TEST_OK'. "
|
||||
"Then tell me the output."
|
||||
))
|
||||
events = list(client.stream("Use the bash tool to run: echo 'LIVE_TEST_OK'. Then tell me the output."))
|
||||
|
||||
types = [e.type for e in events]
|
||||
print(f" event types: {types}")
|
||||
@@ -122,10 +121,7 @@ class TestLiveToolUse:
|
||||
|
||||
def test_agent_uses_ls_tool(self, client):
|
||||
"""Agent uses ls tool to list a directory."""
|
||||
events = list(client.stream(
|
||||
"Use the ls tool to list the contents of /mnt/user-data/workspace. "
|
||||
"Just report what you see."
|
||||
))
|
||||
events = list(client.stream("Use the ls tool to list the contents of /mnt/user-data/workspace. Just report what you see."))
|
||||
|
||||
types = [e.type for e in events]
|
||||
print(f" event types: {types}")
|
||||
@@ -139,15 +135,11 @@ class TestLiveToolUse:
|
||||
# Scenario 4: Multi-tool chain — agent chains tools in sequence
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLiveMultiToolChain:
|
||||
def test_write_then_read(self, client):
|
||||
"""Agent writes a file, then reads it back."""
|
||||
events = list(client.stream(
|
||||
"Step 1: Use write_file to write 'integration_test_content' to "
|
||||
"/mnt/user-data/outputs/live_test.txt. "
|
||||
"Step 2: Use read_file to read that file back. "
|
||||
"Step 3: Tell me the content you read."
|
||||
))
|
||||
events = list(client.stream("Step 1: Use write_file to write 'integration_test_content' to /mnt/user-data/outputs/live_test.txt. Step 2: Use read_file to read that file back. Step 3: Tell me the content you read."))
|
||||
|
||||
types = [e.type for e in events]
|
||||
print(f" event types: {types}")
|
||||
@@ -164,16 +156,14 @@ class TestLiveMultiToolChain:
|
||||
ai_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content")]
|
||||
tr_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "tool"]
|
||||
final_text = ai_events[-1].data["content"] if ai_events else ""
|
||||
assert "integration_test_content" in final_text.lower() or any(
|
||||
"integration_test_content" in e.data.get("content", "")
|
||||
for e in tr_events
|
||||
)
|
||||
assert "integration_test_content" in final_text.lower() or any("integration_test_content" in e.data.get("content", "") for e in tr_events)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Scenario 5: File upload lifecycle with real filesystem
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLiveFileUpload:
|
||||
def test_upload_list_delete(self, client, thread_tmp):
|
||||
"""Upload → list → delete → verify deletion."""
|
||||
@@ -225,6 +215,7 @@ class TestLiveFileUpload:
|
||||
# Scenario 6: Configuration query — real config loading
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLiveConfigQueries:
|
||||
def test_list_models_returns_configured_model(self, client):
|
||||
"""list_models() returns at least one configured model with Gateway-aligned fields."""
|
||||
@@ -266,25 +257,25 @@ class TestLiveConfigQueries:
|
||||
# Scenario 7: Artifact read after agent writes
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLiveArtifact:
|
||||
def test_get_artifact_after_write(self, client):
|
||||
"""Agent writes a file → client reads it back via get_artifact()."""
|
||||
import uuid
|
||||
|
||||
thread_id = f"live-artifact-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Ask agent to write a file
|
||||
events = list(client.stream(
|
||||
"Use write_file to create /mnt/user-data/outputs/artifact_test.json "
|
||||
"with content: {\"status\": \"ok\", \"source\": \"live_test\"}",
|
||||
thread_id=thread_id,
|
||||
))
|
||||
events = list(
|
||||
client.stream(
|
||||
'Use write_file to create /mnt/user-data/outputs/artifact_test.json with content: {"status": "ok", "source": "live_test"}',
|
||||
thread_id=thread_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Verify write happened
|
||||
tc_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and "tool_calls" in e.data]
|
||||
assert any(
|
||||
any(tc["name"] == "write_file" for tc in e.data["tool_calls"])
|
||||
for e in tc_events
|
||||
)
|
||||
assert any(any(tc["name"] == "write_file" for tc in e.data["tool_calls"]) for e in tc_events)
|
||||
|
||||
# Read artifact
|
||||
content, mime = client.get_artifact(thread_id, "mnt/user-data/outputs/artifact_test.json")
|
||||
@@ -303,11 +294,13 @@ class TestLiveArtifact:
|
||||
# Scenario 8: Per-call overrides
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLiveOverrides:
|
||||
def test_thinking_disabled_still_works(self, client):
|
||||
"""Explicit thinking_enabled=False override produces a response."""
|
||||
response = client.chat(
|
||||
"Say OK.", thinking_enabled=False,
|
||||
"Say OK.",
|
||||
thinking_enabled=False,
|
||||
)
|
||||
assert len(response) > 0
|
||||
print(f" response: {response}")
|
||||
@@ -317,6 +310,7 @@ class TestLiveOverrides:
|
||||
# Scenario 9: Error resilience
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLiveErrorResilience:
|
||||
def test_delete_nonexistent_upload(self, client):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
|
||||
Reference in New Issue
Block a user