Support langgraph checkpointer (#1005)

* Add checkpointer configuration to config.example.yaml

- Introduced a new section for checkpointer configuration to enable state persistence for the embedded DeerFlowClient.
- Documented supported types: memory, sqlite, and postgres, along with examples for each.
- Clarified that the LangGraph Server manages its own state persistence separately.

* refactor(checkpointer): streamline checkpointer initialization and logging

* fix(uv.lock): update revision and add new wheel URLs for brotlicffi package

* feat: add langchain-anthropic dependency and update related configurations

* Fix checkpointer lifecycle, docstring, and path resolution bugs from PR #1005 review (#4)

* Initial plan

* Address all review suggestions from PR #1005

Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com>

* Fix resolve_path to always return real Path; move SQLite special-string handling to callers

Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com>

---------

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com>
This commit is contained in:
JeffJiang
2026-03-07 21:07:21 +08:00
committed by GitHub
parent 09325ca28f
commit d664ae5a4b
14 changed files with 819 additions and 84 deletions

View File

@@ -0,0 +1,255 @@
"""Unit tests for checkpointer config and singleton factory."""
import sys
from unittest.mock import MagicMock, patch
import pytest
from src.agents.checkpointer import get_checkpointer, reset_checkpointer
from src.config.checkpointer_config import (
CheckpointerConfig,
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
@pytest.fixture(autouse=True)
def reset_state():
"""Reset singleton state before each test."""
set_checkpointer_config(None)
reset_checkpointer()
yield
set_checkpointer_config(None)
reset_checkpointer()
# ---------------------------------------------------------------------------
# Config tests
# ---------------------------------------------------------------------------
class TestCheckpointerConfig:
def test_load_memory_config(self):
load_checkpointer_config_from_dict({"type": "memory"})
config = get_checkpointer_config()
assert config is not None
assert config.type == "memory"
assert config.connection_string is None
def test_load_sqlite_config(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
config = get_checkpointer_config()
assert config is not None
assert config.type == "sqlite"
assert config.connection_string == "/tmp/test.db"
def test_load_postgres_config(self):
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
config = get_checkpointer_config()
assert config is not None
assert config.type == "postgres"
assert config.connection_string == "postgresql://localhost/db"
def test_default_connection_string_is_none(self):
config = CheckpointerConfig(type="memory")
assert config.connection_string is None
def test_set_config_to_none(self):
load_checkpointer_config_from_dict({"type": "memory"})
set_checkpointer_config(None)
assert get_checkpointer_config() is None
def test_invalid_type_raises(self):
with pytest.raises(Exception):
load_checkpointer_config_from_dict({"type": "unknown"})
# ---------------------------------------------------------------------------
# Factory tests
# ---------------------------------------------------------------------------
class TestGetCheckpointer:
def test_returns_none_when_not_configured(self):
assert get_checkpointer() is None
def test_memory_returns_in_memory_saver(self):
load_checkpointer_config_from_dict({"type": "memory"})
from langgraph.checkpoint.memory import InMemorySaver
cp = get_checkpointer()
assert isinstance(cp, InMemorySaver)
def test_memory_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
cp1 = get_checkpointer()
cp2 = get_checkpointer()
assert cp1 is cp2
def test_reset_clears_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
cp1 = get_checkpointer()
reset_checkpointer()
cp2 = get_checkpointer()
assert cp1 is not cp2
def test_sqlite_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
get_checkpointer()
def test_postgres_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
get_checkpointer()
def test_postgres_raises_when_connection_string_missing(self):
load_checkpointer_config_from_dict({"type": "postgres"})
mock_saver = MagicMock()
mock_module = MagicMock()
mock_module.PostgresSaver = mock_saver
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}):
reset_checkpointer()
with pytest.raises(ValueError, match="connection_string is required"):
get_checkpointer()
def test_sqlite_creates_saver(self):
"""SQLite checkpointer is created when package is available."""
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}):
reset_checkpointer()
cp = get_checkpointer()
assert cp is mock_saver_instance
mock_saver_cls.from_conn_string.assert_called_once()
mock_saver_instance.setup.assert_called_once()
def test_postgres_creates_saver(self):
"""Postgres checkpointer is created when packages are available."""
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_pg_module = MagicMock()
mock_pg_module.PostgresSaver = mock_saver_cls
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}):
reset_checkpointer()
cp = get_checkpointer()
assert cp is mock_saver_instance
mock_saver_cls.from_conn_string.assert_called_once_with("postgresql://localhost/db")
mock_saver_instance.setup.assert_called_once()
# ---------------------------------------------------------------------------
# app_config.py integration
# ---------------------------------------------------------------------------
class TestAppConfigLoadsCheckpointer:
def test_load_checkpointer_section(self):
"""load_checkpointer_config_from_dict populates the global config."""
set_checkpointer_config(None)
load_checkpointer_config_from_dict({"type": "memory"})
cfg = get_checkpointer_config()
assert cfg is not None
assert cfg.type == "memory"
# ---------------------------------------------------------------------------
# DeerFlowClient falls back to config checkpointer
# ---------------------------------------------------------------------------
class TestClientCheckpointerFallback:
def test_client_uses_config_checkpointer_when_none_provided(self):
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
from langgraph.checkpoint.memory import InMemorySaver
from src.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("src.client.get_app_config", return_value=config_mock),
patch("src.client.create_agent", side_effect=fake_create_agent),
patch("src.client.create_chat_model", return_value=MagicMock()),
patch("src.client._build_middlewares", return_value=[]),
patch("src.client.apply_prompt_template", return_value=""),
patch("src.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=None)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert "checkpointer" in captured_kwargs
assert isinstance(captured_kwargs["checkpointer"], InMemorySaver)
def test_client_explicit_checkpointer_takes_precedence(self):
"""An explicitly provided checkpointer is used even when config checkpointer is set."""
from src.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
explicit_cp = MagicMock()
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("src.client.get_app_config", return_value=config_mock),
patch("src.client.create_agent", side_effect=fake_create_agent),
patch("src.client.create_chat_model", return_value=MagicMock()),
patch("src.client._build_middlewares", return_value=[]),
patch("src.client.apply_prompt_template", return_value=""),
patch("src.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=explicit_cp)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert captured_kwargs["checkpointer"] is explicit_cp

View File

@@ -28,6 +28,7 @@ if _skip_reason:
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def client():
"""Create a real DeerFlowClient (no mocks)."""
@@ -38,6 +39,7 @@ def client():
def thread_tmp(tmp_path):
"""Provide a unique thread_id + tmp directory for file operations."""
import uuid
tid = f"live-test-{uuid.uuid4().hex[:8]}"
return tid, tmp_path
@@ -46,6 +48,7 @@ def thread_tmp(tmp_path):
# Scenario 1: Basic chat — model responds coherently
# ===========================================================================
class TestLiveBasicChat:
def test_chat_returns_nonempty_string(self, client):
"""chat() returns a non-empty response from the real model."""
@@ -65,6 +68,7 @@ class TestLiveBasicChat:
# Scenario 2: Streaming — events arrive in correct order
# ===========================================================================
class TestLiveStreaming:
def test_stream_yields_messages_tuple_and_end(self, client):
"""stream() produces at least one messages-tuple event and ends with end."""
@@ -81,10 +85,7 @@ class TestLiveStreaming:
def test_stream_ai_content_nonempty(self, client):
"""Streamed messages-tuple AI events contain non-empty content."""
ai_messages = [
e for e in client.stream("What color is the sky? One word.")
if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content")
]
ai_messages = [e for e in client.stream("What color is the sky? One word.") if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content")]
assert len(ai_messages) >= 1
for m in ai_messages:
assert len(m.data.get("content", "")) > 0
@@ -94,13 +95,11 @@ class TestLiveStreaming:
# Scenario 3: Tool use — agent calls a tool and returns result
# ===========================================================================
class TestLiveToolUse:
def test_agent_uses_bash_tool(self, client):
"""Agent uses bash tool when asked to run a command."""
events = list(client.stream(
"Use the bash tool to run: echo 'LIVE_TEST_OK'. "
"Then tell me the output."
))
events = list(client.stream("Use the bash tool to run: echo 'LIVE_TEST_OK'. Then tell me the output."))
types = [e.type for e in events]
print(f" event types: {types}")
@@ -122,10 +121,7 @@ class TestLiveToolUse:
def test_agent_uses_ls_tool(self, client):
"""Agent uses ls tool to list a directory."""
events = list(client.stream(
"Use the ls tool to list the contents of /mnt/user-data/workspace. "
"Just report what you see."
))
events = list(client.stream("Use the ls tool to list the contents of /mnt/user-data/workspace. Just report what you see."))
types = [e.type for e in events]
print(f" event types: {types}")
@@ -139,15 +135,11 @@ class TestLiveToolUse:
# Scenario 4: Multi-tool chain — agent chains tools in sequence
# ===========================================================================
class TestLiveMultiToolChain:
def test_write_then_read(self, client):
"""Agent writes a file, then reads it back."""
events = list(client.stream(
"Step 1: Use write_file to write 'integration_test_content' to "
"/mnt/user-data/outputs/live_test.txt. "
"Step 2: Use read_file to read that file back. "
"Step 3: Tell me the content you read."
))
events = list(client.stream("Step 1: Use write_file to write 'integration_test_content' to /mnt/user-data/outputs/live_test.txt. Step 2: Use read_file to read that file back. Step 3: Tell me the content you read."))
types = [e.type for e in events]
print(f" event types: {types}")
@@ -164,16 +156,14 @@ class TestLiveMultiToolChain:
ai_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content")]
tr_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "tool"]
final_text = ai_events[-1].data["content"] if ai_events else ""
assert "integration_test_content" in final_text.lower() or any(
"integration_test_content" in e.data.get("content", "")
for e in tr_events
)
assert "integration_test_content" in final_text.lower() or any("integration_test_content" in e.data.get("content", "") for e in tr_events)
# ===========================================================================
# Scenario 5: File upload lifecycle with real filesystem
# ===========================================================================
class TestLiveFileUpload:
def test_upload_list_delete(self, client, thread_tmp):
"""Upload → list → delete → verify deletion."""
@@ -225,6 +215,7 @@ class TestLiveFileUpload:
# Scenario 6: Configuration query — real config loading
# ===========================================================================
class TestLiveConfigQueries:
def test_list_models_returns_configured_model(self, client):
"""list_models() returns at least one configured model with Gateway-aligned fields."""
@@ -266,25 +257,25 @@ class TestLiveConfigQueries:
# Scenario 7: Artifact read after agent writes
# ===========================================================================
class TestLiveArtifact:
def test_get_artifact_after_write(self, client):
"""Agent writes a file → client reads it back via get_artifact()."""
import uuid
thread_id = f"live-artifact-{uuid.uuid4().hex[:8]}"
# Ask agent to write a file
events = list(client.stream(
"Use write_file to create /mnt/user-data/outputs/artifact_test.json "
"with content: {\"status\": \"ok\", \"source\": \"live_test\"}",
thread_id=thread_id,
))
events = list(
client.stream(
'Use write_file to create /mnt/user-data/outputs/artifact_test.json with content: {"status": "ok", "source": "live_test"}',
thread_id=thread_id,
)
)
# Verify write happened
tc_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and "tool_calls" in e.data]
assert any(
any(tc["name"] == "write_file" for tc in e.data["tool_calls"])
for e in tc_events
)
assert any(any(tc["name"] == "write_file" for tc in e.data["tool_calls"]) for e in tc_events)
# Read artifact
content, mime = client.get_artifact(thread_id, "mnt/user-data/outputs/artifact_test.json")
@@ -303,11 +294,13 @@ class TestLiveArtifact:
# Scenario 8: Per-call overrides
# ===========================================================================
class TestLiveOverrides:
def test_thinking_disabled_still_works(self, client):
"""Explicit thinking_enabled=False override produces a response."""
response = client.chat(
"Say OK.", thinking_enabled=False,
"Say OK.",
thinking_enabled=False,
)
assert len(response) > 0
print(f" response: {response}")
@@ -317,6 +310,7 @@ class TestLiveOverrides:
# Scenario 9: Error resilience
# ===========================================================================
class TestLiveErrorResilience:
def test_delete_nonexistent_upload(self, client):
with pytest.raises(FileNotFoundError):