diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 16407c8..9665524 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -202,7 +202,7 @@ def planner_node( def human_feedback_node( - state, + state: State, config: RunnableConfig ) -> Command[Literal["planner", "research_team", "reporter", "__end__"]]: current_plan = state.get("current_plan", "") # check if the plan is auto accepted diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index caae2aa..0f6a268 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -416,52 +416,52 @@ def mock_state_base(): } -def test_human_feedback_node_auto_accepted(monkeypatch, mock_state_base): +def test_human_feedback_node_auto_accepted(monkeypatch, mock_state_base, mock_config): # auto_accepted_plan True, should skip interrupt and parse plan state = dict(mock_state_base) state["auto_accepted_plan"] = True - result = human_feedback_node(state) + result = human_feedback_node(state, mock_config) assert isinstance(result, Command) assert result.goto == "research_team" assert result.update["plan_iterations"] == 1 assert result.update["current_plan"]["has_enough_context"] is False -def test_human_feedback_node_edit_plan(monkeypatch, mock_state_base): +def test_human_feedback_node_edit_plan(monkeypatch, mock_state_base, mock_config): # interrupt returns [EDIT_PLAN]..., should return Command to planner state = dict(mock_state_base) state["auto_accepted_plan"] = False with patch("src.graph.nodes.interrupt", return_value="[EDIT_PLAN] Please revise"): - result = human_feedback_node(state) + result = human_feedback_node(state, mock_config) assert isinstance(result, Command) assert result.goto == "planner" assert result.update["messages"][0].name == "feedback" assert "[EDIT_PLAN]" in result.update["messages"][0].content -def test_human_feedback_node_accepted(monkeypatch, mock_state_base): +def test_human_feedback_node_accepted(monkeypatch, mock_state_base, mock_config): # interrupt returns [ACCEPTED]..., should proceed to parse plan state = dict(mock_state_base) state["auto_accepted_plan"] = False with patch("src.graph.nodes.interrupt", return_value="[ACCEPTED] Looks good!"): - result = human_feedback_node(state) + result = human_feedback_node(state, mock_config) assert isinstance(result, Command) assert result.goto == "research_team" assert result.update["plan_iterations"] == 1 assert result.update["current_plan"]["has_enough_context"] is False -def test_human_feedback_node_invalid_interrupt(monkeypatch, mock_state_base): +def test_human_feedback_node_invalid_interrupt(monkeypatch, mock_state_base, mock_config): # interrupt returns something else, should raise TypeError state = dict(mock_state_base) state["auto_accepted_plan"] = False with patch("src.graph.nodes.interrupt", return_value="RANDOM_FEEDBACK"): with pytest.raises(TypeError): - human_feedback_node(state) + human_feedback_node(state, mock_config) def test_human_feedback_node_json_decode_error_first_iteration( - monkeypatch, mock_state_base + monkeypatch, mock_state_base, mock_config ): # repair_json_output returns bad json, json.loads raises JSONDecodeError, plan_iterations=0 state = dict(mock_state_base) @@ -470,13 +470,13 @@ def test_human_feedback_node_json_decode_error_first_iteration( with patch( "src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0) ): - result = human_feedback_node(state) + result = human_feedback_node(state, mock_config) assert isinstance(result, Command) assert result.goto == "__end__" def test_human_feedback_node_json_decode_error_second_iteration( - monkeypatch, mock_state_base + monkeypatch, mock_state_base, mock_config ): # repair_json_output returns bad json, json.loads raises JSONDecodeError, plan_iterations>0 state = dict(mock_state_base) @@ -485,12 +485,12 @@ def test_human_feedback_node_json_decode_error_second_iteration( with patch( "src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0) ): - result = human_feedback_node(state) + result = human_feedback_node(state, mock_config) assert isinstance(result, Command) assert result.goto == "reporter" -def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base): +def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base, mock_config): # Plan does not have enough context, should goto research_team plan = { "has_enough_context": False, @@ -502,7 +502,7 @@ def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base): state = dict(mock_state_base) state["current_plan"] = json.dumps(plan) state["auto_accepted_plan"] = True - result = human_feedback_node(state) + result = human_feedback_node(state, mock_config) assert isinstance(result, Command) assert result.goto == "research_team" assert result.update["plan_iterations"] == 1