feat: add langgraph.json for langgraph studio debug

This commit is contained in:
He Tao
2025-04-22 15:33:53 +08:00
parent e99bb9bdba
commit abdc740531
8 changed files with 102 additions and 16 deletions

View File

@@ -1,8 +1,9 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .builder import build_graph
from .builder import build_graph_with_memory, build_graph
__all__ = [
"build_graph_with_memory",
"build_graph",
]

View File

@@ -3,6 +3,7 @@
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from .types import State
from .nodes import (
coordinator_node,
@@ -15,13 +16,8 @@ from .nodes import (
)
def build_graph():
"""Build and return the agent workflow graph."""
# use persistent memory to save conversation history
# TODO: be compatible with SQLite / PostgreSQL
memory = MemorySaver()
# build state graph
def _build_base_graph():
"""Build and return the base state graph with all nodes and edges."""
builder = StateGraph(State)
builder.add_edge(START, "coordinator")
builder.add_node("coordinator", coordinator_node)
@@ -32,4 +28,25 @@ def build_graph():
builder.add_node("coder", coder_node)
builder.add_node("human_feedback", human_feedback_node)
builder.add_edge("reporter", END)
return builder
def build_graph_with_memory():
"""Build and return the agent workflow graph with memory."""
# use persistent memory to save conversation history
# TODO: be compatible with SQLite / PostgreSQL
memory = MemorySaver()
# build state graph
builder = _build_base_graph()
return builder.compile(checkpointer=memory)
def build_graph():
"""Build and return the agent workflow graph without memory."""
# build state graph
builder = _build_base_graph()
return builder.compile()
graph = build_graph()

View File

@@ -3,10 +3,10 @@
from langgraph.graph import END, START, StateGraph
from .audio_mixer_node import audio_mixer_node
from .script_writer_node import script_writer_node
from .state import PodcastState
from .tts_node import tts_node
from src.podcast.graph.audio_mixer_node import audio_mixer_node
from src.podcast.graph.script_writer_node import script_writer_node
from src.podcast.graph.state import PodcastState
from src.podcast.graph.tts_node import tts_node
def build_graph():
@@ -23,13 +23,14 @@ def build_graph():
return builder.compile()
workflow = build_graph()
if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
report_content = open("examples/nanjing_tangbao.md").read()
workflow = build_graph()
final_state = workflow.invoke({"input": report_content})
for line in final_state["script"].lines:
print("<M>" if line.speaker == "male" else "<F>", line.text)

View File

@@ -20,11 +20,12 @@ def build_graph():
return builder.compile()
workflow = build_graph()
if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
report_content = open("examples/nanjing_tangbao.md").read()
workflow = build_graph()
final_state = workflow.invoke({"input": report_content})

View File

@@ -14,7 +14,7 @@ from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, ToolMessage
from langgraph.types import Command
from src.graph.builder import build_graph
from src.graph.builder import build_graph_with_memory
from src.podcast.graph.builder import build_graph as build_podcast_graph
from src.ppt.graph.builder import build_graph as build_ppt_graph
from src.server.chat_request import (
@@ -43,7 +43,7 @@ app.add_middleware(
allow_headers=["*"], # Allows all headers
)
graph = build_graph()
graph = build_graph_with_memory()
@app.post("/api/chat/stream")