mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-26 07:14:47 +08:00
feat: support function factory (#4)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .lead_agent import lead_agent
|
||||
from .lead_agent import make_lead_agent
|
||||
from .thread_state import SandboxState, ThreadState
|
||||
|
||||
__all__ = ["lead_agent", "SandboxState", "ThreadState"]
|
||||
__all__ = ["make_lead_agent", "SandboxState", "ThreadState"]
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .agent import lead_agent
|
||||
from .agent import make_lead_agent
|
||||
|
||||
__all__ = ["lead_agent"]
|
||||
__all__ = ["make_lead_agent"]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.agents.lead_agent.prompt import apply_prompt_template
|
||||
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
@@ -11,10 +12,15 @@ from src.tools import get_available_tools
|
||||
# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available
|
||||
middlewares = [ThreadDataMiddleware(), SandboxMiddleware(), TitleMiddleware()]
|
||||
|
||||
lead_agent = create_agent(
|
||||
model=create_chat_model(thinking_enabled=True),
|
||||
tools=get_available_tools(),
|
||||
middleware=middlewares,
|
||||
system_prompt=apply_prompt_template(),
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
def make_lead_agent(config: RunnableConfig):
|
||||
thinking_enabled = config.get("configurable", {}).get("thinking_enabled", True)
|
||||
model_name = config.get("configurable", {}).get("model_name")
|
||||
print(f"thinking_enabled: {thinking_enabled}, model_name: {model_name}")
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||
tools=get_available_tools(),
|
||||
middleware=middlewares,
|
||||
system_prompt=apply_prompt_template(),
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
SYSTEM_PROMPT = f"""
|
||||
|
||||
@@ -71,8 +71,9 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
@override
|
||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
# Generate new thread ID and create directories
|
||||
print(runtime.context)
|
||||
thread_id = runtime.context["thread_id"]
|
||||
thread_id = runtime.context.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("Thread ID is required in the context")
|
||||
paths = self._create_thread_directories(thread_id)
|
||||
print(f"Created thread data directories for thread {thread_id}")
|
||||
|
||||
|
||||
@@ -168,14 +168,16 @@ def read_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
path: str,
|
||||
view_range: tuple[int, int] | None = None,
|
||||
start_line: int | None = None,
|
||||
end_line: int | None = None,
|
||||
) -> str:
|
||||
"""Read the contents of a text file.
|
||||
"""Read the contents of a text file. Use this to examine source code, configuration files, logs, or any text-based file.
|
||||
|
||||
Args:
|
||||
description: Explain why you are viewing this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
description: Explain why you are reading this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
path: The **absolute** path to the file to read.
|
||||
view_range: The range of lines to view. The range is inclusive and starts at 1. For example, (1, 10) will view the first 10 lines of the file.
|
||||
start_line: Optional starting line number (1-indexed, inclusive). Use with end_line to read a specific range.
|
||||
end_line: Optional ending line number (1-indexed, inclusive). Use with start_line to read a specific range.
|
||||
"""
|
||||
try:
|
||||
sandbox = sandbox_from_runtime(runtime)
|
||||
@@ -185,9 +187,8 @@ def read_file_tool(
|
||||
content = sandbox.read_file(path)
|
||||
if not content:
|
||||
return "(empty)"
|
||||
if view_range:
|
||||
start, end = view_range
|
||||
content = "\n".join(content.splitlines()[start - 1 : end])
|
||||
if start_line is not None and end_line is not None:
|
||||
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
|
||||
return content
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
Reference in New Issue
Block a user