diff --git a/backend/packages/harness/deerflow/client.py b/backend/packages/harness/deerflow/client.py index 983aa9f..5f5c2ac 100644 --- a/backend/packages/harness/deerflow/client.py +++ b/backend/packages/harness/deerflow/client.py @@ -235,6 +235,8 @@ class DeerFlowClient: d: dict[str, Any] = {"type": "ai", "content": msg.content, "id": getattr(msg, "id", None)} if msg.tool_calls: d["tool_calls"] = [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls] + if getattr(msg, "usage_metadata", None): + d["usage_metadata"] = msg.usage_metadata return d if isinstance(msg, ToolMessage): return { @@ -296,9 +298,10 @@ class DeerFlowClient: StreamEvent with one of: - type="values" data={"title": str|None, "messages": [...], "artifacts": [...]} - type="messages-tuple" data={"type": "ai", "content": str, "id": str} + - type="messages-tuple" data={"type": "ai", "content": str, "id": str, "usage_metadata": {...}} - type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]} - type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str} - - type="end" data={} + - type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}} """ if thread_id is None: thread_id = str(uuid.uuid4()) @@ -310,6 +313,7 @@ class DeerFlowClient: context = {"thread_id": thread_id} seen_ids: set[str] = set() + cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} for chunk in self._agent.stream(state, config=config, context=context, stream_mode="values"): messages = chunk.get("messages", []) @@ -322,6 +326,13 @@ class DeerFlowClient: seen_ids.add(msg_id) if isinstance(msg, AIMessage): + # Track token usage from AI messages + usage = getattr(msg, "usage_metadata", None) + if usage: + cumulative_usage["input_tokens"] += usage.get("input_tokens", 0) or 0 + cumulative_usage["output_tokens"] += usage.get("output_tokens", 0) or 0 + cumulative_usage["total_tokens"] += usage.get("total_tokens", 0) or 0 + if msg.tool_calls: yield StreamEvent( type="messages-tuple", @@ -335,10 +346,14 @@ class DeerFlowClient: text = self._extract_text(msg.content) if text: - yield StreamEvent( - type="messages-tuple", - data={"type": "ai", "content": text, "id": msg_id}, - ) + event_data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id} + if usage: + event_data["usage_metadata"] = { + "input_tokens": usage.get("input_tokens", 0) or 0, + "output_tokens": usage.get("output_tokens", 0) or 0, + "total_tokens": usage.get("total_tokens", 0) or 0, + } + yield StreamEvent(type="messages-tuple", data=event_data) elif isinstance(msg, ToolMessage): yield StreamEvent( @@ -362,7 +377,7 @@ class DeerFlowClient: }, ) - yield StreamEvent(type="end", data={}) + yield StreamEvent(type="end", data={"usage": cumulative_usage}) def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str: """Send a message and return the final text response. diff --git a/backend/tests/test_token_usage.py b/backend/tests/test_token_usage.py new file mode 100644 index 0000000..db728e8 --- /dev/null +++ b/backend/tests/test_token_usage.py @@ -0,0 +1,306 @@ +"""Tests for token usage tracking in DeerFlowClient.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + +from deerflow.client import DeerFlowClient + +# --------------------------------------------------------------------------- +# _serialize_message — usage_metadata passthrough +# --------------------------------------------------------------------------- + + +class TestSerializeMessageUsageMetadata: + """Verify _serialize_message includes usage_metadata when present.""" + + def test_ai_message_with_usage_metadata(self): + msg = AIMessage( + content="Hello", + id="msg-1", + usage_metadata={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, + ) + result = DeerFlowClient._serialize_message(msg) + assert result["type"] == "ai" + assert result["usage_metadata"] == { + "input_tokens": 100, + "output_tokens": 50, + "total_tokens": 150, + } + + def test_ai_message_without_usage_metadata(self): + msg = AIMessage(content="Hello", id="msg-2") + result = DeerFlowClient._serialize_message(msg) + assert result["type"] == "ai" + assert "usage_metadata" not in result + + def test_tool_message_never_has_usage_metadata(self): + msg = ToolMessage(content="result", tool_call_id="tc-1", name="search") + result = DeerFlowClient._serialize_message(msg) + assert result["type"] == "tool" + assert "usage_metadata" not in result + + def test_human_message_never_has_usage_metadata(self): + msg = HumanMessage(content="Hi") + result = DeerFlowClient._serialize_message(msg) + assert result["type"] == "human" + assert "usage_metadata" not in result + + def test_ai_message_with_tool_calls_and_usage(self): + msg = AIMessage( + content="", + id="msg-3", + tool_calls=[{"name": "search", "args": {"q": "test"}, "id": "tc-1"}], + usage_metadata={"input_tokens": 200, "output_tokens": 30, "total_tokens": 230}, + ) + result = DeerFlowClient._serialize_message(msg) + assert result["type"] == "ai" + assert result["tool_calls"] == [{"name": "search", "args": {"q": "test"}, "id": "tc-1"}] + assert result["usage_metadata"]["input_tokens"] == 200 + + def test_ai_message_with_zero_usage(self): + """usage_metadata with zero token counts should be included.""" + msg = AIMessage( + content="Hello", + id="msg-4", + usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + ) + result = DeerFlowClient._serialize_message(msg) + assert result["usage_metadata"] == { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + + +# --------------------------------------------------------------------------- +# Cumulative usage tracking (simulated, no real agent) +# --------------------------------------------------------------------------- + + +class TestCumulativeUsageTracking: + """Test cumulative usage aggregation logic.""" + + def test_single_message_usage(self): + """Single AI message usage should be the total.""" + cumulative = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + usage = {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150} + cumulative["input_tokens"] += usage.get("input_tokens", 0) or 0 + cumulative["output_tokens"] += usage.get("output_tokens", 0) or 0 + cumulative["total_tokens"] += usage.get("total_tokens", 0) or 0 + assert cumulative == {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150} + + def test_multiple_messages_usage(self): + """Multiple AI messages should accumulate.""" + cumulative = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + messages_usage = [ + {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, + {"input_tokens": 200, "output_tokens": 30, "total_tokens": 230}, + {"input_tokens": 150, "output_tokens": 80, "total_tokens": 230}, + ] + for usage in messages_usage: + cumulative["input_tokens"] += usage.get("input_tokens", 0) or 0 + cumulative["output_tokens"] += usage.get("output_tokens", 0) or 0 + cumulative["total_tokens"] += usage.get("total_tokens", 0) or 0 + assert cumulative == {"input_tokens": 450, "output_tokens": 160, "total_tokens": 610} + + def test_missing_usage_keys_treated_as_zero(self): + """Missing keys in usage dict should be treated as 0.""" + cumulative = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + usage = {"input_tokens": 50} # missing output_tokens, total_tokens + cumulative["input_tokens"] += usage.get("input_tokens", 0) or 0 + cumulative["output_tokens"] += usage.get("output_tokens", 0) or 0 + cumulative["total_tokens"] += usage.get("total_tokens", 0) or 0 + assert cumulative == {"input_tokens": 50, "output_tokens": 0, "total_tokens": 0} + + def test_empty_usage_metadata_stays_zero(self): + """No usage metadata should leave cumulative at zero.""" + cumulative = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + # Simulate: AI message without usage_metadata + usage = None + if usage: + cumulative["input_tokens"] += usage.get("input_tokens", 0) or 0 + assert cumulative == {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + + +# --------------------------------------------------------------------------- +# stream() integration — usage_metadata in end event and messages-tuple +# --------------------------------------------------------------------------- + + +def _make_agent_mock(chunks): + """Create a mock agent whose .stream() yields the given chunks.""" + agent = MagicMock() + agent.stream.return_value = iter(chunks) + return agent + + +def _mock_app_config(): + """Provide a minimal AppConfig mock.""" + model = MagicMock() + model.name = "test-model" + model.model = "test-model" + model.supports_thinking = False + model.supports_reasoning_effort = False + model.model_dump.return_value = {"name": "test-model", "use": "langchain_openai:ChatOpenAI"} + config = MagicMock() + config.models = [model] + return config + + +class TestStreamUsageIntegration: + """Test that stream() emits usage_metadata in messages-tuple and end events.""" + + def _make_client(self): + with patch("deerflow.client.get_app_config", return_value=_mock_app_config()): + return DeerFlowClient() + + def test_stream_emits_usage_in_messages_tuple(self): + """messages-tuple AI event should include usage_metadata when present.""" + client = self._make_client() + ai = AIMessage( + content="Hello!", + id="ai-1", + usage_metadata={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, + ) + chunks = [ + {"messages": [HumanMessage(content="hi", id="h-1"), ai]}, + ] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("hi", thread_id="t1")) + + # Find the AI text messages-tuple event + ai_text_events = [ + e for e in events + if e.type == "messages-tuple" + and e.data.get("type") == "ai" + and e.data.get("content") == "Hello!" + ] + assert len(ai_text_events) == 1 + event_data = ai_text_events[0].data + assert "usage_metadata" in event_data + assert event_data["usage_metadata"] == { + "input_tokens": 100, + "output_tokens": 50, + "total_tokens": 150, + } + + def test_stream_cumulative_usage_in_end_event(self): + """end event should include cumulative usage across all AI messages.""" + client = self._make_client() + ai1 = AIMessage( + content="First", + id="ai-1", + usage_metadata={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, + ) + ai2 = AIMessage( + content="Second", + id="ai-2", + usage_metadata={"input_tokens": 200, "output_tokens": 30, "total_tokens": 230}, + ) + chunks = [ + {"messages": [HumanMessage(content="hi", id="h-1"), ai1]}, + {"messages": [HumanMessage(content="hi", id="h-1"), ai1, ai2]}, + ] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("hi", thread_id="t1")) + + # Find the end event + end_events = [e for e in events if e.type == "end"] + assert len(end_events) == 1 + end_data = end_events[0].data + assert "usage" in end_data + assert end_data["usage"] == { + "input_tokens": 300, + "output_tokens": 80, + "total_tokens": 380, + } + + def test_stream_no_usage_metadata_no_usage_in_events(self): + """When AI messages have no usage_metadata, events should not include it.""" + client = self._make_client() + ai = AIMessage(content="Hello!", id="ai-1") + chunks = [ + {"messages": [HumanMessage(content="hi", id="h-1"), ai]}, + ] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("hi", thread_id="t1")) + + # messages-tuple AI event should NOT have usage_metadata + ai_text_events = [ + e for e in events + if e.type == "messages-tuple" + and e.data.get("type") == "ai" + and e.data.get("content") == "Hello!" + ] + assert len(ai_text_events) == 1 + assert "usage_metadata" not in ai_text_events[0].data + + # end event should still exist but with zero usage + end_events = [e for e in events if e.type == "end"] + assert len(end_events) == 1 + usage = end_events[0].data.get("usage", {}) + assert usage.get("input_tokens", 0) == 0 + assert usage.get("output_tokens", 0) == 0 + assert usage.get("total_tokens", 0) == 0 + + def test_stream_usage_with_tool_calls(self): + """Usage should be tracked even when AI message has tool calls.""" + client = self._make_client() + ai_tool = AIMessage( + content="", + id="ai-1", + tool_calls=[{"name": "search", "args": {"q": "test"}, "id": "tc-1"}], + usage_metadata={"input_tokens": 150, "output_tokens": 25, "total_tokens": 175}, + ) + tool_result = ToolMessage(content="result", id="tm-1", tool_call_id="tc-1", name="search") + ai_final = AIMessage( + content="Here is the answer.", + id="ai-2", + usage_metadata={"input_tokens": 200, "output_tokens": 100, "total_tokens": 300}, + ) + chunks = [ + {"messages": [HumanMessage(content="search", id="h-1"), ai_tool]}, + {"messages": [HumanMessage(content="search", id="h-1"), ai_tool, tool_result]}, + {"messages": [HumanMessage(content="search", id="h-1"), ai_tool, tool_result, ai_final]}, + ] + agent = _make_agent_mock(chunks) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("search", thread_id="t1")) + + # Final AI text event should have usage_metadata + ai_text_events = [ + e for e in events + if e.type == "messages-tuple" + and e.data.get("type") == "ai" + and e.data.get("content") == "Here is the answer." + ] + assert len(ai_text_events) == 1 + assert ai_text_events[0].data["usage_metadata"]["total_tokens"] == 300 + + # end event should have cumulative usage + end_events = [e for e in events if e.type == "end"] + assert end_events[0].data["usage"]["input_tokens"] == 350 + assert end_events[0].data["usage"]["output_tokens"] == 125 + assert end_events[0].data["usage"]["total_tokens"] == 475