mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-09 16:54:46 +08:00
* feat: Enhance chat streaming and tool call processing - Added support for MongoDB checkpointer in the chat streaming workflow. - Introduced functions to process tool call chunks and sanitize arguments. - Improved event message creation with additional metadata. - Enhanced error handling for JSON serialization in event messages. - Updated the frontend to convert escaped characters in tool call arguments. - Refactored the workflow input preparation and initial message processing. - Added new dependencies for MongoDB integration and tool argument sanitization. * fix: Update MongoDB checkpointer configuration to use LANGGRAPH_CHECKPOINT_DB_URL * feat: Add support for Postgres checkpointing and update README with database recommendations * feat: Implement checkpoint saver functionality and update MongoDB connection handling * refactor: Improve code formatting and readability in app.py and json_utils.py * refactor: Clean up commented code and improve formatting in server.py * refactor: Remove unused imports and improve code organization in app.py * refactor: Improve code organization and remove unnecessary comments in app.py * chore: use langgraph-checkpoint-postgres==2.0.21 to avoid the JSON convert issue in the latest version, implement chat stream persistant with Postgres * feat: add MongoDB and PostgreSQL support for LangGraph checkpointing, enhance environment variable handling * fix: update comments for clarity on Windows event loop policy * chore: remove empty code changes in MongoDB and PostgreSQL checkpoint tests * chore: clean up unused imports and code in checkpoint-related files * chore: remove empty code changes in test_checkpoint.py * chore: remove empty code changes in test_checkpoint.py * chore: remove empty code changes in test_checkpoint.py * test: update status code assertions in MCP endpoint tests to allow for 403 responses * test: update MCP endpoint tests to assert specific status codes and enable MCP server configuration * chore: remove unnecessary environment variables from unittest workflow * fix: invert condition for MCP server configuration check to raise 403 when disabled * chore: remove pymongo from test dependencies in uv.lock * chore: optimize the _get_agent_name method * test: enhance ChatStreamManager tests for PostgreSQL and MongoDB initialization * test: add persistence tests for ChatStreamManager with PostgreSQL and MongoDB * test: add unit tests for ChatStreamManager initialization with PostgreSQL and MongoDB * test: enhance persistence tests for ChatStreamManager with PostgreSQL and MongoDB to verify message aggregation * test: add unit tests for ChatStreamManager with PostgreSQL and MongoDB * test: add unit tests for ChatStreamManager initialization with PostgreSQL and MongoDB * test: add unit tests for ChatStreamManager initialization with PostgreSQL and MongoDB --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
661 lines
20 KiB
Python
661 lines
20 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import src.graph.checkpoint as checkpoint
|
|
|
|
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
|
|
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
|
|
|
|
|
|
def test_with_local_postgres_db():
|
|
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True,
|
|
db_uri=POSTGRES_URL,
|
|
)
|
|
assert manager.postgres_conn is not None
|
|
assert manager.mongo_client is None
|
|
|
|
|
|
def test_with_local_mongo_db():
|
|
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True,
|
|
db_uri=MONGO_URL,
|
|
)
|
|
assert manager.mongo_db is not None
|
|
assert manager.postgres_conn is None
|
|
|
|
|
|
def test_init_without_checkpoint_saver():
|
|
"""Manager should not create DB clients when checkpoint_saver is False."""
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
|
|
assert manager.checkpoint_saver is False
|
|
# DB connections are not created when saver is disabled
|
|
assert manager.mongo_client is None
|
|
assert manager.postgres_conn is None
|
|
|
|
|
|
def test_process_stream_partial_buffer_postgres(monkeypatch):
|
|
"""Partial chunks should be buffered; Postgres init is stubbed to no-op."""
|
|
|
|
# Patch Postgres init to no-op
|
|
def _no_pg(self):
|
|
self.postgres_conn = None
|
|
|
|
monkeypatch.setattr(
|
|
checkpoint.ChatStreamManager, "_init_postgresql", _no_pg, raising=True
|
|
)
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True,
|
|
db_uri=POSTGRES_URL,
|
|
)
|
|
result = manager.process_stream_message("t1", "hello", finish_reason="partial")
|
|
assert result is True
|
|
# Verify the chunk was stored in the in-memory store
|
|
items = manager.store.search(("messages", "t1"), limit=10)
|
|
values = [it.dict()["value"] for it in items]
|
|
assert "hello" in values
|
|
|
|
|
|
def test_process_stream_partial_buffer_mongo(monkeypatch):
|
|
"""Partial chunks should be buffered; Mongo init is stubbed to no-op."""
|
|
|
|
# Patch Mongo init to no-op for speed
|
|
def _no_mongo(self):
|
|
self.mongo_client = None
|
|
self.mongo_db = None
|
|
|
|
monkeypatch.setattr(
|
|
checkpoint.ChatStreamManager, "_init_mongodb", _no_mongo, raising=True
|
|
)
|
|
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True,
|
|
db_uri=MONGO_URL,
|
|
)
|
|
result = manager.process_stream_message("t2", "hello", finish_reason="partial")
|
|
assert result is True
|
|
# Verify the chunk was stored in the in-memory store
|
|
items = manager.store.search(("messages", "t2"), limit=10)
|
|
values = [it.dict()["value"] for it in items]
|
|
assert "hello" in values
|
|
|
|
|
|
def test_persist_postgresql_local_db():
|
|
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True,
|
|
db_uri=POSTGRES_URL,
|
|
)
|
|
assert manager.postgres_conn is not None
|
|
assert manager.mongo_client is None
|
|
|
|
# Simulate a message to persist
|
|
thread_id = "test_thread"
|
|
messages = ["This is a test message."]
|
|
result = manager._persist_to_postgresql(thread_id, messages)
|
|
assert result is True
|
|
# Simulate a message with existing thread
|
|
result = manager._persist_to_postgresql(thread_id, ["Another message."])
|
|
assert result is True
|
|
|
|
|
|
def test_persist_postgresql_called_with_aggregated_chunks(monkeypatch):
|
|
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True,
|
|
db_uri=POSTGRES_URL,
|
|
)
|
|
|
|
assert (
|
|
manager.process_stream_message("thd3", "Hello", finish_reason="partial") is True
|
|
)
|
|
assert (
|
|
manager.process_stream_message("thd3", " World", finish_reason="stop") is True
|
|
)
|
|
|
|
# Verify the messages were aggregated correctly
|
|
with manager.postgres_conn.cursor() as cursor:
|
|
# Check if conversation already exists
|
|
cursor.execute(
|
|
"SELECT messages FROM chat_streams WHERE thread_id = %s", ("thd3",)
|
|
)
|
|
existing_record = cursor.fetchone()
|
|
assert existing_record is not None
|
|
assert existing_record["messages"] == ["Hello", " World"]
|
|
|
|
|
|
def test_persist_not_attempted_when_saver_disabled():
|
|
"""When saver disabled, stop should not persist and should return False."""
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
|
|
# stop should try to persist, but saver disabled => returns False
|
|
assert manager.process_stream_message("t4", "hello", finish_reason="stop") is False
|
|
|
|
|
|
def test_persist_mongodb_local_db():
|
|
"""Ensure that the ChatStreamManager can persist to a local MongoDB."""
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True,
|
|
db_uri=MONGO_URL,
|
|
)
|
|
assert manager.mongo_db is not None
|
|
assert manager.postgres_conn is None
|
|
# Simulate a message to persist
|
|
thread_id = "test_thread"
|
|
messages = ["This is a test message."]
|
|
result = manager._persist_to_mongodb(thread_id, messages)
|
|
assert result is True
|
|
# Simulate a message with existing thread
|
|
result = manager._persist_to_mongodb(thread_id, ["Another message."])
|
|
assert result is True
|
|
|
|
|
|
def test_persist_mongodb_called_with_aggregated_chunks(monkeypatch):
|
|
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
|
|
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True,
|
|
db_uri=MONGO_URL,
|
|
)
|
|
|
|
assert (
|
|
manager.process_stream_message("thd5", "Hello", finish_reason="partial") is True
|
|
)
|
|
assert (
|
|
manager.process_stream_message("thd5", " World", finish_reason="stop") is True
|
|
)
|
|
|
|
# Verify the messages were aggregated correctly
|
|
collection = manager.mongo_db.chat_streams
|
|
existing_record = collection.find_one({"thread_id": "thd5"})
|
|
assert existing_record is not None
|
|
assert existing_record["messages"] == ["Hello", " World"]
|
|
|
|
|
|
def test_invalid_inputs_return_false(monkeypatch):
|
|
"""Empty thread_id or message should be rejected and return False."""
|
|
|
|
def _no_mongo(self):
|
|
self.mongo_client = None
|
|
self.mongo_db = None
|
|
|
|
monkeypatch.setattr(
|
|
checkpoint.ChatStreamManager, "_init_mongodb", _no_mongo, raising=True
|
|
)
|
|
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True,
|
|
db_uri=MONGO_URL,
|
|
)
|
|
assert manager.process_stream_message("", "msg", finish_reason="partial") is False
|
|
assert manager.process_stream_message("tid", "", finish_reason="partial") is False
|
|
|
|
|
|
def test_unsupported_db_uri_scheme():
|
|
"""Manager should log warning for unsupported database URI schemes."""
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True, db_uri="redis://localhost:6379/0"
|
|
)
|
|
# Should not have any database connections
|
|
assert manager.mongo_client is None
|
|
assert manager.postgres_conn is None
|
|
assert manager.mongo_db is None
|
|
|
|
|
|
def test_process_stream_with_interrupt_finish_reason():
|
|
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True,
|
|
db_uri=MONGO_URL,
|
|
)
|
|
|
|
# Add partial message
|
|
assert (
|
|
manager.process_stream_message(
|
|
"int_test", "Interrupted", finish_reason="partial"
|
|
)
|
|
is True
|
|
)
|
|
# Interrupt should trigger persistence
|
|
assert (
|
|
manager.process_stream_message(
|
|
"int_test", " message", finish_reason="interrupt"
|
|
)
|
|
is True
|
|
)
|
|
|
|
|
|
def test_postgresql_connection_failure(monkeypatch):
|
|
"""Test PostgreSQL connection failure handling."""
|
|
|
|
def failing_connect(dsn, **kwargs):
|
|
raise RuntimeError("Connection failed")
|
|
|
|
monkeypatch.setattr("psycopg.connect", failing_connect)
|
|
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True,
|
|
db_uri=POSTGRES_URL,
|
|
)
|
|
# Should have no postgres connection on failure
|
|
assert manager.postgres_conn is None
|
|
|
|
|
|
def test_mongodb_ping_failure(monkeypatch):
|
|
"""Test MongoDB ping failure during initialization."""
|
|
|
|
class FakeAdmin:
|
|
def command(self, name):
|
|
raise RuntimeError("Ping failed")
|
|
|
|
class FakeClient:
|
|
def __init__(self, uri):
|
|
self.admin = FakeAdmin()
|
|
|
|
monkeypatch.setattr(checkpoint, "MongoClient", lambda uri: FakeClient(uri))
|
|
|
|
manager = checkpoint.ChatStreamManager(
|
|
checkpoint_saver=True,
|
|
db_uri=MONGO_URL,
|
|
)
|
|
# Should not have mongo_db set on ping failure
|
|
assert getattr(manager, "mongo_db", None) is None
|
|
|
|
|
|
def test_store_namespace_consistency():
|
|
"""Test that store namespace is consistently used across methods."""
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
|
|
|
|
# Process a partial message
|
|
assert (
|
|
manager.process_stream_message("ns_test", "chunk1", finish_reason="partial")
|
|
is True
|
|
)
|
|
|
|
# Verify cursor is stored correctly
|
|
cursor = manager.store.get(("messages", "ns_test"), "cursor")
|
|
assert cursor is not None
|
|
assert cursor.value["index"] == 0
|
|
|
|
# Add another chunk
|
|
assert (
|
|
manager.process_stream_message("ns_test", "chunk2", finish_reason="partial")
|
|
is True
|
|
)
|
|
|
|
# Verify cursor is incremented
|
|
cursor = manager.store.get(("messages", "ns_test"), "cursor")
|
|
assert cursor.value["index"] == 1
|
|
|
|
|
|
def test_cursor_initialization_edge_cases():
|
|
"""Test cursor handling edge cases."""
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
|
|
|
|
# Manually set a cursor with missing index
|
|
namespace = ("messages", "edge_test")
|
|
manager.store.put(namespace, "cursor", {}) # Missing 'index' key
|
|
|
|
# Should handle missing index gracefully
|
|
result = manager.process_stream_message(
|
|
"edge_test", "test", finish_reason="partial"
|
|
)
|
|
assert result is True
|
|
|
|
# Should default to 0 and increment to 1
|
|
cursor = manager.store.get(namespace, "cursor")
|
|
assert cursor.value["index"] == 1
|
|
|
|
|
|
def test_multiple_threads_isolation():
|
|
"""Test that different thread_ids are properly isolated."""
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
|
|
|
|
# Process messages for different threads
|
|
assert (
|
|
manager.process_stream_message("thread1", "msg1", finish_reason="partial")
|
|
is True
|
|
)
|
|
assert (
|
|
manager.process_stream_message("thread2", "msg2", finish_reason="partial")
|
|
is True
|
|
)
|
|
assert (
|
|
manager.process_stream_message("thread1", "msg3", finish_reason="partial")
|
|
is True
|
|
)
|
|
|
|
# Verify isolation
|
|
thread1_items = manager.store.search(("messages", "thread1"), limit=10)
|
|
thread2_items = manager.store.search(("messages", "thread2"), limit=10)
|
|
|
|
thread1_values = [
|
|
item.dict()["value"]
|
|
for item in thread1_items
|
|
if isinstance(item.dict()["value"], str)
|
|
]
|
|
thread2_values = [
|
|
item.dict()["value"]
|
|
for item in thread2_items
|
|
if isinstance(item.dict()["value"], str)
|
|
]
|
|
|
|
assert "msg1" in thread1_values
|
|
assert "msg3" in thread1_values
|
|
assert "msg2" in thread2_values
|
|
assert "msg1" not in thread2_values
|
|
assert "msg2" not in thread1_values
|
|
|
|
|
|
def test_mongodb_insert_and_update_paths(monkeypatch):
|
|
"""Exercise MongoDB insert, update, and exception branches."""
|
|
|
|
# Fake Mongo classes
|
|
class FakeUpdateResult:
|
|
def __init__(self, modified_count):
|
|
self.modified_count = modified_count
|
|
|
|
class FakeInsertResult:
|
|
def __init__(self, inserted_id):
|
|
self.inserted_id = inserted_id
|
|
|
|
class FakeCollection:
|
|
def __init__(self, mode="insert_success"):
|
|
self.mode = mode
|
|
|
|
def find_one(self, query):
|
|
if self.mode.startswith("insert"):
|
|
return None
|
|
return {"thread_id": query["thread_id"]}
|
|
|
|
def update_one(self, q, s):
|
|
if self.mode == "update_success":
|
|
return FakeUpdateResult(1)
|
|
return FakeUpdateResult(0)
|
|
|
|
def insert_one(self, doc):
|
|
if self.mode == "insert_success":
|
|
return FakeInsertResult("ok")
|
|
if self.mode == "insert_none":
|
|
return FakeInsertResult(None)
|
|
raise RuntimeError("boom")
|
|
|
|
class FakeMongoDB:
|
|
def __init__(self, mode):
|
|
self.chat_streams = FakeCollection(mode)
|
|
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
|
|
|
# Insert success
|
|
manager.mongo_db = FakeMongoDB("insert_success")
|
|
assert manager._persist_to_mongodb("th1", ["a"]) is True
|
|
|
|
# Insert returns None id => False
|
|
manager.mongo_db = FakeMongoDB("insert_none")
|
|
assert manager._persist_to_mongodb("th2", ["a"]) is False
|
|
|
|
# Insert raises => False
|
|
manager.mongo_db = FakeMongoDB("insert_raise")
|
|
assert manager._persist_to_mongodb("th3", ["a"]) is False
|
|
|
|
# Update success
|
|
manager.mongo_db = FakeMongoDB("update_success")
|
|
assert manager._persist_to_mongodb("th4", ["a"]) is True
|
|
|
|
# Update modifies 0 => False
|
|
manager.mongo_db = FakeMongoDB("update_zero")
|
|
assert manager._persist_to_mongodb("th5", ["a"]) is False
|
|
|
|
|
|
def test_postgresql_insert_update_and_error_paths():
|
|
"""Exercise PostgreSQL update, insert, and error/rollback branches."""
|
|
calls = {"executed": []}
|
|
|
|
class FakeCursor:
|
|
def __init__(self, mode):
|
|
self.mode = mode
|
|
self.rowcount = 0
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def execute(self, sql, params=None):
|
|
calls["executed"].append(sql.strip().split()[0])
|
|
if "SELECT" in sql:
|
|
if self.mode == "update":
|
|
self._fetch = {"id": "x"}
|
|
elif self.mode == "error":
|
|
raise RuntimeError("sql error")
|
|
else:
|
|
self._fetch = None
|
|
else:
|
|
# UPDATE or INSERT
|
|
self.rowcount = 1
|
|
|
|
def fetchone(self):
|
|
return getattr(self, "_fetch", None)
|
|
|
|
class FakeConn:
|
|
def __init__(self, mode):
|
|
self.mode = mode
|
|
self.commit_called = False
|
|
self.rollback_called = False
|
|
|
|
def cursor(self):
|
|
return FakeCursor(self.mode)
|
|
|
|
def commit(self):
|
|
self.commit_called = True
|
|
|
|
def rollback(self):
|
|
self.rollback_called = True
|
|
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL)
|
|
|
|
# Update path
|
|
manager.postgres_conn = FakeConn("update")
|
|
assert manager._persist_to_postgresql("t", ["m"]) is True
|
|
assert manager.postgres_conn.commit_called is True
|
|
|
|
# Insert path
|
|
manager.postgres_conn = FakeConn("insert")
|
|
assert manager._persist_to_postgresql("t", ["m"]) is True
|
|
assert manager.postgres_conn.commit_called is True
|
|
|
|
# Error path with rollback
|
|
manager.postgres_conn = FakeConn("error")
|
|
assert manager._persist_to_postgresql("t", ["m"]) is False
|
|
assert manager.postgres_conn.rollback_called is True
|
|
|
|
|
|
def test_create_chat_streams_table_success_and_error():
|
|
"""Ensure table creation commits on success and rolls back on failure."""
|
|
|
|
class FakeCursor:
|
|
def __init__(self, should_fail=False):
|
|
self.should_fail = should_fail
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def execute(self, sql):
|
|
if self.should_fail:
|
|
raise RuntimeError("ddl fail")
|
|
|
|
class FakeConn:
|
|
def __init__(self, should_fail=False):
|
|
self.should_fail = should_fail
|
|
self.commits = 0
|
|
self.rollbacks = 0
|
|
|
|
def cursor(self):
|
|
return FakeCursor(self.should_fail)
|
|
|
|
def commit(self):
|
|
self.commits += 1
|
|
|
|
def rollback(self):
|
|
self.rollbacks += 1
|
|
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL)
|
|
|
|
# Success
|
|
manager.postgres_conn = FakeConn(False)
|
|
manager._create_chat_streams_table()
|
|
assert manager.postgres_conn.commits == 1
|
|
|
|
# Failure triggers rollback
|
|
manager.postgres_conn = FakeConn(True)
|
|
manager._create_chat_streams_table()
|
|
assert manager.postgres_conn.rollbacks == 1
|
|
|
|
|
|
def test_close_closes_resources_and_handles_errors():
|
|
"""Close should gracefully handle both success and exceptions."""
|
|
flags = {"mongo": 0, "pg": 0}
|
|
|
|
class M:
|
|
def close(self):
|
|
flags["mongo"] += 1
|
|
|
|
class P:
|
|
def __init__(self, raise_on_close=False):
|
|
self.raise_on_close = raise_on_close
|
|
|
|
def close(self):
|
|
if self.raise_on_close:
|
|
raise RuntimeError("close fail")
|
|
flags["pg"] += 1
|
|
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
|
|
manager.mongo_client = M()
|
|
manager.postgres_conn = P()
|
|
manager.close()
|
|
assert flags == {"mongo": 1, "pg": 1}
|
|
|
|
# Trigger error branches (no raise escapes)
|
|
manager.mongo_client = None # skip mongo
|
|
manager.postgres_conn = P(True)
|
|
manager.close() # should handle exception gracefully
|
|
|
|
|
|
def test_context_manager_calls_close(monkeypatch):
|
|
"""The context manager protocol should call close() on exit."""
|
|
called = {"close": 0}
|
|
|
|
def _noop(self):
|
|
self.mongo_client = None
|
|
self.mongo_db = None
|
|
|
|
monkeypatch.setattr(
|
|
checkpoint.ChatStreamManager, "_init_mongodb", _noop, raising=True
|
|
)
|
|
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
|
|
|
def fake_close():
|
|
called["close"] += 1
|
|
|
|
manager.close = fake_close
|
|
with manager:
|
|
pass
|
|
assert called["close"] == 1
|
|
|
|
|
|
def test_init_mongodb_success_and_failure(monkeypatch):
|
|
"""MongoDB init should succeed with a valid client and fail gracefully otherwise."""
|
|
|
|
class FakeAdmin:
|
|
def command(self, name):
|
|
assert name == "ping"
|
|
|
|
class DummyDB:
|
|
pass
|
|
|
|
class FakeClient:
|
|
def __init__(self, uri):
|
|
self.uri = uri
|
|
self.admin = FakeAdmin()
|
|
self.checkpointing_db = DummyDB()
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
# Success path
|
|
monkeypatch.setattr(checkpoint, "MongoClient", lambda uri: FakeClient(uri))
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
|
assert manager.mongo_db is not None
|
|
|
|
# Failure path
|
|
class Boom:
|
|
def __init__(self, uri):
|
|
raise RuntimeError("fail connect")
|
|
|
|
monkeypatch.setattr(checkpoint, "MongoClient", Boom)
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
|
# Should have no mongo_db set on failure
|
|
assert getattr(manager, "mongo_db", None) is None
|
|
|
|
|
|
def test_init_postgresql_calls_connect_and_create_table(monkeypatch):
|
|
"""PostgreSQL init should connect and create the required table."""
|
|
flags = {"connected": 0, "created": 0}
|
|
|
|
class FakeConn:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
def fake_connect(self):
|
|
flags["connected"] += 1
|
|
flags["created"] += 1
|
|
return FakeConn()
|
|
|
|
monkeypatch.setattr(
|
|
checkpoint.ChatStreamManager, "_init_postgresql", fake_connect, raising=True
|
|
)
|
|
|
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL)
|
|
assert manager.postgres_conn is None
|
|
assert flags == {"connected": 1, "created": 1}
|
|
|
|
|
|
def test_chat_stream_message_wrapper(monkeypatch):
|
|
"""Wrapper should delegate when enabled and return False when disabled."""
|
|
# When saver enabled, should call default manager
|
|
monkeypatch.setattr(
|
|
checkpoint, "get_bool_env", lambda k, d=False: True, raising=True
|
|
)
|
|
|
|
called = {"args": None}
|
|
|
|
def fake_process(tid, msg, fr):
|
|
called["args"] = (tid, msg, fr)
|
|
return True
|
|
|
|
monkeypatch.setattr(
|
|
checkpoint._default_manager,
|
|
"process_stream_message",
|
|
fake_process,
|
|
raising=True,
|
|
)
|
|
assert checkpoint.chat_stream_message("tid", "msg", "stop") is True
|
|
assert called["args"] == ("tid", "msg", "stop")
|
|
|
|
# When saver disabled, returns False and does not call manager
|
|
monkeypatch.setattr(
|
|
checkpoint, "get_bool_env", lambda k, d=False: False, raising=True
|
|
)
|
|
called["args"] = None
|
|
assert checkpoint.chat_stream_message("tid", "msg", "stop") is False
|
|
assert called["args"] is None
|