Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| dad3997459 | |||
| b67c2a4e56 | |||
| 94da8f67d7 | |||
| 5127f08e1a | |||
| dfa4eb0c1a | |||
| 08ee7adeba | |||
| 1c96a6afc8 | |||
| 417416087b | |||
| 881ff71252 | |||
| f76e4e35c8 | |||
| 0d1053ca44 | |||
| 4063dd7157 | |||
| 7a3c58a733 | |||
| 1edc9d9fae | |||
| 7caf03e97c | |||
| 41b04a556f | |||
| c1b7f1d189 | |||
| 109490da25 | |||
| 14c0a32ee6 | |||
| 70737af7cd | |||
| 2b1fcb3e43 | |||
| 7de9b5828b | |||
| 37db689349 | |||
| bd45cb2846 | |||
| 5fd0e6ac89 | |||
| 530bda7107 | |||
| 6c220a9aef | |||
| daa3ffc29b | |||
| 27559f3675 | |||
| cef4224381 | |||
| 2b0e62f679 | |||
| 1336872b15 | |||
| 4ead2c6b19 | |||
| 59c4a3f0a4 | |||
| e8675f266d | |||
| 680187ddc2 | |||
| aded753de3 | |||
| 028493bfd8 | |||
| 8e48b7e85c | |||
| af6e48ccaa | |||
| b10eb7bafc | |||
| d02f762ab0 | |||
| 82e7936d36 | |||
| 222a7773cb | |||
| f80ac961ec | |||
| 44ab21fc44 | |||
| e543bbf5d6 | |||
| ca3332f8bf | |||
| bb8b234d85 | |||
| 17447fccbe | |||
| 866d1ca409 |
@@ -1,3 +1,6 @@
|
||||
# Serper API Key (Google Search) - https://serper.dev
|
||||
SERPER_API_KEY=your-serper-api-key
|
||||
|
||||
# TAVILY API Key
|
||||
TAVILY_API_KEY=your-tavily-api-key
|
||||
|
||||
@@ -45,3 +48,14 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
||||
|
||||
# Set to "false" to disable Swagger UI, ReDoc, and OpenAPI schema in production
|
||||
# GATEWAY_ENABLE_DOCS=false
|
||||
|
||||
# ── Frontend SSR → Gateway wiring ─────────────────────────────────────────────
|
||||
# The Next.js server uses these to reach the Gateway during SSR (auth checks,
|
||||
# /api/* rewrites). They default to localhost values that match `make dev` and
|
||||
# `make start`, so most local users do not need to set them.
|
||||
#
|
||||
# Override only when the Gateway is not on localhost:8001 (e.g. when the
|
||||
# frontend and gateway run on different hosts, in containers with a service
|
||||
# alias, or behind a different port). docker-compose already sets these.
|
||||
# DEER_FLOW_INTERNAL_GATEWAY_BASE_URL=http://localhost:8001
|
||||
# DEER_FLOW_TRUSTED_ORIGINS=http://localhost:3000,http://localhost:2026
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
name: Publish Containers
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
|
||||
jobs:
|
||||
|
||||
backend-container:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
attestations: write
|
||||
id-token: write
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}-backend
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=tag
|
||||
type=ref,event=branch
|
||||
type=sha
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
- name: Build and push Docker image
|
||||
id: push
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
|
||||
with:
|
||||
context: .
|
||||
file: backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- name: Generate artifact attestation
|
||||
uses: actions/attest-build-provenance@v2
|
||||
with:
|
||||
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
|
||||
subject-digest: ${{ steps.push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
|
||||
frontend-container:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
attestations: write
|
||||
id-token: write
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}-frontend
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=tag
|
||||
type=ref,event=branch
|
||||
type=sha
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
- name: Build and push Docker image
|
||||
id: push
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
|
||||
with:
|
||||
context: .
|
||||
file: frontend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- name: Generate artifact attestation
|
||||
uses: actions/attest-build-provenance@v2
|
||||
with:
|
||||
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
|
||||
subject-digest: ${{ steps.push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
+6
-2
@@ -263,8 +263,10 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
- `present_files` - Make output files visible to user (only `/mnt/user-data/outputs`)
|
||||
- `ask_clarification` - Request clarification (intercepted by ClarificationMiddleware → interrupts)
|
||||
- `view_image` - Read image as base64 (added only if model supports vision)
|
||||
- `setup_agent` - Bootstrap-only: persist a brand-new custom agent's `SOUL.md` and `config.yaml`. Bound only when `is_bootstrap=True`.
|
||||
- `update_agent` - Custom-agent-only: persist self-updates to the current agent's `SOUL.md` / `config.yaml` from inside a normal chat (partial update + atomic write). Bound when `agent_name` is set and `is_bootstrap=False`.
|
||||
4. **Subagent tool** (if enabled):
|
||||
- `task` - Delegate to subagent (description, prompt, subagent_type, max_turns)
|
||||
- `task` - Delegate to subagent (description, prompt, subagent_type)
|
||||
|
||||
**Community tools** (`packages/harness/deerflow/community/`):
|
||||
- `tavily/` - Web search (5 results default) and web fetch (4KB limit)
|
||||
@@ -354,10 +356,11 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
||||
**Per-User Isolation**:
|
||||
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
|
||||
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
|
||||
- Custom agent definitions (`SOUL.md` + `config.yaml`) are also per-user at `{base_dir}/users/{user_id}/agents/{agent_name}/`. The legacy shared layout `{base_dir}/agents/{agent_name}/` remains read-only fallback for unmigrated installations
|
||||
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
|
||||
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
|
||||
- Absolute `storage_path` in config opts out of per-user isolation
|
||||
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
|
||||
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json`, `threads/`, and `agents/` into per-user layout. Supports `--dry-run` (preview changes) and `--user-id USER_ID` (assign unowned legacy data to a user, defaults to `default`).
|
||||
|
||||
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
|
||||
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
||||
@@ -517,6 +520,7 @@ Multi-file upload with automatic document conversion:
|
||||
- Rejects directory inputs before copying so uploads stay all-or-nothing
|
||||
- Reuses one conversion worker per request when called from an active event loop
|
||||
- Files stored in thread-isolated directories
|
||||
- Duplicate filenames in a single upload request are auto-renamed with `_N` suffixes so later files do not truncate earlier files
|
||||
- Agent receives uploaded file list via `UploadsMiddleware`
|
||||
|
||||
See [docs/FILE_UPLOAD.md](docs/FILE_UPLOAD.md) for details.
|
||||
|
||||
@@ -50,6 +50,12 @@ COPY backend ./backend
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
|
||||
|
||||
# UTF-8 locale prevents UnicodeEncodeError on Chinese/emoji content in minimal
|
||||
# containers where locale configuration may be missing and the default encoding is not UTF-8.
|
||||
ENV LANG=C.UTF-8
|
||||
ENV LC_ALL=C.UTF-8
|
||||
ENV PYTHONIOENCODING=utf-8
|
||||
|
||||
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
|
||||
# Retains compiler toolchain from builder so startup-time `uv sync` can build
|
||||
# source distributions in development containers.
|
||||
@@ -66,6 +72,10 @@ CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app
|
||||
# Clean image without build-essential — reduces size (~200 MB) and attack surface.
|
||||
FROM python:3.12-slim-bookworm
|
||||
|
||||
ENV LANG=C.UTF-8
|
||||
ENV LC_ALL=C.UTF-8
|
||||
ENV PYTHONIOENCODING=utf-8
|
||||
|
||||
# Copy Node.js runtime from builder (provides npx for MCP servers)
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
|
||||
+1
-1
@@ -124,7 +124,7 @@ FastAPI application providing REST endpoints for frontend integration:
|
||||
| `POST /api/memory/reload` | Force memory reload |
|
||||
| `GET /api/memory/config` | Memory configuration |
|
||||
| `GET /api/memory/status` | Combined config + data |
|
||||
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths) |
|
||||
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths, auto-renames duplicate filenames in one request) |
|
||||
| `GET /api/threads/{id}/uploads/list` | List uploaded files |
|
||||
| `DELETE /api/threads/{id}` | Delete DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
||||
| `GET /api/threads/{id}/artifacts/{path}` | Serve generated artifacts |
|
||||
|
||||
@@ -146,6 +146,13 @@ def _normalize_custom_agent_name(raw_value: str) -> str:
|
||||
return normalized
|
||||
|
||||
|
||||
def _strip_loop_warning_text(text: str) -> str:
|
||||
"""Remove middleware-authored loop warning lines from display text."""
|
||||
if "[LOOP DETECTED]" not in text:
|
||||
return text
|
||||
return "\n".join(line for line in text.splitlines() if "[LOOP DETECTED]" not in line).strip()
|
||||
|
||||
|
||||
def _extract_response_text(result: dict | list) -> str:
|
||||
"""Extract the last AI message text from a LangGraph runs.wait result.
|
||||
|
||||
@@ -155,7 +162,7 @@ def _extract_response_text(result: dict | list) -> str:
|
||||
Handles special cases:
|
||||
- Regular AI text responses
|
||||
- Clarification interrupts (``ask_clarification`` tool messages)
|
||||
- AI messages with tool_calls but no text content
|
||||
- Strips loop-detection warnings attached to tool-call AI messages
|
||||
"""
|
||||
if isinstance(result, list):
|
||||
messages = result
|
||||
@@ -185,7 +192,12 @@ def _extract_response_text(result: dict | list) -> str:
|
||||
# Regular AI message with text content
|
||||
if msg_type == "ai":
|
||||
content = msg.get("content", "")
|
||||
has_tool_calls = bool(msg.get("tool_calls"))
|
||||
if isinstance(content, str) and content:
|
||||
if has_tool_calls:
|
||||
content = _strip_loop_warning_text(content)
|
||||
if not content:
|
||||
continue
|
||||
return content
|
||||
# content can be a list of content blocks
|
||||
if isinstance(content, list):
|
||||
@@ -196,6 +208,8 @@ def _extract_response_text(result: dict | list) -> str:
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
text = "".join(parts)
|
||||
if has_tool_calls:
|
||||
text = _strip_loop_warning_text(text)
|
||||
if text:
|
||||
return text
|
||||
return ""
|
||||
@@ -420,7 +434,13 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
||||
if not msg.files:
|
||||
return []
|
||||
|
||||
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
|
||||
from deerflow.uploads.manager import (
|
||||
UnsafeUploadPathError,
|
||||
claim_unique_filename,
|
||||
ensure_uploads_dir,
|
||||
normalize_filename,
|
||||
write_upload_file_no_symlink,
|
||||
)
|
||||
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
||||
@@ -471,7 +491,10 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
||||
|
||||
dest = uploads_dir / safe_name
|
||||
try:
|
||||
dest.write_bytes(data)
|
||||
dest = write_upload_file_no_symlink(uploads_dir, safe_name, data)
|
||||
except UnsafeUploadPathError:
|
||||
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
|
||||
continue
|
||||
except Exception:
|
||||
logger.exception("[Manager] failed to write inbound file: %s", dest)
|
||||
continue
|
||||
@@ -580,6 +603,17 @@ class ChannelManager:
|
||||
user_layer.get("config"),
|
||||
)
|
||||
|
||||
configurable = run_config.get("configurable")
|
||||
if isinstance(configurable, Mapping):
|
||||
configurable = dict(configurable)
|
||||
else:
|
||||
configurable = {}
|
||||
run_config["configurable"] = configurable
|
||||
# Pin channel-triggered runs to the root graph namespace so follow-up
|
||||
# turns continue from the same conversation checkpoint.
|
||||
configurable["checkpoint_ns"] = ""
|
||||
configurable["thread_id"] = thread_id
|
||||
|
||||
run_context = _merge_dicts(
|
||||
DEFAULT_RUN_CONTEXT,
|
||||
self._default_session.get("context"),
|
||||
@@ -963,7 +997,11 @@ class ChannelManager:
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as http:
|
||||
resp = await http.get(f"{self._gateway_url}{path}", timeout=10)
|
||||
resp = await http.get(
|
||||
f"{self._gateway_url}{path}",
|
||||
timeout=10,
|
||||
headers=create_internal_auth_headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
|
||||
@@ -4,8 +4,10 @@ Per RFC-001:
|
||||
State-changing operations require CSRF protection.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
@@ -19,7 +21,7 @@ CSRF_TOKEN_LENGTH = 64 # bytes
|
||||
|
||||
def is_secure_request(request: Request) -> bool:
|
||||
"""Detect whether the original client request was made over HTTPS."""
|
||||
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
||||
return _request_scheme(request) == "https"
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
@@ -61,6 +63,109 @@ def is_auth_endpoint(request: Request) -> bool:
|
||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||
|
||||
|
||||
def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str:
|
||||
"""Return normalized host[:port], omitting default ports."""
|
||||
host = hostname.lower()
|
||||
if ":" in host and not host.startswith("["):
|
||||
host = f"[{host}]"
|
||||
|
||||
if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
|
||||
return host
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
def _normalize_origin(origin: str) -> str | None:
|
||||
"""Return a normalized scheme://host[:port] origin, or None for invalid input."""
|
||||
try:
|
||||
parsed = urlsplit(origin.strip())
|
||||
port = parsed.port
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
scheme = parsed.scheme.lower()
|
||||
if scheme not in {"http", "https"} or not parsed.hostname:
|
||||
return None
|
||||
|
||||
# Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values.
|
||||
if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment:
|
||||
return None
|
||||
|
||||
return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}"
|
||||
|
||||
|
||||
def _configured_cors_origins() -> set[str]:
|
||||
"""Return explicit configured browser origins that may call auth routes."""
|
||||
origins = set()
|
||||
for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","):
|
||||
origin = raw_origin.strip()
|
||||
if not origin or origin == "*":
|
||||
continue
|
||||
normalized = _normalize_origin(origin)
|
||||
if normalized:
|
||||
origins.add(normalized)
|
||||
return origins
|
||||
|
||||
|
||||
def _first_header_value(value: str | None) -> str | None:
|
||||
"""Return the first value from a comma-separated proxy header."""
|
||||
if not value:
|
||||
return None
|
||||
first = value.split(",", 1)[0].strip()
|
||||
return first or None
|
||||
|
||||
|
||||
def _forwarded_param(request: Request, name: str) -> str | None:
|
||||
"""Extract a parameter from the first RFC 7239 Forwarded header entry."""
|
||||
forwarded = _first_header_value(request.headers.get("forwarded"))
|
||||
if not forwarded:
|
||||
return None
|
||||
|
||||
for part in forwarded.split(";"):
|
||||
key, sep, value = part.strip().partition("=")
|
||||
if sep and key.lower() == name:
|
||||
return value.strip().strip('"') or None
|
||||
return None
|
||||
|
||||
|
||||
def _request_scheme(request: Request) -> str:
|
||||
"""Resolve the original request scheme from trusted proxy headers."""
|
||||
scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme
|
||||
return scheme.lower()
|
||||
|
||||
|
||||
def _request_origin(request: Request) -> str | None:
|
||||
"""Build the origin for the URL the browser is targeting."""
|
||||
scheme = _request_scheme(request)
|
||||
host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc
|
||||
|
||||
forwarded_port = _first_header_value(request.headers.get("x-forwarded-port"))
|
||||
if forwarded_port and ":" not in host.rsplit("]", 1)[-1]:
|
||||
host = f"{host}:{forwarded_port}"
|
||||
|
||||
return _normalize_origin(f"{scheme}://{host}")
|
||||
|
||||
|
||||
def is_allowed_auth_origin(request: Request) -> bool:
|
||||
"""Allow auth POSTs only from the same origin or explicit configured origins.
|
||||
|
||||
Login/register/initialize are exempt from the double-submit token because
|
||||
first-time browser clients do not have a CSRF token yet. They still create
|
||||
a session cookie, so browser requests with a hostile Origin header must be
|
||||
rejected to prevent login CSRF / session fixation. Requests without Origin
|
||||
are allowed for non-browser clients such as curl and mobile integrations.
|
||||
"""
|
||||
origin = request.headers.get("origin")
|
||||
if not origin:
|
||||
return True
|
||||
|
||||
normalized_origin = _normalize_origin(origin)
|
||||
if normalized_origin is None:
|
||||
return False
|
||||
|
||||
request_origin = _request_origin(request)
|
||||
return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin)
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
||||
|
||||
@@ -70,6 +175,12 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
_is_auth = is_auth_endpoint(request)
|
||||
|
||||
if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Cross-site auth request denied."},
|
||||
)
|
||||
|
||||
if should_check_csrf(request) and not _is_auth:
|
||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||
header_token = request.headers.get(CSRF_HEADER_NAME)
|
||||
|
||||
@@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
|
||||
from deerflow.config.agents_api_config import get_agents_api_config
|
||||
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api", tags=["agents"])
|
||||
@@ -86,11 +87,11 @@ def _require_agents_api_enabled() -> None:
|
||||
)
|
||||
|
||||
|
||||
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
|
||||
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False, *, user_id: str | None = None) -> AgentResponse:
|
||||
"""Convert AgentConfig to AgentResponse."""
|
||||
soul: str | None = None
|
||||
if include_soul:
|
||||
soul = load_agent_soul(agent_cfg.name) or ""
|
||||
soul = load_agent_soul(agent_cfg.name, user_id=user_id) or ""
|
||||
|
||||
return AgentResponse(
|
||||
name=agent_cfg.name,
|
||||
@@ -116,9 +117,10 @@ async def list_agents() -> AgentsListResponse:
|
||||
"""
|
||||
_require_agents_api_enabled()
|
||||
|
||||
user_id = get_effective_user_id()
|
||||
try:
|
||||
agents = list_custom_agents()
|
||||
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
|
||||
agents = list_custom_agents(user_id=user_id)
|
||||
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True, user_id=user_id) for a in agents])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list agents: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
|
||||
@@ -144,7 +146,12 @@ async def check_agent_name(name: str) -> dict:
|
||||
_require_agents_api_enabled()
|
||||
_validate_agent_name(name)
|
||||
normalized = _normalize_agent_name(name)
|
||||
available = not get_paths().agent_dir(normalized).exists()
|
||||
user_id = get_effective_user_id()
|
||||
paths = get_paths()
|
||||
# Treat the name as taken if either the per-user path or the legacy shared
|
||||
# path holds an agent — picking a name that collides with an unmigrated
|
||||
# legacy agent would shadow the legacy entry once migration runs.
|
||||
available = not paths.user_agent_dir(user_id, normalized).exists() and not paths.agent_dir(normalized).exists()
|
||||
return {"available": available, "name": normalized}
|
||||
|
||||
|
||||
@@ -169,10 +176,11 @@ async def get_agent(name: str) -> AgentResponse:
|
||||
_require_agents_api_enabled()
|
||||
_validate_agent_name(name)
|
||||
name = _normalize_agent_name(name)
|
||||
user_id = get_effective_user_id()
|
||||
|
||||
try:
|
||||
agent_cfg = load_agent_config(name)
|
||||
return _agent_config_to_response(agent_cfg, include_soul=True)
|
||||
agent_cfg = load_agent_config(name, user_id=user_id)
|
||||
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||
except Exception as e:
|
||||
@@ -202,10 +210,13 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
||||
_require_agents_api_enabled()
|
||||
_validate_agent_name(request.name)
|
||||
normalized_name = _normalize_agent_name(request.name)
|
||||
user_id = get_effective_user_id()
|
||||
paths = get_paths()
|
||||
|
||||
agent_dir = get_paths().agent_dir(normalized_name)
|
||||
agent_dir = paths.user_agent_dir(user_id, normalized_name)
|
||||
legacy_dir = paths.agent_dir(normalized_name)
|
||||
|
||||
if agent_dir.exists():
|
||||
if agent_dir.exists() or legacy_dir.exists():
|
||||
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
|
||||
|
||||
try:
|
||||
@@ -232,8 +243,8 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
||||
|
||||
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
|
||||
|
||||
agent_cfg = load_agent_config(normalized_name)
|
||||
return _agent_config_to_response(agent_cfg, include_soul=True)
|
||||
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
|
||||
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -267,13 +278,20 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
|
||||
_require_agents_api_enabled()
|
||||
_validate_agent_name(name)
|
||||
name = _normalize_agent_name(name)
|
||||
user_id = get_effective_user_id()
|
||||
|
||||
try:
|
||||
agent_cfg = load_agent_config(name)
|
||||
agent_cfg = load_agent_config(name, user_id=user_id)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||
|
||||
agent_dir = get_paths().agent_dir(name)
|
||||
paths = get_paths()
|
||||
agent_dir = paths.user_agent_dir(user_id, name)
|
||||
if not agent_dir.exists() and paths.agent_dir(name).exists():
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before updating."),
|
||||
)
|
||||
|
||||
try:
|
||||
# Update config if any config fields changed
|
||||
@@ -314,8 +332,8 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
|
||||
|
||||
logger.info(f"Updated agent '{name}'")
|
||||
|
||||
refreshed_cfg = load_agent_config(name)
|
||||
return _agent_config_to_response(refreshed_cfg, include_soul=True)
|
||||
refreshed_cfg = load_agent_config(name, user_id=user_id)
|
||||
return _agent_config_to_response(refreshed_cfg, include_soul=True, user_id=user_id)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -402,15 +420,22 @@ async def delete_agent(name: str) -> None:
|
||||
name: The agent name.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if agent not found.
|
||||
HTTPException: 404 if no per-user copy exists; 409 if only a legacy
|
||||
shared copy exists (suggesting the migration script).
|
||||
"""
|
||||
_require_agents_api_enabled()
|
||||
_validate_agent_name(name)
|
||||
name = _normalize_agent_name(name)
|
||||
|
||||
agent_dir = get_paths().agent_dir(name)
|
||||
user_id = get_effective_user_id()
|
||||
paths = get_paths()
|
||||
agent_dir = paths.user_agent_dir(user_id, name)
|
||||
|
||||
if not agent_dir.exists():
|
||||
if paths.agent_dir(name).exists():
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
|
||||
)
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||
|
||||
try:
|
||||
|
||||
@@ -68,6 +68,27 @@ class RunResponse(BaseModel):
|
||||
updated_at: str = ""
|
||||
|
||||
|
||||
class ThreadTokenUsageModelBreakdown(BaseModel):
|
||||
tokens: int = 0
|
||||
runs: int = 0
|
||||
|
||||
|
||||
class ThreadTokenUsageCallerBreakdown(BaseModel):
|
||||
lead_agent: int = 0
|
||||
subagent: int = 0
|
||||
middleware: int = 0
|
||||
|
||||
|
||||
class ThreadTokenUsageResponse(BaseModel):
|
||||
thread_id: str
|
||||
total_tokens: int = 0
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_runs: int = 0
|
||||
by_model: dict[str, ThreadTokenUsageModelBreakdown] = Field(default_factory=dict)
|
||||
by_caller: ThreadTokenUsageCallerBreakdown = Field(default_factory=ThreadTokenUsageCallerBreakdown)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -368,10 +389,10 @@ async def list_run_events(
|
||||
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/token-usage")
|
||||
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def thread_token_usage(thread_id: str, request: Request) -> dict:
|
||||
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
|
||||
"""Thread-level token usage aggregation."""
|
||||
run_store = get_run_store(request)
|
||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
||||
return {"thread_id": thread_id, **agg}
|
||||
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
|
||||
|
||||
@@ -13,11 +13,11 @@ matching the LangGraph Platform wire format expected by the
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
@@ -26,6 +26,7 @@ from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.utils.time import coerce_iso, now_iso
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||
@@ -233,7 +234,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_store = get_thread_store(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = time.time()
|
||||
now = now_iso()
|
||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||
|
||||
@@ -243,8 +244,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=existing_record.get("status", "idle"),
|
||||
created_at=str(existing_record.get("created_at", "")),
|
||||
updated_at=str(existing_record.get("updated_at", "")),
|
||||
created_at=coerce_iso(existing_record.get("created_at", "")),
|
||||
updated_at=coerce_iso(existing_record.get("updated_at", "")),
|
||||
metadata=existing_record.get("metadata", {}),
|
||||
)
|
||||
|
||||
@@ -262,8 +263,6 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
# Write an empty checkpoint so state endpoints work immediately
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
ckpt_metadata = {
|
||||
"step": -1,
|
||||
"source": "input",
|
||||
@@ -281,8 +280,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status="idle",
|
||||
created_at=str(now),
|
||||
updated_at=str(now),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
|
||||
@@ -307,8 +306,11 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
ThreadResponse(
|
||||
thread_id=r["thread_id"],
|
||||
status=r.get("status", "idle"),
|
||||
created_at=r.get("created_at", ""),
|
||||
updated_at=r.get("updated_at", ""),
|
||||
# ``coerce_iso`` heals legacy unix-second values that
|
||||
# ``MemoryThreadMetaStore`` historically wrote with ``time.time()``;
|
||||
# SQL-backed rows already arrive as ISO strings and pass through.
|
||||
created_at=coerce_iso(r.get("created_at", "")),
|
||||
updated_at=coerce_iso(r.get("updated_at", "")),
|
||||
metadata=r.get("metadata", {}),
|
||||
values={"title": r["display_name"]} if r.get("display_name") else {},
|
||||
interrupts={},
|
||||
@@ -340,8 +342,8 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=record.get("status", "idle"),
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
created_at=coerce_iso(record.get("created_at", "")),
|
||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
)
|
||||
|
||||
@@ -381,8 +383,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"status": "idle",
|
||||
"created_at": ckpt_meta.get("created_at", ""),
|
||||
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
|
||||
"created_at": coerce_iso(ckpt_meta.get("created_at", "")),
|
||||
"updated_at": coerce_iso(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
|
||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
||||
}
|
||||
|
||||
@@ -396,8 +398,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=status,
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
created_at=coerce_iso(record.get("created_at", "")),
|
||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
values=serialize_channel_values(channel_values),
|
||||
)
|
||||
@@ -448,10 +450,10 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
||||
values=values,
|
||||
next=next_tasks,
|
||||
metadata=metadata,
|
||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||
checkpoint={"id": checkpoint_id, "ts": coerce_iso(metadata.get("created_at", ""))},
|
||||
checkpoint_id=checkpoint_id,
|
||||
parent_checkpoint_id=parent_checkpoint_id,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
||||
tasks=tasks,
|
||||
)
|
||||
|
||||
@@ -501,7 +503,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
channel_values.update(body.values)
|
||||
|
||||
checkpoint["channel_values"] = channel_values
|
||||
metadata["updated_at"] = time.time()
|
||||
metadata["updated_at"] = now_iso()
|
||||
|
||||
if body.as_node:
|
||||
metadata["source"] = "update"
|
||||
@@ -542,7 +544,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
next=[],
|
||||
metadata=metadata,
|
||||
checkpoint_id=new_checkpoint_id,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
||||
)
|
||||
|
||||
|
||||
@@ -609,7 +611,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
parent_checkpoint_id=parent_id,
|
||||
metadata=user_meta,
|
||||
values=values,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
||||
next=next_tasks,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import stat
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_config
|
||||
@@ -15,12 +15,15 @@ from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
|
||||
from deerflow.uploads.manager import (
|
||||
PathTraversalError,
|
||||
UnsafeUploadPathError,
|
||||
claim_unique_filename,
|
||||
delete_file_safe,
|
||||
enrich_file_listing,
|
||||
ensure_uploads_dir,
|
||||
get_uploads_dir,
|
||||
list_files_in_dir,
|
||||
normalize_filename,
|
||||
open_upload_file_no_symlink,
|
||||
upload_artifact_url,
|
||||
upload_virtual_path,
|
||||
)
|
||||
@@ -42,6 +45,7 @@ class UploadResponse(BaseModel):
|
||||
success: bool
|
||||
files: list[dict[str, str]]
|
||||
message: str
|
||||
skipped_files: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UploadLimits(BaseModel):
|
||||
@@ -116,17 +120,18 @@ def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None:
|
||||
logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True)
|
||||
|
||||
|
||||
async def _write_upload_file_streaming(
|
||||
async def _write_upload_file_with_limits(
|
||||
file: UploadFile,
|
||||
file_path: os.PathLike[str] | str,
|
||||
*,
|
||||
uploads_dir: os.PathLike[str] | str,
|
||||
display_filename: str,
|
||||
max_single_file_size: int,
|
||||
max_total_size: int,
|
||||
total_size: int,
|
||||
) -> tuple[int, int]:
|
||||
) -> tuple[os.PathLike[str] | str, int, int]:
|
||||
file_size = 0
|
||||
with open(file_path, "wb") as output:
|
||||
file_path, fh = open_upload_file_no_symlink(uploads_dir, display_filename)
|
||||
try:
|
||||
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
|
||||
file_size += len(chunk)
|
||||
total_size += len(chunk)
|
||||
@@ -134,8 +139,17 @@ async def _write_upload_file_streaming(
|
||||
raise HTTPException(status_code=413, detail=f"File too large: {display_filename}")
|
||||
if total_size > max_total_size:
|
||||
raise HTTPException(status_code=413, detail="Total upload size too large")
|
||||
output.write(chunk)
|
||||
return file_size, total_size
|
||||
fh.write(chunk)
|
||||
except Exception:
|
||||
fh.close()
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
raise
|
||||
else:
|
||||
fh.close()
|
||||
return file_path, file_size, total_size
|
||||
|
||||
|
||||
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
|
||||
@@ -177,7 +191,12 @@ async def upload_files(
|
||||
uploaded_files = []
|
||||
written_paths = []
|
||||
sandbox_sync_targets = []
|
||||
skipped_files = []
|
||||
total_size = 0
|
||||
# Track filenames within this request so duplicate form parts do not
|
||||
# silently truncate each other. Existing uploads keep the historical
|
||||
# overwrite behavior for a single replacement upload.
|
||||
seen_filenames: set[str] = set()
|
||||
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
|
||||
@@ -194,22 +213,22 @@ async def upload_files(
|
||||
continue
|
||||
|
||||
try:
|
||||
safe_filename = normalize_filename(file.filename)
|
||||
original_filename = normalize_filename(file.filename)
|
||||
safe_filename = claim_unique_filename(original_filename, seen_filenames)
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
||||
continue
|
||||
|
||||
try:
|
||||
file_path = uploads_dir / safe_filename
|
||||
written_paths.append(file_path)
|
||||
file_size, total_size = await _write_upload_file_streaming(
|
||||
file_path, file_size, total_size = await _write_upload_file_with_limits(
|
||||
file,
|
||||
file_path,
|
||||
uploads_dir=uploads_dir,
|
||||
display_filename=safe_filename,
|
||||
max_single_file_size=limits.max_file_size,
|
||||
max_total_size=limits.max_total_size,
|
||||
total_size=total_size,
|
||||
)
|
||||
written_paths.append(file_path)
|
||||
|
||||
virtual_path = upload_virtual_path(safe_filename)
|
||||
|
||||
@@ -223,6 +242,8 @@ async def upload_files(
|
||||
"virtual_path": virtual_path,
|
||||
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
||||
}
|
||||
if safe_filename != original_filename:
|
||||
file_info["original_filename"] = original_filename
|
||||
|
||||
logger.info(f"Saved file: {safe_filename} ({file_size} bytes) to {file_info['path']}")
|
||||
|
||||
@@ -246,6 +267,10 @@ async def upload_files(
|
||||
except HTTPException as e:
|
||||
_cleanup_uploaded_paths(written_paths)
|
||||
raise e
|
||||
except UnsafeUploadPathError as e:
|
||||
logger.warning("Skipping upload with unsafe destination %s: %s", file.filename, e)
|
||||
skipped_files.append(safe_filename)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload {file.filename}: {e}")
|
||||
_cleanup_uploaded_paths(written_paths)
|
||||
@@ -256,10 +281,15 @@ async def upload_files(
|
||||
_make_file_sandbox_writable(file_path)
|
||||
sandbox.update_file(virtual_path, file_path.read_bytes())
|
||||
|
||||
message = f"Successfully uploaded {len(uploaded_files)} file(s)"
|
||||
if skipped_files:
|
||||
message += f"; skipped {len(skipped_files)} unsafe file(s)"
|
||||
|
||||
return UploadResponse(
|
||||
success=True,
|
||||
success=not skipped_files,
|
||||
files=uploaded_files,
|
||||
message=f"Successfully uploaded {len(uploaded_files)} file(s)",
|
||||
message=message,
|
||||
skipped_files=skipped_files,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -136,6 +136,24 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
|
||||
runtime_context.setdefault(key, context[key])
|
||||
|
||||
|
||||
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
|
||||
"""Stamp the authenticated user into the run context for background tools.
|
||||
|
||||
Tool execution may happen after the request handler has returned, so tools
|
||||
that persist user-scoped files should not rely only on ambient ContextVars.
|
||||
The value comes from server-side auth state, never from client context.
|
||||
"""
|
||||
|
||||
user = getattr(request.state, "user", None)
|
||||
user_id = getattr(user, "id", None)
|
||||
if user_id is None:
|
||||
return
|
||||
|
||||
runtime_context = config.setdefault("context", {})
|
||||
if isinstance(runtime_context, dict):
|
||||
runtime_context["user_id"] = str(user_id)
|
||||
|
||||
|
||||
def resolve_agent_factory(assistant_id: str | None):
|
||||
"""Resolve the agent factory callable from config.
|
||||
|
||||
@@ -288,6 +306,7 @@ async def start_run(
|
||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||
inject_authenticated_user_context(config, request)
|
||||
|
||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||
|
||||
|
||||
@@ -79,7 +79,9 @@ async def main():
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents import make_lead_agent
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.mcp import initialize_mcp_tools
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
# Initialize MCP tools at startup
|
||||
try:
|
||||
@@ -113,6 +115,8 @@ async def main():
|
||||
print("Tip: `uv sync --group dev` to enable arrow-key & history support")
|
||||
print("=" * 50)
|
||||
|
||||
seen_artifacts: set[str] = set()
|
||||
|
||||
while True:
|
||||
try:
|
||||
if session:
|
||||
@@ -134,6 +138,22 @@ async def main():
|
||||
last_message = result["messages"][-1]
|
||||
print(f"\nAgent: {last_message.content}")
|
||||
|
||||
# Show files presented to the user this turn (new artifacts only)
|
||||
artifacts = result.get("artifacts") or []
|
||||
new_artifacts = [p for p in artifacts if p not in seen_artifacts]
|
||||
if new_artifacts:
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
user_id = get_effective_user_id()
|
||||
paths = get_paths()
|
||||
print("\n[Presented files]")
|
||||
for virtual in new_artifacts:
|
||||
try:
|
||||
physical = paths.resolve_virtual_path(thread_id, virtual, user_id=user_id)
|
||||
print(f" - {virtual}\n → {physical}")
|
||||
except ValueError as exc:
|
||||
print(f" - {virtual} (failed to resolve physical path: {exc})")
|
||||
seen_artifacts.update(new_artifacts)
|
||||
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
|
||||
@@ -173,7 +173,7 @@ def _assemble_from_features(
|
||||
9. MemoryMiddleware (memory feature)
|
||||
10. ViewImageMiddleware (vision feature)
|
||||
11. SubagentLimitMiddleware (subagent feature)
|
||||
12. LoopDetectionMiddleware (always)
|
||||
12. LoopDetectionMiddleware (loop_detection feature)
|
||||
13. ClarificationMiddleware (always last)
|
||||
|
||||
Two-phase ordering:
|
||||
@@ -272,10 +272,15 @@ def _assemble_from_features(
|
||||
|
||||
extra_tools.append(task_tool)
|
||||
|
||||
# --- [12] LoopDetection (always) ---
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
# --- [12] LoopDetection ---
|
||||
if feat.loop_detection is not False:
|
||||
if isinstance(feat.loop_detection, AgentMiddleware):
|
||||
chain.append(feat.loop_detection)
|
||||
else:
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
|
||||
chain.append(LoopDetectionMiddleware())
|
||||
chain.append(LoopDetectionMiddleware.from_config(LoopDetectionConfig()))
|
||||
|
||||
# --- [13] Clarification (always last among built-ins) ---
|
||||
chain.append(ClarificationMiddleware())
|
||||
|
||||
@@ -31,6 +31,7 @@ class RuntimeFeatures:
|
||||
vision: bool | AgentMiddleware = False
|
||||
auto_title: bool | AgentMiddleware = False
|
||||
guardrail: Literal[False] | AgentMiddleware = False
|
||||
loop_detection: bool | AgentMiddleware = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -20,6 +20,8 @@ from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.agents_config import load_agent_config, validate_agent_name
|
||||
from deerflow.config.app_config import AppConfig, get_app_config
|
||||
from deerflow.models import create_chat_model
|
||||
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -256,6 +258,12 @@ def _build_middlewares(
|
||||
resolved_app_config = app_config or get_app_config()
|
||||
middlewares = build_lead_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
|
||||
|
||||
# Always inject current date (and optionally memory) as <system-reminder> into the
|
||||
# first HumanMessage to keep the system prompt fully static for prefix-cache reuse.
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import DynamicContextMiddleware
|
||||
|
||||
middlewares.append(DynamicContextMiddleware(agent_name=agent_name, app_config=resolved_app_config))
|
||||
|
||||
# Add summarization middleware if enabled
|
||||
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
|
||||
if summarization_middleware is not None:
|
||||
@@ -297,7 +305,9 @@ def _build_middlewares(
|
||||
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
|
||||
|
||||
# LoopDetectionMiddleware — detect and break repetitive tool call loops
|
||||
middlewares.append(LoopDetectionMiddleware())
|
||||
loop_detection_config = resolved_app_config.loop_detection
|
||||
if loop_detection_config.enabled:
|
||||
middlewares.append(LoopDetectionMiddleware.from_config(loop_detection_config))
|
||||
|
||||
# Inject custom middlewares before ClarificationMiddleware
|
||||
if custom_middlewares:
|
||||
@@ -308,6 +318,28 @@ def _build_middlewares(
|
||||
return middlewares
|
||||
|
||||
|
||||
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
|
||||
if is_bootstrap:
|
||||
return {"bootstrap"}
|
||||
if agent_config and agent_config.skills is not None:
|
||||
return set(agent_config.skills)
|
||||
return None
|
||||
|
||||
|
||||
def _load_enabled_skills_for_tool_policy(available_skills: set[str] | None, *, app_config: AppConfig) -> list[Skill]:
|
||||
try:
|
||||
from deerflow.agents.lead_agent.prompt import get_enabled_skills_for_config
|
||||
|
||||
skills = get_enabled_skills_for_config(app_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to load skills for allowed-tools policy")
|
||||
raise
|
||||
|
||||
if available_skills is None:
|
||||
return skills
|
||||
return [skill for skill in skills if skill.name in available_skills]
|
||||
|
||||
|
||||
def make_lead_agent(config: RunnableConfig):
|
||||
"""LangGraph graph factory; keep the signature compatible with LangGraph Server."""
|
||||
runtime_config = _get_runtime_config(config)
|
||||
@@ -318,7 +350,7 @@ def make_lead_agent(config: RunnableConfig):
|
||||
def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
# Lazy import to avoid circular dependency
|
||||
from deerflow.tools import get_available_tools
|
||||
from deerflow.tools.builtins import setup_agent
|
||||
from deerflow.tools.builtins import setup_agent, update_agent
|
||||
|
||||
cfg = _get_runtime_config(config)
|
||||
resolved_app_config = app_config
|
||||
@@ -333,6 +365,7 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
agent_name = validate_agent_name(cfg.get("agent_name"))
|
||||
|
||||
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
|
||||
available_skills = _available_skill_names(agent_config, is_bootstrap)
|
||||
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
|
||||
agent_model_name = agent_config.model if agent_config and agent_config.model else None
|
||||
|
||||
@@ -371,15 +404,18 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
"is_plan_mode": is_plan_mode,
|
||||
"subagent_enabled": subagent_enabled,
|
||||
"tool_groups": agent_config.tool_groups if agent_config else None,
|
||||
"available_skills": ["bootstrap"] if is_bootstrap else (agent_config.skills if agent_config and agent_config.skills is not None else None),
|
||||
"available_skills": sorted(available_skills) if available_skills is not None else None,
|
||||
}
|
||||
)
|
||||
|
||||
skills_for_tool_policy = _load_enabled_skills_for_tool_policy(available_skills, app_config=resolved_app_config)
|
||||
|
||||
if is_bootstrap:
|
||||
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
||||
tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent],
|
||||
tools=filter_tools_by_skill_allowed_tools(tools, skills_for_tool_policy),
|
||||
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled,
|
||||
@@ -390,15 +426,14 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
# Custom agents can update their own SOUL.md / config via update_agent.
|
||||
# The default agent (no agent_name) does not see this tool.
|
||||
extra_tools = [update_agent] if agent_name else []
|
||||
# Default lead agent (unchanged behavior)
|
||||
tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config),
|
||||
tools=get_available_tools(
|
||||
model_name=model_name,
|
||||
groups=agent_config.tool_groups if agent_config else None,
|
||||
subagent_enabled=subagent_enabled,
|
||||
app_config=resolved_app_config,
|
||||
),
|
||||
tools=filter_tools_by_skill_allowed_tools(tools + extra_tools, skills_for_tool_policy),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled,
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -20,6 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
|
||||
_enabled_skills_lock = threading.Lock()
|
||||
_enabled_skills_cache: list[Skill] | None = None
|
||||
_enabled_skills_by_config_cache: dict[int, tuple[object, list[Skill]]] = {}
|
||||
_enabled_skills_refresh_active = False
|
||||
_enabled_skills_refresh_version = 0
|
||||
_enabled_skills_refresh_event = threading.Event()
|
||||
@@ -84,6 +84,7 @@ def _invalidate_enabled_skills_cache() -> threading.Event:
|
||||
_get_cached_skills_prompt_section.cache_clear()
|
||||
with _enabled_skills_lock:
|
||||
_enabled_skills_cache = None
|
||||
_enabled_skills_by_config_cache.clear()
|
||||
_enabled_skills_refresh_version += 1
|
||||
_enabled_skills_refresh_event.clear()
|
||||
if _enabled_skills_refresh_active:
|
||||
@@ -107,6 +108,15 @@ def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_W
|
||||
|
||||
|
||||
def _get_enabled_skills():
|
||||
return get_cached_enabled_skills()
|
||||
|
||||
|
||||
def get_cached_enabled_skills() -> list[Skill]:
|
||||
"""Return the cached enabled-skills list, kicking off a background refresh on miss.
|
||||
|
||||
Safe to call from request paths: never blocks on disk I/O. Returns an empty
|
||||
list on cache miss; the next call will see the warmed result.
|
||||
"""
|
||||
with _enabled_skills_lock:
|
||||
cached = _enabled_skills_cache
|
||||
|
||||
@@ -117,17 +127,29 @@ def _get_enabled_skills():
|
||||
return []
|
||||
|
||||
|
||||
def _get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
|
||||
def get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
|
||||
"""Return enabled skills using the caller's config source.
|
||||
|
||||
When a concrete ``app_config`` is supplied, bypass the global enabled-skills
|
||||
cache so the skill list and skill paths are resolved from the same config
|
||||
object. This keeps request-scoped config injection consistent even while the
|
||||
release branch still supports global fallback paths.
|
||||
When a concrete ``app_config`` is supplied, cache the loaded skills by that
|
||||
config object's identity so request-scoped config injection still resolves
|
||||
skill paths from the matching config without rescanning storage on every
|
||||
agent factory call.
|
||||
"""
|
||||
if app_config is None:
|
||||
return _get_enabled_skills()
|
||||
return list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
|
||||
|
||||
cache_key = id(app_config)
|
||||
with _enabled_skills_lock:
|
||||
cached = _enabled_skills_by_config_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
cached_config, cached_skills = cached
|
||||
if cached_config is app_config:
|
||||
return list(cached_skills)
|
||||
|
||||
skills = list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
|
||||
with _enabled_skills_lock:
|
||||
_enabled_skills_by_config_cache[cache_key] = (app_config, skills)
|
||||
return list(skills)
|
||||
|
||||
|
||||
def _skill_mutability_label(category: SkillCategory | str) -> str:
|
||||
@@ -344,8 +366,7 @@ You are {agent_name}, an open-source super agent.
|
||||
</role>
|
||||
|
||||
{soul}
|
||||
{memory_context}
|
||||
|
||||
{self_update_section}
|
||||
<thinking_style>
|
||||
- Think concisely and strategically about the user's request BEFORE taking action
|
||||
- Break down the task: What is clear? What is ambiguous? What is missing?
|
||||
@@ -604,7 +625,7 @@ You have access to skills that provide optimized workflows for specific tasks. E
|
||||
|
||||
def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_config: AppConfig | None = None) -> str:
|
||||
"""Generate the skills prompt section with available skills list."""
|
||||
skills = _get_enabled_skills_for_config(app_config)
|
||||
skills = get_enabled_skills_for_config(app_config)
|
||||
|
||||
if app_config is None:
|
||||
try:
|
||||
@@ -643,6 +664,26 @@ def get_agent_soul(agent_name: str | None) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _build_self_update_section(agent_name: str | None) -> str:
|
||||
"""Prompt block that teaches the custom agent to persist self-updates via update_agent."""
|
||||
if not agent_name:
|
||||
return ""
|
||||
return f"""<self_update>
|
||||
You are running as the custom agent **{agent_name}** with a persisted SOUL.md and config.yaml.
|
||||
|
||||
When the user asks you to update your own description, personality, behaviour, skill set, tool groups, or default model,
|
||||
you MUST persist the change with the `update_agent` tool. Do NOT use `bash`, `write_file`, or any sandbox tool to edit
|
||||
SOUL.md or config.yaml — those write into a temporary sandbox/tool workspace and the changes will be lost on the next turn.
|
||||
|
||||
Rules:
|
||||
- Always pass the FULL replacement text for `soul` (no patch semantics). Start from your current SOUL above and apply the user's edits.
|
||||
- Only pass the fields that should change. Omit the others to preserve them.
|
||||
- Pass `skills=[]` to disable all skills, or omit `skills` to keep the existing whitelist.
|
||||
- After `update_agent` returns successfully, tell the user the change is persisted and will take effect on the next turn.
|
||||
</self_update>
|
||||
"""
|
||||
|
||||
|
||||
def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
|
||||
"""Generate <available-deferred-tools> block for the system prompt.
|
||||
|
||||
@@ -732,9 +773,6 @@ def apply_prompt_template(
|
||||
available_skills: set[str] | None = None,
|
||||
app_config: AppConfig | None = None,
|
||||
) -> str:
|
||||
# Get memory context
|
||||
memory_context = _get_memory_context(agent_name, app_config=app_config)
|
||||
|
||||
# Include subagent section only if enabled (from runtime parameter)
|
||||
n = max_concurrent_subagents
|
||||
subagent_section = _build_subagent_section(n, app_config=app_config) if subagent_enabled else ""
|
||||
@@ -768,17 +806,18 @@ def apply_prompt_template(
|
||||
custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
|
||||
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
|
||||
|
||||
# Format the prompt with dynamic skills and memory
|
||||
prompt = SYSTEM_PROMPT_TEMPLATE.format(
|
||||
# Build and return the fully static system prompt.
|
||||
# Memory and current date are injected per-turn via DynamicContextMiddleware
|
||||
# as a <system-reminder> in the first HumanMessage, keeping this prompt
|
||||
# identical across users and sessions for maximum prefix-cache reuse.
|
||||
return SYSTEM_PROMPT_TEMPLATE.format(
|
||||
agent_name=agent_name or "DeerFlow 2.0",
|
||||
soul=get_agent_soul(agent_name),
|
||||
self_update_section=_build_self_update_section(agent_name),
|
||||
skills_section=skills_section,
|
||||
deferred_tools_section=deferred_tools_section,
|
||||
memory_context=memory_context,
|
||||
subagent_section=subagent_section,
|
||||
subagent_reminder=subagent_reminder,
|
||||
subagent_thinking=subagent_thinking,
|
||||
acp_section=acp_and_mounts_section,
|
||||
)
|
||||
|
||||
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
"""Middleware to inject dynamic context (memory, current date) as a system-reminder.
|
||||
|
||||
The system prompt is kept fully static for maximum prefix-cache reuse across users
|
||||
and sessions. The current date is always injected. Per-user memory is also injected
|
||||
when ``memory.injection_enabled`` is True in the app config. Both are delivered once
|
||||
per conversation as a dedicated <system-reminder> HumanMessage inserted before the
|
||||
first user message (frozen-snapshot pattern).
|
||||
|
||||
When a conversation spans midnight the middleware detects the date change and injects
|
||||
a lightweight date-update reminder as a separate HumanMessage before the current turn.
|
||||
This correction is persisted so subsequent turns on the new day see a consistent history
|
||||
and do not re-inject.
|
||||
|
||||
Reminder format:
|
||||
|
||||
<system-reminder>
|
||||
<memory>...</memory>
|
||||
|
||||
<current_date>2026-05-08, Friday</current_date>
|
||||
</system-reminder>
|
||||
|
||||
Date-update format:
|
||||
|
||||
<system-reminder>
|
||||
<current_date>2026-05-09, Saturday</current_date>
|
||||
</system-reminder>
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
|
||||
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
|
||||
_SUMMARY_MESSAGE_NAME = "summary"
|
||||
|
||||
|
||||
def _extract_date(content: str) -> str | None:
|
||||
"""Return the first <current_date> value found in *content*, or None."""
|
||||
m = _DATE_RE.search(content)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def is_dynamic_context_reminder(message: object) -> bool:
|
||||
"""Return whether *message* is a hidden dynamic-context reminder."""
|
||||
return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY))
|
||||
|
||||
|
||||
def _last_injected_date(messages: list) -> str | None:
|
||||
"""Scan messages in reverse and return the most recently injected date.
|
||||
|
||||
Detection uses the ``dynamic_context_reminder`` additional_kwargs flag rather
|
||||
than content substring matching, so user messages containing ``<system-reminder>``
|
||||
are not mistakenly treated as injected reminders.
|
||||
"""
|
||||
for msg in reversed(messages):
|
||||
if is_dynamic_context_reminder(msg):
|
||||
content_str = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
return _extract_date(content_str)
|
||||
return None
|
||||
|
||||
|
||||
def _is_user_injection_target(message: object) -> bool:
|
||||
"""Return whether *message* can receive a dynamic-context reminder."""
|
||||
return isinstance(message, HumanMessage) and not is_dynamic_context_reminder(message) and message.name != _SUMMARY_MESSAGE_NAME
|
||||
|
||||
|
||||
class DynamicContextMiddleware(AgentMiddleware):
|
||||
"""Inject memory and current date into HumanMessages as a <system-reminder>.
|
||||
|
||||
First turn
|
||||
----------
|
||||
Prepends a full system-reminder (memory + date) to the first HumanMessage and
|
||||
persists it (same message ID). The first message is then frozen for the whole
|
||||
session — its content never changes again, so the prefix cache can hit on every
|
||||
subsequent turn.
|
||||
|
||||
Midnight crossing
|
||||
-----------------
|
||||
If the conversation spans midnight, the current date differs from the date that
|
||||
was injected earlier. In that case a lightweight date-update reminder is prepended
|
||||
to the **current** (last) HumanMessage and persisted. Subsequent turns on the new
|
||||
day see the corrected date in history and skip re-injection.
|
||||
"""
|
||||
|
||||
def __init__(self, agent_name: str | None = None, *, app_config: AppConfig | None = None):
|
||||
super().__init__()
|
||||
self._agent_name = agent_name
|
||||
self._app_config = app_config
|
||||
|
||||
def _build_full_reminder(self) -> str:
|
||||
from deerflow.agents.lead_agent.prompt import _get_memory_context
|
||||
|
||||
# Memory injection is gated by injection_enabled; date is always included.
|
||||
injection_enabled = self._app_config.memory.injection_enabled if self._app_config else True
|
||||
memory_context = _get_memory_context(self._agent_name, app_config=self._app_config) if injection_enabled else ""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d, %A")
|
||||
|
||||
lines: list[str] = ["<system-reminder>"]
|
||||
if memory_context:
|
||||
lines.append(memory_context.strip())
|
||||
lines.append("") # blank line separating memory from date
|
||||
lines.append(f"<current_date>{current_date}</current_date>")
|
||||
lines.append("</system-reminder>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _build_date_update_reminder(self) -> str:
|
||||
current_date = datetime.now().strftime("%Y-%m-%d, %A")
|
||||
return "\n".join(
|
||||
[
|
||||
"<system-reminder>",
|
||||
f"<current_date>{current_date}</current_date>",
|
||||
"</system-reminder>",
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _make_reminder_and_user_messages(original: HumanMessage, reminder_content: str) -> tuple[HumanMessage, HumanMessage]:
|
||||
"""Return (reminder_msg, user_msg) using the ID-swap technique.
|
||||
|
||||
reminder_msg takes the original message's ID so that add_messages replaces it
|
||||
in-place (preserving position). user_msg carries the original content with a
|
||||
derived ``{id}__user`` ID and is appended immediately after by add_messages.
|
||||
|
||||
If the original message has no ID a stable UUID is generated so the derived
|
||||
``{id}__user`` ID never collapses to the ambiguous ``None__user`` string.
|
||||
"""
|
||||
stable_id = original.id or str(uuid.uuid4())
|
||||
reminder_msg = HumanMessage(
|
||||
content=reminder_content,
|
||||
id=stable_id,
|
||||
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
|
||||
)
|
||||
user_msg = HumanMessage(
|
||||
content=original.content,
|
||||
id=f"{stable_id}__user",
|
||||
name=original.name,
|
||||
additional_kwargs=original.additional_kwargs,
|
||||
)
|
||||
return reminder_msg, user_msg
|
||||
|
||||
def _inject(self, state) -> dict | None:
|
||||
messages = list(state.get("messages", []))
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d, %A")
|
||||
last_date = _last_injected_date(messages)
|
||||
logger.debug(
|
||||
"DynamicContextMiddleware._inject: msg_count=%d last_date=%r current_date=%r",
|
||||
len(messages),
|
||||
last_date,
|
||||
current_date,
|
||||
)
|
||||
|
||||
if last_date is None:
|
||||
# ── First turn: inject full reminder as a separate HumanMessage ─────
|
||||
first_idx = next((i for i, m in enumerate(messages) if _is_user_injection_target(m)), None)
|
||||
if first_idx is None:
|
||||
return None
|
||||
full_reminder = self._build_full_reminder()
|
||||
logger.info(
|
||||
"DynamicContextMiddleware: injecting full reminder (len=%d, has_memory=%s) into first HumanMessage id=%r",
|
||||
len(full_reminder),
|
||||
"<memory>" in full_reminder,
|
||||
messages[first_idx].id,
|
||||
)
|
||||
reminder_msg, user_msg = self._make_reminder_and_user_messages(messages[first_idx], full_reminder)
|
||||
return {"messages": [reminder_msg, user_msg]}
|
||||
|
||||
if last_date == current_date:
|
||||
# ── Same day: nothing to do ──────────────────────────────────────────
|
||||
return None
|
||||
|
||||
# ── Midnight crossed: inject date-update reminder as a separate HumanMessage ──
|
||||
last_human_idx = next((i for i in reversed(range(len(messages))) if _is_user_injection_target(messages[i])), None)
|
||||
if last_human_idx is None:
|
||||
return None
|
||||
|
||||
reminder_msg, user_msg = self._make_reminder_and_user_messages(messages[last_human_idx], self._build_date_update_reminder())
|
||||
logger.info("DynamicContextMiddleware: midnight crossing detected — injected date update before current turn")
|
||||
return {"messages": [reminder_msg, user_msg]}
|
||||
|
||||
@override
|
||||
def before_agent(self, state, runtime: Runtime) -> dict | None:
|
||||
return self._inject(state)
|
||||
|
||||
@override
|
||||
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
|
||||
return self._inject(state)
|
||||
@@ -12,19 +12,23 @@ Detection strategy:
|
||||
response so the agent is forced to produce a final text answer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import override
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults — can be overridden via constructor
|
||||
@@ -140,6 +144,9 @@ _TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times
|
||||
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Detects and breaks repetitive tool call loops.
|
||||
|
||||
Threshold parameters are validated upstream by :class:`LoopDetectionConfig`;
|
||||
construct via :meth:`from_config` to ensure values pass Pydantic validation.
|
||||
|
||||
Args:
|
||||
warn_threshold: Number of identical tool call sets before injecting
|
||||
a warning message. Default: 3.
|
||||
@@ -155,6 +162,14 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
Default: 30.
|
||||
tool_freq_hard_limit: Number of calls to the same tool type before
|
||||
forcing a stop. Default: 50.
|
||||
tool_freq_overrides: Per-tool overrides for frequency thresholds,
|
||||
keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple
|
||||
that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for
|
||||
that specific tool. Tools not listed here fall back to the global
|
||||
thresholds. Useful for raising limits on intentionally
|
||||
high-frequency tools (e.g. ``bash`` in batch pipelines) without
|
||||
weakening protection on all other tools. Default: ``None``
|
||||
(no overrides).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -165,6 +180,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
||||
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
|
||||
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
|
||||
tool_freq_overrides: dict[str, tuple[int, int]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.warn_threshold = warn_threshold
|
||||
@@ -173,14 +189,26 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
self.max_tracked_threads = max_tracked_threads
|
||||
self.tool_freq_warn = tool_freq_warn
|
||||
self.tool_freq_hard_limit = tool_freq_hard_limit
|
||||
self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {}
|
||||
self._lock = threading.Lock()
|
||||
# Per-thread tracking using OrderedDict for LRU eviction
|
||||
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
||||
self._warned: dict[str, set[str]] = defaultdict(set)
|
||||
# Per-thread, per-tool-type cumulative call counts
|
||||
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
|
||||
"""Construct from a Pydantic-validated config, trusting its validation."""
|
||||
return cls(
|
||||
warn_threshold=config.warn_threshold,
|
||||
hard_limit=config.hard_limit,
|
||||
window_size=config.window_size,
|
||||
max_tracked_threads=config.max_tracked_threads,
|
||||
tool_freq_warn=config.tool_freq_warn,
|
||||
tool_freq_hard_limit=config.tool_freq_hard_limit,
|
||||
tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()},
|
||||
)
|
||||
|
||||
def _get_thread_id(self, runtime: Runtime) -> str:
|
||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
@@ -280,7 +308,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
freq[name] += 1
|
||||
tc_count = freq[name]
|
||||
|
||||
if tc_count >= self.tool_freq_hard_limit:
|
||||
if name in self._tool_freq_overrides:
|
||||
eff_warn, eff_hard = self._tool_freq_overrides[name]
|
||||
else:
|
||||
eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit
|
||||
|
||||
if tc_count >= eff_hard:
|
||||
logger.error(
|
||||
"Tool frequency hard limit reached — forcing stop",
|
||||
extra={
|
||||
@@ -291,7 +324,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
)
|
||||
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
|
||||
|
||||
if tc_count >= self.tool_freq_warn:
|
||||
if tc_count >= eff_warn:
|
||||
warned = self._tool_freq_warned[thread_id]
|
||||
if name not in warned:
|
||||
warned.add(name)
|
||||
@@ -356,13 +389,30 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
if warning:
|
||||
# Inject as HumanMessage instead of SystemMessage to avoid
|
||||
# Anthropic's "multiple non-consecutive system messages" error.
|
||||
# Anthropic models require system messages only at the start of
|
||||
# the conversation; injecting one mid-conversation crashes
|
||||
# langchain_anthropic's _format_messages(). HumanMessage works
|
||||
# with all providers. See #1299.
|
||||
return {"messages": [HumanMessage(content=warning, name="loop_warning")]}
|
||||
# WORKAROUND for v2.0-m1 — see #2724.
|
||||
#
|
||||
# Append the warning to the AIMessage content instead of
|
||||
# injecting a separate HumanMessage. Inserting any non-tool
|
||||
# message between an AIMessage(tool_calls=...) and its
|
||||
# ToolMessage responses breaks OpenAI/Moonshot strict pairing
|
||||
# validation ("tool_call_ids did not have response messages")
|
||||
# because the tools node has not run yet at after_model time.
|
||||
# tool_calls are preserved so the tools node still executes.
|
||||
#
|
||||
# This is a temporary mitigation: mutating an existing
|
||||
# AIMessage to carry framework-authored text leaks loop-warning
|
||||
# text into downstream consumers (MemoryMiddleware fact
|
||||
# extraction, TitleMiddleware, telemetry, model replay) as if
|
||||
# the model said it. The proper fix is to defer warning
|
||||
# injection from after_model to wrap_model_call so every prior
|
||||
# ToolMessage is already in the request — see RFC #2517 (which
|
||||
# lists "loop intervention does not leave invalid
|
||||
# tool-call/tool-message state" as acceptance criteria) and
|
||||
# the prototype on `fix/loop-detection-tool-call-pairing`.
|
||||
messages = state.get("messages", [])
|
||||
last_msg = messages[-1]
|
||||
patched_msg = last_msg.model_copy(update={"content": self._append_text(last_msg.content, warning)})
|
||||
return {"messages": [patched_msg]}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
||||
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -63,7 +64,7 @@ class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
|
||||
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
|
||||
|
||||
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
|
||||
updated_msg = clone_ai_message_with_tool_calls(last_msg, truncated_tool_calls)
|
||||
return {"messages": [updated_msg]}
|
||||
|
||||
@override
|
||||
|
||||
@@ -14,6 +14,9 @@ from langgraph.config import get_config
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
|
||||
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -78,10 +81,7 @@ def _clone_ai_message(
|
||||
content: Any | None = None,
|
||||
) -> AIMessage:
|
||||
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
|
||||
update: dict[str, Any] = {"tool_calls": tool_calls}
|
||||
if content is not None:
|
||||
update["content"] = content
|
||||
return message.model_copy(update=update)
|
||||
return clone_ai_message_with_tool_calls(message, tool_calls, content=content)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -136,6 +136,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
||||
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
|
||||
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
||||
summary = self._create_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
@@ -161,6 +162,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
||||
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
|
||||
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
||||
summary = await self._acreate_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
@@ -180,6 +182,24 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
"""
|
||||
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
||||
|
||||
def _preserve_dynamic_context_reminders(
|
||||
self,
|
||||
messages_to_summarize: list[AnyMessage],
|
||||
preserved_messages: list[AnyMessage],
|
||||
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
||||
"""Keep hidden dynamic-context reminders out of summary compression.
|
||||
|
||||
These reminders carry the current date and optional memory. If summarization
|
||||
removes them, DynamicContextMiddleware can mistake the summary HumanMessage
|
||||
for the first user message and inject the reminder in the wrong place.
|
||||
"""
|
||||
reminders = [msg for msg in messages_to_summarize if is_dynamic_context_reminder(msg)]
|
||||
if not reminders:
|
||||
return messages_to_summarize, preserved_messages
|
||||
|
||||
remaining = [msg for msg in messages_to_summarize if not is_dynamic_context_reminder(msg)]
|
||||
return remaining, reminders + preserved_messages
|
||||
|
||||
def _partition_with_skill_rescue(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
|
||||
@@ -9,6 +9,7 @@ from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
|
||||
from deerflow.config.title_config import get_title_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
@@ -61,6 +62,10 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _is_user_message_for_title(message: object) -> bool:
|
||||
return getattr(message, "type", None) == "human" and not is_dynamic_context_reminder(message)
|
||||
|
||||
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
||||
"""Check if we should generate a title for this thread."""
|
||||
config = self._get_title_config()
|
||||
@@ -77,7 +82,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
return False
|
||||
|
||||
# Count user and assistant messages
|
||||
user_messages = [m for m in messages if m.type == "human"]
|
||||
user_messages = [m for m in messages if self._is_user_message_for_title(m)]
|
||||
assistant_messages = [m for m in messages if m.type == "ai"]
|
||||
|
||||
# Generate title after first complete exchange
|
||||
@@ -91,7 +96,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
config = self._get_title_config()
|
||||
messages = state.get("messages", [])
|
||||
|
||||
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
||||
user_msg_content = next((m.content for m in messages if self._is_user_message_for_title(m)), "")
|
||||
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
|
||||
|
||||
user_msg = self._normalize_content(user_msg_content)
|
||||
|
||||
@@ -1,37 +1,303 @@
|
||||
"""Middleware for logging LLM token usage."""
|
||||
"""Middleware for logging token usage and annotating step attribution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
from collections import defaultdict
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.todo import Todo
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOKEN_USAGE_ATTRIBUTION_KEY = "token_usage_attribution"
|
||||
|
||||
|
||||
def _string_arg(value: Any) -> str | None:
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip()
|
||||
return normalized or None
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_todos(value: Any) -> list[Todo]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
|
||||
normalized: list[Todo] = []
|
||||
for item in value:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
todo: Todo = {}
|
||||
content = _string_arg(item.get("content"))
|
||||
status = item.get("status")
|
||||
|
||||
if content is not None:
|
||||
todo["content"] = content
|
||||
if status in {"pending", "in_progress", "completed"}:
|
||||
todo["status"] = status
|
||||
|
||||
normalized.append(todo)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _todo_action_kind(previous: Todo | None, current: Todo) -> str:
|
||||
status = current.get("status")
|
||||
previous_content = previous.get("content") if previous else None
|
||||
current_content = current.get("content")
|
||||
|
||||
if previous is None:
|
||||
if status == "completed":
|
||||
return "todo_complete"
|
||||
if status == "in_progress":
|
||||
return "todo_start"
|
||||
return "todo_update"
|
||||
|
||||
if previous_content != current_content:
|
||||
return "todo_update"
|
||||
|
||||
if status == "completed":
|
||||
return "todo_complete"
|
||||
if status == "in_progress":
|
||||
return "todo_start"
|
||||
return "todo_update"
|
||||
|
||||
|
||||
def _build_todo_actions(previous_todos: list[Todo], next_todos: list[Todo]) -> list[dict[str, Any]]:
|
||||
# This is the single source of truth for precise write_todos token
|
||||
# attribution. The frontend intentionally falls back to a generic
|
||||
# "Update to-do list" label when this metadata is missing or malformed.
|
||||
previous_by_content: dict[str, list[tuple[int, Todo]]] = defaultdict(list)
|
||||
matched_previous_indices: set[int] = set()
|
||||
|
||||
for index, todo in enumerate(previous_todos):
|
||||
content = todo.get("content")
|
||||
if isinstance(content, str) and content:
|
||||
previous_by_content[content].append((index, todo))
|
||||
|
||||
actions: list[dict[str, Any]] = []
|
||||
|
||||
for index, todo in enumerate(next_todos):
|
||||
content = todo.get("content")
|
||||
if not isinstance(content, str) or not content:
|
||||
continue
|
||||
|
||||
previous_match: Todo | None = None
|
||||
content_matches = previous_by_content.get(content)
|
||||
if content_matches:
|
||||
while content_matches and content_matches[0][0] in matched_previous_indices:
|
||||
content_matches.pop(0)
|
||||
if content_matches:
|
||||
previous_index, previous_match = content_matches.pop(0)
|
||||
matched_previous_indices.add(previous_index)
|
||||
|
||||
if previous_match is None and index < len(previous_todos) and index not in matched_previous_indices:
|
||||
previous_match = previous_todos[index]
|
||||
matched_previous_indices.add(index)
|
||||
|
||||
if previous_match is not None:
|
||||
previous_content = previous_match.get("content")
|
||||
previous_status = previous_match.get("status")
|
||||
if previous_content == content and previous_status == todo.get("status"):
|
||||
continue
|
||||
|
||||
actions.append(
|
||||
{
|
||||
"kind": _todo_action_kind(previous_match, todo),
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
for index, todo in enumerate(previous_todos):
|
||||
if index in matched_previous_indices:
|
||||
continue
|
||||
|
||||
content = todo.get("content")
|
||||
if not isinstance(content, str) or not content:
|
||||
continue
|
||||
|
||||
actions.append(
|
||||
{
|
||||
"kind": "todo_remove",
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
def _describe_tool_call(tool_call: dict[str, Any], todos: list[Todo]) -> list[dict[str, Any]]:
|
||||
name = _string_arg(tool_call.get("name")) or "unknown"
|
||||
args = tool_call.get("args") if isinstance(tool_call.get("args"), dict) else {}
|
||||
tool_call_id = _string_arg(tool_call.get("id"))
|
||||
|
||||
if name == "write_todos":
|
||||
next_todos = _normalize_todos(args.get("todos"))
|
||||
actions = _build_todo_actions(todos, next_todos)
|
||||
if not actions:
|
||||
return [
|
||||
{
|
||||
"kind": "tool",
|
||||
"tool_name": name,
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
return [
|
||||
{
|
||||
**action,
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
for action in actions
|
||||
]
|
||||
|
||||
if name == "task":
|
||||
return [
|
||||
{
|
||||
"kind": "subagent",
|
||||
"description": _string_arg(args.get("description")),
|
||||
"subagent_type": _string_arg(args.get("subagent_type")),
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
if name in {"web_search", "image_search"}:
|
||||
query = _string_arg(args.get("query"))
|
||||
return [
|
||||
{
|
||||
"kind": "search",
|
||||
"tool_name": name,
|
||||
"query": query,
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
if name == "present_files":
|
||||
return [
|
||||
{
|
||||
"kind": "present_files",
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
if name == "ask_clarification":
|
||||
return [
|
||||
{
|
||||
"kind": "clarification",
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
return [
|
||||
{
|
||||
"kind": "tool",
|
||||
"tool_name": name,
|
||||
"description": _string_arg(args.get("description")),
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
|
||||
if actions:
|
||||
first_kind = actions[0].get("kind")
|
||||
if len(actions) == 1 and first_kind in {"todo_start", "todo_complete", "todo_update", "todo_remove"}:
|
||||
return "todo_update"
|
||||
if len(actions) == 1 and first_kind == "subagent":
|
||||
return "subagent_dispatch"
|
||||
return "tool_batch"
|
||||
|
||||
if message.content:
|
||||
return "final_answer"
|
||||
return "thinking"
|
||||
|
||||
|
||||
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
actions: list[dict[str, Any]] = []
|
||||
current_todos = list(todos)
|
||||
|
||||
for raw_tool_call in tool_calls:
|
||||
if not isinstance(raw_tool_call, dict):
|
||||
continue
|
||||
|
||||
described_actions = _describe_tool_call(raw_tool_call, current_todos)
|
||||
actions.extend(described_actions)
|
||||
|
||||
if raw_tool_call.get("name") == "write_todos":
|
||||
args = raw_tool_call.get("args") if isinstance(raw_tool_call.get("args"), dict) else {}
|
||||
current_todos = _normalize_todos(args.get("todos"))
|
||||
|
||||
tool_call_ids: list[str] = []
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
|
||||
tool_call_id = _string_arg(tool_call.get("id"))
|
||||
if tool_call_id is not None:
|
||||
tool_call_ids.append(tool_call_id)
|
||||
|
||||
return {
|
||||
# Schema changes should remain additive where possible so older
|
||||
# frontends can ignore unknown fields and fall back safely.
|
||||
"version": 1,
|
||||
"kind": _infer_step_kind(message, actions),
|
||||
"shared_attribution": len(actions) > 1,
|
||||
"tool_call_ids": tool_call_ids,
|
||||
"actions": actions,
|
||||
}
|
||||
|
||||
|
||||
class TokenUsageMiddleware(AgentMiddleware):
|
||||
"""Logs token usage from model response usage_metadata."""
|
||||
"""Logs token usage from model responses and annotates the AI step."""
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._log_usage(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._log_usage(state)
|
||||
|
||||
def _log_usage(self, state: AgentState) -> None:
|
||||
def _apply(self, state: AgentState) -> dict | None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
usage = getattr(last, "usage_metadata", None)
|
||||
if usage:
|
||||
input_token_details = usage.get("input_token_details") or {}
|
||||
output_token_details = usage.get("output_token_details") or {}
|
||||
detail_parts = []
|
||||
if input_token_details:
|
||||
detail_parts.append(f"input_token_details={input_token_details}")
|
||||
if output_token_details:
|
||||
detail_parts.append(f"output_token_details={output_token_details}")
|
||||
detail_suffix = f" {' '.join(detail_parts)}" if detail_parts else ""
|
||||
logger.info(
|
||||
"LLM token usage: input=%s output=%s total=%s",
|
||||
"LLM token usage: input=%s output=%s total=%s%s",
|
||||
usage.get("input_tokens", "?"),
|
||||
usage.get("output_tokens", "?"),
|
||||
usage.get("total_tokens", "?"),
|
||||
detail_suffix,
|
||||
)
|
||||
return None
|
||||
|
||||
todos = state.get("todos") or []
|
||||
attribution = _build_attribution(last, todos if isinstance(todos, list) else [])
|
||||
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
|
||||
|
||||
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
|
||||
return None
|
||||
|
||||
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
|
||||
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||
return {"messages": [updated_msg]}
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state)
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Helpers for keeping AIMessage tool-call metadata consistent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
def _raw_tool_call_id(raw_tool_call: Any) -> str | None:
|
||||
if not isinstance(raw_tool_call, dict):
|
||||
return None
|
||||
|
||||
raw_id = raw_tool_call.get("id")
|
||||
return raw_id if isinstance(raw_id, str) and raw_id else None
|
||||
|
||||
|
||||
def clone_ai_message_with_tool_calls(
|
||||
message: AIMessage,
|
||||
tool_calls: list[dict[str, Any]],
|
||||
*,
|
||||
content: Any | None = None,
|
||||
) -> AIMessage:
|
||||
"""Clone an AIMessage while keeping raw provider tool-call metadata in sync."""
|
||||
kept_ids = {tc["id"] for tc in tool_calls if isinstance(tc.get("id"), str) and tc["id"]}
|
||||
|
||||
update: dict[str, Any] = {"tool_calls": tool_calls}
|
||||
if content is not None:
|
||||
update["content"] = content
|
||||
|
||||
additional_kwargs = dict(getattr(message, "additional_kwargs", {}) or {})
|
||||
raw_tool_calls = additional_kwargs.get("tool_calls")
|
||||
if isinstance(raw_tool_calls, list):
|
||||
synced_raw_tool_calls = [raw_tc for raw_tc in raw_tool_calls if _raw_tool_call_id(raw_tc) in kept_ids]
|
||||
if synced_raw_tool_calls:
|
||||
additional_kwargs["tool_calls"] = synced_raw_tool_calls
|
||||
else:
|
||||
additional_kwargs.pop("tool_calls", None)
|
||||
|
||||
if not tool_calls:
|
||||
additional_kwargs.pop("function_call", None)
|
||||
|
||||
update["additional_kwargs"] = additional_kwargs
|
||||
|
||||
response_metadata = dict(getattr(message, "response_metadata", {}) or {})
|
||||
if not tool_calls and response_metadata.get("finish_reason") == "tool_calls":
|
||||
response_metadata["finish_reason"] = "stop"
|
||||
update["response_metadata"] = response_metadata
|
||||
|
||||
return message.model_copy(update=update)
|
||||
@@ -264,25 +264,35 @@ class DeerFlowClient:
|
||||
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
|
||||
|
||||
@staticmethod
|
||||
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent":
|
||||
"""Build a ``messages-tuple`` AI text event, attaching usage when present."""
|
||||
def _serialize_additional_kwargs(msg) -> dict[str, Any] | None:
|
||||
"""Copy message additional_kwargs when present."""
|
||||
additional_kwargs = getattr(msg, "additional_kwargs", None)
|
||||
if isinstance(additional_kwargs, dict) and additional_kwargs:
|
||||
return dict(additional_kwargs)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
|
||||
"""Build a ``messages-tuple`` AI text event."""
|
||||
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
|
||||
if usage:
|
||||
data["usage_metadata"] = usage
|
||||
if additional_kwargs:
|
||||
data["additional_kwargs"] = additional_kwargs
|
||||
return StreamEvent(type="messages-tuple", data=data)
|
||||
|
||||
@staticmethod
|
||||
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent":
|
||||
def _ai_tool_calls_event(msg_id: str | None, tool_calls, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
|
||||
"""Build a ``messages-tuple`` AI tool-calls event."""
|
||||
return StreamEvent(
|
||||
type="messages-tuple",
|
||||
data={
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"id": msg_id,
|
||||
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
|
||||
},
|
||||
)
|
||||
data: dict[str, Any] = {
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"id": msg_id,
|
||||
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
|
||||
}
|
||||
if additional_kwargs:
|
||||
data["additional_kwargs"] = additional_kwargs
|
||||
return StreamEvent(type="messages-tuple", data=data)
|
||||
|
||||
@staticmethod
|
||||
def _tool_message_event(msg: ToolMessage) -> "StreamEvent":
|
||||
@@ -307,19 +317,30 @@ class DeerFlowClient:
|
||||
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
|
||||
if getattr(msg, "usage_metadata", None):
|
||||
d["usage_metadata"] = msg.usage_metadata
|
||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
||||
d["additional_kwargs"] = additional_kwargs
|
||||
return d
|
||||
if isinstance(msg, ToolMessage):
|
||||
return {
|
||||
d = {
|
||||
"type": "tool",
|
||||
"content": DeerFlowClient._extract_text(msg.content),
|
||||
"name": getattr(msg, "name", None),
|
||||
"tool_call_id": getattr(msg, "tool_call_id", None),
|
||||
"id": getattr(msg, "id", None),
|
||||
}
|
||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
||||
d["additional_kwargs"] = additional_kwargs
|
||||
return d
|
||||
if isinstance(msg, HumanMessage):
|
||||
return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
d = {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
||||
d["additional_kwargs"] = additional_kwargs
|
||||
return d
|
||||
if isinstance(msg, SystemMessage):
|
||||
return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
d = {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
||||
d["additional_kwargs"] = additional_kwargs
|
||||
return d
|
||||
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
|
||||
|
||||
@staticmethod
|
||||
@@ -542,6 +563,7 @@ class DeerFlowClient:
|
||||
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
|
||||
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
|
||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
|
||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "additional_kwargs": {...}}
|
||||
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
|
||||
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
|
||||
"""
|
||||
@@ -564,6 +586,7 @@ class DeerFlowClient:
|
||||
# in both the final ``messages`` chunk and the values snapshot —
|
||||
# count it only on whichever arrives first.
|
||||
counted_usage_ids: set[str] = set()
|
||||
sent_additional_kwargs_by_id: dict[str, dict[str, Any]] = {}
|
||||
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
|
||||
def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
|
||||
@@ -593,6 +616,20 @@ class DeerFlowClient:
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
def _unsent_additional_kwargs(msg_id: str | None, additional_kwargs: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if not additional_kwargs:
|
||||
return None
|
||||
if not msg_id:
|
||||
return additional_kwargs
|
||||
|
||||
sent = sent_additional_kwargs_by_id.setdefault(msg_id, {})
|
||||
delta = {key: value for key, value in additional_kwargs.items() if sent.get(key) != value}
|
||||
if not delta:
|
||||
return None
|
||||
|
||||
sent.update(delta)
|
||||
return delta
|
||||
|
||||
for item in self._agent.stream(
|
||||
state,
|
||||
config=config,
|
||||
@@ -620,17 +657,31 @@ class DeerFlowClient:
|
||||
|
||||
if isinstance(msg_chunk, AIMessage):
|
||||
text = self._extract_text(msg_chunk.content)
|
||||
additional_kwargs = self._serialize_additional_kwargs(msg_chunk)
|
||||
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
|
||||
sent_additional_kwargs = False
|
||||
|
||||
if text:
|
||||
if msg_id:
|
||||
streamed_ids.add(msg_id)
|
||||
yield self._ai_text_event(msg_id, text, counted_usage)
|
||||
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
yield self._ai_text_event(
|
||||
msg_id,
|
||||
text,
|
||||
counted_usage,
|
||||
additional_kwargs_delta,
|
||||
)
|
||||
sent_additional_kwargs = bool(additional_kwargs_delta)
|
||||
|
||||
if msg_chunk.tool_calls:
|
||||
if msg_id:
|
||||
streamed_ids.add(msg_id)
|
||||
yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls)
|
||||
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
yield self._ai_tool_calls_event(
|
||||
msg_id,
|
||||
msg_chunk.tool_calls,
|
||||
additional_kwargs_delta,
|
||||
)
|
||||
|
||||
elif isinstance(msg_chunk, ToolMessage):
|
||||
if msg_id:
|
||||
@@ -653,17 +704,45 @@ class DeerFlowClient:
|
||||
if msg_id and msg_id in streamed_ids:
|
||||
if isinstance(msg, AIMessage):
|
||||
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
|
||||
additional_kwargs = self._serialize_additional_kwargs(msg)
|
||||
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
if additional_kwargs_delta:
|
||||
# Metadata-only follow-up: ``messages-tuple`` has no
|
||||
# dedicated attribution event, so clients should
|
||||
# merge this empty-content AI event by message id
|
||||
# and ignore it for text rendering.
|
||||
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
|
||||
continue
|
||||
|
||||
if isinstance(msg, AIMessage):
|
||||
counted_usage = _account_usage(msg_id, msg.usage_metadata)
|
||||
additional_kwargs = self._serialize_additional_kwargs(msg)
|
||||
sent_additional_kwargs = False
|
||||
|
||||
if msg.tool_calls:
|
||||
yield self._ai_tool_calls_event(msg_id, msg.tool_calls)
|
||||
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
yield self._ai_tool_calls_event(
|
||||
msg_id,
|
||||
msg.tool_calls,
|
||||
additional_kwargs_delta,
|
||||
)
|
||||
sent_additional_kwargs = bool(additional_kwargs_delta)
|
||||
|
||||
text = self._extract_text(msg.content)
|
||||
if text:
|
||||
yield self._ai_text_event(msg_id, text, counted_usage)
|
||||
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
yield self._ai_text_event(
|
||||
msg_id,
|
||||
text,
|
||||
counted_usage,
|
||||
additional_kwargs_delta,
|
||||
)
|
||||
elif msg_id:
|
||||
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
if not additional_kwargs_delta:
|
||||
continue
|
||||
# See the metadata-only follow-up convention above.
|
||||
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
|
||||
|
||||
elif isinstance(msg, ToolMessage):
|
||||
yield self._tool_message_event(msg)
|
||||
|
||||
@@ -80,6 +80,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
port: 8080 # Base port for local containers
|
||||
container_prefix: deer-flow-sandbox
|
||||
idle_timeout: 600 # Idle timeout in seconds (0 to disable)
|
||||
auto_restart: true # Restart crashed containers automatically
|
||||
replicas: 3 # Max concurrent sandbox containers (LRU eviction when exceeded)
|
||||
mounts: # Volume mounts for local containers
|
||||
- host_path: /path/on/host
|
||||
@@ -164,12 +165,14 @@ class AioSandboxProvider(SandboxProvider):
|
||||
|
||||
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
|
||||
replicas = getattr(sandbox_config, "replicas", None)
|
||||
auto_restart = getattr(sandbox_config, "auto_restart", True)
|
||||
|
||||
return {
|
||||
"image": sandbox_config.image or DEFAULT_IMAGE,
|
||||
"port": sandbox_config.port or DEFAULT_PORT,
|
||||
"container_prefix": sandbox_config.container_prefix or DEFAULT_CONTAINER_PREFIX,
|
||||
"idle_timeout": idle_timeout if idle_timeout is not None else DEFAULT_IDLE_TIMEOUT,
|
||||
"auto_restart": auto_restart,
|
||||
"replicas": replicas if replicas is not None else DEFAULT_REPLICAS,
|
||||
"mounts": sandbox_config.mounts or [],
|
||||
"environment": self._resolve_env_vars(sandbox_config.environment or {}),
|
||||
@@ -608,18 +611,58 @@ class AioSandboxProvider(SandboxProvider):
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
"""Get a sandbox by ID. Updates last activity timestamp.
|
||||
|
||||
When ``auto_restart`` is enabled (the default), the container's liveness
|
||||
is verified on each lookup. If the underlying container has crashed, the
|
||||
sandbox is evicted from all caches so that the next ``acquire()`` call will
|
||||
transparently create a fresh container.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox.
|
||||
|
||||
Returns:
|
||||
The sandbox instance if found, None otherwise.
|
||||
The sandbox instance if found and alive, None otherwise.
|
||||
"""
|
||||
with self._lock:
|
||||
sandbox = self._sandboxes.get(sandbox_id)
|
||||
if sandbox is not None:
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
if sandbox is None:
|
||||
return None
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
auto_restart = self._config.get("auto_restart", True)
|
||||
info = self._sandbox_infos.get(sandbox_id) if auto_restart else None
|
||||
|
||||
if not info:
|
||||
return sandbox
|
||||
|
||||
if self._backend.is_alive(info):
|
||||
return sandbox
|
||||
|
||||
info_to_destroy = None
|
||||
with self._lock:
|
||||
current_sandbox = self._sandboxes.get(sandbox_id)
|
||||
current_info = self._sandbox_infos.get(sandbox_id)
|
||||
if current_sandbox is None:
|
||||
return None
|
||||
if current_info is not info:
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
return current_sandbox
|
||||
|
||||
logger.warning(f"Sandbox {sandbox_id} container is not alive, evicting from cache for auto-restart")
|
||||
self._sandboxes.pop(sandbox_id, None)
|
||||
self._sandbox_infos.pop(sandbox_id, None)
|
||||
self._last_activity.pop(sandbox_id, None)
|
||||
self._warm_pool.pop(sandbox_id, None)
|
||||
thread_ids = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||
for tid in thread_ids:
|
||||
del self._thread_sandboxes[tid]
|
||||
info_to_destroy = info
|
||||
|
||||
if info_to_destroy:
|
||||
try:
|
||||
self._backend.destroy(info_to_destroy)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup dead sandbox {sandbox_id}: {e}")
|
||||
return None
|
||||
|
||||
def release(self, sandbox_id: str) -> None:
|
||||
"""Release a sandbox from active use into the warm pool.
|
||||
|
||||
|
||||
@@ -84,8 +84,52 @@ class RemoteSandboxBackend(SandboxBackend):
|
||||
"""
|
||||
return self._provisioner_discover(sandbox_id)
|
||||
|
||||
def list_running(self) -> list[SandboxInfo]:
|
||||
"""Return all sandboxes currently managed by the provisioner.
|
||||
|
||||
Calls ``GET /api/sandboxes`` so that ``AioSandboxProvider._reconcile_orphans()``
|
||||
can adopt pods that were created by a previous process and were never
|
||||
explicitly destroyed.
|
||||
Without this, a process restart silently orphans all existing k8s Pods —
|
||||
they stay running forever because the idle checker only
|
||||
tracks in-process state.
|
||||
"""
|
||||
return self._provisioner_list()
|
||||
|
||||
# ── Provisioner API calls ─────────────────────────────────────────────
|
||||
|
||||
def _provisioner_list(self) -> list[SandboxInfo]:
|
||||
"""GET /api/sandboxes → list all running sandboxes."""
|
||||
try:
|
||||
resp = requests.get(f"{self._provisioner_url}/api/sandboxes", timeout=10)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if not isinstance(data, dict):
|
||||
logger.warning("Provisioner list_running returned non-dict payload: %r", type(data))
|
||||
return []
|
||||
|
||||
sandboxes = data.get("sandboxes", [])
|
||||
if not isinstance(sandboxes, list):
|
||||
logger.warning("Provisioner list_running returned non-list sandboxes: %r", type(sandboxes))
|
||||
return []
|
||||
|
||||
infos: list[SandboxInfo] = []
|
||||
for sandbox in sandboxes:
|
||||
if not isinstance(sandbox, dict):
|
||||
logger.warning("Provisioner list_running entry is not a dict: %r", type(sandbox))
|
||||
continue
|
||||
|
||||
sandbox_id = sandbox.get("sandbox_id")
|
||||
sandbox_url = sandbox.get("sandbox_url")
|
||||
if isinstance(sandbox_id, str) and sandbox_id and isinstance(sandbox_url, str) and sandbox_url:
|
||||
infos.append(SandboxInfo(sandbox_id=sandbox_id, sandbox_url=sandbox_url))
|
||||
|
||||
logger.info("Provisioner list_running: %d sandbox(es) found", len(infos))
|
||||
return infos
|
||||
except requests.RequestException as exc:
|
||||
logger.warning("Provisioner list_running failed: %s", exc)
|
||||
return []
|
||||
|
||||
def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
"""POST /api/sandboxes → create Pod + Service."""
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tools import web_search_tool
|
||||
|
||||
__all__ = ["web_search_tool"]
|
||||
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Web Search Tool - Search the web using Serper (Google Search API).
|
||||
|
||||
Serper provides real-time Google Search results via a JSON API.
|
||||
An API key is required. Sign up at https://serper.dev to get one.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SERPER_ENDPOINT = "https://google.serper.dev/search"
|
||||
_api_key_warned = False
|
||||
|
||||
|
||||
def _get_api_key() -> str | None:
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
if config is not None:
|
||||
api_key = config.model_extra.get("api_key")
|
||||
if isinstance(api_key, str) and api_key.strip():
|
||||
return api_key
|
||||
return os.getenv("SERPER_API_KEY")
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str, max_results: int = 5) -> str:
|
||||
"""Search the web for information using Google Search via Serper.
|
||||
|
||||
Args:
|
||||
query: Search keywords describing what you want to find. Be specific for better results.
|
||||
max_results: Maximum number of search results to return. Default is 5.
|
||||
"""
|
||||
global _api_key_warned
|
||||
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
if config is not None and "max_results" in config.model_extra:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
|
||||
api_key = _get_api_key()
|
||||
if not api_key:
|
||||
if not _api_key_warned:
|
||||
_api_key_warned = True
|
||||
logger.warning("Serper API key is not set. Set SERPER_API_KEY in your environment or provide api_key in config.yaml. Sign up at https://serper.dev")
|
||||
return json.dumps(
|
||||
{"error": "SERPER_API_KEY is not configured", "query": query},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {"q": query, "num": max_results}
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=30) as client:
|
||||
response = client.post(_SERPER_ENDPOINT, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Serper API returned HTTP {e.response.status_code}: {e.response.text}")
|
||||
return json.dumps(
|
||||
{"error": f"Serper API error: HTTP {e.response.status_code}", "query": query},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Serper search failed: {type(e).__name__}: {e}")
|
||||
return json.dumps({"error": str(e), "query": query}, ensure_ascii=False)
|
||||
|
||||
organic = data.get("organic", [])
|
||||
if not organic:
|
||||
return json.dumps({"error": "No results found", "query": query}, ensure_ascii=False)
|
||||
|
||||
normalized_results = [
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("link", ""),
|
||||
"content": r.get("snippet", ""),
|
||||
}
|
||||
for r in organic[:max_results]
|
||||
]
|
||||
|
||||
output = {
|
||||
"query": query,
|
||||
"total_results": len(normalized_results),
|
||||
"results": normalized_results,
|
||||
}
|
||||
return json.dumps(output, indent=2, ensure_ascii=False)
|
||||
@@ -1,5 +1,6 @@
|
||||
from .app_config import get_app_config
|
||||
from .extensions_config import ExtensionsConfig, get_extensions_config
|
||||
from .loop_detection_config import LoopDetectionConfig
|
||||
from .memory_config import MemoryConfig, get_memory_config
|
||||
from .paths import Paths, get_paths
|
||||
from .skill_evolution_config import SkillEvolutionConfig
|
||||
@@ -20,6 +21,7 @@ __all__ = [
|
||||
"SkillsConfig",
|
||||
"ExtensionsConfig",
|
||||
"get_extensions_config",
|
||||
"LoopDetectionConfig",
|
||||
"MemoryConfig",
|
||||
"get_memory_config",
|
||||
"get_tracing_config",
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
"""Configuration and loaders for custom agents."""
|
||||
"""Configuration and loaders for custom agents.
|
||||
|
||||
Custom agents are stored per-user under ``{base_dir}/users/{user_id}/agents/{name}/``.
|
||||
A legacy shared layout at ``{base_dir}/agents/{name}/`` is still readable so that
|
||||
installations that pre-date user isolation continue to work until they run the
|
||||
``scripts/migrate_user_isolation.py`` migration. New writes always target the
|
||||
per-user layout.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,14 +49,47 @@ class AgentConfig(BaseModel):
|
||||
skills: list[str] | None = None
|
||||
|
||||
|
||||
def load_agent_config(name: str | None) -> AgentConfig | None:
|
||||
def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
|
||||
"""Return the on-disk directory for an agent, preferring the per-user layout.
|
||||
|
||||
Resolution order:
|
||||
1. ``{base_dir}/users/{user_id}/agents/{name}/`` (per-user, current layout).
|
||||
2. ``{base_dir}/agents/{name}/`` (legacy shared layout — read-only fallback).
|
||||
|
||||
If neither exists, the per-user path is returned so callers that intend to
|
||||
create the agent write into the new layout.
|
||||
|
||||
Args:
|
||||
name: Validated agent name.
|
||||
user_id: Owner of the agent. Defaults to the effective user from the
|
||||
request context (or ``"default"`` in no-auth mode).
|
||||
"""
|
||||
paths = get_paths()
|
||||
effective_user = user_id or get_effective_user_id()
|
||||
user_path = paths.user_agent_dir(effective_user, name)
|
||||
if user_path.exists():
|
||||
return user_path
|
||||
|
||||
legacy_path = paths.agent_dir(name)
|
||||
if legacy_path.exists():
|
||||
return legacy_path
|
||||
|
||||
return user_path
|
||||
|
||||
|
||||
def load_agent_config(name: str | None, *, user_id: str | None = None) -> AgentConfig | None:
|
||||
"""Load the custom or default agent's config from its directory.
|
||||
|
||||
Reads from the per-user layout first; falls back to the legacy shared layout
|
||||
for installations that have not yet been migrated.
|
||||
|
||||
Args:
|
||||
name: The agent name.
|
||||
user_id: Owner of the agent. Defaults to the effective user from the
|
||||
current request context.
|
||||
|
||||
Returns:
|
||||
AgentConfig instance.
|
||||
AgentConfig instance, or ``None`` if ``name`` is ``None``.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the agent directory or config.yaml does not exist.
|
||||
@@ -58,7 +100,7 @@ def load_agent_config(name: str | None) -> AgentConfig | None:
|
||||
return None
|
||||
|
||||
name = validate_agent_name(name)
|
||||
agent_dir = get_paths().agent_dir(name)
|
||||
agent_dir = resolve_agent_dir(name, user_id=user_id)
|
||||
config_file = agent_dir / "config.yaml"
|
||||
|
||||
if not agent_dir.exists():
|
||||
@@ -84,7 +126,7 @@ def load_agent_config(name: str | None) -> AgentConfig | None:
|
||||
return AgentConfig(**data)
|
||||
|
||||
|
||||
def load_agent_soul(agent_name: str | None) -> str | None:
|
||||
def load_agent_soul(agent_name: str | None, *, user_id: str | None = None) -> str | None:
|
||||
"""Read the SOUL.md file for a custom agent, if it exists.
|
||||
|
||||
SOUL.md defines the agent's personality, values, and behavioral guardrails.
|
||||
@@ -92,11 +134,16 @@ def load_agent_soul(agent_name: str | None) -> str | None:
|
||||
|
||||
Args:
|
||||
agent_name: The name of the agent or None for the default agent.
|
||||
user_id: Owner of the agent. Defaults to the effective user from the
|
||||
current request context.
|
||||
|
||||
Returns:
|
||||
The SOUL.md content as a string, or None if the file does not exist.
|
||||
"""
|
||||
agent_dir = get_paths().agent_dir(agent_name) if agent_name else get_paths().base_dir
|
||||
if agent_name:
|
||||
agent_dir = resolve_agent_dir(agent_name, user_id=user_id)
|
||||
else:
|
||||
agent_dir = get_paths().base_dir
|
||||
soul_path = agent_dir / SOUL_FILENAME
|
||||
if not soul_path.exists():
|
||||
return None
|
||||
@@ -104,32 +151,50 @@ def load_agent_soul(agent_name: str | None) -> str | None:
|
||||
return content or None
|
||||
|
||||
|
||||
def list_custom_agents() -> list[AgentConfig]:
|
||||
def list_custom_agents(*, user_id: str | None = None) -> list[AgentConfig]:
|
||||
"""Scan the agents directory and return all valid custom agents.
|
||||
|
||||
Returns the union of agents in the per-user layout and the legacy shared
|
||||
layout, so that pre-migration installations remain visible until they are
|
||||
migrated. Per-user entries shadow legacy entries with the same name.
|
||||
|
||||
Args:
|
||||
user_id: Owner whose agents to list. Defaults to the effective user
|
||||
from the current request context.
|
||||
|
||||
Returns:
|
||||
List of AgentConfig for each valid agent directory found.
|
||||
"""
|
||||
agents_dir = get_paths().agents_dir
|
||||
|
||||
if not agents_dir.exists():
|
||||
return []
|
||||
paths = get_paths()
|
||||
effective_user = user_id or get_effective_user_id()
|
||||
|
||||
seen: set[str] = set()
|
||||
agents: list[AgentConfig] = []
|
||||
|
||||
for entry in sorted(agents_dir.iterdir()):
|
||||
if not entry.is_dir():
|
||||
user_root = paths.user_agents_dir(effective_user)
|
||||
legacy_root = paths.agents_dir
|
||||
|
||||
for root in (user_root, legacy_root):
|
||||
if not root.exists():
|
||||
continue
|
||||
for entry in sorted(root.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
if entry.name in seen:
|
||||
continue
|
||||
config_file = entry / "config.yaml"
|
||||
if not config_file.exists():
|
||||
logger.debug(f"Skipping {entry.name}: no config.yaml")
|
||||
continue
|
||||
|
||||
config_file = entry / "config.yaml"
|
||||
if not config_file.exists():
|
||||
logger.debug(f"Skipping {entry.name}: no config.yaml")
|
||||
continue
|
||||
|
||||
try:
|
||||
agent_cfg = load_agent_config(entry.name)
|
||||
agents.append(agent_cfg)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping agent '{entry.name}': {e}")
|
||||
try:
|
||||
agent_cfg = load_agent_config(entry.name, user_id=effective_user)
|
||||
if agent_cfg is None:
|
||||
continue
|
||||
agents.append(agent_cfg)
|
||||
seen.add(entry.name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping agent '{entry.name}': {e}")
|
||||
|
||||
agents.sort(key=lambda a: a.name)
|
||||
return agents
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any, Self
|
||||
@@ -14,6 +15,7 @@ from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpo
|
||||
from deerflow.config.database_config import DatabaseConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.run_events_config import RunEventsConfig
|
||||
@@ -99,6 +101,7 @@ class AppConfig(BaseModel):
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
|
||||
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
|
||||
@@ -157,56 +160,54 @@ class AppConfig(BaseModel):
|
||||
config_data = cls.resolve_env_variables(config_data)
|
||||
cls._apply_database_defaults(config_data)
|
||||
|
||||
# Load title config if present
|
||||
if "title" in config_data:
|
||||
load_title_config_from_dict(config_data["title"])
|
||||
|
||||
# Load summarization config if present
|
||||
if "summarization" in config_data:
|
||||
load_summarization_config_from_dict(config_data["summarization"])
|
||||
|
||||
# Load memory config if present
|
||||
if "memory" in config_data:
|
||||
load_memory_config_from_dict(config_data["memory"])
|
||||
|
||||
# Always refresh agents API config so removed config sections reset
|
||||
# singleton-backed state to its default/disabled values on reload.
|
||||
load_agents_api_config_from_dict(config_data.get("agents_api") or {})
|
||||
|
||||
# Load subagents config if present
|
||||
if "subagents" in config_data:
|
||||
load_subagents_config_from_dict(config_data["subagents"])
|
||||
|
||||
# Load tool_search config if present
|
||||
if "tool_search" in config_data:
|
||||
load_tool_search_config_from_dict(config_data["tool_search"])
|
||||
|
||||
# Load guardrails config if present
|
||||
if "guardrails" in config_data:
|
||||
load_guardrails_config_from_dict(config_data["guardrails"])
|
||||
|
||||
# Load circuit_breaker config if present
|
||||
if "circuit_breaker" in config_data:
|
||||
config_data["circuit_breaker"] = config_data["circuit_breaker"]
|
||||
|
||||
# Load checkpointer config if present
|
||||
if "checkpointer" in config_data:
|
||||
load_checkpointer_config_from_dict(config_data["checkpointer"])
|
||||
|
||||
# Load stream bridge config if present
|
||||
if "stream_bridge" in config_data:
|
||||
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
|
||||
|
||||
# Always refresh ACP agent config so removed entries do not linger across reloads.
|
||||
load_acp_config_from_dict(config_data.get("acp_agents", {}))
|
||||
|
||||
# Load extensions config separately (it's in a different file)
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
config_data["extensions"] = extensions_config.model_dump()
|
||||
|
||||
result = cls.model_validate(config_data)
|
||||
acp_agents = cls._validate_acp_agents(config_data.get("acp_agents", {}))
|
||||
cls._apply_singleton_configs(result, acp_agents)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _validate_acp_agents(
|
||||
cls,
|
||||
config_data: Mapping[str, Mapping[str, object]] | None,
|
||||
) -> dict[str, ACPAgentConfig]:
|
||||
if config_data is None:
|
||||
config_data = {}
|
||||
return {name: ACPAgentConfig(**cfg) for name, cfg in config_data.items()}
|
||||
|
||||
@classmethod
|
||||
def _apply_singleton_configs(cls, config: Self, acp_agents: dict[str, ACPAgentConfig]) -> None:
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
previous_checkpointer_config = get_checkpointer_config()
|
||||
|
||||
load_title_config_from_dict(config.title.model_dump())
|
||||
load_summarization_config_from_dict(config.summarization.model_dump())
|
||||
load_memory_config_from_dict(config.memory.model_dump())
|
||||
load_agents_api_config_from_dict(config.agents_api.model_dump())
|
||||
load_subagents_config_from_dict(config.subagents.model_dump())
|
||||
load_tool_search_config_from_dict(config.tool_search.model_dump())
|
||||
load_guardrails_config_from_dict(config.guardrails.model_dump())
|
||||
load_checkpointer_config_from_dict(config.checkpointer.model_dump() if config.checkpointer is not None else None)
|
||||
load_stream_bridge_config_from_dict(config.stream_bridge.model_dump() if config.stream_bridge is not None else None)
|
||||
load_acp_config_from_dict({name: agent.model_dump() for name, agent in acp_agents.items()})
|
||||
|
||||
if previous_checkpointer_config != config.checkpointer:
|
||||
# These runtime singletons derive their backend from checkpointer config.
|
||||
# Keep imports local to avoid cycles: both providers import get_app_config.
|
||||
from deerflow.runtime.checkpointer import reset_checkpointer
|
||||
from deerflow.runtime.store import reset_store
|
||||
|
||||
reset_checkpointer()
|
||||
reset_store()
|
||||
|
||||
@classmethod
|
||||
def _apply_database_defaults(cls, config_data: dict[str, Any]) -> None:
|
||||
"""Apply config.yaml defaults for persistence when the section is absent."""
|
||||
|
||||
@@ -14,12 +14,13 @@ class CheckpointerConfig(BaseModel):
|
||||
description="Checkpointer backend type. "
|
||||
"'memory' is in-process only (lost on restart). "
|
||||
"'sqlite' persists to a local file (requires langgraph-checkpoint-sqlite). "
|
||||
"'postgres' persists to PostgreSQL (requires langgraph-checkpoint-postgres)."
|
||||
"'postgres' persists to PostgreSQL (install with deerflow-harness[postgres])."
|
||||
)
|
||||
connection_string: str | None = Field(
|
||||
default=None,
|
||||
description="Connection string for sqlite (file path) or postgres (DSN). "
|
||||
"Required for sqlite and postgres types. "
|
||||
"Optional for sqlite and defaults to 'store.db' when omitted. "
|
||||
"Required for postgres. "
|
||||
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
|
||||
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
|
||||
)
|
||||
@@ -40,7 +41,10 @@ def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
|
||||
_checkpointer_config = config
|
||||
|
||||
|
||||
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
|
||||
def load_checkpointer_config_from_dict(config_dict: dict | None) -> None:
|
||||
"""Load checkpointer configuration from a dictionary."""
|
||||
global _checkpointer_config
|
||||
if config_dict is None:
|
||||
_checkpointer_config = None
|
||||
return
|
||||
_checkpointer_config = CheckpointerConfig(**config_dict)
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Configuration for loop detection middleware."""
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class ToolFreqOverride(BaseModel):
|
||||
"""Per-tool frequency threshold override.
|
||||
|
||||
Can be higher or lower than the global defaults. Commonly used to raise
|
||||
thresholds for high-frequency tools like bash in batch workflows (e.g.
|
||||
RNA-seq pipelines) without weakening protection on every other tool.
|
||||
"""
|
||||
|
||||
warn: int = Field(ge=1)
|
||||
hard_limit: int = Field(ge=1)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate(self) -> "ToolFreqOverride":
|
||||
if self.hard_limit < self.warn:
|
||||
raise ValueError("hard_limit must be >= warn")
|
||||
return self
|
||||
|
||||
|
||||
class LoopDetectionConfig(BaseModel):
|
||||
"""Configuration for repetitive tool-call loop detection."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable repetitive tool-call loop detection",
|
||||
)
|
||||
warn_threshold: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
description="Number of identical tool-call sets before injecting a warning",
|
||||
)
|
||||
hard_limit: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
description="Number of identical tool-call sets before forcing a stop",
|
||||
)
|
||||
window_size: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
description="Number of recent tool-call sets to track per thread",
|
||||
)
|
||||
max_tracked_threads: int = Field(
|
||||
default=100,
|
||||
ge=1,
|
||||
description="Maximum number of thread histories to keep in memory",
|
||||
)
|
||||
tool_freq_warn: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
description="Number of calls to the same tool type before injecting a frequency warning",
|
||||
)
|
||||
tool_freq_hard_limit: int = Field(
|
||||
default=50,
|
||||
ge=1,
|
||||
description="Number of calls to the same tool type before forcing a stop",
|
||||
)
|
||||
tool_freq_overrides: dict[str, ToolFreqOverride] = Field(
|
||||
default_factory=dict,
|
||||
description=("Per-tool overrides for tool_freq_warn / tool_freq_hard_limit, keyed by tool name. Values can be higher or lower than the global defaults. Commonly used to raise thresholds for high-frequency tools like bash."),
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_thresholds(self) -> "LoopDetectionConfig":
|
||||
"""Ensure hard stop cannot happen before the warning threshold."""
|
||||
if self.hard_limit < self.warn_threshold:
|
||||
raise ValueError("hard_limit must be greater than or equal to warn_threshold")
|
||||
if self.tool_freq_hard_limit < self.tool_freq_warn:
|
||||
raise ValueError("tool_freq_hard_limit must be greater than or equal to tool_freq_warn")
|
||||
return self
|
||||
@@ -132,15 +132,20 @@ class Paths:
|
||||
|
||||
@property
|
||||
def agents_dir(self) -> Path:
|
||||
"""Root directory for all custom agents: `{base_dir}/agents/`."""
|
||||
"""Legacy root for shared (pre user-isolation) custom agents: `{base_dir}/agents/`.
|
||||
|
||||
New code should use :meth:`user_agents_dir` instead. This property remains
|
||||
only as a read-side fallback for installations that have not yet run the
|
||||
``migrate_user_isolation.py`` script.
|
||||
"""
|
||||
return self.base_dir / "agents"
|
||||
|
||||
def agent_dir(self, name: str) -> Path:
|
||||
"""Directory for a specific agent: `{base_dir}/agents/{name}/`."""
|
||||
"""Legacy per-agent directory (no user isolation): `{base_dir}/agents/{name}/`."""
|
||||
return self.agents_dir / name.lower()
|
||||
|
||||
def agent_memory_file(self, name: str) -> Path:
|
||||
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
|
||||
"""Legacy per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
|
||||
return self.agent_dir(name) / "memory.json"
|
||||
|
||||
def user_dir(self, user_id: str) -> Path:
|
||||
@@ -151,9 +156,17 @@ class Paths:
|
||||
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
|
||||
return self.user_dir(user_id) / "memory.json"
|
||||
|
||||
def user_agents_dir(self, user_id: str) -> Path:
|
||||
"""Per-user root for that user's custom agents: `{base_dir}/users/{user_id}/agents/`."""
|
||||
return self.user_dir(user_id) / "agents"
|
||||
|
||||
def user_agent_dir(self, user_id: str, agent_name: str) -> Path:
|
||||
"""Per-user per-agent directory: `{base_dir}/users/{user_id}/agents/{name}/`."""
|
||||
return self.user_agents_dir(user_id) / agent_name.lower()
|
||||
|
||||
def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path:
|
||||
"""Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`."""
|
||||
return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json"
|
||||
return self.user_agent_dir(user_id, agent_name) / "memory.json"
|
||||
|
||||
def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
"""
|
||||
|
||||
@@ -23,6 +23,9 @@ class SandboxConfig(BaseModel):
|
||||
replicas: Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.
|
||||
container_prefix: Prefix for container names (default: deer-flow-sandbox)
|
||||
idle_timeout: Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.
|
||||
auto_restart: Automatically restart sandbox containers that have crashed (default: true). When a tool call
|
||||
detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated
|
||||
on the next acquire. Set to false to disable.
|
||||
mounts: List of volume mounts to share directories with the container
|
||||
environment: Environment variables to inject into the container (values starting with $ are resolved from host env)
|
||||
"""
|
||||
@@ -55,6 +58,10 @@ class SandboxConfig(BaseModel):
|
||||
default=None,
|
||||
description="Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.",
|
||||
)
|
||||
auto_restart: bool = Field(
|
||||
default=True,
|
||||
description="Automatically restart sandbox containers that have crashed. When a tool call detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated on the next acquire.",
|
||||
)
|
||||
mounts: list[VolumeMountConfig] = Field(
|
||||
default_factory=list,
|
||||
description="List of volume mounts to share directories between host and container",
|
||||
|
||||
@@ -6,6 +6,13 @@ from pydantic import BaseModel, Field
|
||||
from deerflow.config.runtime_paths import project_root, resolve_path
|
||||
|
||||
|
||||
def _legacy_skills_candidates() -> tuple[Path, ...]:
|
||||
"""Return source-tree skills locations for monorepo compatibility."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
return (repo_root / "skills",)
|
||||
|
||||
|
||||
class SkillsConfig(BaseModel):
|
||||
"""Configuration for skills system"""
|
||||
|
||||
@@ -15,7 +22,7 @@ class SkillsConfig(BaseModel):
|
||||
)
|
||||
path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to skills directory. If not specified, defaults to skills under the caller project root.",
|
||||
description=("Path to skills directory. If not specified, defaults to `skills` under the caller project root, falling back to the legacy repo-root location for monorepo compatibility."),
|
||||
)
|
||||
container_path: str = Field(
|
||||
default="/mnt/skills",
|
||||
@@ -26,15 +33,30 @@ class SkillsConfig(BaseModel):
|
||||
"""
|
||||
Get the resolved skills directory path.
|
||||
|
||||
Returns:
|
||||
Path to the skills directory
|
||||
Resolution order:
|
||||
1. Explicit ``path`` field
|
||||
2. ``DEER_FLOW_SKILLS_PATH`` environment variable
|
||||
3. ``skills`` under the caller project root (``project_root()``)
|
||||
4. Legacy repo-root candidates for monorepo compatibility (``_legacy_skills_candidates``)
|
||||
|
||||
When none of (3) or (4) exist on disk, the project-root default is returned so callers
|
||||
can still surface a stable "no skills" location without raising.
|
||||
"""
|
||||
if self.path:
|
||||
# Use configured path (can be absolute or relative to project root)
|
||||
return resolve_path(self.path)
|
||||
if env_path := os.getenv("DEER_FLOW_SKILLS_PATH"):
|
||||
return resolve_path(env_path)
|
||||
return project_root() / "skills"
|
||||
|
||||
project_default = project_root() / "skills"
|
||||
if project_default.is_dir():
|
||||
return project_default
|
||||
|
||||
for candidate in _legacy_skills_candidates():
|
||||
if candidate.is_dir():
|
||||
return candidate
|
||||
|
||||
return project_default
|
||||
|
||||
def get_skill_container_path(self, skill_name: str, category: str = "public") -> str:
|
||||
"""
|
||||
|
||||
@@ -40,7 +40,10 @@ def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
|
||||
_stream_bridge_config = config
|
||||
|
||||
|
||||
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
|
||||
def load_stream_bridge_config_from_dict(config_dict: dict | None) -> None:
|
||||
"""Load stream bridge configuration from a dictionary."""
|
||||
global _stream_bridge_config
|
||||
if config_dict is None:
|
||||
_stream_bridge_config = None
|
||||
return
|
||||
_stream_bridge_config = StreamBridgeConfig(**config_dict)
|
||||
|
||||
@@ -179,9 +179,3 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
|
||||
overrides_summary or "none",
|
||||
custom_agents_names or "none",
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
|
||||
_subagents_config.timeout_seconds,
|
||||
_subagents_config.max_turns,
|
||||
)
|
||||
|
||||
@@ -4,4 +4,4 @@ from pydantic import BaseModel, Field
|
||||
class TokenUsageConfig(BaseModel):
|
||||
"""Configuration for token usage tracking."""
|
||||
|
||||
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
|
||||
enabled: bool = Field(default=True, description="Enable token usage tracking middleware")
|
||||
|
||||
@@ -196,6 +196,10 @@ class ClaudeChatModel(ChatAnthropic):
|
||||
enforced by both the Anthropic API and AWS Bedrock. Breakpoints are
|
||||
placed on the *last* eligible blocks because later breakpoints cover a
|
||||
larger prefix and yield better cache hit rates.
|
||||
|
||||
The system prompt is expected to be fully static (no per-user memory or
|
||||
current date). Dynamic context is injected per-turn via
|
||||
DynamicContextMiddleware as a <system-reminder> in the first HumanMessage.
|
||||
"""
|
||||
MAX_CACHE_BREAKPOINTS = 4
|
||||
|
||||
|
||||
@@ -27,6 +27,34 @@ from deerflow.models.credential_loader import CodexCliCredential, load_codex_cli
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
def _build_usage_metadata(oai_usage: dict) -> dict:
|
||||
"""Convert Codex/Responses API usage dict to LangChain usage_metadata format.
|
||||
|
||||
Maps OpenAI Responses API token usage fields to the dict structure that
|
||||
LangChain AIMessage.usage_metadata expects. This avoids depending on
|
||||
langchain_openai private helpers like ``_create_usage_metadata_responses``.
|
||||
"""
|
||||
input_tokens = oai_usage.get("input_tokens", 0)
|
||||
output_tokens = oai_usage.get("output_tokens", 0)
|
||||
total_tokens = oai_usage.get("total_tokens", input_tokens + output_tokens)
|
||||
metadata: dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
input_details = oai_usage.get("input_tokens_details") or {}
|
||||
output_details = oai_usage.get("output_tokens_details") or {}
|
||||
cache_read = input_details.get("cached_tokens")
|
||||
if cache_read is not None:
|
||||
metadata["input_token_details"] = {"cache_read": cache_read}
|
||||
reasoning = output_details.get("reasoning_tokens")
|
||||
if reasoning is not None:
|
||||
metadata["output_token_details"] = {"reasoning": reasoning}
|
||||
return metadata
|
||||
|
||||
|
||||
MAX_RETRIES = 3
|
||||
|
||||
|
||||
@@ -346,6 +374,7 @@ class CodexChatModel(BaseChatModel):
|
||||
)
|
||||
|
||||
usage = response.get("usage", {})
|
||||
usage_metadata = _build_usage_metadata(usage) if usage else None
|
||||
additional_kwargs = {}
|
||||
if reasoning_content:
|
||||
additional_kwargs["reasoning_content"] = reasoning_content
|
||||
@@ -355,6 +384,7 @@ class CodexChatModel(BaseChatModel):
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
additional_kwargs=additional_kwargs,
|
||||
usage_metadata=usage_metadata,
|
||||
response_metadata={
|
||||
"model": response.get("model", self.model),
|
||||
"usage": usage,
|
||||
|
||||
@@ -81,7 +81,16 @@ async def init_engine(
|
||||
try:
|
||||
import asyncpg # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError("database.backend is set to 'postgres' but asyncpg is not installed.\nInstall it with:\n uv sync --extra postgres\nOr switch to backend: sqlite in config.yaml for single-node deployment.") from None
|
||||
raise ImportError(
|
||||
"database.backend is set to 'postgres' but asyncpg is not installed.\n"
|
||||
"Install it with:\n"
|
||||
" cd backend && uv sync --all-packages --extra postgres\n"
|
||||
"On the next `make dev` the postgres extra is auto-detected from\n"
|
||||
"config.yaml (database.backend: postgres) and reinstalled, so it\n"
|
||||
"will not be wiped again. Set UV_EXTRAS=postgres in .env to opt in\n"
|
||||
"explicitly. Or switch to backend: sqlite in config.yaml for\n"
|
||||
"single-node deployment."
|
||||
) from None
|
||||
|
||||
if backend == "sqlite":
|
||||
import os
|
||||
|
||||
@@ -7,13 +7,13 @@ router for thread records.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
from deerflow.utils.time import coerce_iso, now_iso
|
||||
|
||||
THREADS_NS: tuple[str, ...] = ("threads",)
|
||||
|
||||
@@ -48,7 +48,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
|
||||
now = time.time()
|
||||
now = now_iso()
|
||||
record: dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
@@ -106,7 +106,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
if record is None:
|
||||
return
|
||||
record["display_name"] = display_name
|
||||
record["updated_at"] = time.time()
|
||||
record["updated_at"] = now_iso()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
@@ -114,7 +114,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
if record is None:
|
||||
return
|
||||
record["status"] = status
|
||||
record["updated_at"] = time.time()
|
||||
record["updated_at"] = now_iso()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
@@ -124,7 +124,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
merged = dict(record.get("metadata") or {})
|
||||
merged.update(metadata)
|
||||
record["metadata"] = merged
|
||||
record["updated_at"] = time.time()
|
||||
record["updated_at"] = now_iso()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
@@ -144,6 +144,8 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
"display_name": val.get("display_name"),
|
||||
"status": val.get("status", "idle"),
|
||||
"metadata": val.get("metadata", {}),
|
||||
"created_at": str(val.get("created_at", "")),
|
||||
"updated_at": str(val.get("updated_at", "")),
|
||||
# ``coerce_iso`` heals legacy unix-second values written by
|
||||
# earlier Gateway versions that called ``str(time.time())``.
|
||||
"created_at": coerce_iso(val.get("created_at", "")),
|
||||
"updated_at": coerce_iso(val.get("updated_at", "")),
|
||||
}
|
||||
|
||||
@@ -36,7 +36,9 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
|
||||
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
|
||||
POSTGRES_INSTALL = (
|
||||
"langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install the package extra with: pip install 'deerflow-harness[postgres]' (or use: uv sync --all-packages --extra postgres when developing locally)"
|
||||
)
|
||||
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -9,6 +9,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
@@ -33,20 +34,21 @@ class DbRunEventStore(RunEventStore):
|
||||
if isinstance(val, datetime):
|
||||
d["created_at"] = val.isoformat()
|
||||
d.pop("id", None)
|
||||
# Restore dict content that was JSON-serialized on write
|
||||
# Restore structured content that was JSON-serialized on write.
|
||||
raw = d.get("content", "")
|
||||
if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"):
|
||||
metadata = d.get("metadata", {})
|
||||
if isinstance(raw, str) and (metadata.get("content_is_json") or metadata.get("content_is_dict")):
|
||||
try:
|
||||
d["content"] = json.loads(raw)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Content looked like JSON (content_is_dict flag) but failed to parse;
|
||||
# Content looked like JSON but failed to parse;
|
||||
# keep the raw string as-is.
|
||||
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
|
||||
return d
|
||||
|
||||
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
|
||||
def _truncate_trace(self, category: str, content: Any, metadata: dict | None) -> tuple[Any, dict]:
|
||||
if category == "trace":
|
||||
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
|
||||
text = content if isinstance(content, str) else json.dumps(content, default=str, ensure_ascii=False)
|
||||
encoded = text.encode("utf-8")
|
||||
if len(encoded) > self._max_trace_content:
|
||||
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
|
||||
@@ -54,6 +56,18 @@ class DbRunEventStore(RunEventStore):
|
||||
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
|
||||
return content, metadata or {}
|
||||
|
||||
@staticmethod
|
||||
def _content_to_db(content: Any, metadata: dict | None) -> tuple[str, dict]:
|
||||
metadata = metadata or {}
|
||||
if isinstance(content, str):
|
||||
return content, metadata
|
||||
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||
metadata = {**metadata, "content_is_json": True}
|
||||
if isinstance(content, dict):
|
||||
metadata["content_is_dict"] = True
|
||||
return db_content, metadata
|
||||
|
||||
@staticmethod
|
||||
def _user_id_from_context() -> str | None:
|
||||
"""Soft read of user_id from contextvar for write paths.
|
||||
@@ -82,11 +96,7 @@ class DbRunEventStore(RunEventStore):
|
||||
the initial ``human_message`` event (once per run).
|
||||
"""
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
if isinstance(content, dict):
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
db_content, metadata = self._content_to_db(content, metadata)
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
@@ -128,11 +138,7 @@ class DbRunEventStore(RunEventStore):
|
||||
category = e.get("category", "trace")
|
||||
metadata = e.get("metadata")
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
if isinstance(content, dict):
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
db_content, metadata = self._content_to_db(content, metadata)
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
|
||||
@@ -6,9 +6,10 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deerflow.utils.time import now_iso as _now_iso
|
||||
|
||||
from .schemas import DisconnectMode, RunStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -17,10 +18,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(UTC).isoformat()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunRecord:
|
||||
"""Mutable record for a single run."""
|
||||
|
||||
@@ -23,6 +23,8 @@ from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
@@ -442,6 +444,12 @@ async def _rollback_to_pre_run_checkpoint(
|
||||
if checkpoint_to_restore.get("id") is None:
|
||||
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
|
||||
return
|
||||
restore_marker = _new_checkpoint_marker()
|
||||
checkpoint_to_restore = {
|
||||
**checkpoint_to_restore,
|
||||
"id": restore_marker["id"],
|
||||
"ts": restore_marker["ts"],
|
||||
}
|
||||
metadata = pre_run_snapshot.get("metadata", {})
|
||||
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
|
||||
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
|
||||
@@ -493,6 +501,11 @@ async def _rollback_to_pre_run_checkpoint(
|
||||
)
|
||||
|
||||
|
||||
def _new_checkpoint_marker() -> dict[str, str]:
|
||||
marker = empty_checkpoint()
|
||||
return {"id": marker["id"], "ts": marker["ts"]}
|
||||
|
||||
|
||||
def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
"""Map LangGraph internal stream_mode name to SSE event name.
|
||||
|
||||
|
||||
@@ -36,7 +36,9 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_STORE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite store. Install it with: uv add langgraph-checkpoint-sqlite"
|
||||
POSTGRES_STORE_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL store. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
|
||||
POSTGRES_STORE_INSTALL = (
|
||||
"langgraph-checkpoint-postgres is required for the PostgreSQL store. Install the package extra with: pip install 'deerflow-harness[postgres]' (or use: uv sync --all-packages --extra postgres when developing locally)"
|
||||
)
|
||||
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -42,6 +42,13 @@ class LocalSandbox(Sandbox):
|
||||
"""Return whether the selected shell is cmd.exe."""
|
||||
return LocalSandbox._shell_name(shell) in {"cmd", "cmd.exe"}
|
||||
|
||||
@staticmethod
|
||||
def _is_msys_shell(shell: str) -> bool:
|
||||
"""Return whether the selected shell is a Git Bash/MSYS shell."""
|
||||
normalized = shell.replace("\\", "/").lower()
|
||||
shell_name = LocalSandbox._shell_name(shell)
|
||||
return shell_name in {"sh.exe", "bash.exe"} and any(part in normalized for part in ("/git/", "/mingw", "/msys"))
|
||||
|
||||
@staticmethod
|
||||
def _find_first_available_shell(candidates: tuple[str, ...]) -> str | None:
|
||||
"""Return the first executable shell path or command found from candidates."""
|
||||
@@ -303,12 +310,19 @@ class LocalSandbox(Sandbox):
|
||||
shell = self._get_shell()
|
||||
|
||||
if os.name == "nt":
|
||||
env = None
|
||||
if self._is_powershell(shell):
|
||||
args = [shell, "-NoProfile", "-Command", resolved_command]
|
||||
elif self._is_cmd_shell(shell):
|
||||
args = [shell, "/c", resolved_command]
|
||||
else:
|
||||
args = [shell, "-c", resolved_command]
|
||||
if self._is_msys_shell(shell):
|
||||
env = {
|
||||
**os.environ,
|
||||
"MSYS_NO_PATHCONV": "1",
|
||||
"MSYS2_ARG_CONV_EXCL": "*",
|
||||
}
|
||||
|
||||
result = subprocess.run(
|
||||
args,
|
||||
@@ -316,6 +330,7 @@ class LocalSandbox(Sandbox):
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600,
|
||||
env=env,
|
||||
)
|
||||
else:
|
||||
args = [shell, "-c", resolved_command]
|
||||
|
||||
@@ -3,10 +3,9 @@ import re
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
from langgraph.typing import ContextT
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState, ThreadState
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
||||
from deerflow.sandbox.exceptions import (
|
||||
@@ -19,6 +18,7 @@ from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
from deerflow.sandbox.search import GrepMatch
|
||||
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
|
||||
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
|
||||
@@ -419,7 +419,7 @@ def _join_path_preserving_style(base: str, relative: str) -> str:
|
||||
return f"{stripped_base}{separator}{normalized_relative}"
|
||||
|
||||
|
||||
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
|
||||
def _sanitize_error(error: Exception, runtime: Runtime | None = None) -> str:
|
||||
"""Sanitize an error message to avoid leaking host filesystem paths.
|
||||
|
||||
In local-sandbox mode, resolved host paths in the error string are masked
|
||||
@@ -994,7 +994,7 @@ def _apply_cwd_prefix(command: str, thread_data: ThreadDataState | None) -> str:
|
||||
return command
|
||||
|
||||
|
||||
def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> ThreadDataState | None:
|
||||
def get_thread_data(runtime: Runtime | None) -> ThreadDataState | None:
|
||||
"""Extract thread_data from runtime state."""
|
||||
if runtime is None:
|
||||
return None
|
||||
@@ -1003,7 +1003,7 @@ def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> Threa
|
||||
return runtime.state.get("thread_data")
|
||||
|
||||
|
||||
def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool:
|
||||
def is_local_sandbox(runtime: Runtime | None) -> bool:
|
||||
"""Check if the current sandbox is a local sandbox.
|
||||
|
||||
Path replacement is only needed for local sandbox since aio sandbox
|
||||
@@ -1019,7 +1019,7 @@ def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool
|
||||
return sandbox_state.get("sandbox_id") == "local"
|
||||
|
||||
|
||||
def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
|
||||
def sandbox_from_runtime(runtime: Runtime | None = None) -> Sandbox:
|
||||
"""Extract sandbox instance from tool runtime.
|
||||
|
||||
DEPRECATED: Use ensure_sandbox_initialized() for lazy initialization support.
|
||||
@@ -1048,7 +1048,7 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
|
||||
return sandbox
|
||||
|
||||
|
||||
def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
|
||||
def ensure_sandbox_initialized(runtime: Runtime | None = None) -> Sandbox:
|
||||
"""Ensure sandbox is initialized, acquiring lazily if needed.
|
||||
|
||||
On first call, acquires a sandbox from the provider and stores it in runtime state.
|
||||
@@ -1107,7 +1107,7 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
|
||||
return sandbox
|
||||
|
||||
|
||||
def ensure_thread_directories_exist(runtime: ToolRuntime[ContextT, ThreadState] | None) -> None:
|
||||
def ensure_thread_directories_exist(runtime: Runtime | None) -> None:
|
||||
"""Ensure thread data directories (workspace, uploads, outputs) exist.
|
||||
|
||||
This function is called lazily when any sandbox tool is first used.
|
||||
@@ -1221,7 +1221,7 @@ def _truncate_ls_output(output: str, max_chars: int) -> str:
|
||||
|
||||
|
||||
@tool("bash", parse_docstring=True)
|
||||
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
|
||||
def bash_tool(runtime: Runtime, description: str, command: str) -> str:
|
||||
"""Execute a bash command in a Linux environment.
|
||||
|
||||
|
||||
@@ -1270,7 +1270,7 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
||||
|
||||
|
||||
@tool("ls", parse_docstring=True)
|
||||
def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path: str) -> str:
|
||||
def ls_tool(runtime: Runtime, description: str, path: str) -> str:
|
||||
"""List the contents of a directory up to 2 levels deep in tree format.
|
||||
|
||||
Args:
|
||||
@@ -1318,7 +1318,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
|
||||
@tool("glob", parse_docstring=True)
|
||||
def glob_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
@@ -1368,7 +1368,7 @@ def glob_tool(
|
||||
|
||||
@tool("grep", parse_docstring=True)
|
||||
def grep_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
@@ -1438,7 +1438,7 @@ def grep_tool(
|
||||
|
||||
@tool("read_file", parse_docstring=True)
|
||||
def read_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
path: str,
|
||||
start_line: int | None = None,
|
||||
@@ -1493,7 +1493,7 @@ def read_file_tool(
|
||||
|
||||
@tool("write_file", parse_docstring=True)
|
||||
def write_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
path: str,
|
||||
content: str,
|
||||
@@ -1533,7 +1533,7 @@ def write_file_tool(
|
||||
|
||||
@tool("str_replace", parse_docstring=True)
|
||||
def str_replace_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
path: str,
|
||||
old_str: str,
|
||||
|
||||
@@ -9,6 +9,29 @@ from .types import SKILL_MD_FILE, Skill, SkillCategory
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_allowed_tools(raw: object, skill_file: Path) -> list[str] | None:
|
||||
"""Parse the optional allowed-tools frontmatter field.
|
||||
|
||||
Returns None when the field is omitted. Returns a list when the field is a
|
||||
YAML sequence of strings, including an empty list for explicit no-tool
|
||||
skills. Raises ValueError for malformed values.
|
||||
"""
|
||||
if raw is None:
|
||||
return None
|
||||
if not isinstance(raw, list):
|
||||
raise ValueError(f"allowed-tools in {skill_file} must be a list of strings")
|
||||
|
||||
allowed_tools: list[str] = []
|
||||
for item in raw:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError(f"allowed-tools in {skill_file} must contain only strings")
|
||||
tool_name = item.strip()
|
||||
if not tool_name:
|
||||
raise ValueError(f"allowed-tools in {skill_file} cannot contain empty tool names")
|
||||
allowed_tools.append(tool_name)
|
||||
return allowed_tools
|
||||
|
||||
|
||||
def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: Path | None = None) -> Skill | None:
|
||||
"""Parse a SKILL.md file and extract metadata.
|
||||
|
||||
@@ -64,6 +87,12 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
|
||||
if license_text is not None:
|
||||
license_text = str(license_text).strip() or None
|
||||
|
||||
try:
|
||||
allowed_tools = parse_allowed_tools(metadata.get("allowed-tools"), skill_file)
|
||||
except ValueError as exc:
|
||||
logger.error("Invalid allowed-tools in %s: %s", skill_file, exc)
|
||||
return None
|
||||
|
||||
return Skill(
|
||||
name=name,
|
||||
description=description,
|
||||
@@ -72,6 +101,7 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
|
||||
skill_file=skill_file,
|
||||
relative_path=relative_path or Path(skill_file.parent.name),
|
||||
category=category,
|
||||
allowed_tools=allowed_tools,
|
||||
enabled=True, # Actual state comes from the extensions config file.
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
from typing import Protocol
|
||||
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NamedTool(Protocol):
|
||||
name: str
|
||||
|
||||
|
||||
def allowed_tool_names_for_skills(skills: list[Skill]) -> set[str] | None:
|
||||
"""Return the union of explicit skill allowed-tools declarations.
|
||||
|
||||
None means legacy allow-all behavior. It is returned only when no loaded
|
||||
skill declares allowed-tools. Once any skill declares the field, legacy
|
||||
skills without the field contribute no tools instead of disabling the
|
||||
explicit restrictions from other skills.
|
||||
"""
|
||||
if not skills:
|
||||
return None
|
||||
|
||||
allowed: set[str] = set()
|
||||
has_explicit_declaration = False
|
||||
for skill in skills:
|
||||
if skill.allowed_tools is None:
|
||||
continue
|
||||
has_explicit_declaration = True
|
||||
if not skill.allowed_tools:
|
||||
logger.info("Skill %s declared empty allowed-tools", skill.name)
|
||||
allowed.update(skill.allowed_tools)
|
||||
|
||||
if not has_explicit_declaration:
|
||||
return None
|
||||
return allowed
|
||||
|
||||
|
||||
def filter_tools_by_skill_allowed_tools[ToolT: NamedTool](tools: list[ToolT], skills: list[Skill]) -> list[ToolT]:
|
||||
allowed = allowed_tool_names_for_skills(skills)
|
||||
if allowed is None:
|
||||
return tools
|
||||
|
||||
return [tool for tool in tools if tool.name in allowed]
|
||||
@@ -27,6 +27,7 @@ class Skill:
|
||||
skill_file: Path
|
||||
relative_path: Path # Relative path from category root to skill directory
|
||||
category: SkillCategory # 'public' or 'custom'
|
||||
allowed_tools: list[str] | None = None
|
||||
enabled: bool = False # Whether this skill is enabled
|
||||
|
||||
@property
|
||||
|
||||
@@ -8,6 +8,7 @@ from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from deerflow.skills.parser import parse_allowed_tools
|
||||
from deerflow.skills.types import SKILL_MD_FILE
|
||||
|
||||
# Allowed properties in SKILL.md frontmatter
|
||||
@@ -84,4 +85,9 @@ def _validate_skill_frontmatter(skill_dir: Path) -> tuple[bool, str, str | None]
|
||||
if len(description) > 1024:
|
||||
return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters.", None
|
||||
|
||||
try:
|
||||
parse_allowed_tools(frontmatter.get("allowed-tools"), skill_md)
|
||||
except ValueError as e:
|
||||
return False, str(e).replace(str(skill_md), SKILL_MD_FILE), None
|
||||
|
||||
return True, "Skill is valid!", name
|
||||
|
||||
@@ -23,6 +23,8 @@ from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadSt
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.models import create_chat_model
|
||||
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -260,16 +262,16 @@ class SubagentExecutor:
|
||||
# Generate trace_id if not provided (for top-level calls)
|
||||
self.trace_id = trace_id or str(uuid.uuid4())[:8]
|
||||
|
||||
# Filter tools based on config
|
||||
self.tools = _filter_tools(
|
||||
self._base_tools = _filter_tools(
|
||||
tools,
|
||||
config.tools,
|
||||
config.disallowed_tools,
|
||||
)
|
||||
self.tools = self._base_tools
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] SubagentExecutor initialized: {config.name} with {len(self.tools)} tools")
|
||||
|
||||
def _create_agent(self):
|
||||
def _create_agent(self, tools: list[BaseTool] | None = None):
|
||||
"""Create the agent instance."""
|
||||
app_config = self.app_config or get_app_config()
|
||||
if self.model_name is None:
|
||||
@@ -283,26 +285,14 @@ class SubagentExecutor:
|
||||
|
||||
return create_agent(
|
||||
model=model,
|
||||
tools=self.tools,
|
||||
tools=tools if tools is not None else self.tools,
|
||||
middleware=middlewares,
|
||||
system_prompt=self.config.system_prompt,
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
async def _load_skill_messages(self) -> list[SystemMessage]:
|
||||
"""Load skill content as conversation items based on config.skills.
|
||||
|
||||
Aligned with Codex's pattern: each subagent loads its own skills
|
||||
per-session and injects them as conversation items (developer messages),
|
||||
not as system prompt text. The config.skills whitelist controls which
|
||||
skills are loaded:
|
||||
- None: load all enabled skills
|
||||
- []: no skills
|
||||
- ["skill-a", "skill-b"]: only these skills
|
||||
|
||||
Returns:
|
||||
List of SystemMessages containing skill content.
|
||||
"""
|
||||
async def _load_skills(self) -> list[Skill]:
|
||||
"""Load enabled skill metadata based on config.skills."""
|
||||
if self.config.skills is not None and len(self.config.skills) == 0:
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} skills=[] — skipping skill loading")
|
||||
return []
|
||||
@@ -316,8 +306,8 @@ class SubagentExecutor:
|
||||
all_skills = await asyncio.to_thread(storage.load_skills, enabled_only=True)
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} loaded {len(all_skills)} enabled skills from disk")
|
||||
except Exception:
|
||||
logger.warning(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}", exc_info=True)
|
||||
return []
|
||||
logger.exception(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}")
|
||||
raise
|
||||
|
||||
if not all_skills:
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} no enabled skills found")
|
||||
@@ -326,10 +316,26 @@ class SubagentExecutor:
|
||||
# Filter by config.skills whitelist
|
||||
if self.config.skills is not None:
|
||||
allowed = set(self.config.skills)
|
||||
skills = [s for s in all_skills if s.name in allowed]
|
||||
else:
|
||||
skills = all_skills
|
||||
return [s for s in all_skills if s.name in allowed]
|
||||
return all_skills
|
||||
|
||||
def _apply_skill_allowed_tools(self, skills: list[Skill]) -> list[BaseTool]:
|
||||
return filter_tools_by_skill_allowed_tools(self._base_tools, skills)
|
||||
|
||||
async def _load_skill_messages(self, skills: list[Skill]) -> list[SystemMessage]:
|
||||
"""Load skill content as conversation items based on config.skills.
|
||||
|
||||
Aligned with Codex's pattern: each subagent loads its own skills
|
||||
per-session and injects them as conversation items (developer messages),
|
||||
not as system prompt text. The config.skills whitelist controls which
|
||||
skills are loaded:
|
||||
- None: load all enabled skills
|
||||
- []: no skills
|
||||
- ["skill-a", "skill-b"]: only these skills
|
||||
|
||||
Returns:
|
||||
List of SystemMessages containing skill content.
|
||||
"""
|
||||
if not skills:
|
||||
return []
|
||||
|
||||
@@ -347,19 +353,21 @@ class SubagentExecutor:
|
||||
|
||||
return messages
|
||||
|
||||
async def _build_initial_state(self, task: str) -> dict[str, Any]:
|
||||
async def _build_initial_state(self, task: str) -> tuple[dict[str, Any], list[BaseTool]]:
|
||||
"""Build the initial state for agent execution.
|
||||
|
||||
Args:
|
||||
task: The task description.
|
||||
|
||||
Returns:
|
||||
Initial state dictionary.
|
||||
Initial state dictionary and tools filtered by loaded skill metadata.
|
||||
"""
|
||||
# Load skills as conversation items (Codex pattern)
|
||||
skill_messages = await self._load_skill_messages()
|
||||
skills = await self._load_skills()
|
||||
filtered_tools = self._apply_skill_allowed_tools(skills)
|
||||
skill_messages = await self._load_skill_messages(skills)
|
||||
|
||||
messages: list = []
|
||||
messages: list[Any] = []
|
||||
# Skill content injected as developer/system messages before the task
|
||||
messages.extend(skill_messages)
|
||||
# Then the actual task
|
||||
@@ -375,7 +383,7 @@ class SubagentExecutor:
|
||||
if self.thread_data is not None:
|
||||
state["thread_data"] = self.thread_data
|
||||
|
||||
return state
|
||||
return state, filtered_tools
|
||||
|
||||
async def _aexecute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
|
||||
"""Execute a task asynchronously.
|
||||
@@ -405,8 +413,8 @@ class SubagentExecutor:
|
||||
result.ai_messages = ai_messages
|
||||
|
||||
try:
|
||||
agent = self._create_agent()
|
||||
state = await self._build_initial_state(task)
|
||||
state, filtered_tools = await self._build_initial_state(task)
|
||||
agent = self._create_agent(filtered_tools)
|
||||
|
||||
# Build config with thread_id for sandbox access and recursion limit
|
||||
run_config: RunnableConfig = {
|
||||
|
||||
@@ -2,10 +2,12 @@ from .clarification_tool import ask_clarification_tool
|
||||
from .present_file_tool import present_file_tool
|
||||
from .setup_agent_tool import setup_agent
|
||||
from .task_tool import task_tool
|
||||
from .update_agent_tool import update_agent
|
||||
from .view_image_tool import view_image_tool
|
||||
|
||||
__all__ = [
|
||||
"setup_agent",
|
||||
"update_agent",
|
||||
"present_file_tool",
|
||||
"ask_clarification_tool",
|
||||
"view_image_tool",
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain.tools import InjectedToolCallId, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.config import get_config
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
||||
|
||||
|
||||
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
|
||||
def _get_thread_id(runtime: Runtime) -> str | None:
|
||||
"""Resolve the current thread id from runtime context or RunnableConfig."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id:
|
||||
@@ -32,7 +31,7 @@ def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
|
||||
|
||||
|
||||
def _normalize_presented_filepath(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
filepath: str,
|
||||
) -> str:
|
||||
"""Normalize a presented file path to the `/mnt/user-data/outputs/*` contract.
|
||||
@@ -83,7 +82,7 @@ def _normalize_presented_filepath(
|
||||
|
||||
@tool("present_files", parse_docstring=True)
|
||||
def present_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
filepaths: list[str],
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
|
||||
@@ -3,20 +3,28 @@ import logging
|
||||
import yaml
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.prebuilt import ToolRuntime
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.config.agents_config import validate_agent_name
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_runtime_user_id(runtime: Runtime) -> str:
|
||||
context_user_id = runtime.context.get("user_id") if runtime.context else None
|
||||
if context_user_id:
|
||||
return str(context_user_id)
|
||||
return get_effective_user_id()
|
||||
|
||||
|
||||
@tool
|
||||
def setup_agent(
|
||||
soul: str,
|
||||
description: str,
|
||||
runtime: ToolRuntime,
|
||||
runtime: Runtime,
|
||||
skills: list[str] | None = None,
|
||||
) -> Command:
|
||||
"""Setup the custom DeerFlow agent.
|
||||
@@ -34,7 +42,14 @@ def setup_agent(
|
||||
try:
|
||||
agent_name = validate_agent_name(agent_name)
|
||||
paths = get_paths()
|
||||
agent_dir = paths.agent_dir(agent_name) if agent_name else paths.base_dir
|
||||
if agent_name:
|
||||
# Custom agents are persisted under the current user's bucket so
|
||||
# different users do not see each other's agents.
|
||||
user_id = _get_runtime_user_id(runtime)
|
||||
agent_dir = paths.user_agent_dir(user_id, agent_name)
|
||||
else:
|
||||
# Default agent (no agent_name): SOUL.md lives at the global base dir.
|
||||
agent_dir = paths.base_dir
|
||||
is_new_dir = not agent_dir.exists()
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -6,11 +6,9 @@ import uuid
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Annotated, Any, cast
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain.tools import InjectedToolCallId, tool
|
||||
from langgraph.config import get_stream_writer
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
|
||||
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
|
||||
@@ -21,6 +19,7 @@ from deerflow.subagents.executor import (
|
||||
get_background_task_result,
|
||||
request_cancel_background_task,
|
||||
)
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
@@ -50,12 +49,11 @@ def _merge_skill_allowlists(parent: list[str] | None, child: list[str] | None) -
|
||||
|
||||
@tool("task", parse_docstring=True)
|
||||
async def task_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
prompt: str,
|
||||
subagent_type: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
max_turns: int | None = None,
|
||||
) -> str:
|
||||
"""Delegate a task to a specialized subagent that runs in its own context.
|
||||
|
||||
@@ -91,7 +89,6 @@ async def task_tool(
|
||||
description: A short (3-5 word) description of the task for logging/display. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
prompt: The task description for the subagent. Be specific and clear about what needs to be done. ALWAYS PROVIDE THIS PARAMETER SECOND.
|
||||
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.
|
||||
"""
|
||||
runtime_app_config = _get_runtime_app_config(runtime)
|
||||
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
|
||||
@@ -113,9 +110,6 @@ async def task_tool(
|
||||
# each subagent loads its own skills based on config, injected as conversation items).
|
||||
# No longer appended to system_prompt here.
|
||||
|
||||
if max_turns is not None:
|
||||
overrides["max_turns"] = max_turns
|
||||
|
||||
# Extract parent context from runtime
|
||||
sandbox_state = None
|
||||
thread_data = None
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
"""update_agent tool — let a custom agent persist updates to its own SOUL.md / config.
|
||||
|
||||
Bound to the lead agent only when ``runtime.context['agent_name']`` is set
|
||||
(i.e. inside an existing custom agent's chat). The default agent does not see
|
||||
this tool, and the bootstrap flow continues to use ``setup_agent`` for the
|
||||
initial creation handshake.
|
||||
|
||||
The tool writes back to ``{base_dir}/users/{user_id}/agents/{agent_name}/{config.yaml,SOUL.md}``
|
||||
so an agent created by one user is never visible to (or mutable by) another.
|
||||
Writes are staged into temp files first; both files are renamed into place only
|
||||
after both temp files are successfully written, so a partial failure cannot leave
|
||||
config.yaml updated while SOUL.md still holds stale content.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.config.agents_config import load_agent_config, validate_agent_name
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stage_temp(path: Path, text: str) -> Path:
|
||||
"""Write ``text`` into a sibling temp file and return its path.
|
||||
|
||||
The caller is responsible for ``Path.replace``-ing the temp into the target
|
||||
once every staged file is ready, or for unlinking it on failure.
|
||||
"""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
dir=path.parent,
|
||||
suffix=".tmp",
|
||||
delete=False,
|
||||
encoding="utf-8",
|
||||
)
|
||||
try:
|
||||
fd.write(text)
|
||||
fd.flush()
|
||||
fd.close()
|
||||
return Path(fd.name)
|
||||
except BaseException:
|
||||
fd.close()
|
||||
Path(fd.name).unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
|
||||
def _cleanup_temps(temps: list[Path]) -> None:
|
||||
"""Best-effort removal of staged temp files."""
|
||||
for tmp in temps:
|
||||
try:
|
||||
tmp.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
logger.debug("Failed to clean up temp file %s", tmp, exc_info=True)
|
||||
|
||||
|
||||
@tool
|
||||
def update_agent(
|
||||
runtime: Runtime,
|
||||
soul: str | None = None,
|
||||
description: str | None = None,
|
||||
skills: list[str] | None = None,
|
||||
tool_groups: list[str] | None = None,
|
||||
model: str | None = None,
|
||||
) -> Command:
|
||||
"""Persist updates to the current custom agent's SOUL.md and config.yaml.
|
||||
|
||||
Use this when the user asks to refine the agent's identity, description,
|
||||
skill whitelist, tool-group whitelist, or default model. Only the fields
|
||||
you explicitly pass are updated; omitted fields keep their existing values.
|
||||
|
||||
Pass ``soul`` as the FULL replacement SOUL.md content — there is no patch
|
||||
semantics, so always start from the current SOUL and apply your edits.
|
||||
|
||||
Pass ``skills=[]`` to disable all skills for this agent. Omit ``skills``
|
||||
entirely to keep the existing whitelist.
|
||||
|
||||
Args:
|
||||
soul: Optional full replacement SOUL.md content.
|
||||
description: Optional new one-line description.
|
||||
skills: Optional skill whitelist. ``[]`` = no skills, omit = unchanged.
|
||||
tool_groups: Optional tool-group whitelist. ``[]`` = empty, omit = unchanged.
|
||||
model: Optional model override (must match a configured model name).
|
||||
|
||||
Returns:
|
||||
Command with a ToolMessage describing the result. Changes take effect
|
||||
on the next user turn (when the lead agent is rebuilt with the fresh
|
||||
SOUL.md and config.yaml).
|
||||
"""
|
||||
tool_call_id = runtime.tool_call_id
|
||||
agent_name_raw: str | None = runtime.context.get("agent_name") if runtime.context else None
|
||||
|
||||
def _err(message: str) -> Command:
|
||||
return Command(update={"messages": [ToolMessage(content=f"Error: {message}", tool_call_id=tool_call_id)]})
|
||||
|
||||
if soul is None and description is None and skills is None and tool_groups is None and model is None:
|
||||
return _err("No fields provided. Pass at least one of: soul, description, skills, tool_groups, model.")
|
||||
|
||||
try:
|
||||
agent_name = validate_agent_name(agent_name_raw)
|
||||
except ValueError as e:
|
||||
return _err(str(e))
|
||||
|
||||
if not agent_name:
|
||||
return _err("update_agent is only available inside a custom agent's chat. There is no agent_name in the current runtime context, so there is nothing to update. If you are inside the bootstrap flow, use setup_agent instead.")
|
||||
|
||||
# Resolve the active user so that updates only affect this user's agent.
|
||||
# ``get_effective_user_id`` returns DEFAULT_USER_ID when no auth context
|
||||
# is set (matching how memory and thread storage behave).
|
||||
user_id = get_effective_user_id()
|
||||
|
||||
# Reject an unknown ``model`` *before* touching the filesystem. Otherwise
|
||||
# ``_resolve_model_name`` silently falls back to the default at runtime
|
||||
# and the user sees confusing repeated warnings on every later turn.
|
||||
if model is not None and get_app_config().get_model_config(model) is None:
|
||||
return _err(f"Unknown model '{model}'. Pass a model name that exists in config.yaml's models section.")
|
||||
|
||||
paths = get_paths()
|
||||
agent_dir = paths.user_agent_dir(user_id, agent_name)
|
||||
if not agent_dir.exists() and paths.agent_dir(agent_name).exists():
|
||||
return _err(f"Agent '{agent_name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before updating.")
|
||||
|
||||
try:
|
||||
existing_cfg = load_agent_config(agent_name, user_id=user_id)
|
||||
except FileNotFoundError:
|
||||
return _err(f"Agent '{agent_name}' does not exist for the current user. Use setup_agent to create a new agent first.")
|
||||
except ValueError as e:
|
||||
return _err(f"Agent '{agent_name}' has an unreadable config: {e}")
|
||||
|
||||
if existing_cfg is None:
|
||||
return _err(f"Agent '{agent_name}' could not be loaded.")
|
||||
|
||||
updated_fields: list[str] = []
|
||||
|
||||
# Force the on-disk ``name`` to match the directory we are writing into,
|
||||
# even if ``existing_cfg.name`` had drifted (e.g. from manual yaml edits).
|
||||
config_data: dict[str, Any] = {"name": agent_name}
|
||||
new_description = description if description is not None else existing_cfg.description
|
||||
config_data["description"] = new_description
|
||||
if description is not None and description != existing_cfg.description:
|
||||
updated_fields.append("description")
|
||||
|
||||
new_model = model if model is not None else existing_cfg.model
|
||||
if new_model is not None:
|
||||
config_data["model"] = new_model
|
||||
if model is not None and model != existing_cfg.model:
|
||||
updated_fields.append("model")
|
||||
|
||||
new_tool_groups = tool_groups if tool_groups is not None else existing_cfg.tool_groups
|
||||
if new_tool_groups is not None:
|
||||
config_data["tool_groups"] = new_tool_groups
|
||||
if tool_groups is not None and tool_groups != existing_cfg.tool_groups:
|
||||
updated_fields.append("tool_groups")
|
||||
|
||||
new_skills = skills if skills is not None else existing_cfg.skills
|
||||
if new_skills is not None:
|
||||
config_data["skills"] = new_skills
|
||||
if skills is not None and skills != existing_cfg.skills:
|
||||
updated_fields.append("skills")
|
||||
|
||||
config_changed = bool({"description", "model", "tool_groups", "skills"} & set(updated_fields))
|
||||
|
||||
# Stage every file we intend to rewrite into a temp sibling. Only after
|
||||
# *all* temp files exist do we rename them into place — so a failure on
|
||||
# SOUL.md cannot leave config.yaml already replaced.
|
||||
pending: list[tuple[Path, Path]] = []
|
||||
staged_temps: list[Path] = []
|
||||
|
||||
try:
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if config_changed:
|
||||
yaml_text = yaml.dump(config_data, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
||||
config_target = agent_dir / "config.yaml"
|
||||
config_tmp = _stage_temp(config_target, yaml_text)
|
||||
staged_temps.append(config_tmp)
|
||||
pending.append((config_tmp, config_target))
|
||||
|
||||
if soul is not None:
|
||||
soul_target = agent_dir / "SOUL.md"
|
||||
soul_tmp = _stage_temp(soul_target, soul)
|
||||
staged_temps.append(soul_tmp)
|
||||
pending.append((soul_tmp, soul_target))
|
||||
updated_fields.append("soul")
|
||||
|
||||
# Commit phase. ``Path.replace`` is atomic per file on POSIX/NTFS and
|
||||
# the staging step above means any earlier failure has already been
|
||||
# reported. The remaining failure mode is a crash *between* two
|
||||
# ``replace`` calls, which is reported via the partial-write error
|
||||
# branch below so the caller knows which files are now on disk.
|
||||
committed: list[Path] = []
|
||||
try:
|
||||
for tmp, target in pending:
|
||||
tmp.replace(target)
|
||||
committed.append(target)
|
||||
except Exception as e:
|
||||
_cleanup_temps([t for t, _ in pending if t not in committed])
|
||||
if committed:
|
||||
logger.error(
|
||||
"[update_agent] Partial write for agent '%s' (user=%s): committed=%s, failed during rename: %s",
|
||||
agent_name,
|
||||
user_id,
|
||||
[p.name for p in committed],
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return _err(f"Partial update for agent '{agent_name}': {[p.name for p in committed]} were updated, but the rest failed ({e}). Re-run update_agent to retry the remaining fields.")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
_cleanup_temps(staged_temps)
|
||||
logger.error("[update_agent] Failed to update agent '%s' (user=%s): %s", agent_name, user_id, e, exc_info=True)
|
||||
return _err(f"Failed to update agent '{agent_name}': {e}")
|
||||
|
||||
if not updated_fields:
|
||||
return Command(update={"messages": [ToolMessage(content=f"No changes applied to agent '{agent_name}'. The provided values matched the existing config.", tool_call_id=tool_call_id)]})
|
||||
|
||||
logger.info("[update_agent] Updated agent '%s' (user=%s) fields: %s", agent_name, user_id, updated_fields)
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=(f"Agent '{agent_name}' updated successfully. Changed: {', '.join(updated_fields)}. The new configuration takes effect on the next user turn."),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
@@ -3,13 +3,13 @@ import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain.tools import InjectedToolCallId, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState, ThreadState
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
_ALLOWED_IMAGE_VIRTUAL_ROOTS = (
|
||||
f"{VIRTUAL_PATH_PREFIX}/workspace",
|
||||
@@ -48,7 +48,7 @@ def _sanitize_image_error(error: Exception, thread_data: ThreadDataState | None)
|
||||
|
||||
@tool("view_image", parse_docstring=True)
|
||||
def view_image_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
image_path: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
|
||||
@@ -7,16 +7,15 @@ import logging
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
from langgraph.typing import ContextT
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.mcp.tools import _make_sync_tool_wrapper
|
||||
from deerflow.skills.security_scanner import scan_skill_content
|
||||
from deerflow.skills.storage import get_or_new_skill_storage
|
||||
from deerflow.skills.storage.skill_storage import SkillStorage
|
||||
from deerflow.skills.types import SKILL_MD_FILE
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,7 +30,7 @@ def _get_lock(name: str) -> asyncio.Lock:
|
||||
return lock
|
||||
|
||||
|
||||
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
|
||||
def _get_thread_id(runtime: Runtime | None) -> str | None:
|
||||
if runtime is None:
|
||||
return None
|
||||
if runtime.context and runtime.context.get("thread_id"):
|
||||
@@ -65,7 +64,7 @@ async def _to_thread(func, /, *args, **kwargs):
|
||||
|
||||
|
||||
async def _skill_manage_impl(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
action: str,
|
||||
name: str,
|
||||
content: str | None = None,
|
||||
@@ -204,7 +203,7 @@ async def _skill_manage_impl(
|
||||
|
||||
@tool("skill_manage", parse_docstring=True)
|
||||
async def skill_manage_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
action: str,
|
||||
name: str,
|
||||
content: str | None = None,
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
# Concrete runtime type used by all DeerFlow tools.
|
||||
# Using dict[str, Any] for the context parameter instead of the unbound ContextT
|
||||
# TypeVar prevents PydanticSerializationUnexpectedValue warnings when LangChain
|
||||
# calls model_dump() on a tool's auto-generated args_schema.
|
||||
Runtime = ToolRuntime[dict[str, Any], ThreadState]
|
||||
@@ -4,8 +4,10 @@ Pure business logic — no FastAPI/HTTP dependencies.
|
||||
Both Gateway and Client delegate to these functions.
|
||||
"""
|
||||
|
||||
import errno
|
||||
import os
|
||||
import re
|
||||
import stat
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
@@ -17,6 +19,10 @@ class PathTraversalError(ValueError):
|
||||
"""Raised when a path escapes its allowed base directory."""
|
||||
|
||||
|
||||
class UnsafeUploadPathError(ValueError):
|
||||
"""Raised when an upload destination is not a safe regular file path."""
|
||||
|
||||
|
||||
# thread_id must be alphanumeric, hyphens, underscores, or dots only.
|
||||
_SAFE_THREAD_ID = re.compile(r"^[a-zA-Z0-9._-]+$")
|
||||
|
||||
@@ -109,6 +115,108 @@ def validate_path_traversal(path: Path, base: Path) -> None:
|
||||
raise PathTraversalError("Path traversal detected") from None
|
||||
|
||||
|
||||
def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, object]:
|
||||
"""Open an upload destination for safe streaming writes.
|
||||
|
||||
Upload directories may be mounted into local sandboxes. A sandbox process can
|
||||
therefore leave a symlink at a future upload filename. Normal ``Path.write_bytes``
|
||||
follows that link and can overwrite files outside the uploads directory with
|
||||
gateway privileges. This helper rejects symlink destinations using ``O_NOFOLLOW``
|
||||
on POSIX. On Windows (which lacks ``O_NOFOLLOW``), it uses dual ``lstat`` checks
|
||||
and ``fstat`` validation after ``open()`` to reduce the TOCTOU window; this does
|
||||
not eliminate all races but makes exploitation significantly harder. Path-traversal
|
||||
validation prevents escapes from *base_dir* in both cases.
|
||||
"""
|
||||
safe_name = normalize_filename(filename)
|
||||
dest = base_dir / safe_name
|
||||
|
||||
try:
|
||||
st = os.lstat(dest)
|
||||
except FileNotFoundError:
|
||||
st = None
|
||||
|
||||
if st is not None and not stat.S_ISREG(st.st_mode):
|
||||
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
|
||||
|
||||
validate_path_traversal(dest, base_dir)
|
||||
|
||||
has_nofollow = hasattr(os, "O_NOFOLLOW")
|
||||
|
||||
if has_nofollow:
|
||||
# POSIX: O_NOFOLLOW makes open() fail with ELOOP if dest is a symlink.
|
||||
flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW
|
||||
if hasattr(os, "O_NONBLOCK"):
|
||||
flags |= os.O_NONBLOCK
|
||||
|
||||
try:
|
||||
fd = os.open(dest, flags, 0o600)
|
||||
except OSError as exc:
|
||||
if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
|
||||
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
|
||||
raise
|
||||
|
||||
try:
|
||||
opened_stat = os.fstat(fd)
|
||||
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1:
|
||||
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
|
||||
os.ftruncate(fd, 0)
|
||||
fh = os.fdopen(fd, "wb")
|
||||
fd = -1
|
||||
finally:
|
||||
if fd >= 0:
|
||||
os.close(fd)
|
||||
return dest, fh
|
||||
|
||||
# Windows: no O_NOFOLLOW available. Uses a second lstat immediately before open()
|
||||
# to narrow the TOCTOU window, then fstat after open() as a further defence.
|
||||
# Note: a narrow race window remains between the pre-open lstat and open(); the
|
||||
# path-traversal check mitigates escapes from base_dir but cannot prevent an
|
||||
# attacker who can atomically replace dest with a symlink after the check.
|
||||
if st is not None and st.st_nlink > 1:
|
||||
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
|
||||
|
||||
flags = os.O_WRONLY | os.O_CREAT
|
||||
if hasattr(os, "O_BINARY"):
|
||||
flags |= os.O_BINARY
|
||||
|
||||
try:
|
||||
pre_open_st = os.lstat(dest)
|
||||
except FileNotFoundError:
|
||||
pre_open_st = None
|
||||
|
||||
if pre_open_st is not None and not stat.S_ISREG(pre_open_st.st_mode):
|
||||
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
|
||||
if pre_open_st is not None and pre_open_st.st_nlink > 1:
|
||||
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
|
||||
|
||||
try:
|
||||
fd = os.open(dest, flags, 0o600)
|
||||
except OSError as exc:
|
||||
if exc.errno in {errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
|
||||
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
|
||||
raise
|
||||
|
||||
try:
|
||||
opened_stat = os.fstat(fd)
|
||||
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink > 1:
|
||||
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
|
||||
os.ftruncate(fd, 0)
|
||||
fh = os.fdopen(fd, "wb")
|
||||
fd = -1
|
||||
finally:
|
||||
if fd >= 0:
|
||||
os.close(fd)
|
||||
return dest, fh
|
||||
|
||||
|
||||
def write_upload_file_no_symlink(base_dir: Path, filename: str, data: bytes) -> Path:
|
||||
"""Write upload bytes without following a pre-existing destination symlink."""
|
||||
dest, fh = open_upload_file_no_symlink(base_dir, filename)
|
||||
with fh:
|
||||
fh.write(data)
|
||||
return dest
|
||||
|
||||
|
||||
def list_files_in_dir(directory: Path) -> dict:
|
||||
"""List files (not directories) in *directory*.
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
"""ISO 8601 timestamp helpers for the Gateway and embedded runtime.
|
||||
|
||||
DeerFlow stores and serializes thread/run timestamps as ISO 8601 UTC
|
||||
strings to match the LangGraph Platform schema (see
|
||||
``langgraph_sdk.schema.Thread``, where ``created_at`` / ``updated_at``
|
||||
are ``datetime`` and JSON-encode to ISO 8601). All timestamp generation
|
||||
should funnel through :func:`now_iso` so the wire format stays
|
||||
consistent across endpoints, the embedded ``RunManager``, and the
|
||||
checkpoint metadata written by the Gateway.
|
||||
|
||||
:func:`coerce_iso` provides a forward-compatible read path for legacy
|
||||
records that historically stored ``str(time.time())`` floats.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
|
||||
__all__ = ["coerce_iso", "now_iso"]
|
||||
|
||||
_UNIX_TIMESTAMP_PATTERN = re.compile(r"^\d{10}(?:\.\d+)?$")
|
||||
"""Matches the unix-timestamp string shape historically written by
|
||||
``str(time.time())`` (10-digit seconds with optional fractional part).
|
||||
The 10-digit anchor avoids accidentally rewriting ISO years like
|
||||
``"2026"`` and stays valid until the year 2286.
|
||||
"""
|
||||
|
||||
|
||||
def now_iso() -> str:
|
||||
"""Return the current UTC time as an ISO 8601 string.
|
||||
|
||||
Example: ``"2026-04-27T03:19:46.511479+00:00"``.
|
||||
"""
|
||||
return datetime.now(UTC).isoformat()
|
||||
|
||||
|
||||
def coerce_iso(value: object) -> str:
|
||||
"""Best-effort coerce a stored timestamp to an ISO 8601 string.
|
||||
|
||||
Translates legacy unix-timestamp floats / strings written by older
|
||||
DeerFlow versions into ISO without a one-shot migration. ISO strings
|
||||
pass through unchanged; ``datetime`` instances are normalised to UTC
|
||||
(tz-naive values are assumed to be UTC) and emitted via
|
||||
``isoformat()`` so the wire format always uses the ``T`` separator;
|
||||
empty values become ``""``; unrecognised values are stringified as a
|
||||
last resort.
|
||||
"""
|
||||
if value is None or value == "":
|
||||
return ""
|
||||
if isinstance(value, bool):
|
||||
# ``bool`` is a subclass of ``int`` — treat as garbage, not 0/1.
|
||||
return str(value)
|
||||
if isinstance(value, datetime):
|
||||
# ``datetime`` must be handled before the ``int``/``float`` check;
|
||||
# str(datetime) would produce ``"YYYY-MM-DD HH:MM:SS+00:00"``
|
||||
# (space separator), which breaks strict ISO 8601 consumers.
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
else:
|
||||
value = value.astimezone(UTC)
|
||||
return value.isoformat()
|
||||
if isinstance(value, (int, float)):
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value), UTC).isoformat()
|
||||
except (ValueError, OverflowError, OSError):
|
||||
return str(value)
|
||||
if isinstance(value, str):
|
||||
if _UNIX_TIMESTAMP_PATTERN.match(value):
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value), UTC).isoformat()
|
||||
except (ValueError, OverflowError, OSError):
|
||||
return value
|
||||
return value
|
||||
return str(value)
|
||||
@@ -8,7 +8,7 @@ dependencies = [
|
||||
"deerflow-harness",
|
||||
"fastapi>=0.115.0",
|
||||
"httpx>=0.28.0",
|
||||
"python-multipart>=0.0.26",
|
||||
"python-multipart>=0.0.27",
|
||||
"sse-starlette>=2.1.0",
|
||||
"uvicorn[standard]>=0.34.0",
|
||||
"lark-oapi>=1.4.0",
|
||||
@@ -47,4 +47,3 @@ members = ["packages/harness"]
|
||||
|
||||
[tool.uv.sources]
|
||||
deerflow-harness = { workspace = true }
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""One-time migration: move legacy thread dirs and memory into per-user layout.
|
||||
|
||||
Usage:
|
||||
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run]
|
||||
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run] [--user-id USER_ID]
|
||||
|
||||
The script is idempotent — re-running it after a successful migration is a no-op.
|
||||
"""
|
||||
@@ -69,6 +69,67 @@ def migrate_thread_dirs(
|
||||
return report
|
||||
|
||||
|
||||
def migrate_agents(
|
||||
paths: Paths,
|
||||
user_id: str = "default",
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> list[dict]:
|
||||
"""Move legacy custom-agent directories into per-user layout.
|
||||
|
||||
Legacy layout: ``{base_dir}/agents/{name}/``
|
||||
Per-user layout: ``{base_dir}/users/{user_id}/agents/{name}/``
|
||||
|
||||
Pre-existing per-user agents take precedence: if a destination already
|
||||
exists for an agent name, the legacy copy is moved to
|
||||
``{base_dir}/migration-conflicts/agents/{name}/`` for manual review.
|
||||
|
||||
Args:
|
||||
paths: Paths instance.
|
||||
user_id: Target user to receive the legacy agents (defaults to
|
||||
``"default"``, matching ``DEFAULT_USER_ID`` for no-auth setups).
|
||||
dry_run: If True, only log what would happen.
|
||||
|
||||
Returns:
|
||||
List of migration report entries, one per legacy agent directory found.
|
||||
"""
|
||||
report: list[dict] = []
|
||||
legacy_agents = paths.agents_dir
|
||||
if not legacy_agents.exists():
|
||||
logger.info("No legacy agents directory found — nothing to migrate.")
|
||||
return report
|
||||
|
||||
for agent_dir in sorted(legacy_agents.iterdir()):
|
||||
if not agent_dir.is_dir():
|
||||
continue
|
||||
agent_name = agent_dir.name
|
||||
dest = paths.user_agent_dir(user_id, agent_name)
|
||||
|
||||
entry = {"agent": agent_name, "user_id": user_id, "action": ""}
|
||||
|
||||
if dest.exists():
|
||||
conflicts_dir = paths.base_dir / "migration-conflicts" / "agents" / agent_name
|
||||
entry["action"] = f"conflict -> {conflicts_dir}"
|
||||
if not dry_run:
|
||||
conflicts_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(agent_dir), str(conflicts_dir))
|
||||
logger.warning("Conflict for agent %s: moved legacy copy to %s", agent_name, conflicts_dir)
|
||||
else:
|
||||
entry["action"] = f"moved -> {dest}"
|
||||
if not dry_run:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(agent_dir), str(dest))
|
||||
logger.info("Migrated agent %s -> user %s", agent_name, user_id)
|
||||
|
||||
report.append(entry)
|
||||
|
||||
# Clean up empty legacy agents dir
|
||||
if not dry_run and legacy_agents.exists() and not any(legacy_agents.iterdir()):
|
||||
legacy_agents.rmdir()
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def migrate_memory(
|
||||
paths: Paths,
|
||||
user_id: str = "default",
|
||||
@@ -127,6 +188,12 @@ def _build_owner_map_from_db(paths: Paths) -> dict[str, str]:
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Migrate DeerFlow data to per-user layout")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Log actions without making changes")
|
||||
parser.add_argument(
|
||||
"--user-id",
|
||||
default="default",
|
||||
metavar="USER_ID",
|
||||
help=("User ID to claim un-owned legacy data (global memory.json and legacy custom agents). Defaults to 'default'. In multi-user installs, set this to the operator account that should inherit those legacy artifacts."),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
@@ -134,26 +201,42 @@ def main() -> None:
|
||||
paths = get_paths()
|
||||
logger.info("Base directory: %s", paths.base_dir)
|
||||
logger.info("Dry run: %s", args.dry_run)
|
||||
logger.info("Claiming un-owned legacy data for user_id=%s", args.user_id)
|
||||
|
||||
owner_map = _build_owner_map_from_db(paths)
|
||||
logger.info("Found %d thread ownership records in DB", len(owner_map))
|
||||
|
||||
report = migrate_thread_dirs(paths, owner_map, dry_run=args.dry_run)
|
||||
migrate_memory(paths, user_id="default", dry_run=args.dry_run)
|
||||
migrate_memory(paths, user_id=args.user_id, dry_run=args.dry_run)
|
||||
agent_report = migrate_agents(paths, user_id=args.user_id, dry_run=args.dry_run)
|
||||
|
||||
if report:
|
||||
logger.info("Migration report:")
|
||||
logger.info("Thread migration report:")
|
||||
for entry in report:
|
||||
logger.info(" thread=%s user=%s action=%s", entry["thread_id"], entry["user_id"], entry["action"])
|
||||
else:
|
||||
logger.info("No threads to migrate.")
|
||||
|
||||
if agent_report:
|
||||
logger.info("Agent migration report:")
|
||||
for entry in agent_report:
|
||||
logger.info(" agent=%s user=%s action=%s", entry["agent"], entry["user_id"], entry["action"])
|
||||
else:
|
||||
logger.info("No agents to migrate.")
|
||||
|
||||
unowned = [e for e in report if e["user_id"] == "default"]
|
||||
if unowned:
|
||||
logger.warning("%d thread(s) had no owner and were assigned to 'default':", len(unowned))
|
||||
for e in unowned:
|
||||
logger.warning(" %s", e["thread_id"])
|
||||
|
||||
if agent_report:
|
||||
logger.warning(
|
||||
"%d legacy agent(s) were assigned to '%s'. If those agents belonged to other users, move them manually under {base_dir}/users/<user_id>/agents/.",
|
||||
len(agent_report),
|
||||
args.user_id,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
"""Tests for AioSandboxProvider auto-restart of crashed containers."""
|
||||
|
||||
import importlib
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _import_provider():
|
||||
return importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
|
||||
|
||||
def _make_provider(*, auto_restart=True, alive=True):
|
||||
"""Build a minimal AioSandboxProvider with a mock backend.
|
||||
|
||||
Args:
|
||||
auto_restart: Value for the auto_restart config key.
|
||||
alive: Whether the mock backend reports containers as alive.
|
||||
"""
|
||||
mod = _import_provider()
|
||||
with patch.object(mod.AioSandboxProvider, "_start_idle_checker"):
|
||||
provider = mod.AioSandboxProvider.__new__(mod.AioSandboxProvider)
|
||||
provider._config = {"auto_restart": auto_restart}
|
||||
provider._lock = threading.Lock()
|
||||
provider._sandboxes = {}
|
||||
provider._sandbox_infos = {}
|
||||
provider._thread_sandboxes = {}
|
||||
provider._thread_locks = {}
|
||||
provider._last_activity = {}
|
||||
provider._warm_pool = {}
|
||||
provider._shutdown_called = False
|
||||
provider._idle_checker_stop = threading.Event()
|
||||
|
||||
backend = MagicMock()
|
||||
backend.is_alive.return_value = alive
|
||||
provider._backend = backend
|
||||
|
||||
return provider, backend
|
||||
|
||||
|
||||
def _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1"):
|
||||
"""Insert a sandbox into the provider's caches as if it were acquired."""
|
||||
sandbox = MagicMock()
|
||||
info = MagicMock()
|
||||
|
||||
provider._sandboxes[sandbox_id] = sandbox
|
||||
provider._sandbox_infos[sandbox_id] = info
|
||||
provider._last_activity[sandbox_id] = 0.0
|
||||
if thread_id:
|
||||
provider._thread_sandboxes[thread_id] = sandbox_id
|
||||
|
||||
return sandbox, info
|
||||
|
||||
|
||||
# ── get() returns sandbox when container is alive ──────────────────────────
|
||||
|
||||
|
||||
def test_get_returns_sandbox_when_container_alive():
|
||||
"""When auto_restart is on and the container is alive, get() returns the sandbox."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=True)
|
||||
sandbox, _ = _seed_sandbox(provider)
|
||||
|
||||
result = provider.get("dead-beef")
|
||||
|
||||
assert result is sandbox
|
||||
backend.is_alive.assert_called_once()
|
||||
|
||||
|
||||
def test_get_returns_sandbox_when_auto_restart_disabled():
|
||||
"""When auto_restart is off, get() skips the health check entirely."""
|
||||
provider, backend = _make_provider(auto_restart=False)
|
||||
sandbox, _ = _seed_sandbox(provider)
|
||||
|
||||
result = provider.get("dead-beef")
|
||||
|
||||
assert result is sandbox
|
||||
backend.is_alive.assert_not_called()
|
||||
|
||||
|
||||
# ── get() evicts dead sandbox when auto_restart is on ──────────────────────
|
||||
|
||||
|
||||
def test_get_evicts_dead_sandbox_when_auto_restart_enabled():
|
||||
"""When the container is dead and auto_restart is on, get() returns None and cleans caches."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_, info = _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1")
|
||||
|
||||
result = provider.get("dead-beef")
|
||||
|
||||
assert result is None
|
||||
assert "dead-beef" not in provider._sandboxes
|
||||
assert "dead-beef" not in provider._sandbox_infos
|
||||
assert "dead-beef" not in provider._last_activity
|
||||
assert "thread-1" not in provider._thread_sandboxes
|
||||
backend.destroy.assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_get_returns_dead_sandbox_when_auto_restart_disabled():
|
||||
"""When auto_restart is off, get() returns the cached sandbox even if the container is dead."""
|
||||
provider, backend = _make_provider(auto_restart=False, alive=False)
|
||||
sandbox, _ = _seed_sandbox(provider)
|
||||
|
||||
result = provider.get("dead-beef")
|
||||
|
||||
assert result is sandbox
|
||||
# Caches are untouched
|
||||
assert "dead-beef" in provider._sandboxes
|
||||
|
||||
|
||||
def test_get_eviction_cleans_multiple_thread_mappings():
|
||||
"""A sandbox mapped to multiple thread IDs has all mappings cleaned on eviction."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="t-a")
|
||||
# Manually add a second thread mapping to the same sandbox
|
||||
provider._thread_sandboxes["t-b"] = "sid-1"
|
||||
|
||||
result = provider.get("sid-1")
|
||||
|
||||
assert result is None
|
||||
assert "t-a" not in provider._thread_sandboxes
|
||||
assert "t-b" not in provider._thread_sandboxes
|
||||
|
||||
|
||||
# ── get() does not check health for unknown sandbox IDs ────────────────────
|
||||
|
||||
|
||||
def test_get_returns_none_for_unknown_id():
|
||||
"""If the sandbox_id is not in cache, get() returns None without checking health."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=True)
|
||||
|
||||
result = provider.get("nonexistent")
|
||||
|
||||
assert result is None
|
||||
backend.is_alive.assert_not_called()
|
||||
|
||||
|
||||
# ── get() handles missing sandbox_info gracefully ──────────────────────────
|
||||
|
||||
|
||||
def test_get_handles_missing_info_gracefully():
|
||||
"""If sandbox is cached but info is missing, get() skips the health check."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
sandbox = MagicMock()
|
||||
provider._sandboxes["sid-x"] = sandbox
|
||||
provider._sandbox_infos.pop("sid-x", None) # Ensure no info
|
||||
provider._last_activity["sid-x"] = 0.0
|
||||
|
||||
result = provider.get("sid-x")
|
||||
|
||||
# No info → cannot call is_alive → sandbox returned as-is
|
||||
assert result is sandbox
|
||||
backend.is_alive.assert_not_called()
|
||||
|
||||
|
||||
def test_get_liveness_check_runs_outside_provider_lock():
|
||||
"""get() should not hold the provider lock while checking backend liveness."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_seed_sandbox(provider, sandbox_id="sid-locked", thread_id="thread-1")
|
||||
|
||||
def _assert_lock_not_held(_):
|
||||
assert not provider._lock.locked()
|
||||
return False
|
||||
|
||||
backend.is_alive.side_effect = _assert_lock_not_held
|
||||
|
||||
assert provider.get("sid-locked") is None
|
||||
|
||||
|
||||
def test_get_still_evicts_when_backend_destroy_fails():
|
||||
"""Cleanup errors should not keep stale sandbox state in memory."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_seed_sandbox(provider, sandbox_id="sid-fail", thread_id="thread-1")
|
||||
backend.destroy.side_effect = RuntimeError("boom")
|
||||
|
||||
assert provider.get("sid-fail") is None
|
||||
assert "sid-fail" not in provider._sandboxes
|
||||
assert "sid-fail" not in provider._sandbox_infos
|
||||
assert "thread-1" not in provider._thread_sandboxes
|
||||
backend.destroy.assert_called_once()
|
||||
|
||||
|
||||
# ── Integration: eviction clears caches for recreation ─────────────────────
|
||||
|
||||
|
||||
def test_eviction_clears_all_caches_for_recreation():
|
||||
"""After eviction, all caches are clean so _acquire_internal can recreate.
|
||||
|
||||
This verifies the preconditions for transparent restart: when get() evicts
|
||||
a dead sandbox, the next _acquire_internal call will find no cached entry,
|
||||
no warm-pool entry, and fall through to _create_sandbox.
|
||||
"""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="thread-1")
|
||||
|
||||
# Before eviction: caches populated
|
||||
assert "sid-1" in provider._sandboxes
|
||||
assert "sid-1" in provider._sandbox_infos
|
||||
assert "thread-1" in provider._thread_sandboxes
|
||||
|
||||
# get() detects the dead container and evicts
|
||||
assert provider.get("sid-1") is None
|
||||
|
||||
# After eviction: all caches clean
|
||||
assert "sid-1" not in provider._sandboxes
|
||||
assert "sid-1" not in provider._sandbox_infos
|
||||
assert "thread-1" not in provider._thread_sandboxes
|
||||
assert "sid-1" not in provider._warm_pool
|
||||
|
||||
# _acquire_internal for the same thread would find nothing cached
|
||||
# and generate the deterministic ID, then discover fails (container
|
||||
# is gone), falling through to _create_sandbox — a fresh start.
|
||||
@@ -4,10 +4,40 @@ import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from deerflow.config.agents_api_config import get_agents_api_config
|
||||
import deerflow.config.app_config as app_config_module
|
||||
from deerflow.config.acp_config import load_acp_config_from_dict
|
||||
from deerflow.config.agents_api_config import get_agents_api_config, load_agents_api_config_from_dict
|
||||
from deerflow.config.app_config import AppConfig, get_app_config, reset_app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config, load_checkpointer_config_from_dict
|
||||
from deerflow.config.guardrails_config import get_guardrails_config, load_guardrails_config_from_dict
|
||||
from deerflow.config.memory_config import get_memory_config, load_memory_config_from_dict
|
||||
from deerflow.config.stream_bridge_config import get_stream_bridge_config, load_stream_bridge_config_from_dict
|
||||
from deerflow.config.subagents_config import get_subagents_app_config, load_subagents_config_from_dict
|
||||
from deerflow.config.summarization_config import get_summarization_config, load_summarization_config_from_dict
|
||||
from deerflow.config.title_config import get_title_config, load_title_config_from_dict
|
||||
from deerflow.config.tool_search_config import get_tool_search_config, load_tool_search_config_from_dict
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
from deerflow.runtime.store import get_store, reset_store
|
||||
|
||||
|
||||
def _reset_config_singletons() -> None:
|
||||
load_title_config_from_dict({})
|
||||
load_summarization_config_from_dict({})
|
||||
load_memory_config_from_dict({})
|
||||
load_agents_api_config_from_dict({})
|
||||
load_subagents_config_from_dict({})
|
||||
load_tool_search_config_from_dict({})
|
||||
load_guardrails_config_from_dict({})
|
||||
load_checkpointer_config_from_dict(None)
|
||||
load_stream_bridge_config_from_dict(None)
|
||||
load_acp_config_from_dict({})
|
||||
reset_checkpointer()
|
||||
reset_store()
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
|
||||
@@ -53,6 +83,23 @@ def _write_config_with_agents_api(
|
||||
path.write_text(yaml.safe_dump(config), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_config_with_sections(path: Path, sections: dict | None = None) -> None:
|
||||
config = {
|
||||
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
|
||||
"models": [
|
||||
{
|
||||
"name": "first-model",
|
||||
"use": "langchain_openai:ChatOpenAI",
|
||||
"model": "gpt-test",
|
||||
}
|
||||
],
|
||||
}
|
||||
if sections:
|
||||
config.update(sections)
|
||||
|
||||
path.write_text(yaml.safe_dump(config), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_extensions_config(path: Path) -> None:
|
||||
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
|
||||
|
||||
@@ -175,3 +222,168 @@ def test_get_app_config_resets_agents_api_config_when_section_removed(tmp_path,
|
||||
assert get_agents_api_config().enabled is False
|
||||
finally:
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def test_get_app_config_resets_singleton_configs_when_sections_removed(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config_with_sections(
|
||||
config_path,
|
||||
{
|
||||
"title": {"enabled": False, "max_words": 3},
|
||||
"summarization": {"enabled": True},
|
||||
"memory": {"enabled": False, "max_facts": 50},
|
||||
"subagents": {"timeout_seconds": 42, "agents": {"reviewer": {"max_turns": 2}}},
|
||||
"tool_search": {"enabled": True},
|
||||
"guardrails": {"enabled": True, "fail_closed": False},
|
||||
"checkpointer": {"type": "memory"},
|
||||
"stream_bridge": {"type": "memory", "queue_maxsize": 12},
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
reset_app_config()
|
||||
|
||||
try:
|
||||
get_app_config()
|
||||
assert get_title_config().enabled is False
|
||||
assert get_summarization_config().enabled is True
|
||||
assert get_memory_config().enabled is False
|
||||
assert get_subagents_app_config().timeout_seconds == 42
|
||||
assert get_tool_search_config().enabled is True
|
||||
assert get_guardrails_config().enabled is True
|
||||
assert get_checkpointer_config() is not None
|
||||
assert get_stream_bridge_config() is not None
|
||||
|
||||
_write_config_with_sections(config_path)
|
||||
next_mtime = config_path.stat().st_mtime + 5
|
||||
os.utime(config_path, (next_mtime, next_mtime))
|
||||
|
||||
get_app_config()
|
||||
assert get_title_config().enabled is True
|
||||
assert get_summarization_config().enabled is False
|
||||
assert get_memory_config().enabled is True
|
||||
assert get_subagents_app_config().timeout_seconds == 900
|
||||
assert get_tool_search_config().enabled is False
|
||||
assert get_guardrails_config().enabled is False
|
||||
assert get_checkpointer_config() is None
|
||||
assert get_stream_bridge_config() is None
|
||||
finally:
|
||||
_reset_config_singletons()
|
||||
|
||||
|
||||
def test_get_app_config_resets_persistence_runtime_singletons_when_checkpointer_removed(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config_with_sections(config_path, {"checkpointer": {"type": "memory"}})
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
reset_checkpointer()
|
||||
reset_store()
|
||||
reset_app_config()
|
||||
|
||||
try:
|
||||
get_app_config()
|
||||
initial_checkpointer = get_checkpointer()
|
||||
initial_store = get_store()
|
||||
|
||||
_write_config_with_sections(config_path)
|
||||
next_mtime = config_path.stat().st_mtime + 5
|
||||
os.utime(config_path, (next_mtime, next_mtime))
|
||||
|
||||
get_app_config()
|
||||
|
||||
assert get_checkpointer_config() is None
|
||||
assert get_checkpointer() is not initial_checkpointer
|
||||
assert get_store() is not initial_store
|
||||
finally:
|
||||
_reset_config_singletons()
|
||||
|
||||
|
||||
def test_get_app_config_keeps_persistence_runtime_singletons_when_checkpointer_unchanged(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config_with_sections(
|
||||
config_path,
|
||||
{
|
||||
"title": {"enabled": False},
|
||||
"checkpointer": {"type": "memory"},
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
_reset_config_singletons()
|
||||
|
||||
try:
|
||||
get_app_config()
|
||||
initial_checkpointer = get_checkpointer()
|
||||
initial_store = get_store()
|
||||
|
||||
_write_config_with_sections(
|
||||
config_path,
|
||||
{
|
||||
"title": {"enabled": True},
|
||||
"checkpointer": {"type": "memory"},
|
||||
},
|
||||
)
|
||||
next_mtime = config_path.stat().st_mtime + 5
|
||||
os.utime(config_path, (next_mtime, next_mtime))
|
||||
|
||||
get_app_config()
|
||||
|
||||
assert get_checkpointer() is initial_checkpointer
|
||||
assert get_store() is initial_store
|
||||
finally:
|
||||
_reset_config_singletons()
|
||||
|
||||
|
||||
def test_get_app_config_does_not_mutate_singletons_when_reload_validation_fails(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
_write_config_with_sections(
|
||||
config_path,
|
||||
{
|
||||
"title": {"enabled": False},
|
||||
"tool_search": {"enabled": True},
|
||||
"checkpointer": {"type": "memory"},
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
_reset_config_singletons()
|
||||
|
||||
try:
|
||||
previous_app_config = get_app_config()
|
||||
initial_checkpointer = get_checkpointer()
|
||||
initial_store = get_store()
|
||||
|
||||
_write_config_with_sections(
|
||||
config_path,
|
||||
{
|
||||
"title": False,
|
||||
"tool_search": False,
|
||||
"checkpointer": {"type": "memory"},
|
||||
},
|
||||
)
|
||||
next_mtime = config_path.stat().st_mtime + 5
|
||||
os.utime(config_path, (next_mtime, next_mtime))
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
get_app_config()
|
||||
|
||||
assert app_config_module._app_config is previous_app_config
|
||||
assert get_title_config().enabled is False
|
||||
assert get_tool_search_config().enabled is True
|
||||
assert get_checkpointer_config() is not None
|
||||
assert get_checkpointer() is initial_checkpointer
|
||||
assert get_store() is initial_store
|
||||
finally:
|
||||
_reset_config_singletons()
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.message_bus import MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.message_bus import InboundMessage, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
|
||||
def _run(coro):
|
||||
@@ -248,6 +249,109 @@ class TestResolveAttachments:
|
||||
assert result[0].filename == "data.csv"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inbound file ingestion tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInboundFileIngestion:
|
||||
def test_rejects_preexisting_symlink_destination(self, tmp_path):
|
||||
from app.channels import manager
|
||||
|
||||
uploads_dir = tmp_path / "uploads"
|
||||
uploads_dir.mkdir()
|
||||
outside_file = tmp_path / "outside-created.txt"
|
||||
(uploads_dir / "victim.txt").symlink_to(outside_file)
|
||||
|
||||
msg = InboundMessage(
|
||||
channel_name="test-channel",
|
||||
chat_id="chat-1",
|
||||
user_id="user-1",
|
||||
text="see attachment",
|
||||
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
|
||||
)
|
||||
|
||||
async def fake_reader(file_info, client):
|
||||
return b"attacker data"
|
||||
|
||||
with (
|
||||
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
|
||||
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
|
||||
):
|
||||
result = _run(manager._ingest_inbound_files("thread-1", msg))
|
||||
|
||||
assert result == []
|
||||
assert not outside_file.exists()
|
||||
assert (uploads_dir / "victim.txt").is_symlink()
|
||||
|
||||
def test_rejects_dangling_symlink_destination(self, tmp_path):
|
||||
from app.channels import manager
|
||||
|
||||
uploads_dir = tmp_path / "uploads"
|
||||
uploads_dir.mkdir()
|
||||
missing_target = tmp_path / "missing-created.txt"
|
||||
(uploads_dir / "victim.txt").symlink_to(missing_target)
|
||||
|
||||
msg = InboundMessage(
|
||||
channel_name="test-channel",
|
||||
chat_id="chat-1",
|
||||
user_id="user-1",
|
||||
text="see attachment",
|
||||
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
|
||||
)
|
||||
|
||||
async def fake_reader(file_info, client):
|
||||
return b"attacker data"
|
||||
|
||||
with (
|
||||
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
|
||||
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
|
||||
):
|
||||
result = _run(manager._ingest_inbound_files("thread-1", msg))
|
||||
|
||||
assert result == []
|
||||
assert not missing_target.exists()
|
||||
assert (uploads_dir / "victim.txt").is_symlink()
|
||||
|
||||
def test_hardlinked_existing_file_is_not_overwritten(self, tmp_path):
|
||||
from app.channels import manager
|
||||
|
||||
uploads_dir = tmp_path / "uploads"
|
||||
uploads_dir.mkdir()
|
||||
outside_file = tmp_path / "outside-created.txt"
|
||||
outside_file.write_text("protected", encoding="utf-8")
|
||||
os.link(outside_file, uploads_dir / "victim.txt")
|
||||
|
||||
msg = InboundMessage(
|
||||
channel_name="test-channel",
|
||||
chat_id="chat-1",
|
||||
user_id="user-1",
|
||||
text="see attachment",
|
||||
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
|
||||
)
|
||||
|
||||
async def fake_reader(file_info, client):
|
||||
return b"new attachment data"
|
||||
|
||||
with (
|
||||
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
|
||||
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
|
||||
):
|
||||
result = _run(manager._ingest_inbound_files("thread-1", msg))
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"filename": "victim_1.txt",
|
||||
"size": len(b"new attachment data"),
|
||||
"path": "/mnt/user-data/uploads/victim_1.txt",
|
||||
"is_image": False,
|
||||
}
|
||||
]
|
||||
assert outside_file.read_text(encoding="utf-8") == "protected"
|
||||
assert (uploads_dir / "victim.txt").read_text(encoding="utf-8") == "protected"
|
||||
assert (uploads_dir / "victim_1.txt").read_bytes() == b"new attachment data"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Channel base class _on_outbound with attachments
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -372,6 +372,37 @@ class TestExtractResponseText:
|
||||
# Should return "" (no text in current turn), NOT "Hi there!" from previous turn
|
||||
assert _extract_response_text(result) == ""
|
||||
|
||||
def test_does_not_publish_loop_warning_on_tool_calling_ai_message(self):
|
||||
"""Loop-detection warning text on a tool-calling AI message is middleware-authored."""
|
||||
from app.channels.manager import _extract_response_text
|
||||
|
||||
result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "search the repo"},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "[LOOP DETECTED] You are repeating the same tool calls.",
|
||||
"tool_calls": [{"name": "grep", "args": {"pattern": "TODO"}, "id": "call_1"}],
|
||||
},
|
||||
]
|
||||
}
|
||||
assert _extract_response_text(result) == ""
|
||||
|
||||
def test_preserves_visible_text_when_stripping_loop_warning(self):
|
||||
from app.channels.manager import _extract_response_text
|
||||
|
||||
result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "prepare the report"},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "Here is the report.\n\n[LOOP DETECTED] You are repeating the same tool calls.",
|
||||
"tool_calls": [{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/report.md"]}, "id": "call_1"}],
|
||||
},
|
||||
]
|
||||
}
|
||||
assert _extract_response_text(result) == "Here is the report."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelManager tests
|
||||
@@ -435,6 +466,47 @@ class TestChannelManager:
|
||||
assert headers["Cookie"] == f"csrf_token={csrf_token}"
|
||||
assert headers["X-DeerFlow-Internal-Token"]
|
||||
|
||||
def test_fetch_gateway_includes_internal_auth_headers(self, monkeypatch):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
class MockResponse:
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
def json(self):
|
||||
return {"models": [{"name": "default"}]}
|
||||
|
||||
class MockAsyncClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, **kwargs):
|
||||
calls.append({"url": url, **kwargs})
|
||||
return MockResponse()
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr("app.channels.manager.httpx.AsyncClient", MockAsyncClient)
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store, gateway_url="http://gateway:8001")
|
||||
|
||||
reply = await manager._fetch_gateway("/api/models", "models")
|
||||
|
||||
assert reply == "Available models:\n• default"
|
||||
assert calls[0]["url"] == "http://gateway:8001/api/models"
|
||||
assert calls[0]["timeout"] == 10
|
||||
assert calls[0]["headers"]["X-DeerFlow-Internal-Token"]
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_chat_calls_channel_receive_file_for_inbound_files(self, monkeypatch):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
@@ -530,6 +602,8 @@ class TestChannelManager:
|
||||
assert call_args[0][0] == "test-thread-123" # thread_id
|
||||
assert call_args[0][1] == "lead_agent" # assistant_id
|
||||
assert call_args[1]["input"]["messages"][0]["content"] == "hi"
|
||||
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
|
||||
|
||||
assert len(outbound_received) == 1
|
||||
assert outbound_received[0].text == "Hello from agent!"
|
||||
@@ -661,12 +735,135 @@ class TestChannelManager:
|
||||
call_args = mock_client.runs.wait.call_args
|
||||
assert call_args[0][1] == "lead_agent"
|
||||
assert call_args[1]["config"]["recursion_limit"] == 55
|
||||
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
|
||||
assert call_args[1]["context"]["thinking_enabled"] is False
|
||||
assert call_args[1]["context"]["subagent_enabled"] is True
|
||||
assert call_args[1]["context"]["agent_name"] == "mobile-agent"
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_clarification_follow_up_preserves_history(self):
|
||||
"""Conversation should continue after ask_clarification instead of resetting history."""
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store)
|
||||
|
||||
outbound_received = []
|
||||
|
||||
async def capture_outbound(msg):
|
||||
outbound_received.append(msg)
|
||||
|
||||
bus.subscribe_outbound(capture_outbound)
|
||||
|
||||
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
|
||||
|
||||
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
|
||||
del assistant_id, context # unused in this test, kept for signature parity
|
||||
|
||||
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
|
||||
key = (thread_id, str(checkpoint_ns))
|
||||
history = history_by_checkpoint.setdefault(key, [])
|
||||
|
||||
human_text = input["messages"][0]["content"]
|
||||
history.append(human_text)
|
||||
|
||||
if len(history) == 1:
|
||||
return {
|
||||
"messages": [
|
||||
{"type": "human", "content": history[0]},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "ask_clarification",
|
||||
"args": {"question": "Which environment should I use?"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "ask_clarification",
|
||||
"content": "Which environment should I use?",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
if len(history) == 2 and history[0] == "Deploy my app" and history[1] == "prod":
|
||||
return {
|
||||
"messages": [
|
||||
{"type": "human", "content": history[0]},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "ask_clarification",
|
||||
"args": {"question": "Which environment should I use?"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"name": "ask_clarification",
|
||||
"content": "Which environment should I use?",
|
||||
},
|
||||
{"type": "human", "content": history[1]},
|
||||
{"type": "ai", "content": "Got it. I will deploy to prod."},
|
||||
]
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{"type": "human", "content": history[-1]},
|
||||
{"type": "ai", "content": "History missing; clarification repeated."},
|
||||
]
|
||||
}
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.threads.create = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
|
||||
mock_client.threads.get = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
|
||||
mock_client.runs.wait = AsyncMock(side_effect=_runs_wait)
|
||||
manager._client = mock_client
|
||||
|
||||
await manager.start()
|
||||
|
||||
await bus.publish_inbound(
|
||||
InboundMessage(
|
||||
channel_name="test",
|
||||
chat_id="chat1",
|
||||
user_id="user1",
|
||||
text="Deploy my app",
|
||||
)
|
||||
)
|
||||
await _wait_for(lambda: len(outbound_received) >= 1)
|
||||
|
||||
await bus.publish_inbound(
|
||||
InboundMessage(
|
||||
channel_name="test",
|
||||
chat_id="chat1",
|
||||
user_id="user1",
|
||||
text="prod",
|
||||
)
|
||||
)
|
||||
await _wait_for(lambda: len(outbound_received) >= 2)
|
||||
await manager.stop()
|
||||
|
||||
assert outbound_received[0].text == "Which environment should I use?"
|
||||
assert outbound_received[1].text == "Got it. I will deploy to prod."
|
||||
|
||||
assert mock_client.runs.wait.call_count == 2
|
||||
first_call = mock_client.runs.wait.call_args_list[0]
|
||||
second_call = mock_client.runs.wait.call_args_list[1]
|
||||
assert first_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
assert second_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_chat_uses_user_session_overrides(self):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
@@ -1343,6 +1540,8 @@ class TestChannelManager:
|
||||
call_args = mock_client.runs.stream.call_args
|
||||
|
||||
assert call_args[1]["input"]["messages"][0]["content"] == "hello"
|
||||
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
|
||||
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
|
||||
assert call_args[1]["context"]["is_bootstrap"] is True
|
||||
|
||||
# Final message should be published
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Unit tests for checkpointer config and singleton factory."""
|
||||
"""Unit tests for checkpointer config, packaging metadata, and factories."""
|
||||
|
||||
import sys
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -13,6 +15,8 @@ from deerflow.config.checkpointer_config import (
|
||||
set_checkpointer_config,
|
||||
)
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
from deerflow.runtime.checkpointer.provider import POSTGRES_INSTALL
|
||||
from deerflow.runtime.store.provider import POSTGRES_STORE_INSTALL
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -67,6 +71,42 @@ class TestCheckpointerConfig:
|
||||
with pytest.raises(Exception):
|
||||
load_checkpointer_config_from_dict({"type": "unknown"})
|
||||
|
||||
def test_connection_string_description_matches_runtime_defaults(self):
|
||||
description = CheckpointerConfig.model_fields["connection_string"].description
|
||||
|
||||
assert description is not None
|
||||
assert "Optional for sqlite" in description
|
||||
assert "defaults to 'store.db'" in description
|
||||
assert "Required for postgres" in description
|
||||
|
||||
|
||||
class TestHarnessPackaging:
|
||||
def test_pyproject_declares_postgres_extra(self):
|
||||
pyproject_path = Path(__file__).resolve().parents[1] / "packages" / "harness" / "pyproject.toml"
|
||||
data = tomllib.loads(pyproject_path.read_text())
|
||||
|
||||
optional_dependencies = data["project"]["optional-dependencies"]
|
||||
assert "postgres" in optional_dependencies
|
||||
assert optional_dependencies["postgres"] == [
|
||||
"asyncpg>=0.29",
|
||||
"langgraph-checkpoint-postgres>=3.0.5",
|
||||
"psycopg[binary]>=3.3.3",
|
||||
"psycopg-pool>=3.3.0",
|
||||
]
|
||||
|
||||
def test_workspace_pyproject_forwards_postgres_extra_to_harness(self):
|
||||
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
|
||||
data = tomllib.loads(pyproject_path.read_text())
|
||||
|
||||
optional_dependencies = data["project"]["optional-dependencies"]
|
||||
assert optional_dependencies["postgres"] == ["deerflow-harness[postgres]"]
|
||||
|
||||
def test_postgres_missing_dependency_messages_recommend_package_extra(self):
|
||||
assert "deerflow-harness[postgres]" in POSTGRES_INSTALL
|
||||
assert "deerflow-harness[postgres]" in POSTGRES_STORE_INSTALL
|
||||
assert "uv sync --all-packages --extra postgres" in POSTGRES_INSTALL
|
||||
assert "uv sync --all-packages --extra postgres" in POSTGRES_STORE_INSTALL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory tests
|
||||
|
||||
@@ -437,6 +437,85 @@ class TestStream:
|
||||
call_kwargs = agent.stream.call_args.kwargs
|
||||
assert "messages" in call_kwargs["stream_mode"]
|
||||
|
||||
def test_stream_emits_additional_kwargs_updates_for_streamed_ai_messages(self, client):
|
||||
"""stream() emits a follow-up AI event when attribution metadata arrives via values."""
|
||||
assembled = AIMessage(
|
||||
content="Hello!",
|
||||
id="ai-1",
|
||||
additional_kwargs={
|
||||
"token_usage_attribution": {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
agent = MagicMock()
|
||||
agent.stream.return_value = iter(
|
||||
[
|
||||
("messages", (AIMessageChunk(content="Hello!", id="ai-1"), {})),
|
||||
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(client, "_ensure_agent"),
|
||||
patch.object(client, "_agent", agent),
|
||||
):
|
||||
events = list(client.stream("hi", thread_id="t-stream-kwargs"))
|
||||
|
||||
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
|
||||
assert any(event.data.get("content") == "Hello!" for event in ai_events)
|
||||
assert any(event.data.get("additional_kwargs", {}).get("token_usage_attribution", {}).get("kind") == "final_answer" for event in ai_events)
|
||||
|
||||
def test_stream_emits_new_additional_kwargs_after_prior_metadata(self, client):
|
||||
"""stream() emits later attribution metadata even after earlier kwargs for the same id."""
|
||||
attribution = {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
assembled = AIMessage(
|
||||
content="Hello!",
|
||||
id="ai-1",
|
||||
additional_kwargs={
|
||||
"reasoning_content": "Thinking first.",
|
||||
"token_usage_attribution": attribution,
|
||||
},
|
||||
)
|
||||
agent = MagicMock()
|
||||
agent.stream.return_value = iter(
|
||||
[
|
||||
(
|
||||
"messages",
|
||||
(
|
||||
AIMessageChunk(
|
||||
content="Hello!",
|
||||
id="ai-1",
|
||||
additional_kwargs={"reasoning_content": "Thinking first."},
|
||||
),
|
||||
{},
|
||||
),
|
||||
),
|
||||
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(client, "_ensure_agent"),
|
||||
patch.object(client, "_agent", agent),
|
||||
):
|
||||
events = list(client.stream("hi", thread_id="t-stream-kwargs-delta"))
|
||||
|
||||
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
|
||||
metadata_events = [event for event in ai_events if event.data.get("additional_kwargs")]
|
||||
|
||||
assert metadata_events[0].data["additional_kwargs"] == {"reasoning_content": "Thinking first."}
|
||||
assert metadata_events[1].data["content"] == ""
|
||||
assert metadata_events[1].data["additional_kwargs"] == {"token_usage_attribution": attribution}
|
||||
|
||||
def test_chat_accumulates_streamed_deltas(self, client):
|
||||
"""chat() concatenates per-id deltas from messages mode."""
|
||||
agent = MagicMock()
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
"""Tests for DeerFlowClient message serialization helpers."""
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
|
||||
def test_serialize_ai_message_preserves_additional_kwargs():
|
||||
message = AIMessage(
|
||||
content="done",
|
||||
additional_kwargs={
|
||||
"token_usage_attribution": {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
},
|
||||
usage_metadata={"input_tokens": 12, "output_tokens": 3, "total_tokens": 15},
|
||||
)
|
||||
|
||||
serialized = DeerFlowClient._serialize_message(message)
|
||||
|
||||
assert serialized["type"] == "ai"
|
||||
assert serialized["usage_metadata"] == {
|
||||
"input_tokens": 12,
|
||||
"output_tokens": 3,
|
||||
"total_tokens": 15,
|
||||
}
|
||||
assert serialized["additional_kwargs"] == {
|
||||
"token_usage_attribution": {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_serialize_human_message_preserves_additional_kwargs():
|
||||
message = HumanMessage(
|
||||
content="hello",
|
||||
additional_kwargs={"files": [{"name": "diagram.png"}]},
|
||||
)
|
||||
|
||||
serialized = DeerFlowClient._serialize_message(message)
|
||||
|
||||
assert serialized == {
|
||||
"type": "human",
|
||||
"content": "hello",
|
||||
"id": None,
|
||||
"additional_kwargs": {"files": [{"name": "diagram.png"}]},
|
||||
}
|
||||
@@ -82,6 +82,36 @@ def test_parse_response_text_content():
|
||||
assert result.generations[0].message.content == "Hello world"
|
||||
|
||||
|
||||
def test_parse_response_populates_usage_metadata():
|
||||
model = _make_model()
|
||||
response = {
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"content": [{"type": "output_text", "text": "Hello world"}],
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
"input_tokens_details": {"cached_tokens": 3},
|
||||
"output_tokens_details": {"reasoning_tokens": 2},
|
||||
},
|
||||
"model": "gpt-5.4",
|
||||
}
|
||||
|
||||
result = model._parse_response(response)
|
||||
|
||||
meta = result.generations[0].message.usage_metadata
|
||||
assert meta is not None
|
||||
assert meta["input_tokens"] == 10
|
||||
assert meta["output_tokens"] == 5
|
||||
assert meta["total_tokens"] == 15
|
||||
assert meta["input_token_details"]["cache_read"] == 3
|
||||
assert meta["output_token_details"]["reasoning"] == 2
|
||||
|
||||
|
||||
def test_parse_response_reasoning_content():
|
||||
model = _make_model()
|
||||
response = {
|
||||
|
||||
@@ -192,6 +192,7 @@ def test_agent_features_defaults():
|
||||
assert f.vision is False
|
||||
assert f.auto_title is False
|
||||
assert f.guardrail is False
|
||||
assert f.loop_detection is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -630,6 +631,51 @@ def test_loop_detection_before_clarification(mock_create_agent):
|
||||
assert loop_idx == clar_idx - 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 30b. loop_detection=False skips LoopDetectionMiddleware
|
||||
# ---------------------------------------------------------------------------
|
||||
@patch("deerflow.agents.factory.create_agent")
|
||||
def test_loop_detection_disabled(mock_create_agent):
|
||||
mock_create_agent.return_value = MagicMock()
|
||||
create_deerflow_agent(
|
||||
_make_mock_model(),
|
||||
features=RuntimeFeatures(sandbox=False, loop_detection=False),
|
||||
)
|
||||
|
||||
call_kwargs = mock_create_agent.call_args[1]
|
||||
mw_types = [type(m).__name__ for m in call_kwargs["middleware"]]
|
||||
assert "LoopDetectionMiddleware" not in mw_types
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 30c. loop_detection=<custom AgentMiddleware> replaces the default
|
||||
# ---------------------------------------------------------------------------
|
||||
@patch("deerflow.agents.factory.create_agent")
|
||||
def test_loop_detection_custom_middleware(mock_create_agent):
|
||||
from langchain.agents.middleware import AgentMiddleware as AM
|
||||
|
||||
mock_create_agent.return_value = MagicMock()
|
||||
|
||||
class MyLoopDetection(AM):
|
||||
pass
|
||||
|
||||
custom = MyLoopDetection()
|
||||
create_deerflow_agent(
|
||||
_make_mock_model(),
|
||||
features=RuntimeFeatures(sandbox=False, loop_detection=custom),
|
||||
)
|
||||
|
||||
call_kwargs = mock_create_agent.call_args[1]
|
||||
middleware = call_kwargs["middleware"]
|
||||
assert custom in middleware
|
||||
mw_types = [type(m).__name__ for m in middleware]
|
||||
# Default LoopDetectionMiddleware must not also appear.
|
||||
assert "LoopDetectionMiddleware" not in mw_types
|
||||
# Custom replacement still sits immediately before ClarificationMiddleware.
|
||||
assert mw_types[-1] == "ClarificationMiddleware"
|
||||
assert mw_types[-2] == "MyLoopDetection"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 31. plan_mode=True adds TodoMiddleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -85,6 +85,8 @@ def test_load_claude_code_credential_from_override_path(tmp_path, monkeypatch):
|
||||
|
||||
def test_load_claude_code_credential_ignores_directory_path(tmp_path, monkeypatch):
|
||||
_clear_claude_code_env(monkeypatch)
|
||||
# Redirect HOME so the default ~/.claude/.credentials.json doesn't exist
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
cred_dir = tmp_path / "claude-creds-dir"
|
||||
cred_dir.mkdir()
|
||||
monkeypatch.setenv("CLAUDE_CODE_CREDENTIALS_PATH", str(cred_dir))
|
||||
|
||||
@@ -0,0 +1,235 @@
|
||||
"""Tests for CSRF middleware."""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||
|
||||
|
||||
def _make_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/api/v1/auth/login/local")
|
||||
async def login_local():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/v1/auth/register")
|
||||
async def register():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/threads/abc/runs/stream")
|
||||
async def protected_mutation():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_auth_post_rejects_cross_origin_browser_request():
|
||||
"""CSRF-exempt auth routes must not accept hostile browser origins.
|
||||
|
||||
Login/register endpoints intentionally skip the double-submit token because
|
||||
first-time callers do not have a token yet. They still set an auth session,
|
||||
so a hostile cross-site form POST must be rejected to avoid login CSRF /
|
||||
session fixation.
|
||||
"""
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://evil.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||
|
||||
|
||||
def test_auth_post_allows_same_origin_browser_request():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_rejects_malformed_origin_with_path():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example/path"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||
assert response.cookies.get("csrf_token") is None
|
||||
|
||||
|
||||
def test_auth_post_rejects_malformed_origin_with_invalid_port():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example:bad"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||
assert response.cookies.get("csrf_token") is None
|
||||
|
||||
|
||||
def test_auth_post_allows_same_origin_default_port_equivalence():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example:443"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_allows_forwarded_same_origin():
|
||||
client = TestClient(_make_app(), base_url="http://internal:8000")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={
|
||||
"Origin": "https://deerflow.example",
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "deerflow.example, internal:8000",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_allows_forwarded_same_origin_with_non_default_port():
|
||||
client = TestClient(_make_app(), base_url="http://internal:8000")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={
|
||||
"Origin": "http://localhost:2026",
|
||||
"X-Forwarded-Proto": "http",
|
||||
"X-Forwarded-Host": "localhost:2026",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_allows_rfc_forwarded_same_origin():
|
||||
client = TestClient(_make_app(), base_url="http://internal:8000")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={
|
||||
"Origin": "https://deerflow.example",
|
||||
"Forwarded": "proto=https;host=deerflow.example",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
assert "secure" in response.headers["set-cookie"].lower()
|
||||
|
||||
|
||||
def test_auth_post_allows_explicit_configured_origin(monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "https://app.example")
|
||||
client = TestClient(_make_app(), base_url="https://api.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/register",
|
||||
headers={"Origin": "https://app.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_does_not_treat_wildcard_cors_as_allowed_origin(monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "*")
|
||||
client = TestClient(_make_app(), base_url="https://api.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://evil.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||
|
||||
|
||||
def test_auth_post_sets_strict_samesite_csrf_cookie():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
set_cookie = response.headers["set-cookie"].lower()
|
||||
assert "csrf_token=" in set_cookie
|
||||
assert "samesite=strict" in set_cookie
|
||||
assert "secure" in set_cookie
|
||||
|
||||
|
||||
def test_auth_post_without_origin_still_allows_non_browser_clients():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post("/api/v1/auth/login/local")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_non_auth_mutation_still_requires_double_submit_token():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/threads/abc/runs/stream",
|
||||
headers={"Origin": "https://deerflow.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
|
||||
|
||||
|
||||
def test_non_auth_mutation_allows_valid_double_submit_token():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
client.cookies.set("csrf_token", "known-token")
|
||||
|
||||
response = client.post(
|
||||
"/api/threads/abc/runs/stream",
|
||||
headers={
|
||||
"Origin": "https://deerflow.example",
|
||||
"X-CSRF-Token": "known-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_non_auth_mutation_rejects_mismatched_double_submit_token():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
client.cookies.set("csrf_token", "cookie-token")
|
||||
|
||||
response = client.post(
|
||||
"/api/threads/abc/runs/stream",
|
||||
headers={
|
||||
"Origin": "https://deerflow.example",
|
||||
"X-CSRF-Token": "header-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "CSRF token mismatch."
|
||||
@@ -537,7 +537,10 @@ class TestAgentsAPI:
|
||||
def test_create_persists_files_on_disk(self, agent_client, tmp_path):
|
||||
agent_client.post("/api/agents", json={"name": "disk-check", "soul": "disk soul"})
|
||||
|
||||
agent_dir = tmp_path / "agents" / "disk-check"
|
||||
# tests/conftest.py installs an autouse fixture that sets the
|
||||
# contextvar to "test-user-autouse", so the agent is persisted under
|
||||
# users/test-user-autouse/agents/ rather than the legacy shared dir.
|
||||
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "disk-check"
|
||||
assert agent_dir.exists()
|
||||
assert (agent_dir / "config.yaml").exists()
|
||||
assert (agent_dir / "SOUL.md").exists()
|
||||
@@ -545,12 +548,23 @@ class TestAgentsAPI:
|
||||
|
||||
def test_delete_removes_files_from_disk(self, agent_client, tmp_path):
|
||||
agent_client.post("/api/agents", json={"name": "remove-me", "soul": "bye"})
|
||||
agent_dir = tmp_path / "agents" / "remove-me"
|
||||
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "remove-me"
|
||||
assert agent_dir.exists()
|
||||
|
||||
agent_client.delete("/api/agents/remove-me")
|
||||
assert not agent_dir.exists()
|
||||
|
||||
def test_create_rejects_legacy_name_collision(self, agent_client, tmp_path):
|
||||
"""An unmigrated legacy agent must still block name collision so that
|
||||
running the migration script later won't shadow the legacy entry."""
|
||||
legacy_dir = tmp_path / "agents" / "legacy-agent"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
(legacy_dir / "config.yaml").write_text("name: legacy-agent\n", encoding="utf-8")
|
||||
(legacy_dir / "SOUL.md").write_text("legacy soul", encoding="utf-8")
|
||||
|
||||
response = agent_client.post("/api/agents", json={"name": "legacy-agent", "soul": "x"})
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 9. Gateway API – User Profile endpoints
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
"""Unit tests for scripts/detect_uv_extras.py.
|
||||
|
||||
The detector resolves uv extras for `make dev` so that postgres (and any
|
||||
future opt-in extras) are not wiped on every restart — see Issue #2754.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
DETECT_SCRIPT_PATH = REPO_ROOT / "scripts" / "detect_uv_extras.py"
|
||||
|
||||
|
||||
spec = importlib.util.spec_from_file_location("deerflow_detect_uv_extras", DETECT_SCRIPT_PATH)
|
||||
assert spec is not None and spec.loader is not None
|
||||
detect = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(detect)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_cwd(tmp_path, monkeypatch):
|
||||
"""Isolate `find_config_file()` from the real repo by chdir + clearing env."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
monkeypatch.delenv("UV_EXTRAS", raising=False)
|
||||
monkeypatch.delenv("DEER_FLOW_CONFIG_PATH", raising=False)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_parse_env_extras_supports_comma_and_whitespace():
|
||||
assert detect.parse_env_extras("postgres") == ["postgres"]
|
||||
assert detect.parse_env_extras("postgres,ollama") == ["postgres", "ollama"]
|
||||
assert detect.parse_env_extras("postgres ollama") == ["postgres", "ollama"]
|
||||
assert detect.parse_env_extras(" postgres , ollama ,") == ["postgres", "ollama"]
|
||||
assert detect.parse_env_extras("") == []
|
||||
|
||||
|
||||
def test_parse_env_extras_drops_shell_metacharacters(capsys):
|
||||
"""A `.env` value containing shell injection bait must not pass through.
|
||||
|
||||
The whitelist guarantees the *bytes* that reach `uv sync` cannot include
|
||||
shell metacharacters. Any name that looks identifier-like still survives
|
||||
(uv itself will reject unknown extras with its own error), but `;`, `&`,
|
||||
backticks, parentheses, slashes, etc. are stripped.
|
||||
"""
|
||||
# Pure-metacharacter inputs collapse to empty.
|
||||
assert detect.parse_env_extras(";") == []
|
||||
assert detect.parse_env_extras("$(whoami)") == []
|
||||
assert detect.parse_env_extras("`echo bad`") == []
|
||||
assert detect.parse_env_extras("postgres;evil") == [] # single token, contains `;`
|
||||
# Splitting on whitespace yields ['rm'] which is identifier-shaped, but the
|
||||
# destructive bits (`;`, `-rf`, `/`) are dropped.
|
||||
assert detect.parse_env_extras("; rm -rf /") == ["rm"]
|
||||
err = capsys.readouterr().err
|
||||
assert "ignoring invalid UV_EXTRAS entry" in err
|
||||
assert "';'" in err # confirms the dangerous token was reported and dropped
|
||||
|
||||
|
||||
def test_parse_env_extras_rejects_leading_digits_and_punctuation():
|
||||
"""Names must start with a letter — pyproject extras follow this shape."""
|
||||
assert detect.parse_env_extras("1postgres") == []
|
||||
assert detect.parse_env_extras("-postgres") == []
|
||||
# Hyphens and underscores inside the name are fine.
|
||||
assert detect.parse_env_extras("post_gres") == ["post_gres"]
|
||||
assert detect.parse_env_extras("post-gres") == ["post-gres"]
|
||||
|
||||
|
||||
def test_format_flags_emits_one_flag_per_extra():
|
||||
assert detect.format_flags([]) == ""
|
||||
assert detect.format_flags(["postgres"]) == "--extra postgres"
|
||||
assert detect.format_flags(["postgres", "ollama"]) == "--extra postgres --extra ollama"
|
||||
|
||||
|
||||
def test_strip_comment_preserves_quoted_hash():
|
||||
assert detect._strip_comment("backend: postgres # trailing") == "backend: postgres"
|
||||
assert detect._strip_comment('name: "value#with-hash"') == 'name: "value#with-hash"'
|
||||
assert detect._strip_comment("# whole line comment") == ""
|
||||
|
||||
|
||||
def test_section_value_finds_nested_key():
|
||||
yaml_lines = [
|
||||
"database:",
|
||||
" backend: postgres",
|
||||
" postgres_url: $DATABASE_URL",
|
||||
"",
|
||||
"checkpointer:",
|
||||
" type: sqlite",
|
||||
]
|
||||
assert detect.section_value(yaml_lines, "database", "backend") == "postgres"
|
||||
assert detect.section_value(yaml_lines, "checkpointer", "type") == "sqlite"
|
||||
assert detect.section_value(yaml_lines, "database", "missing") is None
|
||||
assert detect.section_value(yaml_lines, "absent_section", "anything") is None
|
||||
|
||||
|
||||
def test_section_value_ignores_commented_lines():
|
||||
yaml_lines = [
|
||||
"# database:",
|
||||
"# backend: postgres",
|
||||
"database:",
|
||||
" backend: sqlite",
|
||||
]
|
||||
assert detect.section_value(yaml_lines, "database", "backend") == "sqlite"
|
||||
|
||||
|
||||
def test_section_value_strips_quotes():
|
||||
yaml_lines = [
|
||||
"database:",
|
||||
' backend: "postgres"',
|
||||
]
|
||||
assert detect.section_value(yaml_lines, "database", "backend") == "postgres"
|
||||
|
||||
|
||||
def test_section_value_does_not_descend_into_grandchildren():
|
||||
yaml_lines = [
|
||||
"database:",
|
||||
" backend: sqlite",
|
||||
" nested:",
|
||||
" backend: postgres",
|
||||
]
|
||||
# Only the immediate child level counts — keeps the parser predictable.
|
||||
assert detect.section_value(yaml_lines, "database", "backend") == "sqlite"
|
||||
|
||||
|
||||
def test_detect_from_config_postgres_via_database(tmp_path):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("database:\n backend: postgres\n postgres_url: $DATABASE_URL\n")
|
||||
assert detect.detect_from_config(cfg) == ["postgres"]
|
||||
|
||||
|
||||
def test_detect_from_config_postgres_via_checkpointer(tmp_path):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("checkpointer:\n type: postgres\n connection_string: postgresql://localhost/db\n")
|
||||
assert detect.detect_from_config(cfg) == ["postgres"]
|
||||
|
||||
|
||||
def test_detect_from_config_sqlite_returns_no_extras(tmp_path):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("database:\n backend: sqlite\n sqlite_dir: .deer-flow/data\n")
|
||||
assert detect.detect_from_config(cfg) == []
|
||||
|
||||
|
||||
def test_detect_from_config_dedupes_when_both_present(tmp_path):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("checkpointer:\n type: postgres\ndatabase:\n backend: postgres\n")
|
||||
# Sorted unique extras, no double-counting.
|
||||
assert detect.detect_from_config(cfg) == ["postgres"]
|
||||
|
||||
|
||||
def test_detect_from_config_missing_file_returns_empty(tmp_path):
|
||||
assert detect.detect_from_config(tmp_path / "does-not-exist.yaml") == []
|
||||
|
||||
|
||||
def test_resolve_extras_env_overrides_config(isolated_cwd, monkeypatch):
|
||||
cfg = isolated_cwd / "config.yaml"
|
||||
cfg.write_text("database:\n backend: sqlite\n")
|
||||
monkeypatch.setenv("UV_EXTRAS", "postgres")
|
||||
|
||||
assert detect.resolve_extras() == ["postgres"]
|
||||
|
||||
|
||||
def test_resolve_extras_env_supports_multiple(isolated_cwd, monkeypatch):
|
||||
monkeypatch.setenv("UV_EXTRAS", "postgres,ollama")
|
||||
assert detect.resolve_extras() == ["postgres", "ollama"]
|
||||
|
||||
|
||||
def test_resolve_extras_falls_back_to_config(isolated_cwd):
|
||||
(isolated_cwd / "config.yaml").write_text("database:\n backend: postgres\n")
|
||||
assert detect.resolve_extras() == ["postgres"]
|
||||
|
||||
|
||||
def test_resolve_extras_respects_explicit_config_path(tmp_path, monkeypatch):
|
||||
monkeypatch.delenv("UV_EXTRAS", raising=False)
|
||||
elsewhere = tmp_path / "elsewhere.yaml"
|
||||
elsewhere.write_text("database:\n backend: postgres\n")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(elsewhere))
|
||||
|
||||
assert detect.resolve_extras() == ["postgres"]
|
||||
|
||||
|
||||
def test_resolve_extras_no_config_no_env(isolated_cwd):
|
||||
assert detect.resolve_extras() == []
|
||||
|
||||
|
||||
def test_resolve_extras_finds_backend_subdir_config(isolated_cwd):
|
||||
sub = isolated_cwd / "backend"
|
||||
sub.mkdir()
|
||||
(sub / "config.yaml").write_text("database:\n backend: postgres\n")
|
||||
assert detect.resolve_extras() == ["postgres"]
|
||||
|
||||
|
||||
def test_resolve_extras_root_config_takes_precedence(isolated_cwd):
|
||||
(isolated_cwd / "config.yaml").write_text("database:\n backend: sqlite\n")
|
||||
sub = isolated_cwd / "backend"
|
||||
sub.mkdir()
|
||||
(sub / "config.yaml").write_text("database:\n backend: postgres\n")
|
||||
# Root config.yaml is checked first, matching the precedence in serve.sh.
|
||||
assert detect.resolve_extras() == []
|
||||
@@ -0,0 +1,102 @@
|
||||
"""Unit tests for docker/dev-entrypoint.sh (UV_EXTRAS validation + parsing).
|
||||
|
||||
Exercises the script via its `--print-extras` dry-run hook so we don't actually
|
||||
launch uvicorn or hit /app/logs. Together with test_detect_uv_extras.py these
|
||||
cover both the local make-dev path and the docker-compose-dev path with the
|
||||
same shape — see PR #2767 / Issue #2754.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
ENTRYPOINT = REPO_ROOT / "docker" / "dev-entrypoint.sh"
|
||||
|
||||
|
||||
def _run(uv_extras: str | None) -> subprocess.CompletedProcess[str]:
|
||||
"""Invoke `dev-entrypoint.sh --print-extras` with UV_EXTRAS set."""
|
||||
env = os.environ.copy()
|
||||
env.pop("UV_EXTRAS", None)
|
||||
if uv_extras is not None:
|
||||
env["UV_EXTRAS"] = uv_extras
|
||||
return subprocess.run(
|
||||
["sh", str(ENTRYPOINT), "--print-extras"],
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
|
||||
def test_entrypoint_script_exists_and_is_posix_sh():
|
||||
assert ENTRYPOINT.is_file()
|
||||
# Catch syntax errors before runtime — `sh -n` is a parse-only check.
|
||||
proc = subprocess.run(["sh", "-n", str(ENTRYPOINT)], capture_output=True, text=True, check=False)
|
||||
assert proc.returncode == 0, proc.stderr
|
||||
|
||||
|
||||
def test_no_uv_extras_yields_empty_flags():
|
||||
proc = _run(None)
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == ""
|
||||
|
||||
|
||||
def test_single_extra():
|
||||
proc = _run("postgres")
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == "--extra postgres"
|
||||
|
||||
|
||||
def test_multi_extra_comma_separated():
|
||||
proc = _run("postgres,ollama")
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == "--extra postgres --extra ollama"
|
||||
|
||||
|
||||
def test_multi_extra_whitespace_separated():
|
||||
proc = _run("postgres ollama")
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == "--extra postgres --extra ollama"
|
||||
|
||||
|
||||
def test_multi_extra_mixed_separators():
|
||||
proc = _run(" postgres , ollama ,")
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == "--extra postgres --extra ollama"
|
||||
|
||||
|
||||
def test_empty_string_yields_empty_flags():
|
||||
proc = _run("")
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad_value",
|
||||
[
|
||||
"; rm -rf /", # the canonical injection attempt
|
||||
"$(whoami)", # command substitution
|
||||
"`echo bad`", # backticks
|
||||
"postgres;evil", # mixed legal+illegal in a single token
|
||||
"1postgres", # leading digit
|
||||
"-postgres", # leading hyphen
|
||||
"post gres extra/path", # contains slash
|
||||
],
|
||||
)
|
||||
def test_metacharacters_abort_with_nonzero_exit(bad_value):
|
||||
proc = _run(bad_value)
|
||||
assert proc.returncode != 0, f"expected abort for {bad_value!r}, got 0"
|
||||
assert "is invalid" in proc.stderr
|
||||
assert proc.stdout.strip() == ""
|
||||
|
||||
|
||||
def test_underscores_and_hyphens_in_name_are_allowed():
|
||||
"""Mirrors uv's accepted shape for `[project.optional-dependencies]` keys."""
|
||||
proc = _run("post_gres,post-gres")
|
||||
assert proc.returncode == 0
|
||||
assert proc.stdout.strip() == "--extra post_gres --extra post-gres"
|
||||
@@ -0,0 +1,336 @@
|
||||
"""Tests for DynamicContextMiddleware.
|
||||
|
||||
Verifies that memory and current date are injected as a <system-reminder> into
|
||||
the first HumanMessage exactly once per session (frozen-snapshot pattern).
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import (
|
||||
_DYNAMIC_CONTEXT_REMINDER_KEY,
|
||||
DynamicContextMiddleware,
|
||||
)
|
||||
|
||||
_SYSTEM_REMINDER_TAG = "<system-reminder>"
|
||||
|
||||
|
||||
def _make_middleware(**kwargs) -> DynamicContextMiddleware:
|
||||
return DynamicContextMiddleware(**kwargs)
|
||||
|
||||
|
||||
def _fake_runtime():
|
||||
return SimpleNamespace(context={})
|
||||
|
||||
|
||||
def _reminder_msg(content: str, msg_id: str) -> HumanMessage:
|
||||
"""Build a reminder HumanMessage the way the middleware would produce it."""
|
||||
return HumanMessage(
|
||||
content=content,
|
||||
id=msg_id,
|
||||
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Basic injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_injects_system_reminder_into_first_human_message():
|
||||
mw = _make_middleware()
|
||||
state = {"messages": [HumanMessage(content="Hello", id="msg-1")]}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is not None
|
||||
updated_msgs = result["messages"]
|
||||
assert len(updated_msgs) == 2
|
||||
|
||||
reminder_msg = updated_msgs[0]
|
||||
assert isinstance(reminder_msg, HumanMessage)
|
||||
assert reminder_msg.id == "msg-1" # takes the original ID (position swap)
|
||||
assert reminder_msg.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||
assert _SYSTEM_REMINDER_TAG in reminder_msg.content
|
||||
assert "<current_date>2026-05-08, Friday</current_date>" in reminder_msg.content
|
||||
assert "Hello" not in reminder_msg.content # reminder only — no user text
|
||||
|
||||
user_msg = updated_msgs[1]
|
||||
assert isinstance(user_msg, HumanMessage)
|
||||
assert user_msg.id == "msg-1__user" # derived ID
|
||||
assert user_msg.content == "Hello"
|
||||
|
||||
|
||||
def test_memory_included_when_present():
|
||||
mw = _make_middleware()
|
||||
state = {"messages": [HumanMessage(content="Hi", id="msg-1")]}
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"deerflow.agents.lead_agent.prompt._get_memory_context",
|
||||
return_value="<memory>\nUser prefers Python.\n</memory>",
|
||||
),
|
||||
mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt,
|
||||
):
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
# Reminder is the first returned message; user query is the second
|
||||
reminder_content = result["messages"][0].content
|
||||
assert "User prefers Python." in reminder_content
|
||||
assert "<current_date>2026-05-08, Friday</current_date>" in reminder_content
|
||||
assert result["messages"][1].content == "Hi"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frozen-snapshot: no re-injection within a session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_skips_injection_if_already_present():
|
||||
"""Second turn: separate reminder message already present → no update."""
|
||||
mw = _make_middleware()
|
||||
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
|
||||
state = {
|
||||
"messages": [
|
||||
_reminder_msg(reminder_content, "msg-1"),
|
||||
HumanMessage(content="Hello", id="msg-1__user"),
|
||||
AIMessage(content="Hi there"),
|
||||
HumanMessage(content="Follow-up", id="msg-2"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is None # no update needed
|
||||
|
||||
|
||||
def test_injects_only_into_first_human_message_not_later_ones():
|
||||
"""Reminder targets the first HumanMessage; subsequent messages are not touched."""
|
||||
mw = _make_middleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="First", id="msg-1"),
|
||||
AIMessage(content="Reply"),
|
||||
HumanMessage(content="Second", id="msg-2"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
# Only the two injected messages are returned (reminder + original first query)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0].id == "msg-1" # reminder takes first message's ID
|
||||
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||
assert _SYSTEM_REMINDER_TAG in msgs[0].content
|
||||
assert msgs[1].id == "msg-1__user" # original content with derived ID
|
||||
assert msgs[1].content == "First"
|
||||
# "Second" (msg-2) is not in the returned update — it is left unchanged
|
||||
assert all(m.id != "msg-2" for m in msgs)
|
||||
|
||||
|
||||
def test_summary_human_message_is_not_used_as_injection_target():
|
||||
"""After summarization, the synthetic summary HumanMessage is not a user turn."""
|
||||
mw = _make_middleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="Here is a summary of the conversation to date:\n\n...", id="summary-1", name="summary"),
|
||||
AIMessage(content="Earlier reply"),
|
||||
HumanMessage(content="Follow-up", id="msg-2"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0].id == "msg-2"
|
||||
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||
assert msgs[1].id == "msg-2__user"
|
||||
assert msgs[1].content == "Follow-up"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_no_messages_returns_none():
|
||||
mw = _make_middleware()
|
||||
result = mw.before_agent({"messages": []}, _fake_runtime())
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_no_human_message_returns_none():
|
||||
mw = _make_middleware()
|
||||
state = {"messages": [AIMessage(content="assistant only")]}
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""):
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_list_content_message_handled_as_separate_reminder():
|
||||
"""List-content (e.g. multi-modal) messages remain intact; reminder is a separate message."""
|
||||
mw = _make_middleware()
|
||||
original_content = [{"type": "text", "text": "Hello"}]
|
||||
state = {"messages": [HumanMessage(content=original_content, id="msg-1")]}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 2
|
||||
# Reminder is a plain string message with the flag set
|
||||
assert isinstance(msgs[0].content, str)
|
||||
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||
assert _SYSTEM_REMINDER_TAG in msgs[0].content
|
||||
# Original list-content message is untouched
|
||||
assert msgs[1].content == original_content
|
||||
|
||||
|
||||
def test_reminder_uses_original_id_user_message_uses_derived_id():
|
||||
"""Reminder takes original ID (position swap); user message gets {id}__user."""
|
||||
mw = _make_middleware()
|
||||
original_id = "original-id-abc"
|
||||
state = {"messages": [HumanMessage(content="Hello", id=original_id)]}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result["messages"][0].id == original_id
|
||||
assert result["messages"][1].id == f"{original_id}__user"
|
||||
|
||||
|
||||
def test_message_without_id_gets_stable_uuid():
|
||||
"""If the original HumanMessage has no ID, a UUID is generated and used consistently."""
|
||||
mw = _make_middleware()
|
||||
state = {"messages": [HumanMessage(content="Hello", id=None)]}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is not None
|
||||
reminder_id = result["messages"][0].id
|
||||
user_id = result["messages"][1].id
|
||||
assert reminder_id is not None
|
||||
assert reminder_id != "None"
|
||||
assert user_id == f"{reminder_id}__user"
|
||||
|
||||
|
||||
def test_user_message_containing_system_reminder_tag_does_not_prevent_injection():
|
||||
"""A user message containing '<system-reminder>' must not be mistaken for a reminder."""
|
||||
mw = _make_middleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="What is <system-reminder>?", id="msg-1"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
# Injection must happen — the user message does NOT carry the reminder flag
|
||||
assert result is not None
|
||||
assert result["messages"][0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Midnight crossing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_midnight_crossing_injects_date_update_as_separate_message():
|
||||
"""When the date has changed, a separate date-update reminder is injected before
|
||||
the current turn's HumanMessage using the ID-swap technique."""
|
||||
mw = _make_middleware()
|
||||
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
|
||||
state = {
|
||||
"messages": [
|
||||
_reminder_msg(reminder_content, "msg-1"),
|
||||
HumanMessage(content="Hello", id="msg-1__user"),
|
||||
AIMessage(content="Response"),
|
||||
HumanMessage(content="Good morning", id="msg-2"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 2
|
||||
|
||||
# Date-update reminder takes the current message's ID
|
||||
assert msgs[0].id == "msg-2"
|
||||
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||
assert _SYSTEM_REMINDER_TAG in msgs[0].content
|
||||
assert "<current_date>2026-05-09, Saturday</current_date>" in msgs[0].content
|
||||
assert "Good morning" not in msgs[0].content # reminder only
|
||||
|
||||
# Original user text appended with derived ID
|
||||
assert msgs[1].id == "msg-2__user"
|
||||
assert msgs[1].content == "Good morning"
|
||||
|
||||
|
||||
def test_midnight_crossing_id_swap():
|
||||
"""Date-update reminder uses original ID; user message uses {id}__user."""
|
||||
mw = _make_middleware()
|
||||
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
|
||||
state = {
|
||||
"messages": [
|
||||
_reminder_msg(reminder_content, "msg-1"),
|
||||
HumanMessage(content="Next day message", id="msg-2"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result["messages"][0].id == "msg-2"
|
||||
assert result["messages"][1].id == "msg-2__user"
|
||||
|
||||
|
||||
def test_no_second_midnight_injection_once_date_updated():
|
||||
"""After a midnight update is persisted, the same-day path skips re-injection."""
|
||||
mw = _make_middleware()
|
||||
date_update_content = "<system-reminder>\n<current_date>2026-05-09, Saturday</current_date>\n</system-reminder>"
|
||||
state = {
|
||||
"messages": [
|
||||
_reminder_msg(
|
||||
"<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
|
||||
"msg-1",
|
||||
),
|
||||
HumanMessage(content="Hello", id="msg-1__user"),
|
||||
AIMessage(content="Response"),
|
||||
_reminder_msg(date_update_content, "msg-2"),
|
||||
HumanMessage(content="Good morning", id="msg-2__user"),
|
||||
AIMessage(content="Good morning!"),
|
||||
HumanMessage(content="Third turn", id="msg-3"),
|
||||
]
|
||||
}
|
||||
|
||||
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
|
||||
result = mw.before_agent(state, _fake_runtime())
|
||||
|
||||
assert result is None # same day as last injected date → no update
|
||||
@@ -50,7 +50,7 @@ def test_nginx_routes_official_langgraph_prefix_to_gateway_api():
|
||||
assert "/api/langgraph-compat" not in content
|
||||
assert "proxy_pass http://langgraph" not in content
|
||||
assert "rewrite ^/api/langgraph/(.*) /api/$1 break;" in content
|
||||
assert "proxy_pass http://gateway" in content
|
||||
assert "proxy_pass http://gateway" in content or "proxy_pass http://$gateway_upstream" in content
|
||||
|
||||
|
||||
def test_frontend_rewrites_langgraph_prefix_to_gateway():
|
||||
|
||||
@@ -324,6 +324,21 @@ def test_context_does_not_override_existing_configurable():
|
||||
assert config["configurable"]["subagent_enabled"] is True
|
||||
|
||||
|
||||
def test_inject_authenticated_user_context_overrides_client_user_id():
|
||||
"""Run context should carry the authenticated user, not client-supplied user_id."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.gateway.services import build_run_config, inject_authenticated_user_context
|
||||
|
||||
config = build_run_config("thread-1", None, None)
|
||||
config["context"] = {"user_id": "spoofed-client"}
|
||||
request = SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(id="auth-user-42")))
|
||||
|
||||
inject_authenticated_user_context(config, request)
|
||||
|
||||
assert config["context"]["user_id"] == "auth-user-42"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -8,17 +8,20 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.summarization_config import SummarizationConfig
|
||||
|
||||
|
||||
def _make_app_config(models: list[ModelConfig]) -> AppConfig:
|
||||
def _make_app_config(models: list[ModelConfig], loop_detection: LoopDetectionConfig | None = None) -> AppConfig:
|
||||
return AppConfig(
|
||||
models=models,
|
||||
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider"),
|
||||
loop_detection=loop_detection or LoopDetectionConfig(),
|
||||
)
|
||||
|
||||
|
||||
@@ -340,6 +343,59 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa
|
||||
assert middlewares[0] == "base-middleware"
|
||||
|
||||
|
||||
def test_build_middlewares_uses_loop_detection_config(monkeypatch):
|
||||
app_config = _make_app_config(
|
||||
[_make_model("safe-model", supports_thinking=False)],
|
||||
loop_detection=LoopDetectionConfig(
|
||||
warn_threshold=7,
|
||||
hard_limit=9,
|
||||
window_size=30,
|
||||
max_tracked_threads=40,
|
||||
tool_freq_warn=50,
|
||||
tool_freq_hard_limit=60,
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
|
||||
middlewares = lead_agent_module._build_middlewares(
|
||||
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
|
||||
model_name="safe-model",
|
||||
app_config=app_config,
|
||||
)
|
||||
|
||||
loop_detection = next(m for m in middlewares if isinstance(m, LoopDetectionMiddleware))
|
||||
assert loop_detection.warn_threshold == 7
|
||||
assert loop_detection.hard_limit == 9
|
||||
assert loop_detection.window_size == 30
|
||||
assert loop_detection.max_tracked_threads == 40
|
||||
assert loop_detection.tool_freq_warn == 50
|
||||
assert loop_detection.tool_freq_hard_limit == 60
|
||||
|
||||
|
||||
def test_build_middlewares_omits_loop_detection_when_disabled(monkeypatch):
|
||||
app_config = _make_app_config(
|
||||
[_make_model("safe-model", supports_thinking=False)],
|
||||
loop_detection=LoopDetectionConfig(enabled=False),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
|
||||
middlewares = lead_agent_module._build_middlewares(
|
||||
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
|
||||
model_name="safe-model",
|
||||
app_config=app_config,
|
||||
)
|
||||
|
||||
assert not any(isinstance(m, LoopDetectionMiddleware) for m in middlewares)
|
||||
|
||||
|
||||
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
|
||||
app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)])
|
||||
app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork")
|
||||
|
||||
@@ -1,22 +1,37 @@
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.subagents_config import CustomSubagentConfig, SubagentsAppConfig
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.skills.types import Skill, SkillCategory
|
||||
|
||||
|
||||
def _set_skills_cache_state(*, skills=None, active=False, version=0):
|
||||
prompt_module._get_cached_skills_prompt_section.cache_clear()
|
||||
with prompt_module._enabled_skills_lock:
|
||||
prompt_module._enabled_skills_cache = skills
|
||||
prompt_module._enabled_skills_by_config_cache.clear()
|
||||
prompt_module._enabled_skills_refresh_active = active
|
||||
prompt_module._enabled_skills_refresh_version = version
|
||||
prompt_module._enabled_skills_refresh_event.clear()
|
||||
|
||||
|
||||
def test_build_self_update_section_empty_for_default_agent():
|
||||
assert prompt_module._build_self_update_section(None) == ""
|
||||
|
||||
|
||||
def test_build_self_update_section_present_for_custom_agent():
|
||||
section = prompt_module._build_self_update_section("my-agent")
|
||||
|
||||
assert "<self_update>" in section
|
||||
assert "my-agent" in section
|
||||
assert "update_agent" in section
|
||||
|
||||
|
||||
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
|
||||
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
@@ -220,7 +235,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=skill_dir.relative_to(tmp_path),
|
||||
category="custom",
|
||||
category=SkillCategory.CUSTOM,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
@@ -240,6 +255,58 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
_set_skills_cache_state()
|
||||
|
||||
|
||||
def test_explicit_config_enabled_skills_are_cached_by_config_identity(monkeypatch, tmp_path):
|
||||
def make_skill(name: str) -> Skill:
|
||||
skill_dir = tmp_path / name
|
||||
return Skill(
|
||||
name=name,
|
||||
description=f"Description for {name}",
|
||||
license="MIT",
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=skill_dir.relative_to(tmp_path),
|
||||
category=SkillCategory.CUSTOM,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
config = cast(
|
||||
AppConfig,
|
||||
cast(
|
||||
object,
|
||||
SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
),
|
||||
),
|
||||
)
|
||||
load_count = 0
|
||||
|
||||
def fake_get_or_new_skill_storage(**kwargs):
|
||||
nonlocal load_count
|
||||
assert kwargs == {"app_config": config}
|
||||
|
||||
def load_skills(*, enabled_only):
|
||||
nonlocal load_count
|
||||
load_count += 1
|
||||
assert enabled_only is True
|
||||
return [make_skill("cached-skill")]
|
||||
|
||||
return SimpleNamespace(load_skills=load_skills)
|
||||
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fake_get_or_new_skill_storage)
|
||||
_set_skills_cache_state()
|
||||
|
||||
try:
|
||||
first = prompt_module.get_skills_prompt_section(app_config=config)
|
||||
second = prompt_module.get_skills_prompt_section(app_config=config)
|
||||
|
||||
assert "cached-skill" in first
|
||||
assert "cached-skill" in second
|
||||
assert load_count == 1
|
||||
finally:
|
||||
_set_skills_cache_state()
|
||||
|
||||
|
||||
def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_path):
|
||||
started = threading.Event()
|
||||
release = threading.Event()
|
||||
@@ -257,7 +324,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=skill_dir.relative_to(tmp_path),
|
||||
category="custom",
|
||||
category=SkillCategory.CUSTOM,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,12 @@ from deerflow.config.agents_config import AgentConfig
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
|
||||
def _make_skill(name: str) -> Skill:
|
||||
class NamedTool:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
def _make_skill(name: str, allowed_tools: list[str] | None = None) -> Skill:
|
||||
return Skill(
|
||||
name=name,
|
||||
description=f"Description for {name}",
|
||||
@@ -15,6 +20,7 @@ def _make_skill(name: str) -> Skill:
|
||||
skill_file=Path(f"/tmp/{name}/SKILL.md"),
|
||||
relative_path=Path(name),
|
||||
category="public",
|
||||
allowed_tools=allowed_tools,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
@@ -132,6 +138,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
|
||||
@@ -164,3 +171,106 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
assert captured_skills[-1] == {"skill1"}
|
||||
|
||||
|
||||
def test_make_lead_agent_filters_tools_from_available_skills(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted", "legacy"]))
|
||||
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("restricted", ["read_file"]), _make_skill("legacy", None)])
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
|
||||
|
||||
|
||||
def test_make_lead_agent_all_legacy_skills_preserve_all_tools(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
|
||||
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("legacy", None)])
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file")])
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
assert [tool.name for tool in agent_kwargs["tools"]] == ["bash", "read_file", "update_agent"]
|
||||
|
||||
|
||||
def test_make_lead_agent_enforces_allowed_tools_when_skill_cache_is_cold(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
mock_storage = SimpleNamespace(load_skills=lambda *, enabled_only: [_make_skill("restricted", ["read_file"])])
|
||||
|
||||
with prompt_module._enabled_skills_lock:
|
||||
prompt_module._enabled_skills_cache = None
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None, **kwargs: mock_storage)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
|
||||
|
||||
|
||||
def test_make_lead_agent_fails_closed_when_skill_policy_load_fails(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
create_agent_mock = MagicMock()
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", create_agent_mock)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
|
||||
def fail_storage(*args, **kwargs):
|
||||
raise RuntimeError("skill storage unavailable")
|
||||
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fail_storage)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
with pytest.raises(RuntimeError, match="skill storage unavailable"):
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
create_agent_mock.assert_not_called()
|
||||
|
||||
@@ -105,6 +105,7 @@ def test_execute_command_uses_powershell_command_mode_on_windows(monkeypatch):
|
||||
"capture_output": True,
|
||||
"text": True,
|
||||
"timeout": 600,
|
||||
"env": None,
|
||||
},
|
||||
)
|
||||
]
|
||||
@@ -118,6 +119,7 @@ def test_execute_command_uses_posix_shell_command_mode_on_windows(monkeypatch):
|
||||
return SimpleNamespace(stdout="ok", stderr="", returncode=0)
|
||||
|
||||
monkeypatch.setattr(local_sandbox.os, "name", "nt")
|
||||
monkeypatch.setattr(local_sandbox.os, "environ", {"PATH": r"C:\Program Files\Git\bin"})
|
||||
monkeypatch.setattr(LocalSandbox, "_get_shell", staticmethod(lambda: r"C:\Program Files\Git\bin\sh.exe"))
|
||||
monkeypatch.setattr(local_sandbox.subprocess, "run", fake_run)
|
||||
|
||||
@@ -132,11 +134,33 @@ def test_execute_command_uses_posix_shell_command_mode_on_windows(monkeypatch):
|
||||
"capture_output": True,
|
||||
"text": True,
|
||||
"timeout": 600,
|
||||
"env": {
|
||||
"PATH": r"C:\Program Files\Git\bin",
|
||||
"MSYS_NO_PATHCONV": "1",
|
||||
"MSYS2_ARG_CONV_EXCL": "*",
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_execute_command_does_not_set_msys_env_for_non_msys_posix_shell_on_windows(monkeypatch):
|
||||
calls: list[tuple[object, dict]] = []
|
||||
|
||||
def fake_run(*args, **kwargs):
|
||||
calls.append((args[0], kwargs))
|
||||
return SimpleNamespace(stdout="ok", stderr="", returncode=0)
|
||||
|
||||
monkeypatch.setattr(local_sandbox.os, "name", "nt")
|
||||
monkeypatch.setattr(LocalSandbox, "_get_shell", staticmethod(lambda: r"C:\tools\busybox\sh.exe"))
|
||||
monkeypatch.setattr(local_sandbox.subprocess, "run", fake_run)
|
||||
|
||||
output = LocalSandbox("t").execute_command("echo /mnt/skills/demo")
|
||||
|
||||
assert output == "ok"
|
||||
assert calls[0][1]["env"] is None
|
||||
|
||||
|
||||
def test_execute_command_uses_cmd_command_mode_on_windows(monkeypatch):
|
||||
calls: list[tuple[object, dict]] = []
|
||||
|
||||
@@ -159,6 +183,7 @@ def test_execute_command_uses_cmd_command_mode_on_windows(monkeypatch):
|
||||
"capture_output": True,
|
||||
"text": True,
|
||||
"timeout": 600,
|
||||
"env": None,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Tests for loop detection configuration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
|
||||
|
||||
class TestLoopDetectionConfig:
|
||||
def test_defaults_match_middleware_defaults(self):
|
||||
config = LoopDetectionConfig()
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.warn_threshold == 3
|
||||
assert config.hard_limit == 5
|
||||
assert config.window_size == 20
|
||||
assert config.max_tracked_threads == 100
|
||||
assert config.tool_freq_warn == 30
|
||||
assert config.tool_freq_hard_limit == 50
|
||||
|
||||
def test_accepts_custom_values(self):
|
||||
config = LoopDetectionConfig(
|
||||
enabled=False,
|
||||
warn_threshold=10,
|
||||
hard_limit=20,
|
||||
window_size=50,
|
||||
max_tracked_threads=200,
|
||||
tool_freq_warn=60,
|
||||
tool_freq_hard_limit=80,
|
||||
)
|
||||
|
||||
assert config.enabled is False
|
||||
assert config.warn_threshold == 10
|
||||
assert config.hard_limit == 20
|
||||
assert config.window_size == 50
|
||||
assert config.max_tracked_threads == 200
|
||||
assert config.tool_freq_warn == 60
|
||||
assert config.tool_freq_hard_limit == 80
|
||||
|
||||
def test_rejects_zero_thresholds(self):
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(warn_threshold=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(hard_limit=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(tool_freq_warn=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(tool_freq_hard_limit=0)
|
||||
|
||||
def test_rejects_hard_limit_below_warn_threshold(self):
|
||||
with pytest.raises(ValueError, match="hard_limit"):
|
||||
LoopDetectionConfig(warn_threshold=5, hard_limit=4)
|
||||
|
||||
def test_rejects_tool_freq_hard_limit_below_warn_threshold(self):
|
||||
with pytest.raises(ValueError, match="tool_freq_hard_limit"):
|
||||
LoopDetectionConfig(tool_freq_warn=5, tool_freq_hard_limit=4)
|
||||
|
||||
def test_tool_freq_override_valid(self):
|
||||
config = LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 150, "hard_limit": 300}})
|
||||
override = config.tool_freq_overrides["bash"]
|
||||
assert override.warn == 150
|
||||
assert override.hard_limit == 300
|
||||
|
||||
def test_tool_freq_override_rejects_zero_warn(self):
|
||||
with pytest.raises(ValueError):
|
||||
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 0, "hard_limit": 10}})
|
||||
|
||||
def test_tool_freq_override_rejects_hard_limit_below_warn(self):
|
||||
with pytest.raises(ValueError, match="hard_limit"):
|
||||
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 100, "hard_limit": 50}})
|
||||
@@ -3,7 +3,7 @@
|
||||
import copy
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import AIMessage, SystemMessage
|
||||
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import (
|
||||
_HARD_STOP_MSG,
|
||||
@@ -146,14 +146,42 @@ class TestLoopDetection:
|
||||
for _ in range(2):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# Third identical call triggers warning
|
||||
# Third identical call triggers warning. The warning is appended to
|
||||
# the AIMessage content (tool_calls preserved) — never inserted as a
|
||||
# separate HumanMessage between the AIMessage(tool_calls) and its
|
||||
# ToolMessage responses, which would break OpenAI/Moonshot strict
|
||||
# tool-call pairing validation.
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 1
|
||||
assert isinstance(msgs[0], HumanMessage)
|
||||
assert isinstance(msgs[0], AIMessage)
|
||||
assert len(msgs[0].tool_calls) == len(call)
|
||||
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
|
||||
assert "LOOP DETECTED" in msgs[0].content
|
||||
|
||||
def test_warn_does_not_break_tool_call_pairing(self):
|
||||
"""Regression: the warn branch must NOT inject a non-tool message
|
||||
after an AIMessage(tool_calls=...). Moonshot/OpenAI reject the next
|
||||
request with 'tool_call_ids did not have response messages' if any
|
||||
non-tool message is wedged between the AIMessage and its ToolMessage
|
||||
responses. See #2029.
|
||||
"""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(2):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 1
|
||||
assert isinstance(msgs[0], AIMessage)
|
||||
assert len(msgs[0].tool_calls) == len(call)
|
||||
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
|
||||
|
||||
def test_warn_only_injected_once(self):
|
||||
"""Warning for the same hash should only be injected once per thread."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
@@ -483,7 +511,11 @@ class TestToolFrequencyDetection:
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime)
|
||||
assert result is not None
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, HumanMessage)
|
||||
# Warning is appended to the AIMessage content; tool_calls preserved
|
||||
# so the tools node still runs and Moonshot/OpenAI tool-call pairing
|
||||
# validation does not break.
|
||||
assert isinstance(msg, AIMessage)
|
||||
assert msg.tool_calls
|
||||
assert "read_file" in msg.content
|
||||
assert "LOOP DETECTED" in msg.content
|
||||
|
||||
@@ -616,6 +648,37 @@ class TestToolFrequencyDetection:
|
||||
assert result is not None
|
||||
assert "read_file" in result["messages"][0].content
|
||||
|
||||
def test_override_tool_uses_override_thresholds(self):
|
||||
"""A tool in tool_freq_overrides uses its own thresholds, not the global ones."""
|
||||
mw = LoopDetectionMiddleware(
|
||||
tool_freq_warn=5,
|
||||
tool_freq_hard_limit=10,
|
||||
tool_freq_overrides={"bash": (50, 100)},
|
||||
)
|
||||
runtime = _make_runtime()
|
||||
|
||||
# 10 bash calls — would hit global hard_limit=10, but bash override is 100
|
||||
for i in range(10):
|
||||
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
|
||||
assert result is None, f"unexpected trigger on call {i + 1}"
|
||||
|
||||
def test_non_override_tool_falls_back_to_global(self):
|
||||
"""A tool NOT in tool_freq_overrides uses the global warn/hard_limit."""
|
||||
mw = LoopDetectionMiddleware(
|
||||
tool_freq_warn=3,
|
||||
tool_freq_hard_limit=6,
|
||||
tool_freq_overrides={"bash": (50, 100)},
|
||||
)
|
||||
runtime = _make_runtime()
|
||||
|
||||
for i in range(2):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
|
||||
|
||||
# 3rd read_file call hits global warn=3 (read_file has no override)
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
|
||||
assert result is not None
|
||||
assert "read_file" in result["messages"][0].content
|
||||
|
||||
def test_hash_detection_takes_priority(self):
|
||||
"""Hash-based hard stop fires before frequency check for identical calls."""
|
||||
mw = LoopDetectionMiddleware(
|
||||
@@ -636,3 +699,48 @@ class TestToolFrequencyDetection:
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, AIMessage)
|
||||
assert _HARD_STOP_MSG in msg.content
|
||||
|
||||
|
||||
class TestFromConfig:
|
||||
"""Tests for LoopDetectionMiddleware.from_config — the sole validated construction path."""
|
||||
|
||||
@staticmethod
|
||||
def _config(**kwargs):
|
||||
from deerflow.config.loop_detection_config import LoopDetectionConfig
|
||||
|
||||
return LoopDetectionConfig(**kwargs)
|
||||
|
||||
def test_scalar_fields_mapped(self):
|
||||
config = self._config(
|
||||
warn_threshold=4,
|
||||
hard_limit=8,
|
||||
window_size=15,
|
||||
max_tracked_threads=50,
|
||||
tool_freq_warn=20,
|
||||
tool_freq_hard_limit=40,
|
||||
)
|
||||
mw = LoopDetectionMiddleware.from_config(config)
|
||||
assert mw.warn_threshold == 4
|
||||
assert mw.hard_limit == 8
|
||||
assert mw.window_size == 15
|
||||
assert mw.max_tracked_threads == 50
|
||||
assert mw.tool_freq_warn == 20
|
||||
assert mw.tool_freq_hard_limit == 40
|
||||
|
||||
def test_overrides_converted_to_tuples(self):
|
||||
config = self._config(tool_freq_overrides={"bash": {"warn": 50, "hard_limit": 100}})
|
||||
mw = LoopDetectionMiddleware.from_config(config)
|
||||
assert mw._tool_freq_overrides == {"bash": (50, 100)}
|
||||
|
||||
def test_empty_overrides(self):
|
||||
mw = LoopDetectionMiddleware.from_config(self._config())
|
||||
assert mw._tool_freq_overrides == {}
|
||||
|
||||
def test_constructed_middleware_detects_loops(self):
|
||||
mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4))
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
|
||||
@@ -125,3 +125,68 @@ class TestMigrateMemory:
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
|
||||
migrate_memory(paths, user_id="default") # should not raise
|
||||
|
||||
|
||||
class TestMigrateAgents:
|
||||
@staticmethod
|
||||
def _seed_legacy_agent(paths: Paths, name: str, *, soul: str = "soul", description: str = "d") -> Path:
|
||||
legacy_dir = paths.agents_dir / name
|
||||
legacy_dir.mkdir(parents=True, exist_ok=True)
|
||||
(legacy_dir / "config.yaml").write_text(f"name: {name}\ndescription: {description}\n", encoding="utf-8")
|
||||
(legacy_dir / "SOUL.md").write_text(soul, encoding="utf-8")
|
||||
return legacy_dir
|
||||
|
||||
def test_moves_legacy_into_user_layout(self, base_dir: Path, paths: Paths):
|
||||
self._seed_legacy_agent(paths, "agent-a", soul="soul-a")
|
||||
self._seed_legacy_agent(paths, "agent-b", soul="soul-b")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_agents
|
||||
|
||||
report = migrate_agents(paths, user_id="default")
|
||||
|
||||
assert {entry["agent"] for entry in report} == {"agent-a", "agent-b"}
|
||||
for entry in report:
|
||||
assert entry["user_id"] == "default"
|
||||
assert "moved -> " in entry["action"]
|
||||
|
||||
for name, soul in [("agent-a", "soul-a"), ("agent-b", "soul-b")]:
|
||||
dest = paths.user_agent_dir("default", name)
|
||||
assert dest.exists(), f"{name} should have moved into the per-user layout"
|
||||
assert (dest / "SOUL.md").read_text() == soul
|
||||
|
||||
# Legacy agents/ root is cleaned up once empty.
|
||||
assert not paths.agents_dir.exists()
|
||||
|
||||
def test_dry_run_does_not_move(self, base_dir: Path, paths: Paths):
|
||||
legacy_dir = self._seed_legacy_agent(paths, "agent-a")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_agents
|
||||
|
||||
report = migrate_agents(paths, user_id="default", dry_run=True)
|
||||
|
||||
assert len(report) == 1
|
||||
assert legacy_dir.exists(), "dry-run must not touch the filesystem"
|
||||
assert not paths.user_agent_dir("default", "agent-a").exists()
|
||||
|
||||
def test_existing_destination_is_treated_as_conflict(self, base_dir: Path, paths: Paths):
|
||||
self._seed_legacy_agent(paths, "agent-a", soul="legacy soul")
|
||||
dest = paths.user_agent_dir("default", "agent-a")
|
||||
dest.mkdir(parents=True)
|
||||
(dest / "SOUL.md").write_text("preexisting", encoding="utf-8")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_agents
|
||||
|
||||
report = migrate_agents(paths, user_id="default")
|
||||
|
||||
assert report[0]["action"].startswith("conflict -> ")
|
||||
# Per-user destination must be left untouched.
|
||||
assert (dest / "SOUL.md").read_text() == "preexisting"
|
||||
# Legacy copy lands under migration-conflicts/agents/.
|
||||
conflicts_dir = paths.base_dir / "migration-conflicts" / "agents" / "agent-a"
|
||||
assert (conflicts_dir / "SOUL.md").read_text() == "legacy soul"
|
||||
|
||||
def test_no_legacy_dir_is_noop(self, base_dir: Path, paths: Paths):
|
||||
from scripts.migrate_user_isolation import migrate_agents
|
||||
|
||||
report = migrate_agents(paths, user_id="default")
|
||||
assert report == []
|
||||
|
||||
@@ -50,6 +50,21 @@ class TestUserAgentMemoryFile:
|
||||
assert paths.user_agent_memory_file("bob", "MyAgent") == expected
|
||||
|
||||
|
||||
class TestUserAgentDir:
|
||||
def test_user_agents_dir(self, paths: Paths):
|
||||
assert paths.user_agents_dir("alice") == paths.base_dir / "users" / "alice" / "agents"
|
||||
|
||||
def test_user_agent_dir(self, paths: Paths):
|
||||
assert paths.user_agent_dir("alice", "code-reviewer") == paths.base_dir / "users" / "alice" / "agents" / "code-reviewer"
|
||||
|
||||
def test_user_agent_dir_lowercases_name(self, paths: Paths):
|
||||
assert paths.user_agent_dir("alice", "CodeReviewer") == paths.base_dir / "users" / "alice" / "agents" / "codereviewer"
|
||||
|
||||
def test_user_agent_dir_validates_user_id(self, paths: Paths):
|
||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||
paths.user_agent_dir("../escape", "myagent")
|
||||
|
||||
|
||||
class TestUserThreadDir:
|
||||
def test_user_thread_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1"
|
||||
|
||||
@@ -8,7 +8,9 @@ Tests:
|
||||
5. Postgres missing-dep error message
|
||||
"""
|
||||
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -221,13 +223,8 @@ class TestEngineLifecycle:
|
||||
"""If asyncpg is not installed, error message tells user what to do."""
|
||||
from deerflow.persistence.engine import init_engine
|
||||
|
||||
try:
|
||||
import asyncpg # noqa: F401
|
||||
|
||||
pytest.skip("asyncpg is installed -- cannot test missing-dep path")
|
||||
except ImportError:
|
||||
# asyncpg is not installed — this is the expected state for this test.
|
||||
# We proceed to verify that init_engine raises an actionable ImportError.
|
||||
pass # noqa: S110 — intentionally ignored
|
||||
with pytest.raises(ImportError, match="uv sync --extra postgres"):
|
||||
with (
|
||||
patch.dict(sys.modules, {"asyncpg": None}),
|
||||
pytest.raises(ImportError, match="uv sync --all-packages --extra postgres"),
|
||||
):
|
||||
await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x")
|
||||
|
||||
@@ -0,0 +1,293 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from deerflow.community.aio_sandbox.remote_backend import RemoteSandboxBackend
|
||||
from deerflow.community.aio_sandbox.sandbox_info import SandboxInfo
|
||||
|
||||
|
||||
class _StubResponse:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
status_code: int = 200,
|
||||
payload: object | None = None,
|
||||
json_exc: Exception | None = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self._payload = {} if payload is None else payload
|
||||
self._json_exc = json_exc
|
||||
self.ok = 200 <= status_code < 400
|
||||
self.text = ""
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
if self.status_code >= 400:
|
||||
raise requests.HTTPError(f"HTTP {self.status_code}")
|
||||
|
||||
def json(self) -> object:
|
||||
if self._json_exc is not None:
|
||||
raise self._json_exc
|
||||
return self._payload
|
||||
|
||||
|
||||
def test_list_running_delegates_to_provisioner_list(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
sandbox_info = SandboxInfo(sandbox_id="test-id", sandbox_url="http://localhost:8080")
|
||||
|
||||
def mock_list():
|
||||
return [sandbox_info]
|
||||
|
||||
monkeypatch.setattr(backend, "_provisioner_list", mock_list)
|
||||
|
||||
assert backend.list_running() == [sandbox_info]
|
||||
|
||||
|
||||
def test_provisioner_list_returns_sandbox_infos_and_filters_invalid_entries(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
assert url == "http://provisioner:8002/api/sandboxes"
|
||||
assert timeout == 10
|
||||
return _StubResponse(
|
||||
payload={
|
||||
"sandboxes": [
|
||||
{"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"},
|
||||
{"sandbox_id": "missing-url"},
|
||||
{"sandbox_url": "http://k3s:31002"},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
infos = backend._provisioner_list()
|
||||
assert len(infos) == 1
|
||||
assert infos[0].sandbox_id == "abc123"
|
||||
assert infos[0].sandbox_url == "http://k3s:31001"
|
||||
|
||||
|
||||
def test_provisioner_list_returns_empty_on_request_exception(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
raise requests.RequestException("network down")
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
assert backend._provisioner_list() == []
|
||||
|
||||
|
||||
def test_provisioner_list_returns_empty_when_payload_is_not_dict(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
return _StubResponse(payload=[{"sandbox_id": "abc", "sandbox_url": "http://k3s:31001"}])
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
assert backend._provisioner_list() == []
|
||||
|
||||
|
||||
def test_provisioner_list_returns_empty_when_sandboxes_is_not_list(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
return _StubResponse(payload={"sandboxes": {"sandbox_id": "abc"}})
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
assert backend._provisioner_list() == []
|
||||
|
||||
|
||||
def test_provisioner_list_skips_non_dict_sandbox_entries(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
return _StubResponse(
|
||||
payload={
|
||||
"sandboxes": [
|
||||
{"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"},
|
||||
"bad-entry",
|
||||
123,
|
||||
None,
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
infos = backend._provisioner_list()
|
||||
assert len(infos) == 1
|
||||
assert infos[0].sandbox_id == "abc123"
|
||||
assert infos[0].sandbox_url == "http://k3s:31001"
|
||||
|
||||
|
||||
def test_create_delegates_to_provisioner_create(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
expected = SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001")
|
||||
|
||||
def mock_create(thread_id: str, sandbox_id: str, extra_mounts=None):
|
||||
assert thread_id == "thread-1"
|
||||
assert sandbox_id == "abc123"
|
||||
assert extra_mounts == [("/host", "/container", False)]
|
||||
return expected
|
||||
|
||||
monkeypatch.setattr(backend, "_provisioner_create", mock_create)
|
||||
|
||||
result = backend.create("thread-1", "abc123", extra_mounts=[("/host", "/container", False)])
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_provisioner_create_returns_sandbox_info(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_post(url: str, json: dict, timeout: int):
|
||||
assert url == "http://provisioner:8002/api/sandboxes"
|
||||
assert json == {"sandbox_id": "abc123", "thread_id": "thread-1"}
|
||||
assert timeout == 30
|
||||
return _StubResponse(payload={"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"})
|
||||
|
||||
monkeypatch.setattr(requests, "post", mock_post)
|
||||
|
||||
info = backend._provisioner_create("thread-1", "abc123")
|
||||
assert info.sandbox_id == "abc123"
|
||||
assert info.sandbox_url == "http://k3s:31001"
|
||||
|
||||
|
||||
def test_provisioner_create_raises_runtime_error_on_request_exception(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_post(url: str, json: dict, timeout: int):
|
||||
raise requests.RequestException("boom")
|
||||
|
||||
monkeypatch.setattr(requests, "post", mock_post)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Provisioner create failed"):
|
||||
backend._provisioner_create("thread-1", "abc123")
|
||||
|
||||
|
||||
def test_destroy_delegates_to_provisioner_destroy(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
called: list[str] = []
|
||||
|
||||
def mock_destroy(sandbox_id: str):
|
||||
called.append(sandbox_id)
|
||||
|
||||
monkeypatch.setattr(backend, "_provisioner_destroy", mock_destroy)
|
||||
|
||||
backend.destroy(SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001"))
|
||||
assert called == ["abc123"]
|
||||
|
||||
|
||||
def test_provisioner_destroy_calls_delete(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_delete(url: str, timeout: int):
|
||||
assert url == "http://provisioner:8002/api/sandboxes/abc123"
|
||||
assert timeout == 15
|
||||
return _StubResponse(status_code=200)
|
||||
|
||||
monkeypatch.setattr(requests, "delete", mock_delete)
|
||||
|
||||
backend._provisioner_destroy("abc123")
|
||||
|
||||
|
||||
def test_provisioner_destroy_swallows_request_exception(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_delete(url: str, timeout: int):
|
||||
raise requests.RequestException("network down")
|
||||
|
||||
monkeypatch.setattr(requests, "delete", mock_delete)
|
||||
|
||||
backend._provisioner_destroy("abc123")
|
||||
|
||||
|
||||
def test_is_alive_delegates_to_provisioner_is_alive(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_is_alive(sandbox_id: str):
|
||||
assert sandbox_id == "abc123"
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(backend, "_provisioner_is_alive", mock_is_alive)
|
||||
|
||||
alive = backend.is_alive(SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001"))
|
||||
assert alive is True
|
||||
|
||||
|
||||
def test_provisioner_is_alive_true_only_when_status_running(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get_running(url: str, timeout: int):
|
||||
return _StubResponse(payload={"status": "Running"})
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get_running)
|
||||
assert backend._provisioner_is_alive("abc123") is True
|
||||
|
||||
def mock_get_pending(url: str, timeout: int):
|
||||
return _StubResponse(payload={"status": "Pending"})
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get_pending)
|
||||
assert backend._provisioner_is_alive("abc123") is False
|
||||
|
||||
|
||||
def test_provisioner_is_alive_returns_false_on_request_exception(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
raise requests.RequestException("boom")
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
assert backend._provisioner_is_alive("abc123") is False
|
||||
|
||||
|
||||
def test_discover_delegates_to_provisioner_discover(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
expected = SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001")
|
||||
|
||||
def mock_discover(sandbox_id: str):
|
||||
assert sandbox_id == "abc123"
|
||||
return expected
|
||||
|
||||
monkeypatch.setattr(backend, "_provisioner_discover", mock_discover)
|
||||
|
||||
result = backend.discover("abc123")
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_provisioner_discover_returns_none_on_404(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
return _StubResponse(status_code=404)
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
assert backend._provisioner_discover("abc123") is None
|
||||
|
||||
|
||||
def test_provisioner_discover_returns_info_on_success(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
return _StubResponse(payload={"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"})
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
info = backend._provisioner_discover("abc123")
|
||||
assert info is not None
|
||||
assert info.sandbox_id == "abc123"
|
||||
assert info.sandbox_url == "http://k3s:31001"
|
||||
|
||||
|
||||
def test_provisioner_discover_returns_none_on_request_exception(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
raise requests.RequestException("boom")
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
assert backend._provisioner_discover("abc123") is None
|
||||
@@ -310,6 +310,28 @@ class TestDbRunEventStore:
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_structured_content_round_trips(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
content = [{"type": "text", "text": "hello"}, {"type": "image_url", "image_url": {"url": "https://example.test/a.png"}}]
|
||||
record = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=content)
|
||||
|
||||
assert record["content"] == content
|
||||
assert record["metadata"]["content_is_json"] is True
|
||||
assert "content_is_dict" not in record["metadata"]
|
||||
|
||||
messages = await s.list_messages("t1")
|
||||
assert messages[0]["content"] == content
|
||||
assert messages[0]["metadata"]["content_is_json"] is True
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_pagination(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
@@ -373,6 +395,55 @@ class TestDbRunEventStore:
|
||||
assert seqs == list(range(1, 51))
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_batch_accepts_structured_content(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
content = [{"messages": [{"type": "ai", "content": ""}]}]
|
||||
results = await s.put_batch(
|
||||
[
|
||||
{
|
||||
"thread_id": "t1",
|
||||
"run_id": "r1",
|
||||
"event_type": "run.end",
|
||||
"category": "outputs",
|
||||
"content": content,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert results[0]["content"] == content
|
||||
assert results[0]["metadata"]["content_is_json"] is True
|
||||
|
||||
events = await s.list_events("t1", "r1")
|
||||
assert events[0]["content"] == content
|
||||
assert events[0]["metadata"]["content_is_json"] is True
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dict_content_keeps_legacy_metadata_flag(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
content = {"status": "success"}
|
||||
record = await s.put(thread_id="t1", run_id="r1", event_type="run.end", category="outputs", content=content)
|
||||
|
||||
assert record["content"] == content
|
||||
assert record["metadata"]["content_is_json"] is True
|
||||
assert record["metadata"]["content_is_dict"] is True
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
# -- Factory tests --
|
||||
|
||||
|
||||
@@ -166,6 +166,61 @@ class TestRunRepository:
|
||||
assert row["total_tokens"] == 100
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("success-run", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"success-run",
|
||||
status="success",
|
||||
total_input_tokens=70,
|
||||
total_output_tokens=30,
|
||||
total_tokens=100,
|
||||
lead_agent_tokens=80,
|
||||
subagent_tokens=15,
|
||||
middleware_tokens=5,
|
||||
)
|
||||
await repo.put("error-run", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"error-run",
|
||||
status="error",
|
||||
total_input_tokens=20,
|
||||
total_output_tokens=30,
|
||||
total_tokens=50,
|
||||
lead_agent_tokens=40,
|
||||
subagent_tokens=10,
|
||||
)
|
||||
await repo.put("running-run", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"running-run",
|
||||
status="running",
|
||||
total_input_tokens=900,
|
||||
total_output_tokens=99,
|
||||
total_tokens=999,
|
||||
lead_agent_tokens=999,
|
||||
)
|
||||
await repo.put("other-thread-run", thread_id="t2", status="running")
|
||||
await repo.update_run_completion(
|
||||
"other-thread-run",
|
||||
status="success",
|
||||
total_tokens=888,
|
||||
lead_agent_tokens=888,
|
||||
)
|
||||
|
||||
agg = await repo.aggregate_tokens_by_thread("t1")
|
||||
|
||||
assert agg["total_tokens"] == 150
|
||||
assert agg["total_input_tokens"] == 90
|
||||
assert agg["total_output_tokens"] == 60
|
||||
assert agg["total_runs"] == 2
|
||||
assert agg["by_model"] == {"unknown": {"tokens": 150, "runs": 2}}
|
||||
assert agg["by_caller"] == {
|
||||
"lead_agent": 120,
|
||||
"subagent": 25,
|
||||
"middleware": 5,
|
||||
}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_ordered_desc(self, tmp_path):
|
||||
"""list_by_thread returns newest first."""
|
||||
|
||||
@@ -3,6 +3,8 @@ from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, call
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from deerflow.runtime.runs.manager import RunManager
|
||||
from deerflow.runtime.runs.schemas import RunStatus
|
||||
@@ -16,6 +18,14 @@ class FakeCheckpointer:
|
||||
self.aput_writes = AsyncMock()
|
||||
|
||||
|
||||
def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int):
|
||||
checkpoint = empty_checkpoint()
|
||||
checkpoint["id"] = checkpoint_id
|
||||
checkpoint["channel_values"] = {"messages": messages}
|
||||
checkpoint["channel_versions"] = {"messages": version}
|
||||
return checkpoint
|
||||
|
||||
|
||||
def test_build_runtime_context_includes_app_config_when_present():
|
||||
app_config = object()
|
||||
|
||||
@@ -110,16 +120,16 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_not_awaited()
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{
|
||||
"id": "ckpt-1",
|
||||
"channel_versions": {"messages": 3},
|
||||
"channel_values": {"messages": ["before"]},
|
||||
},
|
||||
{"source": "input"},
|
||||
{"messages": 3},
|
||||
)
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
|
||||
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
assert restored_checkpoint["id"] != "ckpt-1"
|
||||
assert "channel_versions" in restored_checkpoint
|
||||
assert "channel_values" in restored_checkpoint
|
||||
assert restored_checkpoint["channel_versions"] == {"messages": 3}
|
||||
assert restored_checkpoint["channel_values"] == {"messages": ["before"]}
|
||||
assert restored_metadata == {"source": "input"}
|
||||
assert new_versions == {"messages": 3}
|
||||
assert checkpointer.aput_writes.await_args_list == [
|
||||
call(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
||||
@@ -134,6 +144,40 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer():
|
||||
checkpointer = InMemorySaver()
|
||||
thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
before_checkpoint = _make_checkpoint("0001", ["before"], 1)
|
||||
before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1})
|
||||
after_checkpoint = _make_checkpoint("0002", ["after"], 2)
|
||||
after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2})
|
||||
checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after")
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="0001",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": before_checkpoint,
|
||||
"metadata": {"step": 1},
|
||||
"pending_writes": [("task-before", "messages", "pending-before")],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
latest = checkpointer.get_tuple(thread_config)
|
||||
|
||||
assert latest is not None
|
||||
assert latest.config["configurable"]["checkpoint_id"] != "0001"
|
||||
assert latest.config["configurable"]["checkpoint_id"] != "0002"
|
||||
assert latest.checkpoint["channel_values"] == {"messages": ["before"]}
|
||||
assert latest.pending_writes == [("task-before", "messages", "pending-before")]
|
||||
assert ("task-after", "messages", "pending-after") not in latest.pending_writes
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_deletes_thread_when_no_snapshot_exists():
|
||||
checkpointer = FakeCheckpointer(put_result=None)
|
||||
@@ -194,12 +238,13 @@ async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{"id": "ckpt-1", "channel_versions": {}},
|
||||
{},
|
||||
{},
|
||||
)
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
|
||||
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
assert restored_checkpoint["id"] != "ckpt-1"
|
||||
assert restored_checkpoint["channel_versions"] == {}
|
||||
assert restored_metadata == {}
|
||||
assert new_versions == {}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
@@ -7,6 +7,7 @@ import yaml
|
||||
|
||||
from deerflow.config import app_config as app_config_module
|
||||
from deerflow.config import extensions_config as extensions_config_module
|
||||
from deerflow.config import skills_config as skills_config_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.paths import Paths
|
||||
@@ -35,6 +36,7 @@ def test_default_runtime_paths_resolve_from_current_project(tmp_path: Path, monk
|
||||
encoding="utf-8",
|
||||
)
|
||||
(tmp_path / "extensions_config.json").write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8")
|
||||
(tmp_path / "skills").mkdir()
|
||||
|
||||
assert AppConfig.resolve_config_path() == tmp_path / "config.yaml"
|
||||
assert ExtensionsConfig.resolve_config_path() == tmp_path / "extensions_config.json"
|
||||
@@ -121,6 +123,40 @@ def test_app_config_falls_back_to_legacy_when_project_root_lacks_config(tmp_path
|
||||
assert AppConfig.resolve_config_path() == legacy_backend_config
|
||||
|
||||
|
||||
def test_skills_config_falls_back_to_legacy_when_project_root_lacks_skills(tmp_path: Path, monkeypatch):
|
||||
"""When DEER_FLOW_PROJECT_ROOT is unset and cwd has no `skills/`, the legacy
|
||||
repo-root candidate must be used so monorepo runs (cwd=backend/) keep finding
|
||||
`<repo>/skills` instead of `<repo>/backend/skills` (regression test for #2694)."""
|
||||
_clear_path_env(monkeypatch)
|
||||
cwd = tmp_path / "cwd"
|
||||
cwd.mkdir()
|
||||
monkeypatch.chdir(cwd)
|
||||
|
||||
legacy_skills = tmp_path / "legacy-repo" / "skills"
|
||||
legacy_skills.mkdir(parents=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
skills_config_module,
|
||||
"_legacy_skills_candidates",
|
||||
lambda: (legacy_skills,),
|
||||
)
|
||||
|
||||
assert SkillsConfig().get_skills_path() == legacy_skills
|
||||
|
||||
|
||||
def test_skills_config_returns_project_default_when_neither_exists(tmp_path: Path, monkeypatch):
|
||||
"""When nothing exists, fall back to the project-root default path so callers
|
||||
surface a stable empty location instead of silently picking a stale legacy dir."""
|
||||
_clear_path_env(monkeypatch)
|
||||
cwd = tmp_path / "cwd"
|
||||
cwd.mkdir()
|
||||
monkeypatch.chdir(cwd)
|
||||
|
||||
monkeypatch.setattr(skills_config_module, "_legacy_skills_candidates", lambda: ())
|
||||
|
||||
assert SkillsConfig().get_skills_path() == cwd / "skills"
|
||||
|
||||
|
||||
def test_extensions_config_falls_back_to_legacy_when_project_root_lacks_file(tmp_path: Path, monkeypatch):
|
||||
"""ExtensionsConfig should hit the legacy backend/repo-root locations when
|
||||
the caller project root has no extensions_config.json/mcp_config.json."""
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
"""Unit tests for the Serper community web search tool."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_api_key_warned():
|
||||
"""Reset the module-level warning flag before each test."""
|
||||
import deerflow.community.serper.tools as serper_mod
|
||||
|
||||
serper_mod._api_key_warned = False
|
||||
yield
|
||||
serper_mod._api_key_warned = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_with_key():
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "test-serper-key", "max_results": 5}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_no_key():
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
yield mock
|
||||
|
||||
|
||||
def _make_serper_response(organic: list) -> MagicMock:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {"organic": organic}
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
return mock_resp
|
||||
|
||||
|
||||
class TestGetApiKey:
|
||||
def test_returns_config_key_when_present(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "from-config"}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "from-config"
|
||||
|
||||
def test_falls_back_to_env_when_config_key_empty(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": ""}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-key"
|
||||
|
||||
def test_falls_back_to_env_when_config_key_whitespace(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": " "}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-key"
|
||||
|
||||
def test_falls_back_to_env_when_config_key_null(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": None}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-key"
|
||||
|
||||
def test_falls_back_to_env_when_no_config(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-only"}):
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-only"
|
||||
|
||||
def test_returns_none_when_no_key_anywhere(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
import os
|
||||
|
||||
os.environ.pop("SERPER_API_KEY", None)
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() is None
|
||||
|
||||
|
||||
class TestWebSearchTool:
|
||||
def test_basic_search_returns_normalized_results(self, mock_config_with_key):
|
||||
organic = [
|
||||
{"title": "Result 1", "link": "https://example.com/1", "snippet": "Snippet 1"},
|
||||
{"title": "Result 2", "link": "https://example.com/2", "snippet": "Snippet 2"},
|
||||
]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "python tutorial"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["query"] == "python tutorial"
|
||||
assert parsed["total_results"] == 2
|
||||
assert parsed["results"][0]["title"] == "Result 1"
|
||||
assert parsed["results"][0]["url"] == "https://example.com/1"
|
||||
assert parsed["results"][0]["content"] == "Snippet 1"
|
||||
|
||||
def test_respects_max_results_from_config(self, mock_config_with_key):
|
||||
mock_config_with_key.return_value.get_tool_config.return_value.model_extra = {
|
||||
"api_key": "test-key",
|
||||
"max_results": 3,
|
||||
}
|
||||
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["total_results"] == 3
|
||||
assert len(parsed["results"]) == 3
|
||||
|
||||
def test_max_results_parameter_accepted(self, mock_config_no_key):
|
||||
"""Tool accepts max_results as a call parameter when config does not override it."""
|
||||
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test", "max_results": 2})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["total_results"] == 2
|
||||
|
||||
def test_config_max_results_overrides_parameter(self):
|
||||
"""Config max_results overrides the parameter passed at call time, matching ddg_search behaviour."""
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "test-key", "max_results": 3}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
|
||||
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test", "max_results": 8})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["total_results"] == 3
|
||||
|
||||
def test_empty_organic_returns_error_json(self, mock_config_with_key):
|
||||
"""Empty organic list returns structured error, matching ddg_search convention."""
|
||||
mock_resp = _make_serper_response([])
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "no results"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
assert parsed["error"] == "No results found"
|
||||
assert parsed["query"] == "no results"
|
||||
|
||||
def test_missing_api_key_returns_error_json(self, mock_config_no_key):
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
import os
|
||||
|
||||
os.environ.pop("SERPER_API_KEY", None)
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
assert "SERPER_API_KEY" in parsed["error"]
|
||||
|
||||
def test_missing_api_key_logs_warning_once(self, mock_config_no_key, caplog):
|
||||
import logging
|
||||
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
import os
|
||||
|
||||
os.environ.pop("SERPER_API_KEY", None)
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.community.serper.tools"):
|
||||
web_search_tool.invoke({"query": "q1"})
|
||||
web_search_tool.invoke({"query": "q2"})
|
||||
|
||||
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert len(warnings) == 1
|
||||
|
||||
def test_http_error_returns_structured_error(self, mock_config_with_key):
|
||||
mock_error_response = MagicMock()
|
||||
mock_error_response.status_code = 403
|
||||
mock_error_response.text = "Forbidden"
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = httpx.HTTPStatusError("403", request=MagicMock(), response=mock_error_response)
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
assert "403" in parsed["error"]
|
||||
|
||||
def test_network_exception_returns_error_json(self, mock_config_with_key):
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
|
||||
def test_sends_correct_headers_and_payload(self, mock_config_with_key):
|
||||
organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
web_search_tool.invoke({"query": "hello world"})
|
||||
|
||||
call_kwargs = mock_post.call_args
|
||||
headers = call_kwargs.kwargs["headers"]
|
||||
payload = call_kwargs.kwargs["json"]
|
||||
|
||||
assert headers["X-API-KEY"] == "test-serper-key"
|
||||
assert payload["q"] == "hello world"
|
||||
assert payload["num"] == 5
|
||||
|
||||
def test_uses_env_key_when_config_absent(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-only-key"}):
|
||||
organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
web_search_tool.invoke({"query": "env key test"})
|
||||
headers = mock_post.call_args.kwargs["headers"]
|
||||
|
||||
assert headers["X-API-KEY"] == "env-only-key"
|
||||
|
||||
def test_partial_fields_in_organic_result(self, mock_config_with_key):
|
||||
"""Missing title/link/snippet should default to empty string."""
|
||||
organic = [{}]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["results"][0] == {"title": "", "url": "", "content": ""}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user