diff --git a/backend/src/config/__init__.py b/backend/src/config/__init__.py new file mode 100644 index 0000000..f9d546a --- /dev/null +++ b/backend/src/config/__init__.py @@ -0,0 +1,3 @@ +from .app_config import get_app_config + +__all__ = ["get_app_config"] diff --git a/backend/src/config/app_config.py b/backend/src/config/app_config.py new file mode 100644 index 0000000..2168652 --- /dev/null +++ b/backend/src/config/app_config.py @@ -0,0 +1,143 @@ +import os +from pathlib import Path +from typing import Self + +import yaml +from pydantic import BaseModel, ConfigDict, Field + +from src.config.model_config import ModelConfig +from src.config.sandbox_config import SandboxConfig +from src.config.tool_config import ToolConfig, ToolGroupConfig + + +class AppConfig(BaseModel): + """Config for the DeerFlow application""" + + models: list[ModelConfig] = Field( + default_factory=list, description="Available models" + ) + sandbox: SandboxConfig = Field(description="Sandbox configuration") + tools: list[ToolConfig] = Field(default_factory=list, description="Available tools") + tool_groups: list[ToolGroupConfig] = Field( + default_factory=list, description="Available tool groups" + ) + model_config = ConfigDict(extra="allow", frozen=False) + + @classmethod + def resolve_config_path(cls, config_path: str | None = None) -> Path: + """Resolve the config file path. + + Priority: + 1. If provided `config_path` argument, use it. + 2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it. + 3. Otherwise, first check the `config.yaml` in the current directory, then use `config.yaml` in the parent directory. + """ + if config_path: + path = Path(config_path) + if not Path.exists(path): + raise FileNotFoundError( + f"Config file specified by param `config_path` not found at {path}" + ) + return path + elif os.getenv("DEER_FLOW_CONFIG_PATH"): + path = Path(os.getenv("DEER_FLOW_CONFIG_PATH")) + if not Path.exists(path): + raise FileNotFoundError( + f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}" + ) + return path + else: + # Check if the config.yaml is in the parent directory of CWD + path = Path(os.getcwd()).parent / "config.yaml" + if not path.exists(): + raise FileNotFoundError(f"Config file not found at {path}") + return path + + @classmethod + def from_file(cls, config_path: str | None = None) -> Self: + """Load config from YAML file. + + See `resolve_config_path` for more details. + + Args: + config_path: Path to the config file. + + Returns: + AppConfig: The loaded config. + """ + resolved_path = cls.resolve_config_path(config_path) + with open(resolved_path, "r") as f: + config_data = yaml.safe_load(f) + cls.resolve_env_variables(config_data) + result = cls.model_validate(config_data) + return result + + @classmethod + def resolve_env_variables(cls, config: dict) -> dict: + """Recursively resolve environment variables in the config. + + Environment variables are resolved using the `os.getenv` function. Example: $OPENAI_API_KEY + + Args: + config: The config to resolve environment variables in. + + Returns: + The config with environment variables resolved. + """ + for key, value in config.items(): + if isinstance(value, str): + if value.startswith("$"): + env_value = os.getenv(value[1:], None) + if env_value is not None: + config[key] = env_value + else: + config[key] = value + elif isinstance(value, dict): + config[key] = cls.resolve_env_variables(value) + elif isinstance(value, list): + config[key] = [cls.resolve_env_variables(item) for item in value] + return config + + def get_model_config(self, name: str) -> ModelConfig | None: + """Get the model config by name. + + Args: + name: The name of the model to get the config for. + + Returns: + The model config if found, otherwise None. + """ + return next((model for model in self.models if model.name == name), None) + + def get_tool_config(self, name: str) -> ToolConfig | None: + """Get the tool config by name. + + Args: + name: The name of the tool to get the config for. + + Returns: + The tool config if found, otherwise None. + """ + return next((tool for tool in self.tools if tool.name == name), None) + + def get_tool_group_config(self, name: str) -> ToolGroupConfig | None: + """Get the tool group config by name. + + Args: + name: The name of the tool group to get the config for. + + Returns: + The tool group config if found, otherwise None. + """ + return next((group for group in self.tool_groups if group.name == name), None) + + +_app_config: AppConfig | None = None + + +def get_app_config() -> AppConfig: + """Get the DeerFlow config instance.""" + global _app_config + if _app_config is None: + _app_config = AppConfig.from_file() + return _app_config diff --git a/backend/src/config/model_config.py b/backend/src/config/model_config.py new file mode 100644 index 0000000..ab79b47 --- /dev/null +++ b/backend/src/config/model_config.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, ConfigDict, Field + + +class ModelConfig(BaseModel): + """Config section for a model""" + + name: str = Field(..., description="Unique name for the model") + display_name: str | None = Field( + ..., default_factory=lambda: None, description="Display name for the model" + ) + description: str | None = Field( + ..., default_factory=lambda: None, description="Description for the model" + ) + use: str = Field( + ..., + description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)", + ) + model: str = Field(..., description="Model name") + model_config = ConfigDict(extra="allow") + supports_thinking: bool = Field( + default_factory=lambda: False, description="Whether the model supports thinking" + ) + when_thinking_enabled: dict | None = Field( + default_factory=lambda: None, + description="Extra settings to be passed to the model when thinking is enabled", + ) diff --git a/backend/src/config/sandbox_config.py b/backend/src/config/sandbox_config.py new file mode 100644 index 0000000..48a25a3 --- /dev/null +++ b/backend/src/config/sandbox_config.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field + + +class SandboxConfig(BaseModel): + """Config section for a sandbox""" + + use: str = Field( + ..., + description="Class path of the sandbox provider(e.g. src.sandbox.local:LocalSandbox)", + ) diff --git a/backend/src/config/tool_config.py b/backend/src/config/tool_config.py new file mode 100644 index 0000000..e267f0d --- /dev/null +++ b/backend/src/config/tool_config.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel, ConfigDict, Field + + +class ToolGroupConfig(BaseModel): + """Config section for a tool group""" + + name: str = Field(..., description="Unique name for the tool group") + model_config = ConfigDict(extra="allow") + + +class ToolConfig(BaseModel): + """Config section for a tool""" + + name: str = Field(..., description="Unique name for the tool") + group: str = Field(..., description="Group name for the tool") + use: str = Field( + ..., + description="Variable name of the tool provider(e.g. src.sandbox.tools:bash_tool)", + ) + model_config = ConfigDict(extra="allow")