mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
test: add unit tests for graph (#296)
* test: added unit test of builder * test: Add unit tests for nodes.py * test: add more unit tests in test_nodes * test: try to fix the unit test error on GitHub * test: reformate the code of test_nodes.py * Fix the test error of reset the local argument * Fixed the test error by setup args * reformat the code
This commit is contained in:
@@ -190,7 +190,7 @@ def human_feedback_node(
|
||||
goto = "reporter"
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Planner response is not a valid JSON")
|
||||
if plan_iterations > 0:
|
||||
if plan_iterations > 1: # the plan_iterations is increased before this check
|
||||
return Command(goto="reporter")
|
||||
else:
|
||||
return Command(goto="__end__")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
121
tests/unit/graph/test_builder.py
Normal file
121
tests/unit/graph/test_builder.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
import src.graph.builder as builder_mod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
class Step:
|
||||
def __init__(self, execution_res=None, step_type=None):
|
||||
self.execution_res = execution_res
|
||||
self.step_type = step_type
|
||||
|
||||
class Plan:
|
||||
def __init__(self, steps):
|
||||
self.steps = steps
|
||||
|
||||
return {
|
||||
"Step": Step,
|
||||
"Plan": Plan,
|
||||
}
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_no_plan(mock_state):
|
||||
state = {"current_plan": None}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "planner"
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_no_steps(mock_state):
|
||||
state = {"current_plan": mock_state["Plan"](steps=[])}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "planner"
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_all_executed(mock_state):
|
||||
Step = mock_state["Step"]
|
||||
Plan = mock_state["Plan"]
|
||||
steps = [Step(execution_res=True), Step(execution_res=True)]
|
||||
state = {"current_plan": Plan(steps=steps)}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "planner"
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_next_researcher(mock_state):
|
||||
Step = mock_state["Step"]
|
||||
Plan = mock_state["Plan"]
|
||||
steps = [
|
||||
Step(execution_res=True),
|
||||
Step(execution_res=None, step_type=builder_mod.StepType.RESEARCH),
|
||||
]
|
||||
state = {"current_plan": Plan(steps=steps)}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "researcher"
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_next_coder(mock_state):
|
||||
Step = mock_state["Step"]
|
||||
Plan = mock_state["Plan"]
|
||||
steps = [
|
||||
Step(execution_res=True),
|
||||
Step(execution_res=None, step_type=builder_mod.StepType.PROCESSING),
|
||||
]
|
||||
state = {"current_plan": Plan(steps=steps)}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "coder"
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_default_planner(mock_state):
|
||||
Step = mock_state["Step"]
|
||||
Plan = mock_state["Plan"]
|
||||
steps = [Step(execution_res=True), Step(execution_res=None, step_type=None)]
|
||||
state = {"current_plan": Plan(steps=steps)}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "planner"
|
||||
|
||||
|
||||
@patch("src.graph.builder.StateGraph")
|
||||
def test_build_base_graph_adds_nodes_and_edges(MockStateGraph):
|
||||
mock_builder = MagicMock()
|
||||
MockStateGraph.return_value = mock_builder
|
||||
|
||||
builder_mod._build_base_graph()
|
||||
|
||||
# Check that all nodes and edges are added
|
||||
assert mock_builder.add_edge.call_count >= 2
|
||||
assert mock_builder.add_node.call_count >= 8
|
||||
mock_builder.add_conditional_edges.assert_called_once()
|
||||
|
||||
|
||||
@patch("src.graph.builder._build_base_graph")
|
||||
@patch("src.graph.builder.MemorySaver")
|
||||
def test_build_graph_with_memory_uses_memory(MockMemorySaver, mock_build_base_graph):
|
||||
mock_builder = MagicMock()
|
||||
mock_build_base_graph.return_value = mock_builder
|
||||
mock_memory = MagicMock()
|
||||
MockMemorySaver.return_value = mock_memory
|
||||
|
||||
builder_mod.build_graph_with_memory()
|
||||
|
||||
mock_builder.compile.assert_called_once_with(checkpointer=mock_memory)
|
||||
|
||||
|
||||
@patch("src.graph.builder._build_base_graph")
|
||||
def test_build_graph_without_memory(mock_build_base_graph):
|
||||
mock_builder = MagicMock()
|
||||
mock_build_base_graph.return_value = mock_builder
|
||||
|
||||
builder_mod.build_graph()
|
||||
|
||||
mock_builder.compile.assert_called_once_with()
|
||||
|
||||
|
||||
def test_graph_is_compiled():
|
||||
# The graph object should be the result of build_graph()
|
||||
with patch("src.graph.builder._build_base_graph") as mock_base:
|
||||
mock_builder = MagicMock()
|
||||
mock_base.return_value = mock_builder
|
||||
mock_builder.compile.return_value = "compiled_graph"
|
||||
# reload the module to re-run the graph assignment
|
||||
importlib.reload(sys.modules["src.graph.builder"])
|
||||
assert builder_mod.graph is not None
|
||||
14
uv.lock
generated
14
uv.lock
generated
@@ -403,6 +403,7 @@ dev = [
|
||||
]
|
||||
test = [
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
]
|
||||
|
||||
@@ -429,6 +430,7 @@ requires-dist = [
|
||||
{ name = "numpy", specifier = ">=2.2.3" },
|
||||
{ name = "pandas", specifier = ">=2.2.3" },
|
||||
{ name = "pytest", marker = "extra == 'test'", specifier = ">=7.4.0" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'test'", specifier = ">=1.0.0" },
|
||||
{ name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.1.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.1" },
|
||||
{ name = "readabilipy", specifier = ">=0.3.0" },
|
||||
@@ -1692,6 +1694,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634, upload-time = "2025-03-02T12:54:52.069Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "1.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pytest" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d0/d4/14f53324cb1a6381bef29d698987625d80052bb33932d8e7cbf9b337b17c/pytest_asyncio-1.0.0.tar.gz", hash = "sha256:d15463d13f4456e1ead2594520216b225a16f781e144f8fdf6c5bb4667c48b3f", size = 46960 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/30/05/ce271016e351fddc8399e546f6e23761967ee09c8c568bbfbecb0c150171/pytest_asyncio-1.0.0-py3-none-any.whl", hash = "sha256:4f024da9f1ef945e680dc68610b52550e36590a67fd31bb3b4943979a1f90ef3", size = 15976 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "6.0.0"
|
||||
|
||||
Reference in New Issue
Block a user