Compare commits

..

1 Commits

Author SHA1 Message Date
greatmengqi 2eb45e9bb5 fix: thread app config through client and sync providers 2026-05-02 12:07:26 +08:00
164 changed files with 1265 additions and 10827 deletions
-14
View File
@@ -1,6 +1,3 @@
# Serper API Key (Google Search) - https://serper.dev
SERPER_API_KEY=your-serper-api-key
# TAVILY API Key
TAVILY_API_KEY=your-tavily-api-key
@@ -48,14 +45,3 @@ INFOQUEST_API_KEY=your-infoquest-api-key
# Set to "false" to disable Swagger UI, ReDoc, and OpenAPI schema in production
# GATEWAY_ENABLE_DOCS=false
# ── Frontend SSR → Gateway wiring ─────────────────────────────────────────────
# The Next.js server uses these to reach the Gateway during SSR (auth checks,
# /api/* rewrites). They default to localhost values that match `make dev` and
# `make start`, so most local users do not need to set them.
#
# Override only when the Gateway is not on localhost:8001 (e.g. when the
# frontend and gateway run on different hosts, in containers with a service
# alias, or behind a different port). docker-compose already sets these.
# DEER_FLOW_INTERNAL_GATEWAY_BASE_URL=http://localhost:8001
# DEER_FLOW_TRUSTED_ORIGINS=http://localhost:3000,http://localhost:2026
-101
View File
@@ -1,101 +0,0 @@
name: Publish Containers
on:
push:
tags:
- "v*"
jobs:
backend-container:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
attestations: write
id-token: write
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}-backend
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Log in to the Container registry
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=tag
type=ref,event=branch
type=sha
type=raw,value=latest,enable={{is_default_branch}}
- name: Build and push Docker image
id: push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
with:
context: .
file: backend/Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
- name: Generate artifact attestation
uses: actions/attest-build-provenance@v2
with:
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
frontend-container:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
attestations: write
id-token: write
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}-frontend
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Log in to the Container registry
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=tag
type=ref,event=branch
type=sha
type=raw,value=latest,enable={{is_default_branch}}
- name: Build and push Docker image
id: push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
with:
context: .
file: frontend/Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
- name: Generate artifact attestation
uses: actions/attest-build-provenance@v2
with:
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
+2 -6
View File
@@ -263,10 +263,8 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
- `present_files` - Make output files visible to user (only `/mnt/user-data/outputs`)
- `ask_clarification` - Request clarification (intercepted by ClarificationMiddleware → interrupts)
- `view_image` - Read image as base64 (added only if model supports vision)
- `setup_agent` - Bootstrap-only: persist a brand-new custom agent's `SOUL.md` and `config.yaml`. Bound only when `is_bootstrap=True`.
- `update_agent` - Custom-agent-only: persist self-updates to the current agent's `SOUL.md` / `config.yaml` from inside a normal chat (partial update + atomic write). Bound when `agent_name` is set and `is_bootstrap=False`.
4. **Subagent tool** (if enabled):
- `task` - Delegate to subagent (description, prompt, subagent_type)
- `task` - Delegate to subagent (description, prompt, subagent_type, max_turns)
**Community tools** (`packages/harness/deerflow/community/`):
- `tavily/` - Web search (5 results default) and web fetch (4KB limit)
@@ -356,11 +354,10 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
**Per-User Isolation**:
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
- Custom agent definitions (`SOUL.md` + `config.yaml`) are also per-user at `{base_dir}/users/{user_id}/agents/{agent_name}/`. The legacy shared layout `{base_dir}/agents/{agent_name}/` remains read-only fallback for unmigrated installations
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
- Absolute `storage_path` in config opts out of per-user isolation
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json`, `threads/`, and `agents/` into per-user layout. Supports `--dry-run` (preview changes) and `--user-id USER_ID` (assign unowned legacy data to a user, defaults to `default`).
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
@@ -520,7 +517,6 @@ Multi-file upload with automatic document conversion:
- Rejects directory inputs before copying so uploads stay all-or-nothing
- Reuses one conversion worker per request when called from an active event loop
- Files stored in thread-isolated directories
- Duplicate filenames in a single upload request are auto-renamed with `_N` suffixes so later files do not truncate earlier files
- Agent receives uploaded file list via `UploadsMiddleware`
See [docs/FILE_UPLOAD.md](docs/FILE_UPLOAD.md) for details.
-10
View File
@@ -50,12 +50,6 @@ COPY backend ./backend
RUN --mount=type=cache,target=/root/.cache/uv \
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
# UTF-8 locale prevents UnicodeEncodeError on Chinese/emoji content in minimal
# containers where locale configuration may be missing and the default encoding is not UTF-8.
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
ENV PYTHONIOENCODING=utf-8
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
# Retains compiler toolchain from builder so startup-time `uv sync` can build
# source distributions in development containers.
@@ -72,10 +66,6 @@ CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app
# Clean image without build-essential — reduces size (~200 MB) and attack surface.
FROM python:3.12-slim-bookworm
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
ENV PYTHONIOENCODING=utf-8
# Copy Node.js runtime from builder (provides npx for MCP servers)
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
+1 -1
View File
@@ -124,7 +124,7 @@ FastAPI application providing REST endpoints for frontend integration:
| `POST /api/memory/reload` | Force memory reload |
| `GET /api/memory/config` | Memory configuration |
| `GET /api/memory/status` | Combined config + data |
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths, auto-renames duplicate filenames in one request) |
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths) |
| `GET /api/threads/{id}/uploads/list` | List uploaded files |
| `DELETE /api/threads/{id}` | Delete DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
| `GET /api/threads/{id}/artifacts/{path}` | Serve generated artifacts |
+4 -42
View File
@@ -146,13 +146,6 @@ def _normalize_custom_agent_name(raw_value: str) -> str:
return normalized
def _strip_loop_warning_text(text: str) -> str:
"""Remove middleware-authored loop warning lines from display text."""
if "[LOOP DETECTED]" not in text:
return text
return "\n".join(line for line in text.splitlines() if "[LOOP DETECTED]" not in line).strip()
def _extract_response_text(result: dict | list) -> str:
"""Extract the last AI message text from a LangGraph runs.wait result.
@@ -162,7 +155,7 @@ def _extract_response_text(result: dict | list) -> str:
Handles special cases:
- Regular AI text responses
- Clarification interrupts (``ask_clarification`` tool messages)
- Strips loop-detection warnings attached to tool-call AI messages
- AI messages with tool_calls but no text content
"""
if isinstance(result, list):
messages = result
@@ -192,12 +185,7 @@ def _extract_response_text(result: dict | list) -> str:
# Regular AI message with text content
if msg_type == "ai":
content = msg.get("content", "")
has_tool_calls = bool(msg.get("tool_calls"))
if isinstance(content, str) and content:
if has_tool_calls:
content = _strip_loop_warning_text(content)
if not content:
continue
return content
# content can be a list of content blocks
if isinstance(content, list):
@@ -208,8 +196,6 @@ def _extract_response_text(result: dict | list) -> str:
elif isinstance(block, str):
parts.append(block)
text = "".join(parts)
if has_tool_calls:
text = _strip_loop_warning_text(text)
if text:
return text
return ""
@@ -434,13 +420,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
if not msg.files:
return []
from deerflow.uploads.manager import (
UnsafeUploadPathError,
claim_unique_filename,
ensure_uploads_dir,
normalize_filename,
write_upload_file_no_symlink,
)
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
uploads_dir = ensure_uploads_dir(thread_id)
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
@@ -491,10 +471,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
dest = uploads_dir / safe_name
try:
dest = write_upload_file_no_symlink(uploads_dir, safe_name, data)
except UnsafeUploadPathError:
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
continue
dest.write_bytes(data)
except Exception:
logger.exception("[Manager] failed to write inbound file: %s", dest)
continue
@@ -603,17 +580,6 @@ class ChannelManager:
user_layer.get("config"),
)
configurable = run_config.get("configurable")
if isinstance(configurable, Mapping):
configurable = dict(configurable)
else:
configurable = {}
run_config["configurable"] = configurable
# Pin channel-triggered runs to the root graph namespace so follow-up
# turns continue from the same conversation checkpoint.
configurable["checkpoint_ns"] = ""
configurable["thread_id"] = thread_id
run_context = _merge_dicts(
DEFAULT_RUN_CONTEXT,
self._default_session.get("context"),
@@ -997,11 +963,7 @@ class ChannelManager:
try:
async with httpx.AsyncClient() as http:
resp = await http.get(
f"{self._gateway_url}{path}",
timeout=10,
headers=create_internal_auth_headers(),
)
resp = await http.get(f"{self._gateway_url}{path}", timeout=10)
resp.raise_for_status()
data = resp.json()
except Exception:
+1 -112
View File
@@ -4,10 +4,8 @@ Per RFC-001:
State-changing operations require CSRF protection.
"""
import os
import secrets
from collections.abc import Callable
from urllib.parse import urlsplit
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
@@ -21,7 +19,7 @@ CSRF_TOKEN_LENGTH = 64 # bytes
def is_secure_request(request: Request) -> bool:
"""Detect whether the original client request was made over HTTPS."""
return _request_scheme(request) == "https"
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
def generate_csrf_token() -> str:
@@ -63,109 +61,6 @@ def is_auth_endpoint(request: Request) -> bool:
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str:
"""Return normalized host[:port], omitting default ports."""
host = hostname.lower()
if ":" in host and not host.startswith("["):
host = f"[{host}]"
if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
return host
return f"{host}:{port}"
def _normalize_origin(origin: str) -> str | None:
"""Return a normalized scheme://host[:port] origin, or None for invalid input."""
try:
parsed = urlsplit(origin.strip())
port = parsed.port
except ValueError:
return None
scheme = parsed.scheme.lower()
if scheme not in {"http", "https"} or not parsed.hostname:
return None
# Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values.
if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment:
return None
return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}"
def _configured_cors_origins() -> set[str]:
"""Return explicit configured browser origins that may call auth routes."""
origins = set()
for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","):
origin = raw_origin.strip()
if not origin or origin == "*":
continue
normalized = _normalize_origin(origin)
if normalized:
origins.add(normalized)
return origins
def _first_header_value(value: str | None) -> str | None:
"""Return the first value from a comma-separated proxy header."""
if not value:
return None
first = value.split(",", 1)[0].strip()
return first or None
def _forwarded_param(request: Request, name: str) -> str | None:
"""Extract a parameter from the first RFC 7239 Forwarded header entry."""
forwarded = _first_header_value(request.headers.get("forwarded"))
if not forwarded:
return None
for part in forwarded.split(";"):
key, sep, value = part.strip().partition("=")
if sep and key.lower() == name:
return value.strip().strip('"') or None
return None
def _request_scheme(request: Request) -> str:
"""Resolve the original request scheme from trusted proxy headers."""
scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme
return scheme.lower()
def _request_origin(request: Request) -> str | None:
"""Build the origin for the URL the browser is targeting."""
scheme = _request_scheme(request)
host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc
forwarded_port = _first_header_value(request.headers.get("x-forwarded-port"))
if forwarded_port and ":" not in host.rsplit("]", 1)[-1]:
host = f"{host}:{forwarded_port}"
return _normalize_origin(f"{scheme}://{host}")
def is_allowed_auth_origin(request: Request) -> bool:
"""Allow auth POSTs only from the same origin or explicit configured origins.
Login/register/initialize are exempt from the double-submit token because
first-time browser clients do not have a CSRF token yet. They still create
a session cookie, so browser requests with a hostile Origin header must be
rejected to prevent login CSRF / session fixation. Requests without Origin
are allowed for non-browser clients such as curl and mobile integrations.
"""
origin = request.headers.get("origin")
if not origin:
return True
normalized_origin = _normalize_origin(origin)
if normalized_origin is None:
return False
request_origin = _request_origin(request)
return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin)
class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
@@ -175,12 +70,6 @@ class CSRFMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable) -> Response:
_is_auth = is_auth_endpoint(request)
if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
return JSONResponse(
status_code=403,
content={"detail": "Cross-site auth request denied."},
)
if should_check_csrf(request) and not _is_auth:
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
header_token = request.headers.get(CSRF_HEADER_NAME)
+18 -43
View File
@@ -11,7 +11,6 @@ from pydantic import BaseModel, Field
from deerflow.config.agents_api_config import get_agents_api_config
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["agents"])
@@ -87,11 +86,11 @@ def _require_agents_api_enabled() -> None:
)
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False, *, user_id: str | None = None) -> AgentResponse:
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
"""Convert AgentConfig to AgentResponse."""
soul: str | None = None
if include_soul:
soul = load_agent_soul(agent_cfg.name, user_id=user_id) or ""
soul = load_agent_soul(agent_cfg.name) or ""
return AgentResponse(
name=agent_cfg.name,
@@ -117,10 +116,9 @@ async def list_agents() -> AgentsListResponse:
"""
_require_agents_api_enabled()
user_id = get_effective_user_id()
try:
agents = list_custom_agents(user_id=user_id)
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True, user_id=user_id) for a in agents])
agents = list_custom_agents()
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
except Exception as e:
logger.error(f"Failed to list agents: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
@@ -146,12 +144,7 @@ async def check_agent_name(name: str) -> dict:
_require_agents_api_enabled()
_validate_agent_name(name)
normalized = _normalize_agent_name(name)
user_id = get_effective_user_id()
paths = get_paths()
# Treat the name as taken if either the per-user path or the legacy shared
# path holds an agent — picking a name that collides with an unmigrated
# legacy agent would shadow the legacy entry once migration runs.
available = not paths.user_agent_dir(user_id, normalized).exists() and not paths.agent_dir(normalized).exists()
available = not get_paths().agent_dir(normalized).exists()
return {"available": available, "name": normalized}
@@ -176,11 +169,10 @@ async def get_agent(name: str) -> AgentResponse:
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
user_id = get_effective_user_id()
try:
agent_cfg = load_agent_config(name, user_id=user_id)
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
agent_cfg = load_agent_config(name)
return _agent_config_to_response(agent_cfg, include_soul=True)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
except Exception as e:
@@ -210,13 +202,10 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
_require_agents_api_enabled()
_validate_agent_name(request.name)
normalized_name = _normalize_agent_name(request.name)
user_id = get_effective_user_id()
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, normalized_name)
legacy_dir = paths.agent_dir(normalized_name)
agent_dir = get_paths().agent_dir(normalized_name)
if agent_dir.exists() or legacy_dir.exists():
if agent_dir.exists():
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
try:
@@ -243,8 +232,8 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
agent_cfg = load_agent_config(normalized_name)
return _agent_config_to_response(agent_cfg, include_soul=True)
except HTTPException:
raise
@@ -278,20 +267,13 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
user_id = get_effective_user_id()
try:
agent_cfg = load_agent_config(name, user_id=user_id)
agent_cfg = load_agent_config(name)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, name)
if not agent_dir.exists() and paths.agent_dir(name).exists():
raise HTTPException(
status_code=409,
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before updating."),
)
agent_dir = get_paths().agent_dir(name)
try:
# Update config if any config fields changed
@@ -332,8 +314,8 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
logger.info(f"Updated agent '{name}'")
refreshed_cfg = load_agent_config(name, user_id=user_id)
return _agent_config_to_response(refreshed_cfg, include_soul=True, user_id=user_id)
refreshed_cfg = load_agent_config(name)
return _agent_config_to_response(refreshed_cfg, include_soul=True)
except HTTPException:
raise
@@ -420,22 +402,15 @@ async def delete_agent(name: str) -> None:
name: The agent name.
Raises:
HTTPException: 404 if no per-user copy exists; 409 if only a legacy
shared copy exists (suggesting the migration script).
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
user_id = get_effective_user_id()
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, name)
agent_dir = get_paths().agent_dir(name)
if not agent_dir.exists():
if paths.agent_dir(name).exists():
raise HTTPException(
status_code=409,
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
)
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
try:
+3 -24
View File
@@ -68,27 +68,6 @@ class RunResponse(BaseModel):
updated_at: str = ""
class ThreadTokenUsageModelBreakdown(BaseModel):
tokens: int = 0
runs: int = 0
class ThreadTokenUsageCallerBreakdown(BaseModel):
lead_agent: int = 0
subagent: int = 0
middleware: int = 0
class ThreadTokenUsageResponse(BaseModel):
thread_id: str
total_tokens: int = 0
total_input_tokens: int = 0
total_output_tokens: int = 0
total_runs: int = 0
by_model: dict[str, ThreadTokenUsageModelBreakdown] = Field(default_factory=dict)
by_caller: ThreadTokenUsageCallerBreakdown = Field(default_factory=ThreadTokenUsageCallerBreakdown)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@@ -389,10 +368,10 @@ async def list_run_events(
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
@router.get("/{thread_id}/token-usage")
@require_permission("threads", "read", owner_check=True)
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
async def thread_token_usage(thread_id: str, request: Request) -> dict:
"""Thread-level token usage aggregation."""
run_store = get_run_store(request)
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
return {"thread_id": thread_id, **agg}
+21 -23
View File
@@ -13,11 +13,11 @@ matching the LangGraph Platform wire format expected by the
from __future__ import annotations
import logging
import time
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from langgraph.checkpoint.base import empty_checkpoint
from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_permission
@@ -26,7 +26,6 @@ from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.utils.time import coerce_iso, now_iso
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["threads"])
@@ -234,7 +233,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
checkpointer = get_checkpointer(request)
thread_store = get_thread_store(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = now_iso()
now = time.time()
# ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
@@ -244,8 +243,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
return ThreadResponse(
thread_id=thread_id,
status=existing_record.get("status", "idle"),
created_at=coerce_iso(existing_record.get("created_at", "")),
updated_at=coerce_iso(existing_record.get("updated_at", "")),
created_at=str(existing_record.get("created_at", "")),
updated_at=str(existing_record.get("updated_at", "")),
metadata=existing_record.get("metadata", {}),
)
@@ -263,6 +262,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
# Write an empty checkpoint so state endpoints work immediately
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
from langgraph.checkpoint.base import empty_checkpoint
ckpt_metadata = {
"step": -1,
"source": "input",
@@ -280,8 +281,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
return ThreadResponse(
thread_id=thread_id,
status="idle",
created_at=now,
updated_at=now,
created_at=str(now),
updated_at=str(now),
metadata=body.metadata,
)
@@ -306,11 +307,8 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
ThreadResponse(
thread_id=r["thread_id"],
status=r.get("status", "idle"),
# ``coerce_iso`` heals legacy unix-second values that
# ``MemoryThreadMetaStore`` historically wrote with ``time.time()``;
# SQL-backed rows already arrive as ISO strings and pass through.
created_at=coerce_iso(r.get("created_at", "")),
updated_at=coerce_iso(r.get("updated_at", "")),
created_at=r.get("created_at", ""),
updated_at=r.get("updated_at", ""),
metadata=r.get("metadata", {}),
values={"title": r["display_name"]} if r.get("display_name") else {},
interrupts={},
@@ -342,8 +340,8 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
return ThreadResponse(
thread_id=thread_id,
status=record.get("status", "idle"),
created_at=coerce_iso(record.get("created_at", "")),
updated_at=coerce_iso(record.get("updated_at", "")),
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
)
@@ -383,8 +381,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
record = {
"thread_id": thread_id,
"status": "idle",
"created_at": coerce_iso(ckpt_meta.get("created_at", "")),
"updated_at": coerce_iso(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
"created_at": ckpt_meta.get("created_at", ""),
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
}
@@ -398,8 +396,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
return ThreadResponse(
thread_id=thread_id,
status=status,
created_at=coerce_iso(record.get("created_at", "")),
updated_at=coerce_iso(record.get("updated_at", "")),
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
values=serialize_channel_values(channel_values),
)
@@ -450,10 +448,10 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
values=values,
next=next_tasks,
metadata=metadata,
checkpoint={"id": checkpoint_id, "ts": coerce_iso(metadata.get("created_at", ""))},
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_checkpoint_id,
created_at=coerce_iso(metadata.get("created_at", "")),
created_at=str(metadata.get("created_at", "")),
tasks=tasks,
)
@@ -503,7 +501,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
channel_values.update(body.values)
checkpoint["channel_values"] = channel_values
metadata["updated_at"] = now_iso()
metadata["updated_at"] = time.time()
if body.as_node:
metadata["source"] = "update"
@@ -544,7 +542,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
next=[],
metadata=metadata,
checkpoint_id=new_checkpoint_id,
created_at=coerce_iso(metadata.get("created_at", "")),
created_at=str(metadata.get("created_at", "")),
)
@@ -611,7 +609,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
parent_checkpoint_id=parent_id,
metadata=user_meta,
values=values,
created_at=coerce_iso(metadata.get("created_at", "")),
created_at=str(metadata.get("created_at", "")),
next=next_tasks,
)
)
+14 -44
View File
@@ -5,7 +5,7 @@ import os
import stat
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from pydantic import BaseModel, Field
from pydantic import BaseModel
from app.gateway.authz import require_permission
from app.gateway.deps import get_config
@@ -15,15 +15,12 @@ from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
from deerflow.uploads.manager import (
PathTraversalError,
UnsafeUploadPathError,
claim_unique_filename,
delete_file_safe,
enrich_file_listing,
ensure_uploads_dir,
get_uploads_dir,
list_files_in_dir,
normalize_filename,
open_upload_file_no_symlink,
upload_artifact_url,
upload_virtual_path,
)
@@ -45,7 +42,6 @@ class UploadResponse(BaseModel):
success: bool
files: list[dict[str, str]]
message: str
skipped_files: list[str] = Field(default_factory=list)
class UploadLimits(BaseModel):
@@ -120,18 +116,17 @@ def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None:
logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True)
async def _write_upload_file_with_limits(
async def _write_upload_file_streaming(
file: UploadFile,
file_path: os.PathLike[str] | str,
*,
uploads_dir: os.PathLike[str] | str,
display_filename: str,
max_single_file_size: int,
max_total_size: int,
total_size: int,
) -> tuple[os.PathLike[str] | str, int, int]:
) -> tuple[int, int]:
file_size = 0
file_path, fh = open_upload_file_no_symlink(uploads_dir, display_filename)
try:
with open(file_path, "wb") as output:
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
file_size += len(chunk)
total_size += len(chunk)
@@ -139,17 +134,8 @@ async def _write_upload_file_with_limits(
raise HTTPException(status_code=413, detail=f"File too large: {display_filename}")
if total_size > max_total_size:
raise HTTPException(status_code=413, detail="Total upload size too large")
fh.write(chunk)
except Exception:
fh.close()
try:
os.unlink(file_path)
except FileNotFoundError:
pass
raise
else:
fh.close()
return file_path, file_size, total_size
output.write(chunk)
return file_size, total_size
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
@@ -191,12 +177,7 @@ async def upload_files(
uploaded_files = []
written_paths = []
sandbox_sync_targets = []
skipped_files = []
total_size = 0
# Track filenames within this request so duplicate form parts do not
# silently truncate each other. Existing uploads keep the historical
# overwrite behavior for a single replacement upload.
seen_filenames: set[str] = set()
sandbox_provider = get_sandbox_provider()
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
@@ -213,22 +194,22 @@ async def upload_files(
continue
try:
original_filename = normalize_filename(file.filename)
safe_filename = claim_unique_filename(original_filename, seen_filenames)
safe_filename = normalize_filename(file.filename)
except ValueError:
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
continue
try:
file_path, file_size, total_size = await _write_upload_file_with_limits(
file_path = uploads_dir / safe_filename
written_paths.append(file_path)
file_size, total_size = await _write_upload_file_streaming(
file,
uploads_dir=uploads_dir,
file_path,
display_filename=safe_filename,
max_single_file_size=limits.max_file_size,
max_total_size=limits.max_total_size,
total_size=total_size,
)
written_paths.append(file_path)
virtual_path = upload_virtual_path(safe_filename)
@@ -242,8 +223,6 @@ async def upload_files(
"virtual_path": virtual_path,
"artifact_url": upload_artifact_url(thread_id, safe_filename),
}
if safe_filename != original_filename:
file_info["original_filename"] = original_filename
logger.info(f"Saved file: {safe_filename} ({file_size} bytes) to {file_info['path']}")
@@ -267,10 +246,6 @@ async def upload_files(
except HTTPException as e:
_cleanup_uploaded_paths(written_paths)
raise e
except UnsafeUploadPathError as e:
logger.warning("Skipping upload with unsafe destination %s: %s", file.filename, e)
skipped_files.append(safe_filename)
continue
except Exception as e:
logger.error(f"Failed to upload {file.filename}: {e}")
_cleanup_uploaded_paths(written_paths)
@@ -281,15 +256,10 @@ async def upload_files(
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, file_path.read_bytes())
message = f"Successfully uploaded {len(uploaded_files)} file(s)"
if skipped_files:
message += f"; skipped {len(skipped_files)} unsafe file(s)"
return UploadResponse(
success=not skipped_files,
success=True,
files=uploaded_files,
message=message,
skipped_files=skipped_files,
message=f"Successfully uploaded {len(uploaded_files)} file(s)",
)
-19
View File
@@ -136,24 +136,6 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
runtime_context.setdefault(key, context[key])
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
"""Stamp the authenticated user into the run context for background tools.
Tool execution may happen after the request handler has returned, so tools
that persist user-scoped files should not rely only on ambient ContextVars.
The value comes from server-side auth state, never from client context.
"""
user = getattr(request.state, "user", None)
user_id = getattr(user, "id", None)
if user_id is None:
return
runtime_context = config.setdefault("context", {})
if isinstance(runtime_context, dict):
runtime_context["user_id"] = str(user_id)
def resolve_agent_factory(assistant_id: str | None):
"""Resolve the agent factory callable from config.
@@ -306,7 +288,6 @@ async def start_run(
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
merge_run_context_overrides(config, getattr(body, "context", None))
inject_authenticated_user_context(config, request)
stream_modes = normalize_stream_modes(body.stream_mode)
-20
View File
@@ -79,9 +79,7 @@ async def main():
from langgraph.runtime import Runtime
from deerflow.agents import make_lead_agent
from deerflow.config.paths import get_paths
from deerflow.mcp import initialize_mcp_tools
from deerflow.runtime.user_context import get_effective_user_id
# Initialize MCP tools at startup
try:
@@ -115,8 +113,6 @@ async def main():
print("Tip: `uv sync --group dev` to enable arrow-key & history support")
print("=" * 50)
seen_artifacts: set[str] = set()
while True:
try:
if session:
@@ -138,22 +134,6 @@ async def main():
last_message = result["messages"][-1]
print(f"\nAgent: {last_message.content}")
# Show files presented to the user this turn (new artifacts only)
artifacts = result.get("artifacts") or []
new_artifacts = [p for p in artifacts if p not in seen_artifacts]
if new_artifacts:
thread_id = config["configurable"]["thread_id"]
user_id = get_effective_user_id()
paths = get_paths()
print("\n[Presented files]")
for virtual in new_artifacts:
try:
physical = paths.resolve_virtual_path(thread_id, virtual, user_id=user_id)
print(f" - {virtual}\n{physical}")
except ValueError as exc:
print(f" - {virtual} (failed to resolve physical path: {exc})")
seen_artifacts.update(new_artifacts)
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
@@ -173,7 +173,7 @@ def _assemble_from_features(
9. MemoryMiddleware (memory feature)
10. ViewImageMiddleware (vision feature)
11. SubagentLimitMiddleware (subagent feature)
12. LoopDetectionMiddleware (loop_detection feature)
12. LoopDetectionMiddleware (always)
13. ClarificationMiddleware (always last)
Two-phase ordering:
@@ -272,15 +272,10 @@ def _assemble_from_features(
extra_tools.append(task_tool)
# --- [12] LoopDetection ---
if feat.loop_detection is not False:
if isinstance(feat.loop_detection, AgentMiddleware):
chain.append(feat.loop_detection)
else:
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.config.loop_detection_config import LoopDetectionConfig
# --- [12] LoopDetection (always) ---
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
chain.append(LoopDetectionMiddleware.from_config(LoopDetectionConfig()))
chain.append(LoopDetectionMiddleware())
# --- [13] Clarification (always last among built-ins) ---
chain.append(ClarificationMiddleware())
@@ -31,7 +31,6 @@ class RuntimeFeatures:
vision: bool | AgentMiddleware = False
auto_title: bool | AgentMiddleware = False
guardrail: Literal[False] | AgentMiddleware = False
loop_detection: bool | AgentMiddleware = True
# ---------------------------------------------------------------------------
@@ -20,8 +20,6 @@ from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.models import create_chat_model
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill
logger = logging.getLogger(__name__)
@@ -258,12 +256,6 @@ def _build_middlewares(
resolved_app_config = app_config or get_app_config()
middlewares = build_lead_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
# Always inject current date (and optionally memory) as <system-reminder> into the
# first HumanMessage to keep the system prompt fully static for prefix-cache reuse.
from deerflow.agents.middlewares.dynamic_context_middleware import DynamicContextMiddleware
middlewares.append(DynamicContextMiddleware(agent_name=agent_name, app_config=resolved_app_config))
# Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
if summarization_middleware is not None:
@@ -305,9 +297,7 @@ def _build_middlewares(
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
# LoopDetectionMiddleware — detect and break repetitive tool call loops
loop_detection_config = resolved_app_config.loop_detection
if loop_detection_config.enabled:
middlewares.append(LoopDetectionMiddleware.from_config(loop_detection_config))
middlewares.append(LoopDetectionMiddleware())
# Inject custom middlewares before ClarificationMiddleware
if custom_middlewares:
@@ -318,28 +308,6 @@ def _build_middlewares(
return middlewares
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
if is_bootstrap:
return {"bootstrap"}
if agent_config and agent_config.skills is not None:
return set(agent_config.skills)
return None
def _load_enabled_skills_for_tool_policy(available_skills: set[str] | None, *, app_config: AppConfig) -> list[Skill]:
try:
from deerflow.agents.lead_agent.prompt import get_enabled_skills_for_config
skills = get_enabled_skills_for_config(app_config)
except Exception:
logger.exception("Failed to load skills for allowed-tools policy")
raise
if available_skills is None:
return skills
return [skill for skill in skills if skill.name in available_skills]
def make_lead_agent(config: RunnableConfig):
"""LangGraph graph factory; keep the signature compatible with LangGraph Server."""
runtime_config = _get_runtime_config(config)
@@ -350,7 +318,7 @@ def make_lead_agent(config: RunnableConfig):
def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent, update_agent
from deerflow.tools.builtins import setup_agent
cfg = _get_runtime_config(config)
resolved_app_config = app_config
@@ -365,7 +333,6 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
agent_name = validate_agent_name(cfg.get("agent_name"))
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
available_skills = _available_skill_names(agent_config, is_bootstrap)
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
agent_model_name = agent_config.model if agent_config and agent_config.model else None
@@ -404,18 +371,15 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
"is_plan_mode": is_plan_mode,
"subagent_enabled": subagent_enabled,
"tool_groups": agent_config.tool_groups if agent_config else None,
"available_skills": sorted(available_skills) if available_skills is not None else None,
"available_skills": ["bootstrap"] if is_bootstrap else (agent_config.skills if agent_config and agent_config.skills is not None else None),
}
)
skills_for_tool_policy = _load_enabled_skills_for_tool_policy(available_skills, app_config=resolved_app_config)
if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config),
tools=filter_tools_by_skill_allowed_tools(tools, skills_for_tool_policy),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent],
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled,
@@ -426,14 +390,15 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
state_schema=ThreadState,
)
# Custom agents can update their own SOUL.md / config via update_agent.
# The default agent (no agent_name) does not see this tool.
extra_tools = [update_agent] if agent_name else []
# Default lead agent (unchanged behavior)
tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config),
tools=filter_tools_by_skill_allowed_tools(tools + extra_tools, skills_for_tool_policy),
tools=get_available_tools(
model_name=model_name,
groups=agent_config.tool_groups if agent_config else None,
subagent_enabled=subagent_enabled,
app_config=resolved_app_config,
),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled,
@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import logging
import threading
from datetime import datetime
from functools import lru_cache
from typing import TYPE_CHECKING
@@ -19,7 +20,6 @@ logger = logging.getLogger(__name__)
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
_enabled_skills_lock = threading.Lock()
_enabled_skills_cache: list[Skill] | None = None
_enabled_skills_by_config_cache: dict[int, tuple[object, list[Skill]]] = {}
_enabled_skills_refresh_active = False
_enabled_skills_refresh_version = 0
_enabled_skills_refresh_event = threading.Event()
@@ -84,7 +84,6 @@ def _invalidate_enabled_skills_cache() -> threading.Event:
_get_cached_skills_prompt_section.cache_clear()
with _enabled_skills_lock:
_enabled_skills_cache = None
_enabled_skills_by_config_cache.clear()
_enabled_skills_refresh_version += 1
_enabled_skills_refresh_event.clear()
if _enabled_skills_refresh_active:
@@ -108,15 +107,6 @@ def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_W
def _get_enabled_skills():
return get_cached_enabled_skills()
def get_cached_enabled_skills() -> list[Skill]:
"""Return the cached enabled-skills list, kicking off a background refresh on miss.
Safe to call from request paths: never blocks on disk I/O. Returns an empty
list on cache miss; the next call will see the warmed result.
"""
with _enabled_skills_lock:
cached = _enabled_skills_cache
@@ -127,29 +117,17 @@ def get_cached_enabled_skills() -> list[Skill]:
return []
def get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
def _get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
"""Return enabled skills using the caller's config source.
When a concrete ``app_config`` is supplied, cache the loaded skills by that
config object's identity so request-scoped config injection still resolves
skill paths from the matching config without rescanning storage on every
agent factory call.
When a concrete ``app_config`` is supplied, bypass the global enabled-skills
cache so the skill list and skill paths are resolved from the same config
object. This keeps request-scoped config injection consistent even while the
release branch still supports global fallback paths.
"""
if app_config is None:
return _get_enabled_skills()
cache_key = id(app_config)
with _enabled_skills_lock:
cached = _enabled_skills_by_config_cache.get(cache_key)
if cached is not None:
cached_config, cached_skills = cached
if cached_config is app_config:
return list(cached_skills)
skills = list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
with _enabled_skills_lock:
_enabled_skills_by_config_cache[cache_key] = (app_config, skills)
return list(skills)
return list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
def _skill_mutability_label(category: SkillCategory | str) -> str:
@@ -366,7 +344,8 @@ You are {agent_name}, an open-source super agent.
</role>
{soul}
{self_update_section}
{memory_context}
<thinking_style>
- Think concisely and strategically about the user's request BEFORE taking action
- Break down the task: What is clear? What is ambiguous? What is missing?
@@ -625,7 +604,7 @@ You have access to skills that provide optimized workflows for specific tasks. E
def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_config: AppConfig | None = None) -> str:
"""Generate the skills prompt section with available skills list."""
skills = get_enabled_skills_for_config(app_config)
skills = _get_enabled_skills_for_config(app_config)
if app_config is None:
try:
@@ -664,26 +643,6 @@ def get_agent_soul(agent_name: str | None) -> str:
return ""
def _build_self_update_section(agent_name: str | None) -> str:
"""Prompt block that teaches the custom agent to persist self-updates via update_agent."""
if not agent_name:
return ""
return f"""<self_update>
You are running as the custom agent **{agent_name}** with a persisted SOUL.md and config.yaml.
When the user asks you to update your own description, personality, behaviour, skill set, tool groups, or default model,
you MUST persist the change with the `update_agent` tool. Do NOT use `bash`, `write_file`, or any sandbox tool to edit
SOUL.md or config.yaml — those write into a temporary sandbox/tool workspace and the changes will be lost on the next turn.
Rules:
- Always pass the FULL replacement text for `soul` (no patch semantics). Start from your current SOUL above and apply the user's edits.
- Only pass the fields that should change. Omit the others to preserve them.
- Pass `skills=[]` to disable all skills, or omit `skills` to keep the existing whitelist.
- After `update_agent` returns successfully, tell the user the change is persisted and will take effect on the next turn.
</self_update>
"""
def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
"""Generate <available-deferred-tools> block for the system prompt.
@@ -773,6 +732,9 @@ def apply_prompt_template(
available_skills: set[str] | None = None,
app_config: AppConfig | None = None,
) -> str:
# Get memory context
memory_context = _get_memory_context(agent_name, app_config=app_config)
# Include subagent section only if enabled (from runtime parameter)
n = max_concurrent_subagents
subagent_section = _build_subagent_section(n, app_config=app_config) if subagent_enabled else ""
@@ -806,18 +768,17 @@ def apply_prompt_template(
custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
# Build and return the fully static system prompt.
# Memory and current date are injected per-turn via DynamicContextMiddleware
# as a <system-reminder> in the first HumanMessage, keeping this prompt
# identical across users and sessions for maximum prefix-cache reuse.
return SYSTEM_PROMPT_TEMPLATE.format(
# Format the prompt with dynamic skills and memory
prompt = SYSTEM_PROMPT_TEMPLATE.format(
agent_name=agent_name or "DeerFlow 2.0",
soul=get_agent_soul(agent_name),
self_update_section=_build_self_update_section(agent_name),
skills_section=skills_section,
deferred_tools_section=deferred_tools_section,
memory_context=memory_context,
subagent_section=subagent_section,
subagent_reminder=subagent_reminder,
subagent_thinking=subagent_thinking,
acp_section=acp_and_mounts_section,
)
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
@@ -1,204 +0,0 @@
"""Middleware to inject dynamic context (memory, current date) as a system-reminder.
The system prompt is kept fully static for maximum prefix-cache reuse across users
and sessions. The current date is always injected. Per-user memory is also injected
when ``memory.injection_enabled`` is True in the app config. Both are delivered once
per conversation as a dedicated <system-reminder> HumanMessage inserted before the
first user message (frozen-snapshot pattern).
When a conversation spans midnight the middleware detects the date change and injects
a lightweight date-update reminder as a separate HumanMessage before the current turn.
This correction is persisted so subsequent turns on the new day see a consistent history
and do not re-inject.
Reminder format:
<system-reminder>
<memory>...</memory>
<current_date>2026-05-08, Friday</current_date>
</system-reminder>
Date-update format:
<system-reminder>
<current_date>2026-05-09, Saturday</current_date>
</system-reminder>
"""
from __future__ import annotations
import logging
import re
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, override
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
_SUMMARY_MESSAGE_NAME = "summary"
def _extract_date(content: str) -> str | None:
"""Return the first <current_date> value found in *content*, or None."""
m = _DATE_RE.search(content)
return m.group(1) if m else None
def is_dynamic_context_reminder(message: object) -> bool:
"""Return whether *message* is a hidden dynamic-context reminder."""
return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY))
def _last_injected_date(messages: list) -> str | None:
"""Scan messages in reverse and return the most recently injected date.
Detection uses the ``dynamic_context_reminder`` additional_kwargs flag rather
than content substring matching, so user messages containing ``<system-reminder>``
are not mistakenly treated as injected reminders.
"""
for msg in reversed(messages):
if is_dynamic_context_reminder(msg):
content_str = msg.content if isinstance(msg.content, str) else str(msg.content)
return _extract_date(content_str)
return None
def _is_user_injection_target(message: object) -> bool:
"""Return whether *message* can receive a dynamic-context reminder."""
return isinstance(message, HumanMessage) and not is_dynamic_context_reminder(message) and message.name != _SUMMARY_MESSAGE_NAME
class DynamicContextMiddleware(AgentMiddleware):
"""Inject memory and current date into HumanMessages as a <system-reminder>.
First turn
----------
Prepends a full system-reminder (memory + date) to the first HumanMessage and
persists it (same message ID). The first message is then frozen for the whole
session — its content never changes again, so the prefix cache can hit on every
subsequent turn.
Midnight crossing
-----------------
If the conversation spans midnight, the current date differs from the date that
was injected earlier. In that case a lightweight date-update reminder is prepended
to the **current** (last) HumanMessage and persisted. Subsequent turns on the new
day see the corrected date in history and skip re-injection.
"""
def __init__(self, agent_name: str | None = None, *, app_config: AppConfig | None = None):
super().__init__()
self._agent_name = agent_name
self._app_config = app_config
def _build_full_reminder(self) -> str:
from deerflow.agents.lead_agent.prompt import _get_memory_context
# Memory injection is gated by injection_enabled; date is always included.
injection_enabled = self._app_config.memory.injection_enabled if self._app_config else True
memory_context = _get_memory_context(self._agent_name, app_config=self._app_config) if injection_enabled else ""
current_date = datetime.now().strftime("%Y-%m-%d, %A")
lines: list[str] = ["<system-reminder>"]
if memory_context:
lines.append(memory_context.strip())
lines.append("") # blank line separating memory from date
lines.append(f"<current_date>{current_date}</current_date>")
lines.append("</system-reminder>")
return "\n".join(lines)
def _build_date_update_reminder(self) -> str:
current_date = datetime.now().strftime("%Y-%m-%d, %A")
return "\n".join(
[
"<system-reminder>",
f"<current_date>{current_date}</current_date>",
"</system-reminder>",
]
)
@staticmethod
def _make_reminder_and_user_messages(original: HumanMessage, reminder_content: str) -> tuple[HumanMessage, HumanMessage]:
"""Return (reminder_msg, user_msg) using the ID-swap technique.
reminder_msg takes the original message's ID so that add_messages replaces it
in-place (preserving position). user_msg carries the original content with a
derived ``{id}__user`` ID and is appended immediately after by add_messages.
If the original message has no ID a stable UUID is generated so the derived
``{id}__user`` ID never collapses to the ambiguous ``None__user`` string.
"""
stable_id = original.id or str(uuid.uuid4())
reminder_msg = HumanMessage(
content=reminder_content,
id=stable_id,
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
)
user_msg = HumanMessage(
content=original.content,
id=f"{stable_id}__user",
name=original.name,
additional_kwargs=original.additional_kwargs,
)
return reminder_msg, user_msg
def _inject(self, state) -> dict | None:
messages = list(state.get("messages", []))
if not messages:
return None
current_date = datetime.now().strftime("%Y-%m-%d, %A")
last_date = _last_injected_date(messages)
logger.debug(
"DynamicContextMiddleware._inject: msg_count=%d last_date=%r current_date=%r",
len(messages),
last_date,
current_date,
)
if last_date is None:
# ── First turn: inject full reminder as a separate HumanMessage ─────
first_idx = next((i for i, m in enumerate(messages) if _is_user_injection_target(m)), None)
if first_idx is None:
return None
full_reminder = self._build_full_reminder()
logger.info(
"DynamicContextMiddleware: injecting full reminder (len=%d, has_memory=%s) into first HumanMessage id=%r",
len(full_reminder),
"<memory>" in full_reminder,
messages[first_idx].id,
)
reminder_msg, user_msg = self._make_reminder_and_user_messages(messages[first_idx], full_reminder)
return {"messages": [reminder_msg, user_msg]}
if last_date == current_date:
# ── Same day: nothing to do ──────────────────────────────────────────
return None
# ── Midnight crossed: inject date-update reminder as a separate HumanMessage ──
last_human_idx = next((i for i in reversed(range(len(messages))) if _is_user_injection_target(messages[i])), None)
if last_human_idx is None:
return None
reminder_msg, user_msg = self._make_reminder_and_user_messages(messages[last_human_idx], self._build_date_update_reminder())
logger.info("DynamicContextMiddleware: midnight crossing detected — injected date update before current turn")
return {"messages": [reminder_msg, user_msg]}
@override
def before_agent(self, state, runtime: Runtime) -> dict | None:
return self._inject(state)
@override
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
return self._inject(state)
@@ -12,23 +12,19 @@ Detection strategy:
response so the agent is forced to produce a final text answer.
"""
from __future__ import annotations
import hashlib
import json
import logging
import threading
from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import TYPE_CHECKING, override
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
if TYPE_CHECKING:
from deerflow.config.loop_detection_config import LoopDetectionConfig
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
@@ -144,9 +140,6 @@ _TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops.
Threshold parameters are validated upstream by :class:`LoopDetectionConfig`;
construct via :meth:`from_config` to ensure values pass Pydantic validation.
Args:
warn_threshold: Number of identical tool call sets before injecting
a warning message. Default: 3.
@@ -162,14 +155,6 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
Default: 30.
tool_freq_hard_limit: Number of calls to the same tool type before
forcing a stop. Default: 50.
tool_freq_overrides: Per-tool overrides for frequency thresholds,
keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple
that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for
that specific tool. Tools not listed here fall back to the global
thresholds. Useful for raising limits on intentionally
high-frequency tools (e.g. ``bash`` in batch pipelines) without
weakening protection on all other tools. Default: ``None``
(no overrides).
"""
def __init__(
@@ -180,7 +165,6 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
tool_freq_overrides: dict[str, tuple[int, int]] | None = None,
):
super().__init__()
self.warn_threshold = warn_threshold
@@ -189,26 +173,14 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self.max_tracked_threads = max_tracked_threads
self.tool_freq_warn = tool_freq_warn
self.tool_freq_hard_limit = tool_freq_hard_limit
self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {}
self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set)
# Per-thread, per-tool-type cumulative call counts
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
@classmethod
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
"""Construct from a Pydantic-validated config, trusting its validation."""
return cls(
warn_threshold=config.warn_threshold,
hard_limit=config.hard_limit,
window_size=config.window_size,
max_tracked_threads=config.max_tracked_threads,
tool_freq_warn=config.tool_freq_warn,
tool_freq_hard_limit=config.tool_freq_hard_limit,
tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()},
)
def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
@@ -308,12 +280,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
freq[name] += 1
tc_count = freq[name]
if name in self._tool_freq_overrides:
eff_warn, eff_hard = self._tool_freq_overrides[name]
else:
eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit
if tc_count >= eff_hard:
if tc_count >= self.tool_freq_hard_limit:
logger.error(
"Tool frequency hard limit reached — forcing stop",
extra={
@@ -324,7 +291,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
)
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
if tc_count >= eff_warn:
if tc_count >= self.tool_freq_warn:
warned = self._tool_freq_warned[thread_id]
if name not in warned:
warned.add(name)
@@ -389,30 +356,13 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
return {"messages": [stripped_msg]}
if warning:
# WORKAROUND for v2.0-m1 — see #2724.
#
# Append the warning to the AIMessage content instead of
# injecting a separate HumanMessage. Inserting any non-tool
# message between an AIMessage(tool_calls=...) and its
# ToolMessage responses breaks OpenAI/Moonshot strict pairing
# validation ("tool_call_ids did not have response messages")
# because the tools node has not run yet at after_model time.
# tool_calls are preserved so the tools node still executes.
#
# This is a temporary mitigation: mutating an existing
# AIMessage to carry framework-authored text leaks loop-warning
# text into downstream consumers (MemoryMiddleware fact
# extraction, TitleMiddleware, telemetry, model replay) as if
# the model said it. The proper fix is to defer warning
# injection from after_model to wrap_model_call so every prior
# ToolMessage is already in the request — see RFC #2517 (which
# lists "loop intervention does not leave invalid
# tool-call/tool-message state" as acceptance criteria) and
# the prototype on `fix/loop-detection-tool-call-pairing`.
messages = state.get("messages", [])
last_msg = messages[-1]
patched_msg = last_msg.model_copy(update={"content": self._append_text(last_msg.content, warning)})
return {"messages": [patched_msg]}
# Inject as HumanMessage instead of SystemMessage to avoid
# Anthropic's "multiple non-consecutive system messages" error.
# Anthropic models require system messages only at the start of
# the conversation; injecting one mid-conversation crashes
# langchain_anthropic's _format_messages(). HumanMessage works
# with all providers. See #1299.
return {"messages": [HumanMessage(content=warning, name="loop_warning")]}
return None
@@ -7,7 +7,6 @@ from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
logger = logging.getLogger(__name__)
@@ -64,7 +63,7 @@ class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
updated_msg = clone_ai_message_with_tool_calls(last_msg, truncated_tool_calls)
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
return {"messages": [updated_msg]}
@override
@@ -14,9 +14,6 @@ from langgraph.config import get_config
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
logger = logging.getLogger(__name__)
@@ -81,7 +78,10 @@ def _clone_ai_message(
content: Any | None = None,
) -> AIMessage:
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
return clone_ai_message_with_tool_calls(message, tool_calls, content=content)
update: dict[str, Any] = {"tool_calls": tool_calls}
if content is not None:
update["content"] = content
return message.model_copy(update=update)
@dataclass
@@ -136,7 +136,6 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
return None
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = self._create_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
@@ -162,7 +161,6 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
return None
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = await self._acreate_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
@@ -182,24 +180,6 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
"""
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
def _preserve_dynamic_context_reminders(
self,
messages_to_summarize: list[AnyMessage],
preserved_messages: list[AnyMessage],
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Keep hidden dynamic-context reminders out of summary compression.
These reminders carry the current date and optional memory. If summarization
removes them, DynamicContextMiddleware can mistake the summary HumanMessage
for the first user message and inject the reminder in the wrong place.
"""
reminders = [msg for msg in messages_to_summarize if is_dynamic_context_reminder(msg)]
if not reminders:
return messages_to_summarize, preserved_messages
remaining = [msg for msg in messages_to_summarize if not is_dynamic_context_reminder(msg)]
return remaining, reminders + preserved_messages
def _partition_with_skill_rescue(
self,
messages: list[AnyMessage],
@@ -9,7 +9,6 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
from deerflow.config.title_config import get_title_config
from deerflow.models import create_chat_model
@@ -62,10 +61,6 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return ""
@staticmethod
def _is_user_message_for_title(message: object) -> bool:
return getattr(message, "type", None) == "human" and not is_dynamic_context_reminder(message)
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
"""Check if we should generate a title for this thread."""
config = self._get_title_config()
@@ -82,7 +77,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return False
# Count user and assistant messages
user_messages = [m for m in messages if self._is_user_message_for_title(m)]
user_messages = [m for m in messages if m.type == "human"]
assistant_messages = [m for m in messages if m.type == "ai"]
# Generate title after first complete exchange
@@ -96,7 +91,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
config = self._get_title_config()
messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if self._is_user_message_for_title(m)), "")
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
user_msg = self._normalize_content(user_msg_content)
@@ -1,303 +1,37 @@
"""Middleware for logging token usage and annotating step attribution."""
from __future__ import annotations
"""Middleware for logging LLM token usage."""
import logging
from collections import defaultdict
from typing import Any, override
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.todo import Todo
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
TOKEN_USAGE_ATTRIBUTION_KEY = "token_usage_attribution"
def _string_arg(value: Any) -> str | None:
if isinstance(value, str):
normalized = value.strip()
return normalized or None
return None
def _normalize_todos(value: Any) -> list[Todo]:
if not isinstance(value, list):
return []
normalized: list[Todo] = []
for item in value:
if not isinstance(item, dict):
continue
todo: Todo = {}
content = _string_arg(item.get("content"))
status = item.get("status")
if content is not None:
todo["content"] = content
if status in {"pending", "in_progress", "completed"}:
todo["status"] = status
normalized.append(todo)
return normalized
def _todo_action_kind(previous: Todo | None, current: Todo) -> str:
status = current.get("status")
previous_content = previous.get("content") if previous else None
current_content = current.get("content")
if previous is None:
if status == "completed":
return "todo_complete"
if status == "in_progress":
return "todo_start"
return "todo_update"
if previous_content != current_content:
return "todo_update"
if status == "completed":
return "todo_complete"
if status == "in_progress":
return "todo_start"
return "todo_update"
def _build_todo_actions(previous_todos: list[Todo], next_todos: list[Todo]) -> list[dict[str, Any]]:
# This is the single source of truth for precise write_todos token
# attribution. The frontend intentionally falls back to a generic
# "Update to-do list" label when this metadata is missing or malformed.
previous_by_content: dict[str, list[tuple[int, Todo]]] = defaultdict(list)
matched_previous_indices: set[int] = set()
for index, todo in enumerate(previous_todos):
content = todo.get("content")
if isinstance(content, str) and content:
previous_by_content[content].append((index, todo))
actions: list[dict[str, Any]] = []
for index, todo in enumerate(next_todos):
content = todo.get("content")
if not isinstance(content, str) or not content:
continue
previous_match: Todo | None = None
content_matches = previous_by_content.get(content)
if content_matches:
while content_matches and content_matches[0][0] in matched_previous_indices:
content_matches.pop(0)
if content_matches:
previous_index, previous_match = content_matches.pop(0)
matched_previous_indices.add(previous_index)
if previous_match is None and index < len(previous_todos) and index not in matched_previous_indices:
previous_match = previous_todos[index]
matched_previous_indices.add(index)
if previous_match is not None:
previous_content = previous_match.get("content")
previous_status = previous_match.get("status")
if previous_content == content and previous_status == todo.get("status"):
continue
actions.append(
{
"kind": _todo_action_kind(previous_match, todo),
"content": content,
}
)
for index, todo in enumerate(previous_todos):
if index in matched_previous_indices:
continue
content = todo.get("content")
if not isinstance(content, str) or not content:
continue
actions.append(
{
"kind": "todo_remove",
"content": content,
}
)
return actions
def _describe_tool_call(tool_call: dict[str, Any], todos: list[Todo]) -> list[dict[str, Any]]:
name = _string_arg(tool_call.get("name")) or "unknown"
args = tool_call.get("args") if isinstance(tool_call.get("args"), dict) else {}
tool_call_id = _string_arg(tool_call.get("id"))
if name == "write_todos":
next_todos = _normalize_todos(args.get("todos"))
actions = _build_todo_actions(todos, next_todos)
if not actions:
return [
{
"kind": "tool",
"tool_name": name,
"tool_call_id": tool_call_id,
}
]
return [
{
**action,
"tool_call_id": tool_call_id,
}
for action in actions
]
if name == "task":
return [
{
"kind": "subagent",
"description": _string_arg(args.get("description")),
"subagent_type": _string_arg(args.get("subagent_type")),
"tool_call_id": tool_call_id,
}
]
if name in {"web_search", "image_search"}:
query = _string_arg(args.get("query"))
return [
{
"kind": "search",
"tool_name": name,
"query": query,
"tool_call_id": tool_call_id,
}
]
if name == "present_files":
return [
{
"kind": "present_files",
"tool_call_id": tool_call_id,
}
]
if name == "ask_clarification":
return [
{
"kind": "clarification",
"tool_call_id": tool_call_id,
}
]
return [
{
"kind": "tool",
"tool_name": name,
"description": _string_arg(args.get("description")),
"tool_call_id": tool_call_id,
}
]
def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
if actions:
first_kind = actions[0].get("kind")
if len(actions) == 1 and first_kind in {"todo_start", "todo_complete", "todo_update", "todo_remove"}:
return "todo_update"
if len(actions) == 1 and first_kind == "subagent":
return "subagent_dispatch"
return "tool_batch"
if message.content:
return "final_answer"
return "thinking"
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
tool_calls = getattr(message, "tool_calls", None) or []
actions: list[dict[str, Any]] = []
current_todos = list(todos)
for raw_tool_call in tool_calls:
if not isinstance(raw_tool_call, dict):
continue
described_actions = _describe_tool_call(raw_tool_call, current_todos)
actions.extend(described_actions)
if raw_tool_call.get("name") == "write_todos":
args = raw_tool_call.get("args") if isinstance(raw_tool_call.get("args"), dict) else {}
current_todos = _normalize_todos(args.get("todos"))
tool_call_ids: list[str] = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
continue
tool_call_id = _string_arg(tool_call.get("id"))
if tool_call_id is not None:
tool_call_ids.append(tool_call_id)
return {
# Schema changes should remain additive where possible so older
# frontends can ignore unknown fields and fall back safely.
"version": 1,
"kind": _infer_step_kind(message, actions),
"shared_attribution": len(actions) > 1,
"tool_call_ids": tool_call_ids,
"actions": actions,
}
class TokenUsageMiddleware(AgentMiddleware):
"""Logs token usage from model responses and annotates the AI step."""
def _apply(self, state: AgentState) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
usage = getattr(last, "usage_metadata", None)
if usage:
input_token_details = usage.get("input_token_details") or {}
output_token_details = usage.get("output_token_details") or {}
detail_parts = []
if input_token_details:
detail_parts.append(f"input_token_details={input_token_details}")
if output_token_details:
detail_parts.append(f"output_token_details={output_token_details}")
detail_suffix = f" {' '.join(detail_parts)}" if detail_parts else ""
logger.info(
"LLM token usage: input=%s output=%s total=%s%s",
usage.get("input_tokens", "?"),
usage.get("output_tokens", "?"),
usage.get("total_tokens", "?"),
detail_suffix,
)
todos = state.get("todos") or []
attribution = _build_attribution(last, todos if isinstance(todos, list) else [])
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
return None
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
return {"messages": [updated_msg]}
"""Logs token usage from model response usage_metadata."""
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state)
return self._log_usage(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state)
return self._log_usage(state)
def _log_usage(self, state: AgentState) -> None:
messages = state.get("messages", [])
if not messages:
return None
last = messages[-1]
usage = getattr(last, "usage_metadata", None)
if usage:
logger.info(
"LLM token usage: input=%s output=%s total=%s",
usage.get("input_tokens", "?"),
usage.get("output_tokens", "?"),
usage.get("total_tokens", "?"),
)
return None
@@ -1,50 +0,0 @@
"""Helpers for keeping AIMessage tool-call metadata consistent."""
from __future__ import annotations
from typing import Any
from langchain_core.messages import AIMessage
def _raw_tool_call_id(raw_tool_call: Any) -> str | None:
if not isinstance(raw_tool_call, dict):
return None
raw_id = raw_tool_call.get("id")
return raw_id if isinstance(raw_id, str) and raw_id else None
def clone_ai_message_with_tool_calls(
message: AIMessage,
tool_calls: list[dict[str, Any]],
*,
content: Any | None = None,
) -> AIMessage:
"""Clone an AIMessage while keeping raw provider tool-call metadata in sync."""
kept_ids = {tc["id"] for tc in tool_calls if isinstance(tc.get("id"), str) and tc["id"]}
update: dict[str, Any] = {"tool_calls": tool_calls}
if content is not None:
update["content"] = content
additional_kwargs = dict(getattr(message, "additional_kwargs", {}) or {})
raw_tool_calls = additional_kwargs.get("tool_calls")
if isinstance(raw_tool_calls, list):
synced_raw_tool_calls = [raw_tc for raw_tc in raw_tool_calls if _raw_tool_call_id(raw_tc) in kept_ids]
if synced_raw_tool_calls:
additional_kwargs["tool_calls"] = synced_raw_tool_calls
else:
additional_kwargs.pop("tool_calls", None)
if not tool_calls:
additional_kwargs.pop("function_call", None)
update["additional_kwargs"] = additional_kwargs
response_metadata = dict(getattr(message, "response_metadata", {}) or {})
if not tool_calls and response_metadata.get("finish_reason") == "tool_calls":
response_metadata["finish_reason"] = "stop"
update["response_metadata"] = response_metadata
return message.model_copy(update=update)
+37 -106
View File
@@ -228,14 +228,21 @@ class DeerFlowClient:
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
kwargs: dict[str, Any] = {
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=self._app_config),
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
"middleware": _build_middlewares(
config,
model_name=model_name,
agent_name=self._agent_name,
custom_middlewares=self._middlewares,
app_config=self._app_config,
),
"system_prompt": apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=self._agent_name,
available_skills=self._available_skills,
app_config=self._app_config,
),
"state_schema": ThreadState,
}
@@ -243,7 +250,7 @@ class DeerFlowClient:
if checkpointer is None:
from deerflow.runtime.checkpointer import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(app_config=self._app_config)
if checkpointer is not None:
kwargs["checkpointer"] = checkpointer
@@ -251,12 +258,15 @@ class DeerFlowClient:
self._agent_config_key = key
logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled)
@staticmethod
def _get_tools(*, model_name: str | None, subagent_enabled: bool):
def _get_tools(self, *, model_name: str | None, subagent_enabled: bool):
"""Lazy import to avoid circular dependency at module level."""
from deerflow.tools import get_available_tools
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
return get_available_tools(
model_name=model_name,
subagent_enabled=subagent_enabled,
app_config=self._app_config,
)
@staticmethod
def _serialize_tool_calls(tool_calls) -> list[dict]:
@@ -264,35 +274,25 @@ class DeerFlowClient:
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
@staticmethod
def _serialize_additional_kwargs(msg) -> dict[str, Any] | None:
"""Copy message additional_kwargs when present."""
additional_kwargs = getattr(msg, "additional_kwargs", None)
if isinstance(additional_kwargs, dict) and additional_kwargs:
return dict(additional_kwargs)
return None
@staticmethod
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI text event."""
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI text event, attaching usage when present."""
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
if usage:
data["usage_metadata"] = usage
if additional_kwargs:
data["additional_kwargs"] = additional_kwargs
return StreamEvent(type="messages-tuple", data=data)
@staticmethod
def _ai_tool_calls_event(msg_id: str | None, tool_calls, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent":
"""Build a ``messages-tuple`` AI tool-calls event."""
data: dict[str, Any] = {
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
}
if additional_kwargs:
data["additional_kwargs"] = additional_kwargs
return StreamEvent(type="messages-tuple", data=data)
return StreamEvent(
type="messages-tuple",
data={
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
},
)
@staticmethod
def _tool_message_event(msg: ToolMessage) -> "StreamEvent":
@@ -317,30 +317,19 @@ class DeerFlowClient:
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
if getattr(msg, "usage_metadata", None):
d["usage_metadata"] = msg.usage_metadata
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
if isinstance(msg, ToolMessage):
d = {
return {
"type": "tool",
"content": DeerFlowClient._extract_text(msg.content),
"name": getattr(msg, "name", None),
"tool_call_id": getattr(msg, "tool_call_id", None),
"id": getattr(msg, "id", None),
}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
if isinstance(msg, HumanMessage):
d = {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
if isinstance(msg, SystemMessage):
d = {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
@staticmethod
@@ -398,7 +387,7 @@ class DeerFlowClient:
if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(app_config=self._app_config)
thread_info_map = {}
@@ -453,7 +442,7 @@ class DeerFlowClient:
if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(app_config=self._app_config)
config = {"configurable": {"thread_id": thread_id}}
checkpoints = []
@@ -563,7 +552,6 @@ class DeerFlowClient:
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "additional_kwargs": {...}}
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
"""
@@ -586,7 +574,6 @@ class DeerFlowClient:
# in both the final ``messages`` chunk and the values snapshot —
# count it only on whichever arrives first.
counted_usage_ids: set[str] = set()
sent_additional_kwargs_by_id: dict[str, dict[str, Any]] = {}
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
@@ -616,20 +603,6 @@ class DeerFlowClient:
"total_tokens": total_tokens,
}
def _unsent_additional_kwargs(msg_id: str | None, additional_kwargs: dict[str, Any] | None) -> dict[str, Any] | None:
if not additional_kwargs:
return None
if not msg_id:
return additional_kwargs
sent = sent_additional_kwargs_by_id.setdefault(msg_id, {})
delta = {key: value for key, value in additional_kwargs.items() if sent.get(key) != value}
if not delta:
return None
sent.update(delta)
return delta
for item in self._agent.stream(
state,
config=config,
@@ -657,31 +630,17 @@ class DeerFlowClient:
if isinstance(msg_chunk, AIMessage):
text = self._extract_text(msg_chunk.content)
additional_kwargs = self._serialize_additional_kwargs(msg_chunk)
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
sent_additional_kwargs = False
if text:
if msg_id:
streamed_ids.add(msg_id)
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_text_event(
msg_id,
text,
counted_usage,
additional_kwargs_delta,
)
sent_additional_kwargs = bool(additional_kwargs_delta)
yield self._ai_text_event(msg_id, text, counted_usage)
if msg_chunk.tool_calls:
if msg_id:
streamed_ids.add(msg_id)
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_tool_calls_event(
msg_id,
msg_chunk.tool_calls,
additional_kwargs_delta,
)
yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls)
elif isinstance(msg_chunk, ToolMessage):
if msg_id:
@@ -704,45 +663,17 @@ class DeerFlowClient:
if msg_id and msg_id in streamed_ids:
if isinstance(msg, AIMessage):
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
additional_kwargs = self._serialize_additional_kwargs(msg)
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
if additional_kwargs_delta:
# Metadata-only follow-up: ``messages-tuple`` has no
# dedicated attribution event, so clients should
# merge this empty-content AI event by message id
# and ignore it for text rendering.
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
continue
if isinstance(msg, AIMessage):
counted_usage = _account_usage(msg_id, msg.usage_metadata)
additional_kwargs = self._serialize_additional_kwargs(msg)
sent_additional_kwargs = False
if msg.tool_calls:
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_tool_calls_event(
msg_id,
msg.tool_calls,
additional_kwargs_delta,
)
sent_additional_kwargs = bool(additional_kwargs_delta)
yield self._ai_tool_calls_event(msg_id, msg.tool_calls)
text = self._extract_text(msg.content)
if text:
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_text_event(
msg_id,
text,
counted_usage,
additional_kwargs_delta,
)
elif msg_id:
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
if not additional_kwargs_delta:
continue
# See the metadata-only follow-up convention above.
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
yield self._ai_text_event(msg_id, text, counted_usage)
elif isinstance(msg, ToolMessage):
yield self._tool_message_event(msg)
@@ -80,7 +80,6 @@ class AioSandboxProvider(SandboxProvider):
port: 8080 # Base port for local containers
container_prefix: deer-flow-sandbox
idle_timeout: 600 # Idle timeout in seconds (0 to disable)
auto_restart: true # Restart crashed containers automatically
replicas: 3 # Max concurrent sandbox containers (LRU eviction when exceeded)
mounts: # Volume mounts for local containers
- host_path: /path/on/host
@@ -165,14 +164,12 @@ class AioSandboxProvider(SandboxProvider):
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
replicas = getattr(sandbox_config, "replicas", None)
auto_restart = getattr(sandbox_config, "auto_restart", True)
return {
"image": sandbox_config.image or DEFAULT_IMAGE,
"port": sandbox_config.port or DEFAULT_PORT,
"container_prefix": sandbox_config.container_prefix or DEFAULT_CONTAINER_PREFIX,
"idle_timeout": idle_timeout if idle_timeout is not None else DEFAULT_IDLE_TIMEOUT,
"auto_restart": auto_restart,
"replicas": replicas if replicas is not None else DEFAULT_REPLICAS,
"mounts": sandbox_config.mounts or [],
"environment": self._resolve_env_vars(sandbox_config.environment or {}),
@@ -611,57 +608,17 @@ class AioSandboxProvider(SandboxProvider):
def get(self, sandbox_id: str) -> Sandbox | None:
"""Get a sandbox by ID. Updates last activity timestamp.
When ``auto_restart`` is enabled (the default), the container's liveness
is verified on each lookup. If the underlying container has crashed, the
sandbox is evicted from all caches so that the next ``acquire()`` call will
transparently create a fresh container.
Args:
sandbox_id: The ID of the sandbox.
Returns:
The sandbox instance if found and alive, None otherwise.
The sandbox instance if found, None otherwise.
"""
with self._lock:
sandbox = self._sandboxes.get(sandbox_id)
if sandbox is None:
return None
self._last_activity[sandbox_id] = time.time()
auto_restart = self._config.get("auto_restart", True)
info = self._sandbox_infos.get(sandbox_id) if auto_restart else None
if not info:
return sandbox
if self._backend.is_alive(info):
return sandbox
info_to_destroy = None
with self._lock:
current_sandbox = self._sandboxes.get(sandbox_id)
current_info = self._sandbox_infos.get(sandbox_id)
if current_sandbox is None:
return None
if current_info is not info:
if sandbox is not None:
self._last_activity[sandbox_id] = time.time()
return current_sandbox
logger.warning(f"Sandbox {sandbox_id} container is not alive, evicting from cache for auto-restart")
self._sandboxes.pop(sandbox_id, None)
self._sandbox_infos.pop(sandbox_id, None)
self._last_activity.pop(sandbox_id, None)
self._warm_pool.pop(sandbox_id, None)
thread_ids = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids:
del self._thread_sandboxes[tid]
info_to_destroy = info
if info_to_destroy:
try:
self._backend.destroy(info_to_destroy)
except Exception as e:
logger.warning(f"Failed to cleanup dead sandbox {sandbox_id}: {e}")
return None
return sandbox
def release(self, sandbox_id: str) -> None:
"""Release a sandbox from active use into the warm pool.
@@ -84,52 +84,8 @@ class RemoteSandboxBackend(SandboxBackend):
"""
return self._provisioner_discover(sandbox_id)
def list_running(self) -> list[SandboxInfo]:
"""Return all sandboxes currently managed by the provisioner.
Calls ``GET /api/sandboxes`` so that ``AioSandboxProvider._reconcile_orphans()``
can adopt pods that were created by a previous process and were never
explicitly destroyed.
Without this, a process restart silently orphans all existing k8s Pods —
they stay running forever because the idle checker only
tracks in-process state.
"""
return self._provisioner_list()
# ── Provisioner API calls ─────────────────────────────────────────────
def _provisioner_list(self) -> list[SandboxInfo]:
"""GET /api/sandboxes → list all running sandboxes."""
try:
resp = requests.get(f"{self._provisioner_url}/api/sandboxes", timeout=10)
resp.raise_for_status()
data = resp.json()
if not isinstance(data, dict):
logger.warning("Provisioner list_running returned non-dict payload: %r", type(data))
return []
sandboxes = data.get("sandboxes", [])
if not isinstance(sandboxes, list):
logger.warning("Provisioner list_running returned non-list sandboxes: %r", type(sandboxes))
return []
infos: list[SandboxInfo] = []
for sandbox in sandboxes:
if not isinstance(sandbox, dict):
logger.warning("Provisioner list_running entry is not a dict: %r", type(sandbox))
continue
sandbox_id = sandbox.get("sandbox_id")
sandbox_url = sandbox.get("sandbox_url")
if isinstance(sandbox_id, str) and sandbox_id and isinstance(sandbox_url, str) and sandbox_url:
infos.append(SandboxInfo(sandbox_id=sandbox_id, sandbox_url=sandbox_url))
logger.info("Provisioner list_running: %d sandbox(es) found", len(infos))
return infos
except requests.RequestException as exc:
logger.warning("Provisioner list_running failed: %s", exc)
return []
def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
"""POST /api/sandboxes → create Pod + Service."""
try:
@@ -1,3 +0,0 @@
from .tools import web_search_tool
__all__ = ["web_search_tool"]
@@ -1,95 +0,0 @@
"""
Web Search Tool - Search the web using Serper (Google Search API).
Serper provides real-time Google Search results via a JSON API.
An API key is required. Sign up at https://serper.dev to get one.
"""
import json
import logging
import os
import httpx
from langchain.tools import tool
from deerflow.config import get_app_config
logger = logging.getLogger(__name__)
_SERPER_ENDPOINT = "https://google.serper.dev/search"
_api_key_warned = False
def _get_api_key() -> str | None:
config = get_app_config().get_tool_config("web_search")
if config is not None:
api_key = config.model_extra.get("api_key")
if isinstance(api_key, str) and api_key.strip():
return api_key
return os.getenv("SERPER_API_KEY")
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str, max_results: int = 5) -> str:
"""Search the web for information using Google Search via Serper.
Args:
query: Search keywords describing what you want to find. Be specific for better results.
max_results: Maximum number of search results to return. Default is 5.
"""
global _api_key_warned
config = get_app_config().get_tool_config("web_search")
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results", max_results)
api_key = _get_api_key()
if not api_key:
if not _api_key_warned:
_api_key_warned = True
logger.warning("Serper API key is not set. Set SERPER_API_KEY in your environment or provide api_key in config.yaml. Sign up at https://serper.dev")
return json.dumps(
{"error": "SERPER_API_KEY is not configured", "query": query},
ensure_ascii=False,
)
headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}
payload = {"q": query, "num": max_results}
try:
with httpx.Client(timeout=30) as client:
response = client.post(_SERPER_ENDPOINT, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
except httpx.HTTPStatusError as e:
logger.error(f"Serper API returned HTTP {e.response.status_code}: {e.response.text}")
return json.dumps(
{"error": f"Serper API error: HTTP {e.response.status_code}", "query": query},
ensure_ascii=False,
)
except Exception as e:
logger.error(f"Serper search failed: {type(e).__name__}: {e}")
return json.dumps({"error": str(e), "query": query}, ensure_ascii=False)
organic = data.get("organic", [])
if not organic:
return json.dumps({"error": "No results found", "query": query}, ensure_ascii=False)
normalized_results = [
{
"title": r.get("title", ""),
"url": r.get("link", ""),
"content": r.get("snippet", ""),
}
for r in organic[:max_results]
]
output = {
"query": query,
"total_results": len(normalized_results),
"results": normalized_results,
}
return json.dumps(output, indent=2, ensure_ascii=False)
@@ -1,6 +1,5 @@
from .app_config import get_app_config
from .extensions_config import ExtensionsConfig, get_extensions_config
from .loop_detection_config import LoopDetectionConfig
from .memory_config import MemoryConfig, get_memory_config
from .paths import Paths, get_paths
from .skill_evolution_config import SkillEvolutionConfig
@@ -21,7 +20,6 @@ __all__ = [
"SkillsConfig",
"ExtensionsConfig",
"get_extensions_config",
"LoopDetectionConfig",
"MemoryConfig",
"get_memory_config",
"get_tracing_config",
@@ -1,22 +1,13 @@
"""Configuration and loaders for custom agents.
Custom agents are stored per-user under ``{base_dir}/users/{user_id}/agents/{name}/``.
A legacy shared layout at ``{base_dir}/agents/{name}/`` is still readable so that
installations that pre-date user isolation continue to work until they run the
``scripts/migrate_user_isolation.py`` migration. New writes always target the
per-user layout.
"""
"""Configuration and loaders for custom agents."""
import logging
import re
from pathlib import Path
from typing import Any
import yaml
from pydantic import BaseModel
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
@@ -49,47 +40,14 @@ class AgentConfig(BaseModel):
skills: list[str] | None = None
def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
"""Return the on-disk directory for an agent, preferring the per-user layout.
Resolution order:
1. ``{base_dir}/users/{user_id}/agents/{name}/`` (per-user, current layout).
2. ``{base_dir}/agents/{name}/`` (legacy shared layout — read-only fallback).
If neither exists, the per-user path is returned so callers that intend to
create the agent write into the new layout.
Args:
name: Validated agent name.
user_id: Owner of the agent. Defaults to the effective user from the
request context (or ``"default"`` in no-auth mode).
"""
paths = get_paths()
effective_user = user_id or get_effective_user_id()
user_path = paths.user_agent_dir(effective_user, name)
if user_path.exists():
return user_path
legacy_path = paths.agent_dir(name)
if legacy_path.exists():
return legacy_path
return user_path
def load_agent_config(name: str | None, *, user_id: str | None = None) -> AgentConfig | None:
def load_agent_config(name: str | None) -> AgentConfig | None:
"""Load the custom or default agent's config from its directory.
Reads from the per-user layout first; falls back to the legacy shared layout
for installations that have not yet been migrated.
Args:
name: The agent name.
user_id: Owner of the agent. Defaults to the effective user from the
current request context.
Returns:
AgentConfig instance, or ``None`` if ``name`` is ``None``.
AgentConfig instance.
Raises:
FileNotFoundError: If the agent directory or config.yaml does not exist.
@@ -100,7 +58,7 @@ def load_agent_config(name: str | None, *, user_id: str | None = None) -> AgentC
return None
name = validate_agent_name(name)
agent_dir = resolve_agent_dir(name, user_id=user_id)
agent_dir = get_paths().agent_dir(name)
config_file = agent_dir / "config.yaml"
if not agent_dir.exists():
@@ -126,7 +84,7 @@ def load_agent_config(name: str | None, *, user_id: str | None = None) -> AgentC
return AgentConfig(**data)
def load_agent_soul(agent_name: str | None, *, user_id: str | None = None) -> str | None:
def load_agent_soul(agent_name: str | None) -> str | None:
"""Read the SOUL.md file for a custom agent, if it exists.
SOUL.md defines the agent's personality, values, and behavioral guardrails.
@@ -134,16 +92,11 @@ def load_agent_soul(agent_name: str | None, *, user_id: str | None = None) -> st
Args:
agent_name: The name of the agent or None for the default agent.
user_id: Owner of the agent. Defaults to the effective user from the
current request context.
Returns:
The SOUL.md content as a string, or None if the file does not exist.
"""
if agent_name:
agent_dir = resolve_agent_dir(agent_name, user_id=user_id)
else:
agent_dir = get_paths().base_dir
agent_dir = get_paths().agent_dir(agent_name) if agent_name else get_paths().base_dir
soul_path = agent_dir / SOUL_FILENAME
if not soul_path.exists():
return None
@@ -151,50 +104,32 @@ def load_agent_soul(agent_name: str | None, *, user_id: str | None = None) -> st
return content or None
def list_custom_agents(*, user_id: str | None = None) -> list[AgentConfig]:
def list_custom_agents() -> list[AgentConfig]:
"""Scan the agents directory and return all valid custom agents.
Returns the union of agents in the per-user layout and the legacy shared
layout, so that pre-migration installations remain visible until they are
migrated. Per-user entries shadow legacy entries with the same name.
Args:
user_id: Owner whose agents to list. Defaults to the effective user
from the current request context.
Returns:
List of AgentConfig for each valid agent directory found.
"""
paths = get_paths()
effective_user = user_id or get_effective_user_id()
agents_dir = get_paths().agents_dir
if not agents_dir.exists():
return []
seen: set[str] = set()
agents: list[AgentConfig] = []
user_root = paths.user_agents_dir(effective_user)
legacy_root = paths.agents_dir
for root in (user_root, legacy_root):
if not root.exists():
for entry in sorted(agents_dir.iterdir()):
if not entry.is_dir():
continue
for entry in sorted(root.iterdir()):
if not entry.is_dir():
continue
if entry.name in seen:
continue
config_file = entry / "config.yaml"
if not config_file.exists():
logger.debug(f"Skipping {entry.name}: no config.yaml")
continue
try:
agent_cfg = load_agent_config(entry.name, user_id=effective_user)
if agent_cfg is None:
continue
agents.append(agent_cfg)
seen.add(entry.name)
except Exception as e:
logger.warning(f"Skipping agent '{entry.name}': {e}")
config_file = entry / "config.yaml"
if not config_file.exists():
logger.debug(f"Skipping {entry.name}: no config.yaml")
continue
try:
agent_cfg = load_agent_config(entry.name)
agents.append(agent_cfg)
except Exception as e:
logger.warning(f"Skipping agent '{entry.name}': {e}")
agents.sort(key=lambda a: a.name)
return agents
@@ -1,6 +1,5 @@
import logging
import os
from collections.abc import Mapping
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Self
@@ -15,7 +14,6 @@ from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpo
from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.loop_detection_config import LoopDetectionConfig
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
@@ -101,7 +99,6 @@ class AppConfig(BaseModel):
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
model_config = ConfigDict(extra="allow")
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
@@ -160,54 +157,56 @@ class AppConfig(BaseModel):
config_data = cls.resolve_env_variables(config_data)
cls._apply_database_defaults(config_data)
# Load title config if present
if "title" in config_data:
load_title_config_from_dict(config_data["title"])
# Load summarization config if present
if "summarization" in config_data:
load_summarization_config_from_dict(config_data["summarization"])
# Load memory config if present
if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"])
# Always refresh agents API config so removed config sections reset
# singleton-backed state to its default/disabled values on reload.
load_agents_api_config_from_dict(config_data.get("agents_api") or {})
# Load subagents config if present
if "subagents" in config_data:
load_subagents_config_from_dict(config_data["subagents"])
# Load tool_search config if present
if "tool_search" in config_data:
load_tool_search_config_from_dict(config_data["tool_search"])
# Load guardrails config if present
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load circuit_breaker config if present
if "circuit_breaker" in config_data:
config_data["circuit_breaker"] = config_data["circuit_breaker"]
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])
# Load stream bridge config if present
if "stream_bridge" in config_data:
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
# Always refresh ACP agent config so removed entries do not linger across reloads.
load_acp_config_from_dict(config_data.get("acp_agents", {}))
# Load extensions config separately (it's in a different file)
extensions_config = ExtensionsConfig.from_file()
config_data["extensions"] = extensions_config.model_dump()
result = cls.model_validate(config_data)
acp_agents = cls._validate_acp_agents(config_data.get("acp_agents", {}))
cls._apply_singleton_configs(result, acp_agents)
return result
@classmethod
def _validate_acp_agents(
cls,
config_data: Mapping[str, Mapping[str, object]] | None,
) -> dict[str, ACPAgentConfig]:
if config_data is None:
config_data = {}
return {name: ACPAgentConfig(**cfg) for name, cfg in config_data.items()}
@classmethod
def _apply_singleton_configs(cls, config: Self, acp_agents: dict[str, ACPAgentConfig]) -> None:
from deerflow.config.checkpointer_config import get_checkpointer_config
previous_checkpointer_config = get_checkpointer_config()
load_title_config_from_dict(config.title.model_dump())
load_summarization_config_from_dict(config.summarization.model_dump())
load_memory_config_from_dict(config.memory.model_dump())
load_agents_api_config_from_dict(config.agents_api.model_dump())
load_subagents_config_from_dict(config.subagents.model_dump())
load_tool_search_config_from_dict(config.tool_search.model_dump())
load_guardrails_config_from_dict(config.guardrails.model_dump())
load_checkpointer_config_from_dict(config.checkpointer.model_dump() if config.checkpointer is not None else None)
load_stream_bridge_config_from_dict(config.stream_bridge.model_dump() if config.stream_bridge is not None else None)
load_acp_config_from_dict({name: agent.model_dump() for name, agent in acp_agents.items()})
if previous_checkpointer_config != config.checkpointer:
# These runtime singletons derive their backend from checkpointer config.
# Keep imports local to avoid cycles: both providers import get_app_config.
from deerflow.runtime.checkpointer import reset_checkpointer
from deerflow.runtime.store import reset_store
reset_checkpointer()
reset_store()
@classmethod
def _apply_database_defaults(cls, config_data: dict[str, Any]) -> None:
"""Apply config.yaml defaults for persistence when the section is absent."""
@@ -14,13 +14,12 @@ class CheckpointerConfig(BaseModel):
description="Checkpointer backend type. "
"'memory' is in-process only (lost on restart). "
"'sqlite' persists to a local file (requires langgraph-checkpoint-sqlite). "
"'postgres' persists to PostgreSQL (install with deerflow-harness[postgres])."
"'postgres' persists to PostgreSQL (requires langgraph-checkpoint-postgres)."
)
connection_string: str | None = Field(
default=None,
description="Connection string for sqlite (file path) or postgres (DSN). "
"Optional for sqlite and defaults to 'store.db' when omitted. "
"Required for postgres. "
"Required for sqlite and postgres types. "
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
)
@@ -41,10 +40,7 @@ def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
_checkpointer_config = config
def load_checkpointer_config_from_dict(config_dict: dict | None) -> None:
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
"""Load checkpointer configuration from a dictionary."""
global _checkpointer_config
if config_dict is None:
_checkpointer_config = None
return
_checkpointer_config = CheckpointerConfig(**config_dict)
@@ -1,73 +0,0 @@
"""Configuration for loop detection middleware."""
from pydantic import BaseModel, Field, model_validator
class ToolFreqOverride(BaseModel):
"""Per-tool frequency threshold override.
Can be higher or lower than the global defaults. Commonly used to raise
thresholds for high-frequency tools like bash in batch workflows (e.g.
RNA-seq pipelines) without weakening protection on every other tool.
"""
warn: int = Field(ge=1)
hard_limit: int = Field(ge=1)
@model_validator(mode="after")
def _validate(self) -> "ToolFreqOverride":
if self.hard_limit < self.warn:
raise ValueError("hard_limit must be >= warn")
return self
class LoopDetectionConfig(BaseModel):
"""Configuration for repetitive tool-call loop detection."""
enabled: bool = Field(
default=True,
description="Whether to enable repetitive tool-call loop detection",
)
warn_threshold: int = Field(
default=3,
ge=1,
description="Number of identical tool-call sets before injecting a warning",
)
hard_limit: int = Field(
default=5,
ge=1,
description="Number of identical tool-call sets before forcing a stop",
)
window_size: int = Field(
default=20,
ge=1,
description="Number of recent tool-call sets to track per thread",
)
max_tracked_threads: int = Field(
default=100,
ge=1,
description="Maximum number of thread histories to keep in memory",
)
tool_freq_warn: int = Field(
default=30,
ge=1,
description="Number of calls to the same tool type before injecting a frequency warning",
)
tool_freq_hard_limit: int = Field(
default=50,
ge=1,
description="Number of calls to the same tool type before forcing a stop",
)
tool_freq_overrides: dict[str, ToolFreqOverride] = Field(
default_factory=dict,
description=("Per-tool overrides for tool_freq_warn / tool_freq_hard_limit, keyed by tool name. Values can be higher or lower than the global defaults. Commonly used to raise thresholds for high-frequency tools like bash."),
)
@model_validator(mode="after")
def validate_thresholds(self) -> "LoopDetectionConfig":
"""Ensure hard stop cannot happen before the warning threshold."""
if self.hard_limit < self.warn_threshold:
raise ValueError("hard_limit must be greater than or equal to warn_threshold")
if self.tool_freq_hard_limit < self.tool_freq_warn:
raise ValueError("tool_freq_hard_limit must be greater than or equal to tool_freq_warn")
return self
@@ -132,20 +132,15 @@ class Paths:
@property
def agents_dir(self) -> Path:
"""Legacy root for shared (pre user-isolation) custom agents: `{base_dir}/agents/`.
New code should use :meth:`user_agents_dir` instead. This property remains
only as a read-side fallback for installations that have not yet run the
``migrate_user_isolation.py`` script.
"""
"""Root directory for all custom agents: `{base_dir}/agents/`."""
return self.base_dir / "agents"
def agent_dir(self, name: str) -> Path:
"""Legacy per-agent directory (no user isolation): `{base_dir}/agents/{name}/`."""
"""Directory for a specific agent: `{base_dir}/agents/{name}/`."""
return self.agents_dir / name.lower()
def agent_memory_file(self, name: str) -> Path:
"""Legacy per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
return self.agent_dir(name) / "memory.json"
def user_dir(self, user_id: str) -> Path:
@@ -156,17 +151,9 @@ class Paths:
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
return self.user_dir(user_id) / "memory.json"
def user_agents_dir(self, user_id: str) -> Path:
"""Per-user root for that user's custom agents: `{base_dir}/users/{user_id}/agents/`."""
return self.user_dir(user_id) / "agents"
def user_agent_dir(self, user_id: str, agent_name: str) -> Path:
"""Per-user per-agent directory: `{base_dir}/users/{user_id}/agents/{name}/`."""
return self.user_agents_dir(user_id) / agent_name.lower()
def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path:
"""Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`."""
return self.user_agent_dir(user_id, agent_name) / "memory.json"
return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json"
def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
@@ -23,9 +23,6 @@ class SandboxConfig(BaseModel):
replicas: Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.
container_prefix: Prefix for container names (default: deer-flow-sandbox)
idle_timeout: Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.
auto_restart: Automatically restart sandbox containers that have crashed (default: true). When a tool call
detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated
on the next acquire. Set to false to disable.
mounts: List of volume mounts to share directories with the container
environment: Environment variables to inject into the container (values starting with $ are resolved from host env)
"""
@@ -58,10 +55,6 @@ class SandboxConfig(BaseModel):
default=None,
description="Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.",
)
auto_restart: bool = Field(
default=True,
description="Automatically restart sandbox containers that have crashed. When a tool call detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated on the next acquire.",
)
mounts: list[VolumeMountConfig] = Field(
default_factory=list,
description="List of volume mounts to share directories between host and container",
@@ -6,13 +6,6 @@ from pydantic import BaseModel, Field
from deerflow.config.runtime_paths import project_root, resolve_path
def _legacy_skills_candidates() -> tuple[Path, ...]:
"""Return source-tree skills locations for monorepo compatibility."""
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
return (repo_root / "skills",)
class SkillsConfig(BaseModel):
"""Configuration for skills system"""
@@ -22,7 +15,7 @@ class SkillsConfig(BaseModel):
)
path: str | None = Field(
default=None,
description=("Path to skills directory. If not specified, defaults to `skills` under the caller project root, falling back to the legacy repo-root location for monorepo compatibility."),
description="Path to skills directory. If not specified, defaults to skills under the caller project root.",
)
container_path: str = Field(
default="/mnt/skills",
@@ -33,30 +26,15 @@ class SkillsConfig(BaseModel):
"""
Get the resolved skills directory path.
Resolution order:
1. Explicit ``path`` field
2. ``DEER_FLOW_SKILLS_PATH`` environment variable
3. ``skills`` under the caller project root (``project_root()``)
4. Legacy repo-root candidates for monorepo compatibility (``_legacy_skills_candidates``)
When none of (3) or (4) exist on disk, the project-root default is returned so callers
can still surface a stable "no skills" location without raising.
Returns:
Path to the skills directory
"""
if self.path:
# Use configured path (can be absolute or relative to project root)
return resolve_path(self.path)
if env_path := os.getenv("DEER_FLOW_SKILLS_PATH"):
return resolve_path(env_path)
project_default = project_root() / "skills"
if project_default.is_dir():
return project_default
for candidate in _legacy_skills_candidates():
if candidate.is_dir():
return candidate
return project_default
return project_root() / "skills"
def get_skill_container_path(self, skill_name: str, category: str = "public") -> str:
"""
@@ -40,10 +40,7 @@ def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
_stream_bridge_config = config
def load_stream_bridge_config_from_dict(config_dict: dict | None) -> None:
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
"""Load stream bridge configuration from a dictionary."""
global _stream_bridge_config
if config_dict is None:
_stream_bridge_config = None
return
_stream_bridge_config = StreamBridgeConfig(**config_dict)
@@ -179,3 +179,9 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
overrides_summary or "none",
custom_agents_names or "none",
)
else:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
)
@@ -4,4 +4,4 @@ from pydantic import BaseModel, Field
class TokenUsageConfig(BaseModel):
"""Configuration for token usage tracking."""
enabled: bool = Field(default=True, description="Enable token usage tracking middleware")
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
@@ -196,10 +196,6 @@ class ClaudeChatModel(ChatAnthropic):
enforced by both the Anthropic API and AWS Bedrock. Breakpoints are
placed on the *last* eligible blocks because later breakpoints cover a
larger prefix and yield better cache hit rates.
The system prompt is expected to be fully static (no per-user memory or
current date). Dynamic context is injected per-turn via
DynamicContextMiddleware as a <system-reminder> in the first HumanMessage.
"""
MAX_CACHE_BREAKPOINTS = 4
@@ -27,34 +27,6 @@ from deerflow.models.credential_loader import CodexCliCredential, load_codex_cli
logger = logging.getLogger(__name__)
CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
def _build_usage_metadata(oai_usage: dict) -> dict:
"""Convert Codex/Responses API usage dict to LangChain usage_metadata format.
Maps OpenAI Responses API token usage fields to the dict structure that
LangChain AIMessage.usage_metadata expects. This avoids depending on
langchain_openai private helpers like ``_create_usage_metadata_responses``.
"""
input_tokens = oai_usage.get("input_tokens", 0)
output_tokens = oai_usage.get("output_tokens", 0)
total_tokens = oai_usage.get("total_tokens", input_tokens + output_tokens)
metadata: dict = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}
input_details = oai_usage.get("input_tokens_details") or {}
output_details = oai_usage.get("output_tokens_details") or {}
cache_read = input_details.get("cached_tokens")
if cache_read is not None:
metadata["input_token_details"] = {"cache_read": cache_read}
reasoning = output_details.get("reasoning_tokens")
if reasoning is not None:
metadata["output_token_details"] = {"reasoning": reasoning}
return metadata
MAX_RETRIES = 3
@@ -374,7 +346,6 @@ class CodexChatModel(BaseChatModel):
)
usage = response.get("usage", {})
usage_metadata = _build_usage_metadata(usage) if usage else None
additional_kwargs = {}
if reasoning_content:
additional_kwargs["reasoning_content"] = reasoning_content
@@ -384,7 +355,6 @@ class CodexChatModel(BaseChatModel):
tool_calls=tool_calls if tool_calls else [],
invalid_tool_calls=invalid_tool_calls,
additional_kwargs=additional_kwargs,
usage_metadata=usage_metadata,
response_metadata={
"model": response.get("model", self.model),
"usage": usage,
@@ -81,16 +81,7 @@ async def init_engine(
try:
import asyncpg # noqa: F401
except ImportError:
raise ImportError(
"database.backend is set to 'postgres' but asyncpg is not installed.\n"
"Install it with:\n"
" cd backend && uv sync --all-packages --extra postgres\n"
"On the next `make dev` the postgres extra is auto-detected from\n"
"config.yaml (database.backend: postgres) and reinstalled, so it\n"
"will not be wiped again. Set UV_EXTRAS=postgres in .env to opt in\n"
"explicitly. Or switch to backend: sqlite in config.yaml for\n"
"single-node deployment."
) from None
raise ImportError("database.backend is set to 'postgres' but asyncpg is not installed.\nInstall it with:\n uv sync --extra postgres\nOr switch to backend: sqlite in config.yaml for single-node deployment.") from None
if backend == "sqlite":
import os
@@ -7,13 +7,13 @@ router for thread records.
from __future__ import annotations
import time
from typing import Any
from langgraph.store.base import BaseStore
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
from deerflow.utils.time import coerce_iso, now_iso
THREADS_NS: tuple[str, ...] = ("threads",)
@@ -48,7 +48,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
metadata: dict | None = None,
) -> dict:
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
now = now_iso()
now = time.time()
record: dict[str, Any] = {
"thread_id": thread_id,
"assistant_id": assistant_id,
@@ -106,7 +106,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
if record is None:
return
record["display_name"] = display_name
record["updated_at"] = now_iso()
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
@@ -114,7 +114,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
if record is None:
return
record["status"] = status
record["updated_at"] = now_iso()
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
@@ -124,7 +124,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
merged = dict(record.get("metadata") or {})
merged.update(metadata)
record["metadata"] = merged
record["updated_at"] = now_iso()
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
@@ -144,8 +144,6 @@ class MemoryThreadMetaStore(ThreadMetaStore):
"display_name": val.get("display_name"),
"status": val.get("status", "idle"),
"metadata": val.get("metadata", {}),
# ``coerce_iso`` heals legacy unix-second values written by
# earlier Gateway versions that called ``str(time.time())``.
"created_at": coerce_iso(val.get("created_at", "")),
"updated_at": coerce_iso(val.get("updated_at", "")),
"created_at": str(val.get("created_at", "")),
"updated_at": str(val.get("updated_at", "")),
}
@@ -25,7 +25,7 @@ from collections.abc import Iterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
@@ -36,9 +36,7 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
POSTGRES_INSTALL = (
"langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install the package extra with: pip install 'deerflow-harness[postgres]' (or use: uv sync --all-packages --extra postgres when developing locally)"
)
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
# ---------------------------------------------------------------------------
@@ -100,9 +98,78 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
_checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive
_explicit_checkpointers: dict[int, Checkpointer] = {}
_explicit_checkpointer_contexts: dict[int, object] = {}
def get_checkpointer() -> Checkpointer:
def _default_in_memory_checkpointer() -> Checkpointer:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
return InMemorySaver()
def _persistent_database_backend(db_config) -> str | None:
backend = getattr(db_config, "backend", None)
if backend in {"sqlite", "postgres"}:
return backend
return None
@contextlib.contextmanager
def _sync_checkpointer_from_database_cm(db_config) -> Iterator[Checkpointer]:
"""Context manager that creates a sync checkpointer from unified DatabaseConfig."""
backend = _persistent_database_backend(db_config)
if backend is None:
yield _default_in_memory_checkpointer()
return
if backend == "sqlite":
try:
from langgraph.checkpoint.sqlite import SqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = db_config.checkpointer_sqlite_path
ensure_sqlite_parent_dir(conn_str)
with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
yield saver
return
if backend == "postgres":
try:
from langgraph.checkpoint.postgres import PostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not db_config.postgres_url:
raise ValueError("database.postgres_url is required for the postgres backend")
with PostgresSaver.from_conn_string(db_config.postgres_url) as saver:
saver.setup()
logger.info("Checkpointer: using PostgresSaver")
yield saver
return
raise ValueError(f"Unknown database backend: {backend!r}")
def _build_checkpointer_from_app_config(app_config: AppConfig) -> tuple[Checkpointer, object | None]:
if app_config.checkpointer is not None:
ctx = _sync_checkpointer_cm(app_config.checkpointer)
return ctx.__enter__(), ctx
db_config = getattr(app_config, "database", None)
if _persistent_database_backend(db_config) is not None:
ctx = _sync_checkpointer_from_database_cm(db_config)
return ctx.__enter__(), ctx
return _default_in_memory_checkpointer(), None
def get_checkpointer(app_config: AppConfig | None = None) -> Checkpointer:
"""Return the global sync checkpointer singleton, creating it on first call.
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
@@ -113,6 +180,18 @@ def get_checkpointer() -> Checkpointer:
"""
global _checkpointer, _checkpointer_ctx
if app_config is not None:
cache_key = id(app_config)
cached = _explicit_checkpointers.get(cache_key)
if cached is not None:
return cached
explicit_checkpointer, explicit_ctx = _build_checkpointer_from_app_config(app_config)
_explicit_checkpointers[cache_key] = explicit_checkpointer
if explicit_ctx is not None:
_explicit_checkpointer_contexts[cache_key] = explicit_ctx
return explicit_checkpointer
if _checkpointer is not None:
return _checkpointer
@@ -123,28 +202,30 @@ def get_checkpointer() -> Checkpointer:
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
global_app_config = _app_config
if config is None and _app_config is None:
if config is None and global_app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
global_app_config = get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
_checkpointer = InMemorySaver()
if config is not None:
_checkpointer_ctx = _sync_checkpointer_cm(config)
_checkpointer = _checkpointer_ctx.__enter__()
return _checkpointer
_checkpointer_ctx = _sync_checkpointer_cm(config)
_checkpointer = _checkpointer_ctx.__enter__()
if global_app_config is not None:
_checkpointer, _checkpointer_ctx = _build_checkpointer_from_app_config(global_app_config)
return _checkpointer
_checkpointer = _default_in_memory_checkpointer()
return _checkpointer
@@ -163,6 +244,18 @@ def reset_checkpointer() -> None:
_checkpointer_ctx = None
_checkpointer = None
for cache_key, ctx in list(_explicit_checkpointer_contexts.items()):
try:
ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during explicit checkpointer cleanup", exc_info=True)
finally:
_explicit_checkpointer_contexts.pop(cache_key, None)
_explicit_checkpointers.pop(cache_key, None)
_explicit_checkpointers.clear()
_explicit_checkpointer_contexts.clear()
# ---------------------------------------------------------------------------
# Sync context manager
@@ -170,7 +263,7 @@ def reset_checkpointer() -> None:
@contextlib.contextmanager
def checkpointer_context() -> Iterator[Checkpointer]:
def checkpointer_context(app_config: AppConfig | None = None) -> Iterator[Checkpointer]:
"""Sync context manager that yields a checkpointer and cleans up on exit.
Unlike :func:`get_checkpointer`, this does **not** cache the instance
@@ -183,12 +276,16 @@ def checkpointer_context() -> Iterator[Checkpointer]:
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
resolved_app_config = app_config or get_app_config()
if resolved_app_config.checkpointer is not None:
with _sync_checkpointer_cm(resolved_app_config.checkpointer) as saver:
yield saver
return
with _sync_checkpointer_cm(config.checkpointer) as saver:
yield saver
db_config = getattr(resolved_app_config, "database", None)
if _persistent_database_backend(db_config) is not None:
with _sync_checkpointer_from_database_cm(db_config) as saver:
yield saver
return
yield _default_in_memory_checkpointer()
@@ -9,7 +9,6 @@ from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@@ -34,21 +33,20 @@ class DbRunEventStore(RunEventStore):
if isinstance(val, datetime):
d["created_at"] = val.isoformat()
d.pop("id", None)
# Restore structured content that was JSON-serialized on write.
# Restore dict content that was JSON-serialized on write
raw = d.get("content", "")
metadata = d.get("metadata", {})
if isinstance(raw, str) and (metadata.get("content_is_json") or metadata.get("content_is_dict")):
if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"):
try:
d["content"] = json.loads(raw)
except (json.JSONDecodeError, ValueError):
# Content looked like JSON but failed to parse;
# Content looked like JSON (content_is_dict flag) but failed to parse;
# keep the raw string as-is.
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
return d
def _truncate_trace(self, category: str, content: Any, metadata: dict | None) -> tuple[Any, dict]:
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
if category == "trace":
text = content if isinstance(content, str) else json.dumps(content, default=str, ensure_ascii=False)
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
encoded = text.encode("utf-8")
if len(encoded) > self._max_trace_content:
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
@@ -56,18 +54,6 @@ class DbRunEventStore(RunEventStore):
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
return content, metadata or {}
@staticmethod
def _content_to_db(content: Any, metadata: dict | None) -> tuple[str, dict]:
metadata = metadata or {}
if isinstance(content, str):
return content, metadata
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**metadata, "content_is_json": True}
if isinstance(content, dict):
metadata["content_is_dict"] = True
return db_content, metadata
@staticmethod
def _user_id_from_context() -> str | None:
"""Soft read of user_id from contextvar for write paths.
@@ -96,7 +82,11 @@ class DbRunEventStore(RunEventStore):
the initial ``human_message`` event (once per run).
"""
content, metadata = self._truncate_trace(category, content, metadata)
db_content, metadata = self._content_to_db(content, metadata)
if isinstance(content, dict):
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
@@ -138,7 +128,11 @@ class DbRunEventStore(RunEventStore):
category = e.get("category", "trace")
metadata = e.get("metadata")
content, metadata = self._truncate_trace(category, content, metadata)
db_content, metadata = self._content_to_db(content, metadata)
if isinstance(content, dict):
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
row = RunEventRow(
thread_id=e["thread_id"],
run_id=e["run_id"],
@@ -6,10 +6,9 @@ import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from deerflow.utils.time import now_iso as _now_iso
from .schemas import DisconnectMode, RunStatus
if TYPE_CHECKING:
@@ -18,6 +17,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _now_iso() -> str:
return datetime.now(UTC).isoformat()
@dataclass
class RunRecord:
"""Mutable record for a single run."""
@@ -23,8 +23,6 @@ from dataclasses import dataclass, field
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, cast
from langgraph.checkpoint.base import empty_checkpoint
if TYPE_CHECKING:
from langchain_core.messages import HumanMessage
@@ -444,12 +442,6 @@ async def _rollback_to_pre_run_checkpoint(
if checkpoint_to_restore.get("id") is None:
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
return
restore_marker = _new_checkpoint_marker()
checkpoint_to_restore = {
**checkpoint_to_restore,
"id": restore_marker["id"],
"ts": restore_marker["ts"],
}
metadata = pre_run_snapshot.get("metadata", {})
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
@@ -501,11 +493,6 @@ async def _rollback_to_pre_run_checkpoint(
)
def _new_checkpoint_marker() -> dict[str, str]:
marker = empty_checkpoint()
return {"id": marker["id"], "ts": marker["ts"]}
def _lg_mode_to_sse_event(mode: str) -> str:
"""Map LangGraph internal stream_mode name to SSE event name.
@@ -26,7 +26,7 @@ from collections.abc import Iterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -36,9 +36,7 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
SQLITE_STORE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite store. Install it with: uv add langgraph-checkpoint-sqlite"
POSTGRES_STORE_INSTALL = (
"langgraph-checkpoint-postgres is required for the PostgreSQL store. Install the package extra with: pip install 'deerflow-harness[postgres]' (or use: uv sync --all-packages --extra postgres when developing locally)"
)
POSTGRES_STORE_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL store. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
# ---------------------------------------------------------------------------
@@ -100,9 +98,26 @@ def _sync_store_cm(config) -> Iterator[BaseStore]:
_store: BaseStore | None = None
_store_ctx = None # open context manager keeping the connection alive
_explicit_stores: dict[int, BaseStore] = {}
_explicit_store_contexts: dict[int, object] = {}
def get_store() -> BaseStore:
def _default_in_memory_store() -> BaseStore:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
return InMemoryStore()
def _build_store_from_app_config(app_config: AppConfig) -> tuple[BaseStore, object | None]:
if app_config.checkpointer is not None:
ctx = _sync_store_cm(app_config.checkpointer)
return ctx.__enter__(), ctx
return _default_in_memory_store(), None
def get_store(app_config: AppConfig | None = None) -> BaseStore:
"""Return the global sync Store singleton, creating it on first call.
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
@@ -114,6 +129,18 @@ def get_store() -> BaseStore:
"""
global _store, _store_ctx
if app_config is not None:
cache_key = id(app_config)
cached = _explicit_stores.get(cache_key)
if cached is not None:
return cached
explicit_store, explicit_ctx = _build_store_from_app_config(app_config)
_explicit_stores[cache_key] = explicit_store
if explicit_ctx is not None:
_explicit_store_contexts[cache_key] = explicit_ctx
return explicit_store
if _store is not None:
return _store
@@ -132,10 +159,7 @@ def get_store() -> BaseStore:
config = get_checkpointer_config()
if config is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
_store = InMemoryStore()
_store = _default_in_memory_store()
return _store
_store_ctx = _sync_store_cm(config)
@@ -158,6 +182,18 @@ def reset_store() -> None:
_store_ctx = None
_store = None
for cache_key, ctx in list(_explicit_store_contexts.items()):
try:
ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during explicit store cleanup", exc_info=True)
finally:
_explicit_store_contexts.pop(cache_key, None)
_explicit_stores.pop(cache_key, None)
_explicit_stores.clear()
_explicit_store_contexts.clear()
# ---------------------------------------------------------------------------
# Sync context manager
@@ -165,7 +201,7 @@ def reset_store() -> None:
@contextlib.contextmanager
def store_context() -> Iterator[BaseStore]:
def store_context(app_config: AppConfig | None = None) -> Iterator[BaseStore]:
"""Sync context manager that yields a Store and cleans up on exit.
Unlike :func:`get_store`, this does **not** cache the instance each
@@ -178,13 +214,10 @@ def store_context() -> Iterator[BaseStore]:
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
resolved_app_config = app_config or get_app_config()
if resolved_app_config.checkpointer is None:
yield _default_in_memory_store()
return
with _sync_store_cm(config.checkpointer) as store:
with _sync_store_cm(resolved_app_config.checkpointer) as store:
yield store
@@ -42,13 +42,6 @@ class LocalSandbox(Sandbox):
"""Return whether the selected shell is cmd.exe."""
return LocalSandbox._shell_name(shell) in {"cmd", "cmd.exe"}
@staticmethod
def _is_msys_shell(shell: str) -> bool:
"""Return whether the selected shell is a Git Bash/MSYS shell."""
normalized = shell.replace("\\", "/").lower()
shell_name = LocalSandbox._shell_name(shell)
return shell_name in {"sh.exe", "bash.exe"} and any(part in normalized for part in ("/git/", "/mingw", "/msys"))
@staticmethod
def _find_first_available_shell(candidates: tuple[str, ...]) -> str | None:
"""Return the first executable shell path or command found from candidates."""
@@ -310,19 +303,12 @@ class LocalSandbox(Sandbox):
shell = self._get_shell()
if os.name == "nt":
env = None
if self._is_powershell(shell):
args = [shell, "-NoProfile", "-Command", resolved_command]
elif self._is_cmd_shell(shell):
args = [shell, "/c", resolved_command]
else:
args = [shell, "-c", resolved_command]
if self._is_msys_shell(shell):
env = {
**os.environ,
"MSYS_NO_PATHCONV": "1",
"MSYS2_ARG_CONV_EXCL": "*",
}
result = subprocess.run(
args,
@@ -330,7 +316,6 @@ class LocalSandbox(Sandbox):
capture_output=True,
text=True,
timeout=600,
env=env,
)
else:
args = [shell, "-c", resolved_command]
@@ -3,9 +3,10 @@ import re
import shlex
from pathlib import Path
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadDataState
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.config import get_app_config
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.sandbox.exceptions import (
@@ -18,7 +19,6 @@ from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.sandbox.search import GrepMatch
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.tools.types import Runtime
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
@@ -419,7 +419,7 @@ def _join_path_preserving_style(base: str, relative: str) -> str:
return f"{stripped_base}{separator}{normalized_relative}"
def _sanitize_error(error: Exception, runtime: Runtime | None = None) -> str:
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
"""Sanitize an error message to avoid leaking host filesystem paths.
In local-sandbox mode, resolved host paths in the error string are masked
@@ -994,7 +994,7 @@ def _apply_cwd_prefix(command: str, thread_data: ThreadDataState | None) -> str:
return command
def get_thread_data(runtime: Runtime | None) -> ThreadDataState | None:
def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> ThreadDataState | None:
"""Extract thread_data from runtime state."""
if runtime is None:
return None
@@ -1003,7 +1003,7 @@ def get_thread_data(runtime: Runtime | None) -> ThreadDataState | None:
return runtime.state.get("thread_data")
def is_local_sandbox(runtime: Runtime | None) -> bool:
def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool:
"""Check if the current sandbox is a local sandbox.
Path replacement is only needed for local sandbox since aio sandbox
@@ -1019,7 +1019,7 @@ def is_local_sandbox(runtime: Runtime | None) -> bool:
return sandbox_state.get("sandbox_id") == "local"
def sandbox_from_runtime(runtime: Runtime | None = None) -> Sandbox:
def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
"""Extract sandbox instance from tool runtime.
DEPRECATED: Use ensure_sandbox_initialized() for lazy initialization support.
@@ -1048,7 +1048,7 @@ def sandbox_from_runtime(runtime: Runtime | None = None) -> Sandbox:
return sandbox
def ensure_sandbox_initialized(runtime: Runtime | None = None) -> Sandbox:
def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
"""Ensure sandbox is initialized, acquiring lazily if needed.
On first call, acquires a sandbox from the provider and stores it in runtime state.
@@ -1107,7 +1107,7 @@ def ensure_sandbox_initialized(runtime: Runtime | None = None) -> Sandbox:
return sandbox
def ensure_thread_directories_exist(runtime: Runtime | None) -> None:
def ensure_thread_directories_exist(runtime: ToolRuntime[ContextT, ThreadState] | None) -> None:
"""Ensure thread data directories (workspace, uploads, outputs) exist.
This function is called lazily when any sandbox tool is first used.
@@ -1221,7 +1221,7 @@ def _truncate_ls_output(output: str, max_chars: int) -> str:
@tool("bash", parse_docstring=True)
def bash_tool(runtime: Runtime, description: str, command: str) -> str:
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
"""Execute a bash command in a Linux environment.
@@ -1270,7 +1270,7 @@ def bash_tool(runtime: Runtime, description: str, command: str) -> str:
@tool("ls", parse_docstring=True)
def ls_tool(runtime: Runtime, description: str, path: str) -> str:
def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path: str) -> str:
"""List the contents of a directory up to 2 levels deep in tree format.
Args:
@@ -1318,7 +1318,7 @@ def ls_tool(runtime: Runtime, description: str, path: str) -> str:
@tool("glob", parse_docstring=True)
def glob_tool(
runtime: Runtime,
runtime: ToolRuntime[ContextT, ThreadState],
description: str,
pattern: str,
path: str,
@@ -1368,7 +1368,7 @@ def glob_tool(
@tool("grep", parse_docstring=True)
def grep_tool(
runtime: Runtime,
runtime: ToolRuntime[ContextT, ThreadState],
description: str,
pattern: str,
path: str,
@@ -1438,7 +1438,7 @@ def grep_tool(
@tool("read_file", parse_docstring=True)
def read_file_tool(
runtime: Runtime,
runtime: ToolRuntime[ContextT, ThreadState],
description: str,
path: str,
start_line: int | None = None,
@@ -1493,7 +1493,7 @@ def read_file_tool(
@tool("write_file", parse_docstring=True)
def write_file_tool(
runtime: Runtime,
runtime: ToolRuntime[ContextT, ThreadState],
description: str,
path: str,
content: str,
@@ -1533,7 +1533,7 @@ def write_file_tool(
@tool("str_replace", parse_docstring=True)
def str_replace_tool(
runtime: Runtime,
runtime: ToolRuntime[ContextT, ThreadState],
description: str,
path: str,
old_str: str,
@@ -9,29 +9,6 @@ from .types import SKILL_MD_FILE, Skill, SkillCategory
logger = logging.getLogger(__name__)
def parse_allowed_tools(raw: object, skill_file: Path) -> list[str] | None:
"""Parse the optional allowed-tools frontmatter field.
Returns None when the field is omitted. Returns a list when the field is a
YAML sequence of strings, including an empty list for explicit no-tool
skills. Raises ValueError for malformed values.
"""
if raw is None:
return None
if not isinstance(raw, list):
raise ValueError(f"allowed-tools in {skill_file} must be a list of strings")
allowed_tools: list[str] = []
for item in raw:
if not isinstance(item, str):
raise ValueError(f"allowed-tools in {skill_file} must contain only strings")
tool_name = item.strip()
if not tool_name:
raise ValueError(f"allowed-tools in {skill_file} cannot contain empty tool names")
allowed_tools.append(tool_name)
return allowed_tools
def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: Path | None = None) -> Skill | None:
"""Parse a SKILL.md file and extract metadata.
@@ -87,12 +64,6 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
if license_text is not None:
license_text = str(license_text).strip() or None
try:
allowed_tools = parse_allowed_tools(metadata.get("allowed-tools"), skill_file)
except ValueError as exc:
logger.error("Invalid allowed-tools in %s: %s", skill_file, exc)
return None
return Skill(
name=name,
description=description,
@@ -101,7 +72,6 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
skill_file=skill_file,
relative_path=relative_path or Path(skill_file.parent.name),
category=category,
allowed_tools=allowed_tools,
enabled=True, # Actual state comes from the extensions config file.
)
@@ -1,44 +0,0 @@
import logging
from typing import Protocol
from deerflow.skills.types import Skill
logger = logging.getLogger(__name__)
class NamedTool(Protocol):
name: str
def allowed_tool_names_for_skills(skills: list[Skill]) -> set[str] | None:
"""Return the union of explicit skill allowed-tools declarations.
None means legacy allow-all behavior. It is returned only when no loaded
skill declares allowed-tools. Once any skill declares the field, legacy
skills without the field contribute no tools instead of disabling the
explicit restrictions from other skills.
"""
if not skills:
return None
allowed: set[str] = set()
has_explicit_declaration = False
for skill in skills:
if skill.allowed_tools is None:
continue
has_explicit_declaration = True
if not skill.allowed_tools:
logger.info("Skill %s declared empty allowed-tools", skill.name)
allowed.update(skill.allowed_tools)
if not has_explicit_declaration:
return None
return allowed
def filter_tools_by_skill_allowed_tools[ToolT: NamedTool](tools: list[ToolT], skills: list[Skill]) -> list[ToolT]:
allowed = allowed_tool_names_for_skills(skills)
if allowed is None:
return tools
return [tool for tool in tools if tool.name in allowed]
@@ -27,7 +27,6 @@ class Skill:
skill_file: Path
relative_path: Path # Relative path from category root to skill directory
category: SkillCategory # 'public' or 'custom'
allowed_tools: list[str] | None = None
enabled: bool = False # Whether this skill is enabled
@property
@@ -8,7 +8,6 @@ from pathlib import Path
import yaml
from deerflow.skills.parser import parse_allowed_tools
from deerflow.skills.types import SKILL_MD_FILE
# Allowed properties in SKILL.md frontmatter
@@ -85,9 +84,4 @@ def _validate_skill_frontmatter(skill_dir: Path) -> tuple[bool, str, str | None]
if len(description) > 1024:
return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters.", None
try:
parse_allowed_tools(frontmatter.get("allowed-tools"), skill_md)
except ValueError as e:
return False, str(e).replace(str(skill_md), SKILL_MD_FILE), None
return True, "Skill is valid!", name
@@ -23,8 +23,6 @@ from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadSt
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
logger = logging.getLogger(__name__)
@@ -262,16 +260,16 @@ class SubagentExecutor:
# Generate trace_id if not provided (for top-level calls)
self.trace_id = trace_id or str(uuid.uuid4())[:8]
self._base_tools = _filter_tools(
# Filter tools based on config
self.tools = _filter_tools(
tools,
config.tools,
config.disallowed_tools,
)
self.tools = self._base_tools
logger.info(f"[trace={self.trace_id}] SubagentExecutor initialized: {config.name} with {len(self.tools)} tools")
def _create_agent(self, tools: list[BaseTool] | None = None):
def _create_agent(self):
"""Create the agent instance."""
app_config = self.app_config or get_app_config()
if self.model_name is None:
@@ -285,44 +283,13 @@ class SubagentExecutor:
return create_agent(
model=model,
tools=tools if tools is not None else self.tools,
tools=self.tools,
middleware=middlewares,
system_prompt=self.config.system_prompt,
state_schema=ThreadState,
)
async def _load_skills(self) -> list[Skill]:
"""Load enabled skill metadata based on config.skills."""
if self.config.skills is not None and len(self.config.skills) == 0:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} skills=[] — skipping skill loading")
return []
try:
from deerflow.skills.storage import get_or_new_skill_storage
storage_kwargs = {"app_config": self.app_config} if self.app_config is not None else {}
storage = await asyncio.to_thread(get_or_new_skill_storage, **storage_kwargs)
# Use asyncio.to_thread to avoid blocking the event loop (LangGraph ASGI requirement)
all_skills = await asyncio.to_thread(storage.load_skills, enabled_only=True)
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} loaded {len(all_skills)} enabled skills from disk")
except Exception:
logger.exception(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}")
raise
if not all_skills:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} no enabled skills found")
return []
# Filter by config.skills whitelist
if self.config.skills is not None:
allowed = set(self.config.skills)
return [s for s in all_skills if s.name in allowed]
return all_skills
def _apply_skill_allowed_tools(self, skills: list[Skill]) -> list[BaseTool]:
return filter_tools_by_skill_allowed_tools(self._base_tools, skills)
async def _load_skill_messages(self, skills: list[Skill]) -> list[SystemMessage]:
async def _load_skill_messages(self) -> list[SystemMessage]:
"""Load skill content as conversation items based on config.skills.
Aligned with Codex's pattern: each subagent loads its own skills
@@ -336,6 +303,33 @@ class SubagentExecutor:
Returns:
List of SystemMessages containing skill content.
"""
if self.config.skills is not None and len(self.config.skills) == 0:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} skills=[] — skipping skill loading")
return []
try:
from deerflow.skills.storage import get_or_new_skill_storage
storage_kwargs = {"app_config": self.app_config} if self.app_config is not None else {}
storage = await asyncio.to_thread(get_or_new_skill_storage, **storage_kwargs)
# Use asyncio.to_thread to avoid blocking the event loop (LangGraph ASGI requirement)
all_skills = await asyncio.to_thread(storage.load_skills, enabled_only=True)
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} loaded {len(all_skills)} enabled skills from disk")
except Exception:
logger.warning(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}", exc_info=True)
return []
if not all_skills:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} no enabled skills found")
return []
# Filter by config.skills whitelist
if self.config.skills is not None:
allowed = set(self.config.skills)
skills = [s for s in all_skills if s.name in allowed]
else:
skills = all_skills
if not skills:
return []
@@ -353,21 +347,19 @@ class SubagentExecutor:
return messages
async def _build_initial_state(self, task: str) -> tuple[dict[str, Any], list[BaseTool]]:
async def _build_initial_state(self, task: str) -> dict[str, Any]:
"""Build the initial state for agent execution.
Args:
task: The task description.
Returns:
Initial state dictionary and tools filtered by loaded skill metadata.
Initial state dictionary.
"""
# Load skills as conversation items (Codex pattern)
skills = await self._load_skills()
filtered_tools = self._apply_skill_allowed_tools(skills)
skill_messages = await self._load_skill_messages(skills)
skill_messages = await self._load_skill_messages()
messages: list[Any] = []
messages: list = []
# Skill content injected as developer/system messages before the task
messages.extend(skill_messages)
# Then the actual task
@@ -383,7 +375,7 @@ class SubagentExecutor:
if self.thread_data is not None:
state["thread_data"] = self.thread_data
return state, filtered_tools
return state
async def _aexecute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
"""Execute a task asynchronously.
@@ -413,8 +405,8 @@ class SubagentExecutor:
result.ai_messages = ai_messages
try:
state, filtered_tools = await self._build_initial_state(task)
agent = self._create_agent(filtered_tools)
agent = self._create_agent()
state = await self._build_initial_state(task)
# Build config with thread_id for sandbox access and recursion limit
run_config: RunnableConfig = {
@@ -2,12 +2,10 @@ from .clarification_tool import ask_clarification_tool
from .present_file_tool import present_file_tool
from .setup_agent_tool import setup_agent
from .task_tool import task_tool
from .update_agent_tool import update_agent
from .view_image_tool import view_image_tool
__all__ = [
"setup_agent",
"update_agent",
"present_file_tool",
"ask_clarification_tool",
"view_image_tool",
@@ -1,19 +1,20 @@
from pathlib import Path
from typing import Annotated
from langchain.tools import InjectedToolCallId, tool
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langchain_core.messages import ToolMessage
from langgraph.config import get_config
from langgraph.types import Command
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadState
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.tools.types import Runtime
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
def _get_thread_id(runtime: Runtime) -> str | None:
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
"""Resolve the current thread id from runtime context or RunnableConfig."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
@@ -31,7 +32,7 @@ def _get_thread_id(runtime: Runtime) -> str | None:
def _normalize_presented_filepath(
runtime: Runtime,
runtime: ToolRuntime[ContextT, ThreadState],
filepath: str,
) -> str:
"""Normalize a presented file path to the `/mnt/user-data/outputs/*` contract.
@@ -82,7 +83,7 @@ def _normalize_presented_filepath(
@tool("present_files", parse_docstring=True)
def present_file_tool(
runtime: Runtime,
runtime: ToolRuntime[ContextT, ThreadState],
filepaths: list[str],
tool_call_id: Annotated[str, InjectedToolCallId],
) -> Command:
@@ -3,28 +3,20 @@ import logging
import yaml
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langgraph.prebuilt import ToolRuntime
from langgraph.types import Command
from deerflow.config.agents_config import validate_agent_name
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
def _get_runtime_user_id(runtime: Runtime) -> str:
context_user_id = runtime.context.get("user_id") if runtime.context else None
if context_user_id:
return str(context_user_id)
return get_effective_user_id()
@tool
def setup_agent(
soul: str,
description: str,
runtime: Runtime,
runtime: ToolRuntime,
skills: list[str] | None = None,
) -> Command:
"""Setup the custom DeerFlow agent.
@@ -42,14 +34,7 @@ def setup_agent(
try:
agent_name = validate_agent_name(agent_name)
paths = get_paths()
if agent_name:
# Custom agents are persisted under the current user's bucket so
# different users do not see each other's agents.
user_id = _get_runtime_user_id(runtime)
agent_dir = paths.user_agent_dir(user_id, agent_name)
else:
# Default agent (no agent_name): SOUL.md lives at the global base dir.
agent_dir = paths.base_dir
agent_dir = paths.agent_dir(agent_name) if agent_name else paths.base_dir
is_new_dir = not agent_dir.exists()
agent_dir.mkdir(parents=True, exist_ok=True)
@@ -6,9 +6,11 @@ import uuid
from dataclasses import replace
from typing import TYPE_CHECKING, Annotated, Any, cast
from langchain.tools import InjectedToolCallId, tool
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langgraph.config import get_stream_writer
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadState
from deerflow.config import get_app_config
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
@@ -19,7 +21,6 @@ from deerflow.subagents.executor import (
get_background_task_result,
request_cancel_background_task,
)
from deerflow.tools.types import Runtime
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
@@ -49,11 +50,12 @@ def _merge_skill_allowlists(parent: list[str] | None, child: list[str] | None) -
@tool("task", parse_docstring=True)
async def task_tool(
runtime: Runtime,
runtime: ToolRuntime[ContextT, ThreadState],
description: str,
prompt: str,
subagent_type: str,
tool_call_id: Annotated[str, InjectedToolCallId],
max_turns: int | None = None,
) -> str:
"""Delegate a task to a specialized subagent that runs in its own context.
@@ -89,6 +91,7 @@ async def task_tool(
description: A short (3-5 word) description of the task for logging/display. ALWAYS PROVIDE THIS PARAMETER FIRST.
prompt: The task description for the subagent. Be specific and clear about what needs to be done. ALWAYS PROVIDE THIS PARAMETER SECOND.
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.
"""
runtime_app_config = _get_runtime_app_config(runtime)
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
@@ -110,6 +113,9 @@ async def task_tool(
# each subagent loads its own skills based on config, injected as conversation items).
# No longer appended to system_prompt here.
if max_turns is not None:
overrides["max_turns"] = max_turns
# Extract parent context from runtime
sandbox_state = None
thread_data = None
@@ -1,241 +0,0 @@
"""update_agent tool — let a custom agent persist updates to its own SOUL.md / config.
Bound to the lead agent only when ``runtime.context['agent_name']`` is set
(i.e. inside an existing custom agent's chat). The default agent does not see
this tool, and the bootstrap flow continues to use ``setup_agent`` for the
initial creation handshake.
The tool writes back to ``{base_dir}/users/{user_id}/agents/{agent_name}/{config.yaml,SOUL.md}``
so an agent created by one user is never visible to (or mutable by) another.
Writes are staged into temp files first; both files are renamed into place only
after both temp files are successfully written, so a partial failure cannot leave
config.yaml updated while SOUL.md still holds stale content.
"""
from __future__ import annotations
import logging
import tempfile
from pathlib import Path
from typing import Any
import yaml
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langgraph.types import Command
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
def _stage_temp(path: Path, text: str) -> Path:
"""Write ``text`` into a sibling temp file and return its path.
The caller is responsible for ``Path.replace``-ing the temp into the target
once every staged file is ready, or for unlinking it on failure.
"""
path.parent.mkdir(parents=True, exist_ok=True)
fd = tempfile.NamedTemporaryFile(
mode="w",
dir=path.parent,
suffix=".tmp",
delete=False,
encoding="utf-8",
)
try:
fd.write(text)
fd.flush()
fd.close()
return Path(fd.name)
except BaseException:
fd.close()
Path(fd.name).unlink(missing_ok=True)
raise
def _cleanup_temps(temps: list[Path]) -> None:
"""Best-effort removal of staged temp files."""
for tmp in temps:
try:
tmp.unlink(missing_ok=True)
except OSError:
logger.debug("Failed to clean up temp file %s", tmp, exc_info=True)
@tool
def update_agent(
runtime: Runtime,
soul: str | None = None,
description: str | None = None,
skills: list[str] | None = None,
tool_groups: list[str] | None = None,
model: str | None = None,
) -> Command:
"""Persist updates to the current custom agent's SOUL.md and config.yaml.
Use this when the user asks to refine the agent's identity, description,
skill whitelist, tool-group whitelist, or default model. Only the fields
you explicitly pass are updated; omitted fields keep their existing values.
Pass ``soul`` as the FULL replacement SOUL.md content there is no patch
semantics, so always start from the current SOUL and apply your edits.
Pass ``skills=[]`` to disable all skills for this agent. Omit ``skills``
entirely to keep the existing whitelist.
Args:
soul: Optional full replacement SOUL.md content.
description: Optional new one-line description.
skills: Optional skill whitelist. ``[]`` = no skills, omit = unchanged.
tool_groups: Optional tool-group whitelist. ``[]`` = empty, omit = unchanged.
model: Optional model override (must match a configured model name).
Returns:
Command with a ToolMessage describing the result. Changes take effect
on the next user turn (when the lead agent is rebuilt with the fresh
SOUL.md and config.yaml).
"""
tool_call_id = runtime.tool_call_id
agent_name_raw: str | None = runtime.context.get("agent_name") if runtime.context else None
def _err(message: str) -> Command:
return Command(update={"messages": [ToolMessage(content=f"Error: {message}", tool_call_id=tool_call_id)]})
if soul is None and description is None and skills is None and tool_groups is None and model is None:
return _err("No fields provided. Pass at least one of: soul, description, skills, tool_groups, model.")
try:
agent_name = validate_agent_name(agent_name_raw)
except ValueError as e:
return _err(str(e))
if not agent_name:
return _err("update_agent is only available inside a custom agent's chat. There is no agent_name in the current runtime context, so there is nothing to update. If you are inside the bootstrap flow, use setup_agent instead.")
# Resolve the active user so that updates only affect this user's agent.
# ``get_effective_user_id`` returns DEFAULT_USER_ID when no auth context
# is set (matching how memory and thread storage behave).
user_id = get_effective_user_id()
# Reject an unknown ``model`` *before* touching the filesystem. Otherwise
# ``_resolve_model_name`` silently falls back to the default at runtime
# and the user sees confusing repeated warnings on every later turn.
if model is not None and get_app_config().get_model_config(model) is None:
return _err(f"Unknown model '{model}'. Pass a model name that exists in config.yaml's models section.")
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, agent_name)
if not agent_dir.exists() and paths.agent_dir(agent_name).exists():
return _err(f"Agent '{agent_name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before updating.")
try:
existing_cfg = load_agent_config(agent_name, user_id=user_id)
except FileNotFoundError:
return _err(f"Agent '{agent_name}' does not exist for the current user. Use setup_agent to create a new agent first.")
except ValueError as e:
return _err(f"Agent '{agent_name}' has an unreadable config: {e}")
if existing_cfg is None:
return _err(f"Agent '{agent_name}' could not be loaded.")
updated_fields: list[str] = []
# Force the on-disk ``name`` to match the directory we are writing into,
# even if ``existing_cfg.name`` had drifted (e.g. from manual yaml edits).
config_data: dict[str, Any] = {"name": agent_name}
new_description = description if description is not None else existing_cfg.description
config_data["description"] = new_description
if description is not None and description != existing_cfg.description:
updated_fields.append("description")
new_model = model if model is not None else existing_cfg.model
if new_model is not None:
config_data["model"] = new_model
if model is not None and model != existing_cfg.model:
updated_fields.append("model")
new_tool_groups = tool_groups if tool_groups is not None else existing_cfg.tool_groups
if new_tool_groups is not None:
config_data["tool_groups"] = new_tool_groups
if tool_groups is not None and tool_groups != existing_cfg.tool_groups:
updated_fields.append("tool_groups")
new_skills = skills if skills is not None else existing_cfg.skills
if new_skills is not None:
config_data["skills"] = new_skills
if skills is not None and skills != existing_cfg.skills:
updated_fields.append("skills")
config_changed = bool({"description", "model", "tool_groups", "skills"} & set(updated_fields))
# Stage every file we intend to rewrite into a temp sibling. Only after
# *all* temp files exist do we rename them into place — so a failure on
# SOUL.md cannot leave config.yaml already replaced.
pending: list[tuple[Path, Path]] = []
staged_temps: list[Path] = []
try:
agent_dir.mkdir(parents=True, exist_ok=True)
if config_changed:
yaml_text = yaml.dump(config_data, default_flow_style=False, allow_unicode=True, sort_keys=False)
config_target = agent_dir / "config.yaml"
config_tmp = _stage_temp(config_target, yaml_text)
staged_temps.append(config_tmp)
pending.append((config_tmp, config_target))
if soul is not None:
soul_target = agent_dir / "SOUL.md"
soul_tmp = _stage_temp(soul_target, soul)
staged_temps.append(soul_tmp)
pending.append((soul_tmp, soul_target))
updated_fields.append("soul")
# Commit phase. ``Path.replace`` is atomic per file on POSIX/NTFS and
# the staging step above means any earlier failure has already been
# reported. The remaining failure mode is a crash *between* two
# ``replace`` calls, which is reported via the partial-write error
# branch below so the caller knows which files are now on disk.
committed: list[Path] = []
try:
for tmp, target in pending:
tmp.replace(target)
committed.append(target)
except Exception as e:
_cleanup_temps([t for t, _ in pending if t not in committed])
if committed:
logger.error(
"[update_agent] Partial write for agent '%s' (user=%s): committed=%s, failed during rename: %s",
agent_name,
user_id,
[p.name for p in committed],
e,
exc_info=True,
)
return _err(f"Partial update for agent '{agent_name}': {[p.name for p in committed]} were updated, but the rest failed ({e}). Re-run update_agent to retry the remaining fields.")
raise
except Exception as e:
_cleanup_temps(staged_temps)
logger.error("[update_agent] Failed to update agent '%s' (user=%s): %s", agent_name, user_id, e, exc_info=True)
return _err(f"Failed to update agent '{agent_name}': {e}")
if not updated_fields:
return Command(update={"messages": [ToolMessage(content=f"No changes applied to agent '{agent_name}'. The provided values matched the existing config.", tool_call_id=tool_call_id)]})
logger.info("[update_agent] Updated agent '%s' (user=%s) fields: %s", agent_name, user_id, updated_fields)
return Command(
update={
"messages": [
ToolMessage(
content=(f"Agent '{agent_name}' updated successfully. Changed: {', '.join(updated_fields)}. The new configuration takes effect on the next user turn."),
tool_call_id=tool_call_id,
)
]
}
)
@@ -3,13 +3,13 @@ import mimetypes
from pathlib import Path
from typing import Annotated
from langchain.tools import InjectedToolCallId, tool
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langchain_core.messages import ToolMessage
from langgraph.types import Command
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadDataState
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.tools.types import Runtime
_ALLOWED_IMAGE_VIRTUAL_ROOTS = (
f"{VIRTUAL_PATH_PREFIX}/workspace",
@@ -48,7 +48,7 @@ def _sanitize_image_error(error: Exception, thread_data: ThreadDataState | None)
@tool("view_image", parse_docstring=True)
def view_image_tool(
runtime: Runtime,
runtime: ToolRuntime[ContextT, ThreadState],
image_path: str,
tool_call_id: Annotated[str, InjectedToolCallId],
) -> Command:
@@ -7,15 +7,16 @@ import logging
from typing import Any
from weakref import WeakValueDictionary
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.agents.thread_state import ThreadState
from deerflow.mcp.tools import _make_sync_tool_wrapper
from deerflow.skills.security_scanner import scan_skill_content
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.storage.skill_storage import SkillStorage
from deerflow.skills.types import SKILL_MD_FILE
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
@@ -30,7 +31,7 @@ def _get_lock(name: str) -> asyncio.Lock:
return lock
def _get_thread_id(runtime: Runtime | None) -> str | None:
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
if runtime is None:
return None
if runtime.context and runtime.context.get("thread_id"):
@@ -64,7 +65,7 @@ async def _to_thread(func, /, *args, **kwargs):
async def _skill_manage_impl(
runtime: Runtime,
runtime: ToolRuntime[ContextT, ThreadState],
action: str,
name: str,
content: str | None = None,
@@ -203,7 +204,7 @@ async def _skill_manage_impl(
@tool("skill_manage", parse_docstring=True)
async def skill_manage_tool(
runtime: Runtime,
runtime: ToolRuntime[ContextT, ThreadState],
action: str,
name: str,
content: str | None = None,
@@ -1,11 +0,0 @@
from typing import Any
from langchain.tools import ToolRuntime
from deerflow.agents.thread_state import ThreadState
# Concrete runtime type used by all DeerFlow tools.
# Using dict[str, Any] for the context parameter instead of the unbound ContextT
# TypeVar prevents PydanticSerializationUnexpectedValue warnings when LangChain
# calls model_dump() on a tool's auto-generated args_schema.
Runtime = ToolRuntime[dict[str, Any], ThreadState]
@@ -4,10 +4,8 @@ Pure business logic — no FastAPI/HTTP dependencies.
Both Gateway and Client delegate to these functions.
"""
import errno
import os
import re
import stat
from pathlib import Path
from urllib.parse import quote
@@ -19,10 +17,6 @@ class PathTraversalError(ValueError):
"""Raised when a path escapes its allowed base directory."""
class UnsafeUploadPathError(ValueError):
"""Raised when an upload destination is not a safe regular file path."""
# thread_id must be alphanumeric, hyphens, underscores, or dots only.
_SAFE_THREAD_ID = re.compile(r"^[a-zA-Z0-9._-]+$")
@@ -115,108 +109,6 @@ def validate_path_traversal(path: Path, base: Path) -> None:
raise PathTraversalError("Path traversal detected") from None
def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, object]:
"""Open an upload destination for safe streaming writes.
Upload directories may be mounted into local sandboxes. A sandbox process can
therefore leave a symlink at a future upload filename. Normal ``Path.write_bytes``
follows that link and can overwrite files outside the uploads directory with
gateway privileges. This helper rejects symlink destinations using ``O_NOFOLLOW``
on POSIX. On Windows (which lacks ``O_NOFOLLOW``), it uses dual ``lstat`` checks
and ``fstat`` validation after ``open()`` to reduce the TOCTOU window; this does
not eliminate all races but makes exploitation significantly harder. Path-traversal
validation prevents escapes from *base_dir* in both cases.
"""
safe_name = normalize_filename(filename)
dest = base_dir / safe_name
try:
st = os.lstat(dest)
except FileNotFoundError:
st = None
if st is not None and not stat.S_ISREG(st.st_mode):
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
validate_path_traversal(dest, base_dir)
has_nofollow = hasattr(os, "O_NOFOLLOW")
if has_nofollow:
# POSIX: O_NOFOLLOW makes open() fail with ELOOP if dest is a symlink.
flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW
if hasattr(os, "O_NONBLOCK"):
flags |= os.O_NONBLOCK
try:
fd = os.open(dest, flags, 0o600)
except OSError as exc:
if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
raise
try:
opened_stat = os.fstat(fd)
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1:
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
os.ftruncate(fd, 0)
fh = os.fdopen(fd, "wb")
fd = -1
finally:
if fd >= 0:
os.close(fd)
return dest, fh
# Windows: no O_NOFOLLOW available. Uses a second lstat immediately before open()
# to narrow the TOCTOU window, then fstat after open() as a further defence.
# Note: a narrow race window remains between the pre-open lstat and open(); the
# path-traversal check mitigates escapes from base_dir but cannot prevent an
# attacker who can atomically replace dest with a symlink after the check.
if st is not None and st.st_nlink > 1:
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
flags = os.O_WRONLY | os.O_CREAT
if hasattr(os, "O_BINARY"):
flags |= os.O_BINARY
try:
pre_open_st = os.lstat(dest)
except FileNotFoundError:
pre_open_st = None
if pre_open_st is not None and not stat.S_ISREG(pre_open_st.st_mode):
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
if pre_open_st is not None and pre_open_st.st_nlink > 1:
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
try:
fd = os.open(dest, flags, 0o600)
except OSError as exc:
if exc.errno in {errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
raise
try:
opened_stat = os.fstat(fd)
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink > 1:
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
os.ftruncate(fd, 0)
fh = os.fdopen(fd, "wb")
fd = -1
finally:
if fd >= 0:
os.close(fd)
return dest, fh
def write_upload_file_no_symlink(base_dir: Path, filename: str, data: bytes) -> Path:
"""Write upload bytes without following a pre-existing destination symlink."""
dest, fh = open_upload_file_no_symlink(base_dir, filename)
with fh:
fh.write(data)
return dest
def list_files_in_dir(directory: Path) -> dict:
"""List files (not directories) in *directory*.
@@ -1,75 +0,0 @@
"""ISO 8601 timestamp helpers for the Gateway and embedded runtime.
DeerFlow stores and serializes thread/run timestamps as ISO 8601 UTC
strings to match the LangGraph Platform schema (see
``langgraph_sdk.schema.Thread``, where ``created_at`` / ``updated_at``
are ``datetime`` and JSON-encode to ISO 8601). All timestamp generation
should funnel through :func:`now_iso` so the wire format stays
consistent across endpoints, the embedded ``RunManager``, and the
checkpoint metadata written by the Gateway.
:func:`coerce_iso` provides a forward-compatible read path for legacy
records that historically stored ``str(time.time())`` floats.
"""
from __future__ import annotations
import re
from datetime import UTC, datetime
__all__ = ["coerce_iso", "now_iso"]
_UNIX_TIMESTAMP_PATTERN = re.compile(r"^\d{10}(?:\.\d+)?$")
"""Matches the unix-timestamp string shape historically written by
``str(time.time())`` (10-digit seconds with optional fractional part).
The 10-digit anchor avoids accidentally rewriting ISO years like
``"2026"`` and stays valid until the year 2286.
"""
def now_iso() -> str:
"""Return the current UTC time as an ISO 8601 string.
Example: ``"2026-04-27T03:19:46.511479+00:00"``.
"""
return datetime.now(UTC).isoformat()
def coerce_iso(value: object) -> str:
"""Best-effort coerce a stored timestamp to an ISO 8601 string.
Translates legacy unix-timestamp floats / strings written by older
DeerFlow versions into ISO without a one-shot migration. ISO strings
pass through unchanged; ``datetime`` instances are normalised to UTC
(tz-naive values are assumed to be UTC) and emitted via
``isoformat()`` so the wire format always uses the ``T`` separator;
empty values become ``""``; unrecognised values are stringified as a
last resort.
"""
if value is None or value == "":
return ""
if isinstance(value, bool):
# ``bool`` is a subclass of ``int`` — treat as garbage, not 0/1.
return str(value)
if isinstance(value, datetime):
# ``datetime`` must be handled before the ``int``/``float`` check;
# str(datetime) would produce ``"YYYY-MM-DD HH:MM:SS+00:00"``
# (space separator), which breaks strict ISO 8601 consumers.
if value.tzinfo is None:
value = value.replace(tzinfo=UTC)
else:
value = value.astimezone(UTC)
return value.isoformat()
if isinstance(value, (int, float)):
try:
return datetime.fromtimestamp(float(value), UTC).isoformat()
except (ValueError, OverflowError, OSError):
return str(value)
if isinstance(value, str):
if _UNIX_TIMESTAMP_PATTERN.match(value):
try:
return datetime.fromtimestamp(float(value), UTC).isoformat()
except (ValueError, OverflowError, OSError):
return value
return value
return str(value)
+2 -1
View File
@@ -8,7 +8,7 @@ dependencies = [
"deerflow-harness",
"fastapi>=0.115.0",
"httpx>=0.28.0",
"python-multipart>=0.0.27",
"python-multipart>=0.0.26",
"sse-starlette>=2.1.0",
"uvicorn[standard]>=0.34.0",
"lark-oapi>=1.4.0",
@@ -47,3 +47,4 @@ members = ["packages/harness"]
[tool.uv.sources]
deerflow-harness = { workspace = true }
+3 -86
View File
@@ -1,7 +1,7 @@
"""One-time migration: move legacy thread dirs and memory into per-user layout.
Usage:
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run] [--user-id USER_ID]
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run]
The script is idempotent re-running it after a successful migration is a no-op.
"""
@@ -69,67 +69,6 @@ def migrate_thread_dirs(
return report
def migrate_agents(
paths: Paths,
user_id: str = "default",
*,
dry_run: bool = False,
) -> list[dict]:
"""Move legacy custom-agent directories into per-user layout.
Legacy layout: ``{base_dir}/agents/{name}/``
Per-user layout: ``{base_dir}/users/{user_id}/agents/{name}/``
Pre-existing per-user agents take precedence: if a destination already
exists for an agent name, the legacy copy is moved to
``{base_dir}/migration-conflicts/agents/{name}/`` for manual review.
Args:
paths: Paths instance.
user_id: Target user to receive the legacy agents (defaults to
``"default"``, matching ``DEFAULT_USER_ID`` for no-auth setups).
dry_run: If True, only log what would happen.
Returns:
List of migration report entries, one per legacy agent directory found.
"""
report: list[dict] = []
legacy_agents = paths.agents_dir
if not legacy_agents.exists():
logger.info("No legacy agents directory found — nothing to migrate.")
return report
for agent_dir in sorted(legacy_agents.iterdir()):
if not agent_dir.is_dir():
continue
agent_name = agent_dir.name
dest = paths.user_agent_dir(user_id, agent_name)
entry = {"agent": agent_name, "user_id": user_id, "action": ""}
if dest.exists():
conflicts_dir = paths.base_dir / "migration-conflicts" / "agents" / agent_name
entry["action"] = f"conflict -> {conflicts_dir}"
if not dry_run:
conflicts_dir.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(agent_dir), str(conflicts_dir))
logger.warning("Conflict for agent %s: moved legacy copy to %s", agent_name, conflicts_dir)
else:
entry["action"] = f"moved -> {dest}"
if not dry_run:
dest.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(agent_dir), str(dest))
logger.info("Migrated agent %s -> user %s", agent_name, user_id)
report.append(entry)
# Clean up empty legacy agents dir
if not dry_run and legacy_agents.exists() and not any(legacy_agents.iterdir()):
legacy_agents.rmdir()
return report
def migrate_memory(
paths: Paths,
user_id: str = "default",
@@ -188,12 +127,6 @@ def _build_owner_map_from_db(paths: Paths) -> dict[str, str]:
def main() -> None:
parser = argparse.ArgumentParser(description="Migrate DeerFlow data to per-user layout")
parser.add_argument("--dry-run", action="store_true", help="Log actions without making changes")
parser.add_argument(
"--user-id",
default="default",
metavar="USER_ID",
help=("User ID to claim un-owned legacy data (global memory.json and legacy custom agents). Defaults to 'default'. In multi-user installs, set this to the operator account that should inherit those legacy artifacts."),
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@@ -201,42 +134,26 @@ def main() -> None:
paths = get_paths()
logger.info("Base directory: %s", paths.base_dir)
logger.info("Dry run: %s", args.dry_run)
logger.info("Claiming un-owned legacy data for user_id=%s", args.user_id)
owner_map = _build_owner_map_from_db(paths)
logger.info("Found %d thread ownership records in DB", len(owner_map))
report = migrate_thread_dirs(paths, owner_map, dry_run=args.dry_run)
migrate_memory(paths, user_id=args.user_id, dry_run=args.dry_run)
agent_report = migrate_agents(paths, user_id=args.user_id, dry_run=args.dry_run)
migrate_memory(paths, user_id="default", dry_run=args.dry_run)
if report:
logger.info("Thread migration report:")
logger.info("Migration report:")
for entry in report:
logger.info(" thread=%s user=%s action=%s", entry["thread_id"], entry["user_id"], entry["action"])
else:
logger.info("No threads to migrate.")
if agent_report:
logger.info("Agent migration report:")
for entry in agent_report:
logger.info(" agent=%s user=%s action=%s", entry["agent"], entry["user_id"], entry["action"])
else:
logger.info("No agents to migrate.")
unowned = [e for e in report if e["user_id"] == "default"]
if unowned:
logger.warning("%d thread(s) had no owner and were assigned to 'default':", len(unowned))
for e in unowned:
logger.warning(" %s", e["thread_id"])
if agent_report:
logger.warning(
"%d legacy agent(s) were assigned to '%s'. If those agents belonged to other users, move them manually under {base_dir}/users/<user_id>/agents/.",
len(agent_report),
args.user_id,
)
if __name__ == "__main__":
main()
@@ -1,210 +0,0 @@
"""Tests for AioSandboxProvider auto-restart of crashed containers."""
import importlib
import threading
from unittest.mock import MagicMock, patch
def _import_provider():
return importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
def _make_provider(*, auto_restart=True, alive=True):
"""Build a minimal AioSandboxProvider with a mock backend.
Args:
auto_restart: Value for the auto_restart config key.
alive: Whether the mock backend reports containers as alive.
"""
mod = _import_provider()
with patch.object(mod.AioSandboxProvider, "_start_idle_checker"):
provider = mod.AioSandboxProvider.__new__(mod.AioSandboxProvider)
provider._config = {"auto_restart": auto_restart}
provider._lock = threading.Lock()
provider._sandboxes = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._thread_locks = {}
provider._last_activity = {}
provider._warm_pool = {}
provider._shutdown_called = False
provider._idle_checker_stop = threading.Event()
backend = MagicMock()
backend.is_alive.return_value = alive
provider._backend = backend
return provider, backend
def _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1"):
"""Insert a sandbox into the provider's caches as if it were acquired."""
sandbox = MagicMock()
info = MagicMock()
provider._sandboxes[sandbox_id] = sandbox
provider._sandbox_infos[sandbox_id] = info
provider._last_activity[sandbox_id] = 0.0
if thread_id:
provider._thread_sandboxes[thread_id] = sandbox_id
return sandbox, info
# ── get() returns sandbox when container is alive ──────────────────────────
def test_get_returns_sandbox_when_container_alive():
"""When auto_restart is on and the container is alive, get() returns the sandbox."""
provider, backend = _make_provider(auto_restart=True, alive=True)
sandbox, _ = _seed_sandbox(provider)
result = provider.get("dead-beef")
assert result is sandbox
backend.is_alive.assert_called_once()
def test_get_returns_sandbox_when_auto_restart_disabled():
"""When auto_restart is off, get() skips the health check entirely."""
provider, backend = _make_provider(auto_restart=False)
sandbox, _ = _seed_sandbox(provider)
result = provider.get("dead-beef")
assert result is sandbox
backend.is_alive.assert_not_called()
# ── get() evicts dead sandbox when auto_restart is on ──────────────────────
def test_get_evicts_dead_sandbox_when_auto_restart_enabled():
"""When the container is dead and auto_restart is on, get() returns None and cleans caches."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_, info = _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1")
result = provider.get("dead-beef")
assert result is None
assert "dead-beef" not in provider._sandboxes
assert "dead-beef" not in provider._sandbox_infos
assert "dead-beef" not in provider._last_activity
assert "thread-1" not in provider._thread_sandboxes
backend.destroy.assert_called_once_with(info)
def test_get_returns_dead_sandbox_when_auto_restart_disabled():
"""When auto_restart is off, get() returns the cached sandbox even if the container is dead."""
provider, backend = _make_provider(auto_restart=False, alive=False)
sandbox, _ = _seed_sandbox(provider)
result = provider.get("dead-beef")
assert result is sandbox
# Caches are untouched
assert "dead-beef" in provider._sandboxes
def test_get_eviction_cleans_multiple_thread_mappings():
"""A sandbox mapped to multiple thread IDs has all mappings cleaned on eviction."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="t-a")
# Manually add a second thread mapping to the same sandbox
provider._thread_sandboxes["t-b"] = "sid-1"
result = provider.get("sid-1")
assert result is None
assert "t-a" not in provider._thread_sandboxes
assert "t-b" not in provider._thread_sandboxes
# ── get() does not check health for unknown sandbox IDs ────────────────────
def test_get_returns_none_for_unknown_id():
"""If the sandbox_id is not in cache, get() returns None without checking health."""
provider, backend = _make_provider(auto_restart=True, alive=True)
result = provider.get("nonexistent")
assert result is None
backend.is_alive.assert_not_called()
# ── get() handles missing sandbox_info gracefully ──────────────────────────
def test_get_handles_missing_info_gracefully():
"""If sandbox is cached but info is missing, get() skips the health check."""
provider, backend = _make_provider(auto_restart=True, alive=False)
sandbox = MagicMock()
provider._sandboxes["sid-x"] = sandbox
provider._sandbox_infos.pop("sid-x", None) # Ensure no info
provider._last_activity["sid-x"] = 0.0
result = provider.get("sid-x")
# No info → cannot call is_alive → sandbox returned as-is
assert result is sandbox
backend.is_alive.assert_not_called()
def test_get_liveness_check_runs_outside_provider_lock():
"""get() should not hold the provider lock while checking backend liveness."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-locked", thread_id="thread-1")
def _assert_lock_not_held(_):
assert not provider._lock.locked()
return False
backend.is_alive.side_effect = _assert_lock_not_held
assert provider.get("sid-locked") is None
def test_get_still_evicts_when_backend_destroy_fails():
"""Cleanup errors should not keep stale sandbox state in memory."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-fail", thread_id="thread-1")
backend.destroy.side_effect = RuntimeError("boom")
assert provider.get("sid-fail") is None
assert "sid-fail" not in provider._sandboxes
assert "sid-fail" not in provider._sandbox_infos
assert "thread-1" not in provider._thread_sandboxes
backend.destroy.assert_called_once()
# ── Integration: eviction clears caches for recreation ─────────────────────
def test_eviction_clears_all_caches_for_recreation():
"""After eviction, all caches are clean so _acquire_internal can recreate.
This verifies the preconditions for transparent restart: when get() evicts
a dead sandbox, the next _acquire_internal call will find no cached entry,
no warm-pool entry, and fall through to _create_sandbox.
"""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="thread-1")
# Before eviction: caches populated
assert "sid-1" in provider._sandboxes
assert "sid-1" in provider._sandbox_infos
assert "thread-1" in provider._thread_sandboxes
# get() detects the dead container and evicts
assert provider.get("sid-1") is None
# After eviction: all caches clean
assert "sid-1" not in provider._sandboxes
assert "sid-1" not in provider._sandbox_infos
assert "thread-1" not in provider._thread_sandboxes
assert "sid-1" not in provider._warm_pool
# _acquire_internal for the same thread would find nothing cached
# and generate the deterministic ID, then discover fails (container
# is gone), falling through to _create_sandbox — a fresh start.
+1 -213
View File
@@ -4,40 +4,10 @@ import json
import os
from pathlib import Path
import pytest
import yaml
from pydantic import ValidationError
import deerflow.config.app_config as app_config_module
from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.agents_api_config import get_agents_api_config, load_agents_api_config_from_dict
from deerflow.config.agents_api_config import get_agents_api_config
from deerflow.config.app_config import AppConfig, get_app_config, reset_app_config
from deerflow.config.checkpointer_config import get_checkpointer_config, load_checkpointer_config_from_dict
from deerflow.config.guardrails_config import get_guardrails_config, load_guardrails_config_from_dict
from deerflow.config.memory_config import get_memory_config, load_memory_config_from_dict
from deerflow.config.stream_bridge_config import get_stream_bridge_config, load_stream_bridge_config_from_dict
from deerflow.config.subagents_config import get_subagents_app_config, load_subagents_config_from_dict
from deerflow.config.summarization_config import get_summarization_config, load_summarization_config_from_dict
from deerflow.config.title_config import get_title_config, load_title_config_from_dict
from deerflow.config.tool_search_config import get_tool_search_config, load_tool_search_config_from_dict
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
from deerflow.runtime.store import get_store, reset_store
def _reset_config_singletons() -> None:
load_title_config_from_dict({})
load_summarization_config_from_dict({})
load_memory_config_from_dict({})
load_agents_api_config_from_dict({})
load_subagents_config_from_dict({})
load_tool_search_config_from_dict({})
load_guardrails_config_from_dict({})
load_checkpointer_config_from_dict(None)
load_stream_bridge_config_from_dict(None)
load_acp_config_from_dict({})
reset_checkpointer()
reset_store()
reset_app_config()
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
@@ -83,23 +53,6 @@ def _write_config_with_agents_api(
path.write_text(yaml.safe_dump(config), encoding="utf-8")
def _write_config_with_sections(path: Path, sections: dict | None = None) -> None:
config = {
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
"models": [
{
"name": "first-model",
"use": "langchain_openai:ChatOpenAI",
"model": "gpt-test",
}
],
}
if sections:
config.update(sections)
path.write_text(yaml.safe_dump(config), encoding="utf-8")
def _write_extensions_config(path: Path) -> None:
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
@@ -222,168 +175,3 @@ def test_get_app_config_resets_agents_api_config_when_section_removed(tmp_path,
assert get_agents_api_config().enabled is False
finally:
reset_app_config()
def test_get_app_config_resets_singleton_configs_when_sections_removed(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_sections(
config_path,
{
"title": {"enabled": False, "max_words": 3},
"summarization": {"enabled": True},
"memory": {"enabled": False, "max_facts": 50},
"subagents": {"timeout_seconds": 42, "agents": {"reviewer": {"max_turns": 2}}},
"tool_search": {"enabled": True},
"guardrails": {"enabled": True, "fail_closed": False},
"checkpointer": {"type": "memory"},
"stream_bridge": {"type": "memory", "queue_maxsize": 12},
},
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_app_config()
try:
get_app_config()
assert get_title_config().enabled is False
assert get_summarization_config().enabled is True
assert get_memory_config().enabled is False
assert get_subagents_app_config().timeout_seconds == 42
assert get_tool_search_config().enabled is True
assert get_guardrails_config().enabled is True
assert get_checkpointer_config() is not None
assert get_stream_bridge_config() is not None
_write_config_with_sections(config_path)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
get_app_config()
assert get_title_config().enabled is True
assert get_summarization_config().enabled is False
assert get_memory_config().enabled is True
assert get_subagents_app_config().timeout_seconds == 900
assert get_tool_search_config().enabled is False
assert get_guardrails_config().enabled is False
assert get_checkpointer_config() is None
assert get_stream_bridge_config() is None
finally:
_reset_config_singletons()
def test_get_app_config_resets_persistence_runtime_singletons_when_checkpointer_removed(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_sections(config_path, {"checkpointer": {"type": "memory"}})
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_checkpointer()
reset_store()
reset_app_config()
try:
get_app_config()
initial_checkpointer = get_checkpointer()
initial_store = get_store()
_write_config_with_sections(config_path)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
get_app_config()
assert get_checkpointer_config() is None
assert get_checkpointer() is not initial_checkpointer
assert get_store() is not initial_store
finally:
_reset_config_singletons()
def test_get_app_config_keeps_persistence_runtime_singletons_when_checkpointer_unchanged(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_sections(
config_path,
{
"title": {"enabled": False},
"checkpointer": {"type": "memory"},
},
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
_reset_config_singletons()
try:
get_app_config()
initial_checkpointer = get_checkpointer()
initial_store = get_store()
_write_config_with_sections(
config_path,
{
"title": {"enabled": True},
"checkpointer": {"type": "memory"},
},
)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
get_app_config()
assert get_checkpointer() is initial_checkpointer
assert get_store() is initial_store
finally:
_reset_config_singletons()
def test_get_app_config_does_not_mutate_singletons_when_reload_validation_fails(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_sections(
config_path,
{
"title": {"enabled": False},
"tool_search": {"enabled": True},
"checkpointer": {"type": "memory"},
},
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
_reset_config_singletons()
try:
previous_app_config = get_app_config()
initial_checkpointer = get_checkpointer()
initial_store = get_store()
_write_config_with_sections(
config_path,
{
"title": False,
"tool_search": False,
"checkpointer": {"type": "memory"},
},
)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
with pytest.raises(ValidationError):
get_app_config()
assert app_config_module._app_config is previous_app_config
assert get_title_config().enabled is False
assert get_tool_search_config().enabled is True
assert get_checkpointer_config() is not None
assert get_checkpointer() is initial_checkpointer
assert get_store() is initial_store
finally:
_reset_config_singletons()
+1 -105
View File
@@ -3,12 +3,11 @@
from __future__ import annotations
import asyncio
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
from app.channels.base import Channel
from app.channels.message_bus import InboundMessage, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.message_bus import MessageBus, OutboundMessage, ResolvedAttachment
def _run(coro):
@@ -249,109 +248,6 @@ class TestResolveAttachments:
assert result[0].filename == "data.csv"
# ---------------------------------------------------------------------------
# Inbound file ingestion tests
# ---------------------------------------------------------------------------
class TestInboundFileIngestion:
def test_rejects_preexisting_symlink_destination(self, tmp_path):
from app.channels import manager
uploads_dir = tmp_path / "uploads"
uploads_dir.mkdir()
outside_file = tmp_path / "outside-created.txt"
(uploads_dir / "victim.txt").symlink_to(outside_file)
msg = InboundMessage(
channel_name="test-channel",
chat_id="chat-1",
user_id="user-1",
text="see attachment",
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
)
async def fake_reader(file_info, client):
return b"attacker data"
with (
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
):
result = _run(manager._ingest_inbound_files("thread-1", msg))
assert result == []
assert not outside_file.exists()
assert (uploads_dir / "victim.txt").is_symlink()
def test_rejects_dangling_symlink_destination(self, tmp_path):
from app.channels import manager
uploads_dir = tmp_path / "uploads"
uploads_dir.mkdir()
missing_target = tmp_path / "missing-created.txt"
(uploads_dir / "victim.txt").symlink_to(missing_target)
msg = InboundMessage(
channel_name="test-channel",
chat_id="chat-1",
user_id="user-1",
text="see attachment",
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
)
async def fake_reader(file_info, client):
return b"attacker data"
with (
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
):
result = _run(manager._ingest_inbound_files("thread-1", msg))
assert result == []
assert not missing_target.exists()
assert (uploads_dir / "victim.txt").is_symlink()
def test_hardlinked_existing_file_is_not_overwritten(self, tmp_path):
from app.channels import manager
uploads_dir = tmp_path / "uploads"
uploads_dir.mkdir()
outside_file = tmp_path / "outside-created.txt"
outside_file.write_text("protected", encoding="utf-8")
os.link(outside_file, uploads_dir / "victim.txt")
msg = InboundMessage(
channel_name="test-channel",
chat_id="chat-1",
user_id="user-1",
text="see attachment",
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
)
async def fake_reader(file_info, client):
return b"new attachment data"
with (
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
):
result = _run(manager._ingest_inbound_files("thread-1", msg))
assert result == [
{
"filename": "victim_1.txt",
"size": len(b"new attachment data"),
"path": "/mnt/user-data/uploads/victim_1.txt",
"is_image": False,
}
]
assert outside_file.read_text(encoding="utf-8") == "protected"
assert (uploads_dir / "victim.txt").read_text(encoding="utf-8") == "protected"
assert (uploads_dir / "victim_1.txt").read_bytes() == b"new attachment data"
# ---------------------------------------------------------------------------
# Channel base class _on_outbound with attachments
# ---------------------------------------------------------------------------
-199
View File
@@ -372,37 +372,6 @@ class TestExtractResponseText:
# Should return "" (no text in current turn), NOT "Hi there!" from previous turn
assert _extract_response_text(result) == ""
def test_does_not_publish_loop_warning_on_tool_calling_ai_message(self):
"""Loop-detection warning text on a tool-calling AI message is middleware-authored."""
from app.channels.manager import _extract_response_text
result = {
"messages": [
{"type": "human", "content": "search the repo"},
{
"type": "ai",
"content": "[LOOP DETECTED] You are repeating the same tool calls.",
"tool_calls": [{"name": "grep", "args": {"pattern": "TODO"}, "id": "call_1"}],
},
]
}
assert _extract_response_text(result) == ""
def test_preserves_visible_text_when_stripping_loop_warning(self):
from app.channels.manager import _extract_response_text
result = {
"messages": [
{"type": "human", "content": "prepare the report"},
{
"type": "ai",
"content": "Here is the report.\n\n[LOOP DETECTED] You are repeating the same tool calls.",
"tool_calls": [{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/report.md"]}, "id": "call_1"}],
},
]
}
assert _extract_response_text(result) == "Here is the report."
# ---------------------------------------------------------------------------
# ChannelManager tests
@@ -466,47 +435,6 @@ class TestChannelManager:
assert headers["Cookie"] == f"csrf_token={csrf_token}"
assert headers["X-DeerFlow-Internal-Token"]
def test_fetch_gateway_includes_internal_auth_headers(self, monkeypatch):
from app.channels.manager import ChannelManager
class MockResponse:
def raise_for_status(self):
return None
def json(self):
return {"models": [{"name": "default"}]}
class MockAsyncClient:
def __init__(self, *args, **kwargs):
return None
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, **kwargs):
calls.append({"url": url, **kwargs})
return MockResponse()
calls = []
monkeypatch.setattr("app.channels.manager.httpx.AsyncClient", MockAsyncClient)
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store, gateway_url="http://gateway:8001")
reply = await manager._fetch_gateway("/api/models", "models")
assert reply == "Available models:\n• default"
assert calls[0]["url"] == "http://gateway:8001/api/models"
assert calls[0]["timeout"] == 10
assert calls[0]["headers"]["X-DeerFlow-Internal-Token"]
_run(go())
def test_handle_chat_calls_channel_receive_file_for_inbound_files(self, monkeypatch):
from app.channels.manager import ChannelManager
@@ -602,8 +530,6 @@ class TestChannelManager:
assert call_args[0][0] == "test-thread-123" # thread_id
assert call_args[0][1] == "lead_agent" # assistant_id
assert call_args[1]["input"]["messages"][0]["content"] == "hi"
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
assert len(outbound_received) == 1
assert outbound_received[0].text == "Hello from agent!"
@@ -735,135 +661,12 @@ class TestChannelManager:
call_args = mock_client.runs.wait.call_args
assert call_args[0][1] == "lead_agent"
assert call_args[1]["config"]["recursion_limit"] == 55
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
assert call_args[1]["context"]["thinking_enabled"] is False
assert call_args[1]["context"]["subagent_enabled"] is True
assert call_args[1]["context"]["agent_name"] == "mobile-agent"
_run(go())
def test_clarification_follow_up_preserves_history(self):
"""Conversation should continue after ask_clarification instead of resetting history."""
from app.channels.manager import ChannelManager
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store)
outbound_received = []
async def capture_outbound(msg):
outbound_received.append(msg)
bus.subscribe_outbound(capture_outbound)
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
del assistant_id, context # unused in this test, kept for signature parity
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
key = (thread_id, str(checkpoint_ns))
history = history_by_checkpoint.setdefault(key, [])
human_text = input["messages"][0]["content"]
history.append(human_text)
if len(history) == 1:
return {
"messages": [
{"type": "human", "content": history[0]},
{
"type": "ai",
"content": "",
"tool_calls": [
{
"name": "ask_clarification",
"args": {"question": "Which environment should I use?"},
}
],
},
{
"type": "tool",
"name": "ask_clarification",
"content": "Which environment should I use?",
},
]
}
if len(history) == 2 and history[0] == "Deploy my app" and history[1] == "prod":
return {
"messages": [
{"type": "human", "content": history[0]},
{
"type": "ai",
"content": "",
"tool_calls": [
{
"name": "ask_clarification",
"args": {"question": "Which environment should I use?"},
}
],
},
{
"type": "tool",
"name": "ask_clarification",
"content": "Which environment should I use?",
},
{"type": "human", "content": history[1]},
{"type": "ai", "content": "Got it. I will deploy to prod."},
]
}
return {
"messages": [
{"type": "human", "content": history[-1]},
{"type": "ai", "content": "History missing; clarification repeated."},
]
}
mock_client = MagicMock()
mock_client.threads.create = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
mock_client.threads.get = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
mock_client.runs.wait = AsyncMock(side_effect=_runs_wait)
manager._client = mock_client
await manager.start()
await bus.publish_inbound(
InboundMessage(
channel_name="test",
chat_id="chat1",
user_id="user1",
text="Deploy my app",
)
)
await _wait_for(lambda: len(outbound_received) >= 1)
await bus.publish_inbound(
InboundMessage(
channel_name="test",
chat_id="chat1",
user_id="user1",
text="prod",
)
)
await _wait_for(lambda: len(outbound_received) >= 2)
await manager.stop()
assert outbound_received[0].text == "Which environment should I use?"
assert outbound_received[1].text == "Got it. I will deploy to prod."
assert mock_client.runs.wait.call_count == 2
first_call = mock_client.runs.wait.call_args_list[0]
second_call = mock_client.runs.wait.call_args_list[1]
assert first_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
assert second_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
_run(go())
def test_handle_chat_uses_user_session_overrides(self):
from app.channels.manager import ChannelManager
@@ -1540,8 +1343,6 @@ class TestChannelManager:
call_args = mock_client.runs.stream.call_args
assert call_args[1]["input"]["messages"][0]["content"] == "hello"
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
assert call_args[1]["context"]["is_bootstrap"] is True
# Final message should be published
+49 -41
View File
@@ -1,8 +1,7 @@
"""Unit tests for checkpointer config, packaging metadata, and factories."""
"""Unit tests for checkpointer config and singleton factory."""
import sys
import tomllib
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -15,8 +14,6 @@ from deerflow.config.checkpointer_config import (
set_checkpointer_config,
)
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
from deerflow.runtime.checkpointer.provider import POSTGRES_INSTALL
from deerflow.runtime.store.provider import POSTGRES_STORE_INSTALL
@pytest.fixture(autouse=True)
@@ -71,42 +68,6 @@ class TestCheckpointerConfig:
with pytest.raises(Exception):
load_checkpointer_config_from_dict({"type": "unknown"})
def test_connection_string_description_matches_runtime_defaults(self):
description = CheckpointerConfig.model_fields["connection_string"].description
assert description is not None
assert "Optional for sqlite" in description
assert "defaults to 'store.db'" in description
assert "Required for postgres" in description
class TestHarnessPackaging:
def test_pyproject_declares_postgres_extra(self):
pyproject_path = Path(__file__).resolve().parents[1] / "packages" / "harness" / "pyproject.toml"
data = tomllib.loads(pyproject_path.read_text())
optional_dependencies = data["project"]["optional-dependencies"]
assert "postgres" in optional_dependencies
assert optional_dependencies["postgres"] == [
"asyncpg>=0.29",
"langgraph-checkpoint-postgres>=3.0.5",
"psycopg[binary]>=3.3.3",
"psycopg-pool>=3.3.0",
]
def test_workspace_pyproject_forwards_postgres_extra_to_harness(self):
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
data = tomllib.loads(pyproject_path.read_text())
optional_dependencies = data["project"]["optional-dependencies"]
assert optional_dependencies["postgres"] == ["deerflow-harness[postgres]"]
def test_postgres_missing_dependency_messages_recommend_package_extra(self):
assert "deerflow-harness[postgres]" in POSTGRES_INSTALL
assert "deerflow-harness[postgres]" in POSTGRES_STORE_INSTALL
assert "uv sync --all-packages --extra postgres" in POSTGRES_INSTALL
assert "uv sync --all-packages --extra postgres" in POSTGRES_STORE_INSTALL
# ---------------------------------------------------------------------------
# Factory tests
@@ -143,6 +104,53 @@ class TestGetCheckpointer:
cp2 = get_checkpointer()
assert cp1 is not cp2
def test_explicit_app_config_bypasses_global_config_lookup(self):
from langgraph.checkpoint.memory import InMemorySaver
explicit_config = SimpleNamespace(
checkpointer=CheckpointerConfig(type="memory"),
database=SimpleNamespace(backend="memory"),
)
with patch(
"deerflow.runtime.checkpointer.provider.get_app_config",
side_effect=AssertionError("ambient get_app_config() must not be used when app_config is explicit"),
):
cp = get_checkpointer(app_config=explicit_config)
assert isinstance(cp, InMemorySaver)
def test_explicit_app_config_uses_unified_database_sqlite_backend(self):
explicit_config = SimpleNamespace(
checkpointer=None,
database=SimpleNamespace(backend="sqlite", checkpointer_sqlite_path="/tmp/explicit/deerflow.db"),
)
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
with (
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
patch(
"deerflow.runtime.checkpointer.provider.get_app_config",
side_effect=AssertionError("ambient get_app_config() must not be used when app_config is explicit"),
),
patch("deerflow.runtime.checkpointer.provider.ensure_sqlite_parent_dir") as mock_ensure,
):
cp = get_checkpointer(app_config=explicit_config)
assert cp is mock_saver_instance
mock_ensure.assert_called_once_with("/tmp/explicit/deerflow.db")
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/explicit/deerflow.db")
def test_sqlite_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
+22 -79
View File
@@ -437,85 +437,6 @@ class TestStream:
call_kwargs = agent.stream.call_args.kwargs
assert "messages" in call_kwargs["stream_mode"]
def test_stream_emits_additional_kwargs_updates_for_streamed_ai_messages(self, client):
"""stream() emits a follow-up AI event when attribution metadata arrives via values."""
assembled = AIMessage(
content="Hello!",
id="ai-1",
additional_kwargs={
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
},
)
agent = MagicMock()
agent.stream.return_value = iter(
[
("messages", (AIMessageChunk(content="Hello!", id="ai-1"), {})),
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
]
)
with (
patch.object(client, "_ensure_agent"),
patch.object(client, "_agent", agent),
):
events = list(client.stream("hi", thread_id="t-stream-kwargs"))
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
assert any(event.data.get("content") == "Hello!" for event in ai_events)
assert any(event.data.get("additional_kwargs", {}).get("token_usage_attribution", {}).get("kind") == "final_answer" for event in ai_events)
def test_stream_emits_new_additional_kwargs_after_prior_metadata(self, client):
"""stream() emits later attribution metadata even after earlier kwargs for the same id."""
attribution = {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
assembled = AIMessage(
content="Hello!",
id="ai-1",
additional_kwargs={
"reasoning_content": "Thinking first.",
"token_usage_attribution": attribution,
},
)
agent = MagicMock()
agent.stream.return_value = iter(
[
(
"messages",
(
AIMessageChunk(
content="Hello!",
id="ai-1",
additional_kwargs={"reasoning_content": "Thinking first."},
),
{},
),
),
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
]
)
with (
patch.object(client, "_ensure_agent"),
patch.object(client, "_agent", agent),
):
events = list(client.stream("hi", thread_id="t-stream-kwargs-delta"))
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
metadata_events = [event for event in ai_events if event.data.get("additional_kwargs")]
assert metadata_events[0].data["additional_kwargs"] == {"reasoning_content": "Thinking first."}
assert metadata_events[1].data["content"] == ""
assert metadata_events[1].data["additional_kwargs"] == {"token_usage_attribution": attribution}
def test_chat_accumulates_streamed_deltas(self, client):
"""chat() concatenates per-id deltas from messages mode."""
agent = MagicMock()
@@ -927,6 +848,28 @@ class TestEnsureAgent:
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"}
def test_threads_explicit_app_config_to_dependencies(self, client):
"""Client-owned AppConfig must flow into model/tool/prompt/checkpointer composition."""
mock_agent = MagicMock()
mock_checkpointer = MagicMock()
config = client._get_runnable_config("t1")
with (
patch("deerflow.client.create_chat_model", return_value=MagicMock()) as mock_create_chat_model,
patch("deerflow.client.create_agent", return_value=mock_agent),
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
patch("deerflow.tools.get_available_tools", return_value=[]) as mock_get_available_tools,
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer) as mock_get_checkpointer,
):
client._ensure_agent(config)
assert mock_create_chat_model.call_args.kwargs["app_config"] is client._app_config
assert mock_build_middlewares.call_args.kwargs["app_config"] is client._app_config
assert mock_apply_prompt.call_args.kwargs["app_config"] is client._app_config
assert mock_get_available_tools.call_args.kwargs["app_config"] is client._app_config
assert mock_get_checkpointer.call_args.kwargs["app_config"] is client._app_config
def test_uses_default_checkpointer_when_available(self, client):
mock_agent = MagicMock()
mock_checkpointer = MagicMock()
@@ -1,53 +0,0 @@
"""Tests for DeerFlowClient message serialization helpers."""
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.client import DeerFlowClient
def test_serialize_ai_message_preserves_additional_kwargs():
message = AIMessage(
content="done",
additional_kwargs={
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
},
usage_metadata={"input_tokens": 12, "output_tokens": 3, "total_tokens": 15},
)
serialized = DeerFlowClient._serialize_message(message)
assert serialized["type"] == "ai"
assert serialized["usage_metadata"] == {
"input_tokens": 12,
"output_tokens": 3,
"total_tokens": 15,
}
assert serialized["additional_kwargs"] == {
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
}
def test_serialize_human_message_preserves_additional_kwargs():
message = HumanMessage(
content="hello",
additional_kwargs={"files": [{"name": "diagram.png"}]},
)
serialized = DeerFlowClient._serialize_message(message)
assert serialized == {
"type": "human",
"content": "hello",
"id": None,
"additional_kwargs": {"files": [{"name": "diagram.png"}]},
}
-30
View File
@@ -82,36 +82,6 @@ def test_parse_response_text_content():
assert result.generations[0].message.content == "Hello world"
def test_parse_response_populates_usage_metadata():
model = _make_model()
response = {
"output": [
{
"type": "message",
"content": [{"type": "output_text", "text": "Hello world"}],
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
"input_tokens_details": {"cached_tokens": 3},
"output_tokens_details": {"reasoning_tokens": 2},
},
"model": "gpt-5.4",
}
result = model._parse_response(response)
meta = result.generations[0].message.usage_metadata
assert meta is not None
assert meta["input_tokens"] == 10
assert meta["output_tokens"] == 5
assert meta["total_tokens"] == 15
assert meta["input_token_details"]["cache_read"] == 3
assert meta["output_token_details"]["reasoning"] == 2
def test_parse_response_reasoning_content():
model = _make_model()
response = {
@@ -192,7 +192,6 @@ def test_agent_features_defaults():
assert f.vision is False
assert f.auto_title is False
assert f.guardrail is False
assert f.loop_detection is True
# ---------------------------------------------------------------------------
@@ -631,51 +630,6 @@ def test_loop_detection_before_clarification(mock_create_agent):
assert loop_idx == clar_idx - 1
# ---------------------------------------------------------------------------
# 30b. loop_detection=False skips LoopDetectionMiddleware
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_loop_detection_disabled(mock_create_agent):
mock_create_agent.return_value = MagicMock()
create_deerflow_agent(
_make_mock_model(),
features=RuntimeFeatures(sandbox=False, loop_detection=False),
)
call_kwargs = mock_create_agent.call_args[1]
mw_types = [type(m).__name__ for m in call_kwargs["middleware"]]
assert "LoopDetectionMiddleware" not in mw_types
# ---------------------------------------------------------------------------
# 30c. loop_detection=<custom AgentMiddleware> replaces the default
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_loop_detection_custom_middleware(mock_create_agent):
from langchain.agents.middleware import AgentMiddleware as AM
mock_create_agent.return_value = MagicMock()
class MyLoopDetection(AM):
pass
custom = MyLoopDetection()
create_deerflow_agent(
_make_mock_model(),
features=RuntimeFeatures(sandbox=False, loop_detection=custom),
)
call_kwargs = mock_create_agent.call_args[1]
middleware = call_kwargs["middleware"]
assert custom in middleware
mw_types = [type(m).__name__ for m in middleware]
# Default LoopDetectionMiddleware must not also appear.
assert "LoopDetectionMiddleware" not in mw_types
# Custom replacement still sits immediately before ClarificationMiddleware.
assert mw_types[-1] == "ClarificationMiddleware"
assert mw_types[-2] == "MyLoopDetection"
# ---------------------------------------------------------------------------
# 31. plan_mode=True adds TodoMiddleware
# ---------------------------------------------------------------------------
-2
View File
@@ -85,8 +85,6 @@ def test_load_claude_code_credential_from_override_path(tmp_path, monkeypatch):
def test_load_claude_code_credential_ignores_directory_path(tmp_path, monkeypatch):
_clear_claude_code_env(monkeypatch)
# Redirect HOME so the default ~/.claude/.credentials.json doesn't exist
monkeypatch.setenv("HOME", str(tmp_path))
cred_dir = tmp_path / "claude-creds-dir"
cred_dir.mkdir()
monkeypatch.setenv("CLAUDE_CODE_CREDENTIALS_PATH", str(cred_dir))
-235
View File
@@ -1,235 +0,0 @@
"""Tests for CSRF middleware."""
from fastapi import FastAPI
from starlette.testclient import TestClient
from app.gateway.csrf_middleware import CSRFMiddleware
def _make_app() -> FastAPI:
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/api/v1/auth/login/local")
async def login_local():
return {"ok": True}
@app.post("/api/v1/auth/register")
async def register():
return {"ok": True}
@app.post("/api/threads/abc/runs/stream")
async def protected_mutation():
return {"ok": True}
return app
def test_auth_post_rejects_cross_origin_browser_request():
"""CSRF-exempt auth routes must not accept hostile browser origins.
Login/register endpoints intentionally skip the double-submit token because
first-time callers do not have a token yet. They still set an auth session,
so a hostile cross-site form POST must be rejected to avoid login CSRF /
session fixation.
"""
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://evil.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
def test_auth_post_allows_same_origin_browser_request():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_rejects_malformed_origin_with_path():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example/path"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
assert response.cookies.get("csrf_token") is None
def test_auth_post_rejects_malformed_origin_with_invalid_port():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example:bad"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
assert response.cookies.get("csrf_token") is None
def test_auth_post_allows_same_origin_default_port_equivalence():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example:443"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_forwarded_same_origin():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "https://deerflow.example",
"X-Forwarded-Proto": "https",
"X-Forwarded-Host": "deerflow.example, internal:8000",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_forwarded_same_origin_with_non_default_port():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "http://localhost:2026",
"X-Forwarded-Proto": "http",
"X-Forwarded-Host": "localhost:2026",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_rfc_forwarded_same_origin():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "https://deerflow.example",
"Forwarded": "proto=https;host=deerflow.example",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
assert "secure" in response.headers["set-cookie"].lower()
def test_auth_post_allows_explicit_configured_origin(monkeypatch):
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "https://app.example")
client = TestClient(_make_app(), base_url="https://api.example")
response = client.post(
"/api/v1/auth/register",
headers={"Origin": "https://app.example"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_does_not_treat_wildcard_cors_as_allowed_origin(monkeypatch):
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "*")
client = TestClient(_make_app(), base_url="https://api.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://evil.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
def test_auth_post_sets_strict_samesite_csrf_cookie():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 200
set_cookie = response.headers["set-cookie"].lower()
assert "csrf_token=" in set_cookie
assert "samesite=strict" in set_cookie
assert "secure" in set_cookie
def test_auth_post_without_origin_still_allows_non_browser_clients():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post("/api/v1/auth/login/local")
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_non_auth_mutation_still_requires_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/threads/abc/runs/stream",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
def test_non_auth_mutation_allows_valid_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
client.cookies.set("csrf_token", "known-token")
response = client.post(
"/api/threads/abc/runs/stream",
headers={
"Origin": "https://deerflow.example",
"X-CSRF-Token": "known-token",
},
)
assert response.status_code == 200
def test_non_auth_mutation_rejects_mismatched_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
client.cookies.set("csrf_token", "cookie-token")
response = client.post(
"/api/threads/abc/runs/stream",
headers={
"Origin": "https://deerflow.example",
"X-CSRF-Token": "header-token",
},
)
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token mismatch."
+2 -16
View File
@@ -537,10 +537,7 @@ class TestAgentsAPI:
def test_create_persists_files_on_disk(self, agent_client, tmp_path):
agent_client.post("/api/agents", json={"name": "disk-check", "soul": "disk soul"})
# tests/conftest.py installs an autouse fixture that sets the
# contextvar to "test-user-autouse", so the agent is persisted under
# users/test-user-autouse/agents/ rather than the legacy shared dir.
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "disk-check"
agent_dir = tmp_path / "agents" / "disk-check"
assert agent_dir.exists()
assert (agent_dir / "config.yaml").exists()
assert (agent_dir / "SOUL.md").exists()
@@ -548,23 +545,12 @@ class TestAgentsAPI:
def test_delete_removes_files_from_disk(self, agent_client, tmp_path):
agent_client.post("/api/agents", json={"name": "remove-me", "soul": "bye"})
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "remove-me"
agent_dir = tmp_path / "agents" / "remove-me"
assert agent_dir.exists()
agent_client.delete("/api/agents/remove-me")
assert not agent_dir.exists()
def test_create_rejects_legacy_name_collision(self, agent_client, tmp_path):
"""An unmigrated legacy agent must still block name collision so that
running the migration script later won't shadow the legacy entry."""
legacy_dir = tmp_path / "agents" / "legacy-agent"
legacy_dir.mkdir(parents=True)
(legacy_dir / "config.yaml").write_text("name: legacy-agent\n", encoding="utf-8")
(legacy_dir / "SOUL.md").write_text("legacy soul", encoding="utf-8")
response = agent_client.post("/api/agents", json={"name": "legacy-agent", "soul": "x"})
assert response.status_code == 409
# ===========================================================================
# 9. Gateway API User Profile endpoints
-201
View File
@@ -1,201 +0,0 @@
"""Unit tests for scripts/detect_uv_extras.py.
The detector resolves uv extras for `make dev` so that postgres (and any
future opt-in extras) are not wiped on every restart see Issue #2754.
"""
from __future__ import annotations
import importlib.util
from pathlib import Path
import pytest
REPO_ROOT = Path(__file__).resolve().parents[2]
DETECT_SCRIPT_PATH = REPO_ROOT / "scripts" / "detect_uv_extras.py"
spec = importlib.util.spec_from_file_location("deerflow_detect_uv_extras", DETECT_SCRIPT_PATH)
assert spec is not None and spec.loader is not None
detect = importlib.util.module_from_spec(spec)
spec.loader.exec_module(detect)
@pytest.fixture
def isolated_cwd(tmp_path, monkeypatch):
"""Isolate `find_config_file()` from the real repo by chdir + clearing env."""
monkeypatch.chdir(tmp_path)
monkeypatch.delenv("UV_EXTRAS", raising=False)
monkeypatch.delenv("DEER_FLOW_CONFIG_PATH", raising=False)
return tmp_path
def test_parse_env_extras_supports_comma_and_whitespace():
assert detect.parse_env_extras("postgres") == ["postgres"]
assert detect.parse_env_extras("postgres,ollama") == ["postgres", "ollama"]
assert detect.parse_env_extras("postgres ollama") == ["postgres", "ollama"]
assert detect.parse_env_extras(" postgres , ollama ,") == ["postgres", "ollama"]
assert detect.parse_env_extras("") == []
def test_parse_env_extras_drops_shell_metacharacters(capsys):
"""A `.env` value containing shell injection bait must not pass through.
The whitelist guarantees the *bytes* that reach `uv sync` cannot include
shell metacharacters. Any name that looks identifier-like still survives
(uv itself will reject unknown extras with its own error), but `;`, `&`,
backticks, parentheses, slashes, etc. are stripped.
"""
# Pure-metacharacter inputs collapse to empty.
assert detect.parse_env_extras(";") == []
assert detect.parse_env_extras("$(whoami)") == []
assert detect.parse_env_extras("`echo bad`") == []
assert detect.parse_env_extras("postgres;evil") == [] # single token, contains `;`
# Splitting on whitespace yields ['rm'] which is identifier-shaped, but the
# destructive bits (`;`, `-rf`, `/`) are dropped.
assert detect.parse_env_extras("; rm -rf /") == ["rm"]
err = capsys.readouterr().err
assert "ignoring invalid UV_EXTRAS entry" in err
assert "';'" in err # confirms the dangerous token was reported and dropped
def test_parse_env_extras_rejects_leading_digits_and_punctuation():
"""Names must start with a letter — pyproject extras follow this shape."""
assert detect.parse_env_extras("1postgres") == []
assert detect.parse_env_extras("-postgres") == []
# Hyphens and underscores inside the name are fine.
assert detect.parse_env_extras("post_gres") == ["post_gres"]
assert detect.parse_env_extras("post-gres") == ["post-gres"]
def test_format_flags_emits_one_flag_per_extra():
assert detect.format_flags([]) == ""
assert detect.format_flags(["postgres"]) == "--extra postgres"
assert detect.format_flags(["postgres", "ollama"]) == "--extra postgres --extra ollama"
def test_strip_comment_preserves_quoted_hash():
assert detect._strip_comment("backend: postgres # trailing") == "backend: postgres"
assert detect._strip_comment('name: "value#with-hash"') == 'name: "value#with-hash"'
assert detect._strip_comment("# whole line comment") == ""
def test_section_value_finds_nested_key():
yaml_lines = [
"database:",
" backend: postgres",
" postgres_url: $DATABASE_URL",
"",
"checkpointer:",
" type: sqlite",
]
assert detect.section_value(yaml_lines, "database", "backend") == "postgres"
assert detect.section_value(yaml_lines, "checkpointer", "type") == "sqlite"
assert detect.section_value(yaml_lines, "database", "missing") is None
assert detect.section_value(yaml_lines, "absent_section", "anything") is None
def test_section_value_ignores_commented_lines():
yaml_lines = [
"# database:",
"# backend: postgres",
"database:",
" backend: sqlite",
]
assert detect.section_value(yaml_lines, "database", "backend") == "sqlite"
def test_section_value_strips_quotes():
yaml_lines = [
"database:",
' backend: "postgres"',
]
assert detect.section_value(yaml_lines, "database", "backend") == "postgres"
def test_section_value_does_not_descend_into_grandchildren():
yaml_lines = [
"database:",
" backend: sqlite",
" nested:",
" backend: postgres",
]
# Only the immediate child level counts — keeps the parser predictable.
assert detect.section_value(yaml_lines, "database", "backend") == "sqlite"
def test_detect_from_config_postgres_via_database(tmp_path):
cfg = tmp_path / "config.yaml"
cfg.write_text("database:\n backend: postgres\n postgres_url: $DATABASE_URL\n")
assert detect.detect_from_config(cfg) == ["postgres"]
def test_detect_from_config_postgres_via_checkpointer(tmp_path):
cfg = tmp_path / "config.yaml"
cfg.write_text("checkpointer:\n type: postgres\n connection_string: postgresql://localhost/db\n")
assert detect.detect_from_config(cfg) == ["postgres"]
def test_detect_from_config_sqlite_returns_no_extras(tmp_path):
cfg = tmp_path / "config.yaml"
cfg.write_text("database:\n backend: sqlite\n sqlite_dir: .deer-flow/data\n")
assert detect.detect_from_config(cfg) == []
def test_detect_from_config_dedupes_when_both_present(tmp_path):
cfg = tmp_path / "config.yaml"
cfg.write_text("checkpointer:\n type: postgres\ndatabase:\n backend: postgres\n")
# Sorted unique extras, no double-counting.
assert detect.detect_from_config(cfg) == ["postgres"]
def test_detect_from_config_missing_file_returns_empty(tmp_path):
assert detect.detect_from_config(tmp_path / "does-not-exist.yaml") == []
def test_resolve_extras_env_overrides_config(isolated_cwd, monkeypatch):
cfg = isolated_cwd / "config.yaml"
cfg.write_text("database:\n backend: sqlite\n")
monkeypatch.setenv("UV_EXTRAS", "postgres")
assert detect.resolve_extras() == ["postgres"]
def test_resolve_extras_env_supports_multiple(isolated_cwd, monkeypatch):
monkeypatch.setenv("UV_EXTRAS", "postgres,ollama")
assert detect.resolve_extras() == ["postgres", "ollama"]
def test_resolve_extras_falls_back_to_config(isolated_cwd):
(isolated_cwd / "config.yaml").write_text("database:\n backend: postgres\n")
assert detect.resolve_extras() == ["postgres"]
def test_resolve_extras_respects_explicit_config_path(tmp_path, monkeypatch):
monkeypatch.delenv("UV_EXTRAS", raising=False)
elsewhere = tmp_path / "elsewhere.yaml"
elsewhere.write_text("database:\n backend: postgres\n")
monkeypatch.chdir(tmp_path)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(elsewhere))
assert detect.resolve_extras() == ["postgres"]
def test_resolve_extras_no_config_no_env(isolated_cwd):
assert detect.resolve_extras() == []
def test_resolve_extras_finds_backend_subdir_config(isolated_cwd):
sub = isolated_cwd / "backend"
sub.mkdir()
(sub / "config.yaml").write_text("database:\n backend: postgres\n")
assert detect.resolve_extras() == ["postgres"]
def test_resolve_extras_root_config_takes_precedence(isolated_cwd):
(isolated_cwd / "config.yaml").write_text("database:\n backend: sqlite\n")
sub = isolated_cwd / "backend"
sub.mkdir()
(sub / "config.yaml").write_text("database:\n backend: postgres\n")
# Root config.yaml is checked first, matching the precedence in serve.sh.
assert detect.resolve_extras() == []
-102
View File
@@ -1,102 +0,0 @@
"""Unit tests for docker/dev-entrypoint.sh (UV_EXTRAS validation + parsing).
Exercises the script via its `--print-extras` dry-run hook so we don't actually
launch uvicorn or hit /app/logs. Together with test_detect_uv_extras.py these
cover both the local make-dev path and the docker-compose-dev path with the
same shape see PR #2767 / Issue #2754.
"""
from __future__ import annotations
import os
import subprocess
from pathlib import Path
import pytest
REPO_ROOT = Path(__file__).resolve().parents[2]
ENTRYPOINT = REPO_ROOT / "docker" / "dev-entrypoint.sh"
def _run(uv_extras: str | None) -> subprocess.CompletedProcess[str]:
"""Invoke `dev-entrypoint.sh --print-extras` with UV_EXTRAS set."""
env = os.environ.copy()
env.pop("UV_EXTRAS", None)
if uv_extras is not None:
env["UV_EXTRAS"] = uv_extras
return subprocess.run(
["sh", str(ENTRYPOINT), "--print-extras"],
env=env,
capture_output=True,
text=True,
check=False,
)
def test_entrypoint_script_exists_and_is_posix_sh():
assert ENTRYPOINT.is_file()
# Catch syntax errors before runtime — `sh -n` is a parse-only check.
proc = subprocess.run(["sh", "-n", str(ENTRYPOINT)], capture_output=True, text=True, check=False)
assert proc.returncode == 0, proc.stderr
def test_no_uv_extras_yields_empty_flags():
proc = _run(None)
assert proc.returncode == 0
assert proc.stdout.strip() == ""
def test_single_extra():
proc = _run("postgres")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra postgres"
def test_multi_extra_comma_separated():
proc = _run("postgres,ollama")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra postgres --extra ollama"
def test_multi_extra_whitespace_separated():
proc = _run("postgres ollama")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra postgres --extra ollama"
def test_multi_extra_mixed_separators():
proc = _run(" postgres , ollama ,")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra postgres --extra ollama"
def test_empty_string_yields_empty_flags():
proc = _run("")
assert proc.returncode == 0
assert proc.stdout.strip() == ""
@pytest.mark.parametrize(
"bad_value",
[
"; rm -rf /", # the canonical injection attempt
"$(whoami)", # command substitution
"`echo bad`", # backticks
"postgres;evil", # mixed legal+illegal in a single token
"1postgres", # leading digit
"-postgres", # leading hyphen
"post gres extra/path", # contains slash
],
)
def test_metacharacters_abort_with_nonzero_exit(bad_value):
proc = _run(bad_value)
assert proc.returncode != 0, f"expected abort for {bad_value!r}, got 0"
assert "is invalid" in proc.stderr
assert proc.stdout.strip() == ""
def test_underscores_and_hyphens_in_name_are_allowed():
"""Mirrors uv's accepted shape for `[project.optional-dependencies]` keys."""
proc = _run("post_gres,post-gres")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra post_gres --extra post-gres"
@@ -1,336 +0,0 @@
"""Tests for DynamicContextMiddleware.
Verifies that memory and current date are injected as a <system-reminder> into
the first HumanMessage exactly once per session (frozen-snapshot pattern).
"""
from types import SimpleNamespace
from unittest import mock
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares.dynamic_context_middleware import (
_DYNAMIC_CONTEXT_REMINDER_KEY,
DynamicContextMiddleware,
)
_SYSTEM_REMINDER_TAG = "<system-reminder>"
def _make_middleware(**kwargs) -> DynamicContextMiddleware:
return DynamicContextMiddleware(**kwargs)
def _fake_runtime():
return SimpleNamespace(context={})
def _reminder_msg(content: str, msg_id: str) -> HumanMessage:
"""Build a reminder HumanMessage the way the middleware would produce it."""
return HumanMessage(
content=content,
id=msg_id,
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
)
# ---------------------------------------------------------------------------
# Basic injection
# ---------------------------------------------------------------------------
def test_injects_system_reminder_into_first_human_message():
mw = _make_middleware()
state = {"messages": [HumanMessage(content="Hello", id="msg-1")]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
updated_msgs = result["messages"]
assert len(updated_msgs) == 2
reminder_msg = updated_msgs[0]
assert isinstance(reminder_msg, HumanMessage)
assert reminder_msg.id == "msg-1" # takes the original ID (position swap)
assert reminder_msg.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert _SYSTEM_REMINDER_TAG in reminder_msg.content
assert "<current_date>2026-05-08, Friday</current_date>" in reminder_msg.content
assert "Hello" not in reminder_msg.content # reminder only — no user text
user_msg = updated_msgs[1]
assert isinstance(user_msg, HumanMessage)
assert user_msg.id == "msg-1__user" # derived ID
assert user_msg.content == "Hello"
def test_memory_included_when_present():
mw = _make_middleware()
state = {"messages": [HumanMessage(content="Hi", id="msg-1")]}
with (
mock.patch(
"deerflow.agents.lead_agent.prompt._get_memory_context",
return_value="<memory>\nUser prefers Python.\n</memory>",
),
mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt,
):
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
# Reminder is the first returned message; user query is the second
reminder_content = result["messages"][0].content
assert "User prefers Python." in reminder_content
assert "<current_date>2026-05-08, Friday</current_date>" in reminder_content
assert result["messages"][1].content == "Hi"
# ---------------------------------------------------------------------------
# Frozen-snapshot: no re-injection within a session
# ---------------------------------------------------------------------------
def test_skips_injection_if_already_present():
"""Second turn: separate reminder message already present → no update."""
mw = _make_middleware()
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
state = {
"messages": [
_reminder_msg(reminder_content, "msg-1"),
HumanMessage(content="Hello", id="msg-1__user"),
AIMessage(content="Hi there"),
HumanMessage(content="Follow-up", id="msg-2"),
]
}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is None # no update needed
def test_injects_only_into_first_human_message_not_later_ones():
"""Reminder targets the first HumanMessage; subsequent messages are not touched."""
mw = _make_middleware()
state = {
"messages": [
HumanMessage(content="First", id="msg-1"),
AIMessage(content="Reply"),
HumanMessage(content="Second", id="msg-2"),
]
}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
msgs = result["messages"]
# Only the two injected messages are returned (reminder + original first query)
assert len(msgs) == 2
assert msgs[0].id == "msg-1" # reminder takes first message's ID
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert _SYSTEM_REMINDER_TAG in msgs[0].content
assert msgs[1].id == "msg-1__user" # original content with derived ID
assert msgs[1].content == "First"
# "Second" (msg-2) is not in the returned update — it is left unchanged
assert all(m.id != "msg-2" for m in msgs)
def test_summary_human_message_is_not_used_as_injection_target():
"""After summarization, the synthetic summary HumanMessage is not a user turn."""
mw = _make_middleware()
state = {
"messages": [
HumanMessage(content="Here is a summary of the conversation to date:\n\n...", id="summary-1", name="summary"),
AIMessage(content="Earlier reply"),
HumanMessage(content="Follow-up", id="msg-2"),
]
}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
msgs = result["messages"]
assert len(msgs) == 2
assert msgs[0].id == "msg-2"
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert msgs[1].id == "msg-2__user"
assert msgs[1].content == "Follow-up"
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
def test_no_messages_returns_none():
mw = _make_middleware()
result = mw.before_agent({"messages": []}, _fake_runtime())
assert result is None
def test_no_human_message_returns_none():
mw = _make_middleware()
state = {"messages": [AIMessage(content="assistant only")]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""):
result = mw.before_agent(state, _fake_runtime())
assert result is None
def test_list_content_message_handled_as_separate_reminder():
"""List-content (e.g. multi-modal) messages remain intact; reminder is a separate message."""
mw = _make_middleware()
original_content = [{"type": "text", "text": "Hello"}]
state = {"messages": [HumanMessage(content=original_content, id="msg-1")]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
msgs = result["messages"]
assert len(msgs) == 2
# Reminder is a plain string message with the flag set
assert isinstance(msgs[0].content, str)
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert _SYSTEM_REMINDER_TAG in msgs[0].content
# Original list-content message is untouched
assert msgs[1].content == original_content
def test_reminder_uses_original_id_user_message_uses_derived_id():
"""Reminder takes original ID (position swap); user message gets {id}__user."""
mw = _make_middleware()
original_id = "original-id-abc"
state = {"messages": [HumanMessage(content="Hello", id=original_id)]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result["messages"][0].id == original_id
assert result["messages"][1].id == f"{original_id}__user"
def test_message_without_id_gets_stable_uuid():
"""If the original HumanMessage has no ID, a UUID is generated and used consistently."""
mw = _make_middleware()
state = {"messages": [HumanMessage(content="Hello", id=None)]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
reminder_id = result["messages"][0].id
user_id = result["messages"][1].id
assert reminder_id is not None
assert reminder_id != "None"
assert user_id == f"{reminder_id}__user"
def test_user_message_containing_system_reminder_tag_does_not_prevent_injection():
"""A user message containing '<system-reminder>' must not be mistaken for a reminder."""
mw = _make_middleware()
state = {
"messages": [
HumanMessage(content="What is <system-reminder>?", id="msg-1"),
]
}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
# Injection must happen — the user message does NOT carry the reminder flag
assert result is not None
assert result["messages"][0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
# ---------------------------------------------------------------------------
# Midnight crossing
# ---------------------------------------------------------------------------
def test_midnight_crossing_injects_date_update_as_separate_message():
"""When the date has changed, a separate date-update reminder is injected before
the current turn's HumanMessage using the ID-swap technique."""
mw = _make_middleware()
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
state = {
"messages": [
_reminder_msg(reminder_content, "msg-1"),
HumanMessage(content="Hello", id="msg-1__user"),
AIMessage(content="Response"),
HumanMessage(content="Good morning", id="msg-2"),
]
}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
msgs = result["messages"]
assert len(msgs) == 2
# Date-update reminder takes the current message's ID
assert msgs[0].id == "msg-2"
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert _SYSTEM_REMINDER_TAG in msgs[0].content
assert "<current_date>2026-05-09, Saturday</current_date>" in msgs[0].content
assert "Good morning" not in msgs[0].content # reminder only
# Original user text appended with derived ID
assert msgs[1].id == "msg-2__user"
assert msgs[1].content == "Good morning"
def test_midnight_crossing_id_swap():
"""Date-update reminder uses original ID; user message uses {id}__user."""
mw = _make_middleware()
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
state = {
"messages": [
_reminder_msg(reminder_content, "msg-1"),
HumanMessage(content="Next day message", id="msg-2"),
]
}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
result = mw.before_agent(state, _fake_runtime())
assert result["messages"][0].id == "msg-2"
assert result["messages"][1].id == "msg-2__user"
def test_no_second_midnight_injection_once_date_updated():
"""After a midnight update is persisted, the same-day path skips re-injection."""
mw = _make_middleware()
date_update_content = "<system-reminder>\n<current_date>2026-05-09, Saturday</current_date>\n</system-reminder>"
state = {
"messages": [
_reminder_msg(
"<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
"msg-1",
),
HumanMessage(content="Hello", id="msg-1__user"),
AIMessage(content="Response"),
_reminder_msg(date_update_content, "msg-2"),
HumanMessage(content="Good morning", id="msg-2__user"),
AIMessage(content="Good morning!"),
HumanMessage(content="Third turn", id="msg-3"),
]
}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
result = mw.before_agent(state, _fake_runtime())
assert result is None # same day as last injected date → no update
@@ -50,7 +50,7 @@ def test_nginx_routes_official_langgraph_prefix_to_gateway_api():
assert "/api/langgraph-compat" not in content
assert "proxy_pass http://langgraph" not in content
assert "rewrite ^/api/langgraph/(.*) /api/$1 break;" in content
assert "proxy_pass http://gateway" in content or "proxy_pass http://$gateway_upstream" in content
assert "proxy_pass http://gateway" in content
def test_frontend_rewrites_langgraph_prefix_to_gateway():
-15
View File
@@ -324,21 +324,6 @@ def test_context_does_not_override_existing_configurable():
assert config["configurable"]["subagent_enabled"] is True
def test_inject_authenticated_user_context_overrides_client_user_id():
"""Run context should carry the authenticated user, not client-supplied user_id."""
from types import SimpleNamespace
from app.gateway.services import build_run_config, inject_authenticated_user_context
config = build_run_config("thread-1", None, None)
config["context"] = {"user_id": "spoofed-client"}
request = SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(id="auth-user-42")))
inject_authenticated_user_context(config, request)
assert config["context"]["user_id"] == "auth-user-42"
# ---------------------------------------------------------------------------
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
# ---------------------------------------------------------------------------
@@ -8,20 +8,17 @@ from unittest.mock import MagicMock
import pytest
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.config.app_config import AppConfig
from deerflow.config.loop_detection_config import LoopDetectionConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.summarization_config import SummarizationConfig
def _make_app_config(models: list[ModelConfig], loop_detection: LoopDetectionConfig | None = None) -> AppConfig:
def _make_app_config(models: list[ModelConfig]) -> AppConfig:
return AppConfig(
models=models,
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider"),
loop_detection=loop_detection or LoopDetectionConfig(),
)
@@ -343,59 +340,6 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa
assert middlewares[0] == "base-middleware"
def test_build_middlewares_uses_loop_detection_config(monkeypatch):
app_config = _make_app_config(
[_make_model("safe-model", supports_thinking=False)],
loop_detection=LoopDetectionConfig(
warn_threshold=7,
hard_limit=9,
window_size=30,
max_tracked_threads=40,
tool_freq_warn=50,
tool_freq_hard_limit=60,
),
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
loop_detection = next(m for m in middlewares if isinstance(m, LoopDetectionMiddleware))
assert loop_detection.warn_threshold == 7
assert loop_detection.hard_limit == 9
assert loop_detection.window_size == 30
assert loop_detection.max_tracked_threads == 40
assert loop_detection.tool_freq_warn == 50
assert loop_detection.tool_freq_hard_limit == 60
def test_build_middlewares_omits_loop_detection_when_disabled(monkeypatch):
app_config = _make_app_config(
[_make_model("safe-model", supports_thinking=False)],
loop_detection=LoopDetectionConfig(enabled=False),
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
assert not any(isinstance(m, LoopDetectionMiddleware) for m in middlewares)
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)])
app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork")
+3 -70
View File
@@ -1,37 +1,22 @@
import threading
from types import SimpleNamespace
from typing import cast
import anyio
from deerflow.agents.lead_agent import prompt as prompt_module
from deerflow.config.app_config import AppConfig
from deerflow.config.subagents_config import CustomSubagentConfig, SubagentsAppConfig
from deerflow.skills.types import Skill, SkillCategory
from deerflow.skills.types import Skill
def _set_skills_cache_state(*, skills=None, active=False, version=0):
prompt_module._get_cached_skills_prompt_section.cache_clear()
with prompt_module._enabled_skills_lock:
prompt_module._enabled_skills_cache = skills
prompt_module._enabled_skills_by_config_cache.clear()
prompt_module._enabled_skills_refresh_active = active
prompt_module._enabled_skills_refresh_version = version
prompt_module._enabled_skills_refresh_event.clear()
def test_build_self_update_section_empty_for_default_agent():
assert prompt_module._build_self_update_section(None) == ""
def test_build_self_update_section_present_for_custom_agent():
section = prompt_module._build_self_update_section("my-agent")
assert "<self_update>" in section
assert "my-agent" in section
assert "update_agent" in section
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
@@ -235,7 +220,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
skill_dir=skill_dir,
skill_file=skill_dir / "SKILL.md",
relative_path=skill_dir.relative_to(tmp_path),
category=SkillCategory.CUSTOM,
category="custom",
enabled=True,
)
@@ -255,58 +240,6 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
_set_skills_cache_state()
def test_explicit_config_enabled_skills_are_cached_by_config_identity(monkeypatch, tmp_path):
def make_skill(name: str) -> Skill:
skill_dir = tmp_path / name
return Skill(
name=name,
description=f"Description for {name}",
license="MIT",
skill_dir=skill_dir,
skill_file=skill_dir / "SKILL.md",
relative_path=skill_dir.relative_to(tmp_path),
category=SkillCategory.CUSTOM,
enabled=True,
)
config = cast(
AppConfig,
cast(
object,
SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=False),
),
),
)
load_count = 0
def fake_get_or_new_skill_storage(**kwargs):
nonlocal load_count
assert kwargs == {"app_config": config}
def load_skills(*, enabled_only):
nonlocal load_count
load_count += 1
assert enabled_only is True
return [make_skill("cached-skill")]
return SimpleNamespace(load_skills=load_skills)
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fake_get_or_new_skill_storage)
_set_skills_cache_state()
try:
first = prompt_module.get_skills_prompt_section(app_config=config)
second = prompt_module.get_skills_prompt_section(app_config=config)
assert "cached-skill" in first
assert "cached-skill" in second
assert load_count == 1
finally:
_set_skills_cache_state()
def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_path):
started = threading.Event()
release = threading.Event()
@@ -324,7 +257,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
skill_dir=skill_dir,
skill_file=skill_dir / "SKILL.md",
relative_path=skill_dir.relative_to(tmp_path),
category=SkillCategory.CUSTOM,
category="custom",
enabled=True,
)
+1 -111
View File
@@ -6,12 +6,7 @@ from deerflow.config.agents_config import AgentConfig
from deerflow.skills.types import Skill
class NamedTool:
def __init__(self, name: str):
self.name = name
def _make_skill(name: str, allowed_tools: list[str] | None = None) -> Skill:
def _make_skill(name: str) -> Skill:
return Skill(
name=name,
description=f"Description for {name}",
@@ -20,7 +15,6 @@ def _make_skill(name: str, allowed_tools: list[str] | None = None) -> Skill:
skill_file=Path(f"/tmp/{name}/SKILL.md"),
relative_path=Path(name),
category="public",
allowed_tools=allowed_tools,
enabled=True,
)
@@ -138,7 +132,6 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
@@ -171,106 +164,3 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert captured_skills[-1] == {"skill1"}
def test_make_lead_agent_filters_tools_from_available_skills(monkeypatch):
from unittest.mock import MagicMock
from deerflow.agents.lead_agent import agent as lead_agent_module
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted", "legacy"]))
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("restricted", ["read_file"]), _make_skill("legacy", None)])
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
def test_make_lead_agent_all_legacy_skills_preserve_all_tools(monkeypatch):
from unittest.mock import MagicMock
from deerflow.agents.lead_agent import agent as lead_agent_module
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("legacy", None)])
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file")])
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert [tool.name for tool in agent_kwargs["tools"]] == ["bash", "read_file", "update_agent"]
def test_make_lead_agent_enforces_allowed_tools_when_skill_cache_is_cold(monkeypatch):
from unittest.mock import MagicMock
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.agents.lead_agent import prompt as prompt_module
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
mock_storage = SimpleNamespace(load_skills=lambda *, enabled_only: [_make_skill("restricted", ["read_file"])])
with prompt_module._enabled_skills_lock:
prompt_module._enabled_skills_cache = None
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None, **kwargs: mock_storage)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
def test_make_lead_agent_fails_closed_when_skill_policy_load_fails(monkeypatch):
from unittest.mock import MagicMock
import pytest
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.agents.lead_agent import prompt as prompt_module
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
create_agent_mock = MagicMock()
monkeypatch.setattr(lead_agent_module, "create_agent", create_agent_mock)
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
def fail_storage(*args, **kwargs):
raise RuntimeError("skill storage unavailable")
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fail_storage)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
with pytest.raises(RuntimeError, match="skill storage unavailable"):
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
create_agent_mock.assert_not_called()
@@ -105,7 +105,6 @@ def test_execute_command_uses_powershell_command_mode_on_windows(monkeypatch):
"capture_output": True,
"text": True,
"timeout": 600,
"env": None,
},
)
]
@@ -119,7 +118,6 @@ def test_execute_command_uses_posix_shell_command_mode_on_windows(monkeypatch):
return SimpleNamespace(stdout="ok", stderr="", returncode=0)
monkeypatch.setattr(local_sandbox.os, "name", "nt")
monkeypatch.setattr(local_sandbox.os, "environ", {"PATH": r"C:\Program Files\Git\bin"})
monkeypatch.setattr(LocalSandbox, "_get_shell", staticmethod(lambda: r"C:\Program Files\Git\bin\sh.exe"))
monkeypatch.setattr(local_sandbox.subprocess, "run", fake_run)
@@ -134,33 +132,11 @@ def test_execute_command_uses_posix_shell_command_mode_on_windows(monkeypatch):
"capture_output": True,
"text": True,
"timeout": 600,
"env": {
"PATH": r"C:\Program Files\Git\bin",
"MSYS_NO_PATHCONV": "1",
"MSYS2_ARG_CONV_EXCL": "*",
},
},
)
]
def test_execute_command_does_not_set_msys_env_for_non_msys_posix_shell_on_windows(monkeypatch):
calls: list[tuple[object, dict]] = []
def fake_run(*args, **kwargs):
calls.append((args[0], kwargs))
return SimpleNamespace(stdout="ok", stderr="", returncode=0)
monkeypatch.setattr(local_sandbox.os, "name", "nt")
monkeypatch.setattr(LocalSandbox, "_get_shell", staticmethod(lambda: r"C:\tools\busybox\sh.exe"))
monkeypatch.setattr(local_sandbox.subprocess, "run", fake_run)
output = LocalSandbox("t").execute_command("echo /mnt/skills/demo")
assert output == "ok"
assert calls[0][1]["env"] is None
def test_execute_command_uses_cmd_command_mode_on_windows(monkeypatch):
calls: list[tuple[object, dict]] = []
@@ -183,7 +159,6 @@ def test_execute_command_uses_cmd_command_mode_on_windows(monkeypatch):
"capture_output": True,
"text": True,
"timeout": 600,
"env": None,
},
)
]
@@ -1,72 +0,0 @@
"""Tests for loop detection configuration."""
import pytest
from deerflow.config.loop_detection_config import LoopDetectionConfig
class TestLoopDetectionConfig:
def test_defaults_match_middleware_defaults(self):
config = LoopDetectionConfig()
assert config.enabled is True
assert config.warn_threshold == 3
assert config.hard_limit == 5
assert config.window_size == 20
assert config.max_tracked_threads == 100
assert config.tool_freq_warn == 30
assert config.tool_freq_hard_limit == 50
def test_accepts_custom_values(self):
config = LoopDetectionConfig(
enabled=False,
warn_threshold=10,
hard_limit=20,
window_size=50,
max_tracked_threads=200,
tool_freq_warn=60,
tool_freq_hard_limit=80,
)
assert config.enabled is False
assert config.warn_threshold == 10
assert config.hard_limit == 20
assert config.window_size == 50
assert config.max_tracked_threads == 200
assert config.tool_freq_warn == 60
assert config.tool_freq_hard_limit == 80
def test_rejects_zero_thresholds(self):
with pytest.raises(ValueError):
LoopDetectionConfig(warn_threshold=0)
with pytest.raises(ValueError):
LoopDetectionConfig(hard_limit=0)
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_warn=0)
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_hard_limit=0)
def test_rejects_hard_limit_below_warn_threshold(self):
with pytest.raises(ValueError, match="hard_limit"):
LoopDetectionConfig(warn_threshold=5, hard_limit=4)
def test_rejects_tool_freq_hard_limit_below_warn_threshold(self):
with pytest.raises(ValueError, match="tool_freq_hard_limit"):
LoopDetectionConfig(tool_freq_warn=5, tool_freq_hard_limit=4)
def test_tool_freq_override_valid(self):
config = LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 150, "hard_limit": 300}})
override = config.tool_freq_overrides["bash"]
assert override.warn == 150
assert override.hard_limit == 300
def test_tool_freq_override_rejects_zero_warn(self):
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 0, "hard_limit": 10}})
def test_tool_freq_override_rejects_hard_limit_below_warn(self):
with pytest.raises(ValueError, match="hard_limit"):
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 100, "hard_limit": 50}})
+4 -112
View File
@@ -3,7 +3,7 @@
import copy
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from deerflow.agents.middlewares.loop_detection_middleware import (
_HARD_STOP_MSG,
@@ -146,42 +146,14 @@ class TestLoopDetection:
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
# Third identical call triggers warning. The warning is appended to
# the AIMessage content (tool_calls preserved) — never inserted as a
# separate HumanMessage between the AIMessage(tool_calls) and its
# ToolMessage responses, which would break OpenAI/Moonshot strict
# tool-call pairing validation.
# Third identical call triggers warning
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], AIMessage)
assert len(msgs[0].tool_calls) == len(call)
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
assert isinstance(msgs[0], HumanMessage)
assert "LOOP DETECTED" in msgs[0].content
def test_warn_does_not_break_tool_call_pairing(self):
"""Regression: the warn branch must NOT inject a non-tool message
after an AIMessage(tool_calls=...). Moonshot/OpenAI reject the next
request with 'tool_call_ids did not have response messages' if any
non-tool message is wedged between the AIMessage and its ToolMessage
responses. See #2029.
"""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], AIMessage)
assert len(msgs[0].tool_calls) == len(call)
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
def test_warn_only_injected_once(self):
"""Warning for the same hash should only be injected once per thread."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
@@ -511,11 +483,7 @@ class TestToolFrequencyDetection:
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime)
assert result is not None
msg = result["messages"][0]
# Warning is appended to the AIMessage content; tool_calls preserved
# so the tools node still runs and Moonshot/OpenAI tool-call pairing
# validation does not break.
assert isinstance(msg, AIMessage)
assert msg.tool_calls
assert isinstance(msg, HumanMessage)
assert "read_file" in msg.content
assert "LOOP DETECTED" in msg.content
@@ -648,37 +616,6 @@ class TestToolFrequencyDetection:
assert result is not None
assert "read_file" in result["messages"][0].content
def test_override_tool_uses_override_thresholds(self):
"""A tool in tool_freq_overrides uses its own thresholds, not the global ones."""
mw = LoopDetectionMiddleware(
tool_freq_warn=5,
tool_freq_hard_limit=10,
tool_freq_overrides={"bash": (50, 100)},
)
runtime = _make_runtime()
# 10 bash calls — would hit global hard_limit=10, but bash override is 100
for i in range(10):
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
assert result is None, f"unexpected trigger on call {i + 1}"
def test_non_override_tool_falls_back_to_global(self):
"""A tool NOT in tool_freq_overrides uses the global warn/hard_limit."""
mw = LoopDetectionMiddleware(
tool_freq_warn=3,
tool_freq_hard_limit=6,
tool_freq_overrides={"bash": (50, 100)},
)
runtime = _make_runtime()
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 3rd read_file call hits global warn=3 (read_file has no override)
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
def test_hash_detection_takes_priority(self):
"""Hash-based hard stop fires before frequency check for identical calls."""
mw = LoopDetectionMiddleware(
@@ -699,48 +636,3 @@ class TestToolFrequencyDetection:
msg = result["messages"][0]
assert isinstance(msg, AIMessage)
assert _HARD_STOP_MSG in msg.content
class TestFromConfig:
"""Tests for LoopDetectionMiddleware.from_config — the sole validated construction path."""
@staticmethod
def _config(**kwargs):
from deerflow.config.loop_detection_config import LoopDetectionConfig
return LoopDetectionConfig(**kwargs)
def test_scalar_fields_mapped(self):
config = self._config(
warn_threshold=4,
hard_limit=8,
window_size=15,
max_tracked_threads=50,
tool_freq_warn=20,
tool_freq_hard_limit=40,
)
mw = LoopDetectionMiddleware.from_config(config)
assert mw.warn_threshold == 4
assert mw.hard_limit == 8
assert mw.window_size == 15
assert mw.max_tracked_threads == 50
assert mw.tool_freq_warn == 20
assert mw.tool_freq_hard_limit == 40
def test_overrides_converted_to_tuples(self):
config = self._config(tool_freq_overrides={"bash": {"warn": 50, "hard_limit": 100}})
mw = LoopDetectionMiddleware.from_config(config)
assert mw._tool_freq_overrides == {"bash": (50, 100)}
def test_empty_overrides(self):
mw = LoopDetectionMiddleware.from_config(self._config())
assert mw._tool_freq_overrides == {}
def test_constructed_middleware_detects_loops(self):
mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4))
runtime = _make_runtime()
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
@@ -125,68 +125,3 @@ class TestMigrateMemory:
from scripts.migrate_user_isolation import migrate_memory
migrate_memory(paths, user_id="default") # should not raise
class TestMigrateAgents:
@staticmethod
def _seed_legacy_agent(paths: Paths, name: str, *, soul: str = "soul", description: str = "d") -> Path:
legacy_dir = paths.agents_dir / name
legacy_dir.mkdir(parents=True, exist_ok=True)
(legacy_dir / "config.yaml").write_text(f"name: {name}\ndescription: {description}\n", encoding="utf-8")
(legacy_dir / "SOUL.md").write_text(soul, encoding="utf-8")
return legacy_dir
def test_moves_legacy_into_user_layout(self, base_dir: Path, paths: Paths):
self._seed_legacy_agent(paths, "agent-a", soul="soul-a")
self._seed_legacy_agent(paths, "agent-b", soul="soul-b")
from scripts.migrate_user_isolation import migrate_agents
report = migrate_agents(paths, user_id="default")
assert {entry["agent"] for entry in report} == {"agent-a", "agent-b"}
for entry in report:
assert entry["user_id"] == "default"
assert "moved -> " in entry["action"]
for name, soul in [("agent-a", "soul-a"), ("agent-b", "soul-b")]:
dest = paths.user_agent_dir("default", name)
assert dest.exists(), f"{name} should have moved into the per-user layout"
assert (dest / "SOUL.md").read_text() == soul
# Legacy agents/ root is cleaned up once empty.
assert not paths.agents_dir.exists()
def test_dry_run_does_not_move(self, base_dir: Path, paths: Paths):
legacy_dir = self._seed_legacy_agent(paths, "agent-a")
from scripts.migrate_user_isolation import migrate_agents
report = migrate_agents(paths, user_id="default", dry_run=True)
assert len(report) == 1
assert legacy_dir.exists(), "dry-run must not touch the filesystem"
assert not paths.user_agent_dir("default", "agent-a").exists()
def test_existing_destination_is_treated_as_conflict(self, base_dir: Path, paths: Paths):
self._seed_legacy_agent(paths, "agent-a", soul="legacy soul")
dest = paths.user_agent_dir("default", "agent-a")
dest.mkdir(parents=True)
(dest / "SOUL.md").write_text("preexisting", encoding="utf-8")
from scripts.migrate_user_isolation import migrate_agents
report = migrate_agents(paths, user_id="default")
assert report[0]["action"].startswith("conflict -> ")
# Per-user destination must be left untouched.
assert (dest / "SOUL.md").read_text() == "preexisting"
# Legacy copy lands under migration-conflicts/agents/.
conflicts_dir = paths.base_dir / "migration-conflicts" / "agents" / "agent-a"
assert (conflicts_dir / "SOUL.md").read_text() == "legacy soul"
def test_no_legacy_dir_is_noop(self, base_dir: Path, paths: Paths):
from scripts.migrate_user_isolation import migrate_agents
report = migrate_agents(paths, user_id="default")
assert report == []
@@ -50,21 +50,6 @@ class TestUserAgentMemoryFile:
assert paths.user_agent_memory_file("bob", "MyAgent") == expected
class TestUserAgentDir:
def test_user_agents_dir(self, paths: Paths):
assert paths.user_agents_dir("alice") == paths.base_dir / "users" / "alice" / "agents"
def test_user_agent_dir(self, paths: Paths):
assert paths.user_agent_dir("alice", "code-reviewer") == paths.base_dir / "users" / "alice" / "agents" / "code-reviewer"
def test_user_agent_dir_lowercases_name(self, paths: Paths):
assert paths.user_agent_dir("alice", "CodeReviewer") == paths.base_dir / "users" / "alice" / "agents" / "codereviewer"
def test_user_agent_dir_validates_user_id(self, paths: Paths):
with pytest.raises(ValueError, match="Invalid user_id"):
paths.user_agent_dir("../escape", "myagent")
class TestUserThreadDir:
def test_user_thread_dir(self, paths: Paths):
expected = paths.base_dir / "users" / "u1" / "threads" / "t1"
+9 -6
View File
@@ -8,9 +8,7 @@ Tests:
5. Postgres missing-dep error message
"""
import sys
from datetime import UTC, datetime
from unittest.mock import patch
import pytest
@@ -223,8 +221,13 @@ class TestEngineLifecycle:
"""If asyncpg is not installed, error message tells user what to do."""
from deerflow.persistence.engine import init_engine
with (
patch.dict(sys.modules, {"asyncpg": None}),
pytest.raises(ImportError, match="uv sync --all-packages --extra postgres"),
):
try:
import asyncpg # noqa: F401
pytest.skip("asyncpg is installed -- cannot test missing-dep path")
except ImportError:
# asyncpg is not installed — this is the expected state for this test.
# We proceed to verify that init_engine raises an actionable ImportError.
pass # noqa: S110 — intentionally ignored
with pytest.raises(ImportError, match="uv sync --extra postgres"):
await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x")
@@ -1,293 +0,0 @@
from __future__ import annotations
import pytest
import requests
from deerflow.community.aio_sandbox.remote_backend import RemoteSandboxBackend
from deerflow.community.aio_sandbox.sandbox_info import SandboxInfo
class _StubResponse:
def __init__(
self,
*,
status_code: int = 200,
payload: object | None = None,
json_exc: Exception | None = None,
):
self.status_code = status_code
self._payload = {} if payload is None else payload
self._json_exc = json_exc
self.ok = 200 <= status_code < 400
self.text = ""
def raise_for_status(self) -> None:
if self.status_code >= 400:
raise requests.HTTPError(f"HTTP {self.status_code}")
def json(self) -> object:
if self._json_exc is not None:
raise self._json_exc
return self._payload
def test_list_running_delegates_to_provisioner_list(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
sandbox_info = SandboxInfo(sandbox_id="test-id", sandbox_url="http://localhost:8080")
def mock_list():
return [sandbox_info]
monkeypatch.setattr(backend, "_provisioner_list", mock_list)
assert backend.list_running() == [sandbox_info]
def test_provisioner_list_returns_sandbox_infos_and_filters_invalid_entries(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
assert url == "http://provisioner:8002/api/sandboxes"
assert timeout == 10
return _StubResponse(
payload={
"sandboxes": [
{"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"},
{"sandbox_id": "missing-url"},
{"sandbox_url": "http://k3s:31002"},
]
}
)
monkeypatch.setattr(requests, "get", mock_get)
infos = backend._provisioner_list()
assert len(infos) == 1
assert infos[0].sandbox_id == "abc123"
assert infos[0].sandbox_url == "http://k3s:31001"
def test_provisioner_list_returns_empty_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
raise requests.RequestException("network down")
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_list() == []
def test_provisioner_list_returns_empty_when_payload_is_not_dict(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(payload=[{"sandbox_id": "abc", "sandbox_url": "http://k3s:31001"}])
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_list() == []
def test_provisioner_list_returns_empty_when_sandboxes_is_not_list(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(payload={"sandboxes": {"sandbox_id": "abc"}})
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_list() == []
def test_provisioner_list_skips_non_dict_sandbox_entries(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(
payload={
"sandboxes": [
{"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"},
"bad-entry",
123,
None,
]
}
)
monkeypatch.setattr(requests, "get", mock_get)
infos = backend._provisioner_list()
assert len(infos) == 1
assert infos[0].sandbox_id == "abc123"
assert infos[0].sandbox_url == "http://k3s:31001"
def test_create_delegates_to_provisioner_create(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
expected = SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001")
def mock_create(thread_id: str, sandbox_id: str, extra_mounts=None):
assert thread_id == "thread-1"
assert sandbox_id == "abc123"
assert extra_mounts == [("/host", "/container", False)]
return expected
monkeypatch.setattr(backend, "_provisioner_create", mock_create)
result = backend.create("thread-1", "abc123", extra_mounts=[("/host", "/container", False)])
assert result == expected
def test_provisioner_create_returns_sandbox_info(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_post(url: str, json: dict, timeout: int):
assert url == "http://provisioner:8002/api/sandboxes"
assert json == {"sandbox_id": "abc123", "thread_id": "thread-1"}
assert timeout == 30
return _StubResponse(payload={"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"})
monkeypatch.setattr(requests, "post", mock_post)
info = backend._provisioner_create("thread-1", "abc123")
assert info.sandbox_id == "abc123"
assert info.sandbox_url == "http://k3s:31001"
def test_provisioner_create_raises_runtime_error_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_post(url: str, json: dict, timeout: int):
raise requests.RequestException("boom")
monkeypatch.setattr(requests, "post", mock_post)
with pytest.raises(RuntimeError, match="Provisioner create failed"):
backend._provisioner_create("thread-1", "abc123")
def test_destroy_delegates_to_provisioner_destroy(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
called: list[str] = []
def mock_destroy(sandbox_id: str):
called.append(sandbox_id)
monkeypatch.setattr(backend, "_provisioner_destroy", mock_destroy)
backend.destroy(SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001"))
assert called == ["abc123"]
def test_provisioner_destroy_calls_delete(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_delete(url: str, timeout: int):
assert url == "http://provisioner:8002/api/sandboxes/abc123"
assert timeout == 15
return _StubResponse(status_code=200)
monkeypatch.setattr(requests, "delete", mock_delete)
backend._provisioner_destroy("abc123")
def test_provisioner_destroy_swallows_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_delete(url: str, timeout: int):
raise requests.RequestException("network down")
monkeypatch.setattr(requests, "delete", mock_delete)
backend._provisioner_destroy("abc123")
def test_is_alive_delegates_to_provisioner_is_alive(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_is_alive(sandbox_id: str):
assert sandbox_id == "abc123"
return True
monkeypatch.setattr(backend, "_provisioner_is_alive", mock_is_alive)
alive = backend.is_alive(SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001"))
assert alive is True
def test_provisioner_is_alive_true_only_when_status_running(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get_running(url: str, timeout: int):
return _StubResponse(payload={"status": "Running"})
monkeypatch.setattr(requests, "get", mock_get_running)
assert backend._provisioner_is_alive("abc123") is True
def mock_get_pending(url: str, timeout: int):
return _StubResponse(payload={"status": "Pending"})
monkeypatch.setattr(requests, "get", mock_get_pending)
assert backend._provisioner_is_alive("abc123") is False
def test_provisioner_is_alive_returns_false_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
raise requests.RequestException("boom")
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_is_alive("abc123") is False
def test_discover_delegates_to_provisioner_discover(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
expected = SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001")
def mock_discover(sandbox_id: str):
assert sandbox_id == "abc123"
return expected
monkeypatch.setattr(backend, "_provisioner_discover", mock_discover)
result = backend.discover("abc123")
assert result == expected
def test_provisioner_discover_returns_none_on_404(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(status_code=404)
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_discover("abc123") is None
def test_provisioner_discover_returns_info_on_success(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(payload={"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"})
monkeypatch.setattr(requests, "get", mock_get)
info = backend._provisioner_discover("abc123")
assert info is not None
assert info.sandbox_id == "abc123"
assert info.sandbox_url == "http://k3s:31001"
def test_provisioner_discover_returns_none_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
raise requests.RequestException("boom")
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_discover("abc123") is None
-71
View File
@@ -310,28 +310,6 @@ class TestDbRunEventStore:
await close_engine()
@pytest.mark.anyio
async def test_structured_content_round_trips(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
content = [{"type": "text", "text": "hello"}, {"type": "image_url", "image_url": {"url": "https://example.test/a.png"}}]
record = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=content)
assert record["content"] == content
assert record["metadata"]["content_is_json"] is True
assert "content_is_dict" not in record["metadata"]
messages = await s.list_messages("t1")
assert messages[0]["content"] == content
assert messages[0]["metadata"]["content_is_json"] is True
await close_engine()
@pytest.mark.anyio
async def test_pagination(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
@@ -395,55 +373,6 @@ class TestDbRunEventStore:
assert seqs == list(range(1, 51))
await close_engine()
@pytest.mark.anyio
async def test_put_batch_accepts_structured_content(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
content = [{"messages": [{"type": "ai", "content": ""}]}]
results = await s.put_batch(
[
{
"thread_id": "t1",
"run_id": "r1",
"event_type": "run.end",
"category": "outputs",
"content": content,
}
]
)
assert results[0]["content"] == content
assert results[0]["metadata"]["content_is_json"] is True
events = await s.list_events("t1", "r1")
assert events[0]["content"] == content
assert events[0]["metadata"]["content_is_json"] is True
await close_engine()
@pytest.mark.anyio
async def test_dict_content_keeps_legacy_metadata_flag(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
content = {"status": "success"}
record = await s.put(thread_id="t1", run_id="r1", event_type="run.end", category="outputs", content=content)
assert record["content"] == content
assert record["metadata"]["content_is_json"] is True
assert record["metadata"]["content_is_dict"] is True
await close_engine()
# -- Factory tests --
-55
View File
@@ -166,61 +166,6 @@ class TestRunRepository:
assert row["total_tokens"] == 100
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("success-run", thread_id="t1", status="running")
await repo.update_run_completion(
"success-run",
status="success",
total_input_tokens=70,
total_output_tokens=30,
total_tokens=100,
lead_agent_tokens=80,
subagent_tokens=15,
middleware_tokens=5,
)
await repo.put("error-run", thread_id="t1", status="running")
await repo.update_run_completion(
"error-run",
status="error",
total_input_tokens=20,
total_output_tokens=30,
total_tokens=50,
lead_agent_tokens=40,
subagent_tokens=10,
)
await repo.put("running-run", thread_id="t1", status="running")
await repo.update_run_completion(
"running-run",
status="running",
total_input_tokens=900,
total_output_tokens=99,
total_tokens=999,
lead_agent_tokens=999,
)
await repo.put("other-thread-run", thread_id="t2", status="running")
await repo.update_run_completion(
"other-thread-run",
status="success",
total_tokens=888,
lead_agent_tokens=888,
)
agg = await repo.aggregate_tokens_by_thread("t1")
assert agg["total_tokens"] == 150
assert agg["total_input_tokens"] == 90
assert agg["total_output_tokens"] == 60
assert agg["total_runs"] == 2
assert agg["by_model"] == {"unknown": {"tokens": 150, "runs": 2}}
assert agg["by_caller"] == {
"lead_agent": 120,
"subagent": 25,
"middleware": 5,
}
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread_ordered_desc(self, tmp_path):
"""list_by_thread returns newest first."""
+16 -61
View File
@@ -3,8 +3,6 @@ from types import SimpleNamespace
from unittest.mock import AsyncMock, call
import pytest
from langgraph.checkpoint.base import empty_checkpoint
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.runtime.runs.manager import RunManager
from deerflow.runtime.runs.schemas import RunStatus
@@ -18,14 +16,6 @@ class FakeCheckpointer:
self.aput_writes = AsyncMock()
def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int):
checkpoint = empty_checkpoint()
checkpoint["id"] = checkpoint_id
checkpoint["channel_values"] = {"messages": messages}
checkpoint["channel_versions"] = {"messages": version}
return checkpoint
def test_build_runtime_context_includes_app_config_when_present():
app_config = object()
@@ -120,16 +110,16 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
)
checkpointer.adelete_thread.assert_not_awaited()
checkpointer.aput.assert_awaited_once()
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
assert restored_checkpoint["id"] != "ckpt-1"
assert "channel_versions" in restored_checkpoint
assert "channel_values" in restored_checkpoint
assert restored_checkpoint["channel_versions"] == {"messages": 3}
assert restored_checkpoint["channel_values"] == {"messages": ["before"]}
assert restored_metadata == {"source": "input"}
assert new_versions == {"messages": 3}
checkpointer.aput.assert_awaited_once_with(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
{
"id": "ckpt-1",
"channel_versions": {"messages": 3},
"channel_values": {"messages": ["before"]},
},
{"source": "input"},
{"messages": 3},
)
assert checkpointer.aput_writes.await_args_list == [
call(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
@@ -144,40 +134,6 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
]
@pytest.mark.anyio
async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer():
checkpointer = InMemorySaver()
thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
before_checkpoint = _make_checkpoint("0001", ["before"], 1)
before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1})
after_checkpoint = _make_checkpoint("0002", ["after"], 2)
after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2})
checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after")
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="0001",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": before_checkpoint,
"metadata": {"step": 1},
"pending_writes": [("task-before", "messages", "pending-before")],
},
snapshot_capture_failed=False,
)
latest = checkpointer.get_tuple(thread_config)
assert latest is not None
assert latest.config["configurable"]["checkpoint_id"] != "0001"
assert latest.config["configurable"]["checkpoint_id"] != "0002"
assert latest.checkpoint["channel_values"] == {"messages": ["before"]}
assert latest.pending_writes == [("task-before", "messages", "pending-before")]
assert ("task-after", "messages", "pending-after") not in latest.pending_writes
@pytest.mark.anyio
async def test_rollback_deletes_thread_when_no_snapshot_exists():
checkpointer = FakeCheckpointer(put_result=None)
@@ -238,13 +194,12 @@ async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
snapshot_capture_failed=False,
)
checkpointer.aput.assert_awaited_once()
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
assert restored_checkpoint["id"] != "ckpt-1"
assert restored_checkpoint["channel_versions"] == {}
assert restored_metadata == {}
assert new_versions == {}
checkpointer.aput.assert_awaited_once_with(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
{"id": "ckpt-1", "channel_versions": {}},
{},
{},
)
@pytest.mark.anyio
-36
View File
@@ -7,7 +7,6 @@ import yaml
from deerflow.config import app_config as app_config_module
from deerflow.config import extensions_config as extensions_config_module
from deerflow.config import skills_config as skills_config_module
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.paths import Paths
@@ -36,7 +35,6 @@ def test_default_runtime_paths_resolve_from_current_project(tmp_path: Path, monk
encoding="utf-8",
)
(tmp_path / "extensions_config.json").write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8")
(tmp_path / "skills").mkdir()
assert AppConfig.resolve_config_path() == tmp_path / "config.yaml"
assert ExtensionsConfig.resolve_config_path() == tmp_path / "extensions_config.json"
@@ -123,40 +121,6 @@ def test_app_config_falls_back_to_legacy_when_project_root_lacks_config(tmp_path
assert AppConfig.resolve_config_path() == legacy_backend_config
def test_skills_config_falls_back_to_legacy_when_project_root_lacks_skills(tmp_path: Path, monkeypatch):
"""When DEER_FLOW_PROJECT_ROOT is unset and cwd has no `skills/`, the legacy
repo-root candidate must be used so monorepo runs (cwd=backend/) keep finding
`<repo>/skills` instead of `<repo>/backend/skills` (regression test for #2694)."""
_clear_path_env(monkeypatch)
cwd = tmp_path / "cwd"
cwd.mkdir()
monkeypatch.chdir(cwd)
legacy_skills = tmp_path / "legacy-repo" / "skills"
legacy_skills.mkdir(parents=True)
monkeypatch.setattr(
skills_config_module,
"_legacy_skills_candidates",
lambda: (legacy_skills,),
)
assert SkillsConfig().get_skills_path() == legacy_skills
def test_skills_config_returns_project_default_when_neither_exists(tmp_path: Path, monkeypatch):
"""When nothing exists, fall back to the project-root default path so callers
surface a stable empty location instead of silently picking a stale legacy dir."""
_clear_path_env(monkeypatch)
cwd = tmp_path / "cwd"
cwd.mkdir()
monkeypatch.chdir(cwd)
monkeypatch.setattr(skills_config_module, "_legacy_skills_candidates", lambda: ())
assert SkillsConfig().get_skills_path() == cwd / "skills"
def test_extensions_config_falls_back_to_legacy_when_project_root_lacks_file(tmp_path: Path, monkeypatch):
"""ExtensionsConfig should hit the legacy backend/repo-root locations when
the caller project root has no extensions_config.json/mcp_config.json."""
-308
View File
@@ -1,308 +0,0 @@
"""Unit tests for the Serper community web search tool."""
import json
from unittest.mock import MagicMock, patch
import httpx
import pytest
@pytest.fixture(autouse=True)
def reset_api_key_warned():
"""Reset the module-level warning flag before each test."""
import deerflow.community.serper.tools as serper_mod
serper_mod._api_key_warned = False
yield
serper_mod._api_key_warned = False
@pytest.fixture
def mock_config_with_key():
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "test-serper-key", "max_results": 5}
mock.return_value.get_tool_config.return_value = tool_config
yield mock
@pytest.fixture
def mock_config_no_key():
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {}
mock.return_value.get_tool_config.return_value = tool_config
yield mock
def _make_serper_response(organic: list) -> MagicMock:
mock_resp = MagicMock()
mock_resp.json.return_value = {"organic": organic}
mock_resp.raise_for_status = MagicMock()
return mock_resp
class TestGetApiKey:
def test_returns_config_key_when_present(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "from-config"}
mock.return_value.get_tool_config.return_value = tool_config
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "from-config"
def test_falls_back_to_env_when_config_key_empty(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": ""}
mock.return_value.get_tool_config.return_value = tool_config
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "env-key"
def test_falls_back_to_env_when_config_key_whitespace(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": " "}
mock.return_value.get_tool_config.return_value = tool_config
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "env-key"
def test_falls_back_to_env_when_config_key_null(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": None}
mock.return_value.get_tool_config.return_value = tool_config
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "env-key"
def test_falls_back_to_env_when_no_config(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
mock.return_value.get_tool_config.return_value = None
with patch.dict("os.environ", {"SERPER_API_KEY": "env-only"}):
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "env-only"
def test_returns_none_when_no_key_anywhere(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
mock.return_value.get_tool_config.return_value = None
with patch.dict("os.environ", {}, clear=True):
import os
os.environ.pop("SERPER_API_KEY", None)
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() is None
class TestWebSearchTool:
def test_basic_search_returns_normalized_results(self, mock_config_with_key):
organic = [
{"title": "Result 1", "link": "https://example.com/1", "snippet": "Snippet 1"},
{"title": "Result 2", "link": "https://example.com/2", "snippet": "Snippet 2"},
]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "python tutorial"})
parsed = json.loads(result)
assert parsed["query"] == "python tutorial"
assert parsed["total_results"] == 2
assert parsed["results"][0]["title"] == "Result 1"
assert parsed["results"][0]["url"] == "https://example.com/1"
assert parsed["results"][0]["content"] == "Snippet 1"
def test_respects_max_results_from_config(self, mock_config_with_key):
mock_config_with_key.return_value.get_tool_config.return_value.model_extra = {
"api_key": "test-key",
"max_results": 3,
}
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert parsed["total_results"] == 3
assert len(parsed["results"]) == 3
def test_max_results_parameter_accepted(self, mock_config_no_key):
"""Tool accepts max_results as a call parameter when config does not override it."""
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
mock_resp = _make_serper_response(organic)
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test", "max_results": 2})
parsed = json.loads(result)
assert parsed["total_results"] == 2
def test_config_max_results_overrides_parameter(self):
"""Config max_results overrides the parameter passed at call time, matching ddg_search behaviour."""
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "test-key", "max_results": 3}
mock.return_value.get_tool_config.return_value = tool_config
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test", "max_results": 8})
parsed = json.loads(result)
assert parsed["total_results"] == 3
def test_empty_organic_returns_error_json(self, mock_config_with_key):
"""Empty organic list returns structured error, matching ddg_search convention."""
mock_resp = _make_serper_response([])
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "no results"})
parsed = json.loads(result)
assert "error" in parsed
assert parsed["error"] == "No results found"
assert parsed["query"] == "no results"
def test_missing_api_key_returns_error_json(self, mock_config_no_key):
with patch.dict("os.environ", {}, clear=True):
import os
os.environ.pop("SERPER_API_KEY", None)
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert "error" in parsed
assert "SERPER_API_KEY" in parsed["error"]
def test_missing_api_key_logs_warning_once(self, mock_config_no_key, caplog):
import logging
with patch.dict("os.environ", {}, clear=True):
import os
os.environ.pop("SERPER_API_KEY", None)
from deerflow.community.serper.tools import web_search_tool
with caplog.at_level(logging.WARNING, logger="deerflow.community.serper.tools"):
web_search_tool.invoke({"query": "q1"})
web_search_tool.invoke({"query": "q2"})
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
assert len(warnings) == 1
def test_http_error_returns_structured_error(self, mock_config_with_key):
mock_error_response = MagicMock()
mock_error_response.status_code = 403
mock_error_response.text = "Forbidden"
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = httpx.HTTPStatusError("403", request=MagicMock(), response=mock_error_response)
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert "error" in parsed
assert "403" in parsed["error"]
def test_network_exception_returns_error_json(self, mock_config_with_key):
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert "error" in parsed
def test_sends_correct_headers_and_payload(self, mock_config_with_key):
organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
web_search_tool.invoke({"query": "hello world"})
call_kwargs = mock_post.call_args
headers = call_kwargs.kwargs["headers"]
payload = call_kwargs.kwargs["json"]
assert headers["X-API-KEY"] == "test-serper-key"
assert payload["q"] == "hello world"
assert payload["num"] == 5
def test_uses_env_key_when_config_absent(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
mock.return_value.get_tool_config.return_value = None
with patch.dict("os.environ", {"SERPER_API_KEY": "env-only-key"}):
organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
web_search_tool.invoke({"query": "env key test"})
headers = mock_post.call_args.kwargs["headers"]
assert headers["X-API-KEY"] == "env-only-key"
def test_partial_fields_in_organic_result(self, mock_config_with_key):
"""Missing title/link/snippet should default to empty string."""
organic = [{}]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert parsed["results"][0] == {"title": "", "url": "", "content": ""}

Some files were not shown because too many files have changed in this diff Show More