Compare commits
39 Commits
v2.0-m1-rc0
...
fix-2788
| Author | SHA1 | Date | |
|---|---|---|---|
| dad3997459 | |||
| b67c2a4e56 | |||
| 94da8f67d7 | |||
| 5127f08e1a | |||
| dfa4eb0c1a | |||
| 08ee7adeba | |||
| 1c96a6afc8 | |||
| 417416087b | |||
| 881ff71252 | |||
| f76e4e35c8 | |||
| 0d1053ca44 | |||
| 4063dd7157 | |||
| 7a3c58a733 | |||
| 1edc9d9fae | |||
| 7caf03e97c | |||
| 41b04a556f | |||
| c1b7f1d189 | |||
| 109490da25 | |||
| 14c0a32ee6 | |||
| 70737af7cd | |||
| 2b1fcb3e43 | |||
| 7de9b5828b | |||
| 37db689349 | |||
| bd45cb2846 | |||
| 5fd0e6ac89 | |||
| 530bda7107 | |||
| 6c220a9aef | |||
| daa3ffc29b | |||
| 27559f3675 | |||
| cef4224381 | |||
| 2b0e62f679 | |||
| 1336872b15 | |||
| 4ead2c6b19 | |||
| 59c4a3f0a4 | |||
| e8675f266d | |||
| 680187ddc2 | |||
| aded753de3 | |||
| 028493bfd8 | |||
| 8e48b7e85c |
@@ -48,3 +48,14 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
||||
|
||||
# Set to "false" to disable Swagger UI, ReDoc, and OpenAPI schema in production
|
||||
# GATEWAY_ENABLE_DOCS=false
|
||||
|
||||
# ── Frontend SSR → Gateway wiring ─────────────────────────────────────────────
|
||||
# The Next.js server uses these to reach the Gateway during SSR (auth checks,
|
||||
# /api/* rewrites). They default to localhost values that match `make dev` and
|
||||
# `make start`, so most local users do not need to set them.
|
||||
#
|
||||
# Override only when the Gateway is not on localhost:8001 (e.g. when the
|
||||
# frontend and gateway run on different hosts, in containers with a service
|
||||
# alias, or behind a different port). docker-compose already sets these.
|
||||
# DEER_FLOW_INTERNAL_GATEWAY_BASE_URL=http://localhost:8001
|
||||
# DEER_FLOW_TRUSTED_ORIGINS=http://localhost:3000,http://localhost:2026
|
||||
|
||||
+6
-2
@@ -263,8 +263,10 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
- `present_files` - Make output files visible to user (only `/mnt/user-data/outputs`)
|
||||
- `ask_clarification` - Request clarification (intercepted by ClarificationMiddleware → interrupts)
|
||||
- `view_image` - Read image as base64 (added only if model supports vision)
|
||||
- `setup_agent` - Bootstrap-only: persist a brand-new custom agent's `SOUL.md` and `config.yaml`. Bound only when `is_bootstrap=True`.
|
||||
- `update_agent` - Custom-agent-only: persist self-updates to the current agent's `SOUL.md` / `config.yaml` from inside a normal chat (partial update + atomic write). Bound when `agent_name` is set and `is_bootstrap=False`.
|
||||
4. **Subagent tool** (if enabled):
|
||||
- `task` - Delegate to subagent (description, prompt, subagent_type, max_turns)
|
||||
- `task` - Delegate to subagent (description, prompt, subagent_type)
|
||||
|
||||
**Community tools** (`packages/harness/deerflow/community/`):
|
||||
- `tavily/` - Web search (5 results default) and web fetch (4KB limit)
|
||||
@@ -354,10 +356,11 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
||||
**Per-User Isolation**:
|
||||
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
|
||||
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
|
||||
- Custom agent definitions (`SOUL.md` + `config.yaml`) are also per-user at `{base_dir}/users/{user_id}/agents/{agent_name}/`. The legacy shared layout `{base_dir}/agents/{agent_name}/` remains read-only fallback for unmigrated installations
|
||||
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
|
||||
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
|
||||
- Absolute `storage_path` in config opts out of per-user isolation
|
||||
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
|
||||
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json`, `threads/`, and `agents/` into per-user layout. Supports `--dry-run` (preview changes) and `--user-id USER_ID` (assign unowned legacy data to a user, defaults to `default`).
|
||||
|
||||
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
|
||||
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
||||
@@ -517,6 +520,7 @@ Multi-file upload with automatic document conversion:
|
||||
- Rejects directory inputs before copying so uploads stay all-or-nothing
|
||||
- Reuses one conversion worker per request when called from an active event loop
|
||||
- Files stored in thread-isolated directories
|
||||
- Duplicate filenames in a single upload request are auto-renamed with `_N` suffixes so later files do not truncate earlier files
|
||||
- Agent receives uploaded file list via `UploadsMiddleware`
|
||||
|
||||
See [docs/FILE_UPLOAD.md](docs/FILE_UPLOAD.md) for details.
|
||||
|
||||
+1
-1
@@ -124,7 +124,7 @@ FastAPI application providing REST endpoints for frontend integration:
|
||||
| `POST /api/memory/reload` | Force memory reload |
|
||||
| `GET /api/memory/config` | Memory configuration |
|
||||
| `GET /api/memory/status` | Combined config + data |
|
||||
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths) |
|
||||
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths, auto-renames duplicate filenames in one request) |
|
||||
| `GET /api/threads/{id}/uploads/list` | List uploaded files |
|
||||
| `DELETE /api/threads/{id}` | Delete DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
||||
| `GET /api/threads/{id}/artifacts/{path}` | Serve generated artifacts |
|
||||
|
||||
@@ -146,6 +146,13 @@ def _normalize_custom_agent_name(raw_value: str) -> str:
|
||||
return normalized
|
||||
|
||||
|
||||
def _strip_loop_warning_text(text: str) -> str:
|
||||
"""Remove middleware-authored loop warning lines from display text."""
|
||||
if "[LOOP DETECTED]" not in text:
|
||||
return text
|
||||
return "\n".join(line for line in text.splitlines() if "[LOOP DETECTED]" not in line).strip()
|
||||
|
||||
|
||||
def _extract_response_text(result: dict | list) -> str:
|
||||
"""Extract the last AI message text from a LangGraph runs.wait result.
|
||||
|
||||
@@ -155,7 +162,7 @@ def _extract_response_text(result: dict | list) -> str:
|
||||
Handles special cases:
|
||||
- Regular AI text responses
|
||||
- Clarification interrupts (``ask_clarification`` tool messages)
|
||||
- AI messages with tool_calls but no text content
|
||||
- Strips loop-detection warnings attached to tool-call AI messages
|
||||
"""
|
||||
if isinstance(result, list):
|
||||
messages = result
|
||||
@@ -185,7 +192,12 @@ def _extract_response_text(result: dict | list) -> str:
|
||||
# Regular AI message with text content
|
||||
if msg_type == "ai":
|
||||
content = msg.get("content", "")
|
||||
has_tool_calls = bool(msg.get("tool_calls"))
|
||||
if isinstance(content, str) and content:
|
||||
if has_tool_calls:
|
||||
content = _strip_loop_warning_text(content)
|
||||
if not content:
|
||||
continue
|
||||
return content
|
||||
# content can be a list of content blocks
|
||||
if isinstance(content, list):
|
||||
@@ -196,6 +208,8 @@ def _extract_response_text(result: dict | list) -> str:
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
text = "".join(parts)
|
||||
if has_tool_calls:
|
||||
text = _strip_loop_warning_text(text)
|
||||
if text:
|
||||
return text
|
||||
return ""
|
||||
@@ -589,6 +603,17 @@ class ChannelManager:
|
||||
user_layer.get("config"),
|
||||
)
|
||||
|
||||
configurable = run_config.get("configurable")
|
||||
if isinstance(configurable, Mapping):
|
||||
configurable = dict(configurable)
|
||||
else:
|
||||
configurable = {}
|
||||
run_config["configurable"] = configurable
|
||||
# Pin channel-triggered runs to the root graph namespace so follow-up
|
||||
# turns continue from the same conversation checkpoint.
|
||||
configurable["checkpoint_ns"] = ""
|
||||
configurable["thread_id"] = thread_id
|
||||
|
||||
run_context = _merge_dicts(
|
||||
DEFAULT_RUN_CONTEXT,
|
||||
self._default_session.get("context"),
|
||||
@@ -972,7 +997,11 @@ class ChannelManager:
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as http:
|
||||
resp = await http.get(f"{self._gateway_url}{path}", timeout=10)
|
||||
resp = await http.get(
|
||||
f"{self._gateway_url}{path}",
|
||||
timeout=10,
|
||||
headers=create_internal_auth_headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
|
||||
@@ -4,8 +4,10 @@ Per RFC-001:
|
||||
State-changing operations require CSRF protection.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
@@ -19,7 +21,7 @@ CSRF_TOKEN_LENGTH = 64 # bytes
|
||||
|
||||
def is_secure_request(request: Request) -> bool:
|
||||
"""Detect whether the original client request was made over HTTPS."""
|
||||
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
||||
return _request_scheme(request) == "https"
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
@@ -61,6 +63,109 @@ def is_auth_endpoint(request: Request) -> bool:
|
||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||
|
||||
|
||||
def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str:
|
||||
"""Return normalized host[:port], omitting default ports."""
|
||||
host = hostname.lower()
|
||||
if ":" in host and not host.startswith("["):
|
||||
host = f"[{host}]"
|
||||
|
||||
if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
|
||||
return host
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
def _normalize_origin(origin: str) -> str | None:
|
||||
"""Return a normalized scheme://host[:port] origin, or None for invalid input."""
|
||||
try:
|
||||
parsed = urlsplit(origin.strip())
|
||||
port = parsed.port
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
scheme = parsed.scheme.lower()
|
||||
if scheme not in {"http", "https"} or not parsed.hostname:
|
||||
return None
|
||||
|
||||
# Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values.
|
||||
if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment:
|
||||
return None
|
||||
|
||||
return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}"
|
||||
|
||||
|
||||
def _configured_cors_origins() -> set[str]:
|
||||
"""Return explicit configured browser origins that may call auth routes."""
|
||||
origins = set()
|
||||
for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","):
|
||||
origin = raw_origin.strip()
|
||||
if not origin or origin == "*":
|
||||
continue
|
||||
normalized = _normalize_origin(origin)
|
||||
if normalized:
|
||||
origins.add(normalized)
|
||||
return origins
|
||||
|
||||
|
||||
def _first_header_value(value: str | None) -> str | None:
|
||||
"""Return the first value from a comma-separated proxy header."""
|
||||
if not value:
|
||||
return None
|
||||
first = value.split(",", 1)[0].strip()
|
||||
return first or None
|
||||
|
||||
|
||||
def _forwarded_param(request: Request, name: str) -> str | None:
|
||||
"""Extract a parameter from the first RFC 7239 Forwarded header entry."""
|
||||
forwarded = _first_header_value(request.headers.get("forwarded"))
|
||||
if not forwarded:
|
||||
return None
|
||||
|
||||
for part in forwarded.split(";"):
|
||||
key, sep, value = part.strip().partition("=")
|
||||
if sep and key.lower() == name:
|
||||
return value.strip().strip('"') or None
|
||||
return None
|
||||
|
||||
|
||||
def _request_scheme(request: Request) -> str:
|
||||
"""Resolve the original request scheme from trusted proxy headers."""
|
||||
scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme
|
||||
return scheme.lower()
|
||||
|
||||
|
||||
def _request_origin(request: Request) -> str | None:
|
||||
"""Build the origin for the URL the browser is targeting."""
|
||||
scheme = _request_scheme(request)
|
||||
host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc
|
||||
|
||||
forwarded_port = _first_header_value(request.headers.get("x-forwarded-port"))
|
||||
if forwarded_port and ":" not in host.rsplit("]", 1)[-1]:
|
||||
host = f"{host}:{forwarded_port}"
|
||||
|
||||
return _normalize_origin(f"{scheme}://{host}")
|
||||
|
||||
|
||||
def is_allowed_auth_origin(request: Request) -> bool:
|
||||
"""Allow auth POSTs only from the same origin or explicit configured origins.
|
||||
|
||||
Login/register/initialize are exempt from the double-submit token because
|
||||
first-time browser clients do not have a CSRF token yet. They still create
|
||||
a session cookie, so browser requests with a hostile Origin header must be
|
||||
rejected to prevent login CSRF / session fixation. Requests without Origin
|
||||
are allowed for non-browser clients such as curl and mobile integrations.
|
||||
"""
|
||||
origin = request.headers.get("origin")
|
||||
if not origin:
|
||||
return True
|
||||
|
||||
normalized_origin = _normalize_origin(origin)
|
||||
if normalized_origin is None:
|
||||
return False
|
||||
|
||||
request_origin = _request_origin(request)
|
||||
return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin)
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
||||
|
||||
@@ -70,6 +175,12 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
_is_auth = is_auth_endpoint(request)
|
||||
|
||||
if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Cross-site auth request denied."},
|
||||
)
|
||||
|
||||
if should_check_csrf(request) and not _is_auth:
|
||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||
header_token = request.headers.get(CSRF_HEADER_NAME)
|
||||
|
||||
@@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
|
||||
from deerflow.config.agents_api_config import get_agents_api_config
|
||||
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api", tags=["agents"])
|
||||
@@ -86,11 +87,11 @@ def _require_agents_api_enabled() -> None:
|
||||
)
|
||||
|
||||
|
||||
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
|
||||
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False, *, user_id: str | None = None) -> AgentResponse:
|
||||
"""Convert AgentConfig to AgentResponse."""
|
||||
soul: str | None = None
|
||||
if include_soul:
|
||||
soul = load_agent_soul(agent_cfg.name) or ""
|
||||
soul = load_agent_soul(agent_cfg.name, user_id=user_id) or ""
|
||||
|
||||
return AgentResponse(
|
||||
name=agent_cfg.name,
|
||||
@@ -116,9 +117,10 @@ async def list_agents() -> AgentsListResponse:
|
||||
"""
|
||||
_require_agents_api_enabled()
|
||||
|
||||
user_id = get_effective_user_id()
|
||||
try:
|
||||
agents = list_custom_agents()
|
||||
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
|
||||
agents = list_custom_agents(user_id=user_id)
|
||||
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True, user_id=user_id) for a in agents])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list agents: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
|
||||
@@ -144,7 +146,12 @@ async def check_agent_name(name: str) -> dict:
|
||||
_require_agents_api_enabled()
|
||||
_validate_agent_name(name)
|
||||
normalized = _normalize_agent_name(name)
|
||||
available = not get_paths().agent_dir(normalized).exists()
|
||||
user_id = get_effective_user_id()
|
||||
paths = get_paths()
|
||||
# Treat the name as taken if either the per-user path or the legacy shared
|
||||
# path holds an agent — picking a name that collides with an unmigrated
|
||||
# legacy agent would shadow the legacy entry once migration runs.
|
||||
available = not paths.user_agent_dir(user_id, normalized).exists() and not paths.agent_dir(normalized).exists()
|
||||
return {"available": available, "name": normalized}
|
||||
|
||||
|
||||
@@ -169,10 +176,11 @@ async def get_agent(name: str) -> AgentResponse:
|
||||
_require_agents_api_enabled()
|
||||
_validate_agent_name(name)
|
||||
name = _normalize_agent_name(name)
|
||||
user_id = get_effective_user_id()
|
||||
|
||||
try:
|
||||
agent_cfg = load_agent_config(name)
|
||||
return _agent_config_to_response(agent_cfg, include_soul=True)
|
||||
agent_cfg = load_agent_config(name, user_id=user_id)
|
||||
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||
except Exception as e:
|
||||
@@ -202,10 +210,13 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
||||
_require_agents_api_enabled()
|
||||
_validate_agent_name(request.name)
|
||||
normalized_name = _normalize_agent_name(request.name)
|
||||
user_id = get_effective_user_id()
|
||||
paths = get_paths()
|
||||
|
||||
agent_dir = get_paths().agent_dir(normalized_name)
|
||||
agent_dir = paths.user_agent_dir(user_id, normalized_name)
|
||||
legacy_dir = paths.agent_dir(normalized_name)
|
||||
|
||||
if agent_dir.exists():
|
||||
if agent_dir.exists() or legacy_dir.exists():
|
||||
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
|
||||
|
||||
try:
|
||||
@@ -232,8 +243,8 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
||||
|
||||
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
|
||||
|
||||
agent_cfg = load_agent_config(normalized_name)
|
||||
return _agent_config_to_response(agent_cfg, include_soul=True)
|
||||
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
|
||||
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -267,13 +278,20 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
|
||||
_require_agents_api_enabled()
|
||||
_validate_agent_name(name)
|
||||
name = _normalize_agent_name(name)
|
||||
user_id = get_effective_user_id()
|
||||
|
||||
try:
|
||||
agent_cfg = load_agent_config(name)
|
||||
agent_cfg = load_agent_config(name, user_id=user_id)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||
|
||||
agent_dir = get_paths().agent_dir(name)
|
||||
paths = get_paths()
|
||||
agent_dir = paths.user_agent_dir(user_id, name)
|
||||
if not agent_dir.exists() and paths.agent_dir(name).exists():
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before updating."),
|
||||
)
|
||||
|
||||
try:
|
||||
# Update config if any config fields changed
|
||||
@@ -314,8 +332,8 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
|
||||
|
||||
logger.info(f"Updated agent '{name}'")
|
||||
|
||||
refreshed_cfg = load_agent_config(name)
|
||||
return _agent_config_to_response(refreshed_cfg, include_soul=True)
|
||||
refreshed_cfg = load_agent_config(name, user_id=user_id)
|
||||
return _agent_config_to_response(refreshed_cfg, include_soul=True, user_id=user_id)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -402,15 +420,22 @@ async def delete_agent(name: str) -> None:
|
||||
name: The agent name.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if agent not found.
|
||||
HTTPException: 404 if no per-user copy exists; 409 if only a legacy
|
||||
shared copy exists (suggesting the migration script).
|
||||
"""
|
||||
_require_agents_api_enabled()
|
||||
_validate_agent_name(name)
|
||||
name = _normalize_agent_name(name)
|
||||
|
||||
agent_dir = get_paths().agent_dir(name)
|
||||
user_id = get_effective_user_id()
|
||||
paths = get_paths()
|
||||
agent_dir = paths.user_agent_dir(user_id, name)
|
||||
|
||||
if not agent_dir.exists():
|
||||
if paths.agent_dir(name).exists():
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
|
||||
)
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||
|
||||
try:
|
||||
|
||||
@@ -68,6 +68,27 @@ class RunResponse(BaseModel):
|
||||
updated_at: str = ""
|
||||
|
||||
|
||||
class ThreadTokenUsageModelBreakdown(BaseModel):
|
||||
tokens: int = 0
|
||||
runs: int = 0
|
||||
|
||||
|
||||
class ThreadTokenUsageCallerBreakdown(BaseModel):
|
||||
lead_agent: int = 0
|
||||
subagent: int = 0
|
||||
middleware: int = 0
|
||||
|
||||
|
||||
class ThreadTokenUsageResponse(BaseModel):
|
||||
thread_id: str
|
||||
total_tokens: int = 0
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_runs: int = 0
|
||||
by_model: dict[str, ThreadTokenUsageModelBreakdown] = Field(default_factory=dict)
|
||||
by_caller: ThreadTokenUsageCallerBreakdown = Field(default_factory=ThreadTokenUsageCallerBreakdown)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -368,10 +389,10 @@ async def list_run_events(
|
||||
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/token-usage")
|
||||
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def thread_token_usage(thread_id: str, request: Request) -> dict:
|
||||
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
|
||||
"""Thread-level token usage aggregation."""
|
||||
run_store = get_run_store(request)
|
||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
||||
return {"thread_id": thread_id, **agg}
|
||||
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
|
||||
|
||||
@@ -16,6 +16,7 @@ from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provi
|
||||
from deerflow.uploads.manager import (
|
||||
PathTraversalError,
|
||||
UnsafeUploadPathError,
|
||||
claim_unique_filename,
|
||||
delete_file_safe,
|
||||
enrich_file_listing,
|
||||
ensure_uploads_dir,
|
||||
@@ -192,6 +193,10 @@ async def upload_files(
|
||||
sandbox_sync_targets = []
|
||||
skipped_files = []
|
||||
total_size = 0
|
||||
# Track filenames within this request so duplicate form parts do not
|
||||
# silently truncate each other. Existing uploads keep the historical
|
||||
# overwrite behavior for a single replacement upload.
|
||||
seen_filenames: set[str] = set()
|
||||
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
|
||||
@@ -208,7 +213,8 @@ async def upload_files(
|
||||
continue
|
||||
|
||||
try:
|
||||
safe_filename = normalize_filename(file.filename)
|
||||
original_filename = normalize_filename(file.filename)
|
||||
safe_filename = claim_unique_filename(original_filename, seen_filenames)
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
||||
continue
|
||||
@@ -236,6 +242,8 @@ async def upload_files(
|
||||
"virtual_path": virtual_path,
|
||||
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
||||
}
|
||||
if safe_filename != original_filename:
|
||||
file_info["original_filename"] = original_filename
|
||||
|
||||
logger.info(f"Saved file: {safe_filename} ({file_size} bytes) to {file_info['path']}")
|
||||
|
||||
|
||||
@@ -136,6 +136,24 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
|
||||
runtime_context.setdefault(key, context[key])
|
||||
|
||||
|
||||
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
|
||||
"""Stamp the authenticated user into the run context for background tools.
|
||||
|
||||
Tool execution may happen after the request handler has returned, so tools
|
||||
that persist user-scoped files should not rely only on ambient ContextVars.
|
||||
The value comes from server-side auth state, never from client context.
|
||||
"""
|
||||
|
||||
user = getattr(request.state, "user", None)
|
||||
user_id = getattr(user, "id", None)
|
||||
if user_id is None:
|
||||
return
|
||||
|
||||
runtime_context = config.setdefault("context", {})
|
||||
if isinstance(runtime_context, dict):
|
||||
runtime_context["user_id"] = str(user_id)
|
||||
|
||||
|
||||
def resolve_agent_factory(assistant_id: str | None):
|
||||
"""Resolve the agent factory callable from config.
|
||||
|
||||
@@ -288,6 +306,7 @@ async def start_run(
|
||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||
inject_authenticated_user_context(config, request)
|
||||
|
||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||
|
||||
|
||||
@@ -79,7 +79,9 @@ async def main():
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents import make_lead_agent
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.mcp import initialize_mcp_tools
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
# Initialize MCP tools at startup
|
||||
try:
|
||||
@@ -113,6 +115,8 @@ async def main():
|
||||
print("Tip: `uv sync --group dev` to enable arrow-key & history support")
|
||||
print("=" * 50)
|
||||
|
||||
seen_artifacts: set[str] = set()
|
||||
|
||||
while True:
|
||||
try:
|
||||
if session:
|
||||
@@ -134,6 +138,22 @@ async def main():
|
||||
last_message = result["messages"][-1]
|
||||
print(f"\nAgent: {last_message.content}")
|
||||
|
||||
# Show files presented to the user this turn (new artifacts only)
|
||||
artifacts = result.get("artifacts") or []
|
||||
new_artifacts = [p for p in artifacts if p not in seen_artifacts]
|
||||
if new_artifacts:
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
user_id = get_effective_user_id()
|
||||
paths = get_paths()
|
||||
print("\n[Presented files]")
|
||||
for virtual in new_artifacts:
|
||||
try:
|
||||
physical = paths.resolve_virtual_path(thread_id, virtual, user_id=user_id)
|
||||
print(f" - {virtual}\n → {physical}")
|
||||
except ValueError as exc:
|
||||
print(f" - {virtual} (failed to resolve physical path: {exc})")
|
||||
seen_artifacts.update(new_artifacts)
|
||||
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
|
||||
@@ -173,7 +173,7 @@ def _assemble_from_features(
|
||||
9. MemoryMiddleware (memory feature)
|
||||
10. ViewImageMiddleware (vision feature)
|
||||
11. SubagentLimitMiddleware (subagent feature)
|
||||
12. LoopDetectionMiddleware (always)
|
||||
12. LoopDetectionMiddleware (loop_detection feature)
|
||||
13. ClarificationMiddleware (always last)
|
||||
|
||||
Two-phase ordering:
|
||||
@@ -272,10 +272,15 @@ def _assemble_from_features(
|
||||
|
||||
extra_tools.append(task_tool)
|
||||
|
||||
# --- [12] LoopDetection (always) ---
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
# --- [12] LoopDetection ---
|
||||
if feat.loop_detection is not False:
|
||||
if isinstance(feat.loop_detection, AgentMiddleware):
|
||||
chain.append(feat.loop_detection)
|
||||
else:
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
|
||||
chain.append(LoopDetectionMiddleware())
|
||||
chain.append(LoopDetectionMiddleware.from_config(LoopDetectionConfig()))
|
||||
|
||||
# --- [13] Clarification (always last among built-ins) ---
|
||||
chain.append(ClarificationMiddleware())
|
||||
|
||||
@@ -31,6 +31,7 @@ class RuntimeFeatures:
|
||||
vision: bool | AgentMiddleware = False
|
||||
auto_title: bool | AgentMiddleware = False
|
||||
guardrail: Literal[False] | AgentMiddleware = False
|
||||
loop_detection: bool | AgentMiddleware = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -20,6 +20,8 @@ from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.agents_config import load_agent_config, validate_agent_name
|
||||
from deerflow.config.app_config import AppConfig, get_app_config
|
||||
from deerflow.models import create_chat_model
|
||||
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -256,6 +258,12 @@ def _build_middlewares(
|
||||
resolved_app_config = app_config or get_app_config()
|
||||
middlewares = build_lead_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
|
||||
|
||||
# Always inject current date (and optionally memory) as <system-reminder> into the
|
||||
# first HumanMessage to keep the system prompt fully static for prefix-cache reuse.
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import DynamicContextMiddleware
|
||||
|
||||
middlewares.append(DynamicContextMiddleware(agent_name=agent_name, app_config=resolved_app_config))
|
||||
|
||||
# Add summarization middleware if enabled
|
||||
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
|
||||
if summarization_middleware is not None:
|
||||
@@ -297,7 +305,9 @@ def _build_middlewares(
|
||||
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
|
||||
|
||||
# LoopDetectionMiddleware — detect and break repetitive tool call loops
|
||||
middlewares.append(LoopDetectionMiddleware())
|
||||
loop_detection_config = resolved_app_config.loop_detection
|
||||
if loop_detection_config.enabled:
|
||||
middlewares.append(LoopDetectionMiddleware.from_config(loop_detection_config))
|
||||
|
||||
# Inject custom middlewares before ClarificationMiddleware
|
||||
if custom_middlewares:
|
||||
@@ -308,6 +318,28 @@ def _build_middlewares(
|
||||
return middlewares
|
||||
|
||||
|
||||
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
|
||||
if is_bootstrap:
|
||||
return {"bootstrap"}
|
||||
if agent_config and agent_config.skills is not None:
|
||||
return set(agent_config.skills)
|
||||
return None
|
||||
|
||||
|
||||
def _load_enabled_skills_for_tool_policy(available_skills: set[str] | None, *, app_config: AppConfig) -> list[Skill]:
|
||||
try:
|
||||
from deerflow.agents.lead_agent.prompt import get_enabled_skills_for_config
|
||||
|
||||
skills = get_enabled_skills_for_config(app_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to load skills for allowed-tools policy")
|
||||
raise
|
||||
|
||||
if available_skills is None:
|
||||
return skills
|
||||
return [skill for skill in skills if skill.name in available_skills]
|
||||
|
||||
|
||||
def make_lead_agent(config: RunnableConfig):
|
||||
"""LangGraph graph factory; keep the signature compatible with LangGraph Server."""
|
||||
runtime_config = _get_runtime_config(config)
|
||||
@@ -318,7 +350,7 @@ def make_lead_agent(config: RunnableConfig):
|
||||
def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
# Lazy import to avoid circular dependency
|
||||
from deerflow.tools import get_available_tools
|
||||
from deerflow.tools.builtins import setup_agent
|
||||
from deerflow.tools.builtins import setup_agent, update_agent
|
||||
|
||||
cfg = _get_runtime_config(config)
|
||||
resolved_app_config = app_config
|
||||
@@ -333,6 +365,7 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
agent_name = validate_agent_name(cfg.get("agent_name"))
|
||||
|
||||
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
|
||||
available_skills = _available_skill_names(agent_config, is_bootstrap)
|
||||
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
|
||||
agent_model_name = agent_config.model if agent_config and agent_config.model else None
|
||||
|
||||
@@ -371,15 +404,18 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
"is_plan_mode": is_plan_mode,
|
||||
"subagent_enabled": subagent_enabled,
|
||||
"tool_groups": agent_config.tool_groups if agent_config else None,
|
||||
"available_skills": ["bootstrap"] if is_bootstrap else (agent_config.skills if agent_config and agent_config.skills is not None else None),
|
||||
"available_skills": sorted(available_skills) if available_skills is not None else None,
|
||||
}
|
||||
)
|
||||
|
||||
skills_for_tool_policy = _load_enabled_skills_for_tool_policy(available_skills, app_config=resolved_app_config)
|
||||
|
||||
if is_bootstrap:
|
||||
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
||||
tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent],
|
||||
tools=filter_tools_by_skill_allowed_tools(tools, skills_for_tool_policy),
|
||||
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled,
|
||||
@@ -390,15 +426,14 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
# Custom agents can update their own SOUL.md / config via update_agent.
|
||||
# The default agent (no agent_name) does not see this tool.
|
||||
extra_tools = [update_agent] if agent_name else []
|
||||
# Default lead agent (unchanged behavior)
|
||||
tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config),
|
||||
tools=get_available_tools(
|
||||
model_name=model_name,
|
||||
groups=agent_config.tool_groups if agent_config else None,
|
||||
subagent_enabled=subagent_enabled,
|
||||
app_config=resolved_app_config,
|
||||
),
|
||||
tools=filter_tools_by_skill_allowed_tools(tools + extra_tools, skills_for_tool_policy),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled,
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -20,6 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
|
||||
_enabled_skills_lock = threading.Lock()
|
||||
_enabled_skills_cache: list[Skill] | None = None
|
||||
_enabled_skills_by_config_cache: dict[int, tuple[object, list[Skill]]] = {}
|
||||
_enabled_skills_refresh_active = False
|
||||
_enabled_skills_refresh_version = 0
|
||||
_enabled_skills_refresh_event = threading.Event()
|
||||
@@ -84,6 +84,7 @@ def _invalidate_enabled_skills_cache() -> threading.Event:
|
||||
_get_cached_skills_prompt_section.cache_clear()
|
||||
with _enabled_skills_lock:
|
||||
_enabled_skills_cache = None
|
||||
_enabled_skills_by_config_cache.clear()
|
||||
_enabled_skills_refresh_version += 1
|
||||
_enabled_skills_refresh_event.clear()
|
||||
if _enabled_skills_refresh_active:
|
||||
@@ -107,6 +108,15 @@ def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_W
|
||||
|
||||
|
||||
def _get_enabled_skills():
|
||||
return get_cached_enabled_skills()
|
||||
|
||||
|
||||
def get_cached_enabled_skills() -> list[Skill]:
|
||||
"""Return the cached enabled-skills list, kicking off a background refresh on miss.
|
||||
|
||||
Safe to call from request paths: never blocks on disk I/O. Returns an empty
|
||||
list on cache miss; the next call will see the warmed result.
|
||||
"""
|
||||
with _enabled_skills_lock:
|
||||
cached = _enabled_skills_cache
|
||||
|
||||
@@ -117,17 +127,29 @@ def _get_enabled_skills():
|
||||
return []
|
||||
|
||||
|
||||
def _get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
|
||||
def get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
|
||||
"""Return enabled skills using the caller's config source.
|
||||
|
||||
When a concrete ``app_config`` is supplied, bypass the global enabled-skills
|
||||
cache so the skill list and skill paths are resolved from the same config
|
||||
object. This keeps request-scoped config injection consistent even while the
|
||||
release branch still supports global fallback paths.
|
||||
When a concrete ``app_config`` is supplied, cache the loaded skills by that
|
||||
config object's identity so request-scoped config injection still resolves
|
||||
skill paths from the matching config without rescanning storage on every
|
||||
agent factory call.
|
||||
"""
|
||||
if app_config is None:
|
||||
return _get_enabled_skills()
|
||||
return list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
|
||||
|
||||
cache_key = id(app_config)
|
||||
with _enabled_skills_lock:
|
||||
cached = _enabled_skills_by_config_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
cached_config, cached_skills = cached
|
||||
if cached_config is app_config:
|
||||
return list(cached_skills)
|
||||
|
||||
skills = list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
|
||||
with _enabled_skills_lock:
|
||||
_enabled_skills_by_config_cache[cache_key] = (app_config, skills)
|
||||
return list(skills)
|
||||
|
||||
|
||||
def _skill_mutability_label(category: SkillCategory | str) -> str:
|
||||
@@ -344,8 +366,7 @@ You are {agent_name}, an open-source super agent.
|
||||
</role>
|
||||
|
||||
{soul}
|
||||
{memory_context}
|
||||
|
||||
{self_update_section}
|
||||
<thinking_style>
|
||||
- Think concisely and strategically about the user's request BEFORE taking action
|
||||
- Break down the task: What is clear? What is ambiguous? What is missing?
|
||||
@@ -604,7 +625,7 @@ You have access to skills that provide optimized workflows for specific tasks. E
|
||||
|
||||
def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_config: AppConfig | None = None) -> str:
|
||||
"""Generate the skills prompt section with available skills list."""
|
||||
skills = _get_enabled_skills_for_config(app_config)
|
||||
skills = get_enabled_skills_for_config(app_config)
|
||||
|
||||
if app_config is None:
|
||||
try:
|
||||
@@ -643,6 +664,26 @@ def get_agent_soul(agent_name: str | None) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _build_self_update_section(agent_name: str | None) -> str:
|
||||
"""Prompt block that teaches the custom agent to persist self-updates via update_agent."""
|
||||
if not agent_name:
|
||||
return ""
|
||||
return f"""<self_update>
|
||||
You are running as the custom agent **{agent_name}** with a persisted SOUL.md and config.yaml.
|
||||
|
||||
When the user asks you to update your own description, personality, behaviour, skill set, tool groups, or default model,
|
||||
you MUST persist the change with the `update_agent` tool. Do NOT use `bash`, `write_file`, or any sandbox tool to edit
|
||||
SOUL.md or config.yaml — those write into a temporary sandbox/tool workspace and the changes will be lost on the next turn.
|
||||
|
||||
Rules:
|
||||
- Always pass the FULL replacement text for `soul` (no patch semantics). Start from your current SOUL above and apply the user's edits.
|
||||
- Only pass the fields that should change. Omit the others to preserve them.
|
||||
- Pass `skills=[]` to disable all skills, or omit `skills` to keep the existing whitelist.
|
||||
- After `update_agent` returns successfully, tell the user the change is persisted and will take effect on the next turn.
|
||||
</self_update>
|
||||
"""
|
||||
|
||||
|
||||
def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
|
||||
"""Generate <available-deferred-tools> block for the system prompt.
|
||||
|
||||
@@ -732,9 +773,6 @@ def apply_prompt_template(
|
||||
available_skills: set[str] | None = None,
|
||||
app_config: AppConfig | None = None,
|
||||
) -> str:
|
||||
# Get memory context
|
||||
memory_context = _get_memory_context(agent_name, app_config=app_config)
|
||||
|
||||
# Include subagent section only if enabled (from runtime parameter)
|
||||
n = max_concurrent_subagents
|
||||
subagent_section = _build_subagent_section(n, app_config=app_config) if subagent_enabled else ""
|
||||
@@ -768,17 +806,18 @@ def apply_prompt_template(
|
||||
custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
|
||||
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
|
||||
|
||||
# Format the prompt with dynamic skills and memory
|
||||
prompt = SYSTEM_PROMPT_TEMPLATE.format(
|
||||
# Build and return the fully static system prompt.
|
||||
# Memory and current date are injected per-turn via DynamicContextMiddleware
|
||||
# as a <system-reminder> in the first HumanMessage, keeping this prompt
|
||||
# identical across users and sessions for maximum prefix-cache reuse.
|
||||
return SYSTEM_PROMPT_TEMPLATE.format(
|
||||
agent_name=agent_name or "DeerFlow 2.0",
|
||||
soul=get_agent_soul(agent_name),
|
||||
self_update_section=_build_self_update_section(agent_name),
|
||||
skills_section=skills_section,
|
||||
deferred_tools_section=deferred_tools_section,
|
||||
memory_context=memory_context,
|
||||
subagent_section=subagent_section,
|
||||
subagent_reminder=subagent_reminder,
|
||||
subagent_thinking=subagent_thinking,
|
||||
acp_section=acp_and_mounts_section,
|
||||
)
|
||||
|
||||
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
"""Middleware to inject dynamic context (memory, current date) as a system-reminder.
|
||||
|
||||
The system prompt is kept fully static for maximum prefix-cache reuse across users
|
||||
and sessions. The current date is always injected. Per-user memory is also injected
|
||||
when ``memory.injection_enabled`` is True in the app config. Both are delivered once
|
||||
per conversation as a dedicated <system-reminder> HumanMessage inserted before the
|
||||
first user message (frozen-snapshot pattern).
|
||||
|
||||
When a conversation spans midnight the middleware detects the date change and injects
|
||||
a lightweight date-update reminder as a separate HumanMessage before the current turn.
|
||||
This correction is persisted so subsequent turns on the new day see a consistent history
|
||||
and do not re-inject.
|
||||
|
||||
Reminder format:
|
||||
|
||||
<system-reminder>
|
||||
<memory>...</memory>
|
||||
|
||||
<current_date>2026-05-08, Friday</current_date>
|
||||
</system-reminder>
|
||||
|
||||
Date-update format:
|
||||
|
||||
<system-reminder>
|
||||
<current_date>2026-05-09, Saturday</current_date>
|
||||
</system-reminder>
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
|
||||
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
|
||||
_SUMMARY_MESSAGE_NAME = "summary"
|
||||
|
||||
|
||||
def _extract_date(content: str) -> str | None:
|
||||
"""Return the first <current_date> value found in *content*, or None."""
|
||||
m = _DATE_RE.search(content)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def is_dynamic_context_reminder(message: object) -> bool:
|
||||
"""Return whether *message* is a hidden dynamic-context reminder."""
|
||||
return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY))
|
||||
|
||||
|
||||
def _last_injected_date(messages: list) -> str | None:
|
||||
"""Scan messages in reverse and return the most recently injected date.
|
||||
|
||||
Detection uses the ``dynamic_context_reminder`` additional_kwargs flag rather
|
||||
than content substring matching, so user messages containing ``<system-reminder>``
|
||||
are not mistakenly treated as injected reminders.
|
||||
"""
|
||||
for msg in reversed(messages):
|
||||
if is_dynamic_context_reminder(msg):
|
||||
content_str = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
return _extract_date(content_str)
|
||||
return None
|
||||
|
||||
|
||||
def _is_user_injection_target(message: object) -> bool:
|
||||
"""Return whether *message* can receive a dynamic-context reminder."""
|
||||
return isinstance(message, HumanMessage) and not is_dynamic_context_reminder(message) and message.name != _SUMMARY_MESSAGE_NAME
|
||||
|
||||
|
||||
class DynamicContextMiddleware(AgentMiddleware):
|
||||
"""Inject memory and current date into HumanMessages as a <system-reminder>.
|
||||
|
||||
First turn
|
||||
----------
|
||||
Prepends a full system-reminder (memory + date) to the first HumanMessage and
|
||||
persists it (same message ID). The first message is then frozen for the whole
|
||||
session — its content never changes again, so the prefix cache can hit on every
|
||||
subsequent turn.
|
||||
|
||||
Midnight crossing
|
||||
-----------------
|
||||
If the conversation spans midnight, the current date differs from the date that
|
||||
was injected earlier. In that case a lightweight date-update reminder is prepended
|
||||
to the **current** (last) HumanMessage and persisted. Subsequent turns on the new
|
||||
day see the corrected date in history and skip re-injection.
|
||||
"""
|
||||
|
||||
def __init__(self, agent_name: str | None = None, *, app_config: AppConfig | None = None):
|
||||
super().__init__()
|
||||
self._agent_name = agent_name
|
||||
self._app_config = app_config
|
||||
|
||||
def _build_full_reminder(self) -> str:
|
||||
from deerflow.agents.lead_agent.prompt import _get_memory_context
|
||||
|
||||
# Memory injection is gated by injection_enabled; date is always included.
|
||||
injection_enabled = self._app_config.memory.injection_enabled if self._app_config else True
|
||||
memory_context = _get_memory_context(self._agent_name, app_config=self._app_config) if injection_enabled else ""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d, %A")
|
||||
|
||||
lines: list[str] = ["<system-reminder>"]
|
||||
if memory_context:
|
||||
lines.append(memory_context.strip())
|
||||
lines.append("") # blank line separating memory from date
|
||||
lines.append(f"<current_date>{current_date}</current_date>")
|
||||
lines.append("</system-reminder>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _build_date_update_reminder(self) -> str:
|
||||
current_date = datetime.now().strftime("%Y-%m-%d, %A")
|
||||
return "\n".join(
|
||||
[
|
||||
"<system-reminder>",
|
||||
f"<current_date>{current_date}</current_date>",
|
||||
"</system-reminder>",
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _make_reminder_and_user_messages(original: HumanMessage, reminder_content: str) -> tuple[HumanMessage, HumanMessage]:
|
||||
"""Return (reminder_msg, user_msg) using the ID-swap technique.
|
||||
|
||||
reminder_msg takes the original message's ID so that add_messages replaces it
|
||||
in-place (preserving position). user_msg carries the original content with a
|
||||
derived ``{id}__user`` ID and is appended immediately after by add_messages.
|
||||
|
||||
If the original message has no ID a stable UUID is generated so the derived
|
||||
``{id}__user`` ID never collapses to the ambiguous ``None__user`` string.
|
||||
"""
|
||||
stable_id = original.id or str(uuid.uuid4())
|
||||
reminder_msg = HumanMessage(
|
||||
content=reminder_content,
|
||||
id=stable_id,
|
||||
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
|
||||
)
|
||||
user_msg = HumanMessage(
|
||||
content=original.content,
|
||||
id=f"{stable_id}__user",
|
||||
name=original.name,
|
||||
additional_kwargs=original.additional_kwargs,
|
||||
)
|
||||
return reminder_msg, user_msg
|
||||
|
||||
def _inject(self, state) -> dict | None:
|
||||
messages = list(state.get("messages", []))
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d, %A")
|
||||
last_date = _last_injected_date(messages)
|
||||
logger.debug(
|
||||
"DynamicContextMiddleware._inject: msg_count=%d last_date=%r current_date=%r",
|
||||
len(messages),
|
||||
last_date,
|
||||
current_date,
|
||||
)
|
||||
|
||||
if last_date is None:
|
||||
# ── First turn: inject full reminder as a separate HumanMessage ─────
|
||||
first_idx = next((i for i, m in enumerate(messages) if _is_user_injection_target(m)), None)
|
||||
if first_idx is None:
|
||||
return None
|
||||
full_reminder = self._build_full_reminder()
|
||||
logger.info(
|
||||
"DynamicContextMiddleware: injecting full reminder (len=%d, has_memory=%s) into first HumanMessage id=%r",
|
||||
len(full_reminder),
|
||||
"<memory>" in full_reminder,
|
||||
messages[first_idx].id,
|
||||
)
|
||||
reminder_msg, user_msg = self._make_reminder_and_user_messages(messages[first_idx], full_reminder)
|
||||
return {"messages": [reminder_msg, user_msg]}
|
||||
|
||||
if last_date == current_date:
|
||||
# ── Same day: nothing to do ──────────────────────────────────────────
|
||||
return None
|
||||
|
||||
# ── Midnight crossed: inject date-update reminder as a separate HumanMessage ──
|
||||
last_human_idx = next((i for i in reversed(range(len(messages))) if _is_user_injection_target(messages[i])), None)
|
||||
if last_human_idx is None:
|
||||
return None
|
||||
|
||||
reminder_msg, user_msg = self._make_reminder_and_user_messages(messages[last_human_idx], self._build_date_update_reminder())
|
||||
logger.info("DynamicContextMiddleware: midnight crossing detected — injected date update before current turn")
|
||||
return {"messages": [reminder_msg, user_msg]}
|
||||
|
||||
@override
|
||||
def before_agent(self, state, runtime: Runtime) -> dict | None:
|
||||
return self._inject(state)
|
||||
|
||||
@override
|
||||
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
|
||||
return self._inject(state)
|
||||
@@ -12,19 +12,23 @@ Detection strategy:
|
||||
response so the agent is forced to produce a final text answer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import override
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults — can be overridden via constructor
|
||||
@@ -140,6 +144,9 @@ _TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times
|
||||
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Detects and breaks repetitive tool call loops.
|
||||
|
||||
Threshold parameters are validated upstream by :class:`LoopDetectionConfig`;
|
||||
construct via :meth:`from_config` to ensure values pass Pydantic validation.
|
||||
|
||||
Args:
|
||||
warn_threshold: Number of identical tool call sets before injecting
|
||||
a warning message. Default: 3.
|
||||
@@ -155,6 +162,14 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
Default: 30.
|
||||
tool_freq_hard_limit: Number of calls to the same tool type before
|
||||
forcing a stop. Default: 50.
|
||||
tool_freq_overrides: Per-tool overrides for frequency thresholds,
|
||||
keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple
|
||||
that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for
|
||||
that specific tool. Tools not listed here fall back to the global
|
||||
thresholds. Useful for raising limits on intentionally
|
||||
high-frequency tools (e.g. ``bash`` in batch pipelines) without
|
||||
weakening protection on all other tools. Default: ``None``
|
||||
(no overrides).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -165,6 +180,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
||||
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
|
||||
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
|
||||
tool_freq_overrides: dict[str, tuple[int, int]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.warn_threshold = warn_threshold
|
||||
@@ -173,14 +189,26 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
self.max_tracked_threads = max_tracked_threads
|
||||
self.tool_freq_warn = tool_freq_warn
|
||||
self.tool_freq_hard_limit = tool_freq_hard_limit
|
||||
self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {}
|
||||
self._lock = threading.Lock()
|
||||
# Per-thread tracking using OrderedDict for LRU eviction
|
||||
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
||||
self._warned: dict[str, set[str]] = defaultdict(set)
|
||||
# Per-thread, per-tool-type cumulative call counts
|
||||
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
|
||||
"""Construct from a Pydantic-validated config, trusting its validation."""
|
||||
return cls(
|
||||
warn_threshold=config.warn_threshold,
|
||||
hard_limit=config.hard_limit,
|
||||
window_size=config.window_size,
|
||||
max_tracked_threads=config.max_tracked_threads,
|
||||
tool_freq_warn=config.tool_freq_warn,
|
||||
tool_freq_hard_limit=config.tool_freq_hard_limit,
|
||||
tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()},
|
||||
)
|
||||
|
||||
def _get_thread_id(self, runtime: Runtime) -> str:
|
||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
@@ -280,7 +308,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
freq[name] += 1
|
||||
tc_count = freq[name]
|
||||
|
||||
if tc_count >= self.tool_freq_hard_limit:
|
||||
if name in self._tool_freq_overrides:
|
||||
eff_warn, eff_hard = self._tool_freq_overrides[name]
|
||||
else:
|
||||
eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit
|
||||
|
||||
if tc_count >= eff_hard:
|
||||
logger.error(
|
||||
"Tool frequency hard limit reached — forcing stop",
|
||||
extra={
|
||||
@@ -291,7 +324,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
)
|
||||
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
|
||||
|
||||
if tc_count >= self.tool_freq_warn:
|
||||
if tc_count >= eff_warn:
|
||||
warned = self._tool_freq_warned[thread_id]
|
||||
if name not in warned:
|
||||
warned.add(name)
|
||||
@@ -356,13 +389,30 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
if warning:
|
||||
# Inject as HumanMessage instead of SystemMessage to avoid
|
||||
# Anthropic's "multiple non-consecutive system messages" error.
|
||||
# Anthropic models require system messages only at the start of
|
||||
# the conversation; injecting one mid-conversation crashes
|
||||
# langchain_anthropic's _format_messages(). HumanMessage works
|
||||
# with all providers. See #1299.
|
||||
return {"messages": [HumanMessage(content=warning, name="loop_warning")]}
|
||||
# WORKAROUND for v2.0-m1 — see #2724.
|
||||
#
|
||||
# Append the warning to the AIMessage content instead of
|
||||
# injecting a separate HumanMessage. Inserting any non-tool
|
||||
# message between an AIMessage(tool_calls=...) and its
|
||||
# ToolMessage responses breaks OpenAI/Moonshot strict pairing
|
||||
# validation ("tool_call_ids did not have response messages")
|
||||
# because the tools node has not run yet at after_model time.
|
||||
# tool_calls are preserved so the tools node still executes.
|
||||
#
|
||||
# This is a temporary mitigation: mutating an existing
|
||||
# AIMessage to carry framework-authored text leaks loop-warning
|
||||
# text into downstream consumers (MemoryMiddleware fact
|
||||
# extraction, TitleMiddleware, telemetry, model replay) as if
|
||||
# the model said it. The proper fix is to defer warning
|
||||
# injection from after_model to wrap_model_call so every prior
|
||||
# ToolMessage is already in the request — see RFC #2517 (which
|
||||
# lists "loop intervention does not leave invalid
|
||||
# tool-call/tool-message state" as acceptance criteria) and
|
||||
# the prototype on `fix/loop-detection-tool-call-pairing`.
|
||||
messages = state.get("messages", [])
|
||||
last_msg = messages[-1]
|
||||
patched_msg = last_msg.model_copy(update={"content": self._append_text(last_msg.content, warning)})
|
||||
return {"messages": [patched_msg]}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
||||
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -63,7 +64,7 @@ class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
|
||||
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
|
||||
|
||||
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
|
||||
updated_msg = clone_ai_message_with_tool_calls(last_msg, truncated_tool_calls)
|
||||
return {"messages": [updated_msg]}
|
||||
|
||||
@override
|
||||
|
||||
@@ -14,6 +14,9 @@ from langgraph.config import get_config
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
|
||||
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -78,10 +81,7 @@ def _clone_ai_message(
|
||||
content: Any | None = None,
|
||||
) -> AIMessage:
|
||||
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
|
||||
update: dict[str, Any] = {"tool_calls": tool_calls}
|
||||
if content is not None:
|
||||
update["content"] = content
|
||||
return message.model_copy(update=update)
|
||||
return clone_ai_message_with_tool_calls(message, tool_calls, content=content)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -136,6 +136,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
||||
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
|
||||
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
||||
summary = self._create_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
@@ -161,6 +162,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
||||
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
|
||||
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
||||
summary = await self._acreate_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
@@ -180,6 +182,24 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
"""
|
||||
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
||||
|
||||
def _preserve_dynamic_context_reminders(
|
||||
self,
|
||||
messages_to_summarize: list[AnyMessage],
|
||||
preserved_messages: list[AnyMessage],
|
||||
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
||||
"""Keep hidden dynamic-context reminders out of summary compression.
|
||||
|
||||
These reminders carry the current date and optional memory. If summarization
|
||||
removes them, DynamicContextMiddleware can mistake the summary HumanMessage
|
||||
for the first user message and inject the reminder in the wrong place.
|
||||
"""
|
||||
reminders = [msg for msg in messages_to_summarize if is_dynamic_context_reminder(msg)]
|
||||
if not reminders:
|
||||
return messages_to_summarize, preserved_messages
|
||||
|
||||
remaining = [msg for msg in messages_to_summarize if not is_dynamic_context_reminder(msg)]
|
||||
return remaining, reminders + preserved_messages
|
||||
|
||||
def _partition_with_skill_rescue(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
|
||||
@@ -9,6 +9,7 @@ from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
|
||||
from deerflow.config.title_config import get_title_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
@@ -61,6 +62,10 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _is_user_message_for_title(message: object) -> bool:
|
||||
return getattr(message, "type", None) == "human" and not is_dynamic_context_reminder(message)
|
||||
|
||||
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
||||
"""Check if we should generate a title for this thread."""
|
||||
config = self._get_title_config()
|
||||
@@ -77,7 +82,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
return False
|
||||
|
||||
# Count user and assistant messages
|
||||
user_messages = [m for m in messages if m.type == "human"]
|
||||
user_messages = [m for m in messages if self._is_user_message_for_title(m)]
|
||||
assistant_messages = [m for m in messages if m.type == "ai"]
|
||||
|
||||
# Generate title after first complete exchange
|
||||
@@ -91,7 +96,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
config = self._get_title_config()
|
||||
messages = state.get("messages", [])
|
||||
|
||||
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
||||
user_msg_content = next((m.content for m in messages if self._is_user_message_for_title(m)), "")
|
||||
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
|
||||
|
||||
user_msg = self._normalize_content(user_msg_content)
|
||||
|
||||
@@ -267,11 +267,20 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
|
||||
usage = getattr(last, "usage_metadata", None)
|
||||
if usage:
|
||||
input_token_details = usage.get("input_token_details") or {}
|
||||
output_token_details = usage.get("output_token_details") or {}
|
||||
detail_parts = []
|
||||
if input_token_details:
|
||||
detail_parts.append(f"input_token_details={input_token_details}")
|
||||
if output_token_details:
|
||||
detail_parts.append(f"output_token_details={output_token_details}")
|
||||
detail_suffix = f" {' '.join(detail_parts)}" if detail_parts else ""
|
||||
logger.info(
|
||||
"LLM token usage: input=%s output=%s total=%s",
|
||||
"LLM token usage: input=%s output=%s total=%s%s",
|
||||
usage.get("input_tokens", "?"),
|
||||
usage.get("output_tokens", "?"),
|
||||
usage.get("total_tokens", "?"),
|
||||
detail_suffix,
|
||||
)
|
||||
|
||||
todos = state.get("todos") or []
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Helpers for keeping AIMessage tool-call metadata consistent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
def _raw_tool_call_id(raw_tool_call: Any) -> str | None:
|
||||
if not isinstance(raw_tool_call, dict):
|
||||
return None
|
||||
|
||||
raw_id = raw_tool_call.get("id")
|
||||
return raw_id if isinstance(raw_id, str) and raw_id else None
|
||||
|
||||
|
||||
def clone_ai_message_with_tool_calls(
|
||||
message: AIMessage,
|
||||
tool_calls: list[dict[str, Any]],
|
||||
*,
|
||||
content: Any | None = None,
|
||||
) -> AIMessage:
|
||||
"""Clone an AIMessage while keeping raw provider tool-call metadata in sync."""
|
||||
kept_ids = {tc["id"] for tc in tool_calls if isinstance(tc.get("id"), str) and tc["id"]}
|
||||
|
||||
update: dict[str, Any] = {"tool_calls": tool_calls}
|
||||
if content is not None:
|
||||
update["content"] = content
|
||||
|
||||
additional_kwargs = dict(getattr(message, "additional_kwargs", {}) or {})
|
||||
raw_tool_calls = additional_kwargs.get("tool_calls")
|
||||
if isinstance(raw_tool_calls, list):
|
||||
synced_raw_tool_calls = [raw_tc for raw_tc in raw_tool_calls if _raw_tool_call_id(raw_tc) in kept_ids]
|
||||
if synced_raw_tool_calls:
|
||||
additional_kwargs["tool_calls"] = synced_raw_tool_calls
|
||||
else:
|
||||
additional_kwargs.pop("tool_calls", None)
|
||||
|
||||
if not tool_calls:
|
||||
additional_kwargs.pop("function_call", None)
|
||||
|
||||
update["additional_kwargs"] = additional_kwargs
|
||||
|
||||
response_metadata = dict(getattr(message, "response_metadata", {}) or {})
|
||||
if not tool_calls and response_metadata.get("finish_reason") == "tool_calls":
|
||||
response_metadata["finish_reason"] = "stop"
|
||||
update["response_metadata"] = response_metadata
|
||||
|
||||
return message.model_copy(update=update)
|
||||
@@ -80,6 +80,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
port: 8080 # Base port for local containers
|
||||
container_prefix: deer-flow-sandbox
|
||||
idle_timeout: 600 # Idle timeout in seconds (0 to disable)
|
||||
auto_restart: true # Restart crashed containers automatically
|
||||
replicas: 3 # Max concurrent sandbox containers (LRU eviction when exceeded)
|
||||
mounts: # Volume mounts for local containers
|
||||
- host_path: /path/on/host
|
||||
@@ -164,12 +165,14 @@ class AioSandboxProvider(SandboxProvider):
|
||||
|
||||
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
|
||||
replicas = getattr(sandbox_config, "replicas", None)
|
||||
auto_restart = getattr(sandbox_config, "auto_restart", True)
|
||||
|
||||
return {
|
||||
"image": sandbox_config.image or DEFAULT_IMAGE,
|
||||
"port": sandbox_config.port or DEFAULT_PORT,
|
||||
"container_prefix": sandbox_config.container_prefix or DEFAULT_CONTAINER_PREFIX,
|
||||
"idle_timeout": idle_timeout if idle_timeout is not None else DEFAULT_IDLE_TIMEOUT,
|
||||
"auto_restart": auto_restart,
|
||||
"replicas": replicas if replicas is not None else DEFAULT_REPLICAS,
|
||||
"mounts": sandbox_config.mounts or [],
|
||||
"environment": self._resolve_env_vars(sandbox_config.environment or {}),
|
||||
@@ -608,18 +611,58 @@ class AioSandboxProvider(SandboxProvider):
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
"""Get a sandbox by ID. Updates last activity timestamp.
|
||||
|
||||
When ``auto_restart`` is enabled (the default), the container's liveness
|
||||
is verified on each lookup. If the underlying container has crashed, the
|
||||
sandbox is evicted from all caches so that the next ``acquire()`` call will
|
||||
transparently create a fresh container.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox.
|
||||
|
||||
Returns:
|
||||
The sandbox instance if found, None otherwise.
|
||||
The sandbox instance if found and alive, None otherwise.
|
||||
"""
|
||||
with self._lock:
|
||||
sandbox = self._sandboxes.get(sandbox_id)
|
||||
if sandbox is not None:
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
if sandbox is None:
|
||||
return None
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
auto_restart = self._config.get("auto_restart", True)
|
||||
info = self._sandbox_infos.get(sandbox_id) if auto_restart else None
|
||||
|
||||
if not info:
|
||||
return sandbox
|
||||
|
||||
if self._backend.is_alive(info):
|
||||
return sandbox
|
||||
|
||||
info_to_destroy = None
|
||||
with self._lock:
|
||||
current_sandbox = self._sandboxes.get(sandbox_id)
|
||||
current_info = self._sandbox_infos.get(sandbox_id)
|
||||
if current_sandbox is None:
|
||||
return None
|
||||
if current_info is not info:
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
return current_sandbox
|
||||
|
||||
logger.warning(f"Sandbox {sandbox_id} container is not alive, evicting from cache for auto-restart")
|
||||
self._sandboxes.pop(sandbox_id, None)
|
||||
self._sandbox_infos.pop(sandbox_id, None)
|
||||
self._last_activity.pop(sandbox_id, None)
|
||||
self._warm_pool.pop(sandbox_id, None)
|
||||
thread_ids = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||
for tid in thread_ids:
|
||||
del self._thread_sandboxes[tid]
|
||||
info_to_destroy = info
|
||||
|
||||
if info_to_destroy:
|
||||
try:
|
||||
self._backend.destroy(info_to_destroy)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup dead sandbox {sandbox_id}: {e}")
|
||||
return None
|
||||
|
||||
def release(self, sandbox_id: str) -> None:
|
||||
"""Release a sandbox from active use into the warm pool.
|
||||
|
||||
|
||||
@@ -84,8 +84,52 @@ class RemoteSandboxBackend(SandboxBackend):
|
||||
"""
|
||||
return self._provisioner_discover(sandbox_id)
|
||||
|
||||
def list_running(self) -> list[SandboxInfo]:
|
||||
"""Return all sandboxes currently managed by the provisioner.
|
||||
|
||||
Calls ``GET /api/sandboxes`` so that ``AioSandboxProvider._reconcile_orphans()``
|
||||
can adopt pods that were created by a previous process and were never
|
||||
explicitly destroyed.
|
||||
Without this, a process restart silently orphans all existing k8s Pods —
|
||||
they stay running forever because the idle checker only
|
||||
tracks in-process state.
|
||||
"""
|
||||
return self._provisioner_list()
|
||||
|
||||
# ── Provisioner API calls ─────────────────────────────────────────────
|
||||
|
||||
def _provisioner_list(self) -> list[SandboxInfo]:
|
||||
"""GET /api/sandboxes → list all running sandboxes."""
|
||||
try:
|
||||
resp = requests.get(f"{self._provisioner_url}/api/sandboxes", timeout=10)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if not isinstance(data, dict):
|
||||
logger.warning("Provisioner list_running returned non-dict payload: %r", type(data))
|
||||
return []
|
||||
|
||||
sandboxes = data.get("sandboxes", [])
|
||||
if not isinstance(sandboxes, list):
|
||||
logger.warning("Provisioner list_running returned non-list sandboxes: %r", type(sandboxes))
|
||||
return []
|
||||
|
||||
infos: list[SandboxInfo] = []
|
||||
for sandbox in sandboxes:
|
||||
if not isinstance(sandbox, dict):
|
||||
logger.warning("Provisioner list_running entry is not a dict: %r", type(sandbox))
|
||||
continue
|
||||
|
||||
sandbox_id = sandbox.get("sandbox_id")
|
||||
sandbox_url = sandbox.get("sandbox_url")
|
||||
if isinstance(sandbox_id, str) and sandbox_id and isinstance(sandbox_url, str) and sandbox_url:
|
||||
infos.append(SandboxInfo(sandbox_id=sandbox_id, sandbox_url=sandbox_url))
|
||||
|
||||
logger.info("Provisioner list_running: %d sandbox(es) found", len(infos))
|
||||
return infos
|
||||
except requests.RequestException as exc:
|
||||
logger.warning("Provisioner list_running failed: %s", exc)
|
||||
return []
|
||||
|
||||
def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
"""POST /api/sandboxes → create Pod + Service."""
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .app_config import get_app_config
|
||||
from .extensions_config import ExtensionsConfig, get_extensions_config
|
||||
from .loop_detection_config import LoopDetectionConfig
|
||||
from .memory_config import MemoryConfig, get_memory_config
|
||||
from .paths import Paths, get_paths
|
||||
from .skill_evolution_config import SkillEvolutionConfig
|
||||
@@ -20,6 +21,7 @@ __all__ = [
|
||||
"SkillsConfig",
|
||||
"ExtensionsConfig",
|
||||
"get_extensions_config",
|
||||
"LoopDetectionConfig",
|
||||
"MemoryConfig",
|
||||
"get_memory_config",
|
||||
"get_tracing_config",
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
"""Configuration and loaders for custom agents."""
|
||||
"""Configuration and loaders for custom agents.
|
||||
|
||||
Custom agents are stored per-user under ``{base_dir}/users/{user_id}/agents/{name}/``.
|
||||
A legacy shared layout at ``{base_dir}/agents/{name}/`` is still readable so that
|
||||
installations that pre-date user isolation continue to work until they run the
|
||||
``scripts/migrate_user_isolation.py`` migration. New writes always target the
|
||||
per-user layout.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,14 +49,47 @@ class AgentConfig(BaseModel):
|
||||
skills: list[str] | None = None
|
||||
|
||||
|
||||
def load_agent_config(name: str | None) -> AgentConfig | None:
|
||||
def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
|
||||
"""Return the on-disk directory for an agent, preferring the per-user layout.
|
||||
|
||||
Resolution order:
|
||||
1. ``{base_dir}/users/{user_id}/agents/{name}/`` (per-user, current layout).
|
||||
2. ``{base_dir}/agents/{name}/`` (legacy shared layout — read-only fallback).
|
||||
|
||||
If neither exists, the per-user path is returned so callers that intend to
|
||||
create the agent write into the new layout.
|
||||
|
||||
Args:
|
||||
name: Validated agent name.
|
||||
user_id: Owner of the agent. Defaults to the effective user from the
|
||||
request context (or ``"default"`` in no-auth mode).
|
||||
"""
|
||||
paths = get_paths()
|
||||
effective_user = user_id or get_effective_user_id()
|
||||
user_path = paths.user_agent_dir(effective_user, name)
|
||||
if user_path.exists():
|
||||
return user_path
|
||||
|
||||
legacy_path = paths.agent_dir(name)
|
||||
if legacy_path.exists():
|
||||
return legacy_path
|
||||
|
||||
return user_path
|
||||
|
||||
|
||||
def load_agent_config(name: str | None, *, user_id: str | None = None) -> AgentConfig | None:
|
||||
"""Load the custom or default agent's config from its directory.
|
||||
|
||||
Reads from the per-user layout first; falls back to the legacy shared layout
|
||||
for installations that have not yet been migrated.
|
||||
|
||||
Args:
|
||||
name: The agent name.
|
||||
user_id: Owner of the agent. Defaults to the effective user from the
|
||||
current request context.
|
||||
|
||||
Returns:
|
||||
AgentConfig instance.
|
||||
AgentConfig instance, or ``None`` if ``name`` is ``None``.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the agent directory or config.yaml does not exist.
|
||||
@@ -58,7 +100,7 @@ def load_agent_config(name: str | None) -> AgentConfig | None:
|
||||
return None
|
||||
|
||||
name = validate_agent_name(name)
|
||||
agent_dir = get_paths().agent_dir(name)
|
||||
agent_dir = resolve_agent_dir(name, user_id=user_id)
|
||||
config_file = agent_dir / "config.yaml"
|
||||
|
||||
if not agent_dir.exists():
|
||||
@@ -84,7 +126,7 @@ def load_agent_config(name: str | None) -> AgentConfig | None:
|
||||
return AgentConfig(**data)
|
||||
|
||||
|
||||
def load_agent_soul(agent_name: str | None) -> str | None:
|
||||
def load_agent_soul(agent_name: str | None, *, user_id: str | None = None) -> str | None:
|
||||
"""Read the SOUL.md file for a custom agent, if it exists.
|
||||
|
||||
SOUL.md defines the agent's personality, values, and behavioral guardrails.
|
||||
@@ -92,11 +134,16 @@ def load_agent_soul(agent_name: str | None) -> str | None:
|
||||
|
||||
Args:
|
||||
agent_name: The name of the agent or None for the default agent.
|
||||
user_id: Owner of the agent. Defaults to the effective user from the
|
||||
current request context.
|
||||
|
||||
Returns:
|
||||
The SOUL.md content as a string, or None if the file does not exist.
|
||||
"""
|
||||
agent_dir = get_paths().agent_dir(agent_name) if agent_name else get_paths().base_dir
|
||||
if agent_name:
|
||||
agent_dir = resolve_agent_dir(agent_name, user_id=user_id)
|
||||
else:
|
||||
agent_dir = get_paths().base_dir
|
||||
soul_path = agent_dir / SOUL_FILENAME
|
||||
if not soul_path.exists():
|
||||
return None
|
||||
@@ -104,32 +151,50 @@ def load_agent_soul(agent_name: str | None) -> str | None:
|
||||
return content or None
|
||||
|
||||
|
||||
def list_custom_agents() -> list[AgentConfig]:
|
||||
def list_custom_agents(*, user_id: str | None = None) -> list[AgentConfig]:
|
||||
"""Scan the agents directory and return all valid custom agents.
|
||||
|
||||
Returns the union of agents in the per-user layout and the legacy shared
|
||||
layout, so that pre-migration installations remain visible until they are
|
||||
migrated. Per-user entries shadow legacy entries with the same name.
|
||||
|
||||
Args:
|
||||
user_id: Owner whose agents to list. Defaults to the effective user
|
||||
from the current request context.
|
||||
|
||||
Returns:
|
||||
List of AgentConfig for each valid agent directory found.
|
||||
"""
|
||||
agents_dir = get_paths().agents_dir
|
||||
|
||||
if not agents_dir.exists():
|
||||
return []
|
||||
paths = get_paths()
|
||||
effective_user = user_id or get_effective_user_id()
|
||||
|
||||
seen: set[str] = set()
|
||||
agents: list[AgentConfig] = []
|
||||
|
||||
for entry in sorted(agents_dir.iterdir()):
|
||||
if not entry.is_dir():
|
||||
user_root = paths.user_agents_dir(effective_user)
|
||||
legacy_root = paths.agents_dir
|
||||
|
||||
for root in (user_root, legacy_root):
|
||||
if not root.exists():
|
||||
continue
|
||||
for entry in sorted(root.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
if entry.name in seen:
|
||||
continue
|
||||
config_file = entry / "config.yaml"
|
||||
if not config_file.exists():
|
||||
logger.debug(f"Skipping {entry.name}: no config.yaml")
|
||||
continue
|
||||
|
||||
config_file = entry / "config.yaml"
|
||||
if not config_file.exists():
|
||||
logger.debug(f"Skipping {entry.name}: no config.yaml")
|
||||
continue
|
||||
|
||||
try:
|
||||
agent_cfg = load_agent_config(entry.name)
|
||||
agents.append(agent_cfg)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping agent '{entry.name}': {e}")
|
||||
try:
|
||||
agent_cfg = load_agent_config(entry.name, user_id=effective_user)
|
||||
if agent_cfg is None:
|
||||
continue
|
||||
agents.append(agent_cfg)
|
||||
seen.add(entry.name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping agent '{entry.name}': {e}")
|
||||
|
||||
agents.sort(key=lambda a: a.name)
|
||||
return agents
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any, Self
|
||||
@@ -14,6 +15,7 @@ from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpo
|
||||
from deerflow.config.database_config import DatabaseConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.run_events_config import RunEventsConfig
|
||||
@@ -99,6 +101,7 @@ class AppConfig(BaseModel):
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
|
||||
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
|
||||
@@ -157,56 +160,54 @@ class AppConfig(BaseModel):
|
||||
config_data = cls.resolve_env_variables(config_data)
|
||||
cls._apply_database_defaults(config_data)
|
||||
|
||||
# Load title config if present
|
||||
if "title" in config_data:
|
||||
load_title_config_from_dict(config_data["title"])
|
||||
|
||||
# Load summarization config if present
|
||||
if "summarization" in config_data:
|
||||
load_summarization_config_from_dict(config_data["summarization"])
|
||||
|
||||
# Load memory config if present
|
||||
if "memory" in config_data:
|
||||
load_memory_config_from_dict(config_data["memory"])
|
||||
|
||||
# Always refresh agents API config so removed config sections reset
|
||||
# singleton-backed state to its default/disabled values on reload.
|
||||
load_agents_api_config_from_dict(config_data.get("agents_api") or {})
|
||||
|
||||
# Load subagents config if present
|
||||
if "subagents" in config_data:
|
||||
load_subagents_config_from_dict(config_data["subagents"])
|
||||
|
||||
# Load tool_search config if present
|
||||
if "tool_search" in config_data:
|
||||
load_tool_search_config_from_dict(config_data["tool_search"])
|
||||
|
||||
# Load guardrails config if present
|
||||
if "guardrails" in config_data:
|
||||
load_guardrails_config_from_dict(config_data["guardrails"])
|
||||
|
||||
# Load circuit_breaker config if present
|
||||
if "circuit_breaker" in config_data:
|
||||
config_data["circuit_breaker"] = config_data["circuit_breaker"]
|
||||
|
||||
# Load checkpointer config if present
|
||||
if "checkpointer" in config_data:
|
||||
load_checkpointer_config_from_dict(config_data["checkpointer"])
|
||||
|
||||
# Load stream bridge config if present
|
||||
if "stream_bridge" in config_data:
|
||||
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
|
||||
|
||||
# Always refresh ACP agent config so removed entries do not linger across reloads.
|
||||
load_acp_config_from_dict(config_data.get("acp_agents", {}))
|
||||
|
||||
# Load extensions config separately (it's in a different file)
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
config_data["extensions"] = extensions_config.model_dump()
|
||||
|
||||
result = cls.model_validate(config_data)
|
||||
acp_agents = cls._validate_acp_agents(config_data.get("acp_agents", {}))
|
||||
cls._apply_singleton_configs(result, acp_agents)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _validate_acp_agents(
|
||||
cls,
|
||||
config_data: Mapping[str, Mapping[str, object]] | None,
|
||||
) -> dict[str, ACPAgentConfig]:
|
||||
if config_data is None:
|
||||
config_data = {}
|
||||
return {name: ACPAgentConfig(**cfg) for name, cfg in config_data.items()}
|
||||
|
||||
@classmethod
|
||||
def _apply_singleton_configs(cls, config: Self, acp_agents: dict[str, ACPAgentConfig]) -> None:
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
previous_checkpointer_config = get_checkpointer_config()
|
||||
|
||||
load_title_config_from_dict(config.title.model_dump())
|
||||
load_summarization_config_from_dict(config.summarization.model_dump())
|
||||
load_memory_config_from_dict(config.memory.model_dump())
|
||||
load_agents_api_config_from_dict(config.agents_api.model_dump())
|
||||
load_subagents_config_from_dict(config.subagents.model_dump())
|
||||
load_tool_search_config_from_dict(config.tool_search.model_dump())
|
||||
load_guardrails_config_from_dict(config.guardrails.model_dump())
|
||||
load_checkpointer_config_from_dict(config.checkpointer.model_dump() if config.checkpointer is not None else None)
|
||||
load_stream_bridge_config_from_dict(config.stream_bridge.model_dump() if config.stream_bridge is not None else None)
|
||||
load_acp_config_from_dict({name: agent.model_dump() for name, agent in acp_agents.items()})
|
||||
|
||||
if previous_checkpointer_config != config.checkpointer:
|
||||
# These runtime singletons derive their backend from checkpointer config.
|
||||
# Keep imports local to avoid cycles: both providers import get_app_config.
|
||||
from deerflow.runtime.checkpointer import reset_checkpointer
|
||||
from deerflow.runtime.store import reset_store
|
||||
|
||||
reset_checkpointer()
|
||||
reset_store()
|
||||
|
||||
@classmethod
|
||||
def _apply_database_defaults(cls, config_data: dict[str, Any]) -> None:
|
||||
"""Apply config.yaml defaults for persistence when the section is absent."""
|
||||
|
||||
@@ -14,12 +14,13 @@ class CheckpointerConfig(BaseModel):
|
||||
description="Checkpointer backend type. "
|
||||
"'memory' is in-process only (lost on restart). "
|
||||
"'sqlite' persists to a local file (requires langgraph-checkpoint-sqlite). "
|
||||
"'postgres' persists to PostgreSQL (requires langgraph-checkpoint-postgres)."
|
||||
"'postgres' persists to PostgreSQL (install with deerflow-harness[postgres])."
|
||||
)
|
||||
connection_string: str | None = Field(
|
||||
default=None,
|
||||
description="Connection string for sqlite (file path) or postgres (DSN). "
|
||||
"Required for sqlite and postgres types. "
|
||||
"Optional for sqlite and defaults to 'store.db' when omitted. "
|
||||
"Required for postgres. "
|
||||
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
|
||||
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
|
||||
)
|
||||
@@ -40,7 +41,10 @@ def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
|
||||
_checkpointer_config = config
|
||||
|
||||
|
||||
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
|
||||
def load_checkpointer_config_from_dict(config_dict: dict | None) -> None:
|
||||
"""Load checkpointer configuration from a dictionary."""
|
||||
global _checkpointer_config
|
||||
if config_dict is None:
|
||||
_checkpointer_config = None
|
||||
return
|
||||
_checkpointer_config = CheckpointerConfig(**config_dict)
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Configuration for loop detection middleware."""
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class ToolFreqOverride(BaseModel):
|
||||
"""Per-tool frequency threshold override.
|
||||
|
||||
Can be higher or lower than the global defaults. Commonly used to raise
|
||||
thresholds for high-frequency tools like bash in batch workflows (e.g.
|
||||
RNA-seq pipelines) without weakening protection on every other tool.
|
||||
"""
|
||||
|
||||
warn: int = Field(ge=1)
|
||||
hard_limit: int = Field(ge=1)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate(self) -> "ToolFreqOverride":
|
||||
if self.hard_limit < self.warn:
|
||||
raise ValueError("hard_limit must be >= warn")
|
||||
return self
|
||||
|
||||
|
||||
class LoopDetectionConfig(BaseModel):
|
||||
"""Configuration for repetitive tool-call loop detection."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable repetitive tool-call loop detection",
|
||||
)
|
||||
warn_threshold: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
description="Number of identical tool-call sets before injecting a warning",
|
||||
)
|
||||
hard_limit: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
description="Number of identical tool-call sets before forcing a stop",
|
||||
)
|
||||
window_size: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
description="Number of recent tool-call sets to track per thread",
|
||||
)
|
||||
max_tracked_threads: int = Field(
|
||||
default=100,
|
||||
ge=1,
|
||||
description="Maximum number of thread histories to keep in memory",
|
||||
)
|
||||
tool_freq_warn: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
description="Number of calls to the same tool type before injecting a frequency warning",
|
||||
)
|
||||
tool_freq_hard_limit: int = Field(
|
||||
default=50,
|
||||
ge=1,
|
||||
description="Number of calls to the same tool type before forcing a stop",
|
||||
)
|
||||
tool_freq_overrides: dict[str, ToolFreqOverride] = Field(
|
||||
default_factory=dict,
|
||||
description=("Per-tool overrides for tool_freq_warn / tool_freq_hard_limit, keyed by tool name. Values can be higher or lower than the global defaults. Commonly used to raise thresholds for high-frequency tools like bash."),
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_thresholds(self) -> "LoopDetectionConfig":
|
||||
"""Ensure hard stop cannot happen before the warning threshold."""
|
||||
if self.hard_limit < self.warn_threshold:
|
||||
raise ValueError("hard_limit must be greater than or equal to warn_threshold")
|
||||
if self.tool_freq_hard_limit < self.tool_freq_warn:
|
||||
raise ValueError("tool_freq_hard_limit must be greater than or equal to tool_freq_warn")
|
||||
return self
|
||||
@@ -132,15 +132,20 @@ class Paths:
|
||||
|
||||
@property
|
||||
def agents_dir(self) -> Path:
|
||||
"""Root directory for all custom agents: `{base_dir}/agents/`."""
|
||||
"""Legacy root for shared (pre user-isolation) custom agents: `{base_dir}/agents/`.
|
||||
|
||||
New code should use :meth:`user_agents_dir` instead. This property remains
|
||||
only as a read-side fallback for installations that have not yet run the
|
||||
``migrate_user_isolation.py`` script.
|
||||
"""
|
||||
return self.base_dir / "agents"
|
||||
|
||||
def agent_dir(self, name: str) -> Path:
|
||||
"""Directory for a specific agent: `{base_dir}/agents/{name}/`."""
|
||||
"""Legacy per-agent directory (no user isolation): `{base_dir}/agents/{name}/`."""
|
||||
return self.agents_dir / name.lower()
|
||||
|
||||
def agent_memory_file(self, name: str) -> Path:
|
||||
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
|
||||
"""Legacy per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
|
||||
return self.agent_dir(name) / "memory.json"
|
||||
|
||||
def user_dir(self, user_id: str) -> Path:
|
||||
@@ -151,9 +156,17 @@ class Paths:
|
||||
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
|
||||
return self.user_dir(user_id) / "memory.json"
|
||||
|
||||
def user_agents_dir(self, user_id: str) -> Path:
|
||||
"""Per-user root for that user's custom agents: `{base_dir}/users/{user_id}/agents/`."""
|
||||
return self.user_dir(user_id) / "agents"
|
||||
|
||||
def user_agent_dir(self, user_id: str, agent_name: str) -> Path:
|
||||
"""Per-user per-agent directory: `{base_dir}/users/{user_id}/agents/{name}/`."""
|
||||
return self.user_agents_dir(user_id) / agent_name.lower()
|
||||
|
||||
def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path:
|
||||
"""Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`."""
|
||||
return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json"
|
||||
return self.user_agent_dir(user_id, agent_name) / "memory.json"
|
||||
|
||||
def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
"""
|
||||
|
||||
@@ -23,6 +23,9 @@ class SandboxConfig(BaseModel):
|
||||
replicas: Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.
|
||||
container_prefix: Prefix for container names (default: deer-flow-sandbox)
|
||||
idle_timeout: Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.
|
||||
auto_restart: Automatically restart sandbox containers that have crashed (default: true). When a tool call
|
||||
detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated
|
||||
on the next acquire. Set to false to disable.
|
||||
mounts: List of volume mounts to share directories with the container
|
||||
environment: Environment variables to inject into the container (values starting with $ are resolved from host env)
|
||||
"""
|
||||
@@ -55,6 +58,10 @@ class SandboxConfig(BaseModel):
|
||||
default=None,
|
||||
description="Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.",
|
||||
)
|
||||
auto_restart: bool = Field(
|
||||
default=True,
|
||||
description="Automatically restart sandbox containers that have crashed. When a tool call detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated on the next acquire.",
|
||||
)
|
||||
mounts: list[VolumeMountConfig] = Field(
|
||||
default_factory=list,
|
||||
description="List of volume mounts to share directories between host and container",
|
||||
|
||||
@@ -40,7 +40,10 @@ def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
|
||||
_stream_bridge_config = config
|
||||
|
||||
|
||||
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
|
||||
def load_stream_bridge_config_from_dict(config_dict: dict | None) -> None:
|
||||
"""Load stream bridge configuration from a dictionary."""
|
||||
global _stream_bridge_config
|
||||
if config_dict is None:
|
||||
_stream_bridge_config = None
|
||||
return
|
||||
_stream_bridge_config = StreamBridgeConfig(**config_dict)
|
||||
|
||||
@@ -179,9 +179,3 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
|
||||
overrides_summary or "none",
|
||||
custom_agents_names or "none",
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
|
||||
_subagents_config.timeout_seconds,
|
||||
_subagents_config.max_turns,
|
||||
)
|
||||
|
||||
@@ -4,4 +4,4 @@ from pydantic import BaseModel, Field
|
||||
class TokenUsageConfig(BaseModel):
|
||||
"""Configuration for token usage tracking."""
|
||||
|
||||
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
|
||||
enabled: bool = Field(default=True, description="Enable token usage tracking middleware")
|
||||
|
||||
@@ -196,6 +196,10 @@ class ClaudeChatModel(ChatAnthropic):
|
||||
enforced by both the Anthropic API and AWS Bedrock. Breakpoints are
|
||||
placed on the *last* eligible blocks because later breakpoints cover a
|
||||
larger prefix and yield better cache hit rates.
|
||||
|
||||
The system prompt is expected to be fully static (no per-user memory or
|
||||
current date). Dynamic context is injected per-turn via
|
||||
DynamicContextMiddleware as a <system-reminder> in the first HumanMessage.
|
||||
"""
|
||||
MAX_CACHE_BREAKPOINTS = 4
|
||||
|
||||
|
||||
@@ -81,7 +81,16 @@ async def init_engine(
|
||||
try:
|
||||
import asyncpg # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError("database.backend is set to 'postgres' but asyncpg is not installed.\nInstall it with:\n uv sync --extra postgres\nOr switch to backend: sqlite in config.yaml for single-node deployment.") from None
|
||||
raise ImportError(
|
||||
"database.backend is set to 'postgres' but asyncpg is not installed.\n"
|
||||
"Install it with:\n"
|
||||
" cd backend && uv sync --all-packages --extra postgres\n"
|
||||
"On the next `make dev` the postgres extra is auto-detected from\n"
|
||||
"config.yaml (database.backend: postgres) and reinstalled, so it\n"
|
||||
"will not be wiped again. Set UV_EXTRAS=postgres in .env to opt in\n"
|
||||
"explicitly. Or switch to backend: sqlite in config.yaml for\n"
|
||||
"single-node deployment."
|
||||
) from None
|
||||
|
||||
if backend == "sqlite":
|
||||
import os
|
||||
|
||||
@@ -36,7 +36,9 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
|
||||
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
|
||||
POSTGRES_INSTALL = (
|
||||
"langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install the package extra with: pip install 'deerflow-harness[postgres]' (or use: uv sync --all-packages --extra postgres when developing locally)"
|
||||
)
|
||||
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -9,6 +9,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
@@ -33,20 +34,21 @@ class DbRunEventStore(RunEventStore):
|
||||
if isinstance(val, datetime):
|
||||
d["created_at"] = val.isoformat()
|
||||
d.pop("id", None)
|
||||
# Restore dict content that was JSON-serialized on write
|
||||
# Restore structured content that was JSON-serialized on write.
|
||||
raw = d.get("content", "")
|
||||
if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"):
|
||||
metadata = d.get("metadata", {})
|
||||
if isinstance(raw, str) and (metadata.get("content_is_json") or metadata.get("content_is_dict")):
|
||||
try:
|
||||
d["content"] = json.loads(raw)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Content looked like JSON (content_is_dict flag) but failed to parse;
|
||||
# Content looked like JSON but failed to parse;
|
||||
# keep the raw string as-is.
|
||||
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
|
||||
return d
|
||||
|
||||
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
|
||||
def _truncate_trace(self, category: str, content: Any, metadata: dict | None) -> tuple[Any, dict]:
|
||||
if category == "trace":
|
||||
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
|
||||
text = content if isinstance(content, str) else json.dumps(content, default=str, ensure_ascii=False)
|
||||
encoded = text.encode("utf-8")
|
||||
if len(encoded) > self._max_trace_content:
|
||||
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
|
||||
@@ -54,6 +56,18 @@ class DbRunEventStore(RunEventStore):
|
||||
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
|
||||
return content, metadata or {}
|
||||
|
||||
@staticmethod
|
||||
def _content_to_db(content: Any, metadata: dict | None) -> tuple[str, dict]:
|
||||
metadata = metadata or {}
|
||||
if isinstance(content, str):
|
||||
return content, metadata
|
||||
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||
metadata = {**metadata, "content_is_json": True}
|
||||
if isinstance(content, dict):
|
||||
metadata["content_is_dict"] = True
|
||||
return db_content, metadata
|
||||
|
||||
@staticmethod
|
||||
def _user_id_from_context() -> str | None:
|
||||
"""Soft read of user_id from contextvar for write paths.
|
||||
@@ -82,11 +96,7 @@ class DbRunEventStore(RunEventStore):
|
||||
the initial ``human_message`` event (once per run).
|
||||
"""
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
if isinstance(content, dict):
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
db_content, metadata = self._content_to_db(content, metadata)
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
@@ -128,11 +138,7 @@ class DbRunEventStore(RunEventStore):
|
||||
category = e.get("category", "trace")
|
||||
metadata = e.get("metadata")
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
if isinstance(content, dict):
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
db_content, metadata = self._content_to_db(content, metadata)
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
|
||||
@@ -36,7 +36,9 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_STORE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite store. Install it with: uv add langgraph-checkpoint-sqlite"
|
||||
POSTGRES_STORE_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL store. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
|
||||
POSTGRES_STORE_INSTALL = (
|
||||
"langgraph-checkpoint-postgres is required for the PostgreSQL store. Install the package extra with: pip install 'deerflow-harness[postgres]' (or use: uv sync --all-packages --extra postgres when developing locally)"
|
||||
)
|
||||
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -42,6 +42,13 @@ class LocalSandbox(Sandbox):
|
||||
"""Return whether the selected shell is cmd.exe."""
|
||||
return LocalSandbox._shell_name(shell) in {"cmd", "cmd.exe"}
|
||||
|
||||
@staticmethod
|
||||
def _is_msys_shell(shell: str) -> bool:
|
||||
"""Return whether the selected shell is a Git Bash/MSYS shell."""
|
||||
normalized = shell.replace("\\", "/").lower()
|
||||
shell_name = LocalSandbox._shell_name(shell)
|
||||
return shell_name in {"sh.exe", "bash.exe"} and any(part in normalized for part in ("/git/", "/mingw", "/msys"))
|
||||
|
||||
@staticmethod
|
||||
def _find_first_available_shell(candidates: tuple[str, ...]) -> str | None:
|
||||
"""Return the first executable shell path or command found from candidates."""
|
||||
@@ -303,12 +310,19 @@ class LocalSandbox(Sandbox):
|
||||
shell = self._get_shell()
|
||||
|
||||
if os.name == "nt":
|
||||
env = None
|
||||
if self._is_powershell(shell):
|
||||
args = [shell, "-NoProfile", "-Command", resolved_command]
|
||||
elif self._is_cmd_shell(shell):
|
||||
args = [shell, "/c", resolved_command]
|
||||
else:
|
||||
args = [shell, "-c", resolved_command]
|
||||
if self._is_msys_shell(shell):
|
||||
env = {
|
||||
**os.environ,
|
||||
"MSYS_NO_PATHCONV": "1",
|
||||
"MSYS2_ARG_CONV_EXCL": "*",
|
||||
}
|
||||
|
||||
result = subprocess.run(
|
||||
args,
|
||||
@@ -316,6 +330,7 @@ class LocalSandbox(Sandbox):
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600,
|
||||
env=env,
|
||||
)
|
||||
else:
|
||||
args = [shell, "-c", resolved_command]
|
||||
|
||||
@@ -3,10 +3,9 @@ import re
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
from langgraph.typing import ContextT
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState, ThreadState
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
||||
from deerflow.sandbox.exceptions import (
|
||||
@@ -19,6 +18,7 @@ from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
from deerflow.sandbox.search import GrepMatch
|
||||
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
|
||||
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
|
||||
@@ -419,7 +419,7 @@ def _join_path_preserving_style(base: str, relative: str) -> str:
|
||||
return f"{stripped_base}{separator}{normalized_relative}"
|
||||
|
||||
|
||||
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
|
||||
def _sanitize_error(error: Exception, runtime: Runtime | None = None) -> str:
|
||||
"""Sanitize an error message to avoid leaking host filesystem paths.
|
||||
|
||||
In local-sandbox mode, resolved host paths in the error string are masked
|
||||
@@ -994,7 +994,7 @@ def _apply_cwd_prefix(command: str, thread_data: ThreadDataState | None) -> str:
|
||||
return command
|
||||
|
||||
|
||||
def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> ThreadDataState | None:
|
||||
def get_thread_data(runtime: Runtime | None) -> ThreadDataState | None:
|
||||
"""Extract thread_data from runtime state."""
|
||||
if runtime is None:
|
||||
return None
|
||||
@@ -1003,7 +1003,7 @@ def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> Threa
|
||||
return runtime.state.get("thread_data")
|
||||
|
||||
|
||||
def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool:
|
||||
def is_local_sandbox(runtime: Runtime | None) -> bool:
|
||||
"""Check if the current sandbox is a local sandbox.
|
||||
|
||||
Path replacement is only needed for local sandbox since aio sandbox
|
||||
@@ -1019,7 +1019,7 @@ def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool
|
||||
return sandbox_state.get("sandbox_id") == "local"
|
||||
|
||||
|
||||
def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
|
||||
def sandbox_from_runtime(runtime: Runtime | None = None) -> Sandbox:
|
||||
"""Extract sandbox instance from tool runtime.
|
||||
|
||||
DEPRECATED: Use ensure_sandbox_initialized() for lazy initialization support.
|
||||
@@ -1048,7 +1048,7 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
|
||||
return sandbox
|
||||
|
||||
|
||||
def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
|
||||
def ensure_sandbox_initialized(runtime: Runtime | None = None) -> Sandbox:
|
||||
"""Ensure sandbox is initialized, acquiring lazily if needed.
|
||||
|
||||
On first call, acquires a sandbox from the provider and stores it in runtime state.
|
||||
@@ -1107,7 +1107,7 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
|
||||
return sandbox
|
||||
|
||||
|
||||
def ensure_thread_directories_exist(runtime: ToolRuntime[ContextT, ThreadState] | None) -> None:
|
||||
def ensure_thread_directories_exist(runtime: Runtime | None) -> None:
|
||||
"""Ensure thread data directories (workspace, uploads, outputs) exist.
|
||||
|
||||
This function is called lazily when any sandbox tool is first used.
|
||||
@@ -1221,7 +1221,7 @@ def _truncate_ls_output(output: str, max_chars: int) -> str:
|
||||
|
||||
|
||||
@tool("bash", parse_docstring=True)
|
||||
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
|
||||
def bash_tool(runtime: Runtime, description: str, command: str) -> str:
|
||||
"""Execute a bash command in a Linux environment.
|
||||
|
||||
|
||||
@@ -1270,7 +1270,7 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
||||
|
||||
|
||||
@tool("ls", parse_docstring=True)
|
||||
def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path: str) -> str:
|
||||
def ls_tool(runtime: Runtime, description: str, path: str) -> str:
|
||||
"""List the contents of a directory up to 2 levels deep in tree format.
|
||||
|
||||
Args:
|
||||
@@ -1318,7 +1318,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
|
||||
@tool("glob", parse_docstring=True)
|
||||
def glob_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
@@ -1368,7 +1368,7 @@ def glob_tool(
|
||||
|
||||
@tool("grep", parse_docstring=True)
|
||||
def grep_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
@@ -1438,7 +1438,7 @@ def grep_tool(
|
||||
|
||||
@tool("read_file", parse_docstring=True)
|
||||
def read_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
path: str,
|
||||
start_line: int | None = None,
|
||||
@@ -1493,7 +1493,7 @@ def read_file_tool(
|
||||
|
||||
@tool("write_file", parse_docstring=True)
|
||||
def write_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
path: str,
|
||||
content: str,
|
||||
@@ -1533,7 +1533,7 @@ def write_file_tool(
|
||||
|
||||
@tool("str_replace", parse_docstring=True)
|
||||
def str_replace_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
path: str,
|
||||
old_str: str,
|
||||
|
||||
@@ -9,6 +9,29 @@ from .types import SKILL_MD_FILE, Skill, SkillCategory
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_allowed_tools(raw: object, skill_file: Path) -> list[str] | None:
|
||||
"""Parse the optional allowed-tools frontmatter field.
|
||||
|
||||
Returns None when the field is omitted. Returns a list when the field is a
|
||||
YAML sequence of strings, including an empty list for explicit no-tool
|
||||
skills. Raises ValueError for malformed values.
|
||||
"""
|
||||
if raw is None:
|
||||
return None
|
||||
if not isinstance(raw, list):
|
||||
raise ValueError(f"allowed-tools in {skill_file} must be a list of strings")
|
||||
|
||||
allowed_tools: list[str] = []
|
||||
for item in raw:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError(f"allowed-tools in {skill_file} must contain only strings")
|
||||
tool_name = item.strip()
|
||||
if not tool_name:
|
||||
raise ValueError(f"allowed-tools in {skill_file} cannot contain empty tool names")
|
||||
allowed_tools.append(tool_name)
|
||||
return allowed_tools
|
||||
|
||||
|
||||
def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: Path | None = None) -> Skill | None:
|
||||
"""Parse a SKILL.md file and extract metadata.
|
||||
|
||||
@@ -64,6 +87,12 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
|
||||
if license_text is not None:
|
||||
license_text = str(license_text).strip() or None
|
||||
|
||||
try:
|
||||
allowed_tools = parse_allowed_tools(metadata.get("allowed-tools"), skill_file)
|
||||
except ValueError as exc:
|
||||
logger.error("Invalid allowed-tools in %s: %s", skill_file, exc)
|
||||
return None
|
||||
|
||||
return Skill(
|
||||
name=name,
|
||||
description=description,
|
||||
@@ -72,6 +101,7 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
|
||||
skill_file=skill_file,
|
||||
relative_path=relative_path or Path(skill_file.parent.name),
|
||||
category=category,
|
||||
allowed_tools=allowed_tools,
|
||||
enabled=True, # Actual state comes from the extensions config file.
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
from typing import Protocol
|
||||
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NamedTool(Protocol):
|
||||
name: str
|
||||
|
||||
|
||||
def allowed_tool_names_for_skills(skills: list[Skill]) -> set[str] | None:
|
||||
"""Return the union of explicit skill allowed-tools declarations.
|
||||
|
||||
None means legacy allow-all behavior. It is returned only when no loaded
|
||||
skill declares allowed-tools. Once any skill declares the field, legacy
|
||||
skills without the field contribute no tools instead of disabling the
|
||||
explicit restrictions from other skills.
|
||||
"""
|
||||
if not skills:
|
||||
return None
|
||||
|
||||
allowed: set[str] = set()
|
||||
has_explicit_declaration = False
|
||||
for skill in skills:
|
||||
if skill.allowed_tools is None:
|
||||
continue
|
||||
has_explicit_declaration = True
|
||||
if not skill.allowed_tools:
|
||||
logger.info("Skill %s declared empty allowed-tools", skill.name)
|
||||
allowed.update(skill.allowed_tools)
|
||||
|
||||
if not has_explicit_declaration:
|
||||
return None
|
||||
return allowed
|
||||
|
||||
|
||||
def filter_tools_by_skill_allowed_tools[ToolT: NamedTool](tools: list[ToolT], skills: list[Skill]) -> list[ToolT]:
|
||||
allowed = allowed_tool_names_for_skills(skills)
|
||||
if allowed is None:
|
||||
return tools
|
||||
|
||||
return [tool for tool in tools if tool.name in allowed]
|
||||
@@ -27,6 +27,7 @@ class Skill:
|
||||
skill_file: Path
|
||||
relative_path: Path # Relative path from category root to skill directory
|
||||
category: SkillCategory # 'public' or 'custom'
|
||||
allowed_tools: list[str] | None = None
|
||||
enabled: bool = False # Whether this skill is enabled
|
||||
|
||||
@property
|
||||
|
||||
@@ -8,6 +8,7 @@ from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from deerflow.skills.parser import parse_allowed_tools
|
||||
from deerflow.skills.types import SKILL_MD_FILE
|
||||
|
||||
# Allowed properties in SKILL.md frontmatter
|
||||
@@ -84,4 +85,9 @@ def _validate_skill_frontmatter(skill_dir: Path) -> tuple[bool, str, str | None]
|
||||
if len(description) > 1024:
|
||||
return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters.", None
|
||||
|
||||
try:
|
||||
parse_allowed_tools(frontmatter.get("allowed-tools"), skill_md)
|
||||
except ValueError as e:
|
||||
return False, str(e).replace(str(skill_md), SKILL_MD_FILE), None
|
||||
|
||||
return True, "Skill is valid!", name
|
||||
|
||||
@@ -23,6 +23,8 @@ from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadSt
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.models import create_chat_model
|
||||
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -260,16 +262,16 @@ class SubagentExecutor:
|
||||
# Generate trace_id if not provided (for top-level calls)
|
||||
self.trace_id = trace_id or str(uuid.uuid4())[:8]
|
||||
|
||||
# Filter tools based on config
|
||||
self.tools = _filter_tools(
|
||||
self._base_tools = _filter_tools(
|
||||
tools,
|
||||
config.tools,
|
||||
config.disallowed_tools,
|
||||
)
|
||||
self.tools = self._base_tools
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] SubagentExecutor initialized: {config.name} with {len(self.tools)} tools")
|
||||
|
||||
def _create_agent(self):
|
||||
def _create_agent(self, tools: list[BaseTool] | None = None):
|
||||
"""Create the agent instance."""
|
||||
app_config = self.app_config or get_app_config()
|
||||
if self.model_name is None:
|
||||
@@ -283,26 +285,14 @@ class SubagentExecutor:
|
||||
|
||||
return create_agent(
|
||||
model=model,
|
||||
tools=self.tools,
|
||||
tools=tools if tools is not None else self.tools,
|
||||
middleware=middlewares,
|
||||
system_prompt=self.config.system_prompt,
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
async def _load_skill_messages(self) -> list[SystemMessage]:
|
||||
"""Load skill content as conversation items based on config.skills.
|
||||
|
||||
Aligned with Codex's pattern: each subagent loads its own skills
|
||||
per-session and injects them as conversation items (developer messages),
|
||||
not as system prompt text. The config.skills whitelist controls which
|
||||
skills are loaded:
|
||||
- None: load all enabled skills
|
||||
- []: no skills
|
||||
- ["skill-a", "skill-b"]: only these skills
|
||||
|
||||
Returns:
|
||||
List of SystemMessages containing skill content.
|
||||
"""
|
||||
async def _load_skills(self) -> list[Skill]:
|
||||
"""Load enabled skill metadata based on config.skills."""
|
||||
if self.config.skills is not None and len(self.config.skills) == 0:
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} skills=[] — skipping skill loading")
|
||||
return []
|
||||
@@ -316,8 +306,8 @@ class SubagentExecutor:
|
||||
all_skills = await asyncio.to_thread(storage.load_skills, enabled_only=True)
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} loaded {len(all_skills)} enabled skills from disk")
|
||||
except Exception:
|
||||
logger.warning(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}", exc_info=True)
|
||||
return []
|
||||
logger.exception(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}")
|
||||
raise
|
||||
|
||||
if not all_skills:
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} no enabled skills found")
|
||||
@@ -326,10 +316,26 @@ class SubagentExecutor:
|
||||
# Filter by config.skills whitelist
|
||||
if self.config.skills is not None:
|
||||
allowed = set(self.config.skills)
|
||||
skills = [s for s in all_skills if s.name in allowed]
|
||||
else:
|
||||
skills = all_skills
|
||||
return [s for s in all_skills if s.name in allowed]
|
||||
return all_skills
|
||||
|
||||
def _apply_skill_allowed_tools(self, skills: list[Skill]) -> list[BaseTool]:
|
||||
return filter_tools_by_skill_allowed_tools(self._base_tools, skills)
|
||||
|
||||
async def _load_skill_messages(self, skills: list[Skill]) -> list[SystemMessage]:
|
||||
"""Load skill content as conversation items based on config.skills.
|
||||
|
||||
Aligned with Codex's pattern: each subagent loads its own skills
|
||||
per-session and injects them as conversation items (developer messages),
|
||||
not as system prompt text. The config.skills whitelist controls which
|
||||
skills are loaded:
|
||||
- None: load all enabled skills
|
||||
- []: no skills
|
||||
- ["skill-a", "skill-b"]: only these skills
|
||||
|
||||
Returns:
|
||||
List of SystemMessages containing skill content.
|
||||
"""
|
||||
if not skills:
|
||||
return []
|
||||
|
||||
@@ -347,19 +353,21 @@ class SubagentExecutor:
|
||||
|
||||
return messages
|
||||
|
||||
async def _build_initial_state(self, task: str) -> dict[str, Any]:
|
||||
async def _build_initial_state(self, task: str) -> tuple[dict[str, Any], list[BaseTool]]:
|
||||
"""Build the initial state for agent execution.
|
||||
|
||||
Args:
|
||||
task: The task description.
|
||||
|
||||
Returns:
|
||||
Initial state dictionary.
|
||||
Initial state dictionary and tools filtered by loaded skill metadata.
|
||||
"""
|
||||
# Load skills as conversation items (Codex pattern)
|
||||
skill_messages = await self._load_skill_messages()
|
||||
skills = await self._load_skills()
|
||||
filtered_tools = self._apply_skill_allowed_tools(skills)
|
||||
skill_messages = await self._load_skill_messages(skills)
|
||||
|
||||
messages: list = []
|
||||
messages: list[Any] = []
|
||||
# Skill content injected as developer/system messages before the task
|
||||
messages.extend(skill_messages)
|
||||
# Then the actual task
|
||||
@@ -375,7 +383,7 @@ class SubagentExecutor:
|
||||
if self.thread_data is not None:
|
||||
state["thread_data"] = self.thread_data
|
||||
|
||||
return state
|
||||
return state, filtered_tools
|
||||
|
||||
async def _aexecute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
|
||||
"""Execute a task asynchronously.
|
||||
@@ -405,8 +413,8 @@ class SubagentExecutor:
|
||||
result.ai_messages = ai_messages
|
||||
|
||||
try:
|
||||
agent = self._create_agent()
|
||||
state = await self._build_initial_state(task)
|
||||
state, filtered_tools = await self._build_initial_state(task)
|
||||
agent = self._create_agent(filtered_tools)
|
||||
|
||||
# Build config with thread_id for sandbox access and recursion limit
|
||||
run_config: RunnableConfig = {
|
||||
|
||||
@@ -2,10 +2,12 @@ from .clarification_tool import ask_clarification_tool
|
||||
from .present_file_tool import present_file_tool
|
||||
from .setup_agent_tool import setup_agent
|
||||
from .task_tool import task_tool
|
||||
from .update_agent_tool import update_agent
|
||||
from .view_image_tool import view_image_tool
|
||||
|
||||
__all__ = [
|
||||
"setup_agent",
|
||||
"update_agent",
|
||||
"present_file_tool",
|
||||
"ask_clarification_tool",
|
||||
"view_image_tool",
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain.tools import InjectedToolCallId, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.config import get_config
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
||||
|
||||
|
||||
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
|
||||
def _get_thread_id(runtime: Runtime) -> str | None:
|
||||
"""Resolve the current thread id from runtime context or RunnableConfig."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id:
|
||||
@@ -32,7 +31,7 @@ def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
|
||||
|
||||
|
||||
def _normalize_presented_filepath(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
filepath: str,
|
||||
) -> str:
|
||||
"""Normalize a presented file path to the `/mnt/user-data/outputs/*` contract.
|
||||
@@ -83,7 +82,7 @@ def _normalize_presented_filepath(
|
||||
|
||||
@tool("present_files", parse_docstring=True)
|
||||
def present_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
filepaths: list[str],
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
|
||||
@@ -3,20 +3,28 @@ import logging
|
||||
import yaml
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.prebuilt import ToolRuntime
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.config.agents_config import validate_agent_name
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_runtime_user_id(runtime: Runtime) -> str:
|
||||
context_user_id = runtime.context.get("user_id") if runtime.context else None
|
||||
if context_user_id:
|
||||
return str(context_user_id)
|
||||
return get_effective_user_id()
|
||||
|
||||
|
||||
@tool
|
||||
def setup_agent(
|
||||
soul: str,
|
||||
description: str,
|
||||
runtime: ToolRuntime,
|
||||
runtime: Runtime,
|
||||
skills: list[str] | None = None,
|
||||
) -> Command:
|
||||
"""Setup the custom DeerFlow agent.
|
||||
@@ -34,7 +42,14 @@ def setup_agent(
|
||||
try:
|
||||
agent_name = validate_agent_name(agent_name)
|
||||
paths = get_paths()
|
||||
agent_dir = paths.agent_dir(agent_name) if agent_name else paths.base_dir
|
||||
if agent_name:
|
||||
# Custom agents are persisted under the current user's bucket so
|
||||
# different users do not see each other's agents.
|
||||
user_id = _get_runtime_user_id(runtime)
|
||||
agent_dir = paths.user_agent_dir(user_id, agent_name)
|
||||
else:
|
||||
# Default agent (no agent_name): SOUL.md lives at the global base dir.
|
||||
agent_dir = paths.base_dir
|
||||
is_new_dir = not agent_dir.exists()
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -6,11 +6,9 @@ import uuid
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Annotated, Any, cast
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain.tools import InjectedToolCallId, tool
|
||||
from langgraph.config import get_stream_writer
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
|
||||
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
|
||||
@@ -21,6 +19,7 @@ from deerflow.subagents.executor import (
|
||||
get_background_task_result,
|
||||
request_cancel_background_task,
|
||||
)
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
@@ -50,12 +49,11 @@ def _merge_skill_allowlists(parent: list[str] | None, child: list[str] | None) -
|
||||
|
||||
@tool("task", parse_docstring=True)
|
||||
async def task_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
prompt: str,
|
||||
subagent_type: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
max_turns: int | None = None,
|
||||
) -> str:
|
||||
"""Delegate a task to a specialized subagent that runs in its own context.
|
||||
|
||||
@@ -91,7 +89,6 @@ async def task_tool(
|
||||
description: A short (3-5 word) description of the task for logging/display. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
prompt: The task description for the subagent. Be specific and clear about what needs to be done. ALWAYS PROVIDE THIS PARAMETER SECOND.
|
||||
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.
|
||||
"""
|
||||
runtime_app_config = _get_runtime_app_config(runtime)
|
||||
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
|
||||
@@ -113,9 +110,6 @@ async def task_tool(
|
||||
# each subagent loads its own skills based on config, injected as conversation items).
|
||||
# No longer appended to system_prompt here.
|
||||
|
||||
if max_turns is not None:
|
||||
overrides["max_turns"] = max_turns
|
||||
|
||||
# Extract parent context from runtime
|
||||
sandbox_state = None
|
||||
thread_data = None
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
"""update_agent tool — let a custom agent persist updates to its own SOUL.md / config.
|
||||
|
||||
Bound to the lead agent only when ``runtime.context['agent_name']`` is set
|
||||
(i.e. inside an existing custom agent's chat). The default agent does not see
|
||||
this tool, and the bootstrap flow continues to use ``setup_agent`` for the
|
||||
initial creation handshake.
|
||||
|
||||
The tool writes back to ``{base_dir}/users/{user_id}/agents/{agent_name}/{config.yaml,SOUL.md}``
|
||||
so an agent created by one user is never visible to (or mutable by) another.
|
||||
Writes are staged into temp files first; both files are renamed into place only
|
||||
after both temp files are successfully written, so a partial failure cannot leave
|
||||
config.yaml updated while SOUL.md still holds stale content.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.config.agents_config import load_agent_config, validate_agent_name
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stage_temp(path: Path, text: str) -> Path:
|
||||
"""Write ``text`` into a sibling temp file and return its path.
|
||||
|
||||
The caller is responsible for ``Path.replace``-ing the temp into the target
|
||||
once every staged file is ready, or for unlinking it on failure.
|
||||
"""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
dir=path.parent,
|
||||
suffix=".tmp",
|
||||
delete=False,
|
||||
encoding="utf-8",
|
||||
)
|
||||
try:
|
||||
fd.write(text)
|
||||
fd.flush()
|
||||
fd.close()
|
||||
return Path(fd.name)
|
||||
except BaseException:
|
||||
fd.close()
|
||||
Path(fd.name).unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
|
||||
def _cleanup_temps(temps: list[Path]) -> None:
|
||||
"""Best-effort removal of staged temp files."""
|
||||
for tmp in temps:
|
||||
try:
|
||||
tmp.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
logger.debug("Failed to clean up temp file %s", tmp, exc_info=True)
|
||||
|
||||
|
||||
@tool
|
||||
def update_agent(
|
||||
runtime: Runtime,
|
||||
soul: str | None = None,
|
||||
description: str | None = None,
|
||||
skills: list[str] | None = None,
|
||||
tool_groups: list[str] | None = None,
|
||||
model: str | None = None,
|
||||
) -> Command:
|
||||
"""Persist updates to the current custom agent's SOUL.md and config.yaml.
|
||||
|
||||
Use this when the user asks to refine the agent's identity, description,
|
||||
skill whitelist, tool-group whitelist, or default model. Only the fields
|
||||
you explicitly pass are updated; omitted fields keep their existing values.
|
||||
|
||||
Pass ``soul`` as the FULL replacement SOUL.md content — there is no patch
|
||||
semantics, so always start from the current SOUL and apply your edits.
|
||||
|
||||
Pass ``skills=[]`` to disable all skills for this agent. Omit ``skills``
|
||||
entirely to keep the existing whitelist.
|
||||
|
||||
Args:
|
||||
soul: Optional full replacement SOUL.md content.
|
||||
description: Optional new one-line description.
|
||||
skills: Optional skill whitelist. ``[]`` = no skills, omit = unchanged.
|
||||
tool_groups: Optional tool-group whitelist. ``[]`` = empty, omit = unchanged.
|
||||
model: Optional model override (must match a configured model name).
|
||||
|
||||
Returns:
|
||||
Command with a ToolMessage describing the result. Changes take effect
|
||||
on the next user turn (when the lead agent is rebuilt with the fresh
|
||||
SOUL.md and config.yaml).
|
||||
"""
|
||||
tool_call_id = runtime.tool_call_id
|
||||
agent_name_raw: str | None = runtime.context.get("agent_name") if runtime.context else None
|
||||
|
||||
def _err(message: str) -> Command:
|
||||
return Command(update={"messages": [ToolMessage(content=f"Error: {message}", tool_call_id=tool_call_id)]})
|
||||
|
||||
if soul is None and description is None and skills is None and tool_groups is None and model is None:
|
||||
return _err("No fields provided. Pass at least one of: soul, description, skills, tool_groups, model.")
|
||||
|
||||
try:
|
||||
agent_name = validate_agent_name(agent_name_raw)
|
||||
except ValueError as e:
|
||||
return _err(str(e))
|
||||
|
||||
if not agent_name:
|
||||
return _err("update_agent is only available inside a custom agent's chat. There is no agent_name in the current runtime context, so there is nothing to update. If you are inside the bootstrap flow, use setup_agent instead.")
|
||||
|
||||
# Resolve the active user so that updates only affect this user's agent.
|
||||
# ``get_effective_user_id`` returns DEFAULT_USER_ID when no auth context
|
||||
# is set (matching how memory and thread storage behave).
|
||||
user_id = get_effective_user_id()
|
||||
|
||||
# Reject an unknown ``model`` *before* touching the filesystem. Otherwise
|
||||
# ``_resolve_model_name`` silently falls back to the default at runtime
|
||||
# and the user sees confusing repeated warnings on every later turn.
|
||||
if model is not None and get_app_config().get_model_config(model) is None:
|
||||
return _err(f"Unknown model '{model}'. Pass a model name that exists in config.yaml's models section.")
|
||||
|
||||
paths = get_paths()
|
||||
agent_dir = paths.user_agent_dir(user_id, agent_name)
|
||||
if not agent_dir.exists() and paths.agent_dir(agent_name).exists():
|
||||
return _err(f"Agent '{agent_name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before updating.")
|
||||
|
||||
try:
|
||||
existing_cfg = load_agent_config(agent_name, user_id=user_id)
|
||||
except FileNotFoundError:
|
||||
return _err(f"Agent '{agent_name}' does not exist for the current user. Use setup_agent to create a new agent first.")
|
||||
except ValueError as e:
|
||||
return _err(f"Agent '{agent_name}' has an unreadable config: {e}")
|
||||
|
||||
if existing_cfg is None:
|
||||
return _err(f"Agent '{agent_name}' could not be loaded.")
|
||||
|
||||
updated_fields: list[str] = []
|
||||
|
||||
# Force the on-disk ``name`` to match the directory we are writing into,
|
||||
# even if ``existing_cfg.name`` had drifted (e.g. from manual yaml edits).
|
||||
config_data: dict[str, Any] = {"name": agent_name}
|
||||
new_description = description if description is not None else existing_cfg.description
|
||||
config_data["description"] = new_description
|
||||
if description is not None and description != existing_cfg.description:
|
||||
updated_fields.append("description")
|
||||
|
||||
new_model = model if model is not None else existing_cfg.model
|
||||
if new_model is not None:
|
||||
config_data["model"] = new_model
|
||||
if model is not None and model != existing_cfg.model:
|
||||
updated_fields.append("model")
|
||||
|
||||
new_tool_groups = tool_groups if tool_groups is not None else existing_cfg.tool_groups
|
||||
if new_tool_groups is not None:
|
||||
config_data["tool_groups"] = new_tool_groups
|
||||
if tool_groups is not None and tool_groups != existing_cfg.tool_groups:
|
||||
updated_fields.append("tool_groups")
|
||||
|
||||
new_skills = skills if skills is not None else existing_cfg.skills
|
||||
if new_skills is not None:
|
||||
config_data["skills"] = new_skills
|
||||
if skills is not None and skills != existing_cfg.skills:
|
||||
updated_fields.append("skills")
|
||||
|
||||
config_changed = bool({"description", "model", "tool_groups", "skills"} & set(updated_fields))
|
||||
|
||||
# Stage every file we intend to rewrite into a temp sibling. Only after
|
||||
# *all* temp files exist do we rename them into place — so a failure on
|
||||
# SOUL.md cannot leave config.yaml already replaced.
|
||||
pending: list[tuple[Path, Path]] = []
|
||||
staged_temps: list[Path] = []
|
||||
|
||||
try:
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if config_changed:
|
||||
yaml_text = yaml.dump(config_data, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
||||
config_target = agent_dir / "config.yaml"
|
||||
config_tmp = _stage_temp(config_target, yaml_text)
|
||||
staged_temps.append(config_tmp)
|
||||
pending.append((config_tmp, config_target))
|
||||
|
||||
if soul is not None:
|
||||
soul_target = agent_dir / "SOUL.md"
|
||||
soul_tmp = _stage_temp(soul_target, soul)
|
||||
staged_temps.append(soul_tmp)
|
||||
pending.append((soul_tmp, soul_target))
|
||||
updated_fields.append("soul")
|
||||
|
||||
# Commit phase. ``Path.replace`` is atomic per file on POSIX/NTFS and
|
||||
# the staging step above means any earlier failure has already been
|
||||
# reported. The remaining failure mode is a crash *between* two
|
||||
# ``replace`` calls, which is reported via the partial-write error
|
||||
# branch below so the caller knows which files are now on disk.
|
||||
committed: list[Path] = []
|
||||
try:
|
||||
for tmp, target in pending:
|
||||
tmp.replace(target)
|
||||
committed.append(target)
|
||||
except Exception as e:
|
||||
_cleanup_temps([t for t, _ in pending if t not in committed])
|
||||
if committed:
|
||||
logger.error(
|
||||
"[update_agent] Partial write for agent '%s' (user=%s): committed=%s, failed during rename: %s",
|
||||
agent_name,
|
||||
user_id,
|
||||
[p.name for p in committed],
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return _err(f"Partial update for agent '{agent_name}': {[p.name for p in committed]} were updated, but the rest failed ({e}). Re-run update_agent to retry the remaining fields.")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
_cleanup_temps(staged_temps)
|
||||
logger.error("[update_agent] Failed to update agent '%s' (user=%s): %s", agent_name, user_id, e, exc_info=True)
|
||||
return _err(f"Failed to update agent '{agent_name}': {e}")
|
||||
|
||||
if not updated_fields:
|
||||
return Command(update={"messages": [ToolMessage(content=f"No changes applied to agent '{agent_name}'. The provided values matched the existing config.", tool_call_id=tool_call_id)]})
|
||||
|
||||
logger.info("[update_agent] Updated agent '%s' (user=%s) fields: %s", agent_name, user_id, updated_fields)
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(f"Agent '{agent_name}' updated successfully. Changed: {', '.join(updated_fields)}. The new configuration takes effect on the next user turn."),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
@@ -3,13 +3,13 @@ import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain.tools import InjectedToolCallId, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState, ThreadState
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
_ALLOWED_IMAGE_VIRTUAL_ROOTS = (
|
||||
f"{VIRTUAL_PATH_PREFIX}/workspace",
|
||||
@@ -48,7 +48,7 @@ def _sanitize_image_error(error: Exception, thread_data: ThreadDataState | None)
|
||||
|
||||
@tool("view_image", parse_docstring=True)
|
||||
def view_image_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
image_path: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
|
||||
@@ -7,16 +7,15 @@ import logging
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
from langgraph.typing import ContextT
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.mcp.tools import _make_sync_tool_wrapper
|
||||
from deerflow.skills.security_scanner import scan_skill_content
|
||||
from deerflow.skills.storage import get_or_new_skill_storage
|
||||
from deerflow.skills.storage.skill_storage import SkillStorage
|
||||
from deerflow.skills.types import SKILL_MD_FILE
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,7 +30,7 @@ def _get_lock(name: str) -> asyncio.Lock:
|
||||
return lock
|
||||
|
||||
|
||||
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
|
||||
def _get_thread_id(runtime: Runtime | None) -> str | None:
|
||||
if runtime is None:
|
||||
return None
|
||||
if runtime.context and runtime.context.get("thread_id"):
|
||||
@@ -65,7 +64,7 @@ async def _to_thread(func, /, *args, **kwargs):
|
||||
|
||||
|
||||
async def _skill_manage_impl(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
action: str,
|
||||
name: str,
|
||||
content: str | None = None,
|
||||
@@ -204,7 +203,7 @@ async def _skill_manage_impl(
|
||||
|
||||
@tool("skill_manage", parse_docstring=True)
|
||||
async def skill_manage_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
action: str,
|
||||
name: str,
|
||||
content: str | None = None,
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
# Concrete runtime type used by all DeerFlow tools.
|
||||
# Using dict[str, Any] for the context parameter instead of the unbound ContextT
|
||||
# TypeVar prevents PydanticSerializationUnexpectedValue warnings when LangChain
|
||||
# calls model_dump() on a tool's auto-generated args_schema.
|
||||
Runtime = ToolRuntime[dict[str, Any], ThreadState]
|
||||
@@ -121,9 +121,11 @@ def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, ob
|
||||
Upload directories may be mounted into local sandboxes. A sandbox process can
|
||||
therefore leave a symlink at a future upload filename. Normal ``Path.write_bytes``
|
||||
follows that link and can overwrite files outside the uploads directory with
|
||||
gateway privileges. This helper rejects symlink destinations and uses
|
||||
``O_NOFOLLOW`` where available so the final path component cannot be raced into
|
||||
a symlink between validation and open.
|
||||
gateway privileges. This helper rejects symlink destinations using ``O_NOFOLLOW``
|
||||
on POSIX. On Windows (which lacks ``O_NOFOLLOW``), it uses dual ``lstat`` checks
|
||||
and ``fstat`` validation after ``open()`` to reduce the TOCTOU window; this does
|
||||
not eliminate all races but makes exploitation significantly harder. Path-traversal
|
||||
validation prevents escapes from *base_dir* in both cases.
|
||||
"""
|
||||
safe_name = normalize_filename(filename)
|
||||
dest = base_dir / safe_name
|
||||
@@ -138,23 +140,65 @@ def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, ob
|
||||
|
||||
validate_path_traversal(dest, base_dir)
|
||||
|
||||
if not hasattr(os, "O_NOFOLLOW"):
|
||||
raise UnsafeUploadPathError("Upload writes require O_NOFOLLOW support")
|
||||
has_nofollow = hasattr(os, "O_NOFOLLOW")
|
||||
|
||||
flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW
|
||||
if hasattr(os, "O_NONBLOCK"):
|
||||
flags |= os.O_NONBLOCK
|
||||
if has_nofollow:
|
||||
# POSIX: O_NOFOLLOW makes open() fail with ELOOP if dest is a symlink.
|
||||
flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW
|
||||
if hasattr(os, "O_NONBLOCK"):
|
||||
flags |= os.O_NONBLOCK
|
||||
|
||||
try:
|
||||
fd = os.open(dest, flags, 0o600)
|
||||
except OSError as exc:
|
||||
if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
|
||||
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
|
||||
raise
|
||||
|
||||
try:
|
||||
opened_stat = os.fstat(fd)
|
||||
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1:
|
||||
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
|
||||
os.ftruncate(fd, 0)
|
||||
fh = os.fdopen(fd, "wb")
|
||||
fd = -1
|
||||
finally:
|
||||
if fd >= 0:
|
||||
os.close(fd)
|
||||
return dest, fh
|
||||
|
||||
# Windows: no O_NOFOLLOW available. Uses a second lstat immediately before open()
|
||||
# to narrow the TOCTOU window, then fstat after open() as a further defence.
|
||||
# Note: a narrow race window remains between the pre-open lstat and open(); the
|
||||
# path-traversal check mitigates escapes from base_dir but cannot prevent an
|
||||
# attacker who can atomically replace dest with a symlink after the check.
|
||||
if st is not None and st.st_nlink > 1:
|
||||
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
|
||||
|
||||
flags = os.O_WRONLY | os.O_CREAT
|
||||
if hasattr(os, "O_BINARY"):
|
||||
flags |= os.O_BINARY
|
||||
|
||||
try:
|
||||
pre_open_st = os.lstat(dest)
|
||||
except FileNotFoundError:
|
||||
pre_open_st = None
|
||||
|
||||
if pre_open_st is not None and not stat.S_ISREG(pre_open_st.st_mode):
|
||||
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
|
||||
if pre_open_st is not None and pre_open_st.st_nlink > 1:
|
||||
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
|
||||
|
||||
try:
|
||||
fd = os.open(dest, flags, 0o600)
|
||||
except OSError as exc:
|
||||
if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
|
||||
if exc.errno in {errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
|
||||
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
|
||||
raise
|
||||
|
||||
try:
|
||||
opened_stat = os.fstat(fd)
|
||||
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1:
|
||||
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink > 1:
|
||||
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
|
||||
os.ftruncate(fd, 0)
|
||||
fh = os.fdopen(fd, "wb")
|
||||
|
||||
@@ -8,7 +8,7 @@ dependencies = [
|
||||
"deerflow-harness",
|
||||
"fastapi>=0.115.0",
|
||||
"httpx>=0.28.0",
|
||||
"python-multipart>=0.0.26",
|
||||
"python-multipart>=0.0.27",
|
||||
"sse-starlette>=2.1.0",
|
||||
"uvicorn[standard]>=0.34.0",
|
||||
"lark-oapi>=1.4.0",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""One-time migration: move legacy thread dirs and memory into per-user layout.
|
||||
|
||||
Usage:
|
||||
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run]
|
||||
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run] [--user-id USER_ID]
|
||||
|
||||
The script is idempotent — re-running it after a successful migration is a no-op.
|
||||
"""
|
||||
@@ -69,6 +69,67 @@ def migrate_thread_dirs(
|
||||
return report
|
||||
|
||||
|
||||
def migrate_agents(
|
||||
paths: Paths,
|
||||
user_id: str = "default",
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> list[dict]:
|
||||
"""Move legacy custom-agent directories into per-user layout.
|
||||
|
||||
Legacy layout: ``{base_dir}/agents/{name}/``
|
||||
Per-user layout: ``{base_dir}/users/{user_id}/agents/{name}/``
|
||||
|
||||
Pre-existing per-user agents take precedence: if a destination already
|
||||
exists for an agent name, the legacy copy is moved to
|
||||
``{base_dir}/migration-conflicts/agents/{name}/`` for manual review.
|
||||
|
||||
Args:
|
||||
paths: Paths instance.
|
||||
user_id: Target user to receive the legacy agents (defaults to
|
||||
``"default"``, matching ``DEFAULT_USER_ID`` for no-auth setups).
|
||||
dry_run: If True, only log what would happen.
|
||||
|
||||
Returns:
|
||||
List of migration report entries, one per legacy agent directory found.
|
||||
"""
|
||||
report: list[dict] = []
|
||||
legacy_agents = paths.agents_dir
|
||||
if not legacy_agents.exists():
|
||||
logger.info("No legacy agents directory found — nothing to migrate.")
|
||||
return report
|
||||
|
||||
for agent_dir in sorted(legacy_agents.iterdir()):
|
||||
if not agent_dir.is_dir():
|
||||
continue
|
||||
agent_name = agent_dir.name
|
||||
dest = paths.user_agent_dir(user_id, agent_name)
|
||||
|
||||
entry = {"agent": agent_name, "user_id": user_id, "action": ""}
|
||||
|
||||
if dest.exists():
|
||||
conflicts_dir = paths.base_dir / "migration-conflicts" / "agents" / agent_name
|
||||
entry["action"] = f"conflict -> {conflicts_dir}"
|
||||
if not dry_run:
|
||||
conflicts_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(agent_dir), str(conflicts_dir))
|
||||
logger.warning("Conflict for agent %s: moved legacy copy to %s", agent_name, conflicts_dir)
|
||||
else:
|
||||
entry["action"] = f"moved -> {dest}"
|
||||
if not dry_run:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(agent_dir), str(dest))
|
||||
logger.info("Migrated agent %s -> user %s", agent_name, user_id)
|
||||
|
||||
report.append(entry)
|
||||
|
||||
# Clean up empty legacy agents dir
|
||||
if not dry_run and legacy_agents.exists() and not any(legacy_agents.iterdir()):
|
||||
legacy_agents.rmdir()
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def migrate_memory(
|
||||
paths: Paths,
|
||||
user_id: str = "default",
|
||||
@@ -127,6 +188,12 @@ def _build_owner_map_from_db(paths: Paths) -> dict[str, str]:
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Migrate DeerFlow data to per-user layout")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Log actions without making changes")
|
||||
parser.add_argument(
|
||||
"--user-id",
|
||||
default="default",
|
||||
metavar="USER_ID",
|
||||
help=("User ID to claim un-owned legacy data (global memory.json and legacy custom agents). Defaults to 'default'. In multi-user installs, set this to the operator account that should inherit those legacy artifacts."),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
@@ -134,26 +201,42 @@ def main() -> None:
|
||||
paths = get_paths()
|
||||
logger.info("Base directory: %s", paths.base_dir)
|
||||
logger.info("Dry run: %s", args.dry_run)
|
||||
logger.info("Claiming un-owned legacy data for user_id=%s", args.user_id)
|
||||
|
||||
owner_map = _build_owner_map_from_db(paths)
|
||||
logger.info("Found %d thread ownership records in DB", len(owner_map))
|
||||
|
||||
report = migrate_thread_dirs(paths, owner_map, dry_run=args.dry_run)
|
||||
migrate_memory(paths, user_id="default", dry_run=args.dry_run)
|
||||
migrate_memory(paths, user_id=args.user_id, dry_run=args.dry_run)
|
||||
agent_report = migrate_agents(paths, user_id=args.user_id, dry_run=args.dry_run)
|
||||
|
||||
if report:
|
||||
logger.info("Migration report:")
|
||||
logger.info("Thread migration report:")
|
||||
for entry in report:
|
||||
logger.info(" thread=%s user=%s action=%s", entry["thread_id"], entry["user_id"], entry["action"])
|
||||
else:
|
||||
logger.info("No threads to migrate.")
|
||||
|
||||
if agent_report:
|
||||
logger.info("Agent migration report:")
|
||||
for entry in agent_report:
|
||||
logger.info(" agent=%s user=%s action=%s", entry["agent"], entry["user_id"], entry["action"])
|
||||
else:
|
||||
logger.info("No agents to migrate.")
|
||||
|
||||
unowned = [e for e in report if e["user_id"] == "default"]
|
||||
if unowned:
|
||||
logger.warning("%d thread(s) had no owner and were assigned to 'default':", len(unowned))
|
||||
for e in unowned:
|
||||
logger.warning(" %s", e["thread_id"])
|
||||
|
||||
if agent_report:
|
||||
logger.warning(
|
||||
"%d legacy agent(s) were assigned to '%s'. If those agents belonged to other users, move them manually under {base_dir}/users/<user_id>/agents/.",
|
||||
len(agent_report),
|
||||
args.user_id,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
"""Tests for AioSandboxProvider auto-restart of crashed containers."""
|
||||
|
||||
import importlib
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _import_provider():
|
||||
return importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
|
||||
|
||||
def _make_provider(*, auto_restart=True, alive=True):
|
||||
"""Build a minimal AioSandboxProvider with a mock backend.
|
||||
|
||||
Args:
|
||||
auto_restart: Value for the auto_restart config key.
|
||||
alive: Whether the mock backend reports containers as alive.
|
||||
"""
|
||||
mod = _import_provider()
|
||||
with patch.object(mod.AioSandboxProvider, "_start_idle_checker"):
|
||||
provider = mod.AioSandboxProvider.__new__(mod.AioSandboxProvider)
|
||||
provider._config = {"auto_restart": auto_restart}
|
||||
provider._lock = threading.Lock()
|
||||
provider._sandboxes = {}
|
||||
provider._sandbox_infos = {}
|
||||
provider._thread_sandboxes = {}
|
||||
provider._thread_locks = {}
|
||||
provider._last_activity = {}
|
||||
provider._warm_pool = {}
|
||||
provider._shutdown_called = False
|
||||
provider._idle_checker_stop = threading.Event()
|
||||
|
||||
backend = MagicMock()
|
||||
backend.is_alive.return_value = alive
|
||||
provider._backend = backend
|
||||
|
||||
return provider, backend
|
||||
|
||||
|
||||
def _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1"):
|
||||
"""Insert a sandbox into the provider's caches as if it were acquired."""
|
||||
sandbox = MagicMock()
|
||||
info = MagicMock()
|
||||
|
||||
provider._sandboxes[sandbox_id] = sandbox
|
||||
provider._sandbox_infos[sandbox_id] = info
|
||||
provider._last_activity[sandbox_id] = 0.0
|
||||
if thread_id:
|
||||
provider._thread_sandboxes[thread_id] = sandbox_id
|
||||
|
||||
return sandbox, info
|
||||
|
||||
|
||||
# ── get() returns sandbox when container is alive ──────────────────────────
|
||||
|
||||
|
||||
def test_get_returns_sandbox_when_container_alive():
|
||||
"""When auto_restart is on and the container is alive, get() returns the sandbox."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=True)
|
||||
sandbox, _ = _seed_sandbox(provider)
|
||||
|
||||
result = provider.get("dead-beef")
|
||||
|
||||
assert result is sandbox
|
||||
backend.is_alive.assert_called_once()
|
||||
|
||||
|
||||
def test_get_returns_sandbox_when_auto_restart_disabled():
|
||||
"""When auto_restart is off, get() skips the health check entirely."""
|
||||
provider, backend = _make_provider(auto_restart=False)
|
||||
sandbox, _ = _seed_sandbox(provider)
|
||||
|
||||
result = provider.get("dead-beef")
|
||||
|
||||
assert result is sandbox
|
||||
backend.is_alive.assert_not_called()
|
||||
|
||||
|
||||
# ── get() evicts dead sandbox when auto_restart is on ──────────────────────
|
||||
|
||||
|
||||
def test_get_evicts_dead_sandbox_when_auto_restart_enabled():
|
||||
"""When the container is dead and auto_restart is on, get() returns None and cleans caches."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_, info = _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1")
|
||||
|
||||
result = provider.get("dead-beef")
|
||||
|
||||
assert result is None
|
||||
assert "dead-beef" not in provider._sandboxes
|
||||
assert "dead-beef" not in provider._sandbox_infos
|
||||
assert "dead-beef" not in provider._last_activity
|
||||
assert "thread-1" not in provider._thread_sandboxes
|
||||
backend.destroy.assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_get_returns_dead_sandbox_when_auto_restart_disabled():
|
||||
"""When auto_restart is off, get() returns the cached sandbox even if the container is dead."""
|
||||
provider, backend = _make_provider(auto_restart=False, alive=False)
|
||||
sandbox, _ = _seed_sandbox(provider)
|
||||
|
||||
result = provider.get("dead-beef")
|
||||
|
||||
assert result is sandbox
|
||||
# Caches are untouched
|
||||
assert "dead-beef" in provider._sandboxes
|
||||
|
||||
|
||||
def test_get_eviction_cleans_multiple_thread_mappings():
|
||||
"""A sandbox mapped to multiple thread IDs has all mappings cleaned on eviction."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="t-a")
|
||||
# Manually add a second thread mapping to the same sandbox
|
||||
provider._thread_sandboxes["t-b"] = "sid-1"
|
||||
|
||||
result = provider.get("sid-1")
|
||||
|
||||
assert result is None
|
||||
assert "t-a" not in provider._thread_sandboxes
|
||||
assert "t-b" not in provider._thread_sandboxes
|
||||
|
||||
|
||||
# ── get() does not check health for unknown sandbox IDs ────────────────────
|
||||
|
||||
|
||||
def test_get_returns_none_for_unknown_id():
|
||||
"""If the sandbox_id is not in cache, get() returns None without checking health."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=True)
|
||||
|
||||
result = provider.get("nonexistent")
|
||||
|
||||
assert result is None
|
||||
backend.is_alive.assert_not_called()
|
||||
|
||||
|
||||
# ── get() handles missing sandbox_info gracefully ──────────────────────────
|
||||
|
||||
|
||||
def test_get_handles_missing_info_gracefully():
|
||||
"""If sandbox is cached but info is missing, get() skips the health check."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
sandbox = MagicMock()
|
||||
provider._sandboxes["sid-x"] = sandbox
|
||||
provider._sandbox_infos.pop("sid-x", None) # Ensure no info
|
||||
provider._last_activity["sid-x"] = 0.0
|
||||
|
||||
result = provider.get("sid-x")
|
||||
|
||||
# No info → cannot call is_alive → sandbox returned as-is
|
||||
assert result is sandbox
|
||||
backend.is_alive.assert_not_called()
|
||||
|
||||
|
||||
def test_get_liveness_check_runs_outside_provider_lock():
|
||||
"""get() should not hold the provider lock while checking backend liveness."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_seed_sandbox(provider, sandbox_id="sid-locked", thread_id="thread-1")
|
||||
|
||||
def _assert_lock_not_held(_):
|
||||
assert not provider._lock.locked()
|
||||
return False
|
||||
|
||||
backend.is_alive.side_effect = _assert_lock_not_held
|
||||
|
||||
assert provider.get("sid-locked") is None
|
||||
|
||||
|
||||
def test_get_still_evicts_when_backend_destroy_fails():
|
||||
"""Cleanup errors should not keep stale sandbox state in memory."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_seed_sandbox(provider, sandbox_id="sid-fail", thread_id="thread-1")
|
||||
backend.destroy.side_effect = RuntimeError("boom")
|
||||
|
||||
assert provider.get("sid-fail") is None
|
||||
assert "sid-fail" not in provider._sandboxes
|
||||
assert "sid-fail" not in provider._sandbox_infos
|
||||
assert "thread-1" not in provider._thread_sandboxes
|
||||
backend.destroy.assert_called_once()
|
||||
|
||||
|
||||
# ── Integration: eviction clears caches for recreation ─────────────────────
|
||||
|
||||
|
||||
def test_eviction_clears_all_caches_for_recreation():
|
||||
"""After eviction, all caches are clean so _acquire_internal can recreate.
|
||||
|
||||
This verifies the preconditions for transparent restart: when get() evicts
|
||||
a dead sandbox, the next _acquire_internal call will find no cached entry,
|
||||
no warm-pool entry, and fall through to _create_sandbox.
|
||||
"""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="thread-1")
|
||||
|
||||
# Before eviction: caches populated
|
||||
assert "sid-1" in provider._sandboxes
|
||||
assert "sid-1" in provider._sandbox_infos
|
||||
assert "thread-1" in provider._thread_sandboxes
|
||||
|
||||
# get() detects the dead container and evicts
|
||||
assert provider.get("sid-1") is None
|
||||
|
||||
# After eviction: all caches clean
|
||||
assert "sid-1" not in provider._sandboxes
|
||||
assert "sid-1" not in provider._sandbox_infos
|
||||
assert "thread-1" not in provider._thread_sandboxes
|
||||
assert "sid-1" not in provider._warm_pool
|
||||
|
||||
# _acquire_internal for the same thread would find nothing cached
|
||||
# and generate the deterministic ID, then discover fails (container
|
||||
# is gone), falling through to _create_sandbox — a fresh start.
|
||||
@@ -4,10 +4,40 @@ import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from deerflow.config.agents_api_config import get_agents_api_config
|
||||
import deerflow.config.app_config as app_config_module
|
||||
from deerflow.config.acp_config import load_acp_config_from_dict
|
||||
from deerflow.config.agents_api_config import get_agents_api_config, load_agents_api_config_from_dict
|
||||
from deerflow.config.app_config import AppConfig, get_app_config, reset_app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config, load_checkpointer_config_from_dict
|
||||
from deerflow.config.guardrails_config import get_guardrails_config, load_guardrails_config_from_dict
|
||||
from deerflow.config.memory_config import get_memory_config, load_memory_config_from_dict
|
||||
from deerflow.config.stream_bridge_config import get_stream_bridge_config, load_stream_bridge_config_from_dict
|
||||
from deerflow.config.subagents_config import get_subagents_app_config, load_subagents_config_from_dict
|
||||
from deerflow.config.summarization_config import get_summarization_config, load_summarization_config_from_dict
|
||||
from deerflow.config.title_config import get_title_config, load_title_config_from_dict
|
||||
from deerflow.config.tool_search_config import get_tool_search_config, load_tool_search_config_from_dict
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
from deerflow.runtime.store import get_store, reset_store
|
||||
|
||||
|
||||
def _reset_config_singletons() -> None:
|
||||
load_title_config_from_dict({})
|
||||
load_summarization_config_from_dict({})
|
||||
load_memory_config_from_dict({})
|
||||
load_agents_api_config_from_dict({})
|
||||
load_subagents_config_from_dict({})
|
||||
load_tool_search_config_from_dict({})
|
||||
load_guardrails_config_from_dict({})
|
||||
load_checkpointer_config_from_dict(None)
|
||||
load_stream_bridge_config_from_dict(None)
|
||||
load_acp_config_from_dict({})
|
||||
reset_checkpointer()
|
||||
reset_store()
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
|
||||
@@ -53,6 +83,23 @@ def _write_config_with_agents_api(
|
||||
path.write_text(yaml.safe_dump(config), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_config_with_sections(path: Path, sections: dict | None = None) -> None:
|
||||
config = {
|
||||
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
|
||||
"models": [
|
||||
{
|
||||
"name": "first-model",
|
||||
"use": "langchain_openai:ChatOpenAI",
|
||||
"model": "gpt-test",
|
||||
}
|
||||
],
|
||||
}
|
||||
if sections:
|
||||
config.update(sections)
|
||||
|
||||
path.write_text(yaml.safe_dump(config), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_extensions_config(path: Path) -> None:
|
||||
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
|
||||
|
||||
@@ -175,3 +222,168 @@ def test_get_app_config_resets_agents_api_config_when_section_removed(tmp_path,
|
||||
assert get_agents_api_config().enabled is False
|
||||
finally:
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def test_get_app_config_resets_singleton_configs_when_sections_removed(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config_with_sections(
|
||||
config_path,
|
||||
{
|
||||
"title": {"enabled": False, "max_words": 3},
|
||||
"summarization": {"enabled": True},
|
||||
"memory": {"enabled": False, "max_facts": 50},
|
||||
"subagents": {"timeout_seconds": 42, "agents": {"reviewer": {"max_turns": 2}}},
|
||||
"tool_search": {"enabled": True},
|
||||
"guardrails": {"enabled": True, "fail_closed": False},
|
||||
"checkpointer": {"type": "memory"},
|
||||
"stream_bridge": {"type": "memory", "queue_maxsize": 12},
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
reset_app_config()
|
||||
|
||||
try:
|
||||
get_app_config()
|
||||
assert get_title_config().enabled is False
|
||||
assert get_summarization_config().enabled is True
|
||||
assert get_memory_config().enabled is False
|
||||
assert get_subagents_app_config().timeout_seconds == 42
|
||||
assert get_tool_search_config().enabled is True
|
||||
assert get_guardrails_config().enabled is True
|
||||
assert get_checkpointer_config() is not None
|
||||
assert get_stream_bridge_config() is not None
|
||||
|
||||
_write_config_with_sections(config_path)
|
||||
next_mtime = config_path.stat().st_mtime + 5
|
||||
os.utime(config_path, (next_mtime, next_mtime))
|
||||
|
||||
get_app_config()
|
||||
assert get_title_config().enabled is True
|
||||
assert get_summarization_config().enabled is False
|
||||
assert get_memory_config().enabled is True
|
||||
assert get_subagents_app_config().timeout_seconds == 900
|
||||
assert get_tool_search_config().enabled is False
|
||||
assert get_guardrails_config().enabled is False
|
||||
assert get_checkpointer_config() is None
|
||||
assert get_stream_bridge_config() is None
|
||||
finally:
|
||||
_reset_config_singletons()
|
||||
|
||||
|
||||
def test_get_app_config_resets_persistence_runtime_singletons_when_checkpointer_removed(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config_with_sections(config_path, {"checkpointer": {"type": "memory"}})
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
reset_checkpointer()
|
||||
reset_store()
|
||||
reset_app_config()
|
||||
|
||||
try:
|
||||
get_app_config()
|
||||
initial_checkpointer = get_checkpointer()
|
||||
initial_store = get_store()
|
||||
|
||||
_write_config_with_sections(config_path)
|
||||
next_mtime = config_path.stat().st_mtime + 5
|
||||
os.utime(config_path, (next_mtime, next_mtime))
|
||||
|
||||
get_app_config()
|
||||
|
||||
assert get_checkpointer_config() is None
|
||||
assert get_checkpointer() is not initial_checkpointer
|
||||
assert get_store() is not initial_store
|
||||
finally:
|
||||
_reset_config_singletons()
|
||||
|
||||
|
||||
def test_get_app_config_keeps_persistence_runtime_singletons_when_checkpointer_unchanged(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config_with_sections(
|
||||
config_path,
|
||||
{
|
||||
"title": {"enabled": False},
|
||||
"checkpointer": {"type": "memory"},
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
_reset_config_singletons()
|
||||
|
||||
try:
|
||||
get_app_config()
|
||||
initial_checkpointer = get_checkpointer()
|
||||
initial_store = get_store()
|
||||
|
||||
_write_config_with_sections(
|
||||
config_path,
|
||||
{
|
||||
"title": {"enabled": True},
|
||||
"checkpointer": {"type": "memory"},
|
||||
},
|
||||
)
|
||||
next_mtime = config_path.stat().st_mtime + 5
|
||||
os.utime(config_path, (next_mtime, next_mtime))
|
||||
|
||||
get_app_config()
|
||||
|
||||
assert get_checkpointer() is initial_checkpointer
|
||||
assert get_store() is initial_store
|
||||
finally:
|
||||
_reset_config_singletons()
|
||||
|
||||
|
||||
def test_get_app_config_does_not_mutate_singletons_when_reload_validation_fails(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config_with_sections(
|
||||
config_path,
|
||||
{
|
||||
"title": {"enabled": False},
|
||||
"tool_search": {"enabled": True},
|
||||
"checkpointer": {"type": "memory"},
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
_reset_config_singletons()
|
||||
|
||||
try:
|
||||
previous_app_config = get_app_config()
|
||||
initial_checkpointer = get_checkpointer()
|
||||
initial_store = get_store()
|
||||
|
||||
_write_config_with_sections(
|
||||
config_path,
|
||||
{
|
||||
"title": False,
|
||||
"tool_search": False,
|
||||
"checkpointer": {"type": "memory"},
|
||||
},
|
||||
)
|
||||
next_mtime = config_path.stat().st_mtime + 5
|
||||
os.utime(config_path, (next_mtime, next_mtime))
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
get_app_config()
|
||||
|
||||
assert app_config_module._app_config is previous_app_config
|
||||
assert get_title_config().enabled is False
|
||||
assert get_tool_search_config().enabled is True
|
||||
assert get_checkpointer_config() is not None
|
||||
assert get_checkpointer() is initial_checkpointer
|
||||
assert get_store() is initial_store
|
||||
finally:
|
||||
_reset_config_singletons()
|
||||
|
||||
@@ -372,6 +372,37 @@ class TestExtractResponseText:
|
||||
# Should return "" (no text in current turn), NOT "Hi there!" from previous turn
|
||||
assert _extract_response_text(result) == ""
|
||||
|
||||
def test_does_not_publish_loop_warning_on_tool_calling_ai_message(self):
|
||||
"""Loop-detection warning text on a tool-calling AI message is middleware-authored."""
|
||||
from app.channels.manager import _extract_response_text
|
||||
|
||||
result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "search the repo"},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "[LOOP DETECTED] You are repeating the same tool calls.",
|
||||
"tool_calls": [{"name": "grep", "args": {"pattern": "TODO"}, "id": "call_1"}],
|
||||
},
|
||||
]
|
||||
}
|
||||
assert _extract_response_text(result) == ""
|
||||
|
||||
def test_preserves_visible_text_when_stripping_loop_warning(self):
|
||||
from app.channels.manager import _extract_response_text
|
||||
|
||||
result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "prepare the report"},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "Here is the report.\n\n[LOOP DETECTED] You are repeating the same tool calls.",
|
||||
"tool_calls": [{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/report.md"]}, "id": "call_1"}],
|
||||
},
|
||||
]
|
||||
}
|
||||
assert _extract_response_text(result) == "Here is the report."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelManager tests
|
||||
@@ -435,6 +466,47 @@ class TestChannelManager:
|
||||
assert headers["Cookie"] == f"csrf_token={csrf_token}"
|
||||
assert headers["X-DeerFlow-Internal-Token"]
|
||||
|
||||
def test_fetch_gateway_includes_internal_auth_headers(self, monkeypatch):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
class MockResponse:
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
def json(self):
|
||||
return {"models": [{"name": "default"}]}
|
||||
|
||||
class MockAsyncClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, **kwargs):
|
||||
calls.append({"url": url, **kwargs})
|
||||
return MockResponse()
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr("app.channels.manager.httpx.AsyncClient", MockAsyncClient)
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store, gateway_url="http://gateway:8001")
|
||||
|
||||
reply = await manager._fetch_gateway("/api/models", "models")
|
||||
|
||||
assert reply == "Available models:\n• default"
|
||||
assert calls[0]["url"] == "http://gateway:8001/api/models"
|
||||
assert calls[0]["timeout"] == 10
|
||||
assert calls[0]["headers"]["X-DeerFlow-Internal-Token"]
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_chat_calls_channel_receive_file_for_inbound_files(self, monkeypatch):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
@@ -530,6 +602,8 @@ class TestChannelManager:
|
||||
assert call_args[0][0] == "test-thread-123" # thread_id
|
||||
assert call_args[0][1] == "lead_agent" # assistant_id
|
||||
assert call_args[1]["input"]["messages"][0]["content"] == "hi"
|
||||
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
|
||||
|
||||
assert len(outbound_received) == 1
|
||||
assert outbound_received[0].text == "Hello from agent!"
|
||||
@@ -661,12 +735,135 @@ class TestChannelManager:
|
||||
call_args = mock_client.runs.wait.call_args
|
||||
assert call_args[0][1] == "lead_agent"
|
||||
assert call_args[1]["config"]["recursion_limit"] == 55
|
||||
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
|
||||
assert call_args[1]["context"]["thinking_enabled"] is False
|
||||
assert call_args[1]["context"]["subagent_enabled"] is True
|
||||
assert call_args[1]["context"]["agent_name"] == "mobile-agent"
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_clarification_follow_up_preserves_history(self):
|
||||
"""Conversation should continue after ask_clarification instead of resetting history."""
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store)
|
||||
|
||||
outbound_received = []
|
||||
|
||||
async def capture_outbound(msg):
|
||||
outbound_received.append(msg)
|
||||
|
||||
bus.subscribe_outbound(capture_outbound)
|
||||
|
||||
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
|
||||
|
||||
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
|
||||
del assistant_id, context # unused in this test, kept for signature parity
|
||||
|
||||
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
|
||||
key = (thread_id, str(checkpoint_ns))
|
||||
history = history_by_checkpoint.setdefault(key, [])
|
||||
|
||||
human_text = input["messages"][0]["content"]
|
||||
history.append(human_text)
|
||||
|
||||
if len(history) == 1:
|
||||
return {
|
||||
"messages": [
|
||||
{"type": "human", "content": history[0]},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "ask_clarification",
|
||||
"args": {"question": "Which environment should I use?"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "ask_clarification",
|
||||
"content": "Which environment should I use?",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
if len(history) == 2 and history[0] == "Deploy my app" and history[1] == "prod":
|
||||
return {
|
||||
"messages": [
|
||||
{"type": "human", "content": history[0]},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "ask_clarification",
|
||||
"args": {"question": "Which environment should I use?"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "ask_clarification",
|
||||
"content": "Which environment should I use?",
|
||||
},
|
||||
{"type": "human", "content": history[1]},
|
||||
{"type": "ai", "content": "Got it. I will deploy to prod."},
|
||||
]
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{"type": "human", "content": history[-1]},
|
||||
{"type": "ai", "content": "History missing; clarification repeated."},
|
||||
]
|
||||
}
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.threads.create = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
|
||||
mock_client.threads.get = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
|
||||
mock_client.runs.wait = AsyncMock(side_effect=_runs_wait)
|
||||
manager._client = mock_client
|
||||
|
||||
await manager.start()
|
||||
|
||||
await bus.publish_inbound(
|
||||
InboundMessage(
|
||||
channel_name="test",
|
||||
chat_id="chat1",
|
||||
user_id="user1",
|
||||
text="Deploy my app",
|
||||
)
|
||||
)
|
||||
await _wait_for(lambda: len(outbound_received) >= 1)
|
||||
|
||||
await bus.publish_inbound(
|
||||
InboundMessage(
|
||||
channel_name="test",
|
||||
chat_id="chat1",
|
||||
user_id="user1",
|
||||
text="prod",
|
||||
)
|
||||
)
|
||||
await _wait_for(lambda: len(outbound_received) >= 2)
|
||||
await manager.stop()
|
||||
|
||||
assert outbound_received[0].text == "Which environment should I use?"
|
||||
assert outbound_received[1].text == "Got it. I will deploy to prod."
|
||||
|
||||
assert mock_client.runs.wait.call_count == 2
|
||||
first_call = mock_client.runs.wait.call_args_list[0]
|
||||
second_call = mock_client.runs.wait.call_args_list[1]
|
||||
assert first_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
assert second_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_chat_uses_user_session_overrides(self):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
@@ -1343,6 +1540,8 @@ class TestChannelManager:
|
||||
call_args = mock_client.runs.stream.call_args
|
||||
|
||||
assert call_args[1]["input"]["messages"][0]["content"] == "hello"
|
||||
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
|
||||
assert call_args[1]["context"]["is_bootstrap"] is True
|
||||
|
||||
# Final message should be published
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Unit tests for checkpointer config and singleton factory."""
|
||||
"""Unit tests for checkpointer config, packaging metadata, and factories."""
|
||||
|
||||
import sys
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -13,6 +15,8 @@ from deerflow.config.checkpointer_config import (
|
||||
set_checkpointer_config,
|
||||
)
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
from deerflow.runtime.checkpointer.provider import POSTGRES_INSTALL
|
||||
from deerflow.runtime.store.provider import POSTGRES_STORE_INSTALL
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -67,6 +71,42 @@ class TestCheckpointerConfig:
|
||||
with pytest.raises(Exception):
|
||||
load_checkpointer_config_from_dict({"type": "unknown"})
|
||||
|
||||
def test_connection_string_description_matches_runtime_defaults(self):
|
||||
description = CheckpointerConfig.model_fields["connection_string"].description
|
||||
|
||||
assert description is not None
|
||||
assert "Optional for sqlite" in description
|
||||
assert "defaults to 'store.db'" in description
|
||||
assert "Required for postgres" in description
|
||||
|
||||
|
||||
class TestHarnessPackaging:
|
||||
def test_pyproject_declares_postgres_extra(self):
|
||||
pyproject_path = Path(__file__).resolve().parents[1] / "packages" / "harness" / "pyproject.toml"
|
||||
data = tomllib.loads(pyproject_path.read_text())
|
||||
|
||||
optional_dependencies = data["project"]["optional-dependencies"]
|
||||
assert "postgres" in optional_dependencies
|
||||
assert optional_dependencies["postgres"] == [
|
||||
"asyncpg>=0.29",
|
||||
"langgraph-checkpoint-postgres>=3.0.5",
|
||||
"psycopg[binary]>=3.3.3",
|
||||
"psycopg-pool>=3.3.0",
|
||||
]
|
||||
|
||||
def test_workspace_pyproject_forwards_postgres_extra_to_harness(self):
|
||||
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
|
||||
data = tomllib.loads(pyproject_path.read_text())
|
||||
|
||||
optional_dependencies = data["project"]["optional-dependencies"]
|
||||
assert optional_dependencies["postgres"] == ["deerflow-harness[postgres]"]
|
||||
|
||||
def test_postgres_missing_dependency_messages_recommend_package_extra(self):
|
||||
assert "deerflow-harness[postgres]" in POSTGRES_INSTALL
|
||||
assert "deerflow-harness[postgres]" in POSTGRES_STORE_INSTALL
|
||||
assert "uv sync --all-packages --extra postgres" in POSTGRES_INSTALL
|
||||
assert "uv sync --all-packages --extra postgres" in POSTGRES_STORE_INSTALL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory tests
|
||||
|
||||
@@ -192,6 +192,7 @@ def test_agent_features_defaults():
|
||||
assert f.vision is False
|
||||
assert f.auto_title is False
|
||||
assert f.guardrail is False
|
||||
assert f.loop_detection is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -630,6 +631,51 @@ def test_loop_detection_before_clarification(mock_create_agent):
|
||||
assert loop_idx == clar_idx - 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 30b. loop_detection=False skips LoopDetectionMiddleware
|
||||
# ---------------------------------------------------------------------------
|
||||
@patch("deerflow.agents.factory.create_agent")
|
||||
def test_loop_detection_disabled(mock_create_agent):
|
||||
mock_create_agent.return_value = MagicMock()
|
||||
create_deerflow_agent(
|
||||
_make_mock_model(),
|
||||
features=RuntimeFeatures(sandbox=False, loop_detection=False),
|
||||
)
|
||||
|
||||
call_kwargs = mock_create_agent.call_args[1]
|
||||
mw_types = [type(m).__name__ for m in call_kwargs["middleware"]]
|
||||
assert "LoopDetectionMiddleware" not in mw_types
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 30c. loop_detection=<custom AgentMiddleware> replaces the default
|
||||
# ---------------------------------------------------------------------------
|
||||
@patch("deerflow.agents.factory.create_agent")
|
||||
def test_loop_detection_custom_middleware(mock_create_agent):
|
||||
from langchain.agents.middleware import AgentMiddleware as AM
|
||||
|
||||
mock_create_agent.return_value = MagicMock()
|
||||
|
||||
class MyLoopDetection(AM):
|
||||
pass
|
||||
|
||||
custom = MyLoopDetection()
|
||||
create_deerflow_agent(
|
||||
_make_mock_model(),
|
||||
features=RuntimeFeatures(sandbox=False, loop_detection=custom),
|
||||
)
|
||||
|
||||
call_kwargs = mock_create_agent.call_args[1]
|
||||
middleware = call_kwargs["middleware"]
|
||||
assert custom in middleware
|
||||
mw_types = [type(m).__name__ for m in middleware]
|
||||
# Default LoopDetectionMiddleware must not also appear.
|
||||
assert "LoopDetectionMiddleware" not in mw_types
|
||||
# Custom replacement still sits immediately before ClarificationMiddleware.
|
||||
assert mw_types[-1] == "ClarificationMiddleware"
|
||||
assert mw_types[-2] == "MyLoopDetection"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 31. plan_mode=True adds TodoMiddleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -85,6 +85,8 @@ def test_load_claude_code_credential_from_override_path(tmp_path, monkeypatch):
|
||||
|
||||
def test_load_claude_code_credential_ignores_directory_path(tmp_path, monkeypatch):
|
||||
_clear_claude_code_env(monkeypatch)
|
||||
# Redirect HOME so the default ~/.claude/.credentials.json doesn't exist
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
cred_dir = tmp_path / "claude-creds-dir"
|
||||
cred_dir.mkdir()
|
||||
monkeypatch.setenv("CLAUDE_CODE_CREDENTIALS_PATH", str(cred_dir))
|
||||
|
||||
@@ -0,0 +1,235 @@
|
||||
"""Tests for CSRF middleware."""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||
|
||||
|
||||
def _make_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/api/v1/auth/login/local")
|
||||
async def login_local():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/v1/auth/register")
|
||||
async def register():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/threads/abc/runs/stream")
|
||||
async def protected_mutation():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_auth_post_rejects_cross_origin_browser_request():
|
||||
"""CSRF-exempt auth routes must not accept hostile browser origins.
|
||||
|
||||
Login/register endpoints intentionally skip the double-submit token because
|
||||
first-time callers do not have a token yet. They still set an auth session,
|
||||
so a hostile cross-site form POST must be rejected to avoid login CSRF /
|
||||
session fixation.
|
||||
"""
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://evil.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||
|
||||
|
||||
def test_auth_post_allows_same_origin_browser_request():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_rejects_malformed_origin_with_path():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example/path"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||
assert response.cookies.get("csrf_token") is None
|
||||
|
||||
|
||||
def test_auth_post_rejects_malformed_origin_with_invalid_port():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example:bad"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||
assert response.cookies.get("csrf_token") is None
|
||||
|
||||
|
||||
def test_auth_post_allows_same_origin_default_port_equivalence():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example:443"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_allows_forwarded_same_origin():
|
||||
client = TestClient(_make_app(), base_url="http://internal:8000")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={
|
||||
"Origin": "https://deerflow.example",
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "deerflow.example, internal:8000",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_allows_forwarded_same_origin_with_non_default_port():
|
||||
client = TestClient(_make_app(), base_url="http://internal:8000")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={
|
||||
"Origin": "http://localhost:2026",
|
||||
"X-Forwarded-Proto": "http",
|
||||
"X-Forwarded-Host": "localhost:2026",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_allows_rfc_forwarded_same_origin():
|
||||
client = TestClient(_make_app(), base_url="http://internal:8000")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={
|
||||
"Origin": "https://deerflow.example",
|
||||
"Forwarded": "proto=https;host=deerflow.example",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
assert "secure" in response.headers["set-cookie"].lower()
|
||||
|
||||
|
||||
def test_auth_post_allows_explicit_configured_origin(monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "https://app.example")
|
||||
client = TestClient(_make_app(), base_url="https://api.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/register",
|
||||
headers={"Origin": "https://app.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_does_not_treat_wildcard_cors_as_allowed_origin(monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "*")
|
||||
client = TestClient(_make_app(), base_url="https://api.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://evil.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||
|
||||
|
||||
def test_auth_post_sets_strict_samesite_csrf_cookie():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
set_cookie = response.headers["set-cookie"].lower()
|
||||
assert "csrf_token=" in set_cookie
|
||||
assert "samesite=strict" in set_cookie
|
||||
assert "secure" in set_cookie
|
||||
|
||||
|
||||
def test_auth_post_without_origin_still_allows_non_browser_clients():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post("/api/v1/auth/login/local")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_non_auth_mutation_still_requires_double_submit_token():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/threads/abc/runs/stream",
|
||||
headers={"Origin": "https://deerflow.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
|
||||
|
||||
|
||||
def test_non_auth_mutation_allows_valid_double_submit_token():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
client.cookies.set("csrf_token", "known-token")
|
||||
|
||||
response = client.post(
|
||||
"/api/threads/abc/runs/stream",
|
||||
headers={
|
||||
"Origin": "https://deerflow.example",
|
||||
"X-CSRF-Token": "known-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_non_auth_mutation_rejects_mismatched_double_submit_token():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
client.cookies.set("csrf_token", "cookie-token")
|
||||
|
||||
response = client.post(
|
||||
"/api/threads/abc/runs/stream",
|
||||
headers={
|
||||
"Origin": "https://deerflow.example",
|
||||
"X-CSRF-Token": "header-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "CSRF token mismatch."
|
||||
@@ -537,7 +537,10 @@ class TestAgentsAPI:
|
||||
def test_create_persists_files_on_disk(self, agent_client, tmp_path):
|
||||
agent_client.post("/api/agents", json={"name": "disk-check", "soul": "disk soul"})
|
||||
|
||||
agent_dir = tmp_path / "agents" / "disk-check"
|
||||
# tests/conftest.py installs an autouse fixture that sets the
|
||||
# contextvar to "test-user-autouse", so the agent is persisted under
|
||||
# users/test-user-autouse/agents/ rather than the legacy shared dir.
|
||||
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "disk-check"
|
||||
assert agent_dir.exists()
|
||||
assert (agent_dir / "config.yaml").exists()
|
||||
assert (agent_dir / "SOUL.md").exists()
|
||||
@@ -545,12 +548,23 @@ class TestAgentsAPI:
|
||||
|
||||
def test_delete_removes_files_from_disk(self, agent_client, tmp_path):
|
||||
agent_client.post("/api/agents", json={"name": "remove-me", "soul": "bye"})
|
||||
agent_dir = tmp_path / "agents" / "remove-me"
|
||||
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "remove-me"
|
||||
assert agent_dir.exists()
|
||||
|
||||
agent_client.delete("/api/agents/remove-me")
|
||||
assert not agent_dir.exists()
|
||||
|
||||
def test_create_rejects_legacy_name_collision(self, agent_client, tmp_path):
|
||||
"""An unmigrated legacy agent must still block name collision so that
|
||||
running the migration script later won't shadow the legacy entry."""
|
||||
legacy_dir = tmp_path / "agents" / "legacy-agent"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
(legacy_dir / "config.yaml").write_text("name: legacy-agent\n", encoding="utf-8")
|
||||
(legacy_dir / "SOUL.md").write_text("legacy soul", encoding="utf-8")
|
||||
|
||||
response = agent_client.post("/api/agents", json={"name": "legacy-agent", "soul": "x"})
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 9. Gateway API – User Profile endpoints
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
"""Unit tests for scripts/detect_uv_extras.py.
|
||||
|
||||
The detector resolves uv extras for `make dev` so that postgres (and any
|
||||
future opt-in extras) are not wiped on every restart — see Issue #2754.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
DETECT_SCRIPT_PATH = REPO_ROOT / "scripts" / "detect_uv_extras.py"
|
||||
|
||||
|
||||
spec = importlib.util.spec_from_file_location("deerflow_detect_uv_extras", DETECT_SCRIPT_PATH)
|
||||
assert spec is not None and spec.loader is not None
|
||||
detect = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(detect)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_cwd(tmp_path, monkeypatch):
|
||||
"""Isolate `find_config_file()` from the real repo by chdir + clearing env."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
monkeypatch.delenv("UV_EXTRAS", raising=False)
|
||||
monkeypatch.delenv("DEER_FLOW_CONFIG_PATH", raising=False)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_parse_env_extras_supports_comma_and_whitespace():
|
||||
assert detect.parse_env_extras("postgres") == ["postgres"]
|
||||
assert detect.parse_env_extras("postgres,ollama") == ["postgres", "ollama"]
|
||||
assert detect.parse_env_extras("postgres ollama") == ["postgres", "ollama"]
|
||||
assert detect.parse_env_extras(" postgres , ollama ,") == ["postgres", "ollama"]
|
||||
assert detect.parse_env_extras("") == []
|
||||
|
||||
|
||||
def test_parse_env_extras_drops_shell_metacharacters(capsys):
|
||||
"""A `.env` value containing shell injection bait must not pass through.
|
||||
|
||||
The whitelist guarantees the *bytes* that reach `uv sync` cannot include
|
||||
shell metacharacters. Any name that looks identifier-like still survives
|
||||
(uv itself will reject unknown extras with its own error), but `;`, `&`,
|
||||
backticks, parentheses, slashes, etc. are stripped.
|
||||
"""
|
||||
# Pure-metacharacter inputs collapse to empty.
|
||||
assert detect.parse_env_extras(";") == []
|
||||
assert detect.parse_env_extras("$(whoami)") == []
|
||||
assert detect.parse_env_extras("`echo bad`") == []
|
||||
assert detect.parse_env_extras("postgres;evil") == [] # single token, contains `;`
|
||||
# Splitting on whitespace yields ['rm'] which is identifier-shaped, but the
|
||||
# destructive bits (`;`, `-rf`, `/`) are dropped.
|
||||
assert detect.parse_env_extras("; rm -rf /") == ["rm"]
|
||||
err = capsys.readouterr().err
|
||||
assert "ignoring invalid UV_EXTRAS entry" in err
|
||||
assert "';'" in err # confirms the dangerous token was reported and dropped
|
||||
|
||||
|
||||
def test_parse_env_extras_rejects_leading_digits_and_punctuation():
|
||||
"""Names must start with a letter — pyproject extras follow this shape."""
|
||||
assert detect.parse_env_extras("1postgres") == []
|
||||
assert detect.parse_env_extras("-postgres") == []
|
||||
# Hyphens and underscores inside the name are fine.
|
||||
assert detect.parse_env_extras("post_gres") == ["post_gres"]
|
||||
assert detect.parse_env_extras("post-gres") == ["post-gres"]
|
||||
|
||||
|
||||
def test_format_flags_emits_one_flag_per_extra():
|
||||
assert detect.format_flags([]) == ""
|
||||
assert detect.format_flags(["postgres"]) == "--extra postgres"
|
||||
assert detect.format_flags(["postgres", "ollama"]) == "--extra postgres --extra ollama"
|
||||
|
||||
|
||||
def test_strip_comment_preserves_quoted_hash():
|
||||
assert detect._strip_comment("backend: postgres # trailing") == "backend: postgres"
|
||||
assert detect._strip_comment('name: "value#with-hash"') == 'name: "value#with-hash"'
|
||||
assert detect._strip_comment("# whole line comment") == ""
|
||||
|
||||
|
||||
def test_section_value_finds_nested_key():
|
||||
yaml_lines = [
|
||||
"database:",
|
||||
" backend: postgres",
|
||||
" postgres_url: $DATABASE_URL",
|
||||
"",
|
||||
"checkpointer:",
|
||||
" type: sqlite",
|
||||
]
|
||||
assert detect.section_value(yaml_lines, "database", "backend") == "postgres"
|
||||
assert detect.section_value(yaml_lines, "checkpointer", "type") == "sqlite"
|
||||
assert detect.section_value(yaml_lines, "database", "missing") is None
|
||||
assert detect.section_value(yaml_lines, "absent_section", "anything") is None
|
||||
|
||||
|
||||
def test_section_value_ignores_commented_lines():
|
||||
yaml_lines = [
|
||||
"# database:",
|
||||
"# backend: postgres",
|
||||
"database:",
|
||||
" backend: sqlite",
|
||||
]
|
||||
assert detect.section_value(yaml_lines, "database", "backend") == "sqlite"
|
||||
|
||||
|
||||
def test_section_value_strips_quotes():
|
||||
yaml_lines = [
|
||||
"database:",
|
||||
' backend: "postgres"',
|
||||
]
|
||||
assert detect.section_value(yaml_lines, "database", "backend") == "postgres"
|
||||
|
||||
|
||||
def test_section_value_does_not_descend_into_grandchildren():
|
||||
yaml_lines = [
|
||||
"database:",
|
||||
" backend: sqlite",
|
||||
" nested:",
|
||||
" backend: postgres",
|
||||
]
|
||||
# Only the immediate child level counts — keeps the parser predictable.
|
||||
assert detect.section_value(yaml_lines, "database", "backend") == "sqlite"
|
||||
|
||||
|
||||
def test_detect_from_config_postgres_via_database(tmp_path):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("database:\n backend: postgres\n postgres_url: $DATABASE_URL\n")
|
||||
assert detect.detect_from_config(cfg) == ["postgres"]
|
||||
|
||||
|
||||
def test_detect_from_config_postgres_via_checkpointer(tmp_path):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("checkpointer:\n type: postgres\n connection_string: postgresql://localhost/db\n")
|
||||
assert detect.detect_from_config(cfg) == ["postgres"]
|
||||
|
||||
|
||||
def test_detect_from_config_sqlite_returns_no_extras(tmp_path):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("database:\n backend: sqlite\n sqlite_dir: .deer-flow/data\n")
|
||||
assert detect.detect_from_config(cfg) == []
|
||||
|
||||
|
||||
def test_detect_from_config_dedupes_when_both_present(tmp_path):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("checkpointer:\n type: postgres\ndatabase:\n backend: postgres\n")
|
||||
# Sorted unique extras, no double-counting.
|
||||
assert detect.detect_from_config(cfg) == ["postgres"]
|
||||
|
||||
|
||||
def test_detect_from_config_missing_file_returns_empty(tmp_path):
|
||||
assert detect.detect_from_config(tmp_path / "does-not-exist.yaml") == []
|
||||
|
||||
|
||||
def test_resolve_extras_env_overrides_config(isolated_cwd, monkeypatch):
|
||||
cfg = isolated_cwd / "config.yaml"
|
||||
cfg.write_text("database:\n backend: sqlite\n")
|
||||
monkeypatch.setenv("UV_EXTRAS", "postgres")
|
||||
|
||||
assert detect.resolve_extras() == ["postgres"]
|
||||
|
||||
|
||||
def test_resolve_extras_env_supports_multiple(isolated_cwd, monkeypatch):
|
||||
monkeypatch.setenv("UV_EXTRAS", "postgres,ollama")
|
||||
assert detect.resolve_extras() == ["postgres", "ollama"]
|
||||
|
||||
|
||||
def test_resolve_extras_falls_back_to_config(isolated_cwd):
|
||||
(isolated_cwd / "config.yaml").write_text("database:\n backend: postgres\n")
|
||||
assert detect.resolve_extras() == ["postgres"]
|
||||
|
||||
|
||||
def test_resolve_extras_respects_explicit_config_path(tmp_path, monkeypatch):
|
||||
monkeypatch.delenv("UV_EXTRAS", raising=False)
|
||||
elsewhere = tmp_path / "elsewhere.yaml"
|
||||
elsewhere.write_text("database:\n backend: postgres\n")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(elsewhere))
|
||||
|
||||
assert detect.resolve_extras() == ["postgres"]
|
||||
|
||||
|
||||
def test_resolve_extras_no_config_no_env(isolated_cwd):
|
||||
assert detect.resolve_extras() == []
|
||||
|
||||
|
||||
def test_resolve_extras_finds_backend_subdir_config(isolated_cwd):
|
||||
sub = isolated_cwd / "backend"
|
||||
sub.mkdir()
|
||||
(sub / "config.yaml").write_text("database:\n backend: postgres\n")
|
||||
assert detect.resolve_extras() == ["postgres"]
|
||||
|
||||
|
||||
def test_resolve_extras_root_config_takes_precedence(isolated_cwd):
|
||||
(isolated_cwd / "config.yaml").write_text("database:\n backend: sqlite\n")
|
||||
sub = isolated_cwd / "backend"
|
||||
sub.mkdir()
|
||||
(sub / "config.yaml").write_text("database:\n backend: postgres\n")
|
||||
# Root config.yaml is checked first, matching the precedence in serve.sh.
|
||||
assert detect.resolve_extras() == []
|
||||
@@ -0,0 +1,102 @@
|
||||
"""Unit tests for docker/dev-entrypoint.sh (UV_EXTRAS validation + parsing).
|
||||
|
||||
Exercises the script via its `--print-extras` dry-run hook so we don't actually
|
||||
launch uvicorn or hit /app/logs. Together with test_detect_uv_extras.py these
|
||||
cover both the local make-dev path and the docker-compose-dev path with the
|
||||
same shape — see PR #2767 / Issue #2754.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
ENTRYPOINT = REPO_ROOT / "docker" / "dev-entrypoint.sh"
|
||||
|
||||
|
||||
def _run(uv_extras: str | None) -> subprocess.CompletedProcess[str]:
|
||||
"""Invoke `dev-entrypoint.sh --print-extras` with UV_EXTRAS set."""
|
||||
env = os.environ.copy()
|
||||
env.pop("UV_EXTRAS", None)
|
||||
if uv_extras is not None:
|
||||
env["UV_EXTRAS"] = uv_extras
|
||||
return subprocess.run(
|
||||
["sh", str(ENTRYPOINT), "--print-extras"],
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
|
||||
def test_entrypoint_script_exists_and_is_posix_sh():
|
||||
assert ENTRYPOINT.is_file()
|
||||
# Catch syntax errors before runtime — `sh -n` is a parse-only check.
|
||||
proc = subprocess.run(["sh", "-n", str(ENTRYPOINT)], capture_output=True, text=True, check=False)
|
||||
assert proc.returncode == 0, proc.stderr
|
||||
|
||||
|
||||
def test_no_uv_extras_yields_empty_flags():
|
||||
proc = _run(None)
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == ""
|
||||
|
||||
|
||||
def test_single_extra():
|
||||
proc = _run("postgres")
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == "--extra postgres"
|
||||
|
||||
|
||||
def test_multi_extra_comma_separated():
|
||||
proc = _run("postgres,ollama")
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == "--extra postgres --extra ollama"
|
||||
|
||||
|
||||
def test_multi_extra_whitespace_separated():
|
||||
proc = _run("postgres ollama")
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == "--extra postgres --extra ollama"
|
||||
|
||||
|
||||
def test_multi_extra_mixed_separators():
|
||||
proc = _run(" postgres , ollama ,")
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == "--extra postgres --extra ollama"
|
||||
|
||||
|
||||
def test_empty_string_yields_empty_flags():
|
||||
proc = _run("")
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad_value",
|
||||
[
|
||||
"; rm -rf /", # the canonical injection attempt
|
||||
"$(whoami)", # command substitution
|
||||
"`echo bad`", # backticks
|
||||
"postgres;evil", # mixed legal+illegal in a single token
|
||||
"1postgres", # leading digit
|
||||
"-postgres", # leading hyphen
|
||||
"post gres extra/path", # contains slash
|
||||
],
|
||||
)
|
||||
def test_metacharacters_abort_with_nonzero_exit(bad_value):
|
||||
proc = _run(bad_value)
|
||||
assert proc.returncode != 0, f"expected abort for {bad_value!r}, got 0"
|
||||
assert "is invalid" in proc.stderr
|
||||
assert proc.stdout.strip() == ""
|
||||
|
||||
|
||||
def test_underscores_and_hyphens_in_name_are_allowed():
|
||||
"""Mirrors uv's accepted shape for `[project.optional-dependencies]` keys."""
|
||||
proc = _run("post_gres,post-gres")
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == "--extra post_gres --extra post-gres"
|
||||
@@ -0,0 +1,336 @@
|
||||
"""Tests for DynamicContextMiddleware.
|
||||
|
||||
Verifies that memory and current date are injected as a <system-reminder> into
|
||||
the first HumanMessage exactly once per session (frozen-snapshot pattern).
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import (
|
||||
_DYNAMIC_CONTEXT_REMINDER_KEY,
|
||||
DynamicContextMiddleware,
|
||||
)
|
||||
|
||||
_SYSTEM_REMINDER_TAG = "<system-reminder>"
|
||||
|
||||
|
||||
def _make_middleware(**kwargs) -> DynamicContextMiddleware:
|
||||
return DynamicContextMiddleware(**kwargs)
|
||||
|
||||
|
||||
def _fake_runtime():
|
||||
return SimpleNamespace(context={})
|
||||
|
||||
|
||||
def _reminder_msg(content: str, msg_id: str) -> HumanMessage:
|
||||
"""Build a reminder HumanMessage the way the middleware would produce it."""
|
||||
return HumanMessage(
|
||||
content=content,
|
||||
id=msg_id,
|
||||
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Basic injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_injects_system_reminder_into_first_human_message():
|
||||
mw = _make_middleware()
|
||||
state = {"messages": [HumanMessage(content="Hello", id="msg-1")]}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is not None
|
||||
updated_msgs = result["messages"]
|
||||
assert len(updated_msgs) == 2
|
||||
|
||||
reminder_msg = updated_msgs[0]
|
||||
assert isinstance(reminder_msg, HumanMessage)
|
||||
assert reminder_msg.id == "msg-1" # takes the original ID (position swap)
|
||||
assert reminder_msg.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||
assert _SYSTEM_REMINDER_TAG in reminder_msg.content
|
||||
assert "<current_date>2026-05-08, Friday</current_date>" in reminder_msg.content
|
||||
assert "Hello" not in reminder_msg.content # reminder only — no user text
|
||||
|
||||
user_msg = updated_msgs[1]
|
||||
assert isinstance(user_msg, HumanMessage)
|
||||
assert user_msg.id == "msg-1__user" # derived ID
|
||||
assert user_msg.content == "Hello"
|
||||
|
||||
|
||||
def test_memory_included_when_present():
|
||||
mw = _make_middleware()
|
||||
state = {"messages": [HumanMessage(content="Hi", id="msg-1")]}
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"deerflow.agents.lead_agent.prompt._get_memory_context",
|
||||
return_value="<memory>\nUser prefers Python.\n</memory>",
|
||||
),
|
||||
mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt,
|
||||
):
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
# Reminder is the first returned message; user query is the second
|
||||
reminder_content = result["messages"][0].content
|
||||
assert "User prefers Python." in reminder_content
|
||||
assert "<current_date>2026-05-08, Friday</current_date>" in reminder_content
|
||||
assert result["messages"][1].content == "Hi"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frozen-snapshot: no re-injection within a session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_skips_injection_if_already_present():
|
||||
"""Second turn: separate reminder message already present → no update."""
|
||||
mw = _make_middleware()
|
||||
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
|
||||
state = {
|
||||
"messages": [
|
||||
_reminder_msg(reminder_content, "msg-1"),
|
||||
HumanMessage(content="Hello", id="msg-1__user"),
|
||||
AIMessage(content="Hi there"),
|
||||
HumanMessage(content="Follow-up", id="msg-2"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is None # no update needed
|
||||
|
||||
|
||||
def test_injects_only_into_first_human_message_not_later_ones():
|
||||
"""Reminder targets the first HumanMessage; subsequent messages are not touched."""
|
||||
mw = _make_middleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="First", id="msg-1"),
|
||||
AIMessage(content="Reply"),
|
||||
HumanMessage(content="Second", id="msg-2"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
# Only the two injected messages are returned (reminder + original first query)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0].id == "msg-1" # reminder takes first message's ID
|
||||
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||
assert _SYSTEM_REMINDER_TAG in msgs[0].content
|
||||
assert msgs[1].id == "msg-1__user" # original content with derived ID
|
||||
assert msgs[1].content == "First"
|
||||
# "Second" (msg-2) is not in the returned update — it is left unchanged
|
||||
assert all(m.id != "msg-2" for m in msgs)
|
||||
|
||||
|
||||
def test_summary_human_message_is_not_used_as_injection_target():
|
||||
"""After summarization, the synthetic summary HumanMessage is not a user turn."""
|
||||
mw = _make_middleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="Here is a summary of the conversation to date:\n\n...", id="summary-1", name="summary"),
|
||||
AIMessage(content="Earlier reply"),
|
||||
HumanMessage(content="Follow-up", id="msg-2"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0].id == "msg-2"
|
||||
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||
assert msgs[1].id == "msg-2__user"
|
||||
assert msgs[1].content == "Follow-up"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_no_messages_returns_none():
|
||||
mw = _make_middleware()
|
||||
result = mw.before_agent({"messages": []}, _fake_runtime())
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_no_human_message_returns_none():
|
||||
mw = _make_middleware()
|
||||
state = {"messages": [AIMessage(content="assistant only")]}
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""):
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_list_content_message_handled_as_separate_reminder():
|
||||
"""List-content (e.g. multi-modal) messages remain intact; reminder is a separate message."""
|
||||
mw = _make_middleware()
|
||||
original_content = [{"type": "text", "text": "Hello"}]
|
||||
state = {"messages": [HumanMessage(content=original_content, id="msg-1")]}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 2
|
||||
# Reminder is a plain string message with the flag set
|
||||
assert isinstance(msgs[0].content, str)
|
||||
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||
assert _SYSTEM_REMINDER_TAG in msgs[0].content
|
||||
# Original list-content message is untouched
|
||||
assert msgs[1].content == original_content
|
||||
|
||||
|
||||
def test_reminder_uses_original_id_user_message_uses_derived_id():
|
||||
"""Reminder takes original ID (position swap); user message gets {id}__user."""
|
||||
mw = _make_middleware()
|
||||
original_id = "original-id-abc"
|
||||
state = {"messages": [HumanMessage(content="Hello", id=original_id)]}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result["messages"][0].id == original_id
|
||||
assert result["messages"][1].id == f"{original_id}__user"
|
||||
|
||||
|
||||
def test_message_without_id_gets_stable_uuid():
|
||||
"""If the original HumanMessage has no ID, a UUID is generated and used consistently."""
|
||||
mw = _make_middleware()
|
||||
state = {"messages": [HumanMessage(content="Hello", id=None)]}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is not None
|
||||
reminder_id = result["messages"][0].id
|
||||
user_id = result["messages"][1].id
|
||||
assert reminder_id is not None
|
||||
assert reminder_id != "None"
|
||||
assert user_id == f"{reminder_id}__user"
|
||||
|
||||
|
||||
def test_user_message_containing_system_reminder_tag_does_not_prevent_injection():
|
||||
"""A user message containing '<system-reminder>' must not be mistaken for a reminder."""
|
||||
mw = _make_middleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="What is <system-reminder>?", id="msg-1"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
# Injection must happen — the user message does NOT carry the reminder flag
|
||||
assert result is not None
|
||||
assert result["messages"][0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Midnight crossing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_midnight_crossing_injects_date_update_as_separate_message():
|
||||
"""When the date has changed, a separate date-update reminder is injected before
|
||||
the current turn's HumanMessage using the ID-swap technique."""
|
||||
mw = _make_middleware()
|
||||
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
|
||||
state = {
|
||||
"messages": [
|
||||
_reminder_msg(reminder_content, "msg-1"),
|
||||
HumanMessage(content="Hello", id="msg-1__user"),
|
||||
AIMessage(content="Response"),
|
||||
HumanMessage(content="Good morning", id="msg-2"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 2
|
||||
|
||||
# Date-update reminder takes the current message's ID
|
||||
assert msgs[0].id == "msg-2"
|
||||
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||
assert _SYSTEM_REMINDER_TAG in msgs[0].content
|
||||
assert "<current_date>2026-05-09, Saturday</current_date>" in msgs[0].content
|
||||
assert "Good morning" not in msgs[0].content # reminder only
|
||||
|
||||
# Original user text appended with derived ID
|
||||
assert msgs[1].id == "msg-2__user"
|
||||
assert msgs[1].content == "Good morning"
|
||||
|
||||
|
||||
def test_midnight_crossing_id_swap():
|
||||
"""Date-update reminder uses original ID; user message uses {id}__user."""
|
||||
mw = _make_middleware()
|
||||
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
|
||||
state = {
|
||||
"messages": [
|
||||
_reminder_msg(reminder_content, "msg-1"),
|
||||
HumanMessage(content="Next day message", id="msg-2"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result["messages"][0].id == "msg-2"
|
||||
assert result["messages"][1].id == "msg-2__user"
|
||||
|
||||
|
||||
def test_no_second_midnight_injection_once_date_updated():
|
||||
"""After a midnight update is persisted, the same-day path skips re-injection."""
|
||||
mw = _make_middleware()
|
||||
date_update_content = "<system-reminder>\n<current_date>2026-05-09, Saturday</current_date>\n</system-reminder>"
|
||||
state = {
|
||||
"messages": [
|
||||
_reminder_msg(
|
||||
"<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
|
||||
"msg-1",
|
||||
),
|
||||
HumanMessage(content="Hello", id="msg-1__user"),
|
||||
AIMessage(content="Response"),
|
||||
_reminder_msg(date_update_content, "msg-2"),
|
||||
HumanMessage(content="Good morning", id="msg-2__user"),
|
||||
AIMessage(content="Good morning!"),
|
||||
HumanMessage(content="Third turn", id="msg-3"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is None # same day as last injected date → no update
|
||||
@@ -50,7 +50,7 @@ def test_nginx_routes_official_langgraph_prefix_to_gateway_api():
|
||||
assert "/api/langgraph-compat" not in content
|
||||
assert "proxy_pass http://langgraph" not in content
|
||||
assert "rewrite ^/api/langgraph/(.*) /api/$1 break;" in content
|
||||
assert "proxy_pass http://gateway" in content
|
||||
assert "proxy_pass http://gateway" in content or "proxy_pass http://$gateway_upstream" in content
|
||||
|
||||
|
||||
def test_frontend_rewrites_langgraph_prefix_to_gateway():
|
||||
|
||||
@@ -324,6 +324,21 @@ def test_context_does_not_override_existing_configurable():
|
||||
assert config["configurable"]["subagent_enabled"] is True
|
||||
|
||||
|
||||
def test_inject_authenticated_user_context_overrides_client_user_id():
|
||||
"""Run context should carry the authenticated user, not client-supplied user_id."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.gateway.services import build_run_config, inject_authenticated_user_context
|
||||
|
||||
config = build_run_config("thread-1", None, None)
|
||||
config["context"] = {"user_id": "spoofed-client"}
|
||||
request = SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(id="auth-user-42")))
|
||||
|
||||
inject_authenticated_user_context(config, request)
|
||||
|
||||
assert config["context"]["user_id"] == "auth-user-42"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -8,17 +8,20 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.summarization_config import SummarizationConfig
|
||||
|
||||
|
||||
def _make_app_config(models: list[ModelConfig]) -> AppConfig:
|
||||
def _make_app_config(models: list[ModelConfig], loop_detection: LoopDetectionConfig | None = None) -> AppConfig:
|
||||
return AppConfig(
|
||||
models=models,
|
||||
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider"),
|
||||
loop_detection=loop_detection or LoopDetectionConfig(),
|
||||
)
|
||||
|
||||
|
||||
@@ -340,6 +343,59 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa
|
||||
assert middlewares[0] == "base-middleware"
|
||||
|
||||
|
||||
def test_build_middlewares_uses_loop_detection_config(monkeypatch):
|
||||
app_config = _make_app_config(
|
||||
[_make_model("safe-model", supports_thinking=False)],
|
||||
loop_detection=LoopDetectionConfig(
|
||||
warn_threshold=7,
|
||||
hard_limit=9,
|
||||
window_size=30,
|
||||
max_tracked_threads=40,
|
||||
tool_freq_warn=50,
|
||||
tool_freq_hard_limit=60,
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
|
||||
middlewares = lead_agent_module._build_middlewares(
|
||||
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
|
||||
model_name="safe-model",
|
||||
app_config=app_config,
|
||||
)
|
||||
|
||||
loop_detection = next(m for m in middlewares if isinstance(m, LoopDetectionMiddleware))
|
||||
assert loop_detection.warn_threshold == 7
|
||||
assert loop_detection.hard_limit == 9
|
||||
assert loop_detection.window_size == 30
|
||||
assert loop_detection.max_tracked_threads == 40
|
||||
assert loop_detection.tool_freq_warn == 50
|
||||
assert loop_detection.tool_freq_hard_limit == 60
|
||||
|
||||
|
||||
def test_build_middlewares_omits_loop_detection_when_disabled(monkeypatch):
|
||||
app_config = _make_app_config(
|
||||
[_make_model("safe-model", supports_thinking=False)],
|
||||
loop_detection=LoopDetectionConfig(enabled=False),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
|
||||
middlewares = lead_agent_module._build_middlewares(
|
||||
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
|
||||
model_name="safe-model",
|
||||
app_config=app_config,
|
||||
)
|
||||
|
||||
assert not any(isinstance(m, LoopDetectionMiddleware) for m in middlewares)
|
||||
|
||||
|
||||
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
|
||||
app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)])
|
||||
app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork")
|
||||
|
||||
@@ -1,22 +1,37 @@
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.subagents_config import CustomSubagentConfig, SubagentsAppConfig
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.skills.types import Skill, SkillCategory
|
||||
|
||||
|
||||
def _set_skills_cache_state(*, skills=None, active=False, version=0):
|
||||
prompt_module._get_cached_skills_prompt_section.cache_clear()
|
||||
with prompt_module._enabled_skills_lock:
|
||||
prompt_module._enabled_skills_cache = skills
|
||||
prompt_module._enabled_skills_by_config_cache.clear()
|
||||
prompt_module._enabled_skills_refresh_active = active
|
||||
prompt_module._enabled_skills_refresh_version = version
|
||||
prompt_module._enabled_skills_refresh_event.clear()
|
||||
|
||||
|
||||
def test_build_self_update_section_empty_for_default_agent():
|
||||
assert prompt_module._build_self_update_section(None) == ""
|
||||
|
||||
|
||||
def test_build_self_update_section_present_for_custom_agent():
|
||||
section = prompt_module._build_self_update_section("my-agent")
|
||||
|
||||
assert "<self_update>" in section
|
||||
assert "my-agent" in section
|
||||
assert "update_agent" in section
|
||||
|
||||
|
||||
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
|
||||
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
@@ -220,7 +235,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=skill_dir.relative_to(tmp_path),
|
||||
category="custom",
|
||||
category=SkillCategory.CUSTOM,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
@@ -240,6 +255,58 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
_set_skills_cache_state()
|
||||
|
||||
|
||||
def test_explicit_config_enabled_skills_are_cached_by_config_identity(monkeypatch, tmp_path):
|
||||
def make_skill(name: str) -> Skill:
|
||||
skill_dir = tmp_path / name
|
||||
return Skill(
|
||||
name=name,
|
||||
description=f"Description for {name}",
|
||||
license="MIT",
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=skill_dir.relative_to(tmp_path),
|
||||
category=SkillCategory.CUSTOM,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
config = cast(
|
||||
AppConfig,
|
||||
cast(
|
||||
object,
|
||||
SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
),
|
||||
),
|
||||
)
|
||||
load_count = 0
|
||||
|
||||
def fake_get_or_new_skill_storage(**kwargs):
|
||||
nonlocal load_count
|
||||
assert kwargs == {"app_config": config}
|
||||
|
||||
def load_skills(*, enabled_only):
|
||||
nonlocal load_count
|
||||
load_count += 1
|
||||
assert enabled_only is True
|
||||
return [make_skill("cached-skill")]
|
||||
|
||||
return SimpleNamespace(load_skills=load_skills)
|
||||
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fake_get_or_new_skill_storage)
|
||||
_set_skills_cache_state()
|
||||
|
||||
try:
|
||||
first = prompt_module.get_skills_prompt_section(app_config=config)
|
||||
second = prompt_module.get_skills_prompt_section(app_config=config)
|
||||
|
||||
assert "cached-skill" in first
|
||||
assert "cached-skill" in second
|
||||
assert load_count == 1
|
||||
finally:
|
||||
_set_skills_cache_state()
|
||||
|
||||
|
||||
def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_path):
|
||||
started = threading.Event()
|
||||
release = threading.Event()
|
||||
@@ -257,7 +324,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=skill_dir.relative_to(tmp_path),
|
||||
category="custom",
|
||||
category=SkillCategory.CUSTOM,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,12 @@ from deerflow.config.agents_config import AgentConfig
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
|
||||
def _make_skill(name: str) -> Skill:
|
||||
class NamedTool:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
def _make_skill(name: str, allowed_tools: list[str] | None = None) -> Skill:
|
||||
return Skill(
|
||||
name=name,
|
||||
description=f"Description for {name}",
|
||||
@@ -15,6 +20,7 @@ def _make_skill(name: str) -> Skill:
|
||||
skill_file=Path(f"/tmp/{name}/SKILL.md"),
|
||||
relative_path=Path(name),
|
||||
category="public",
|
||||
allowed_tools=allowed_tools,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
@@ -132,6 +138,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
|
||||
@@ -164,3 +171,106 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
assert captured_skills[-1] == {"skill1"}
|
||||
|
||||
|
||||
def test_make_lead_agent_filters_tools_from_available_skills(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted", "legacy"]))
|
||||
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("restricted", ["read_file"]), _make_skill("legacy", None)])
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
|
||||
|
||||
|
||||
def test_make_lead_agent_all_legacy_skills_preserve_all_tools(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
|
||||
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("legacy", None)])
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file")])
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
assert [tool.name for tool in agent_kwargs["tools"]] == ["bash", "read_file", "update_agent"]
|
||||
|
||||
|
||||
def test_make_lead_agent_enforces_allowed_tools_when_skill_cache_is_cold(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
mock_storage = SimpleNamespace(load_skills=lambda *, enabled_only: [_make_skill("restricted", ["read_file"])])
|
||||
|
||||
with prompt_module._enabled_skills_lock:
|
||||
prompt_module._enabled_skills_cache = None
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None, **kwargs: mock_storage)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
|
||||
|
||||
|
||||
def test_make_lead_agent_fails_closed_when_skill_policy_load_fails(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
create_agent_mock = MagicMock()
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", create_agent_mock)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
|
||||
def fail_storage(*args, **kwargs):
|
||||
raise RuntimeError("skill storage unavailable")
|
||||
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fail_storage)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
with pytest.raises(RuntimeError, match="skill storage unavailable"):
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
create_agent_mock.assert_not_called()
|
||||
|
||||
@@ -105,6 +105,7 @@ def test_execute_command_uses_powershell_command_mode_on_windows(monkeypatch):
|
||||
"capture_output": True,
|
||||
"text": True,
|
||||
"timeout": 600,
|
||||
"env": None,
|
||||
},
|
||||
)
|
||||
]
|
||||
@@ -118,6 +119,7 @@ def test_execute_command_uses_posix_shell_command_mode_on_windows(monkeypatch):
|
||||
return SimpleNamespace(stdout="ok", stderr="", returncode=0)
|
||||
|
||||
monkeypatch.setattr(local_sandbox.os, "name", "nt")
|
||||
monkeypatch.setattr(local_sandbox.os, "environ", {"PATH": r"C:\Program Files\Git\bin"})
|
||||
monkeypatch.setattr(LocalSandbox, "_get_shell", staticmethod(lambda: r"C:\Program Files\Git\bin\sh.exe"))
|
||||
monkeypatch.setattr(local_sandbox.subprocess, "run", fake_run)
|
||||
|
||||
@@ -132,11 +134,33 @@ def test_execute_command_uses_posix_shell_command_mode_on_windows(monkeypatch):
|
||||
"capture_output": True,
|
||||
"text": True,
|
||||
"timeout": 600,
|
||||
"env": {
|
||||
"PATH": r"C:\Program Files\Git\bin",
|
||||
"MSYS_NO_PATHCONV": "1",
|
||||
"MSYS2_ARG_CONV_EXCL": "*",
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_execute_command_does_not_set_msys_env_for_non_msys_posix_shell_on_windows(monkeypatch):
|
||||
calls: list[tuple[object, dict]] = []
|
||||
|
||||
def fake_run(*args, **kwargs):
|
||||
calls.append((args[0], kwargs))
|
||||
return SimpleNamespace(stdout="ok", stderr="", returncode=0)
|
||||
|
||||
monkeypatch.setattr(local_sandbox.os, "name", "nt")
|
||||
monkeypatch.setattr(LocalSandbox, "_get_shell", staticmethod(lambda: r"C:\tools\busybox\sh.exe"))
|
||||
monkeypatch.setattr(local_sandbox.subprocess, "run", fake_run)
|
||||
|
||||
output = LocalSandbox("t").execute_command("echo /mnt/skills/demo")
|
||||
|
||||
assert output == "ok"
|
||||
assert calls[0][1]["env"] is None
|
||||
|
||||
|
||||
def test_execute_command_uses_cmd_command_mode_on_windows(monkeypatch):
|
||||
calls: list[tuple[object, dict]] = []
|
||||
|
||||
@@ -159,6 +183,7 @@ def test_execute_command_uses_cmd_command_mode_on_windows(monkeypatch):
|
||||
"capture_output": True,
|
||||
"text": True,
|
||||
"timeout": 600,
|
||||
"env": None,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Tests for loop detection configuration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
|
||||
|
||||
class TestLoopDetectionConfig:
|
||||
def test_defaults_match_middleware_defaults(self):
|
||||
config = LoopDetectionConfig()
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.warn_threshold == 3
|
||||
assert config.hard_limit == 5
|
||||
assert config.window_size == 20
|
||||
assert config.max_tracked_threads == 100
|
||||
assert config.tool_freq_warn == 30
|
||||
assert config.tool_freq_hard_limit == 50
|
||||
|
||||
def test_accepts_custom_values(self):
|
||||
config = LoopDetectionConfig(
|
||||
enabled=False,
|
||||
warn_threshold=10,
|
||||
hard_limit=20,
|
||||
window_size=50,
|
||||
max_tracked_threads=200,
|
||||
tool_freq_warn=60,
|
||||
tool_freq_hard_limit=80,
|
||||
)
|
||||
|
||||
assert config.enabled is False
|
||||
assert config.warn_threshold == 10
|
||||
assert config.hard_limit == 20
|
||||
assert config.window_size == 50
|
||||
assert config.max_tracked_threads == 200
|
||||
assert config.tool_freq_warn == 60
|
||||
assert config.tool_freq_hard_limit == 80
|
||||
|
||||
def test_rejects_zero_thresholds(self):
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(warn_threshold=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(hard_limit=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(tool_freq_warn=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(tool_freq_hard_limit=0)
|
||||
|
||||
def test_rejects_hard_limit_below_warn_threshold(self):
|
||||
with pytest.raises(ValueError, match="hard_limit"):
|
||||
LoopDetectionConfig(warn_threshold=5, hard_limit=4)
|
||||
|
||||
def test_rejects_tool_freq_hard_limit_below_warn_threshold(self):
|
||||
with pytest.raises(ValueError, match="tool_freq_hard_limit"):
|
||||
LoopDetectionConfig(tool_freq_warn=5, tool_freq_hard_limit=4)
|
||||
|
||||
def test_tool_freq_override_valid(self):
|
||||
config = LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 150, "hard_limit": 300}})
|
||||
override = config.tool_freq_overrides["bash"]
|
||||
assert override.warn == 150
|
||||
assert override.hard_limit == 300
|
||||
|
||||
def test_tool_freq_override_rejects_zero_warn(self):
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 0, "hard_limit": 10}})
|
||||
|
||||
def test_tool_freq_override_rejects_hard_limit_below_warn(self):
|
||||
with pytest.raises(ValueError, match="hard_limit"):
|
||||
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 100, "hard_limit": 50}})
|
||||
@@ -3,7 +3,7 @@
|
||||
import copy
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import AIMessage, SystemMessage
|
||||
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import (
|
||||
_HARD_STOP_MSG,
|
||||
@@ -146,14 +146,42 @@ class TestLoopDetection:
|
||||
for _ in range(2):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# Third identical call triggers warning
|
||||
# Third identical call triggers warning. The warning is appended to
|
||||
# the AIMessage content (tool_calls preserved) — never inserted as a
|
||||
# separate HumanMessage between the AIMessage(tool_calls) and its
|
||||
# ToolMessage responses, which would break OpenAI/Moonshot strict
|
||||
# tool-call pairing validation.
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 1
|
||||
assert isinstance(msgs[0], HumanMessage)
|
||||
assert isinstance(msgs[0], AIMessage)
|
||||
assert len(msgs[0].tool_calls) == len(call)
|
||||
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
|
||||
assert "LOOP DETECTED" in msgs[0].content
|
||||
|
||||
def test_warn_does_not_break_tool_call_pairing(self):
|
||||
"""Regression: the warn branch must NOT inject a non-tool message
|
||||
after an AIMessage(tool_calls=...). Moonshot/OpenAI reject the next
|
||||
request with 'tool_call_ids did not have response messages' if any
|
||||
non-tool message is wedged between the AIMessage and its ToolMessage
|
||||
responses. See #2029.
|
||||
"""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(2):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 1
|
||||
assert isinstance(msgs[0], AIMessage)
|
||||
assert len(msgs[0].tool_calls) == len(call)
|
||||
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
|
||||
|
||||
def test_warn_only_injected_once(self):
|
||||
"""Warning for the same hash should only be injected once per thread."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
@@ -483,7 +511,11 @@ class TestToolFrequencyDetection:
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime)
|
||||
assert result is not None
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, HumanMessage)
|
||||
# Warning is appended to the AIMessage content; tool_calls preserved
|
||||
# so the tools node still runs and Moonshot/OpenAI tool-call pairing
|
||||
# validation does not break.
|
||||
assert isinstance(msg, AIMessage)
|
||||
assert msg.tool_calls
|
||||
assert "read_file" in msg.content
|
||||
assert "LOOP DETECTED" in msg.content
|
||||
|
||||
@@ -616,6 +648,37 @@ class TestToolFrequencyDetection:
|
||||
assert result is not None
|
||||
assert "read_file" in result["messages"][0].content
|
||||
|
||||
def test_override_tool_uses_override_thresholds(self):
|
||||
"""A tool in tool_freq_overrides uses its own thresholds, not the global ones."""
|
||||
mw = LoopDetectionMiddleware(
|
||||
tool_freq_warn=5,
|
||||
tool_freq_hard_limit=10,
|
||||
tool_freq_overrides={"bash": (50, 100)},
|
||||
)
|
||||
runtime = _make_runtime()
|
||||
|
||||
# 10 bash calls — would hit global hard_limit=10, but bash override is 100
|
||||
for i in range(10):
|
||||
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
|
||||
assert result is None, f"unexpected trigger on call {i + 1}"
|
||||
|
||||
def test_non_override_tool_falls_back_to_global(self):
|
||||
"""A tool NOT in tool_freq_overrides uses the global warn/hard_limit."""
|
||||
mw = LoopDetectionMiddleware(
|
||||
tool_freq_warn=3,
|
||||
tool_freq_hard_limit=6,
|
||||
tool_freq_overrides={"bash": (50, 100)},
|
||||
)
|
||||
runtime = _make_runtime()
|
||||
|
||||
for i in range(2):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
|
||||
|
||||
# 3rd read_file call hits global warn=3 (read_file has no override)
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
|
||||
assert result is not None
|
||||
assert "read_file" in result["messages"][0].content
|
||||
|
||||
def test_hash_detection_takes_priority(self):
|
||||
"""Hash-based hard stop fires before frequency check for identical calls."""
|
||||
mw = LoopDetectionMiddleware(
|
||||
@@ -636,3 +699,48 @@ class TestToolFrequencyDetection:
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, AIMessage)
|
||||
assert _HARD_STOP_MSG in msg.content
|
||||
|
||||
|
||||
class TestFromConfig:
|
||||
"""Tests for LoopDetectionMiddleware.from_config — the sole validated construction path."""
|
||||
|
||||
@staticmethod
|
||||
def _config(**kwargs):
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
|
||||
return LoopDetectionConfig(**kwargs)
|
||||
|
||||
def test_scalar_fields_mapped(self):
|
||||
config = self._config(
|
||||
warn_threshold=4,
|
||||
hard_limit=8,
|
||||
window_size=15,
|
||||
max_tracked_threads=50,
|
||||
tool_freq_warn=20,
|
||||
tool_freq_hard_limit=40,
|
||||
)
|
||||
mw = LoopDetectionMiddleware.from_config(config)
|
||||
assert mw.warn_threshold == 4
|
||||
assert mw.hard_limit == 8
|
||||
assert mw.window_size == 15
|
||||
assert mw.max_tracked_threads == 50
|
||||
assert mw.tool_freq_warn == 20
|
||||
assert mw.tool_freq_hard_limit == 40
|
||||
|
||||
def test_overrides_converted_to_tuples(self):
|
||||
config = self._config(tool_freq_overrides={"bash": {"warn": 50, "hard_limit": 100}})
|
||||
mw = LoopDetectionMiddleware.from_config(config)
|
||||
assert mw._tool_freq_overrides == {"bash": (50, 100)}
|
||||
|
||||
def test_empty_overrides(self):
|
||||
mw = LoopDetectionMiddleware.from_config(self._config())
|
||||
assert mw._tool_freq_overrides == {}
|
||||
|
||||
def test_constructed_middleware_detects_loops(self):
|
||||
mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4))
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
|
||||
@@ -125,3 +125,68 @@ class TestMigrateMemory:
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
|
||||
migrate_memory(paths, user_id="default") # should not raise
|
||||
|
||||
|
||||
class TestMigrateAgents:
|
||||
@staticmethod
|
||||
def _seed_legacy_agent(paths: Paths, name: str, *, soul: str = "soul", description: str = "d") -> Path:
|
||||
legacy_dir = paths.agents_dir / name
|
||||
legacy_dir.mkdir(parents=True, exist_ok=True)
|
||||
(legacy_dir / "config.yaml").write_text(f"name: {name}\ndescription: {description}\n", encoding="utf-8")
|
||||
(legacy_dir / "SOUL.md").write_text(soul, encoding="utf-8")
|
||||
return legacy_dir
|
||||
|
||||
def test_moves_legacy_into_user_layout(self, base_dir: Path, paths: Paths):
|
||||
self._seed_legacy_agent(paths, "agent-a", soul="soul-a")
|
||||
self._seed_legacy_agent(paths, "agent-b", soul="soul-b")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_agents
|
||||
|
||||
report = migrate_agents(paths, user_id="default")
|
||||
|
||||
assert {entry["agent"] for entry in report} == {"agent-a", "agent-b"}
|
||||
for entry in report:
|
||||
assert entry["user_id"] == "default"
|
||||
assert "moved -> " in entry["action"]
|
||||
|
||||
for name, soul in [("agent-a", "soul-a"), ("agent-b", "soul-b")]:
|
||||
dest = paths.user_agent_dir("default", name)
|
||||
assert dest.exists(), f"{name} should have moved into the per-user layout"
|
||||
assert (dest / "SOUL.md").read_text() == soul
|
||||
|
||||
# Legacy agents/ root is cleaned up once empty.
|
||||
assert not paths.agents_dir.exists()
|
||||
|
||||
def test_dry_run_does_not_move(self, base_dir: Path, paths: Paths):
|
||||
legacy_dir = self._seed_legacy_agent(paths, "agent-a")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_agents
|
||||
|
||||
report = migrate_agents(paths, user_id="default", dry_run=True)
|
||||
|
||||
assert len(report) == 1
|
||||
assert legacy_dir.exists(), "dry-run must not touch the filesystem"
|
||||
assert not paths.user_agent_dir("default", "agent-a").exists()
|
||||
|
||||
def test_existing_destination_is_treated_as_conflict(self, base_dir: Path, paths: Paths):
|
||||
self._seed_legacy_agent(paths, "agent-a", soul="legacy soul")
|
||||
dest = paths.user_agent_dir("default", "agent-a")
|
||||
dest.mkdir(parents=True)
|
||||
(dest / "SOUL.md").write_text("preexisting", encoding="utf-8")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_agents
|
||||
|
||||
report = migrate_agents(paths, user_id="default")
|
||||
|
||||
assert report[0]["action"].startswith("conflict -> ")
|
||||
# Per-user destination must be left untouched.
|
||||
assert (dest / "SOUL.md").read_text() == "preexisting"
|
||||
# Legacy copy lands under migration-conflicts/agents/.
|
||||
conflicts_dir = paths.base_dir / "migration-conflicts" / "agents" / "agent-a"
|
||||
assert (conflicts_dir / "SOUL.md").read_text() == "legacy soul"
|
||||
|
||||
def test_no_legacy_dir_is_noop(self, base_dir: Path, paths: Paths):
|
||||
from scripts.migrate_user_isolation import migrate_agents
|
||||
|
||||
report = migrate_agents(paths, user_id="default")
|
||||
assert report == []
|
||||
|
||||
@@ -50,6 +50,21 @@ class TestUserAgentMemoryFile:
|
||||
assert paths.user_agent_memory_file("bob", "MyAgent") == expected
|
||||
|
||||
|
||||
class TestUserAgentDir:
|
||||
def test_user_agents_dir(self, paths: Paths):
|
||||
assert paths.user_agents_dir("alice") == paths.base_dir / "users" / "alice" / "agents"
|
||||
|
||||
def test_user_agent_dir(self, paths: Paths):
|
||||
assert paths.user_agent_dir("alice", "code-reviewer") == paths.base_dir / "users" / "alice" / "agents" / "code-reviewer"
|
||||
|
||||
def test_user_agent_dir_lowercases_name(self, paths: Paths):
|
||||
assert paths.user_agent_dir("alice", "CodeReviewer") == paths.base_dir / "users" / "alice" / "agents" / "codereviewer"
|
||||
|
||||
def test_user_agent_dir_validates_user_id(self, paths: Paths):
|
||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||
paths.user_agent_dir("../escape", "myagent")
|
||||
|
||||
|
||||
class TestUserThreadDir:
|
||||
def test_user_thread_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1"
|
||||
|
||||
@@ -8,7 +8,9 @@ Tests:
|
||||
5. Postgres missing-dep error message
|
||||
"""
|
||||
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -221,13 +223,8 @@ class TestEngineLifecycle:
|
||||
"""If asyncpg is not installed, error message tells user what to do."""
|
||||
from deerflow.persistence.engine import init_engine
|
||||
|
||||
try:
|
||||
import asyncpg # noqa: F401
|
||||
|
||||
pytest.skip("asyncpg is installed -- cannot test missing-dep path")
|
||||
except ImportError:
|
||||
# asyncpg is not installed — this is the expected state for this test.
|
||||
# We proceed to verify that init_engine raises an actionable ImportError.
|
||||
pass # noqa: S110 — intentionally ignored
|
||||
with pytest.raises(ImportError, match="uv sync --extra postgres"):
|
||||
with (
|
||||
patch.dict(sys.modules, {"asyncpg": None}),
|
||||
pytest.raises(ImportError, match="uv sync --all-packages --extra postgres"),
|
||||
):
|
||||
await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x")
|
||||
|
||||
@@ -0,0 +1,293 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from deerflow.community.aio_sandbox.remote_backend import RemoteSandboxBackend
|
||||
from deerflow.community.aio_sandbox.sandbox_info import SandboxInfo
|
||||
|
||||
|
||||
class _StubResponse:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
status_code: int = 200,
|
||||
payload: object | None = None,
|
||||
json_exc: Exception | None = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self._payload = {} if payload is None else payload
|
||||
self._json_exc = json_exc
|
||||
self.ok = 200 <= status_code < 400
|
||||
self.text = ""
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
if self.status_code >= 400:
|
||||
raise requests.HTTPError(f"HTTP {self.status_code}")
|
||||
|
||||
def json(self) -> object:
|
||||
if self._json_exc is not None:
|
||||
raise self._json_exc
|
||||
return self._payload
|
||||
|
||||
|
||||
def test_list_running_delegates_to_provisioner_list(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
sandbox_info = SandboxInfo(sandbox_id="test-id", sandbox_url="http://localhost:8080")
|
||||
|
||||
def mock_list():
|
||||
return [sandbox_info]
|
||||
|
||||
monkeypatch.setattr(backend, "_provisioner_list", mock_list)
|
||||
|
||||
assert backend.list_running() == [sandbox_info]
|
||||
|
||||
|
||||
def test_provisioner_list_returns_sandbox_infos_and_filters_invalid_entries(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
assert url == "http://provisioner:8002/api/sandboxes"
|
||||
assert timeout == 10
|
||||
return _StubResponse(
|
||||
payload={
|
||||
"sandboxes": [
|
||||
{"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"},
|
||||
{"sandbox_id": "missing-url"},
|
||||
{"sandbox_url": "http://k3s:31002"},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
infos = backend._provisioner_list()
|
||||
assert len(infos) == 1
|
||||
assert infos[0].sandbox_id == "abc123"
|
||||
assert infos[0].sandbox_url == "http://k3s:31001"
|
||||
|
||||
|
||||
def test_provisioner_list_returns_empty_on_request_exception(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
raise requests.RequestException("network down")
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
assert backend._provisioner_list() == []
|
||||
|
||||
|
||||
def test_provisioner_list_returns_empty_when_payload_is_not_dict(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
return _StubResponse(payload=[{"sandbox_id": "abc", "sandbox_url": "http://k3s:31001"}])
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
assert backend._provisioner_list() == []
|
||||
|
||||
|
||||
def test_provisioner_list_returns_empty_when_sandboxes_is_not_list(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
return _StubResponse(payload={"sandboxes": {"sandbox_id": "abc"}})
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
assert backend._provisioner_list() == []
|
||||
|
||||
|
||||
def test_provisioner_list_skips_non_dict_sandbox_entries(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
return _StubResponse(
|
||||
payload={
|
||||
"sandboxes": [
|
||||
{"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"},
|
||||
"bad-entry",
|
||||
123,
|
||||
None,
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
infos = backend._provisioner_list()
|
||||
assert len(infos) == 1
|
||||
assert infos[0].sandbox_id == "abc123"
|
||||
assert infos[0].sandbox_url == "http://k3s:31001"
|
||||
|
||||
|
||||
def test_create_delegates_to_provisioner_create(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
expected = SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001")
|
||||
|
||||
def mock_create(thread_id: str, sandbox_id: str, extra_mounts=None):
|
||||
assert thread_id == "thread-1"
|
||||
assert sandbox_id == "abc123"
|
||||
assert extra_mounts == [("/host", "/container", False)]
|
||||
return expected
|
||||
|
||||
monkeypatch.setattr(backend, "_provisioner_create", mock_create)
|
||||
|
||||
result = backend.create("thread-1", "abc123", extra_mounts=[("/host", "/container", False)])
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_provisioner_create_returns_sandbox_info(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_post(url: str, json: dict, timeout: int):
|
||||
assert url == "http://provisioner:8002/api/sandboxes"
|
||||
assert json == {"sandbox_id": "abc123", "thread_id": "thread-1"}
|
||||
assert timeout == 30
|
||||
return _StubResponse(payload={"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"})
|
||||
|
||||
monkeypatch.setattr(requests, "post", mock_post)
|
||||
|
||||
info = backend._provisioner_create("thread-1", "abc123")
|
||||
assert info.sandbox_id == "abc123"
|
||||
assert info.sandbox_url == "http://k3s:31001"
|
||||
|
||||
|
||||
def test_provisioner_create_raises_runtime_error_on_request_exception(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_post(url: str, json: dict, timeout: int):
|
||||
raise requests.RequestException("boom")
|
||||
|
||||
monkeypatch.setattr(requests, "post", mock_post)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Provisioner create failed"):
|
||||
backend._provisioner_create("thread-1", "abc123")
|
||||
|
||||
|
||||
def test_destroy_delegates_to_provisioner_destroy(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
called: list[str] = []
|
||||
|
||||
def mock_destroy(sandbox_id: str):
|
||||
called.append(sandbox_id)
|
||||
|
||||
monkeypatch.setattr(backend, "_provisioner_destroy", mock_destroy)
|
||||
|
||||
backend.destroy(SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001"))
|
||||
assert called == ["abc123"]
|
||||
|
||||
|
||||
def test_provisioner_destroy_calls_delete(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_delete(url: str, timeout: int):
|
||||
assert url == "http://provisioner:8002/api/sandboxes/abc123"
|
||||
assert timeout == 15
|
||||
return _StubResponse(status_code=200)
|
||||
|
||||
monkeypatch.setattr(requests, "delete", mock_delete)
|
||||
|
||||
backend._provisioner_destroy("abc123")
|
||||
|
||||
|
||||
def test_provisioner_destroy_swallows_request_exception(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_delete(url: str, timeout: int):
|
||||
raise requests.RequestException("network down")
|
||||
|
||||
monkeypatch.setattr(requests, "delete", mock_delete)
|
||||
|
||||
backend._provisioner_destroy("abc123")
|
||||
|
||||
|
||||
def test_is_alive_delegates_to_provisioner_is_alive(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_is_alive(sandbox_id: str):
|
||||
assert sandbox_id == "abc123"
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(backend, "_provisioner_is_alive", mock_is_alive)
|
||||
|
||||
alive = backend.is_alive(SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001"))
|
||||
assert alive is True
|
||||
|
||||
|
||||
def test_provisioner_is_alive_true_only_when_status_running(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get_running(url: str, timeout: int):
|
||||
return _StubResponse(payload={"status": "Running"})
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get_running)
|
||||
assert backend._provisioner_is_alive("abc123") is True
|
||||
|
||||
def mock_get_pending(url: str, timeout: int):
|
||||
return _StubResponse(payload={"status": "Pending"})
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get_pending)
|
||||
assert backend._provisioner_is_alive("abc123") is False
|
||||
|
||||
|
||||
def test_provisioner_is_alive_returns_false_on_request_exception(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
raise requests.RequestException("boom")
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
assert backend._provisioner_is_alive("abc123") is False
|
||||
|
||||
|
||||
def test_discover_delegates_to_provisioner_discover(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
expected = SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001")
|
||||
|
||||
def mock_discover(sandbox_id: str):
|
||||
assert sandbox_id == "abc123"
|
||||
return expected
|
||||
|
||||
monkeypatch.setattr(backend, "_provisioner_discover", mock_discover)
|
||||
|
||||
result = backend.discover("abc123")
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_provisioner_discover_returns_none_on_404(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
return _StubResponse(status_code=404)
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
assert backend._provisioner_discover("abc123") is None
|
||||
|
||||
|
||||
def test_provisioner_discover_returns_info_on_success(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
return _StubResponse(payload={"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"})
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
info = backend._provisioner_discover("abc123")
|
||||
assert info is not None
|
||||
assert info.sandbox_id == "abc123"
|
||||
assert info.sandbox_url == "http://k3s:31001"
|
||||
|
||||
|
||||
def test_provisioner_discover_returns_none_on_request_exception(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
raise requests.RequestException("boom")
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
assert backend._provisioner_discover("abc123") is None
|
||||
@@ -310,6 +310,28 @@ class TestDbRunEventStore:
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_structured_content_round_trips(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
content = [{"type": "text", "text": "hello"}, {"type": "image_url", "image_url": {"url": "https://example.test/a.png"}}]
|
||||
record = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=content)
|
||||
|
||||
assert record["content"] == content
|
||||
assert record["metadata"]["content_is_json"] is True
|
||||
assert "content_is_dict" not in record["metadata"]
|
||||
|
||||
messages = await s.list_messages("t1")
|
||||
assert messages[0]["content"] == content
|
||||
assert messages[0]["metadata"]["content_is_json"] is True
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_pagination(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
@@ -373,6 +395,55 @@ class TestDbRunEventStore:
|
||||
assert seqs == list(range(1, 51))
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_batch_accepts_structured_content(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
content = [{"messages": [{"type": "ai", "content": ""}]}]
|
||||
results = await s.put_batch(
|
||||
[
|
||||
{
|
||||
"thread_id": "t1",
|
||||
"run_id": "r1",
|
||||
"event_type": "run.end",
|
||||
"category": "outputs",
|
||||
"content": content,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert results[0]["content"] == content
|
||||
assert results[0]["metadata"]["content_is_json"] is True
|
||||
|
||||
events = await s.list_events("t1", "r1")
|
||||
assert events[0]["content"] == content
|
||||
assert events[0]["metadata"]["content_is_json"] is True
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dict_content_keeps_legacy_metadata_flag(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
content = {"status": "success"}
|
||||
record = await s.put(thread_id="t1", run_id="r1", event_type="run.end", category="outputs", content=content)
|
||||
|
||||
assert record["content"] == content
|
||||
assert record["metadata"]["content_is_json"] is True
|
||||
assert record["metadata"]["content_is_dict"] is True
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
# -- Factory tests --
|
||||
|
||||
|
||||
@@ -166,6 +166,61 @@ class TestRunRepository:
|
||||
assert row["total_tokens"] == 100
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("success-run", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"success-run",
|
||||
status="success",
|
||||
total_input_tokens=70,
|
||||
total_output_tokens=30,
|
||||
total_tokens=100,
|
||||
lead_agent_tokens=80,
|
||||
subagent_tokens=15,
|
||||
middleware_tokens=5,
|
||||
)
|
||||
await repo.put("error-run", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"error-run",
|
||||
status="error",
|
||||
total_input_tokens=20,
|
||||
total_output_tokens=30,
|
||||
total_tokens=50,
|
||||
lead_agent_tokens=40,
|
||||
subagent_tokens=10,
|
||||
)
|
||||
await repo.put("running-run", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"running-run",
|
||||
status="running",
|
||||
total_input_tokens=900,
|
||||
total_output_tokens=99,
|
||||
total_tokens=999,
|
||||
lead_agent_tokens=999,
|
||||
)
|
||||
await repo.put("other-thread-run", thread_id="t2", status="running")
|
||||
await repo.update_run_completion(
|
||||
"other-thread-run",
|
||||
status="success",
|
||||
total_tokens=888,
|
||||
lead_agent_tokens=888,
|
||||
)
|
||||
|
||||
agg = await repo.aggregate_tokens_by_thread("t1")
|
||||
|
||||
assert agg["total_tokens"] == 150
|
||||
assert agg["total_input_tokens"] == 90
|
||||
assert agg["total_output_tokens"] == 60
|
||||
assert agg["total_runs"] == 2
|
||||
assert agg["by_model"] == {"unknown": {"tokens": 150, "runs": 2}}
|
||||
assert agg["by_caller"] == {
|
||||
"lead_agent": 120,
|
||||
"subagent": 25,
|
||||
"middleware": 5,
|
||||
}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_ordered_desc(self, tmp_path):
|
||||
"""list_by_thread returns newest first."""
|
||||
|
||||
@@ -6,6 +6,8 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.tools.builtins.setup_agent_tool import setup_agent
|
||||
|
||||
# --- Helpers ---
|
||||
@@ -27,6 +29,7 @@ def _make_paths_mock(tmp_path: Path):
|
||||
paths = MagicMock()
|
||||
paths.base_dir = tmp_path
|
||||
paths.agent_dir = lambda name: tmp_path / "agents" / name
|
||||
paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name
|
||||
return paths
|
||||
|
||||
|
||||
@@ -54,7 +57,7 @@ def test_setup_agent_rejects_invalid_agent_name_before_writing(tmp_path, monkeyp
|
||||
messages = result.update["messages"]
|
||||
assert len(messages) == 1
|
||||
assert "Invalid agent name" in messages[0].content
|
||||
assert not (tmp_path / "agents").exists()
|
||||
assert not (tmp_path / "users" / "test-user-autouse" / "agents").exists()
|
||||
assert not (outside_dir / "evil" / "SOUL.md").exists()
|
||||
|
||||
|
||||
@@ -68,7 +71,7 @@ def test_setup_agent_rejects_absolute_agent_name_before_writing(tmp_path, monkey
|
||||
messages = result.update["messages"]
|
||||
assert len(messages) == 1
|
||||
assert "Invalid agent name" in messages[0].content
|
||||
assert not (tmp_path / "agents").exists()
|
||||
assert not (tmp_path / "users" / "test-user-autouse" / "agents").exists()
|
||||
assert not (Path(absolute_agent) / "SOUL.md").exists()
|
||||
|
||||
|
||||
@@ -81,10 +84,10 @@ class TestSetupAgentNoDataLoss:
|
||||
def test_existing_agent_dir_preserved_on_failure(self, tmp_path: Path):
|
||||
"""If the agent directory already exists and setup fails,
|
||||
the directory and its contents must NOT be deleted."""
|
||||
agent_dir = tmp_path / "agents" / "test-agent"
|
||||
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "test-agent"
|
||||
agent_dir.mkdir(parents=True)
|
||||
old_soul = agent_dir / "SOUL.md"
|
||||
old_soul.write_text("original soul content")
|
||||
old_soul.write_text("original soul content", encoding="utf-8")
|
||||
|
||||
with patch("deerflow.tools.builtins.setup_agent_tool.get_paths", return_value=_make_paths_mock(tmp_path)):
|
||||
# Force soul_file.write_text to raise after directory already exists
|
||||
@@ -103,7 +106,7 @@ class TestSetupAgentNoDataLoss:
|
||||
def test_new_agent_dir_cleaned_up_on_failure(self, tmp_path: Path):
|
||||
"""If the agent directory is newly created and setup fails,
|
||||
the directory should be cleaned up."""
|
||||
agent_dir = tmp_path / "agents" / "test-agent"
|
||||
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "test-agent"
|
||||
assert not agent_dir.exists()
|
||||
|
||||
with patch("deerflow.tools.builtins.setup_agent_tool.get_paths", return_value=_make_paths_mock(tmp_path)):
|
||||
@@ -121,7 +124,27 @@ class TestSetupAgentNoDataLoss:
|
||||
"""Happy path: setup_agent creates config.yaml and SOUL.md."""
|
||||
_call_setup_agent(tmp_path, soul="# My Agent", description="A test agent")
|
||||
|
||||
agent_dir = tmp_path / "agents" / "test-agent"
|
||||
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "test-agent"
|
||||
assert agent_dir.exists()
|
||||
assert (agent_dir / "SOUL.md").read_text() == "# My Agent"
|
||||
assert (agent_dir / "config.yaml").exists()
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_runtime_user_id_used_when_contextvar_missing(self, tmp_path: Path):
|
||||
"""setup_agent should not fall back to default when runtime carries user_id."""
|
||||
runtime = _DummyRuntime(
|
||||
context={"agent_name": "test-agent", "user_id": "auth-user-42"},
|
||||
tool_call_id="tool-3",
|
||||
)
|
||||
|
||||
with patch("deerflow.tools.builtins.setup_agent_tool.get_paths", return_value=_make_paths_mock(tmp_path)):
|
||||
setup_agent.func(
|
||||
soul="# My Agent",
|
||||
description="A test agent",
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
expected_dir = tmp_path / "users" / "auth-user-42" / "agents" / "test-agent"
|
||||
default_dir = tmp_path / "users" / "default" / "agents" / "test-agent"
|
||||
assert (expected_dir / "SOUL.md").read_text() == "# My Agent"
|
||||
assert not default_dir.exists()
|
||||
|
||||
@@ -313,7 +313,7 @@ class TestWriteConfigYaml:
|
||||
{
|
||||
"config_version": 5,
|
||||
"log_level": "info",
|
||||
"token_usage": {"enabled": False},
|
||||
"token_usage": {"enabled": True},
|
||||
"tool_groups": [{"name": "web"}, {"name": "file:read"}, {"name": "file:write"}, {"name": "bash"}],
|
||||
"tools": [
|
||||
{
|
||||
@@ -361,7 +361,7 @@ class TestWriteConfigYaml:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
assert data["log_level"] == "info"
|
||||
assert data["token_usage"]["enabled"] is False
|
||||
assert data["token_usage"]["enabled"] is True
|
||||
assert data["tool_groups"][0]["name"] == "web"
|
||||
assert data["summarization"]["max_tokens"] == 2048
|
||||
assert any(tool["name"] == "image_search" and tool["max_results"] == 5 for tool in data["tools"])
|
||||
|
||||
@@ -86,6 +86,33 @@ def test_parse_license_field(tmp_path):
|
||||
assert skill.license == "MIT"
|
||||
|
||||
|
||||
def test_parse_missing_allowed_tools_returns_none(tmp_path):
|
||||
skill_file = _write_skill(tmp_path, "name: my-skill\ndescription: Test")
|
||||
skill = parse_skill_file(skill_file, category="custom")
|
||||
assert skill is not None
|
||||
assert skill.allowed_tools is None
|
||||
|
||||
|
||||
def test_parse_allowed_tools_list(tmp_path):
|
||||
skill_file = _write_skill(tmp_path, 'name: my-skill\ndescription: Test\nallowed-tools: ["bash", "read_file"]')
|
||||
skill = parse_skill_file(skill_file, category="custom")
|
||||
assert skill is not None
|
||||
assert skill.allowed_tools == ["bash", "read_file"]
|
||||
|
||||
|
||||
def test_parse_empty_allowed_tools_list(tmp_path):
|
||||
skill_file = _write_skill(tmp_path, "name: my-skill\ndescription: Test\nallowed-tools: []")
|
||||
skill = parse_skill_file(skill_file, category="custom")
|
||||
assert skill is not None
|
||||
assert skill.allowed_tools == []
|
||||
|
||||
|
||||
def test_parse_invalid_allowed_tools_returns_none(tmp_path):
|
||||
skill_file = _write_skill(tmp_path, "name: my-skill\ndescription: Test\nallowed-tools: bash")
|
||||
skill = parse_skill_file(skill_file, category="custom")
|
||||
assert skill is None
|
||||
|
||||
|
||||
def test_parse_missing_name_returns_none(tmp_path):
|
||||
"""Skills missing a name field are rejected."""
|
||||
skill_file = _write_skill(tmp_path, "description: A test skill")
|
||||
|
||||
@@ -30,13 +30,47 @@ class TestValidateSkillFrontmatter:
|
||||
def test_valid_with_all_allowed_fields(self, tmp_path):
|
||||
skill_dir = _write_skill(
|
||||
tmp_path,
|
||||
"---\nname: my-skill\ndescription: A skill\nlicense: MIT\nversion: '1.0'\nauthor: test\n---\n\nBody\n",
|
||||
"---\nname: my-skill\ndescription: A skill\nlicense: MIT\nversion: '1.0'\nauthor: test\nallowed-tools: [bash, read_file]\n---\n\nBody\n",
|
||||
)
|
||||
valid, msg, name = _validate_skill_frontmatter(skill_dir)
|
||||
assert valid is True
|
||||
assert msg == "Skill is valid!"
|
||||
assert name == "my-skill"
|
||||
|
||||
def test_allows_empty_allowed_tools(self, tmp_path):
|
||||
skill_dir = _write_skill(
|
||||
tmp_path,
|
||||
"---\nname: my-skill\ndescription: A skill\nallowed-tools: []\n---\n\nBody\n",
|
||||
)
|
||||
valid, msg, name = _validate_skill_frontmatter(skill_dir)
|
||||
assert valid is True
|
||||
assert msg == "Skill is valid!"
|
||||
assert name == "my-skill"
|
||||
|
||||
def test_rejects_allowed_tools_string(self, tmp_path):
|
||||
skill_dir = _write_skill(
|
||||
tmp_path,
|
||||
"---\nname: my-skill\ndescription: A skill\nallowed-tools: bash\n---\n\nBody\n",
|
||||
)
|
||||
valid, msg, name = _validate_skill_frontmatter(skill_dir)
|
||||
assert valid is False
|
||||
assert "allowed-tools" in msg
|
||||
assert str(tmp_path) not in msg
|
||||
assert "SKILL.md" in msg
|
||||
assert name is None
|
||||
|
||||
def test_rejects_allowed_tools_non_string_entry(self, tmp_path):
|
||||
skill_dir = _write_skill(
|
||||
tmp_path,
|
||||
"---\nname: my-skill\ndescription: A skill\nallowed-tools: [bash, 1]\n---\n\nBody\n",
|
||||
)
|
||||
valid, msg, name = _validate_skill_frontmatter(skill_dir)
|
||||
assert valid is False
|
||||
assert "allowed-tools" in msg
|
||||
assert str(tmp_path) not in msg
|
||||
assert "SKILL.md" in msg
|
||||
assert name is None
|
||||
|
||||
def test_missing_skill_md(self, tmp_path):
|
||||
valid, msg, name = _validate_skill_frontmatter(tmp_path)
|
||||
assert valid is False
|
||||
|
||||
@@ -17,11 +17,14 @@ import asyncio
|
||||
import sys
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
# Module names that need to be mocked to break circular imports
|
||||
_MOCKED_MODULE_NAMES = [
|
||||
"deerflow.agents",
|
||||
@@ -32,14 +35,15 @@ _MOCKED_MODULE_NAMES = [
|
||||
"deerflow.sandbox.middleware",
|
||||
"deerflow.sandbox.security",
|
||||
"deerflow.models",
|
||||
"deerflow.skills.storage",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_executor_classes():
|
||||
"""Set up mocked modules and import real executor classes.
|
||||
|
||||
This fixture runs once per session and yields the executor classes.
|
||||
This fixture runs once per test and yields the executor classes.
|
||||
It handles module cleanup to avoid affecting other test files.
|
||||
"""
|
||||
# Save original modules
|
||||
@@ -53,6 +57,9 @@ def _setup_executor_classes():
|
||||
# Set up mocks
|
||||
for name in _MOCKED_MODULE_NAMES:
|
||||
sys.modules[name] = MagicMock()
|
||||
storage_module = ModuleType("deerflow.skills.storage")
|
||||
storage_module.get_or_new_skill_storage = lambda **kwargs: SimpleNamespace(load_skills=lambda *, enabled_only: [])
|
||||
sys.modules["deerflow.skills.storage"] = storage_module
|
||||
|
||||
# Import real classes inside fixture
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
@@ -117,6 +124,26 @@ class MockAIMessage:
|
||||
return msg
|
||||
|
||||
|
||||
class NamedTool:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
def _skill(name: str, allowed_tools: list[str] | None) -> Skill:
|
||||
skill_dir = Path(f"/tmp/{name}")
|
||||
return Skill(
|
||||
name=name,
|
||||
description=f"{name} skill",
|
||||
license=None,
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=Path(name),
|
||||
category="custom",
|
||||
allowed_tools=allowed_tools,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
async def async_iterator(items):
|
||||
"""Helper to create an async iterator from a list."""
|
||||
for item in items:
|
||||
@@ -288,7 +315,7 @@ class TestAgentConstruction:
|
||||
captured["app_config"] = app_config
|
||||
return SimpleNamespace(load_skills=lambda *, enabled_only: [SimpleNamespace(name="demo-skill", skill_file=skill_file)])
|
||||
|
||||
monkeypatch.setattr("deerflow.skills.storage.get_or_new_skill_storage", fake_get_or_new_skill_storage)
|
||||
monkeypatch.setattr(sys.modules["deerflow.skills.storage"], "get_or_new_skill_storage", fake_get_or_new_skill_storage)
|
||||
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
@@ -297,7 +324,8 @@ class TestAgentConstruction:
|
||||
thread_id="test-thread",
|
||||
)
|
||||
|
||||
messages = await executor._load_skill_messages()
|
||||
skills = await executor._load_skills()
|
||||
messages = await executor._load_skill_messages(skills)
|
||||
|
||||
assert captured["app_config"] is app_config
|
||||
assert len(messages) == 1
|
||||
@@ -487,6 +515,115 @@ class TestAsyncExecutionPath:
|
||||
assert "Task" in result.result
|
||||
|
||||
|
||||
class TestSkillAllowedTools:
|
||||
@pytest.mark.anyio
|
||||
async def test_skill_allowed_tools_union_filters_agent_tools(self, classes, base_config, mock_agent, msg):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("a", ["bash"]), _skill("b", ["read_file"])]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
await executor._aexecute("Task")
|
||||
|
||||
create_agent_mock.assert_called_once()
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash", "read_file"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_all_missing_allowed_tools_preserves_legacy_allow_all(self, classes, base_config, mock_agent, msg):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("legacy-a", None), _skill("legacy-b", None)]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
await executor._aexecute("Task")
|
||||
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash", "read_file", "web_search"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_mixed_missing_allowed_tools_does_not_disable_explicit_restrictions(self, classes, base_config, mock_agent, msg):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("legacy", None), _skill("restricted", ["bash"])]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
await executor._aexecute("Task")
|
||||
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_mixed_missing_allowed_tools_order_does_not_disable_explicit_restrictions(self, classes, base_config, mock_agent, msg):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("restricted", ["bash"]), _skill("legacy", None)]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
await executor._aexecute("Task")
|
||||
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_allowed_tools_contributes_no_tools(self, classes, base_config, mock_agent, msg, caplog):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("empty", []), _skill("reader", ["read_file"])]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock, caplog.at_level("INFO"):
|
||||
await executor._aexecute("Task")
|
||||
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["read_file"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
assert "declared empty allowed-tools" in caplog.text
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skill_load_failure_fails_without_creating_agent(self, classes, base_config, mock_agent):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
executor = SubagentExecutor(config=base_config, tools=[NamedTool("bash")], thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
raise RuntimeError("skill storage unavailable")
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
result = await executor._aexecute("Task")
|
||||
|
||||
assert result.status == classes["SubagentStatus"].FAILED
|
||||
assert result.error == "skill storage unavailable"
|
||||
create_agent_mock.assert_not_called()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Sync Execution Path Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -27,6 +27,14 @@ def _other_call(name="bash", call_id="call_other"):
|
||||
return {"name": name, "id": call_id, "args": {}}
|
||||
|
||||
|
||||
def _raw_tool_call(call_id: str, name: str = "task") -> dict:
|
||||
return {
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": "{}"},
|
||||
}
|
||||
|
||||
|
||||
class TestClampSubagentLimit:
|
||||
def test_below_min_clamped_to_min(self):
|
||||
assert _clamp_subagent_limit(0) == MIN_SUBAGENT_LIMIT
|
||||
@@ -117,6 +125,23 @@ class TestTruncateTaskCalls:
|
||||
task_calls = [tc for tc in updated_msg.tool_calls if tc["name"] == "task"]
|
||||
assert len(task_calls) == 2
|
||||
|
||||
def test_truncation_syncs_raw_provider_tool_calls(self):
|
||||
mw = SubagentLimitMiddleware(max_concurrent=2)
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3"), _task_call("t4")],
|
||||
additional_kwargs={"tool_calls": [_raw_tool_call("t1"), _raw_tool_call("t2"), _raw_tool_call("t3"), _raw_tool_call("t4")]},
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
|
||||
result = mw._truncate_task_calls({"messages": [msg]})
|
||||
|
||||
assert result is not None
|
||||
updated_msg = result["messages"][0]
|
||||
assert [tc["id"] for tc in updated_msg.tool_calls] == ["t1", "t2"]
|
||||
assert [tc["id"] for tc in updated_msg.additional_kwargs["tool_calls"]] == ["t1", "t2"]
|
||||
assert updated_msg.response_metadata["finish_reason"] == "tool_calls"
|
||||
|
||||
def test_only_non_task_calls_returns_none(self):
|
||||
mw = SubagentLimitMiddleware()
|
||||
msg = AIMessage(
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware
|
||||
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
@@ -20,6 +22,14 @@ def _messages() -> list:
|
||||
]
|
||||
|
||||
|
||||
def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content="<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
|
||||
id=msg_id,
|
||||
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
|
||||
)
|
||||
|
||||
|
||||
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
|
||||
context = {}
|
||||
if thread_id is not None:
|
||||
@@ -75,6 +85,14 @@ def _skill_conversation() -> list:
|
||||
]
|
||||
|
||||
|
||||
def _raw_tool_call(tool_id: str, name: str = "read_file") -> dict:
|
||||
return {
|
||||
"id": tool_id,
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": "{}"},
|
||||
}
|
||||
|
||||
|
||||
def test_before_summarization_hook_receives_messages_before_compression() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append])
|
||||
@@ -90,6 +108,38 @@ def test_before_summarization_hook_receives_messages_before_compression() -> Non
|
||||
assert result["messages"][1].content.startswith("Here is a summary")
|
||||
|
||||
|
||||
def test_dynamic_context_reminder_is_preserved_across_summarization() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append])
|
||||
reminder = _dynamic_context_reminder()
|
||||
|
||||
result = middleware.before_model(
|
||||
{
|
||||
"messages": [
|
||||
reminder,
|
||||
HumanMessage(content="user-1"),
|
||||
AIMessage(content="assistant-1"),
|
||||
HumanMessage(content="user-2"),
|
||||
]
|
||||
},
|
||||
_runtime(),
|
||||
)
|
||||
|
||||
assert len(captured) == 1
|
||||
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1"]
|
||||
assert captured[0].preserved_messages[0] is reminder
|
||||
|
||||
emitted = result["messages"]
|
||||
assert isinstance(emitted[0], RemoveMessage)
|
||||
assert emitted[1].name == "summary"
|
||||
assert emitted[2] is reminder
|
||||
|
||||
followup_state = {"messages": [*emitted[1:], HumanMessage(content="Follow-up", id="msg-2")]}
|
||||
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
assert DynamicContextMiddleware().before_agent(followup_state, _runtime()) is None
|
||||
|
||||
|
||||
def test_before_summarization_hook_not_called_when_threshold_not_met() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append], trigger=("messages", 10))
|
||||
@@ -413,6 +463,47 @@ def test_skill_rescue_does_not_preserve_non_skill_outputs_from_mixed_tool_calls(
|
||||
assert any(isinstance(m, ToolMessage) and m.content == "user notes" for m in summarized)
|
||||
|
||||
|
||||
def test_skill_rescue_syncs_raw_provider_tool_calls_on_split_ai_messages() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
before_summarization=[captured.append],
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
preserve_recent_skill_count=5,
|
||||
preserve_recent_skill_tokens=10_000,
|
||||
preserve_recent_skill_tokens_per_skill=10_000,
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="u1"),
|
||||
AIMessage(
|
||||
content="reading skill and notes",
|
||||
tool_calls=[
|
||||
_skill_read_call("skill-1", "alpha"),
|
||||
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
|
||||
],
|
||||
additional_kwargs={"tool_calls": [_raw_tool_call("skill-1"), _raw_tool_call("file-1")]},
|
||||
),
|
||||
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
|
||||
ToolMessage(content="user notes", tool_call_id="file-1"),
|
||||
HumanMessage(content="u2"),
|
||||
AIMessage(content="done"),
|
||||
]
|
||||
|
||||
middleware.before_model({"messages": messages}, _runtime())
|
||||
|
||||
preserved = captured[0].preserved_messages
|
||||
summarized = captured[0].messages_to_summarize
|
||||
|
||||
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
|
||||
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
|
||||
|
||||
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
|
||||
assert [tc["id"] for tc in preserved_ai.additional_kwargs["tool_calls"]] == ["skill-1"]
|
||||
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["file-1"]
|
||||
assert [tc["id"] for tc in summarized_ai.additional_kwargs["tool_calls"]] == ["file-1"]
|
||||
|
||||
|
||||
def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
@@ -451,6 +542,42 @@ def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
|
||||
assert summarized_ai.content == "reading skill and notes"
|
||||
|
||||
|
||||
def test_skill_rescue_removes_raw_provider_tool_calls_from_content_only_summary_clone() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
before_summarization=[captured.append],
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
preserve_recent_skill_count=5,
|
||||
preserve_recent_skill_tokens=10_000,
|
||||
preserve_recent_skill_tokens_per_skill=10_000,
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="u1"),
|
||||
AIMessage(
|
||||
content="reading skill",
|
||||
tool_calls=[_skill_read_call("skill-1", "alpha")],
|
||||
additional_kwargs={"tool_calls": [_raw_tool_call("skill-1")], "function_call": {"name": "read_file"}},
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
),
|
||||
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
|
||||
HumanMessage(content="u2"),
|
||||
AIMessage(content="done"),
|
||||
]
|
||||
|
||||
middleware.before_model({"messages": messages}, _runtime())
|
||||
|
||||
summarized = captured[0].messages_to_summarize
|
||||
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage))
|
||||
|
||||
assert summarized_ai.content == "reading skill"
|
||||
assert summarized_ai.tool_calls == []
|
||||
assert "tool_calls" not in summarized_ai.additional_kwargs
|
||||
assert "function_call" not in summarized_ai.additional_kwargs
|
||||
assert summarized_ai.response_metadata["finish_reason"] == "stop"
|
||||
|
||||
|
||||
def test_skill_rescue_only_preserves_skill_calls_with_matched_tool_results() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
|
||||
@@ -221,7 +221,6 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
|
||||
prompt="collect diagnostics",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-123",
|
||||
max_turns=7,
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: all done"
|
||||
@@ -229,7 +228,7 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
|
||||
assert captured["task_id"] == "tc-123"
|
||||
assert captured["executor_kwargs"]["thread_id"] == "thread-1"
|
||||
assert captured["executor_kwargs"]["parent_model"] == "ark-model"
|
||||
assert captured["executor_kwargs"]["config"].max_turns == 7
|
||||
assert captured["executor_kwargs"]["config"].max_turns == config.max_turns
|
||||
# Skills are no longer appended to system_prompt; they are loaded per-session
|
||||
# by SubagentExecutor and injected as conversation items (Codex pattern).
|
||||
assert captured["executor_kwargs"]["config"].system_prompt == "Base system prompt"
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Tests for thread-level token usage aggregation API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import thread_runs
|
||||
|
||||
|
||||
def _make_app(run_store: MagicMock):
|
||||
app = make_authed_test_app()
|
||||
app.include_router(thread_runs.router)
|
||||
app.state.run_store = run_store
|
||||
return app
|
||||
|
||||
|
||||
def test_thread_token_usage_returns_stable_shape():
|
||||
run_store = MagicMock()
|
||||
run_store.aggregate_tokens_by_thread = AsyncMock(
|
||||
return_value={
|
||||
"total_tokens": 150,
|
||||
"total_input_tokens": 90,
|
||||
"total_output_tokens": 60,
|
||||
"total_runs": 2,
|
||||
"by_model": {"unknown": {"tokens": 150, "runs": 2}},
|
||||
"by_caller": {
|
||||
"lead_agent": 120,
|
||||
"subagent": 25,
|
||||
"middleware": 5,
|
||||
},
|
||||
},
|
||||
)
|
||||
app = _make_app(run_store)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-1/token-usage")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"thread_id": "thread-1",
|
||||
"total_tokens": 150,
|
||||
"total_input_tokens": 90,
|
||||
"total_output_tokens": 60,
|
||||
"total_runs": 2,
|
||||
"by_model": {"unknown": {"tokens": 150, "runs": 2}},
|
||||
"by_caller": {
|
||||
"lead_agent": 120,
|
||||
"subagent": 25,
|
||||
"middleware": 5,
|
||||
},
|
||||
}
|
||||
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
|
||||
@@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.agents.middlewares import title_middleware as title_middleware_module
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
|
||||
|
||||
@@ -44,6 +45,22 @@ class TestTitleMiddlewareCoreLogic:
|
||||
|
||||
assert middleware._should_generate_title(state) is True
|
||||
|
||||
def test_should_generate_title_with_dynamic_context_reminder(self):
|
||||
_set_test_title_config(enabled=True)
|
||||
middleware = TitleMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(
|
||||
content="<system-reminder>\n<memory>User prefers Python.</memory>\n</system-reminder>",
|
||||
additional_kwargs={_DYNAMIC_CONTEXT_REMINDER_KEY: True},
|
||||
),
|
||||
HumanMessage(content="帮我总结这段代码"),
|
||||
AIMessage(content="好的,我先看结构"),
|
||||
]
|
||||
}
|
||||
|
||||
assert middleware._should_generate_title(state) is True
|
||||
|
||||
def test_should_not_generate_title_when_disabled_or_already_set(self):
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
@@ -243,6 +260,25 @@ class TestTitleMiddlewareCoreLogic:
|
||||
prompt, _ = middleware._build_title_prompt(state)
|
||||
assert "<think>" not in prompt
|
||||
|
||||
def test_build_title_prompt_uses_real_user_message_with_dynamic_context_reminder(self):
|
||||
_set_test_title_config(enabled=True)
|
||||
middleware = TitleMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(
|
||||
content="<system-reminder>\n<memory>User prefers Python.</memory>\n</system-reminder>",
|
||||
additional_kwargs={_DYNAMIC_CONTEXT_REMINDER_KEY: True},
|
||||
),
|
||||
HumanMessage(content="请帮我写测试"),
|
||||
AIMessage(content="好的"),
|
||||
]
|
||||
}
|
||||
|
||||
prompt, user_msg = middleware._build_title_prompt(state)
|
||||
assert user_msg == "请帮我写测试"
|
||||
assert "<system-reminder>" not in prompt
|
||||
assert "User prefers Python" not in prompt
|
||||
|
||||
def test_generate_title_async_strips_think_tags_in_response(self, monkeypatch):
|
||||
"""Async title generation strips <think> blocks from the model response."""
|
||||
_set_test_title_config(max_chars=50)
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
from deerflow.config.token_usage_config import TokenUsageConfig
|
||||
|
||||
|
||||
def test_token_usage_enabled_by_default():
|
||||
assert TokenUsageConfig().enabled is True
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for TokenUsageMiddleware attribution annotations."""
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
@@ -17,6 +18,82 @@ def _make_runtime():
|
||||
|
||||
|
||||
class TestTokenUsageMiddleware:
|
||||
def test_logs_cache_token_details(self, caplog):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
content="Here is the final answer.",
|
||||
usage_metadata={
|
||||
"input_tokens": 350,
|
||||
"output_tokens": 240,
|
||||
"total_tokens": 590,
|
||||
"input_token_details": {
|
||||
"audio": 10,
|
||||
"cache_creation": 200,
|
||||
"cache_read": 100,
|
||||
},
|
||||
"output_token_details": {
|
||||
"audio": 10,
|
||||
"reasoning": 200,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
with caplog.at_level(
|
||||
logging.INFO,
|
||||
logger="deerflow.agents.middlewares.token_usage_middleware",
|
||||
):
|
||||
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
assert "LLM token usage: input=350 output=240 total=590" in caplog.text
|
||||
assert "input_token_details={'audio': 10, 'cache_creation': 200, 'cache_read': 100}" in caplog.text
|
||||
assert "output_token_details={'audio': 10, 'reasoning': 200}" in caplog.text
|
||||
|
||||
def test_logs_basic_tokens_when_no_detail_fields_in_usage_metadata(self, caplog):
|
||||
"""When usage_metadata has only totals (no input_token_details), log just the counts."""
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
content="Here is the final answer.",
|
||||
usage_metadata={
|
||||
"input_tokens": 350,
|
||||
"output_tokens": 240,
|
||||
"total_tokens": 590,
|
||||
},
|
||||
)
|
||||
|
||||
with caplog.at_level(
|
||||
logging.INFO,
|
||||
logger="deerflow.agents.middlewares.token_usage_middleware",
|
||||
):
|
||||
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
assert "LLM token usage: input=350 output=240 total=590" in caplog.text
|
||||
assert "input_token_details" not in caplog.text
|
||||
|
||||
def test_no_log_when_usage_metadata_is_missing(self, caplog):
|
||||
"""When usage_metadata is absent, no token usage line is logged."""
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
content="Here is the final answer.",
|
||||
response_metadata={
|
||||
"usage": {
|
||||
"input_tokens": 350,
|
||||
"output_tokens": 240,
|
||||
"total_tokens": 590,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
with caplog.at_level(
|
||||
logging.INFO,
|
||||
logger="deerflow.agents.middlewares.token_usage_middleware",
|
||||
):
|
||||
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
assert "LLM token usage" not in caplog.text
|
||||
|
||||
def test_annotates_todo_updates_with_structured_actions(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Regression test: tool args schemas must not emit Pydantic serialization warnings.
|
||||
|
||||
DeerFlow tools annotate their runtime parameter as ``Runtime``
|
||||
(``deerflow.tools.types.Runtime`` = ``ToolRuntime[dict[str, Any], ThreadState]``)
|
||||
so the LangChain tool framework injects the runtime automatically.
|
||||
When the inner ``Runtime.context`` field is left as the unbound ``ContextT``
|
||||
TypeVar (default ``None``), Pydantic's ``model_dump()`` on the auto-generated
|
||||
args schema emits a ``PydanticSerializationUnexpectedValue`` warning on every
|
||||
tool call because the actual context DeerFlow installs is a dict. Using the
|
||||
``Runtime`` alias (which binds the context to ``dict[str, Any]``) keeps
|
||||
Pydantic's serialization expectations aligned with reality.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from deerflow.sandbox.tools import (
|
||||
bash_tool,
|
||||
glob_tool,
|
||||
grep_tool,
|
||||
ls_tool,
|
||||
read_file_tool,
|
||||
str_replace_tool,
|
||||
write_file_tool,
|
||||
)
|
||||
from deerflow.tools.builtins.present_file_tool import present_file_tool
|
||||
from deerflow.tools.builtins.setup_agent_tool import setup_agent
|
||||
from deerflow.tools.builtins.task_tool import task_tool
|
||||
from deerflow.tools.builtins.update_agent_tool import update_agent
|
||||
from deerflow.tools.builtins.view_image_tool import view_image_tool
|
||||
from deerflow.tools.skill_manage_tool import skill_manage_tool
|
||||
|
||||
|
||||
def _make_runtime(context: dict) -> ToolRuntime:
|
||||
return ToolRuntime(
|
||||
state={"sandbox": {"sandbox_id": "local"}, "thread_data": {}},
|
||||
context=context,
|
||||
config={"configurable": {"thread_id": context.get("thread_id", "thread-1")}},
|
||||
stream_writer=lambda _: None,
|
||||
tools=[],
|
||||
tool_call_id="call-1",
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
_TOOL_CASES = [
|
||||
(bash_tool, {"description": "list", "command": "ls"}),
|
||||
(ls_tool, {"description": "list", "path": "/tmp"}),
|
||||
(glob_tool, {"description": "find", "pattern": "*.py", "path": "/tmp"}),
|
||||
(grep_tool, {"description": "search", "pattern": "x", "path": "/tmp"}),
|
||||
(read_file_tool, {"description": "read", "path": "/tmp/x"}),
|
||||
(write_file_tool, {"description": "write", "path": "/tmp/x", "content": "hi"}),
|
||||
(str_replace_tool, {"description": "replace", "path": "/tmp/x", "old_str": "a", "new_str": "b"}),
|
||||
(present_file_tool, {"filepaths": ["/tmp/x"], "tool_call_id": "call-1"}),
|
||||
(view_image_tool, {"image_path": "/tmp/img.png", "tool_call_id": "call-1"}),
|
||||
(task_tool, {"description": "do", "prompt": "go", "subagent_type": "general-purpose", "tool_call_id": "call-1"}),
|
||||
(skill_manage_tool, {"action": "list", "name": "demo"}),
|
||||
(setup_agent, {"soul": "s", "description": "d"}),
|
||||
(update_agent, {}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tool_obj", "extra_args"),
|
||||
_TOOL_CASES,
|
||||
ids=[case[0].name for case in _TOOL_CASES],
|
||||
)
|
||||
def test_tool_args_schema_does_not_emit_pydantic_context_warning(tool_obj, extra_args) -> None:
|
||||
"""``model_dump()`` of the auto-generated args_schema must not warn about ``context``.
|
||||
|
||||
The model_dump path is hit by LangChain's ``BaseTool._parse_input`` on every tool
|
||||
invocation (see langchain_core/tools/base.py:712), so any warning here would fire
|
||||
once per tool call and pollute production logs.
|
||||
"""
|
||||
schema = tool_obj.args_schema
|
||||
assert schema is not None, f"{tool_obj.name} has no args_schema"
|
||||
|
||||
runtime_obj = _make_runtime({"thread_id": "thread-1", "sandbox_id": "local"})
|
||||
payload = {**extra_args, "runtime": runtime_obj}
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
validated = schema.model_validate(payload)
|
||||
validated.model_dump()
|
||||
|
||||
pydantic_warnings = [w for w in caught if "PydanticSerializationUnexpectedValue" in str(w.message)]
|
||||
assert not pydantic_warnings, f"{tool_obj.name} args_schema.model_dump() emitted Pydantic context serialization warnings: {[str(w.message) for w in pydantic_warnings]}"
|
||||
@@ -0,0 +1,310 @@
|
||||
"""Tests for update_agent tool — partial updates, atomic writes, and validation.
|
||||
|
||||
Resolves issue #2616: a custom agent must be able to persist updates to its
|
||||
own SOUL.md / config.yaml from inside a normal chat (not only from bootstrap).
|
||||
|
||||
The tool writes per-user (``{base_dir}/users/{user_id}/agents/{name}/``) so
|
||||
that one user's update cannot mutate another user's agent.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from deerflow.config.agents_config import AgentConfig
|
||||
from deerflow.tools.builtins.update_agent_tool import update_agent
|
||||
|
||||
DEFAULT_USER = "test-user-autouse" # matches the autouse fixture in tests/conftest.py
|
||||
|
||||
|
||||
class _DummyRuntime(SimpleNamespace):
|
||||
context: dict
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
def _runtime(agent_name: str | None = "test-agent", tool_call_id: str = "call_1") -> _DummyRuntime:
|
||||
return _DummyRuntime(context={"agent_name": agent_name} if agent_name is not None else {}, tool_call_id=tool_call_id)
|
||||
|
||||
|
||||
def _make_paths_mock(tmp_path: Path) -> MagicMock:
|
||||
paths = MagicMock()
|
||||
paths.base_dir = tmp_path
|
||||
paths.agent_dir = lambda name: tmp_path / "agents" / name
|
||||
paths.agents_dir = tmp_path / "agents"
|
||||
paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name
|
||||
paths.user_agents_dir = lambda user_id: tmp_path / "users" / user_id / "agents"
|
||||
return paths
|
||||
|
||||
|
||||
def _user_agent_dir(tmp_path: Path, name: str = "test-agent", user_id: str = DEFAULT_USER) -> Path:
|
||||
return tmp_path / "users" / user_id / "agents" / name
|
||||
|
||||
|
||||
def _seed_agent(
|
||||
tmp_path: Path,
|
||||
name: str = "test-agent",
|
||||
*,
|
||||
description: str = "old desc",
|
||||
soul: str = "old soul",
|
||||
skills: list[str] | None = None,
|
||||
user_id: str = DEFAULT_USER,
|
||||
) -> Path:
|
||||
"""Create a baseline agent dir with config.yaml and SOUL.md for tests to mutate."""
|
||||
agent_dir = _user_agent_dir(tmp_path, name, user_id=user_id)
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
cfg: dict = {"name": name, "description": description}
|
||||
if skills is not None:
|
||||
cfg["skills"] = skills
|
||||
(agent_dir / "config.yaml").write_text(yaml.safe_dump(cfg, sort_keys=False), encoding="utf-8")
|
||||
(agent_dir / "SOUL.md").write_text(soul, encoding="utf-8")
|
||||
return agent_dir
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def patched_paths(tmp_path: Path):
|
||||
paths_mock = _make_paths_mock(tmp_path)
|
||||
with patch("deerflow.tools.builtins.update_agent_tool.get_paths", return_value=paths_mock):
|
||||
# load_agent_config also calls get_paths(); patch the same target it uses.
|
||||
with patch("deerflow.config.agents_config.get_paths", return_value=paths_mock):
|
||||
yield paths_mock
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def stub_app_config():
|
||||
"""Stub get_app_config so model validation accepts only known names."""
|
||||
fake = MagicMock()
|
||||
fake.get_model_config.side_effect = lambda name: object() if name in {"gpt-known", "m1"} else None
|
||||
with patch("deerflow.tools.builtins.update_agent_tool.get_app_config", return_value=fake):
|
||||
yield fake
|
||||
|
||||
|
||||
# --- Validation tests ---
|
||||
|
||||
|
||||
def test_update_agent_rejects_missing_agent_name(patched_paths):
|
||||
result = update_agent.func(runtime=_runtime(agent_name=None), soul="new soul")
|
||||
|
||||
msg = result.update["messages"][0]
|
||||
assert "only available inside a custom agent's chat" in msg.content
|
||||
|
||||
|
||||
def test_update_agent_rejects_invalid_agent_name(patched_paths):
|
||||
result = update_agent.func(runtime=_runtime(agent_name="../../etc/passwd"), soul="x")
|
||||
|
||||
msg = result.update["messages"][0]
|
||||
assert "Invalid agent name" in msg.content
|
||||
|
||||
|
||||
def test_update_agent_rejects_unknown_agent(tmp_path, patched_paths):
|
||||
result = update_agent.func(runtime=_runtime(agent_name="ghost"), soul="x")
|
||||
|
||||
msg = result.update["messages"][0]
|
||||
assert "does not exist" in msg.content
|
||||
assert not _user_agent_dir(tmp_path, "ghost").exists()
|
||||
|
||||
|
||||
def test_update_agent_requires_at_least_one_field(tmp_path, patched_paths):
|
||||
_seed_agent(tmp_path)
|
||||
|
||||
result = update_agent.func(runtime=_runtime())
|
||||
|
||||
msg = result.update["messages"][0]
|
||||
assert "No fields provided" in msg.content
|
||||
|
||||
|
||||
def test_update_agent_rejects_unknown_model(tmp_path, patched_paths, stub_app_config):
|
||||
"""Copilot review: model must be validated against configured models before
|
||||
being persisted; otherwise _resolve_model_name silently falls back to the
|
||||
default and the user gets repeated warnings on every later turn."""
|
||||
_seed_agent(tmp_path)
|
||||
|
||||
result = update_agent.func(runtime=_runtime(), model="not-in-config")
|
||||
|
||||
msg = result.update["messages"][0]
|
||||
assert "Unknown model" in msg.content
|
||||
cfg = yaml.safe_load((_user_agent_dir(tmp_path) / "config.yaml").read_text())
|
||||
assert "model" not in cfg, "Invalid model must not have been written to config.yaml"
|
||||
|
||||
|
||||
def test_update_agent_accepts_known_model(tmp_path, patched_paths, stub_app_config):
|
||||
_seed_agent(tmp_path)
|
||||
|
||||
result = update_agent.func(runtime=_runtime(), model="gpt-known")
|
||||
|
||||
cfg = yaml.safe_load((_user_agent_dir(tmp_path) / "config.yaml").read_text())
|
||||
assert cfg["model"] == "gpt-known"
|
||||
assert "model" in result.update["messages"][0].content
|
||||
|
||||
|
||||
# --- Partial update tests ---
|
||||
|
||||
|
||||
def test_update_agent_updates_soul_only(tmp_path, patched_paths):
|
||||
agent_dir = _seed_agent(tmp_path, description="keep me", soul="old soul")
|
||||
|
||||
result = update_agent.func(runtime=_runtime(), soul="brand new soul")
|
||||
|
||||
assert (agent_dir / "SOUL.md").read_text() == "brand new soul"
|
||||
cfg = yaml.safe_load((agent_dir / "config.yaml").read_text())
|
||||
assert cfg["description"] == "keep me", "description must be preserved"
|
||||
assert "soul" in result.update["messages"][0].content
|
||||
|
||||
|
||||
def test_update_agent_updates_description_only(tmp_path, patched_paths):
|
||||
agent_dir = _seed_agent(tmp_path, description="old desc", soul="keep this soul")
|
||||
|
||||
result = update_agent.func(runtime=_runtime(), description="new desc")
|
||||
|
||||
cfg = yaml.safe_load((agent_dir / "config.yaml").read_text())
|
||||
assert cfg["description"] == "new desc"
|
||||
assert (agent_dir / "SOUL.md").read_text() == "keep this soul", "SOUL.md must be preserved"
|
||||
assert "description" in result.update["messages"][0].content
|
||||
|
||||
|
||||
def test_update_agent_skills_empty_list_disables_all(tmp_path, patched_paths):
|
||||
agent_dir = _seed_agent(tmp_path, skills=["a", "b"])
|
||||
|
||||
result = update_agent.func(runtime=_runtime(), skills=[])
|
||||
|
||||
cfg = yaml.safe_load((agent_dir / "config.yaml").read_text())
|
||||
assert cfg["skills"] == [], "empty list must persist as empty list (not be omitted)"
|
||||
assert "skills" in result.update["messages"][0].content
|
||||
|
||||
|
||||
def test_update_agent_skills_omitted_keeps_existing(tmp_path, patched_paths):
|
||||
agent_dir = _seed_agent(tmp_path, skills=["alpha", "beta"])
|
||||
|
||||
update_agent.func(runtime=_runtime(), description="bumped")
|
||||
|
||||
cfg = yaml.safe_load((agent_dir / "config.yaml").read_text())
|
||||
assert cfg["skills"] == ["alpha", "beta"], "omitting skills must preserve the existing whitelist"
|
||||
|
||||
|
||||
def test_update_agent_no_op_when_values_match_existing(tmp_path, patched_paths):
|
||||
_seed_agent(tmp_path, description="same")
|
||||
|
||||
result = update_agent.func(runtime=_runtime(), description="same")
|
||||
|
||||
assert "No changes applied" in result.update["messages"][0].content
|
||||
|
||||
|
||||
def test_update_agent_forces_name_to_directory(tmp_path, patched_paths):
|
||||
"""Copilot review: if the existing config.yaml has a drifted ``name`` field,
|
||||
update_agent must rewrite it to match the directory name so on-disk state
|
||||
stays consistent with the runtime context."""
|
||||
agent_dir = _user_agent_dir(tmp_path)
|
||||
agent_dir.mkdir(parents=True)
|
||||
(agent_dir / "config.yaml").write_text(yaml.safe_dump({"name": "drifted-name", "description": "old"}, sort_keys=False), encoding="utf-8")
|
||||
(agent_dir / "SOUL.md").write_text("soul", encoding="utf-8")
|
||||
|
||||
update_agent.func(runtime=_runtime(), description="bumped")
|
||||
|
||||
cfg = yaml.safe_load((agent_dir / "config.yaml").read_text())
|
||||
assert cfg["name"] == "test-agent", "config.yaml name must follow the directory name, not legacy yaml content"
|
||||
|
||||
|
||||
# --- Atomicity tests ---
|
||||
|
||||
|
||||
def test_update_agent_failure_preserves_existing_files(tmp_path, patched_paths):
|
||||
agent_dir = _seed_agent(tmp_path, soul="original soul")
|
||||
|
||||
real_replace = Path.replace
|
||||
|
||||
def _explode(self, target):
|
||||
if str(target).endswith("SOUL.md"):
|
||||
raise OSError("disk full")
|
||||
return real_replace(self, target)
|
||||
|
||||
with patch.object(Path, "replace", _explode):
|
||||
result = update_agent.func(runtime=_runtime(), soul="poisoned content")
|
||||
|
||||
assert (agent_dir / "SOUL.md").read_text() == "original soul", "atomic write must not corrupt existing SOUL.md"
|
||||
assert "Error" in result.update["messages"][0].content
|
||||
leftover_tmps = list(agent_dir.glob("*.tmp"))
|
||||
assert leftover_tmps == [], "temp files must be cleaned up on failure"
|
||||
|
||||
|
||||
def test_update_agent_soul_failure_does_not_replace_config(tmp_path, patched_paths):
|
||||
"""Copilot review: if both config.yaml and SOUL.md are scheduled to be
|
||||
written and SOUL.md staging fails *before* any rename, config.yaml must
|
||||
NOT be replaced. The fix stages every temp file first and only renames
|
||||
after all temps exist on disk."""
|
||||
agent_dir = _seed_agent(tmp_path, description="original-desc", soul="original soul")
|
||||
|
||||
real_named_temp_file = __import__("tempfile").NamedTemporaryFile
|
||||
call_count = {"n": 0}
|
||||
|
||||
def _explode_on_soul(*args, **kwargs):
|
||||
# Inspect target dir + suffix; the SOUL temp file is the second one we stage.
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] >= 2:
|
||||
raise OSError("disk full while staging SOUL.md")
|
||||
return real_named_temp_file(*args, **kwargs)
|
||||
|
||||
with patch("deerflow.tools.builtins.update_agent_tool.tempfile.NamedTemporaryFile", side_effect=_explode_on_soul):
|
||||
result = update_agent.func(runtime=_runtime(), description="new-desc", soul="new soul")
|
||||
|
||||
cfg = yaml.safe_load((agent_dir / "config.yaml").read_text())
|
||||
assert cfg["description"] == "original-desc", "config.yaml must not be replaced when SOUL.md staging fails"
|
||||
assert (agent_dir / "SOUL.md").read_text() == "original soul"
|
||||
assert "Error" in result.update["messages"][0].content
|
||||
assert list(agent_dir.glob("*.tmp")) == [], "staged config.yaml temp must be cleaned up on SOUL.md failure"
|
||||
|
||||
|
||||
# --- Per-user isolation ---
|
||||
|
||||
|
||||
def test_update_agent_only_writes_under_current_user(tmp_path, patched_paths):
|
||||
"""An update from user 'alice' must never touch user 'bob's agent files."""
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
# Seed an agent for both users with the same name.
|
||||
alice_dir = _seed_agent(tmp_path, name="shared", description="alice-desc", soul="alice soul", user_id="alice")
|
||||
bob_dir = _seed_agent(tmp_path, name="shared", description="bob-desc", soul="bob soul", user_id="bob")
|
||||
|
||||
# Override the autouse contextvar so update_agent runs as Alice.
|
||||
token = set_current_user(SimpleNamespace(id="alice"))
|
||||
try:
|
||||
update_agent.func(runtime=_runtime(agent_name="shared"), description="alice-bumped")
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
alice_cfg = yaml.safe_load((alice_dir / "config.yaml").read_text())
|
||||
bob_cfg = yaml.safe_load((bob_dir / "config.yaml").read_text())
|
||||
assert alice_cfg["description"] == "alice-bumped"
|
||||
assert bob_cfg["description"] == "bob-desc", "bob's config.yaml must not have been touched"
|
||||
assert (bob_dir / "SOUL.md").read_text() == "bob soul"
|
||||
|
||||
|
||||
# --- Loader passthrough sanity check ---
|
||||
|
||||
|
||||
def test_update_agent_round_trips_known_fields(tmp_path, patched_paths):
|
||||
"""update_agent reads through load_agent_config so all fields the loader
|
||||
knows about (name, description, model, tool_groups, skills) round-trip
|
||||
on a partial update.
|
||||
|
||||
Note: ``load_agent_config`` strips unknown fields before constructing
|
||||
AgentConfig, so legacy/extra YAML keys are NOT preserved across
|
||||
updates — by design.
|
||||
"""
|
||||
_seed_agent(tmp_path, description="legacy")
|
||||
|
||||
fake_cfg = AgentConfig(name="test-agent", description="legacy", skills=["s1"], tool_groups=["g1"], model="m1")
|
||||
fake_app_config = MagicMock()
|
||||
fake_app_config.get_model_config.return_value = object()
|
||||
with patch("deerflow.tools.builtins.update_agent_tool.load_agent_config", return_value=fake_cfg):
|
||||
with patch("deerflow.tools.builtins.update_agent_tool.get_app_config", return_value=fake_app_config):
|
||||
update_agent.func(runtime=_runtime(), description="bumped")
|
||||
|
||||
cfg = yaml.safe_load((_user_agent_dir(tmp_path) / "config.yaml").read_text())
|
||||
assert cfg["description"] == "bumped"
|
||||
assert cfg["skills"] == ["s1"]
|
||||
assert cfg["tool_groups"] == ["g1"]
|
||||
assert cfg["model"] == "m1"
|
||||
@@ -126,15 +126,18 @@ class TestWriteUploadFileNoSymlink:
|
||||
assert dest.read_bytes() == b"new contents"
|
||||
assert os.stat(dest).st_nlink == 1
|
||||
|
||||
def test_fails_closed_without_no_follow_support(self, tmp_path, monkeypatch):
|
||||
def test_fallback_without_no_follow_support_succeeds(self, tmp_path, monkeypatch):
|
||||
monkeypatch.delattr(os, "O_NOFOLLOW", raising=False)
|
||||
|
||||
with pytest.raises(UnsafeUploadPathError, match="O_NOFOLLOW"):
|
||||
write_upload_file_no_symlink(tmp_path, "notes.txt", b"hello")
|
||||
|
||||
assert not (tmp_path / "notes.txt").exists()
|
||||
# When O_NOFOLLOW is absent (Windows), the function falls back to
|
||||
# a dual-lstat + fstat approach and succeeds.
|
||||
result = write_upload_file_no_symlink(tmp_path, "notes.txt", b"hello")
|
||||
assert result == tmp_path / "notes.txt"
|
||||
assert (tmp_path / "notes.txt").read_bytes() == b"hello"
|
||||
|
||||
def test_open_uses_nonblocking_flag_when_available(self, tmp_path):
|
||||
if not hasattr(os, "O_NONBLOCK"):
|
||||
pytest.skip("O_NONBLOCK not available on this platform")
|
||||
with patch("deerflow.uploads.manager.os.open", side_effect=OSError(errno.ENXIO, "no reader")) as open_mock:
|
||||
with pytest.raises(UnsafeUploadPathError, match="Unsafe upload destination"):
|
||||
write_upload_file_no_symlink(tmp_path, "pipe.txt", b"hello")
|
||||
@@ -144,6 +147,8 @@ class TestWriteUploadFileNoSymlink:
|
||||
|
||||
@pytest.mark.parametrize("open_errno", [errno.ENXIO, errno.EAGAIN])
|
||||
def test_nonblocking_special_file_open_errors_are_unsafe(self, tmp_path, open_errno):
|
||||
if not hasattr(os, "O_NONBLOCK"):
|
||||
pytest.skip("O_NONBLOCK not available on this platform")
|
||||
with patch("deerflow.uploads.manager.os.open", side_effect=OSError(open_errno, "would block")):
|
||||
with pytest.raises(UnsafeUploadPathError, match="Unsafe upload destination"):
|
||||
write_upload_file_no_symlink(tmp_path, "pipe.txt", b"hello")
|
||||
|
||||
@@ -61,6 +61,39 @@ def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_pat
|
||||
sandbox.update_file.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_files_auto_renames_duplicate_form_filenames(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
|
||||
with (
|
||||
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
result = asyncio.run(
|
||||
call_unwrapped(
|
||||
uploads.upload_files,
|
||||
"thread-local",
|
||||
request=MagicMock(),
|
||||
files=[
|
||||
UploadFile(filename="data.txt", file=BytesIO(b"first")),
|
||||
UploadFile(filename="data.txt", file=BytesIO(b"second")),
|
||||
],
|
||||
config=SimpleNamespace(),
|
||||
)
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert [file_info["filename"] for file_info in result.files] == ["data.txt", "data_1.txt"]
|
||||
assert "original_filename" not in result.files[0]
|
||||
assert result.files[1]["original_filename"] == "data.txt"
|
||||
assert (thread_uploads_dir / "data.txt").read_bytes() == b"first"
|
||||
assert (thread_uploads_dir / "data_1.txt").read_bytes() == b"second"
|
||||
|
||||
|
||||
def test_upload_files_skips_acquire_when_thread_data_is_mounted(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
Generated
+10
-10
@@ -788,7 +788,7 @@ requires-dist = [
|
||||
{ name = "lark-oapi", specifier = ">=1.4.0" },
|
||||
{ name = "markdown-to-mrkdwn", specifier = ">=0.3.1" },
|
||||
{ name = "pyjwt", specifier = ">=2.9.0" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.26" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.27" },
|
||||
{ name = "python-telegram-bot", specifier = ">=21.0" },
|
||||
{ name = "slack-sdk", specifier = ">=3.33.0" },
|
||||
{ name = "sse-starlette", specifier = ">=2.1.0" },
|
||||
@@ -1725,7 +1725,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.3.2"
|
||||
version = "1.3.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@@ -1738,9 +1738,9 @@ dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "uuid-utils" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a8/03/7219502e8ca728d65eb44d7a3eb60239230742a70dbfc9241b9bfd61c4ab/langchain_core-1.3.2.tar.gz", hash = "sha256:fd7a50b2f28ba561fd9d7f5d2760bc9e06cf00cdf820a3ccafe88a94ffa8d5b7", size = 911813, upload-time = "2026-04-24T15:49:23.699Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d3/ae/8b74458fc3850ec3d150eb9f45e857db129dafa801fb5cf173dfc9f8bbf3/langchain_core-1.3.3.tar.gz", hash = "sha256:fa510a5db8efdc0c6ff41c0939fb5c00a0183c11f6b84233e892e3227ff69182", size = 915041, upload-time = "2026-05-05T19:02:36.612Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/d5/8fa4431007cbb7cfed7590f4d6a5dea3ad724f4174d248f6642ef5ce7d05/langchain_core-1.3.2-py3-none-any.whl", hash = "sha256:d44a66127f9f8db735bdfd0ab9661bccb47a97113cfd3f2d89c74864422b7274", size = 542390, upload-time = "2026-04-24T15:49:21.991Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/01/4771b7ab2af1d1aba5b710bd8f13d9225c609425214b357590a17b01be77/langchain_core-1.3.3-py3-none-any.whl", hash = "sha256:18aae8506f37da7f74398492279a7d6efcee4f8e23c4c41c7af080eeb7ef7bd1", size = 543857, upload-time = "2026-05-05T19:02:34.52Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2145,14 +2145,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mako"
|
||||
version = "1.3.11"
|
||||
version = "1.3.12"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "markupsafe" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/59/8a/805404d0c0b9f3d7a326475ca008db57aea9c5c9f2e1e39ed0faa335571c/mako-1.3.11.tar.gz", hash = "sha256:071eb4ab4c5010443152255d77db7faa6ce5916f35226eb02dc34479b6858069", size = 399811, upload-time = "2026-04-14T20:19:51.493Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/00/62/791b31e69ae182791ec67f04850f2f062716bbd205483d63a215f3e062d3/mako-1.3.12.tar.gz", hash = "sha256:9f778e93289bd410bb35daadeb4fc66d95a746f0b75777b942088b7fd7af550a", size = 400219, upload-time = "2026-04-28T19:01:08.512Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/68/a5/19d7aaa7e433713ffe881df33705925a196afb9532efc8475d26593921a6/mako-1.3.11-py3-none-any.whl", hash = "sha256:e372c6e333cf004aa736a15f425087ec977e1fcbd2966aae7f17c8dc1da27a77", size = 78503, upload-time = "2026-04-14T20:19:53.233Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/b1/a0ec7a5a9db730a08daef1fdfb8090435b82465abbf758a596f0ea88727e/mako-1.3.12-py3-none-any.whl", hash = "sha256:8f61569480282dbf557145ce441e4ba888be453c30989f879f0d652e39f53ea9", size = 78521, upload-time = "2026-04-28T19:01:10.393Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3532,11 +3532,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "python-multipart"
|
||||
version = "0.0.26"
|
||||
version = "0.0.27"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/88/71/b145a380824a960ebd60e1014256dbb7d2253f2316ff2d73dfd8928ec2c3/python_multipart-0.0.26.tar.gz", hash = "sha256:08fadc45918cd615e26846437f50c5d6d23304da32c341f289a617127b081f17", size = 43501, upload-time = "2026-04-10T14:09:59.473Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/69/9b/f23807317a113dc36e74e75eb265a02dd1a4d9082abc3c1064acd22997c4/python_multipart-0.0.27.tar.gz", hash = "sha256:9870a6a8c5a20a5bf4f07c017bd1489006ff8836cff097b6933355ee2b49b602", size = 44043, upload-time = "2026-04-27T10:51:26.649Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/22/f1925cdda983ab66fc8ec6ec8014b959262747e58bdca26a4e3d1da29d56/python_multipart-0.0.26-py3-none-any.whl", hash = "sha256:c0b169f8c4484c13b0dcf2ef0ec3a4adb255c4b7d18d8e420477d2b1dd03f185", size = 28847, upload-time = "2026-04-10T14:09:58.131Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/99/78/4126abcbdbd3c559d43e0db7f7b9173fc6befe45d39a2856cc0b8ec2a5a6/python_multipart-0.0.27-py3-none-any.whl", hash = "sha256:6fccfad17a27334bd0193681b369f476eda3409f17381a2d65aa7df3f7275645", size = 29254, upload-time = "2026-04-27T10:51:24.997Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
+49
-5
@@ -15,7 +15,7 @@
|
||||
# ============================================================================
|
||||
# Bump this number when the config schema changes.
|
||||
# Run `make config-upgrade` to merge new fields into your local config.yaml.
|
||||
config_version: 8
|
||||
config_version: 9
|
||||
|
||||
# ============================================================================
|
||||
# Logging
|
||||
@@ -30,7 +30,7 @@ log_level: info
|
||||
# When enabled, DeerFlow records input/output/total tokens per model call
|
||||
# and shows usage metadata in the workspace UI when providers return it.
|
||||
token_usage:
|
||||
enabled: false
|
||||
enabled: true
|
||||
|
||||
# ============================================================================
|
||||
# Models Configuration
|
||||
@@ -506,6 +506,29 @@ tools:
|
||||
tool_search:
|
||||
enabled: false
|
||||
|
||||
# ============================================================================
|
||||
# Loop Detection Configuration
|
||||
# ============================================================================
|
||||
# Detect and interrupt repeated identical tool-call loops.
|
||||
# Frequency thresholds are safety limits for repeated use of the same tool type.
|
||||
|
||||
loop_detection:
|
||||
enabled: true
|
||||
warn_threshold: 3
|
||||
hard_limit: 5
|
||||
window_size: 20
|
||||
max_tracked_threads: 100
|
||||
tool_freq_warn: 30
|
||||
tool_freq_hard_limit: 50
|
||||
# Per-tool overrides for tool_freq_warn / tool_freq_hard_limit. Values can be
|
||||
# higher or lower than the global defaults. Commonly used to raise thresholds
|
||||
# for high-frequency tools like bash in batch workflows (e.g. RNA-seq pipelines)
|
||||
# without weakening protection on every other tool.
|
||||
# tool_freq_overrides:
|
||||
# bash:
|
||||
# warn: 150
|
||||
# hard_limit: 300
|
||||
|
||||
# ============================================================================
|
||||
# Sandbox Configuration
|
||||
# ============================================================================
|
||||
@@ -578,6 +601,11 @@ sandbox:
|
||||
# # Optional: Prefix for container names (default: deer-flow-sandbox)
|
||||
# # container_prefix: deer-flow-sandbox
|
||||
#
|
||||
# # Optional: Automatically restart crashed sandbox containers (default: true)
|
||||
# # When enabled, a dead container is detected on the next tool call and
|
||||
# # transparently replaced with a fresh one. Set to false to disable.
|
||||
# # auto_restart: true
|
||||
#
|
||||
# # Optional: Additional mount directories from host to container
|
||||
# # NOTE: Skills directory is automatically mounted from skills.path to skills.container_path
|
||||
# # mounts:
|
||||
@@ -848,9 +876,25 @@ skill_evolution:
|
||||
#
|
||||
# Postgres mode: put your connection URL in .env as DATABASE_URL,
|
||||
# then reference it here with $DATABASE_URL.
|
||||
# Install the driver first:
|
||||
# Local: uv sync --extra postgres
|
||||
# Docker: UV_EXTRAS=postgres docker compose build
|
||||
#
|
||||
# Install the driver — Issue #2754 fix lands `UV_EXTRAS` in every code path:
|
||||
# Local `make dev` auto-detects from `database.backend: postgres` below
|
||||
# and passes `--extra postgres` to `uv sync` on every restart, so
|
||||
# the extra is no longer wiped. To opt in explicitly (or layer
|
||||
# extras like `postgres,ollama`), set in project-root .env:
|
||||
# UV_EXTRAS=postgres
|
||||
# Docker dev `make docker-start` reads `UV_EXTRAS` from project-root .env via
|
||||
# `env_file`. Set:
|
||||
# UV_EXTRAS=postgres
|
||||
# Multiple extras (`postgres,ollama`) supported here too — see
|
||||
# docker/dev-entrypoint.sh.
|
||||
# Docker img build-arg `UV_EXTRAS=postgres docker compose build` — single
|
||||
# extra only at build time (backend/Dockerfile passes the value
|
||||
# as one token to `--extra`).
|
||||
#
|
||||
# First-time bootstrap (before `make dev`):
|
||||
# cd backend && uv sync --all-packages --extra postgres
|
||||
# (--all-packages propagates the extra into workspace members — see PR #2584)
|
||||
#
|
||||
# NOTE: When both `checkpointer` and `database` are configured,
|
||||
# `checkpointer` takes precedence for LangGraph state persistence.
|
||||
|
||||
Executable
+85
@@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env sh
|
||||
#
|
||||
# DeerFlow gateway dev entrypoint — runs inside the docker-compose-dev gateway
|
||||
# container. Extracted from docker/docker-compose-dev.yaml's inline `command:`
|
||||
# (PR #2767, addressing review on Issue #2754).
|
||||
#
|
||||
# Responsibilities:
|
||||
# 1. Resolve `--extra X` flags from UV_EXTRAS (comma- or whitespace-separated,
|
||||
# mirroring scripts/detect_uv_extras.py for parity with local `make dev`).
|
||||
# 2. Validate each extra against [A-Za-z][A-Za-z0-9_-]* so a stray shell
|
||||
# metacharacter in `.env` cannot reach `uv sync`.
|
||||
# 3. `uv sync --all-packages` so workspace member extras (deerflow-harness's
|
||||
# postgres extra in particular) are installed — see PR #2584.
|
||||
# 4. Self-heal: if the first sync fails, recreate .venv and retry once.
|
||||
# 5. Hand off to uvicorn with reload, replacing this shell so uvicorn becomes
|
||||
# PID 1 inside the container.
|
||||
#
|
||||
# Anchored at /bin/sh (not bash) since alpine-based base images may not ship
|
||||
# bash. Uses POSIX-only constructs throughout.
|
||||
|
||||
set -e
|
||||
|
||||
# `--print-extras` is a dry-run hook: parse + validate UV_EXTRAS, print the
|
||||
# resulting `--extra X` flags to stdout, and exit. Used by the unit test in
|
||||
# backend/tests/test_dev_entrypoint.py and useful for ad-hoc debugging.
|
||||
PRINT_EXTRAS_ONLY=0
|
||||
if [ "${1:-}" = "--print-extras" ]; then
|
||||
PRINT_EXTRAS_ONLY=1
|
||||
fi
|
||||
|
||||
# Mirror the legacy command's behavior: redirect both stdout and stderr to the
|
||||
# host-mounted log file (../logs/gateway.log → /app/logs/gateway.log). Skip
|
||||
# the redirect under --print-extras so the test runner can capture stdout.
|
||||
if [ "$PRINT_EXTRAS_ONLY" = "0" ]; then
|
||||
exec >/app/logs/gateway.log 2>&1
|
||||
fi
|
||||
|
||||
# ── Resolve extras ──────────────────────────────────────────────────────────
|
||||
|
||||
EXTRAS_FLAGS=""
|
||||
if [ -n "${UV_EXTRAS:-}" ]; then
|
||||
# Normalize comma → space, then split on whitespace via the unquoted `for`.
|
||||
for raw in $(printf '%s' "$UV_EXTRAS" | tr ',' ' '); do
|
||||
[ -z "$raw" ] && continue
|
||||
# Reject anything that does not look like an identifier.
|
||||
# Two patterns: leading non-letter, or any non-[A-Za-z0-9_-] character.
|
||||
case "$raw" in
|
||||
[!A-Za-z]* | *[!A-Za-z0-9_-]*)
|
||||
echo "[startup] UV_EXTRAS entry '$raw' is invalid (must match [A-Za-z][A-Za-z0-9_-]*) — aborting" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
EXTRAS_FLAGS="$EXTRAS_FLAGS --extra $raw"
|
||||
done
|
||||
fi
|
||||
|
||||
if [ "$PRINT_EXTRAS_ONLY" = "1" ]; then
|
||||
# Trim leading space for tidier output, then exit.
|
||||
printf '%s\n' "${EXTRAS_FLAGS# }"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ -n "$EXTRAS_FLAGS" ]; then
|
||||
echo "[startup] uv extras:$EXTRAS_FLAGS"
|
||||
fi
|
||||
|
||||
# ── Sync dependencies (with self-heal) ──────────────────────────────────────
|
||||
|
||||
cd /app/backend
|
||||
|
||||
# `--all-packages` propagates extras into workspace members (PR #2584).
|
||||
# `$EXTRAS_FLAGS` intentionally unquoted so each `--extra X` becomes its own arg.
|
||||
# shellcheck disable=SC2086 # word-splitting is intentional here
|
||||
if ! uv sync --all-packages $EXTRAS_FLAGS; then
|
||||
echo "[startup] uv sync failed; recreating .venv and retrying once"
|
||||
uv venv --allow-existing .venv
|
||||
# shellcheck disable=SC2086
|
||||
uv sync --all-packages $EXTRAS_FLAGS
|
||||
fi
|
||||
|
||||
# ── Hand off to uvicorn ─────────────────────────────────────────────────────
|
||||
|
||||
PYTHONPATH=. exec uv run uvicorn app.gateway.app:app \
|
||||
--host 0.0.0.0 --port 8001 \
|
||||
--reload --reload-include='*.yaml .env'
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user