mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-04 06:32:13 +08:00
feat: add config modules
This commit is contained in:
3
backend/src/config/__init__.py
Normal file
3
backend/src/config/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .app_config import get_app_config
|
||||
|
||||
__all__ = ["get_app_config"]
|
||||
143
backend/src/config/app_config.py
Normal file
143
backend/src/config/app_config.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from src.config.model_config import ModelConfig
|
||||
from src.config.sandbox_config import SandboxConfig
|
||||
from src.config.tool_config import ToolConfig, ToolGroupConfig
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""Config for the DeerFlow application"""
|
||||
|
||||
models: list[ModelConfig] = Field(
|
||||
default_factory=list, description="Available models"
|
||||
)
|
||||
sandbox: SandboxConfig = Field(description="Sandbox configuration")
|
||||
tools: list[ToolConfig] = Field(default_factory=list, description="Available tools")
|
||||
tool_groups: list[ToolGroupConfig] = Field(
|
||||
default_factory=list, description="Available tool groups"
|
||||
)
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
|
||||
@classmethod
|
||||
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
||||
"""Resolve the config file path.
|
||||
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, first check the `config.yaml` in the current directory, then use `config.yaml` in the parent directory.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(
|
||||
f"Config file specified by param `config_path` not found at {path}"
|
||||
)
|
||||
return path
|
||||
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
|
||||
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(
|
||||
f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}"
|
||||
)
|
||||
return path
|
||||
else:
|
||||
# Check if the config.yaml is in the parent directory of CWD
|
||||
path = Path(os.getcwd()).parent / "config.yaml"
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Config file not found at {path}")
|
||||
return path
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> Self:
|
||||
"""Load config from YAML file.
|
||||
|
||||
See `resolve_config_path` for more details.
|
||||
|
||||
Args:
|
||||
config_path: Path to the config file.
|
||||
|
||||
Returns:
|
||||
AppConfig: The loaded config.
|
||||
"""
|
||||
resolved_path = cls.resolve_config_path(config_path)
|
||||
with open(resolved_path, "r") as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
cls.resolve_env_variables(config_data)
|
||||
result = cls.model_validate(config_data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def resolve_env_variables(cls, config: dict) -> dict:
|
||||
"""Recursively resolve environment variables in the config.
|
||||
|
||||
Environment variables are resolved using the `os.getenv` function. Example: $OPENAI_API_KEY
|
||||
|
||||
Args:
|
||||
config: The config to resolve environment variables in.
|
||||
|
||||
Returns:
|
||||
The config with environment variables resolved.
|
||||
"""
|
||||
for key, value in config.items():
|
||||
if isinstance(value, str):
|
||||
if value.startswith("$"):
|
||||
env_value = os.getenv(value[1:], None)
|
||||
if env_value is not None:
|
||||
config[key] = env_value
|
||||
else:
|
||||
config[key] = value
|
||||
elif isinstance(value, dict):
|
||||
config[key] = cls.resolve_env_variables(value)
|
||||
elif isinstance(value, list):
|
||||
config[key] = [cls.resolve_env_variables(item) for item in value]
|
||||
return config
|
||||
|
||||
def get_model_config(self, name: str) -> ModelConfig | None:
|
||||
"""Get the model config by name.
|
||||
|
||||
Args:
|
||||
name: The name of the model to get the config for.
|
||||
|
||||
Returns:
|
||||
The model config if found, otherwise None.
|
||||
"""
|
||||
return next((model for model in self.models if model.name == name), None)
|
||||
|
||||
def get_tool_config(self, name: str) -> ToolConfig | None:
|
||||
"""Get the tool config by name.
|
||||
|
||||
Args:
|
||||
name: The name of the tool to get the config for.
|
||||
|
||||
Returns:
|
||||
The tool config if found, otherwise None.
|
||||
"""
|
||||
return next((tool for tool in self.tools if tool.name == name), None)
|
||||
|
||||
def get_tool_group_config(self, name: str) -> ToolGroupConfig | None:
|
||||
"""Get the tool group config by name.
|
||||
|
||||
Args:
|
||||
name: The name of the tool group to get the config for.
|
||||
|
||||
Returns:
|
||||
The tool group config if found, otherwise None.
|
||||
"""
|
||||
return next((group for group in self.tool_groups if group.name == name), None)
|
||||
|
||||
|
||||
_app_config: AppConfig | None = None
|
||||
|
||||
|
||||
def get_app_config() -> AppConfig:
|
||||
"""Get the DeerFlow config instance."""
|
||||
global _app_config
|
||||
if _app_config is None:
|
||||
_app_config = AppConfig.from_file()
|
||||
return _app_config
|
||||
26
backend/src/config/model_config.py
Normal file
26
backend/src/config/model_config.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""Config section for a model"""
|
||||
|
||||
name: str = Field(..., description="Unique name for the model")
|
||||
display_name: str | None = Field(
|
||||
..., default_factory=lambda: None, description="Display name for the model"
|
||||
)
|
||||
description: str | None = Field(
|
||||
..., default_factory=lambda: None, description="Description for the model"
|
||||
)
|
||||
use: str = Field(
|
||||
...,
|
||||
description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)",
|
||||
)
|
||||
model: str = Field(..., description="Model name")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
supports_thinking: bool = Field(
|
||||
default_factory=lambda: False, description="Whether the model supports thinking"
|
||||
)
|
||||
when_thinking_enabled: dict | None = Field(
|
||||
default_factory=lambda: None,
|
||||
description="Extra settings to be passed to the model when thinking is enabled",
|
||||
)
|
||||
10
backend/src/config/sandbox_config.py
Normal file
10
backend/src/config/sandbox_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SandboxConfig(BaseModel):
|
||||
"""Config section for a sandbox"""
|
||||
|
||||
use: str = Field(
|
||||
...,
|
||||
description="Class path of the sandbox provider(e.g. src.sandbox.local:LocalSandbox)",
|
||||
)
|
||||
20
backend/src/config/tool_config.py
Normal file
20
backend/src/config/tool_config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ToolGroupConfig(BaseModel):
|
||||
"""Config section for a tool group"""
|
||||
|
||||
name: str = Field(..., description="Unique name for the tool group")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
"""Config section for a tool"""
|
||||
|
||||
name: str = Field(..., description="Unique name for the tool")
|
||||
group: str = Field(..., description="Group name for the tool")
|
||||
use: str = Field(
|
||||
...,
|
||||
description="Variable name of the tool provider(e.g. src.sandbox.tools:bash_tool)",
|
||||
)
|
||||
model_config = ConfigDict(extra="allow")
|
||||
Reference in New Issue
Block a user