mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-02 22:02:13 +08:00
Fix Windows backend test compatibility (#1384)
* Fix Windows backend test compatibility * Preserve ACP path style on Windows * Fix installer import ordering * Address review comments for Windows fixes --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -11,7 +11,6 @@ The provider itself handles:
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import fcntl
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
@@ -20,6 +19,12 @@ import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
except ImportError: # pragma: no cover - Windows fallback
|
||||
fcntl = None # type: ignore[assignment]
|
||||
import msvcrt
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, Paths, get_paths
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
@@ -42,6 +47,24 @@ DEFAULT_REPLICAS = 3 # Maximum concurrent sandbox containers
|
||||
IDLE_CHECK_INTERVAL = 60 # Check every 60 seconds
|
||||
|
||||
|
||||
def _lock_file_exclusive(lock_file) -> None:
|
||||
if fcntl is not None:
|
||||
fcntl.flock(lock_file, fcntl.LOCK_EX)
|
||||
return
|
||||
|
||||
lock_file.seek(0)
|
||||
msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1)
|
||||
|
||||
|
||||
def _unlock_file(lock_file) -> None:
|
||||
if fcntl is not None:
|
||||
fcntl.flock(lock_file, fcntl.LOCK_UN)
|
||||
return
|
||||
|
||||
lock_file.seek(0)
|
||||
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
|
||||
|
||||
class AioSandboxProvider(SandboxProvider):
|
||||
"""Sandbox provider that manages containers running the AIO sandbox.
|
||||
|
||||
@@ -405,8 +428,10 @@ class AioSandboxProvider(SandboxProvider):
|
||||
lock_path = paths.thread_dir(thread_id) / f"{sandbox_id}.lock"
|
||||
|
||||
with open(lock_path, "a", encoding="utf-8") as lock_file:
|
||||
locked = False
|
||||
try:
|
||||
fcntl.flock(lock_file, fcntl.LOCK_EX)
|
||||
_lock_file_exclusive(lock_file)
|
||||
locked = True
|
||||
# Re-check in-process caches under the file lock in case another
|
||||
# thread in this process won the race while we were waiting.
|
||||
with self._lock:
|
||||
@@ -440,7 +465,8 @@ class AioSandboxProvider(SandboxProvider):
|
||||
|
||||
return self._create_sandbox(thread_id, sandbox_id)
|
||||
finally:
|
||||
fcntl.flock(lock_file, fcntl.LOCK_UN)
|
||||
if locked:
|
||||
_unlock_file(lock_file)
|
||||
|
||||
def _evict_oldest_warm(self) -> str | None:
|
||||
"""Destroy the oldest container in the warm pool to free capacity.
|
||||
|
||||
@@ -60,7 +60,14 @@ def _resolve_credential_path(env_var: str, default_relative_path: str) -> Path:
|
||||
configured_path = os.getenv(env_var)
|
||||
if configured_path:
|
||||
return Path(configured_path).expanduser()
|
||||
return Path.home() / default_relative_path
|
||||
return _home_dir() / default_relative_path
|
||||
|
||||
|
||||
def _home_dir() -> Path:
|
||||
home = os.getenv("HOME")
|
||||
if home:
|
||||
return Path(home).expanduser()
|
||||
return Path.home()
|
||||
|
||||
|
||||
def _load_json_file(path: Path, label: str) -> dict[str, Any] | None:
|
||||
@@ -90,7 +97,7 @@ def _read_secret_from_file_descriptor(env_var: str) -> str | None:
|
||||
return None
|
||||
|
||||
try:
|
||||
secret = Path(f"/dev/fd/{fd}").read_text().strip()
|
||||
secret = os.read(fd, 1024 * 1024).decode().strip()
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to read {env_var}: {e}")
|
||||
return None
|
||||
@@ -111,7 +118,7 @@ def _iter_claude_code_credential_paths() -> list[Path]:
|
||||
if override_path:
|
||||
paths.append(Path(override_path).expanduser())
|
||||
|
||||
default_path = Path.home() / ".claude/.credentials.json"
|
||||
default_path = _home_dir() / ".claude/.credentials.json"
|
||||
if not paths or paths[-1] != default_path:
|
||||
paths.append(default_path)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import posixpath
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
@@ -99,8 +100,8 @@ def _resolve_skills_path(path: str) -> str:
|
||||
if path == skills_container:
|
||||
return skills_host
|
||||
|
||||
relative = path[len(skills_container) :].lstrip("/")
|
||||
return str(Path(skills_host) / relative) if relative else skills_host
|
||||
relative = path[len(skills_container):].lstrip("/")
|
||||
return _join_path_preserving_style(skills_host, relative)
|
||||
|
||||
|
||||
def _is_acp_workspace_path(path: str) -> bool:
|
||||
@@ -190,23 +191,39 @@ def _resolve_acp_workspace_path(path: str, thread_id: str | None = None) -> str:
|
||||
return host_path
|
||||
|
||||
relative = path[len(_ACP_WORKSPACE_VIRTUAL_PATH) :].lstrip("/")
|
||||
if not relative:
|
||||
return host_path
|
||||
resolved = _join_path_preserving_style(host_path, relative)
|
||||
|
||||
resolved = Path(host_path).resolve() / relative
|
||||
# Ensure resolved path stays inside the ACP workspace
|
||||
if "/" in host_path and "\\" not in host_path:
|
||||
base_path = posixpath.normpath(host_path)
|
||||
candidate_path = posixpath.normpath(resolved)
|
||||
try:
|
||||
if posixpath.commonpath([base_path, candidate_path]) != base_path:
|
||||
raise PermissionError("Access denied: path traversal detected")
|
||||
except ValueError:
|
||||
raise PermissionError("Access denied: path traversal detected") from None
|
||||
return resolved
|
||||
|
||||
resolved_path = Path(resolved).resolve()
|
||||
try:
|
||||
resolved.resolve().relative_to(Path(host_path).resolve())
|
||||
resolved_path.relative_to(Path(host_path).resolve())
|
||||
except ValueError:
|
||||
raise PermissionError("Access denied: path traversal detected")
|
||||
|
||||
return str(resolved)
|
||||
return str(resolved_path)
|
||||
|
||||
|
||||
def _path_variants(path: str) -> set[str]:
|
||||
return {path, path.replace("\\", "/"), path.replace("/", "\\")}
|
||||
|
||||
|
||||
def _join_path_preserving_style(base: str, relative: str) -> str:
|
||||
if not relative:
|
||||
return base
|
||||
if "/" in base and "\\" not in base:
|
||||
return f"{base.rstrip('/')}/{relative}"
|
||||
return str(Path(base) / relative)
|
||||
|
||||
|
||||
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
|
||||
"""Sanitize an error message to avoid leaking host filesystem paths.
|
||||
|
||||
@@ -249,7 +266,7 @@ def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
|
||||
return actual_base
|
||||
if path.startswith(f"{virtual_base}/"):
|
||||
rest = path[len(virtual_base) :].lstrip("/")
|
||||
return str(Path(actual_base) / rest) if rest else actual_base
|
||||
return _join_path_preserving_style(actual_base, rest)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@@ -5,11 +5,12 @@ Both Gateway and Client delegate to these functions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import posixpath
|
||||
import shutil
|
||||
import stat
|
||||
import tempfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from pathlib import Path, PurePosixPath, PureWindowsPath
|
||||
|
||||
from deerflow.skills.loader import get_skills_root_path
|
||||
from deerflow.skills.validation import _validate_skill_frontmatter
|
||||
@@ -26,9 +27,14 @@ def is_unsafe_zip_member(info: zipfile.ZipInfo) -> bool:
|
||||
name = info.filename
|
||||
if not name:
|
||||
return False
|
||||
path = Path(name)
|
||||
normalized = name.replace("\\", "/")
|
||||
if normalized.startswith("/"):
|
||||
return True
|
||||
path = PurePosixPath(normalized)
|
||||
if path.is_absolute():
|
||||
return True
|
||||
if PureWindowsPath(name).is_absolute():
|
||||
return True
|
||||
if ".." in path.parts:
|
||||
return True
|
||||
return False
|
||||
@@ -90,7 +96,8 @@ def safe_extract_skill_archive(
|
||||
logger.warning("Skipping symlink entry in skill archive: %s", info.filename)
|
||||
continue
|
||||
|
||||
member_path = dest_root / info.filename
|
||||
normalized_name = posixpath.normpath(info.filename.replace("\\", "/"))
|
||||
member_path = dest_root.joinpath(*PurePosixPath(normalized_name).parts)
|
||||
if not member_path.resolve().is_relative_to(dest_root):
|
||||
raise ValueError(f"Zip entry escapes destination: {info.filename!r}")
|
||||
member_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import importlib
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
# ── ensure_thread_dirs ───────────────────────────────────────────────────────
|
||||
@@ -71,3 +73,33 @@ def test_get_thread_mounts_includes_user_data_dirs(tmp_path, monkeypatch):
|
||||
assert "/mnt/user-data/workspace" in container_paths
|
||||
assert "/mnt/user-data/uploads" in container_paths
|
||||
assert "/mnt/user-data/outputs" in container_paths
|
||||
|
||||
|
||||
def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatch):
|
||||
"""Unlock should not run if exclusive locking itself fails."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider = _make_provider(tmp_path)
|
||||
provider._discover_or_create_with_lock = aio_mod.AioSandboxProvider._discover_or_create_with_lock.__get__(
|
||||
provider,
|
||||
aio_mod.AioSandboxProvider,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(
|
||||
aio_mod,
|
||||
"_lock_file_exclusive",
|
||||
lambda _lock_file: (_ for _ in ()).throw(RuntimeError("lock failed")),
|
||||
)
|
||||
|
||||
unlock_calls: list[object] = []
|
||||
monkeypatch.setattr(
|
||||
aio_mod,
|
||||
"_unlock_file",
|
||||
lambda lock_file: unlock_calls.append(lock_file),
|
||||
)
|
||||
|
||||
with patch.object(provider, "_create_sandbox", return_value="sandbox-id"):
|
||||
with pytest.raises(RuntimeError, match="lock failed"):
|
||||
provider._discover_or_create_with_lock("thread-5", "sandbox-5")
|
||||
|
||||
assert unlock_calls == []
|
||||
|
||||
@@ -2146,7 +2146,12 @@ class TestUploadDeleteSymlink:
|
||||
|
||||
# Create a symlink inside uploads dir pointing to outside file.
|
||||
link = uploads_dir / "harmless.txt"
|
||||
link.symlink_to(outside)
|
||||
try:
|
||||
link.symlink_to(outside)
|
||||
except OSError as exc:
|
||||
if getattr(exc, "winerror", None) == 1314:
|
||||
pytest.skip("symlink creation requires Developer Mode or elevated privileges on Windows")
|
||||
raise
|
||||
|
||||
with patch("deerflow.client.get_uploads_dir", return_value=uploads_dir), patch("deerflow.client.ensure_uploads_dir", return_value=uploads_dir):
|
||||
# The resolved path of the symlink escapes uploads_dir,
|
||||
|
||||
@@ -5,9 +5,16 @@ from __future__ import annotations
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from shutil import which
|
||||
|
||||
import pytest
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
SCRIPT_PATH = REPO_ROOT / "scripts" / "docker.sh"
|
||||
BASH_EXECUTABLE = which("bash") or r"C:\Program Files\Git\bin\bash.exe"
|
||||
|
||||
if not Path(BASH_EXECUTABLE).exists():
|
||||
pytestmark = pytest.mark.skip(reason="bash is required for docker.sh detection tests")
|
||||
|
||||
|
||||
def _detect_mode_with_config(config_content: str) -> str:
|
||||
@@ -19,7 +26,7 @@ def _detect_mode_with_config(config_content: str) -> str:
|
||||
command = f"source '{SCRIPT_PATH}' && PROJECT_ROOT='{tmp_root}' && detect_sandbox_mode"
|
||||
|
||||
output = subprocess.check_output(
|
||||
["bash", "-lc", command],
|
||||
[BASH_EXECUTABLE, "-lc", command],
|
||||
text=True,
|
||||
).strip()
|
||||
|
||||
@@ -30,7 +37,7 @@ def test_detect_mode_defaults_to_local_when_config_missing():
|
||||
"""No config file should default to local mode."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
command = f"source '{SCRIPT_PATH}' && PROJECT_ROOT='{tmpdir}' && detect_sandbox_mode"
|
||||
output = subprocess.check_output(["bash", "-lc", command], text=True).strip()
|
||||
output = subprocess.check_output([BASH_EXECUTABLE, "-lc", command], text=True).strip()
|
||||
|
||||
assert output == "local"
|
||||
|
||||
|
||||
@@ -25,6 +25,10 @@ class TestIsUnsafeZipMember:
|
||||
info = zipfile.ZipInfo("/etc/passwd")
|
||||
assert is_unsafe_zip_member(info) is True
|
||||
|
||||
def test_windows_absolute_path(self):
|
||||
info = zipfile.ZipInfo("C:\\Windows\\system32\\drivers\\etc\\hosts")
|
||||
assert is_unsafe_zip_member(info) is True
|
||||
|
||||
def test_dotdot_traversal(self):
|
||||
info = zipfile.ZipInfo("foo/../../../etc/passwd")
|
||||
assert is_unsafe_zip_member(info) is True
|
||||
|
||||
@@ -4,6 +4,10 @@ from langgraph.runtime import Runtime
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
|
||||
|
||||
def _as_posix(path: str) -> str:
|
||||
return path.replace("\\", "/")
|
||||
|
||||
|
||||
class TestThreadDataMiddleware:
|
||||
def test_before_agent_returns_paths_when_thread_id_present_in_context(self, tmp_path):
|
||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||
@@ -11,9 +15,9 @@ class TestThreadDataMiddleware:
|
||||
result = middleware.before_agent(state={}, runtime=Runtime(context={"thread_id": "thread-123"}))
|
||||
|
||||
assert result is not None
|
||||
assert result["thread_data"]["workspace_path"].endswith("threads/thread-123/user-data/workspace")
|
||||
assert result["thread_data"]["uploads_path"].endswith("threads/thread-123/user-data/uploads")
|
||||
assert result["thread_data"]["outputs_path"].endswith("threads/thread-123/user-data/outputs")
|
||||
assert _as_posix(result["thread_data"]["workspace_path"]).endswith("threads/thread-123/user-data/workspace")
|
||||
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-123/user-data/uploads")
|
||||
assert _as_posix(result["thread_data"]["outputs_path"]).endswith("threads/thread-123/user-data/outputs")
|
||||
|
||||
def test_before_agent_uses_thread_id_from_configurable_when_context_is_none(self, tmp_path, monkeypatch):
|
||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||
@@ -26,7 +30,7 @@ class TestThreadDataMiddleware:
|
||||
result = middleware.before_agent(state={}, runtime=runtime)
|
||||
|
||||
assert result is not None
|
||||
assert result["thread_data"]["workspace_path"].endswith("threads/thread-from-config/user-data/workspace")
|
||||
assert _as_posix(result["thread_data"]["workspace_path"]).endswith("threads/thread-from-config/user-data/workspace")
|
||||
assert runtime.context is None
|
||||
|
||||
def test_before_agent_uses_thread_id_from_configurable_when_context_missing_thread_id(self, tmp_path, monkeypatch):
|
||||
@@ -40,7 +44,7 @@ class TestThreadDataMiddleware:
|
||||
result = middleware.before_agent(state={}, runtime=runtime)
|
||||
|
||||
assert result is not None
|
||||
assert result["thread_data"]["uploads_path"].endswith("threads/thread-from-config/user-data/uploads")
|
||||
assert _as_posix(result["thread_data"]["uploads_path"]).endswith("threads/thread-from-config/user-data/uploads")
|
||||
assert runtime.context == {}
|
||||
|
||||
def test_before_agent_raises_clear_error_when_thread_id_missing_everywhere(self, tmp_path, monkeypatch):
|
||||
|
||||
@@ -87,7 +87,12 @@ class TestValidatePathTraversal:
|
||||
target = tmp_path.parent / "secret.txt"
|
||||
target.touch()
|
||||
link = tmp_path / "escape"
|
||||
link.symlink_to(target)
|
||||
try:
|
||||
link.symlink_to(target)
|
||||
except OSError as exc:
|
||||
if getattr(exc, "winerror", None) == 1314:
|
||||
pytest.skip("symlink creation requires Developer Mode or elevated privileges on Windows")
|
||||
raise
|
||||
with pytest.raises(PathTraversalError, match="traversal"):
|
||||
validate_path_traversal(link, tmp_path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user