Compare commits

..

1 Commits

Author SHA1 Message Date
greatmengqi edf345cd72 refactor(config): eliminate global mutable state, wire DeerFlowContext into runtime
- Freeze all config models (AppConfig + 15 sub-configs) with frozen=True
- Purify from_file() — remove 9 load_*_from_dict() side-effect calls
- Replace mtime/reload/push/pop machinery with single ContextVar + init_app_config()
- Delete 10 sub-module globals and their getters/setters/loaders
- Migrate 50+ consumers from get_*_config() to get_app_config().xxx

- Expand DeerFlowContext: app_config + thread_id + agent_name (frozen dataclass)
- Wire into Gateway runtime (worker.py) and DeerFlowClient via context= parameter
- Remove sandbox_id from runtime.context — flows through ThreadState.sandbox only
- Middleware/tools access runtime.context directly via Runtime[DeerFlowContext] generic
- resolve_context() retained at server entry points for LangGraph Server fallback
2026-04-14 01:18:19 +08:00
234 changed files with 6080 additions and 13795 deletions
-63
View File
@@ -1,63 +0,0 @@
name: E2E Tests
on:
push:
branches: [ 'main' ]
paths:
- 'frontend/**'
- '.github/workflows/e2e-tests.yml'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'frontend/**'
- '.github/workflows/e2e-tests.yml'
concurrency:
group: e2e-tests-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
e2e-tests:
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }}
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Enable Corepack
run: corepack enable
- name: Use pinned pnpm version
run: corepack prepare pnpm@10.26.2 --activate
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Install Playwright Chromium
working-directory: frontend
run: npx playwright install chromium --with-deps
- name: Run E2E tests
working-directory: frontend
run: pnpm exec playwright test
env:
SKIP_ENV_VALIDATION: '1'
- name: Upload Playwright report
uses: actions/upload-artifact@v4
if: ${{ !cancelled() }}
with:
name: playwright-report
path: frontend/playwright-report/
retention-days: 7
-43
View File
@@ -1,43 +0,0 @@
name: Frontend Unit Tests
on:
push:
branches: [ 'main' ]
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
concurrency:
group: frontend-unit-tests-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
frontend-unit-tests:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Enable Corepack
run: corepack enable
- name: Use pinned pnpm version
run: corepack prepare pnpm@10.26.2 --activate
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Run unit tests of frontend
working-directory: frontend
run: make test
-3
View File
@@ -40,7 +40,6 @@ coverage/
skills/custom/*
logs/
log/
debug.log
# Local git hooks (keep only on this machine, do not push)
.githooks/
@@ -56,7 +55,5 @@ web/
backend/Dockerfile.langgraph
config.yaml.bak
.playwright-mcp
/frontend/test-results/
/frontend/playwright-report/
.gstack/
.worktrees
-33
View File
@@ -1,33 +0,0 @@
repos:
# Backend: ruff lint + format via uv (uses the same ruff version as backend deps)
- repo: local
hooks:
- id: ruff
name: ruff lint
entry: bash -c 'cd backend && uv run ruff check --fix "${@/#backend\//}"' --
language: system
types_or: [python]
files: ^backend/
- id: ruff-format
name: ruff format
entry: bash -c 'cd backend && uv run ruff format "${@/#backend\//}"' --
language: system
types_or: [python]
files: ^backend/
# Frontend: eslint + prettier (must run from frontend/ for node_modules resolution)
- repo: local
hooks:
- id: frontend-eslint
name: eslint (frontend)
entry: bash -c 'cd frontend && npx eslint --fix "${@/#frontend\//}"' --
language: system
types_or: [javascript, tsx, ts]
files: ^frontend/
- id: frontend-prettier
name: prettier (frontend)
entry: bash -c 'cd frontend && npx prettier --write "${@/#frontend\//}"' --
language: system
files: ^frontend/
types_or: [javascript, tsx, ts, json, css]
+7 -12
View File
@@ -166,7 +166,7 @@ Required tools:
1. **Configure the application** (same as Docker setup above)
2. **Install dependencies** (this also sets up pre-commit hooks):
2. **Install dependencies**:
```bash
make install
```
@@ -298,24 +298,19 @@ Nginx (port 2026) ← Unified entry point
```bash
# Backend tests
cd backend
make test
uv run pytest
# Frontend unit tests
# Frontend checks
cd frontend
make test
# Frontend E2E tests (requires Chromium; builds and auto-starts the Next.js production server)
cd frontend
make test-e2e
pnpm check
```
### PR Regression Checks
Every pull request triggers the following CI workflows:
Every pull request runs the backend regression workflow at [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml), including:
- **Backend unit tests** — [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml)
- **Frontend unit tests** — [.github/workflows/frontend-unit-tests.yml](.github/workflows/frontend-unit-tests.yml)
- **Frontend E2E tests** — [.github/workflows/e2e-tests.yml](.github/workflows/e2e-tests.yml) (triggered only when `frontend/` files change)
- `tests/test_provisioner_kubeconfig.py`
- `tests/test_docker_sandbox_mode_detection.py`
## Code Style
+2 -4
View File
@@ -23,7 +23,7 @@ help:
@echo " make config - Generate local config files (aborts if config already exists)"
@echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml"
@echo " make check - Check if all required tools are installed"
@echo " make install - Install all dependencies (frontend + backend + pre-commit hooks)"
@echo " make install - Install all dependencies (frontend + backend)"
@echo " make setup-sandbox - Pre-pull sandbox container image (recommended)"
@echo " make dev - Start all services in development mode (with hot-reloading)"
@echo " make dev-pro - Start in dev + Gateway mode (experimental, no LangGraph server)"
@@ -73,8 +73,6 @@ install:
@cd backend && uv sync
@echo "Installing frontend dependencies..."
@cd frontend && pnpm install
@echo "Installing pre-commit hooks..."
@$(BACKEND_UV_RUN) --with pre-commit pre-commit install
@echo "✓ All dependencies installed"
@echo ""
@echo "=========================================="
@@ -101,7 +99,7 @@ setup-sandbox:
echo ""; \
if command -v container >/dev/null 2>&1 && [ "$$(uname)" = "Darwin" ]; then \
echo "Detected Apple Container on macOS, pulling image..."; \
container image pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \
container pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \
fi; \
if command -v docker >/dev/null 2>&1; then \
echo "Pulling image using Docker..."; \
+1 -3
View File
@@ -264,7 +264,7 @@ On Windows, run the local development flow from Git Bash. Native `cmd.exe` and P
2. **Install dependencies**:
```bash
make install # Install backend + frontend dependencies + pre-commit hooks
make install # Install backend + frontend dependencies
```
3. **(Optional) Pre-pull sandbox image**:
@@ -658,8 +658,6 @@ This is the difference between a chatbot with tool access and an agent with an a
**Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window.
**Strict Tool-Call Recovery**: When a provider or middleware interrupts a tool-call loop, DeerFlow now strips provider-level raw tool-call metadata on forced-stop assistant messages and injects placeholder tool results for dangling calls before the next model invocation. This keeps OpenAI-compatible reasoning models that strictly validate `tool_call_id` sequences from failing with malformed history errors.
### Long-Term Memory
Most agents forget everything the moment a conversation ends. DeerFlow remembers.
+13 -17
View File
@@ -156,26 +156,20 @@ from deerflow.config import get_app_config
### Middleware Chain
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`_build_middlewares`):
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption), including raw provider tool-call payloads preserved only in `additional_kwargs["tool_calls"]`
5. **LLMErrorHandlingMiddleware** - Normalizes provider/model invocation failures into recoverable assistant-facing errors before later middleware/tool stages run
6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
15. **DeferredToolFilterMiddleware** - Hides deferred tool schemas from the bound model until tool search is enabled (optional)
16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
5. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
6. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
7. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
8. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
9. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
10. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
11. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if subagent_enabled)
12. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
### Configuration System
@@ -185,7 +179,9 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
**Config Versioning**: `config.example.yaml` has a `config_version` field. On startup, `AppConfig.from_file()` compares user version vs example version and emits a warning if outdated. Missing `config_version` = version 0. Run `make config-upgrade` to auto-merge missing fields. When changing the config schema, bump `config_version` in `config.example.yaml`.
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
**Config Lifecycle**: All config models are `frozen=True` (immutable after construction). `AppConfig.from_file()` is a pure function — no side effects on sub-module globals. `get_app_config()` is backed by a single `ContextVar`, set once via `init_app_config()` at process startup. To update config at runtime (e.g., Gateway API updates MCP/Skills), construct a new `AppConfig.from_file()` and call `init_app_config()` again. No mtime detection, no auto-reload.
**DeerFlowContext**: Per-invocation typed context for the agent execution path, injected via LangGraph `Runtime[DeerFlowContext]`. Holds `app_config: AppConfig`, `thread_id: str`, `agent_name: str | None`. Gateway runtime and `DeerFlowClient` construct full `DeerFlowContext` at invoke time; LangGraph Server path uses a fallback via `resolve_context()`. Middleware and tools access context through `resolve_context(runtime)` which returns a typed `DeerFlowContext` regardless of entry point. Mutable runtime state (`sandbox_id`) flows through `ThreadState.sandbox`, not context.
Configuration priority:
1. Explicit `config_path` argument
+3 -22
View File
@@ -23,16 +23,6 @@ _CHANNEL_REGISTRY: dict[str, str] = {
"wecom": "app.channels.wecom:WeComChannel",
}
# Keys that indicate a user has configured credentials for a channel.
_CHANNEL_CREDENTIAL_KEYS: dict[str, list[str]] = {
"discord": ["bot_token"],
"feishu": ["app_id", "app_secret"],
"slack": ["bot_token", "app_token"],
"telegram": ["bot_token"],
"wecom": ["bot_id", "bot_secret"],
"wechat": ["bot_token"],
}
_CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
@@ -77,9 +67,9 @@ class ChannelService:
@classmethod
def from_app_config(cls) -> ChannelService:
"""Create a ChannelService from the application config."""
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
config = get_app_config()
config = AppConfig.current()
channels_config = {}
# extra fields are allowed by AppConfig (extra="allow")
extra = config.model_extra or {}
@@ -98,16 +88,7 @@ class ChannelService:
if not isinstance(channel_config, dict):
continue
if not channel_config.get("enabled", False):
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
if has_creds:
logger.warning(
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
name,
name,
)
else:
logger.info("Channel %s is disabled, skipping", name)
logger.info("Channel %s is disabled, skipping", name)
continue
await self._start_channel(name, channel_config)
+2 -20
View File
@@ -16,31 +16,13 @@ logger = logging.getLogger(__name__)
_slack_md_converter = SlackMarkdownConverter()
def _normalize_allowed_users(allowed_users: Any) -> set[str]:
if allowed_users is None:
return set()
if isinstance(allowed_users, str):
values = [allowed_users]
elif isinstance(allowed_users, list | tuple | set):
values = allowed_users
else:
logger.warning(
"Slack allowed_users should be a list of Slack user IDs or a single Slack user ID string; treating %s as one string value",
type(allowed_users).__name__,
)
values = [allowed_users]
return {str(user_id) for user_id in values if str(user_id)}
class SlackChannel(Channel):
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
Configuration keys (in ``config.yaml`` under ``channels.slack``):
- ``bot_token``: Slack Bot User OAuth Token (xoxb-...).
- ``app_token``: Slack App-Level Token (xapp-...) for Socket Mode.
- ``allowed_users``: (optional) List of allowed Slack user IDs, or a
single Slack user ID string as shorthand. Empty = allow all. Other
scalar values are treated as a single string with a warning.
- ``allowed_users``: (optional) List of allowed Slack user IDs. Empty = allow all.
"""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
@@ -48,7 +30,7 @@ class SlackChannel(Channel):
self._socket_client = None
self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])}
async def start(self) -> None:
if self._running:
+4 -18
View File
@@ -1,4 +1,3 @@
import asyncio
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
@@ -22,7 +21,7 @@ from app.gateway.routers import (
threads,
uploads,
)
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
# Configure logging
logging.basicConfig(
@@ -33,11 +32,6 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
# Upper bound (seconds) each lifespan shutdown hook is allowed to run.
# Bounds worker exit time so uvicorn's reload supervisor does not keep
# firing signals into a worker that is stuck waiting for shutdown cleanup.
_SHUTDOWN_HOOK_TIMEOUT_SECONDS = 5.0
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
@@ -45,7 +39,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Load config and check necessary environment variables at startup
try:
get_app_config()
AppConfig.current()
logger.info("Configuration loaded successfully")
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"
@@ -69,19 +63,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
yield
# Stop channel service on shutdown (bounded to prevent worker hang)
# Stop channel service on shutdown
try:
from app.channels.service import stop_channel_service
await asyncio.wait_for(
stop_channel_service(),
timeout=_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
)
except TimeoutError:
logger.warning(
"Channel service shutdown exceeded %.1fs; proceeding with worker exit.",
_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
)
await stop_channel_service()
except Exception:
logger.exception("Failed to stop channel service")
+4 -42
View File
@@ -8,7 +8,6 @@ import yaml
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config.agents_api_config import get_agents_api_config
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
from deerflow.config.paths import get_paths
@@ -25,7 +24,6 @@ class AgentResponse(BaseModel):
description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
skills: list[str] | None = Field(default=None, description="Optional skill whitelist (None=all, []=none)")
soul: str | None = Field(default=None, description="SOUL.md content")
@@ -42,7 +40,6 @@ class AgentCreateRequest(BaseModel):
description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
skills: list[str] | None = Field(default=None, description="Optional skill whitelist (None=all enabled, []=none)")
soul: str = Field(default="", description="SOUL.md content — agent personality and behavioral guardrails")
@@ -52,7 +49,6 @@ class AgentUpdateRequest(BaseModel):
description: str | None = Field(default=None, description="Updated description")
model: str | None = Field(default=None, description="Updated model override")
tool_groups: list[str] | None = Field(default=None, description="Updated tool group whitelist")
skills: list[str] | None = Field(default=None, description="Updated skill whitelist (None=all, []=none)")
soul: str | None = Field(default=None, description="Updated SOUL.md content")
@@ -77,15 +73,6 @@ def _normalize_agent_name(name: str) -> str:
return name.lower()
def _require_agents_api_enabled() -> None:
"""Reject access unless the custom-agent management API is explicitly enabled."""
if not get_agents_api_config().enabled:
raise HTTPException(
status_code=403,
detail=("Custom-agent management API is disabled. Set agents_api.enabled=true to expose agent and user-profile routes over HTTP."),
)
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
"""Convert AgentConfig to AgentResponse."""
soul: str | None = None
@@ -97,7 +84,6 @@ def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False
description=agent_cfg.description,
model=agent_cfg.model,
tool_groups=agent_cfg.tool_groups,
skills=agent_cfg.skills,
soul=soul,
)
@@ -114,8 +100,6 @@ async def list_agents() -> AgentsListResponse:
Returns:
List of all custom agents with their metadata and soul content.
"""
_require_agents_api_enabled()
try:
agents = list_custom_agents()
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
@@ -141,7 +125,6 @@ async def check_agent_name(name: str) -> dict:
Raises:
HTTPException: 422 if the name is invalid.
"""
_require_agents_api_enabled()
_validate_agent_name(name)
normalized = _normalize_agent_name(name)
available = not get_paths().agent_dir(normalized).exists()
@@ -166,7 +149,6 @@ async def get_agent(name: str) -> AgentResponse:
Raises:
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
@@ -199,7 +181,6 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
Raises:
HTTPException: 409 if agent already exists, 422 if name is invalid.
"""
_require_agents_api_enabled()
_validate_agent_name(request.name)
normalized_name = _normalize_agent_name(request.name)
@@ -219,8 +200,6 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
config_data["model"] = request.model
if request.tool_groups is not None:
config_data["tool_groups"] = request.tool_groups
if request.skills is not None:
config_data["skills"] = request.skills
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
@@ -264,7 +243,6 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
Raises:
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
@@ -277,32 +255,21 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
try:
# Update config if any config fields changed
# Use model_fields_set to distinguish "field omitted" from "explicitly set to null".
# This is critical for skills where None means "inherit all" (not "don't change").
fields_set = request.model_fields_set
config_changed = bool(fields_set & {"description", "model", "tool_groups", "skills"})
config_changed = any(v is not None for v in [request.description, request.model, request.tool_groups])
if config_changed:
updated: dict = {
"name": agent_cfg.name,
"description": request.description if "description" in fields_set else agent_cfg.description,
"description": request.description if request.description is not None else agent_cfg.description,
}
new_model = request.model if "model" in fields_set else agent_cfg.model
new_model = request.model if request.model is not None else agent_cfg.model
if new_model is not None:
updated["model"] = new_model
new_tool_groups = request.tool_groups if "tool_groups" in fields_set else agent_cfg.tool_groups
new_tool_groups = request.tool_groups if request.tool_groups is not None else agent_cfg.tool_groups
if new_tool_groups is not None:
updated["tool_groups"] = new_tool_groups
# skills: None = inherit all, [] = no skills, ["a","b"] = whitelist
if "skills" in fields_set:
new_skills = request.skills
else:
new_skills = agent_cfg.skills
if new_skills is not None:
updated["skills"] = new_skills
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(updated, f, default_flow_style=False, allow_unicode=True)
@@ -348,8 +315,6 @@ async def get_user_profile() -> UserProfileResponse:
Returns:
UserProfileResponse with content=None if USER.md does not exist yet.
"""
_require_agents_api_enabled()
try:
user_md_path = get_paths().user_md_file
if not user_md_path.exists():
@@ -376,8 +341,6 @@ async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileR
Returns:
UserProfileResponse with the saved content.
"""
_require_agents_api_enabled()
try:
paths = get_paths()
paths.base_dir.mkdir(parents=True, exist_ok=True)
@@ -404,7 +367,6 @@ async def delete_agent(name: str) -> None:
Raises:
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
+9 -7
View File
@@ -6,7 +6,8 @@ from typing import Literal
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["mcp"])
@@ -90,9 +91,9 @@ async def get_mcp_configuration() -> McpConfigResponse:
}
```
"""
config = get_extensions_config()
ext = AppConfig.current().extensions
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in config.mcp_servers.items()})
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in ext.mcp_servers.items()})
@router.put(
@@ -143,12 +144,12 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
# Load current config to preserve skills configuration
current_config = get_extensions_config()
current_ext = AppConfig.current().extensions
# Convert request to dict format for JSON serialization
config_data = {
"mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()},
}
# Write the configuration to file
@@ -161,8 +162,9 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
# will detect config file changes via mtime and reinitialize MCP tools automatically
# Reload the configuration and update the global cache
reloaded_config = reload_extensions_config()
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_config.mcp_servers.items()})
AppConfig.init(AppConfig.from_file())
reloaded_ext = AppConfig.current().extensions
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_ext.mcp_servers.items()})
except Exception as e:
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
+3 -3
View File
@@ -12,7 +12,7 @@ from deerflow.agents.memory.updater import (
reload_memory_data,
update_memory_fact,
)
from deerflow.config.memory_config import get_memory_config
from deerflow.config.app_config import AppConfig
router = APIRouter(prefix="/api", tags=["memory"])
@@ -311,7 +311,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
}
```
"""
config = get_memory_config()
config = AppConfig.current().memory
return MemoryConfigResponse(
enabled=config.enabled,
storage_path=config.storage_path,
@@ -336,7 +336,7 @@ async def get_memory_status() -> MemoryStatusResponse:
Returns:
Combined memory configuration and current data.
"""
config = get_memory_config()
config = AppConfig.current().memory
memory_data = get_memory_data()
return MemoryStatusResponse(
+8 -25
View File
@@ -1,7 +1,7 @@
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
router = APIRouter(prefix="/api", tags=["models"])
@@ -17,17 +17,10 @@ class ModelResponse(BaseModel):
supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort")
class TokenUsageResponse(BaseModel):
"""Token usage display configuration."""
enabled: bool = Field(default=False, description="Whether token usage display is enabled")
class ModelsListResponse(BaseModel):
"""Response model for listing all models."""
models: list[ModelResponse]
token_usage: TokenUsageResponse
@router.get(
@@ -43,7 +36,7 @@ async def list_models() -> ModelsListResponse:
excluding sensitive fields like API keys and internal configuration.
Returns:
A list of all configured models with their metadata and token usage display settings.
A list of all configured models with their metadata.
Example Response:
```json
@@ -51,28 +44,21 @@ async def list_models() -> ModelsListResponse:
"models": [
{
"name": "gpt-4",
"model": "gpt-4",
"display_name": "GPT-4",
"description": "OpenAI GPT-4 model",
"supports_thinking": false,
"supports_reasoning_effort": false
"supports_thinking": false
},
{
"name": "claude-3-opus",
"model": "claude-3-opus",
"display_name": "Claude 3 Opus",
"description": "Anthropic Claude 3 Opus model",
"supports_thinking": true,
"supports_reasoning_effort": false
"supports_thinking": true
}
],
"token_usage": {
"enabled": true
}
]
}
```
"""
config = get_app_config()
config = AppConfig.current()
models = [
ModelResponse(
name=model.name,
@@ -84,10 +70,7 @@ async def list_models() -> ModelsListResponse:
)
for model in config.models
]
return ModelsListResponse(
models=models,
token_usage=TokenUsageResponse(enabled=config.token_usage.enabled),
)
return ModelsListResponse(models=models)
@router.get(
@@ -118,7 +101,7 @@ async def get_model(model_name: str) -> ModelResponse:
}
```
"""
config = get_app_config()
config = AppConfig.current()
model = config.get_model_config(model_name)
if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
+19 -24
View File
@@ -1,4 +1,3 @@
import errno
import json
import logging
import shutil
@@ -9,7 +8,8 @@ from pydantic import BaseModel, Field
from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
from deerflow.skills.manager import (
@@ -202,23 +202,18 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
ensure_custom_skill_is_editable(skill_name)
skill_dir = get_custom_skill_dir(skill_name)
prev_content = read_custom_skill_content(skill_name)
try:
append_history(
skill_name,
{
"action": "human_delete",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": prev_content,
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
)
except OSError as e:
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
raise
logger.warning("Skipping delete history write for custom skill %s due to readonly/permission failure; continuing with skill directory removal: %s", skill_name, e)
append_history(
skill_name,
{
"action": "human_delete",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": prev_content,
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
)
shutil.rmtree(skill_dir)
await refresh_skills_system_prompt_cache_async()
return {"success": True}
@@ -331,19 +326,19 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
extensions_config = get_extensions_config()
extensions_config.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
ext = AppConfig.current().extensions
ext.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
config_data = {
"mcpServers": {name: server.model_dump() for name, server in extensions_config.mcp_servers.items()},
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in extensions_config.skills.items()},
"mcpServers": {name: server.model_dump() for name, server in ext.mcp_servers.items()},
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in ext.skills.items()},
}
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
logger.info(f"Skills configuration updated and saved to: {config_path}")
reload_extensions_config()
AppConfig.init(AppConfig.from_file())
await refresh_skills_system_prompt_cache_async()
skills = load_skills(enabled_only=False)
+1 -1
View File
@@ -121,7 +121,7 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
try:
model = create_chat_model(name=request.model_name, thinking_enabled=False)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"})
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
+6 -39
View File
@@ -7,9 +7,8 @@ import stat
from fastapi import APIRouter, File, HTTPException, UploadFile
from pydantic import BaseModel
from deerflow.config.app_config import get_app_config
from deerflow.config.paths import get_paths
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.uploads.manager import (
PathTraversalError,
delete_file_safe,
@@ -54,34 +53,6 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
os.chmod(file_path, writable_mode, **chmod_kwargs)
def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
def _get_uploads_config_value(key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default)
def _auto_convert_documents_enabled() -> bool:
"""Return whether automatic host-side document conversion is enabled.
The secure default is disabled unless an operator explicitly opts in via
uploads.auto_convert_documents in config.yaml.
"""
try:
raw = _get_uploads_config_value("auto_convert_documents", False)
if isinstance(raw, str):
return raw.strip().lower() in {"1", "true", "yes", "on"}
return bool(raw)
except Exception:
return False
@router.post("", response_model=UploadResponse)
async def upload_files(
thread_id: str,
@@ -99,12 +70,8 @@ async def upload_files(
uploaded_files = []
sandbox_provider = get_sandbox_provider()
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
sandbox = None
if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
auto_convert_documents = _auto_convert_documents_enabled()
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
for file in files:
if not file.filename:
@@ -123,7 +90,7 @@ async def upload_files(
virtual_path = upload_virtual_path(safe_filename)
if sync_to_sandbox and sandbox is not None:
if sandbox_id != "local":
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, content)
@@ -138,12 +105,12 @@ async def upload_files(
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
file_ext = file_path.suffix.lower()
if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS:
if file_ext in CONVERTIBLE_EXTENSIONS:
md_path = await convert_file_to_markdown(file_path)
if md_path:
md_virtual_path = upload_virtual_path(md_path.name)
if sync_to_sandbox and sandbox is not None:
if sandbox_id != "local":
_make_file_sandbox_writable(md_path)
sandbox.update_file(md_virtual_path, md_path.read_bytes())
+15 -34
View File
@@ -12,7 +12,6 @@ import json
import logging
import re
import time
from collections.abc import Mapping
from typing import Any
from fastapi import HTTPException, Request
@@ -102,10 +101,9 @@ def resolve_agent_factory(assistant_id: str | None):
"""Resolve the agent factory callable from config.
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
injected into ``configurable`` or ``context`` — see
:func:`build_run_config`. All ``assistant_id`` values therefore map to the
same factory; the routing happens inside ``make_lead_agent`` when it reads
``cfg["agent_name"]``.
injected into ``configurable`` — see :func:`build_run_config`. All
``assistant_id`` values therefore map to the same factory; the routing
happens inside ``make_lead_agent`` when it reads ``cfg["agent_name"]``.
"""
from deerflow.agents.lead_agent.agent import make_lead_agent
@@ -122,12 +120,10 @@ def build_run_config(
"""Build a RunnableConfig dict for the agent.
When *assistant_id* refers to a custom agent (anything other than
``"lead_agent"`` / ``None``), the name is forwarded as ``agent_name`` in
whichever runtime options container is active: ``context`` for
LangGraph >= 0.6.0 requests, otherwise ``configurable``.
``make_lead_agent`` reads this key to load the matching
``agents/<name>/SOUL.md`` and per-agent config — without it the agent
silently runs as the default lead agent.
``"lead_agent"`` / ``None``), the name is forwarded as
``configurable["agent_name"]``. ``make_lead_agent`` reads this key to
load the matching ``agents/<name>/SOUL.md`` and per-agent config —
without it the agent silently runs as the default lead agent.
This mirrors the channel manager's ``_resolve_run_params`` logic so that
the LangGraph Platform-compatible HTTP API and the IM channel path behave
@@ -146,14 +142,7 @@ def build_run_config(
thread_id,
list(request_config.get("configurable", {}).keys()),
)
context_value = request_config["context"]
if context_value is None:
context = {}
elif isinstance(context_value, Mapping):
context = dict(context_value)
else:
raise ValueError("request config 'context' must be a mapping or null.")
config["context"] = context
config["context"] = request_config["context"]
else:
configurable = {"thread_id": thread_id}
configurable.update(request_config.get("configurable", {}))
@@ -165,19 +154,13 @@ def build_run_config(
config["configurable"] = {"thread_id": thread_id}
# Inject custom agent name when the caller specified a non-default assistant.
# Honour an explicit agent_name in the active runtime options container.
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID:
normalized = assistant_id.strip().lower().replace("_", "-")
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
if "configurable" in config:
target = config["configurable"]
elif "context" in config:
target = config["context"]
else:
target = config.setdefault("configurable", {})
if target is not None and "agent_name" not in target:
target["agent_name"] = normalized
# Honour an explicit configurable["agent_name"] in the request if already set.
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "configurable" in config:
if "agent_name" not in config["configurable"]:
normalized = assistant_id.strip().lower().replace("_", "-")
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
config["configurable"]["agent_name"] = normalized
if metadata:
config.setdefault("metadata", {}).update(metadata)
return config
@@ -315,8 +298,6 @@ async def start_run(
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
"agent_name",
"is_bootstrap",
}
configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
+13 -78
View File
@@ -19,78 +19,24 @@ import asyncio
import logging
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage
try:
from prompt_toolkit import PromptSession
from prompt_toolkit.history import InMemoryHistory
_HAS_PROMPT_TOOLKIT = True
except ImportError:
_HAS_PROMPT_TOOLKIT = False
from deerflow.agents import make_lead_agent
load_dotenv()
_LOG_FMT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
_LOG_DATEFMT = "%Y-%m-%d %H:%M:%S"
def _logging_level_from_config(name: str) -> int:
"""Map ``config.yaml`` ``log_level`` string to a ``logging`` level constant."""
mapping = logging.getLevelNamesMapping()
return mapping.get((name or "info").strip().upper(), logging.INFO)
def _setup_logging(log_level: str) -> None:
"""Send application logs to ``debug.log`` at *log_level*; do not print them on the console.
Idempotent: any pre-existing handlers on the root logger (e.g. installed by
``logging.basicConfig`` in transitively imported modules) are removed so the
debug session output only lands in ``debug.log``.
"""
level = _logging_level_from_config(log_level)
root = logging.root
for h in list(root.handlers):
root.removeHandler(h)
h.close()
root.setLevel(level)
file_handler = logging.FileHandler("debug.log", mode="a", encoding="utf-8")
file_handler.setLevel(level)
file_handler.setFormatter(logging.Formatter(_LOG_FMT, datefmt=_LOG_DATEFMT))
root.addHandler(file_handler)
def _update_logging_level(log_level: str) -> None:
"""Update the root logger and existing handlers to *log_level*."""
level = _logging_level_from_config(log_level)
root = logging.root
root.setLevel(level)
for handler in root.handlers:
handler.setLevel(level)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
async def main():
# Install file logging first so warnings emitted while loading config do not
# leak onto the interactive terminal via Python's lastResort handler.
_setup_logging("info")
from deerflow.config import get_app_config
app_config = get_app_config()
_update_logging_level(app_config.log_level)
# Delay the rest of the deerflow imports until *after* logging is installed
# so that any import-time side effects (e.g. deerflow.agents starts a
# background skill-loader thread on import) emit logs to debug.log instead
# of leaking onto the interactive terminal via Python's lastResort handler.
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.agents import make_lead_agent
from deerflow.mcp import initialize_mcp_tools
# Initialize MCP tools at startup
try:
from deerflow.mcp import initialize_mcp_tools
await initialize_mcp_tools()
except Exception as e:
print(f"Warning: Failed to initialize MCP tools: {e}")
@@ -106,27 +52,16 @@ async def main():
}
}
runtime = Runtime(context={"thread_id": config["configurable"]["thread_id"]})
config["configurable"]["__pregel_runtime"] = runtime
agent = make_lead_agent(config)
session = PromptSession(history=InMemoryHistory()) if _HAS_PROMPT_TOOLKIT else None
print("=" * 50)
print("Lead Agent Debug Mode")
print("Type 'quit' or 'exit' to stop")
print(f"Logs: debug.log (log_level={app_config.log_level})")
if not _HAS_PROMPT_TOOLKIT:
print("Tip: `uv sync --group dev` to enable arrow-key & history support")
print("=" * 50)
while True:
try:
if session:
user_input = (await session.prompt_async("\nYou: ")).strip()
else:
user_input = input("\nYou: ").strip()
user_input = input("\nYou: ").strip()
if not user_input:
continue
if user_input.lower() in ("quit", "exit"):
@@ -135,15 +70,15 @@ async def main():
# Invoke the agent
state = {"messages": [HumanMessage(content=user_input)]}
result = await agent.ainvoke(state, config=config)
result = await agent.ainvoke(state, config=config, context={"thread_id": "debug-thread-001"})
# Print the response
if result.get("messages"):
last_message = result["messages"][-1]
print(f"\nAgent: {last_message.content}")
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
except KeyboardInterrupt:
print("\nInterrupted. Goodbye!")
break
except Exception as e:
print(f"\nError: {e}")
+1 -1
View File
@@ -199,7 +199,7 @@ class ThreadState(AgentState):
│ Built-in Tools │ │ Configured Tools │ │ MCP Tools │
│ (packages/harness/deerflow/tools/) │ │ (config.yaml) │ │ (extensions.json) │
├─────────────────────┤ ├─────────────────────┤ ├─────────────────────┤
│ - present_files │ │ - web_search │ │ - github │
│ - present_file │ │ - web_search │ │ - github │
│ - ask_clarification │ │ - web_fetch │ │ - filesystem │
│ - view_image │ │ - bash │ │ - postgres │
│ │ │ - read_file │ │ - brave-search │
+3 -6
View File
@@ -2,12 +2,12 @@
## 概述
DeerFlow 后端提供了完整的文件上传功能,支持多文件上传,并可选地将 Office 文档和 PDF 转换为 Markdown 格式。
DeerFlow 后端提供了完整的文件上传功能,支持多文件上传,并自动将 Office 文档和 PDF 转换为 Markdown 格式。
## 功能特性
- ✅ 支持多文件同时上传
- ✅ 可选地转换文档为 MarkdownPDF、PPT、Excel、Word
- ✅ 自动转换文档为 MarkdownPDF、PPT、Excel、Word
- ✅ 文件存储在线程隔离的目录中
- ✅ Agent 自动感知已上传的文件
- ✅ 支持文件列表查询和删除
@@ -86,7 +86,7 @@ DELETE /api/threads/{thread_id}/uploads/{filename}
## 支持的文档格式
以下格式在显式启用 `uploads.auto_convert_documents: true`会自动转换为 Markdown
以下格式会自动转换为 Markdown:
- PDF (`.pdf`)
- PowerPoint (`.ppt`, `.pptx`)
- Excel (`.xls`, `.xlsx`)
@@ -94,8 +94,6 @@ DELETE /api/threads/{thread_id}/uploads/{filename}
转换后的 Markdown 文件会保存在同一目录下,文件名为原文件名 + `.md` 扩展名。
默认情况下,自动转换是关闭的,以避免在网关主机上对不受信任的 Office/PDF 上传执行解析。只有在受信任部署中明确接受此风险时,才应将 `uploads.auto_convert_documents` 设置为 `true`
## Agent 集成
### 自动文件列举
@@ -209,7 +207,6 @@ backend/.deer-flow/threads/
- 最大文件大小:100MB(可在 nginx.conf 中配置 `client_max_body_size`
- 文件名安全性:系统会自动验证文件路径,防止目录遍历攻击
- 线程隔离:每个线程的上传文件相互隔离,无法跨线程访问
- 自动文档转换默认关闭;如需启用,需在 `config.yaml` 中显式设置 `uploads.auto_convert_documents: true`
## 技术实现
+1 -1
View File
@@ -296,7 +296,7 @@ These are the tool names your provider will see in `request.tool_name`:
| `web_search` | Web search query |
| `web_fetch` | Fetch URL content |
| `image_search` | Image search |
| `present_files` | Present file to user |
| `present_file` | Present file to user |
| `view_image` | Display image |
| `ask_clarification` | Ask user a question |
| `task` | Delegate to subagent |
-35
View File
@@ -45,41 +45,6 @@ Example:
}
```
## Custom Tool Interceptors
You can register custom interceptors that run before every MCP tool call. This is useful for injecting per-request headers (e.g., user auth tokens from the LangGraph execution context), logging, or metrics.
Declare interceptors in `extensions_config.json` using the `mcpInterceptors` field:
```json
{
"mcpInterceptors": [
"my_package.mcp.auth:build_auth_interceptor"
],
"mcpServers": { ... }
}
```
Each entry is a Python import path in `module:variable` format (resolved via `resolve_variable`). The variable must be a **no-arg builder function** that returns an async interceptor compatible with `MultiServerMCPClient`s `tool_interceptors` interface, or `None` to skip.
Example interceptor that injects auth headers from LangGraph metadata:
```python
def build_auth_interceptor():
async def interceptor(request, handler):
from langgraph.config import get_config
metadata = get_config().get("metadata", {})
headers = dict(request.headers or {})
if token := metadata.get("auth_token"):
headers["X-Auth-Token"] = token
return await handler(request.override(headers=headers))
return interceptor
```
- A single string value is accepted and normalized to a one-element list.
- Invalid paths or builder failures are logged as warnings without blocking other interceptors.
- The builder return value must be `callable`; non-callable values are skipped with a warning.
## How It Works
MCP servers expose tools that are automatically discovered and integrated into DeerFlows agent system at runtime. Once enabled, these tools become available to agents without additional code changes.
+3 -3
View File
@@ -11,7 +11,6 @@
- [x] Add Plan Mode with TodoList middleware
- [x] Add vision model support with ViewImageMiddleware
- [x] Skills system with SKILL.md format
- [x] Replace `time.sleep(5)` with `asyncio.sleep()` in `packages/harness/deerflow/tools/builtins/task_tool.py` (subagent polling)
## Planned Features
@@ -22,9 +21,10 @@
- [ ] Support for more document formats in upload
- [ ] Skill marketplace / remote skill installation
- [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario)
- [ ] Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
- Replace `time.sleep(5)` with `asyncio.sleep()` in `packages/harness/deerflow/tools/builtins/task_tool.py` (subagent polling)
- Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
- [x] Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
- Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
- For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker)
-28
View File
@@ -41,13 +41,6 @@ summarization:
# Custom summary prompt (optional)
summary_prompt: null
# Tool names treated as skill file reads for skill rescue
skill_file_read_tool_names:
- read_file
- read
- view
- cat
```
### Configuration Options
@@ -132,26 +125,6 @@ keep:
- **Default**: `null` (uses LangChain's default prompt)
- **Description**: Custom prompt template for generating summaries. The prompt should guide the model to extract the most important context.
#### `preserve_recent_skill_count`
- **Type**: Integer (≥ 0)
- **Default**: `5`
- **Description**: Number of most-recently-loaded skill files (tool results whose tool name is in `skill_file_read_tool_names` and whose target path is under `skills.container_path`, e.g. `/mnt/skills/...`) that are rescued from summarization. Prevents the agent from losing skill instructions after compression. Set to `0` to disable skill rescue entirely.
#### `preserve_recent_skill_tokens`
- **Type**: Integer (≥ 0)
- **Default**: `25000`
- **Description**: Total token budget reserved for rescued skill reads. Once this budget is exhausted, older skill bundles are allowed to be summarized.
#### `preserve_recent_skill_tokens_per_skill`
- **Type**: Integer (≥ 0)
- **Default**: `5000`
- **Description**: Per-skill token cap. Any individual skill read whose tool result exceeds this size is not rescued (it falls through to the summarizer like ordinary content).
#### `skill_file_read_tool_names`
- **Type**: List of strings
- **Default**: `["read_file", "read", "view", "cat"]`
- **Description**: Tool names treated as skill file reads during summarization rescue. A tool call is only eligible for skill rescue when its name appears in this list and its target path is under `skills.container_path`.
**Default Prompt Behavior:**
The default LangChain prompt instructs the model to:
- Extract highest quality/most relevant context
@@ -174,7 +147,6 @@ The default LangChain prompt instructs the model to:
- A single summary message is added
- Recent messages are preserved
6. **AI/Tool Pair Protection**: The system ensures AI messages and their corresponding tool messages stay together
7. **Skill Rescue**: Before the summary is generated, the most recently loaded skill files (tool results whose tool name is in `skill_file_read_tool_names` and whose target path is under `skills.container_path`) are lifted out of the summarization set and prepended to the preserved tail. Selection walks newest-first under three budgets: `preserve_recent_skill_count`, `preserve_recent_skill_tokens`, and `preserve_recent_skill_tokens_per_skill`. The triggering AIMessage and all of its paired ToolMessages move together so tool_call ↔ tool_result pairing stays intact.
### Token Counting
@@ -29,7 +29,7 @@ from deerflow.agents.checkpointer.provider import (
POSTGRES_INSTALL,
SQLITE_INSTALL,
)
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -94,7 +94,7 @@ async def make_checkpointer() -> AsyncIterator[Checkpointer]:
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
config = AppConfig.current()
if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
@@ -25,9 +25,9 @@ from collections.abc import Iterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
from deerflow.runtime.store._sqlite_utils import resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -67,7 +67,6 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
ensure_sqlite_parent_dir(conn_str)
with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
@@ -114,25 +113,10 @@ def get_checkpointer() -> Checkpointer:
if _checkpointer is not None:
return _checkpointer
# Ensure app config is loaded before checking checkpointer config
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
# but hasn't been loaded yet
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
try:
config = AppConfig.current().checkpointer
except (LookupError, FileNotFoundError):
config = None
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
@@ -181,7 +165,7 @@ def checkpointer_context() -> Iterator[Checkpointer]:
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
config = AppConfig.current()
if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
@@ -1,43 +1,32 @@
import logging
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config
from deerflow.config.memory_config import get_memory_config
from deerflow.config.summarization_config import get_summarization_config
from deerflow.config.agents_config import load_agent_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
def _get_runtime_config(config: RunnableConfig) -> dict:
"""Merge legacy configurable options with LangGraph runtime context."""
cfg = dict(config.get("configurable", {}) or {})
context = config.get("context", {}) or {}
if isinstance(context, dict):
cfg.update(context)
return cfg
def _resolve_model_name(requested_model_name: str | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config()
app_config = AppConfig.current()
default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
@@ -50,9 +39,9 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
return default_model_name
def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
def _create_summarization_middleware() -> SummarizationMiddleware | None:
"""Create and configure the summarization middleware from config."""
config = get_summarization_config()
config = AppConfig.current().summarization
if not config.enabled:
return None
@@ -89,28 +78,7 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
if config.summary_prompt is not None:
kwargs["summary_prompt"] = config.summary_prompt
hooks: list[BeforeSummarizationHook] = []
if get_memory_config().enabled:
hooks.append(memory_flush_hook)
# The logic below relies on two assumptions holding true: this factory is
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
# config is not expected to change after startup.
try:
skills_container_path = get_app_config().skills.container_path or "/mnt/skills"
except Exception:
logger.exception("Failed to resolve skills container path; falling back to default")
skills_container_path = "/mnt/skills"
return DeerFlowSummarizationMiddleware(
**kwargs,
skills_container_path=skills_container_path,
skill_file_read_tool_names=config.skill_file_read_tool_names,
before_summarization=hooks,
preserve_recent_skill_count=config.preserve_recent_skill_count,
preserve_recent_skill_tokens=config.preserve_recent_skill_tokens,
preserve_recent_skill_tokens_per_skill=config.preserve_recent_skill_tokens_per_skill,
)
return SummarizationMiddleware(**kwargs)
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
@@ -257,14 +225,13 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
middlewares.append(summarization_middleware)
# Add TodoList middleware if plan mode is enabled
cfg = _get_runtime_config(config)
is_plan_mode = cfg.get("is_plan_mode", False)
is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False)
todo_list_middleware = _create_todo_list_middleware(is_plan_mode)
if todo_list_middleware is not None:
middlewares.append(todo_list_middleware)
# Add TokenUsageMiddleware when token_usage tracking is enabled
if get_app_config().token_usage.enabled:
if AppConfig.current().token_usage.enabled:
middlewares.append(TokenUsageMiddleware())
# Add TitleMiddleware
@@ -275,7 +242,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
# Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = get_app_config()
app_config = AppConfig.current()
model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware())
@@ -287,9 +254,9 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
middlewares.append(DeferredToolFilterMiddleware())
# Add SubagentLimitMiddleware to truncate excess parallel task calls
subagent_enabled = cfg.get("subagent_enabled", False)
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False)
if subagent_enabled:
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
# LoopDetectionMiddleware — detect and break repetitive tool call loops
@@ -304,12 +271,12 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
return middlewares
def make_lead_agent(config: RunnableConfig):
def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph:
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent
cfg = _get_runtime_config(config)
cfg = config.get("configurable", {})
thinking_enabled = cfg.get("thinking_enabled", True)
reasoning_effort = cfg.get("reasoning_effort", None)
@@ -318,7 +285,7 @@ def make_lead_agent(config: RunnableConfig):
subagent_enabled = cfg.get("subagent_enabled", False)
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
is_bootstrap = cfg.get("is_bootstrap", False)
agent_name = validate_agent_name(cfg.get("agent_name"))
agent_name = cfg.get("agent_name")
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
@@ -327,7 +294,7 @@ def make_lead_agent(config: RunnableConfig):
# Final model name resolution: request → agent config → global default, with fallback for unknown names
model_name = _resolve_model_name(requested_model_name or agent_model_name)
app_config = get_app_config()
app_config = AppConfig.current()
model_config = app_config.get_model_config(model_name)
if model_config is None:
@@ -359,8 +326,6 @@ def make_lead_agent(config: RunnableConfig):
"reasoning_effort": reasoning_effort,
"is_plan_mode": is_plan_mode,
"subagent_enabled": subagent_enabled,
"tool_groups": agent_config.tool_groups if agent_config else None,
"available_skills": ["bootstrap"] if is_bootstrap else (agent_config.skills if agent_config and agent_config.skills is not None else None),
}
)
@@ -372,6 +337,7 @@ def make_lead_agent(config: RunnableConfig):
middleware=_build_middlewares(config, model_name=model_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
state_schema=ThreadState,
context_schema=DeerFlowContext,
)
# Default lead agent (unchanged behavior)
@@ -383,4 +349,5 @@ def make_lead_agent(config: RunnableConfig):
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
),
state_schema=ThreadState,
context_schema=DeerFlowContext,
)
@@ -5,6 +5,7 @@ from datetime import datetime
from functools import lru_cache
from deerflow.config.agents_config import load_agent_soul
from deerflow.config.app_config import AppConfig
from deerflow.skills import load_skills
from deerflow.skills.types import Skill
from deerflow.subagents import get_available_subagent_names
@@ -164,36 +165,6 @@ Skip simple one-off tasks.
"""
def _build_available_subagents_description(available_names: list[str], bash_available: bool) -> str:
"""Dynamically build subagent type descriptions from registry.
Mirrors Codex's pattern where agent_type_description is dynamically generated
from all registered roles, so the LLM knows about every available type.
"""
# Built-in descriptions (kept for backward compatibility with existing prompt quality)
builtin_descriptions = {
"general-purpose": "For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.",
"bash": (
"For command execution (git, build, test, deploy operations)" if bash_available else "Not available in the current sandbox configuration. Use direct file/web tools or switch to AioSandboxProvider for isolated shell access."
),
}
# Lazy import moved outside loop to avoid repeated import overhead
from deerflow.subagents.registry import get_subagent_config
lines = []
for name in available_names:
if name in builtin_descriptions:
lines.append(f"- **{name}**: {builtin_descriptions[name]}")
else:
config = get_subagent_config(name)
if config is not None:
desc = config.description.split("\n")[0].strip() # First line only for brevity
lines.append(f"- **{name}**: {desc}")
return "\n".join(lines)
def _build_subagent_section(max_concurrent: int) -> str:
"""Build the subagent system prompt section with dynamic concurrency limit.
@@ -204,12 +175,13 @@ def _build_subagent_section(max_concurrent: int) -> str:
Formatted subagent section string.
"""
n = max_concurrent
available_names = get_available_subagent_names()
bash_available = "bash" in available_names
# Dynamically build subagent type descriptions from registry (aligned with Codex's
# agent_type_description pattern where all registered roles are listed in the tool spec).
available_subagents = _build_available_subagents_description(available_names, bash_available)
bash_available = "bash" in get_available_subagent_names()
available_subagents = (
"- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n- **bash**: For command execution (git, build, test, deploy operations)"
if bash_available
else "- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n"
"- **bash**: Not available in the current sandbox configuration. Use direct file/web tools or switch to AioSandboxProvider for isolated shell access."
)
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
direct_execution_example = (
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
@@ -449,7 +421,7 @@ You: "Deploying to staging..." [proceed]
- Treat `/mnt/user-data/workspace` as your default current working directory for coding and file-editing tasks
- When writing scripts or commands that create/read files from the workspace, prefer relative paths such as `hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`
- Avoid hardcoding `/mnt/user-data/...` inside generated scripts when a relative path from the workspace is enough
- Final deliverables must be copied to `/mnt/user-data/outputs` and presented using `present_files` tool
- Final deliverables must be copied to `/mnt/user-data/outputs` and presented using `present_file` tool
{acp_section}
</working_directory>
@@ -547,9 +519,8 @@ def _get_memory_context(agent_name: str | None = None) -> str:
"""
try:
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.config.memory_config import get_memory_config
config = get_memory_config()
config = AppConfig.current().memory
if not config.enabled or not config.injection_enabled:
return ""
@@ -605,9 +576,7 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
skills = _get_enabled_skills()
try:
from deerflow.config import get_app_config
config = get_app_config()
config = AppConfig.current()
container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled
except Exception:
@@ -646,9 +615,7 @@ def get_deferred_tools_prompt_section() -> str:
from deerflow.tools.builtins.tool_search import get_deferred_registry
try:
from deerflow.config import get_app_config
if not get_app_config().tool_search.enabled:
if not AppConfig.current().tool_search.enabled:
return ""
except Exception:
return ""
@@ -664,9 +631,7 @@ def get_deferred_tools_prompt_section() -> str:
def _build_acp_section() -> str:
"""Build the ACP agent prompt section, only if ACP agents are configured."""
try:
from deerflow.config.acp_config import get_acp_agents
agents = get_acp_agents()
agents = AppConfig.current().acp_agents
if not agents:
return ""
except Exception:
@@ -677,16 +642,14 @@ def _build_acp_section() -> str:
"- ACP agents (e.g. codex, claude_code) run in their own independent workspace — NOT in `/mnt/user-data/`\n"
"- When writing prompts for ACP agents, describe the task only — do NOT reference `/mnt/user-data` paths\n"
"- ACP agent results are accessible at `/mnt/acp-workspace/` (read-only) — use `ls`, `read_file`, or `bash cp` to retrieve output files\n"
"- To deliver ACP output to the user: copy from `/mnt/acp-workspace/<file>` to `/mnt/user-data/outputs/<file>`, then use `present_files`"
"- To deliver ACP output to the user: copy from `/mnt/acp-workspace/<file>` to `/mnt/user-data/outputs/<file>`, then use `present_file`"
)
def _build_custom_mounts_section() -> str:
"""Build a prompt section for explicitly configured sandbox mounts."""
try:
from deerflow.config import get_app_config
mounts = get_app_config().sandbox.mounts or []
mounts = AppConfig.current().sandbox.mounts or []
except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return ""
@@ -1,109 +0,0 @@
"""Shared helpers for turning conversations into memory update inputs."""
from __future__ import annotations
import re
from copy import copy
from typing import Any
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
def extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Keep only user inputs and final assistant responses for memory updates."""
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content_str = extract_message_text(msg)
if "<uploaded_files>" in content_str:
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
if not stripped:
skip_next_ai = True
continue
clean_msg = copy(msg)
clean_msg.content = stripped
filtered.append(clean_msg)
skip_next_ai = False
else:
filtered.append(msg)
skip_next_ai = False
elif msg_type == "ai":
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
if skip_next_ai:
skip_next_ai = False
continue
filtered.append(msg)
return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns."""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = extract_message_text(msg).strip()
if content and any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns."""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = extract_message_text(msg).strip()
if content and any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
@@ -7,7 +7,7 @@ from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
from deerflow.config.memory_config import get_memory_config
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@@ -56,93 +56,53 @@ class MemoryUpdateQueue:
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
"""
config = get_memory_config()
config = AppConfig.current().memory
if not config.enabled:
return
with self._lock:
self._enqueue_locked(
existing_context = next(
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
# Check if this thread already has a pending update
# If so, replace it with the newer one
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue.append(context)
# Reset or start the debounce timer
self._reset_timer()
logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue))
def add_nowait(
self,
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
"""Add a conversation and start processing immediately in the background."""
config = get_memory_config()
if not config.enabled:
return
with self._lock:
self._enqueue_locked(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
self._schedule_timer(0)
logger.info("Memory update queued for immediate processing on thread %s, queue size: %d", thread_id, len(self._queue))
def _enqueue_locked(
self,
*,
thread_id: str,
messages: list[Any],
agent_name: str | None,
correction_detected: bool,
reinforcement_detected: bool,
) -> None:
existing_context = next(
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue.append(context)
def _reset_timer(self) -> None:
"""Reset the debounce timer."""
config = get_memory_config()
self._schedule_timer(config.debounce_seconds)
config = AppConfig.current().memory
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
def _schedule_timer(self, delay_seconds: float) -> None:
"""Schedule queue processing after the provided delay."""
# Cancel existing timer if any
if self._timer is not None:
self._timer.cancel()
# Start new timer
self._timer = threading.Timer(
delay_seconds,
config.debounce_seconds,
self._process_queue,
)
self._timer.daemon = True
self._timer.start()
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
def _process_queue(self) -> None:
"""Process all queued conversation contexts."""
# Import here to avoid circular dependency
@@ -150,8 +110,8 @@ class MemoryUpdateQueue:
with self._lock:
if self._processing:
# Preserve immediate flush semantics even if another worker is active.
self._schedule_timer(0)
# Already processing, reschedule
self._reset_timer()
return
if not self._queue:
@@ -204,13 +164,6 @@ class MemoryUpdateQueue:
self._process_queue()
def flush_nowait(self) -> None:
"""Start queue processing immediately in a background thread."""
with self._lock:
# Daemon thread: queued messages may be lost if the process exits
# before _process_queue completes. Acceptable for best-effort memory updates.
self._schedule_timer(0)
def clear(self) -> None:
"""Clear the queue without processing.
@@ -4,13 +4,12 @@ import abc
import json
import logging
import threading
import uuid
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.memory_config import get_memory_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__)
@@ -67,8 +66,6 @@ class FileMemoryStorage(MemoryStorage):
# Per-agent memory cache: keyed by agent_name (None = global)
# Value: (memory_data, file_mtime)
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
# Guards all reads and writes to _memory_cache across concurrent callers.
self._cache_lock = threading.Lock()
def _validate_agent_name(self, agent_name: str) -> None:
"""Validate that the agent name is safe to use in filesystem paths.
@@ -87,7 +84,7 @@ class FileMemoryStorage(MemoryStorage):
self._validate_agent_name(agent_name)
return get_paths().agent_memory_file(agent_name)
config = get_memory_config()
config = AppConfig.current().memory
if config.storage_path:
p = Path(config.storage_path)
return p if p.is_absolute() else get_paths().base_dir / p
@@ -117,17 +114,14 @@ class FileMemoryStorage(MemoryStorage):
except OSError:
current_mtime = None
with self._cache_lock:
cached = self._memory_cache.get(agent_name)
if cached is not None and cached[1] == current_mtime:
return cached[0]
cached = self._memory_cache.get(agent_name)
memory_data = self._load_memory_from_file(agent_name)
with self._cache_lock:
if cached is None or cached[1] != current_mtime:
memory_data = self._load_memory_from_file(agent_name)
self._memory_cache[agent_name] = (memory_data, current_mtime)
return memory_data
return memory_data
return cached[0]
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
"""Reload memory data from file, forcing cache invalidation."""
@@ -139,8 +133,7 @@ class FileMemoryStorage(MemoryStorage):
except OSError:
mtime = None
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
self._memory_cache[agent_name] = (memory_data, mtime)
return memory_data
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
@@ -149,12 +142,9 @@ class FileMemoryStorage(MemoryStorage):
try:
file_path.parent.mkdir(parents=True, exist_ok=True)
# Shallow-copy before adding lastUpdated so the caller's dict is not
# mutated as a side-effect, and the cache reference is not silently
# updated before the file write succeeds.
memory_data = {**memory_data, "lastUpdated": utc_now_iso_z()}
memory_data["lastUpdated"] = utc_now_iso_z()
temp_path = file_path.with_suffix(f".{uuid.uuid4().hex}.tmp")
temp_path = file_path.with_suffix(".tmp")
with open(temp_path, "w", encoding="utf-8") as f:
json.dump(memory_data, f, indent=2, ensure_ascii=False)
@@ -165,8 +155,7 @@ class FileMemoryStorage(MemoryStorage):
except OSError:
mtime = None
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
self._memory_cache[agent_name] = (memory_data, mtime)
logger.info("Memory saved to %s", file_path)
return True
except OSError as e:
@@ -188,7 +177,7 @@ def get_memory_storage() -> MemoryStorage:
if _storage_instance is not None:
return _storage_instance
config = get_memory_config()
config = AppConfig.current().memory
storage_class_path = config.storage_class
try:
@@ -1,31 +0,0 @@
"""Hooks fired before summarization removes messages from state."""
from __future__ import annotations
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
from deerflow.config.memory_config import get_memory_config
def memory_flush_hook(event: SummarizationEvent) -> None:
"""Flush messages about to be summarized into the memory queue."""
if not get_memory_config().enabled or not event.thread_id:
return
filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize))
user_messages = [message for message in filtered_messages if getattr(message, "type", None) == "human"]
assistant_messages = [message for message in filtered_messages if getattr(message, "type", None) == "ai"]
if not user_messages or not assistant_messages:
return
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue()
queue.add_nowait(
thread_id=event.thread_id,
messages=filtered_messages,
agent_name=event.agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@@ -1,15 +1,10 @@
"""Memory updater for reading, writing, and updating memory data."""
import asyncio
import atexit
import concurrent.futures
import copy
import json
import logging
import math
import re
import uuid
from collections.abc import Awaitable
from typing import Any
from deerflow.agents.memory.prompt import (
@@ -21,17 +16,11 @@ from deerflow.agents.memory.storage import (
get_memory_storage,
utc_now_iso_z,
)
from deerflow.config.memory_config import get_memory_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
max_workers=4,
thread_name_prefix="memory-updater-sync",
)
atexit.register(lambda: _SYNC_MEMORY_UPDATER_EXECUTOR.shutdown(wait=False))
def _create_empty_memory() -> dict[str, Any]:
"""Backward-compatible wrapper around the storage-layer empty-memory factory."""
@@ -217,39 +206,6 @@ def _extract_text(content: Any) -> str:
return str(content)
def _run_async_update_sync(coro: Awaitable[bool]) -> bool:
"""Run an async memory update from sync code, including nested-loop contexts."""
handed_off = False
try:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro)
handed_off = True
return future.result()
handed_off = True
return asyncio.run(coro)
except Exception:
if not handed_off:
close = getattr(coro, "close", None)
if callable(close):
try:
close()
except Exception:
logger.debug(
"Failed to close un-awaited memory update coroutine",
exc_info=True,
)
logger.exception("Failed to run async memory update from sync context")
return False
# Matches sentences that describe a file-upload *event* rather than general
# file-related work. Deliberately narrow to avoid removing legitimate facts
# such as "User works with CSV files" or "prefers PDF export".
@@ -309,121 +265,10 @@ class MemoryUpdater:
def _get_model(self):
"""Get the model for memory updates."""
config = get_memory_config()
config = AppConfig.current().memory
model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False)
def _build_correction_hint(
self,
correction_detected: bool,
reinforcement_detected: bool,
) -> str:
"""Build optional prompt hints for correction and reinforcement signals."""
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
return correction_hint
def _prepare_update_prompt(
self,
messages: list[Any],
agent_name: str | None,
correction_detected: bool,
reinforcement_detected: bool,
) -> tuple[dict[str, Any], str] | None:
"""Load memory and build the update prompt for a conversation."""
config = get_memory_config()
if not config.enabled or not messages:
return None
current_memory = get_memory_data(agent_name)
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return None
correction_hint = self._build_correction_hint(
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
)
return current_memory, prompt
def _finalize_update(
self,
current_memory: dict[str, Any],
response_content: Any,
thread_id: str | None,
agent_name: str | None,
) -> bool:
"""Parse the model response, apply updates, and persist memory."""
response_text = _extract_text(response_content).strip()
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Deep-copy before in-place mutation so a subsequent save() failure
# cannot corrupt the still-cached original object reference.
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
return get_memory_storage().save(updated_memory, agent_name)
async def aupdate_memory(
self,
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Update memory asynchronously based on conversation messages."""
try:
prepared = await asyncio.to_thread(
self._prepare_update_prompt,
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
if prepared is None:
return False
current_memory, prompt = prepared
model = self._get_model()
response = await model.ainvoke(prompt, config={"run_name": "memory_agent"})
return await asyncio.to_thread(
self._finalize_update,
current_memory=current_memory,
response_content=response.content,
thread_id=thread_id,
agent_name=agent_name,
)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
def update_memory(
self,
messages: list[Any],
@@ -432,7 +277,7 @@ class MemoryUpdater:
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Synchronously update memory via the async updater path.
"""Update memory based on conversation messages.
Args:
messages: List of conversation messages.
@@ -444,15 +289,78 @@ class MemoryUpdater:
Returns:
True if update was successful, False otherwise.
"""
return _run_async_update_sync(
self.aupdate_memory(
messages=messages,
thread_id=thread_id,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
config = AppConfig.current().memory
if not config.enabled:
return False
if not messages:
return False
try:
# Get current memory
current_memory = get_memory_data(agent_name)
# Format conversation for prompt
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return False
# Build prompt
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
)
)
# Call LLM
model = self._get_model()
response = model.invoke(prompt)
response_text = _extract_text(response.content).strip()
# Parse response
# Remove markdown code blocks if present
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Apply updates
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
# Strip file-upload mentions from all summaries before saving.
# Uploaded files are session-scoped and won't exist in future sessions,
# so recording upload events in long-term memory causes the agent to
# try (and fail) to locate those files in subsequent conversations.
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save
return get_memory_storage().save(updated_memory, agent_name)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
def _apply_updates(
self,
@@ -470,7 +378,7 @@ class MemoryUpdater:
Returns:
Updated memory data.
"""
config = get_memory_config()
config = AppConfig.current().memory
now = utc_now_iso_z()
# Update user sections
@@ -3,7 +3,6 @@
import json
import logging
from collections.abc import Callable
from hashlib import sha256
from typing import override
from langchain.agents import AgentState
@@ -37,13 +36,6 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
state_schema = ClarificationMiddlewareState
def _stable_message_id(self, tool_call_id: str, formatted_message: str) -> str:
"""Build a deterministic message ID so retried clarification calls replace, not append."""
if tool_call_id:
return f"clarification:{tool_call_id}"
digest = sha256(formatted_message.encode("utf-8")).hexdigest()[:16]
return f"clarification:{digest}"
def _is_chinese(self, text: str) -> bool:
"""Check if text contains Chinese characters.
@@ -139,7 +131,6 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
# Create a ToolMessage with the formatted question
# This will be added to the message history
tool_message = ToolMessage(
id=self._stable_message_id(tool_call_id, formatted_message),
content=formatted_message,
tool_call_id=tool_call_id,
name="ask_clarification",
@@ -13,7 +13,6 @@ at the correct positions (immediately after each dangling AIMessage), not append
to the end of the message list as before_model + add_messages reducer would do.
"""
import json
import logging
from collections.abc import Awaitable, Callable
from typing import override
@@ -34,44 +33,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
offending AIMessage so the LLM receives a well-formed conversation.
"""
@staticmethod
def _message_tool_calls(msg) -> list[dict]:
"""Return normalized tool calls from structured fields or raw provider payloads."""
tool_calls = getattr(msg, "tool_calls", None) or []
if tool_calls:
return list(tool_calls)
raw_tool_calls = (getattr(msg, "additional_kwargs", None) or {}).get("tool_calls") or []
normalized: list[dict] = []
for raw_tc in raw_tool_calls:
if not isinstance(raw_tc, dict):
continue
function = raw_tc.get("function")
name = raw_tc.get("name")
if not name and isinstance(function, dict):
name = function.get("name")
args = raw_tc.get("args", {})
if not args and isinstance(function, dict):
raw_args = function.get("arguments")
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except (TypeError, ValueError, json.JSONDecodeError):
parsed_args = {}
args = parsed_args if isinstance(parsed_args, dict) else {}
normalized.append(
{
"id": raw_tc.get("id"),
"name": name or "unknown",
"args": args if isinstance(args, dict) else {},
}
)
return normalized
def _build_patched_messages(self, messages: list) -> list | None:
"""Return a new message list with patches inserted at the correct positions.
@@ -90,7 +51,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
for msg in messages:
if getattr(msg, "type", None) != "ai":
continue
for tc in self._message_tool_calls(msg):
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids:
needs_patch = True
@@ -109,7 +70,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
patched.append(msg)
if getattr(msg, "type", None) != "ai":
continue
for tc in self._message_tool_calls(msg):
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
patched.append(
@@ -16,9 +16,6 @@ from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__)
@@ -38,7 +35,7 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
if not registry:
return request
deferred_names = registry.deferred_names
deferred_names = {e.name for e in registry.entries}
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
if len(active_tools) < len(request.tools):
@@ -46,28 +43,6 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
return request.override(tools=active_tools)
def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None:
from deerflow.tools.builtins.tool_search import get_deferred_registry
registry = get_deferred_registry()
if not registry:
return None
tool_name = str(request.tool_call.get("name") or "")
if not tool_name:
return None
if not registry.contains(tool_name):
return None
tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id")
return ToolMessage(
content=(f"Error: Tool '{tool_name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."),
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
@override
def wrap_model_call(
self,
@@ -76,17 +51,6 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
) -> ModelCallResult:
return handler(self._filter_tools(request))
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
blocked = self._blocked_tool_message(request)
if blocked is not None:
return blocked
return handler(request)
@override
async def awrap_model_call(
self,
@@ -94,14 +58,3 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await handler(self._filter_tools(request))
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
blocked = self._blocked_tool_message(request)
if blocked is not None:
return blocked
return await handler(request)
@@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio
import logging
import threading
import time
from collections.abc import Awaitable, Callable
from email.utils import parsedate_to_datetime
@@ -20,8 +19,6 @@ from langchain.agents.middleware.types import (
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
from deerflow.config import get_app_config
logger = logging.getLogger(__name__)
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
@@ -70,80 +67,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000
circuit_failure_threshold: int = 5
circuit_recovery_timeout_sec: int = 60
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
# Load Circuit Breaker configs from app config if available, fall back to defaults
try:
app_config = get_app_config()
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
except (FileNotFoundError, RuntimeError):
# Gracefully fall back to class defaults in test environments
pass
# Circuit Breaker state
self._circuit_lock = threading.Lock()
self._circuit_failure_count = 0
self._circuit_open_until = 0.0
self._circuit_state = "closed"
self._circuit_probe_in_flight = False
def _check_circuit(self) -> bool:
"""Returns True if circuit is OPEN (fast fail), False otherwise."""
with self._circuit_lock:
now = time.time()
if self._circuit_state == "open":
if now < self._circuit_open_until:
return True
self._circuit_state = "half_open"
self._circuit_probe_in_flight = False
if self._circuit_state == "half_open":
if self._circuit_probe_in_flight:
return True
self._circuit_probe_in_flight = True
return False
return False
def _record_success(self) -> None:
with self._circuit_lock:
if self._circuit_state != "closed" or self._circuit_failure_count > 0:
logger.info("Circuit breaker reset (Closed). LLM service recovered.")
self._circuit_failure_count = 0
self._circuit_open_until = 0.0
self._circuit_state = "closed"
self._circuit_probe_in_flight = False
def _record_failure(self) -> None:
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
self._circuit_state = "open"
self._circuit_probe_in_flight = False
logger.error(
"Circuit breaker probe failed (Open). Will probe again after %ds.",
self.circuit_recovery_timeout_sec,
)
return
self._circuit_failure_count += 1
if self._circuit_failure_count >= self.circuit_failure_threshold:
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
if self._circuit_state != "open":
self._circuit_state = "open"
self._circuit_probe_in_flight = False
logger.error(
"Circuit breaker tripped (Open). Threshold reached (%d). Will probe after %ds.",
self.circuit_failure_threshold,
self.circuit_recovery_timeout_sec,
)
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
detail = _extract_error_detail(exc)
lowered = detail.lower()
@@ -160,8 +83,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"APITimeoutError",
"APIConnectionError",
"InternalServerError",
"ReadError", # httpx.ReadError: connection dropped mid-stream
"RemoteProtocolError", # httpx: server closed connection unexpectedly
}:
return True, "transient"
if status_code in _RETRIABLE_STATUS_CODES:
@@ -183,9 +104,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
def _build_circuit_breaker_message(self) -> str:
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc)
if reason == "quota":
@@ -220,20 +138,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
if self._check_circuit():
return AIMessage(content=self._build_circuit_breaker_message())
attempt = 1
while True:
try:
response = handler(request)
self._record_success()
return response
return handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_probe_in_flight = False
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
@@ -256,8 +166,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
_extract_error_detail(exc),
exc_info=exc,
)
if retriable:
self._record_failure()
return AIMessage(content=self._build_user_message(exc, reason))
@override
@@ -266,20 +174,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
if self._check_circuit():
return AIMessage(content=self._build_circuit_breaker_message())
attempt = 1
while True:
try:
response = await handler(request)
self._record_success()
return response
return await handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_probe_in_flight = False
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
@@ -302,8 +202,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
_extract_error_detail(exc),
exc_info=exc,
)
if retriable:
self._record_failure()
return AIMessage(content=self._build_user_message(exc, reason))
@@ -17,7 +17,6 @@ import json
import logging
import threading
from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import override
from langchain.agents import AgentState
@@ -25,7 +24,7 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.utils.runtime import get_thread_id
from deerflow.config.deer_flow_context import DeerFlowContext
logger = logging.getLogger(__name__)
@@ -183,9 +182,9 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime) -> str:
def _get_thread_id(self, runtime: Runtime[DeerFlowContext]) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
return get_thread_id(runtime) or "default"
return runtime.context.thread_id or "default"
def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit.
@@ -323,26 +322,6 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
# Fallback: coerce unexpected types to str to avoid TypeError
return str(content) + f"\n\n{text}"
@staticmethod
def _build_hard_stop_update(last_msg, content: str | list) -> dict:
"""Clear tool-call metadata so forced-stop messages serialize as plain assistant text."""
update = {
"tool_calls": [],
"content": content,
}
additional_kwargs = dict(getattr(last_msg, "additional_kwargs", {}) or {})
for key in ("tool_calls", "function_call"):
additional_kwargs.pop(key, None)
update["additional_kwargs"] = additional_kwargs
response_metadata = deepcopy(getattr(last_msg, "response_metadata", {}) or {})
if response_metadata.get("finish_reason") == "tool_calls":
response_metadata["finish_reason"] = "stop"
update["response_metadata"] = response_metadata
return update
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime)
@@ -350,8 +329,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
# Strip tool_calls from the last AIMessage to force text output
messages = state.get("messages", [])
last_msg = messages[-1]
content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG)
stripped_msg = last_msg.model_copy(update=self._build_hard_stop_update(last_msg, content))
stripped_msg = last_msg.model_copy(
update={
"tool_calls": [],
"content": self._append_text(last_msg.content, warning),
}
)
return {"messages": [stripped_msg]}
if warning:
@@ -366,11 +349,11 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
return None
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
def after_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._apply(state, runtime)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
async def aafter_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._apply(state, runtime)
def reset(self, thread_id: str | None = None) -> None:
@@ -1,19 +1,49 @@
"""Middleware for memory mechanism."""
import logging
from typing import override
import re
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config
from deerflow.utils.runtime import get_thread_id
from deerflow.config.deer_flow_context import DeerFlowContext
logger = logging.getLogger(__name__)
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
@@ -21,6 +51,125 @@ class MemoryMiddlewareState(AgentState):
pass
def _extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Filter messages to keep only user inputs and final assistant responses.
This filters out:
- Tool messages (intermediate tool call results)
- AI messages with tool_calls (intermediate steps, not final responses)
- The <uploaded_files> block injected by UploadsMiddleware into human messages
(file paths are session-scoped and must not persist in long-term memory).
The user's actual question is preserved; only turns whose content is entirely
the upload block (nothing remains after stripping) are dropped along with
their paired assistant response.
Only keeps:
- Human messages (with the ephemeral upload block removed)
- AI messages without tool_calls (final assistant responses), unless the
paired human turn was upload-only and had no real user text.
Args:
messages: List of all conversation messages.
Returns:
Filtered list containing only user inputs and final assistant responses.
"""
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content_str = _extract_message_text(msg)
if "<uploaded_files>" in content_str:
# Strip the ephemeral upload block; keep the user's real question.
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
if not stripped:
# Nothing left — the entire turn was upload bookkeeping;
# skip it and the paired assistant response.
skip_next_ai = True
continue
# Rebuild the message with cleaned content so the user's question
# is still available for memory summarisation.
from copy import copy
clean_msg = copy(msg)
clean_msg.content = stripped
filtered.append(clean_msg)
skip_next_ai = False
else:
filtered.append(msg)
skip_next_ai = False
elif msg_type == "ai":
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
if skip_next_ai:
skip_next_ai = False
continue
filtered.append(msg)
# Skip tool messages and AI messages with tool_calls
return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale corrections from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns.
Complements detect_correction() by identifying when the user confirms the
agent's approach was correct. This allows the memory system to record what
worked well, not just what went wrong.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale signals from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
@@ -43,7 +192,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
self._agent_name = agent_name
@override
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
"""Queue conversation for memory update after agent completes.
Args:
@@ -53,14 +202,13 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
Returns:
None (no state changes needed from this middleware).
"""
config = get_memory_config()
if not config.enabled:
memory_config = runtime.context.app_config.memory
if not memory_config.enabled:
return None
# Resolve thread ID from the runtime or configured fallback sources
thread_id = get_thread_id(runtime)
thread_id = runtime.context.thread_id
if not thread_id:
logger.debug("No thread_id could be resolved from runtime/config, skipping memory update")
logger.debug("No thread_id in context, skipping memory update")
return None
# Get messages from state
@@ -70,7 +218,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
return None
# Filter to only keep user inputs and final assistant responses
filtered_messages = filter_messages_for_memory(messages)
filtered_messages = _filter_messages_for_memory(messages)
# Only queue if there's meaningful conversation
# At minimum need one user message and one assistant response
@@ -14,7 +14,6 @@ from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from deerflow.agents.thread_state import ThreadState
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@@ -219,7 +218,15 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
# ------------------------------------------------------------------
def _get_thread_id(self, request: ToolCallRequest) -> str | None:
return get_thread_id(request.runtime)
runtime = request.runtime # ToolRuntime; may be None-like in tests
if runtime is None:
return None
ctx = getattr(runtime, "context", None) or {}
thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None
if thread_id is None:
cfg = getattr(runtime, "config", None) or {}
thread_id = cfg.get("configurable", {}).get("thread_id")
return thread_id
_AUDIT_COMMAND_LIMIT = 200
@@ -1,337 +0,0 @@
"""Summarization middleware extensions for DeerFlow."""
from __future__ import annotations
import logging
from collections.abc import Collection
from dataclasses import dataclass
from typing import Any, Protocol, runtime_checkable
from langchain.agents import AgentState
from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.messages import AIMessage, AnyMessage, RemoveMessage, ToolMessage
from langgraph.config import get_config
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class SummarizationEvent:
"""Context emitted before conversation history is summarized away."""
messages_to_summarize: tuple[AnyMessage, ...]
preserved_messages: tuple[AnyMessage, ...]
thread_id: str | None
agent_name: str | None
runtime: Runtime
@runtime_checkable
class BeforeSummarizationHook(Protocol):
"""Hook invoked before summarization removes messages from state."""
def __call__(self, event: SummarizationEvent) -> None: ...
def _resolve_agent_name(runtime: Runtime) -> str | None:
"""Resolve the current agent name from runtime context or LangGraph config."""
agent_name = runtime.context.get("agent_name") if runtime.context else None
if agent_name is None:
try:
config_data = get_config()
except RuntimeError:
return None
agent_name = config_data.get("configurable", {}).get("agent_name")
return agent_name
def _tool_call_path(tool_call: dict[str, Any]) -> str | None:
"""Best-effort extraction of a file path argument from a read_file-like tool call."""
args = tool_call.get("args") or {}
if not isinstance(args, dict):
return None
for key in ("path", "file_path", "filepath"):
value = args.get(key)
if isinstance(value, str) and value:
return value
return None
def _clone_ai_message(
message: AIMessage,
tool_calls: list[dict[str, Any]],
*,
content: Any | None = None,
) -> AIMessage:
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
update: dict[str, Any] = {"tool_calls": tool_calls}
if content is not None:
update["content"] = content
return message.model_copy(update=update)
@dataclass
class _SkillBundle:
"""Skill-related tool calls and tool results associated with one AIMessage."""
ai_index: int
skill_tool_indices: tuple[int, ...]
skill_tool_call_ids: frozenset[str]
skill_tool_tokens: int
skill_key: str
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
"""Summarization middleware with pre-compression hook dispatch and skill rescue."""
def __init__(
self,
*args,
skills_container_path: str | None = None,
skill_file_read_tool_names: Collection[str] | None = None,
before_summarization: list[BeforeSummarizationHook] | None = None,
preserve_recent_skill_count: int = 5,
preserve_recent_skill_tokens: int = 25_000,
preserve_recent_skill_tokens_per_skill: int = 5_000,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._skills_container_path = skills_container_path or "/mnt/skills"
self._skill_file_read_tool_names = frozenset(skill_file_read_tool_names or {"read_file", "read", "view", "cat"})
self._before_summarization_hooks = before_summarization or []
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._maybe_summarize(state, runtime)
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return await self._amaybe_summarize(state, runtime)
def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = self._create_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = await self._acreate_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
def _partition_with_skill_rescue(
self,
messages: list[AnyMessage],
cutoff_index: int,
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Partition like the parent, then rescue recently-loaded skill bundles."""
to_summarize, preserved = self._partition_messages(messages, cutoff_index)
if self._preserve_recent_skill_count == 0 or self._preserve_recent_skill_tokens == 0 or not to_summarize:
return to_summarize, preserved
try:
bundles = self._find_skill_bundles(to_summarize, self._skills_container_path)
except Exception:
logger.exception("Skill-preserving summarization rescue failed; falling back to default partition")
return to_summarize, preserved
if not bundles:
return to_summarize, preserved
rescue_bundles = self._select_bundles_to_rescue(bundles)
if not rescue_bundles:
return to_summarize, preserved
bundles_by_ai_index = {bundle.ai_index: bundle for bundle in rescue_bundles}
rescue_tool_indices = {idx for bundle in rescue_bundles for idx in bundle.skill_tool_indices}
rescued: list[AnyMessage] = []
remaining: list[AnyMessage] = []
for i, msg in enumerate(to_summarize):
bundle = bundles_by_ai_index.get(i)
if bundle is not None and isinstance(msg, AIMessage):
rescued_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") in bundle.skill_tool_call_ids]
remaining_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") not in bundle.skill_tool_call_ids]
if rescued_tool_calls:
rescued.append(_clone_ai_message(msg, rescued_tool_calls, content=""))
if remaining_tool_calls or msg.content:
remaining.append(_clone_ai_message(msg, remaining_tool_calls))
continue
if i in rescue_tool_indices:
rescued.append(msg)
continue
remaining.append(msg)
return remaining, rescued + preserved
def _find_skill_bundles(
self,
messages: list[AnyMessage],
skills_root: str,
) -> list[_SkillBundle]:
"""Locate AIMessage + paired ToolMessage groups that load skill files."""
bundles: list[_SkillBundle] = []
n = len(messages)
i = 0
while i < n:
msg = messages[i]
if not (isinstance(msg, AIMessage) and msg.tool_calls):
i += 1
continue
tool_calls = list(msg.tool_calls)
skill_paths_by_id: dict[str, str] = {}
for tc in tool_calls:
if self._is_skill_tool_call(tc, skills_root):
tc_id = tc.get("id")
path = _tool_call_path(tc)
if tc_id and path:
skill_paths_by_id[tc_id] = path
if not skill_paths_by_id:
i += 1
continue
skill_tool_tokens = 0
skill_key_parts: list[str] = []
skill_tool_indices: list[int] = []
matched_skill_call_ids: set[str] = set()
j = i + 1
while j < n and isinstance(messages[j], ToolMessage):
j += 1
for k in range(i + 1, j):
tool_msg = messages[k]
if isinstance(tool_msg, ToolMessage) and tool_msg.tool_call_id in skill_paths_by_id:
skill_tool_tokens += self.token_counter([tool_msg])
skill_key_parts.append(skill_paths_by_id[tool_msg.tool_call_id])
skill_tool_indices.append(k)
matched_skill_call_ids.add(tool_msg.tool_call_id)
if not skill_tool_indices:
i = j
continue
bundles.append(
_SkillBundle(
ai_index=i,
skill_tool_indices=tuple(skill_tool_indices),
skill_tool_call_ids=frozenset(matched_skill_call_ids),
skill_tool_tokens=skill_tool_tokens,
skill_key="|".join(sorted(skill_key_parts)),
)
)
i = j
return bundles
def _select_bundles_to_rescue(self, bundles: list[_SkillBundle]) -> list[_SkillBundle]:
"""Pick bundles to keep, walking newest-first under count/token budgets."""
selected: list[_SkillBundle] = []
if not bundles:
return selected
seen_skill_keys: set[str] = set()
total_tokens = 0
kept = 0
for bundle in reversed(bundles):
if kept >= self._preserve_recent_skill_count:
break
if bundle.skill_key in seen_skill_keys:
continue
if bundle.skill_tool_tokens > self._preserve_recent_skill_tokens_per_skill:
continue
if total_tokens + bundle.skill_tool_tokens > self._preserve_recent_skill_tokens:
continue
selected.append(bundle)
total_tokens += bundle.skill_tool_tokens
kept += 1
seen_skill_keys.add(bundle.skill_key)
selected.reverse()
return selected
def _is_skill_tool_call(self, tool_call: dict[str, Any], skills_root: str) -> bool:
"""Return True when ``tool_call`` reads a file under the configured skills root."""
name = tool_call.get("name") or ""
if name not in self._skill_file_read_tool_names:
return False
path = _tool_call_path(tool_call)
if not path:
return False
normalized_root = skills_root.rstrip("/")
return path == normalized_root or path.startswith(normalized_root + "/")
def _fire_hooks(
self,
messages_to_summarize: list[AnyMessage],
preserved_messages: list[AnyMessage],
runtime: Runtime,
) -> None:
if not self._before_summarization_hooks:
return
event = SummarizationEvent(
messages_to_summarize=tuple(messages_to_summarize),
preserved_messages=tuple(preserved_messages),
thread_id=get_thread_id(runtime),
agent_name=_resolve_agent_name(runtime),
runtime=runtime,
)
for hook in self._before_summarization_hooks:
try:
hook(event)
except Exception:
hook_name = getattr(hook, "__name__", None) or type(hook).__name__
logger.exception("before_summarization hook %s failed", hook_name)
@@ -6,8 +6,8 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths, get_paths
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@@ -74,10 +74,10 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
return self._get_thread_paths(thread_id)
@override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
thread_id = get_thread_id(runtime)
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
thread_id = runtime.context.thread_id
if thread_id is None:
if not thread_id:
raise ValueError("Thread ID is required in runtime context or config.configurable")
if self._lazy_init:
@@ -1,14 +1,13 @@
"""Middleware for automatic thread title generation."""
import logging
import re
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -46,7 +45,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
"""Check if we should generate a title for this thread."""
config = get_title_config()
config = AppConfig.current().title
if not config.enabled:
return False
@@ -71,14 +70,14 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
"""
config = get_title_config()
config = AppConfig.current().title
messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
user_msg = self._normalize_content(user_msg_content)
assistant_msg = self._strip_think_tags(self._normalize_content(assistant_msg_content))
assistant_msg = self._normalize_content(assistant_msg_content)
prompt = config.prompt_template.format(
max_words=config.max_words,
@@ -87,20 +86,15 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
)
return prompt, user_msg
def _strip_think_tags(self, text: str) -> str:
"""Remove <think>...</think> blocks emitted by reasoning models (e.g. minimax, DeepSeek-R1)."""
return re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip()
def _parse_title(self, content: object) -> str:
"""Normalize model output into a clean title string."""
config = get_title_config()
config = AppConfig.current().title
title_content = self._normalize_content(content)
title_content = self._strip_think_tags(title_content)
title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title
def _fallback_title(self, user_msg: str) -> str:
config = get_title_config()
config = AppConfig.current().title
fallback_chars = min(config.max_chars, 50)
if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..."
@@ -119,7 +113,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
if not self._should_generate_title(state):
return None
config = get_title_config()
config = AppConfig.current().title
prompt, user_msg = self._build_title_prompt(state)
try:
@@ -127,7 +121,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
model = create_chat_model(name=config.model_name, thinking_enabled=False)
else:
model = create_chat_model(thinking_enabled=False)
response = await model.ainvoke(prompt, config={"run_name": "title_agent"})
response = await model.ainvoke(prompt)
title = self._parse_title(response.content)
if title:
return {"title": title}
@@ -1,14 +1,9 @@
"""Middleware that extends TodoListMiddleware with context-loss detection and premature-exit prevention.
"""Middleware that extends TodoListMiddleware with context-loss detection.
When the message history is truncated (e.g., by SummarizationMiddleware), the
original `write_todos` tool call and its ToolMessage can be scrolled out of the
active context window. This middleware detects that situation and injects a
reminder message so the model still knows about the outstanding todo list.
Additionally, this middleware prevents the agent from exiting the loop while
there are still incomplete todo items. When the model produces a final response
(no tool calls) but todos are not yet complete, the middleware injects a reminder
and jumps back to the model node to force continued engagement.
"""
from __future__ import annotations
@@ -17,7 +12,6 @@ from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo
from langchain.agents.middleware.types import hook_config
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime
@@ -40,11 +34,6 @@ def _reminder_in_messages(messages: list[Any]) -> bool:
return False
def _completion_reminder_count(messages: list[Any]) -> int:
"""Return the number of todo_completion_reminder HumanMessages in *messages*."""
return sum(1 for msg in messages if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_completion_reminder")
def _format_todos(todos: list[Todo]) -> str:
"""Format a list of Todo items into a human-readable string."""
lines: list[str] = []
@@ -68,7 +57,7 @@ class TodoMiddleware(TodoListMiddleware):
def before_model(
self,
state: PlanningState,
runtime: Runtime,
runtime: Runtime, # noqa: ARG002
) -> dict[str, Any] | None:
"""Inject a todo-list reminder when write_todos has left the context window."""
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
@@ -109,71 +98,3 @@ class TodoMiddleware(TodoListMiddleware):
) -> dict[str, Any] | None:
"""Async version of before_model."""
return self.before_model(state, runtime)
# Maximum number of completion reminders before allowing the agent to exit.
# This prevents infinite loops when the agent cannot make further progress.
_MAX_COMPLETION_REMINDERS = 2
@hook_config(can_jump_to=["model"])
@override
def after_model(
self,
state: PlanningState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Prevent premature agent exit when todo items are still incomplete.
In addition to the base class check for parallel ``write_todos`` calls,
this override intercepts model responses that have no tool calls while
there are still incomplete todo items. It injects a reminder
``HumanMessage`` and jumps back to the model node so the agent
continues working through the todo list.
A retry cap of ``_MAX_COMPLETION_REMINDERS`` (default 2) prevents
infinite loops when the agent cannot make further progress.
"""
# 1. Preserve base class logic (parallel write_todos detection).
base_result = super().after_model(state, runtime)
if base_result is not None:
return base_result
# 2. Only intervene when the agent wants to exit (no tool calls).
messages = state.get("messages") or []
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
if not last_ai or last_ai.tool_calls:
return None
# 3. Allow exit when all todos are completed or there are no todos.
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
if not todos or all(t.get("status") == "completed" for t in todos):
return None
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS:
return None
# 5. Inject a reminder and force the agent back to the model.
incomplete = [t for t in todos if t.get("status") != "completed"]
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
reminder = HumanMessage(
name="todo_completion_reminder",
content=(
"<system_reminder>\n"
"You have incomplete todo items that must be finished before giving your final response:\n\n"
f"{incomplete_text}\n\n"
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
"as you finish them, and only respond when all items are done.\n"
"</system_reminder>"
),
)
return {"jump_to": "model", "messages": [reminder]}
@override
@hook_config(can_jump_to=["model"])
async def aafter_model(
self,
state: PlanningState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async version of after_model."""
return self.after_model(state, runtime)
@@ -94,9 +94,9 @@ def _build_runtime_middlewares(
middlewares.append(LLMErrorHandlingMiddleware())
# Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config
from deerflow.config.app_config import AppConfig
guardrails_config = get_guardrails_config()
guardrails_config = AppConfig.current().guardrails
if guardrails_config.enabled and guardrails_config.provider:
import inspect
@@ -9,9 +9,9 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths, get_paths
from deerflow.utils.file_conversion import extract_outline
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@@ -185,7 +185,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return files if files else None
@override
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
"""Inject uploaded files information before agent execution.
New files come from the current message's additional_kwargs.files.
@@ -214,7 +214,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return None
# Resolve uploads directory for existence checks
thread_id = get_thread_id(runtime)
thread_id = runtime.context.thread_id
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files
+19 -26
View File
@@ -36,8 +36,9 @@ from deerflow.agents.lead_agent.agent import _build_middlewares
from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.app_config import get_app_config, reload_app_config
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
from deerflow.config.paths import get_paths
from deerflow.models import create_chat_model
from deerflow.skills.installer import install_skill_from_archive
@@ -141,8 +142,8 @@ class DeerFlowClient:
middlewares: Optional list of custom middlewares to inject into the agent.
"""
if config_path is not None:
reload_app_config(config_path)
self._app_config = get_app_config()
AppConfig.init(AppConfig.from_file(config_path))
self._app_config = AppConfig.current()
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
@@ -551,9 +552,7 @@ class DeerFlowClient:
self._ensure_agent(config)
state: dict[str, Any] = {"messages": [HumanMessage(content=message)]}
context = {"thread_id": thread_id}
if self._agent_name:
context["agent_name"] = self._agent_name
context = DeerFlowContext(app_config=self._app_config, thread_id=thread_id, agent_name=self._agent_name)
seen_ids: set[str] = set()
# Cross-mode handoff: ids already streamed via LangGraph ``messages``
@@ -722,10 +721,6 @@ class DeerFlowClient:
Dict with "models" key containing list of model info dicts,
matching the Gateway API ``ModelsListResponse`` schema.
"""
token_usage_enabled = getattr(getattr(self._app_config, "token_usage", None), "enabled", False)
if not isinstance(token_usage_enabled, bool):
token_usage_enabled = False
return {
"models": [
{
@@ -737,8 +732,7 @@ class DeerFlowClient:
"supports_reasoning_effort": getattr(model, "supports_reasoning_effort", False),
}
for model in self._app_config.models
],
"token_usage": {"enabled": token_usage_enabled},
]
}
def list_skills(self, enabled_only: bool = False) -> dict:
@@ -821,8 +815,8 @@ class DeerFlowClient:
Dict with "mcp_servers" key mapping server name to config,
matching the Gateway API ``McpConfigResponse`` schema.
"""
config = get_extensions_config()
return {"mcp_servers": {name: server.model_dump() for name, server in config.mcp_servers.items()}}
ext = AppConfig.current().extensions
return {"mcp_servers": {name: server.model_dump() for name, server in ext.mcp_servers.items()}}
def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict:
"""Update MCP server configurations.
@@ -844,18 +838,19 @@ class DeerFlowClient:
if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
current_config = get_extensions_config()
current_ext = AppConfig.current().extensions
config_data = {
"mcpServers": mcp_servers,
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()},
}
self._atomic_write_json(config_path, config_data)
self._agent = None
self._agent_config_key = None
reloaded = reload_extensions_config()
AppConfig.init(AppConfig.from_file())
reloaded = AppConfig.current().extensions
return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}}
# ------------------------------------------------------------------
@@ -909,19 +904,19 @@ class DeerFlowClient:
if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
extensions_config = get_extensions_config()
extensions_config.skills[name] = SkillStateConfig(enabled=enabled)
ext = AppConfig.current().extensions
ext.skills[name] = SkillStateConfig(enabled=enabled)
config_data = {
"mcpServers": {n: s.model_dump() for n, s in extensions_config.mcp_servers.items()},
"skills": {n: {"enabled": sc.enabled} for n, sc in extensions_config.skills.items()},
"mcpServers": {n: s.model_dump() for n, s in ext.mcp_servers.items()},
"skills": {n: {"enabled": sc.enabled} for n, sc in ext.skills.items()},
}
self._atomic_write_json(config_path, config_data)
self._agent = None
self._agent_config_key = None
reload_extensions_config()
AppConfig.init(AppConfig.from_file())
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
if updated is None:
@@ -1004,9 +999,7 @@ class DeerFlowClient:
Returns:
Memory config dict.
"""
from deerflow.config.memory_config import get_memory_config
config = get_memory_config()
config = AppConfig.current().memory
return {
"enabled": config.enabled,
"storage_path": config.storage_path,
@@ -25,7 +25,7 @@ except ImportError: # pragma: no cover - Windows fallback
fcntl = None # type: ignore[assignment]
import msvcrt
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider
@@ -119,16 +119,6 @@ class AioSandboxProvider(SandboxProvider):
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
self._start_idle_checker()
@property
def uses_thread_data_mounts(self) -> bool:
"""Whether thread workspace/uploads/outputs are visible via mounts.
Local container backends bind-mount the thread data directories, so files
written by the gateway are already visible when the sandbox starts.
Remote backends may require explicit file sync.
"""
return isinstance(self._backend, LocalContainerBackend)
# ── Factory methods ──────────────────────────────────────────────────
def _create_backend(self) -> SandboxBackend:
@@ -158,7 +148,7 @@ class AioSandboxProvider(SandboxProvider):
def _load_config(self) -> dict:
"""Load sandbox configuration from app config."""
config = get_app_config()
config = AppConfig.current()
sandbox_config = config.sandbox
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
@@ -289,7 +279,7 @@ class AioSandboxProvider(SandboxProvider):
so the host Docker daemon can resolve the path.
"""
try:
config = get_app_config()
config = AppConfig.current()
skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path
@@ -7,7 +7,7 @@ import logging
from langchain.tools import tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@@ -63,7 +63,7 @@ def web_search_tool(
query: Search keywords describing what you want to find. Be specific for better results.
max_results: Maximum number of results to return. Default is 5.
"""
config = get_app_config().get_tool_config("web_search")
config = AppConfig.current().get_tool_config("web_search")
# Override max_results from config if set
if config is not None and "max_results" in config.model_extra:
@@ -3,11 +3,11 @@ import json
from exa_py import Exa
from langchain.tools import tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
def _get_exa_client(tool_name: str = "web_search") -> Exa:
config = get_app_config().get_tool_config(tool_name)
config = AppConfig.current().get_tool_config(tool_name)
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
@@ -22,7 +22,7 @@ def web_search_tool(query: str) -> str:
query: The query to search for.
"""
try:
config = get_app_config().get_tool_config("web_search")
config = AppConfig.current().get_tool_config("web_search")
max_results = 5
search_type = "auto"
contents_max_characters = 1000
@@ -3,11 +3,11 @@ import json
from firecrawl import FirecrawlApp
from langchain.tools import tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
def _get_firecrawl_client(tool_name: str = "web_search") -> FirecrawlApp:
config = get_app_config().get_tool_config(tool_name)
config = AppConfig.current().get_tool_config(tool_name)
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
@@ -22,7 +22,7 @@ def web_search_tool(query: str) -> str:
query: The query to search for.
"""
try:
config = get_app_config().get_tool_config("web_search")
config = AppConfig.current().get_tool_config("web_search")
max_results = 5
if config is not None:
max_results = config.model_extra.get("max_results", max_results)
@@ -7,7 +7,7 @@ import logging
from langchain.tools import tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@@ -99,7 +99,7 @@ def image_search_tool(
type_image: Image type filter. Options: "photo", "clipart", "gif", "transparent", "line". Use "photo" for realistic references.
layout: Layout filter. Options: "Square", "Tall", "Wide". Choose based on your generation needs.
"""
config = get_app_config().get_tool_config("image_search")
config = AppConfig.current().get_tool_config("image_search")
# Override max_results from config if set
if config is not None and "max_results" in config.model_extra:
@@ -1,6 +1,6 @@
from langchain.tools import tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.utils.readability import ReadabilityExtractor
from .infoquest_client import InfoQuestClient
@@ -9,12 +9,12 @@ readability_extractor = ReadabilityExtractor()
def _get_infoquest_client() -> InfoQuestClient:
search_config = get_app_config().get_tool_config("web_search")
search_config = AppConfig.current().get_tool_config("web_search")
search_time_range = -1
if search_config is not None and "search_time_range" in search_config.model_extra:
search_time_range = search_config.model_extra.get("search_time_range")
fetch_config = get_app_config().get_tool_config("web_fetch")
fetch_config = AppConfig.current().get_tool_config("web_fetch")
fetch_time = -1
if fetch_config is not None and "fetch_time" in fetch_config.model_extra:
fetch_time = fetch_config.model_extra.get("fetch_time")
@@ -25,7 +25,7 @@ def _get_infoquest_client() -> InfoQuestClient:
if fetch_config is not None and "navigation_timeout" in fetch_config.model_extra:
navigation_timeout = fetch_config.model_extra.get("navigation_timeout")
image_search_config = get_app_config().get_tool_config("image_search")
image_search_config = AppConfig.current().get_tool_config("image_search")
image_search_time_range = -1
if image_search_config is not None and "image_search_time_range" in image_search_config.model_extra:
image_search_time_range = image_search_config.model_extra.get("image_search_time_range")
@@ -38,6 +38,6 @@ class JinaClient:
return response.text
except Exception as e:
error_message = f"Request to Jina API failed: {type(e).__name__}: {e}"
logger.warning(error_message)
error_message = f"Request to Jina API failed: {str(e)}"
logger.exception(error_message)
return f"Error: {error_message}"
@@ -1,9 +1,7 @@
import asyncio
from langchain.tools import tool
from deerflow.community.jina_ai.jina_client import JinaClient
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.utils.readability import ReadabilityExtractor
readability_extractor = ReadabilityExtractor()
@@ -22,11 +20,11 @@ async def web_fetch_tool(url: str) -> str:
"""
jina_client = JinaClient()
timeout = 10
config = get_app_config().get_tool_config("web_fetch")
config = AppConfig.current().get_tool_config("web_fetch")
if config is not None and "timeout" in config.model_extra:
timeout = config.model_extra.get("timeout")
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
if isinstance(html_content, str) and html_content.startswith("Error:"):
return html_content
article = await asyncio.to_thread(readability_extractor.extract_article, html_content)
article = readability_extractor.extract_article(html_content)
return article.to_markdown()[:4096]
@@ -3,11 +3,11 @@ import json
from langchain.tools import tool
from tavily import TavilyClient
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
def _get_tavily_client() -> TavilyClient:
config = get_app_config().get_tool_config("web_search")
config = AppConfig.current().get_tool_config("web_search")
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
@@ -21,7 +21,7 @@ def web_search_tool(query: str) -> str:
Args:
query: The query to search for.
"""
config = get_app_config().get_tool_config("web_search")
config = AppConfig.current().get_tool_config("web_search")
max_results = 5
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results")
@@ -1,6 +1,6 @@
from .app_config import get_app_config
from .extensions_config import ExtensionsConfig, get_extensions_config
from .memory_config import MemoryConfig, get_memory_config
from .app_config import AppConfig
from .extensions_config import ExtensionsConfig
from .memory_config import MemoryConfig
from .paths import Paths, get_paths
from .skill_evolution_config import SkillEvolutionConfig
from .skills_config import SkillsConfig
@@ -13,18 +13,16 @@ from .tracing_config import (
)
__all__ = [
"get_app_config",
"SkillEvolutionConfig",
"Paths",
"get_paths",
"SkillsConfig",
"AppConfig",
"ExtensionsConfig",
"get_extensions_config",
"MemoryConfig",
"get_memory_config",
"get_tracing_config",
"get_explicitly_enabled_tracing_providers",
"Paths",
"SkillEvolutionConfig",
"SkillsConfig",
"get_enabled_tracing_providers",
"get_explicitly_enabled_tracing_providers",
"get_paths",
"get_tracing_config",
"is_tracing_enabled",
"validate_enabled_tracing_providers",
]
@@ -1,16 +1,13 @@
"""ACP (Agent Client Protocol) agent configuration loaded from config.yaml."""
import logging
from collections.abc import Mapping
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
from pydantic import BaseModel, ConfigDict, Field
class ACPAgentConfig(BaseModel):
"""Configuration for a single ACP-compatible agent."""
model_config = ConfigDict(frozen=True)
command: str = Field(description="Command to launch the ACP agent subprocess")
args: list[str] = Field(default_factory=list, description="Additional command arguments")
env: dict[str, str] = Field(default_factory=dict, description="Environment variables to inject into the agent subprocess. Values starting with $ are resolved from host environment variables.")
@@ -24,28 +21,3 @@ class ACPAgentConfig(BaseModel):
"are denied — the agent must be configured to operate without requesting permissions."
),
)
_acp_agents: dict[str, ACPAgentConfig] = {}
def get_acp_agents() -> dict[str, ACPAgentConfig]:
"""Get the currently configured ACP agents.
Returns:
Mapping of agent name -> ACPAgentConfig. Empty dict if no ACP agents are configured.
"""
return _acp_agents
def load_acp_config_from_dict(config_dict: Mapping[str, Mapping[str, object]] | None) -> None:
"""Load ACP agent configuration from a dictionary (typically from config.yaml).
Args:
config_dict: Mapping of agent name -> config fields.
"""
global _acp_agents
if config_dict is None:
config_dict = {}
_acp_agents = {name: ACPAgentConfig(**cfg) for name, cfg in config_dict.items()}
logger.info("ACP config loaded: %d agent(s): %s", len(_acp_agents), list(_acp_agents.keys()))
@@ -1,32 +0,0 @@
"""Configuration for the custom agents management API."""
from pydantic import BaseModel, Field
class AgentsApiConfig(BaseModel):
"""Configuration for custom-agent and user-profile management routes."""
enabled: bool = Field(
default=False,
description=("Whether to expose the custom-agent management API over HTTP. When disabled, the gateway rejects read/write access to custom agent SOUL.md, config, and USER.md prompt-management routes."),
)
_agents_api_config: AgentsApiConfig = AgentsApiConfig()
def get_agents_api_config() -> AgentsApiConfig:
"""Get the current agents API configuration."""
return _agents_api_config
def set_agents_api_config(config: AgentsApiConfig) -> None:
"""Set the agents API configuration."""
global _agents_api_config
_agents_api_config = config
def load_agents_api_config_from_dict(config_dict: dict) -> None:
"""Load agents API configuration from a dictionary."""
global _agents_api_config
_agents_api_config = AgentsApiConfig(**config_dict)
@@ -5,7 +5,7 @@ import re
from typing import Any
import yaml
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from deerflow.config.paths import get_paths
@@ -15,20 +15,11 @@ SOUL_FILENAME = "SOUL.md"
AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
def validate_agent_name(name: str | None) -> str | None:
"""Validate a custom agent name before using it in filesystem paths."""
if name is None:
return None
if not isinstance(name, str):
raise ValueError("Invalid agent name. Expected a string or None.")
if not AGENT_NAME_PATTERN.fullmatch(name):
raise ValueError(f"Invalid agent name '{name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
return name
class AgentConfig(BaseModel):
"""Configuration for a custom agent."""
model_config = ConfigDict(frozen=True)
name: str
description: str = ""
model: str | None = None
@@ -57,7 +48,8 @@ def load_agent_config(name: str | None) -> AgentConfig | None:
if name is None:
return None
name = validate_agent_name(name)
if not AGENT_NAME_PATTERN.match(name):
raise ValueError(f"Invalid agent name '{name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
agent_dir = get_paths().agent_dir(name)
config_file = agent_dir / "config.yaml"
@@ -1,43 +1,37 @@
from __future__ import annotations
import logging
import os
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Self
from typing import Any, ClassVar, Self
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.guardrails_config import GuardrailsConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
from deerflow.config.stream_bridge_config import StreamBridgeConfig
from deerflow.config.subagents_config import SubagentsAppConfig
from deerflow.config.summarization_config import SummarizationConfig
from deerflow.config.title_config import TitleConfig
from deerflow.config.token_usage_config import TokenUsageConfig
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
from deerflow.config.tool_search_config import ToolSearchConfig
load_dotenv()
logger = logging.getLogger(__name__)
class CircuitBreakerConfig(BaseModel):
"""Configuration for the LLM Circuit Breaker."""
failure_threshold: int = Field(default=5, description="Number of consecutive failures before tripping the circuit")
recovery_timeout_sec: int = Field(default=60, description="Time in seconds before attempting to recover the circuit")
def _default_config_candidates() -> tuple[Path, ...]:
"""Return deterministic config.yaml locations without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
@@ -61,13 +55,12 @@ class AppConfig(BaseModel):
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
agents_api: AgentsApiConfig = Field(default_factory=AgentsApiConfig, description="Custom-agent management API configuration")
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
model_config = ConfigDict(extra="allow", frozen=False)
model_config = ConfigDict(extra="allow", frozen=True)
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
acp_agents: dict[str, ACPAgentConfig] = Field(default_factory=dict, description="ACP agent configurations keyed by agent name")
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path:
@@ -115,49 +108,6 @@ class AppConfig(BaseModel):
config_data = cls.resolve_env_variables(config_data)
# Load title config if present
if "title" in config_data:
load_title_config_from_dict(config_data["title"])
# Load summarization config if present
if "summarization" in config_data:
load_summarization_config_from_dict(config_data["summarization"])
# Load memory config if present
if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"])
# Always refresh agents API config so removed config sections reset
# singleton-backed state to its default/disabled values on reload.
load_agents_api_config_from_dict(config_data.get("agents_api") or {})
# Load subagents config if present
if "subagents" in config_data:
load_subagents_config_from_dict(config_data["subagents"])
# Load tool_search config if present
if "tool_search" in config_data:
load_tool_search_config_from_dict(config_data["tool_search"])
# Load guardrails config if present
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load circuit_breaker config if present
if "circuit_breaker" in config_data:
config_data["circuit_breaker"] = config_data["circuit_breaker"]
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])
# Load stream bridge config if present
if "stream_bridge" in config_data:
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
# Always refresh ACP agent config so removed entries do not linger across reloads.
load_acp_config_from_dict(config_data.get("acp_agents", {}))
# Load extensions config separately (it's in a different file)
extensions_config = ExtensionsConfig.from_file()
config_data["extensions"] = extensions_config.model_dump()
@@ -268,130 +218,26 @@ class AppConfig(BaseModel):
"""
return next((group for group in self.tool_groups if group.name == name), None)
# -- Lifecycle (class-level singleton via ContextVar) --
_app_config: AppConfig | None = None
_app_config_path: Path | None = None
_app_config_mtime: float | None = None
_app_config_is_custom = False
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
_current: ClassVar[ContextVar[AppConfig]] = ContextVar("deerflow_app_config")
@classmethod
def init(cls, config: AppConfig) -> None:
"""Set the AppConfig for the current context. Call once at process startup."""
cls._current.set(config)
def _get_config_mtime(config_path: Path) -> float | None:
"""Get the modification time of a config file if it exists."""
try:
return config_path.stat().st_mtime
except OSError:
return None
@classmethod
def current(cls) -> AppConfig:
"""Get the current AppConfig.
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
"""Load config from disk and refresh cache metadata."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
resolved_path = AppConfig.resolve_config_path(config_path)
_app_config = AppConfig.from_file(str(resolved_path))
_app_config_path = resolved_path
_app_config_mtime = _get_config_mtime(resolved_path)
_app_config_is_custom = False
return _app_config
def get_app_config() -> AppConfig:
"""Get the DeerFlow config instance.
Returns a cached singleton instance and automatically reloads it when the
underlying config file path or modification time changes. Use
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
the cache.
"""
global _app_config, _app_config_path, _app_config_mtime
runtime_override = _current_app_config.get()
if runtime_override is not None:
return runtime_override
if _app_config is not None and _app_config_is_custom:
return _app_config
resolved_path = AppConfig.resolve_config_path()
current_mtime = _get_config_mtime(resolved_path)
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
if should_reload:
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
logger.info(
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
_app_config_mtime,
current_mtime,
)
_load_and_cache_app_config(str(resolved_path))
return _app_config
def reload_app_config(config_path: str | None = None) -> AppConfig:
"""Reload the config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded AppConfig instance.
"""
return _load_and_cache_app_config(config_path)
def reset_app_config() -> None:
"""Reset the cached config instance.
This clears the singleton cache, causing the next call to
`get_app_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = None
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = False
def set_app_config(config: AppConfig) -> None:
"""Set a custom config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The AppConfig instance to use.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = config
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = True
def peek_current_app_config() -> AppConfig | None:
"""Return the runtime-scoped AppConfig override, if one is active."""
return _current_app_config.get()
def push_current_app_config(config: AppConfig) -> None:
"""Push a runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
_current_app_config_stack.set(stack + (_current_app_config.get(),))
_current_app_config.set(config)
def pop_current_app_config() -> None:
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
if not stack:
_current_app_config.set(None)
return
previous = stack[-1]
_current_app_config_stack.set(stack[:-1])
_current_app_config.set(previous)
Auto-initializes from config file on first access for backward compatibility.
Prefer calling AppConfig.init() explicitly at process startup.
"""
try:
return cls._current.get()
except LookupError:
logger.debug("AppConfig not initialized, auto-loading from file")
config = cls.from_file()
cls._current.set(config)
return config
@@ -2,7 +2,7 @@
from typing import Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
CheckpointerType = Literal["memory", "sqlite", "postgres"]
@@ -10,6 +10,8 @@ CheckpointerType = Literal["memory", "sqlite", "postgres"]
class CheckpointerConfig(BaseModel):
"""Configuration for LangGraph state persistence checkpointer."""
model_config = ConfigDict(frozen=True)
type: CheckpointerType = Field(
description="Checkpointer backend type. "
"'memory' is in-process only (lost on restart). "
@@ -23,24 +25,3 @@ class CheckpointerConfig(BaseModel):
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
)
# Global configuration instance — None means no checkpointer is configured.
_checkpointer_config: CheckpointerConfig | None = None
def get_checkpointer_config() -> CheckpointerConfig | None:
"""Get the current checkpointer configuration, or None if not configured."""
return _checkpointer_config
def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
"""Set the checkpointer configuration."""
global _checkpointer_config
_checkpointer_config = config
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
"""Load checkpointer configuration from a dictionary."""
global _checkpointer_config
_checkpointer_config = CheckpointerConfig(**config_dict)
@@ -0,0 +1,59 @@
"""Per-invocation context for DeerFlow agent execution.
Injected via LangGraph Runtime. Middleware and tools access this
via Runtime[DeerFlowContext] parameters, through resolve_context().
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True)
class DeerFlowContext:
"""Typed, immutable, per-invocation context injected via LangGraph Runtime.
Fields are all known at run start and never change during execution.
Mutable runtime state (e.g. sandbox_id) flows through ThreadState, not here.
"""
app_config: Any # AppConfig — typed as Any to avoid circular import at module level
thread_id: str
agent_name: str | None = None
def resolve_context(runtime: Any) -> DeerFlowContext:
"""Extract or construct DeerFlowContext from runtime.
Gateway/Client paths: runtime.context is already DeerFlowContext return directly.
LangGraph Server / legacy dict path: construct from dict context or configurable fallback.
"""
ctx = getattr(runtime, "context", None)
if isinstance(ctx, DeerFlowContext):
return ctx
from deerflow.config.app_config import AppConfig
# Try dict context first (legacy path, tests), then configurable
if isinstance(ctx, dict):
return DeerFlowContext(
app_config=AppConfig.current(),
thread_id=ctx.get("thread_id", ""),
agent_name=ctx.get("agent_name"),
)
# No context at all — fall back to LangGraph configurable
try:
from langgraph.config import get_config
cfg = get_config().get("configurable", {})
except RuntimeError:
# Outside runnable context (e.g. unit tests)
cfg = {}
return DeerFlowContext(
app_config=AppConfig.current(),
thread_id=cfg.get("thread_id", ""),
agent_name=cfg.get("agent_name"),
)
@@ -11,6 +11,8 @@ from pydantic import BaseModel, ConfigDict, Field
class McpOAuthConfig(BaseModel):
"""OAuth configuration for an MCP server (HTTP/SSE transports)."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
token_url: str = Field(description="OAuth token endpoint URL")
grant_type: Literal["client_credentials", "refresh_token"] = Field(
@@ -28,12 +30,13 @@ class McpOAuthConfig(BaseModel):
default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response")
refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry")
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
model_config = ConfigDict(extra="allow")
class McpServerConfig(BaseModel):
"""Configuration for a single MCP server."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
@@ -43,12 +46,13 @@ class McpServerConfig(BaseModel):
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)")
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
model_config = ConfigDict(extra="allow")
class SkillStateConfig(BaseModel):
"""Configuration for a single skill's state."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=True, description="Whether this skill is enabled")
@@ -64,7 +68,7 @@ class ExtensionsConfig(BaseModel):
default_factory=dict,
description="Map of skill name to state configuration",
)
model_config = ConfigDict(extra="allow", populate_by_name=True)
model_config = ConfigDict(extra="allow", frozen=True, populate_by_name=True)
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path | None:
@@ -195,62 +199,3 @@ class ExtensionsConfig(BaseModel):
# Default to enable for public & custom skill
return skill_category in ("public", "custom")
return skill_config.enabled
_extensions_config: ExtensionsConfig | None = None
def get_extensions_config() -> ExtensionsConfig:
"""Get the extensions config instance.
Returns a cached singleton instance. Use `reload_extensions_config()` to reload
from file, or `reset_extensions_config()` to clear the cache.
Returns:
The cached ExtensionsConfig instance.
"""
global _extensions_config
if _extensions_config is None:
_extensions_config = ExtensionsConfig.from_file()
return _extensions_config
def reload_extensions_config(config_path: str | None = None) -> ExtensionsConfig:
"""Reload the extensions config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to extensions config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded ExtensionsConfig instance.
"""
global _extensions_config
_extensions_config = ExtensionsConfig.from_file(config_path)
return _extensions_config
def reset_extensions_config() -> None:
"""Reset the cached extensions config instance.
This clears the singleton cache, causing the next call to
`get_extensions_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _extensions_config
_extensions_config = None
def set_extensions_config(config: ExtensionsConfig) -> None:
"""Set a custom extensions config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The ExtensionsConfig instance to use.
"""
global _extensions_config
_extensions_config = config
@@ -1,11 +1,13 @@
"""Configuration for pre-tool-call authorization."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class GuardrailProviderConfig(BaseModel):
"""Configuration for a guardrail provider."""
model_config = ConfigDict(frozen=True)
use: str = Field(description="Class path (e.g. 'deerflow.guardrails.builtin:AllowlistProvider')")
config: dict = Field(default_factory=dict, description="Provider-specific settings passed as kwargs")
@@ -18,31 +20,9 @@ class GuardrailsConfig(BaseModel):
agent's passport reference, and returns an allow/deny decision.
"""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=False, description="Enable guardrail middleware")
fail_closed: bool = Field(default=True, description="Block tool calls if provider errors")
passport: str | None = Field(default=None, description="OAP passport path or hosted agent ID")
provider: GuardrailProviderConfig | None = Field(default=None, description="Guardrail provider configuration")
_guardrails_config: GuardrailsConfig | None = None
def get_guardrails_config() -> GuardrailsConfig:
"""Get the guardrails config, returning defaults if not loaded."""
global _guardrails_config
if _guardrails_config is None:
_guardrails_config = GuardrailsConfig()
return _guardrails_config
def load_guardrails_config_from_dict(data: dict) -> GuardrailsConfig:
"""Load guardrails config from a dict (called during AppConfig loading)."""
global _guardrails_config
_guardrails_config = GuardrailsConfig.model_validate(data)
return _guardrails_config
def reset_guardrails_config() -> None:
"""Reset the cached config instance. Used in tests to prevent singleton leaks."""
global _guardrails_config
_guardrails_config = None
@@ -1,11 +1,13 @@
"""Configuration for memory mechanism."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class MemoryConfig(BaseModel):
"""Configuration for global memory mechanism."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=True,
description="Whether to enable memory mechanism",
@@ -59,24 +61,3 @@ class MemoryConfig(BaseModel):
le=8000,
description="Maximum tokens to use for memory injection",
)
# Global configuration instance
_memory_config: MemoryConfig = MemoryConfig()
def get_memory_config() -> MemoryConfig:
"""Get the current memory configuration."""
return _memory_config
def set_memory_config(config: MemoryConfig) -> None:
"""Set the memory configuration."""
global _memory_config
_memory_config = config
def load_memory_config_from_dict(config_dict: dict) -> None:
"""Load memory configuration from a dictionary."""
global _memory_config
_memory_config = MemoryConfig(**config_dict)
@@ -12,7 +12,7 @@ class ModelConfig(BaseModel):
description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)",
)
model: str = Field(..., description="Model name")
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", frozen=True)
use_responses_api: bool | None = Field(
default=None,
description="Whether to route OpenAI ChatOpenAI calls through the /v1/responses API",
@@ -4,6 +4,8 @@ from pydantic import BaseModel, ConfigDict, Field
class VolumeMountConfig(BaseModel):
"""Configuration for a volume mount."""
model_config = ConfigDict(frozen=True)
host_path: str = Field(..., description="Path on the host machine")
container_path: str = Field(..., description="Path inside the container")
read_only: bool = Field(default=False, description="Whether the mount is read-only")
@@ -80,4 +82,4 @@ class SandboxConfig(BaseModel):
description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
)
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", frozen=True)
@@ -1,9 +1,11 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class SkillEvolutionConfig(BaseModel):
"""Configuration for agent-managed skill evolution."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=False,
description="Whether the agent can create and modify skills under skills/custom.",
@@ -1,6 +1,6 @@
from pathlib import Path
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
def _default_repo_root() -> Path:
@@ -11,6 +11,8 @@ def _default_repo_root() -> Path:
class SkillsConfig(BaseModel):
"""Configuration for skills system"""
model_config = ConfigDict(frozen=True)
path: str | None = Field(
default=None,
description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory",
@@ -2,7 +2,7 @@
from typing import Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
StreamBridgeType = Literal["memory", "redis"]
@@ -10,6 +10,8 @@ StreamBridgeType = Literal["memory", "redis"]
class StreamBridgeConfig(BaseModel):
"""Configuration for the stream bridge that connects agent workers to SSE endpoints."""
model_config = ConfigDict(frozen=True)
type: StreamBridgeType = Field(
default="memory",
description="Stream bridge backend type. 'memory' uses in-process asyncio.Queue (single-process only). 'redis' uses Redis Streams (planned for Phase 2, not yet implemented).",
@@ -22,25 +24,3 @@ class StreamBridgeConfig(BaseModel):
default=256,
description="Maximum number of events buffered per run in the memory bridge.",
)
# Global configuration instance — None means no stream bridge is configured
# (falls back to memory with defaults).
_stream_bridge_config: StreamBridgeConfig | None = None
def get_stream_bridge_config() -> StreamBridgeConfig | None:
"""Get the current stream bridge configuration, or None if not configured."""
return _stream_bridge_config
def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
"""Set the stream bridge configuration."""
global _stream_bridge_config
_stream_bridge_config = config
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
"""Load stream bridge configuration from a dictionary."""
global _stream_bridge_config
_stream_bridge_config = StreamBridgeConfig(**config_dict)
@@ -1,15 +1,13 @@
"""Configuration for the subagent system loaded from config.yaml."""
import logging
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
from pydantic import BaseModel, ConfigDict, Field
class SubagentOverrideConfig(BaseModel):
"""Per-agent configuration overrides."""
model_config = ConfigDict(frozen=True)
timeout_seconds: int | None = Field(
default=None,
ge=1,
@@ -20,57 +18,13 @@ class SubagentOverrideConfig(BaseModel):
ge=1,
description="Maximum turns for this subagent (None = use global or builtin default)",
)
model: str | None = Field(
default=None,
min_length=1,
description="Model name for this subagent (None = inherit from parent agent)",
)
skills: list[str] | None = Field(
default=None,
description="Skill names whitelist for this subagent (None = inherit all enabled skills, [] = no skills)",
)
class CustomSubagentConfig(BaseModel):
"""User-defined subagent type declared in config.yaml."""
description: str = Field(
description="When the lead agent should delegate to this subagent",
)
system_prompt: str = Field(
description="System prompt that guides the subagent's behavior",
)
tools: list[str] | None = Field(
default=None,
description="Tool names whitelist (None = inherit all tools from parent)",
)
disallowed_tools: list[str] | None = Field(
default_factory=lambda: ["task", "ask_clarification", "present_files"],
description="Tool names to deny",
)
skills: list[str] | None = Field(
default=None,
description="Skill names whitelist (None = inherit all enabled skills, [] = no skills)",
)
model: str = Field(
default="inherit",
description="Model to use - 'inherit' uses parent's model",
)
max_turns: int = Field(
default=50,
ge=1,
description="Maximum number of agent turns before stopping",
)
timeout_seconds: int = Field(
default=900,
ge=1,
description="Maximum execution time in seconds",
)
class SubagentsAppConfig(BaseModel):
"""Configuration for the subagent system."""
model_config = ConfigDict(frozen=True)
timeout_seconds: int = Field(
default=900,
ge=1,
@@ -85,10 +39,6 @@ class SubagentsAppConfig(BaseModel):
default_factory=dict,
description="Per-agent configuration overrides keyed by agent name",
)
custom_agents: dict[str, CustomSubagentConfig] = Field(
default_factory=dict,
description="User-defined subagent types keyed by agent name",
)
def get_timeout_for(self, agent_name: str) -> int:
"""Get the effective timeout for a specific agent.
@@ -104,20 +54,6 @@ class SubagentsAppConfig(BaseModel):
return override.timeout_seconds
return self.timeout_seconds
def get_model_for(self, agent_name: str) -> str | None:
"""Get the model override for a specific agent.
Args:
agent_name: The name of the subagent.
Returns:
Model name if overridden, None otherwise (subagent will inherit parent model).
"""
override = self.agents.get(agent_name)
if override is not None and override.model is not None:
return override.model
return None
def get_max_turns_for(self, agent_name: str, builtin_default: int) -> int:
"""Get the effective max_turns for a specific agent."""
override = self.agents.get(agent_name)
@@ -126,62 +62,3 @@ class SubagentsAppConfig(BaseModel):
if self.max_turns is not None:
return self.max_turns
return builtin_default
def get_skills_for(self, agent_name: str) -> list[str] | None:
"""Get the skills override for a specific agent.
Args:
agent_name: The name of the subagent.
Returns:
Skill names whitelist if overridden, None otherwise (subagent will inherit all enabled skills).
"""
override = self.agents.get(agent_name)
if override is not None and override.skills is not None:
return override.skills
return None
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
def get_subagents_app_config() -> SubagentsAppConfig:
"""Get the current subagents configuration."""
return _subagents_config
def load_subagents_config_from_dict(config_dict: dict) -> None:
"""Load subagents configuration from a dictionary."""
global _subagents_config
_subagents_config = SubagentsAppConfig(**config_dict)
overrides_summary = {}
for name, override in _subagents_config.agents.items():
parts = []
if override.timeout_seconds is not None:
parts.append(f"timeout={override.timeout_seconds}s")
if override.max_turns is not None:
parts.append(f"max_turns={override.max_turns}")
if override.model is not None:
parts.append(f"model={override.model}")
if override.skills is not None:
parts.append(f"skills={override.skills}")
if parts:
overrides_summary[name] = ", ".join(parts)
custom_agents_names = list(_subagents_config.custom_agents.keys())
if overrides_summary or custom_agents_names:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s, custom_agents=%s",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
overrides_summary or "none",
custom_agents_names or "none",
)
else:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
)
@@ -2,7 +2,7 @@
from typing import Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
ContextSizeType = Literal["fraction", "tokens", "messages"]
@@ -10,6 +10,8 @@ ContextSizeType = Literal["fraction", "tokens", "messages"]
class ContextSize(BaseModel):
"""Context size specification for trigger or keep parameters."""
model_config = ConfigDict(frozen=True)
type: ContextSizeType = Field(description="Type of context size specification")
value: int | float = Field(description="Value for the context size specification")
@@ -21,6 +23,8 @@ class ContextSize(BaseModel):
class SummarizationConfig(BaseModel):
"""Configuration for automatic conversation summarization."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=False,
description="Whether to enable automatic conversation summarization",
@@ -51,43 +55,3 @@ class SummarizationConfig(BaseModel):
default=None,
description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.",
)
preserve_recent_skill_count: int = Field(
default=5,
ge=0,
description="Number of most-recently-loaded skill files to exclude from summarization. Set to 0 to disable skill preservation.",
)
preserve_recent_skill_tokens: int = Field(
default=25000,
ge=0,
description="Total token budget reserved for recently-loaded skill files that must be preserved across summarization.",
)
preserve_recent_skill_tokens_per_skill: int = Field(
default=5000,
ge=0,
description="Per-skill token cap when preserving skill files across summarization. Skill reads above this size are not rescued.",
)
skill_file_read_tool_names: list[str] = Field(
default_factory=lambda: ["read_file", "read", "view", "cat"],
description="Tool names treated as skill file reads when preserving recently-loaded skills across summarization.",
)
# Global configuration instance
_summarization_config: SummarizationConfig = SummarizationConfig()
def get_summarization_config() -> SummarizationConfig:
"""Get the current summarization configuration."""
return _summarization_config
def set_summarization_config(config: SummarizationConfig) -> None:
"""Set the summarization configuration."""
global _summarization_config
_summarization_config = config
def load_summarization_config_from_dict(config_dict: dict) -> None:
"""Load summarization configuration from a dictionary."""
global _summarization_config
_summarization_config = SummarizationConfig(**config_dict)
@@ -1,11 +1,13 @@
"""Configuration for automatic thread title generation."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class TitleConfig(BaseModel):
"""Configuration for automatic thread title generation."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=True,
description="Whether to enable automatic title generation",
@@ -30,24 +32,3 @@ class TitleConfig(BaseModel):
default=("Generate a concise title (max {max_words} words) for this conversation.\nUser: {user_msg}\nAssistant: {assistant_msg}\n\nReturn ONLY the title, no quotes, no explanation."),
description="Prompt template for title generation",
)
# Global configuration instance
_title_config: TitleConfig = TitleConfig()
def get_title_config() -> TitleConfig:
"""Get the current title configuration."""
return _title_config
def set_title_config(config: TitleConfig) -> None:
"""Set the title configuration."""
global _title_config
_title_config = config
def load_title_config_from_dict(config_dict: dict) -> None:
"""Load title configuration from a dictionary."""
global _title_config
_title_config = TitleConfig(**config_dict)
@@ -1,7 +1,9 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class TokenUsageConfig(BaseModel):
"""Configuration for token usage tracking."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
@@ -5,7 +5,7 @@ class ToolGroupConfig(BaseModel):
"""Config section for a tool group"""
name: str = Field(..., description="Unique name for the tool group")
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", frozen=True)
class ToolConfig(BaseModel):
@@ -17,4 +17,4 @@ class ToolConfig(BaseModel):
...,
description="Variable name of the tool provider(e.g. deerflow.sandbox.tools:bash_tool)",
)
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", frozen=True)
@@ -1,6 +1,6 @@
"""Configuration for deferred tool loading via tool_search."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class ToolSearchConfig(BaseModel):
@@ -11,25 +11,9 @@ class ToolSearchConfig(BaseModel):
via the tool_search tool at runtime.
"""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=False,
description="Defer tools and enable tool_search",
)
_tool_search_config: ToolSearchConfig | None = None
def get_tool_search_config() -> ToolSearchConfig:
"""Get the tool search config, loading from AppConfig if needed."""
global _tool_search_config
if _tool_search_config is None:
_tool_search_config = ToolSearchConfig()
return _tool_search_config
def load_tool_search_config_from_dict(data: dict) -> ToolSearchConfig:
"""Load tool search config from a dict (called during AppConfig loading)."""
global _tool_search_config
_tool_search_config = ToolSearchConfig.model_validate(data)
return _tool_search_config
@@ -1,7 +1,7 @@
import os
import threading
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
_config_lock = threading.Lock()
@@ -9,6 +9,8 @@ _config_lock = threading.Lock()
class LangSmithTracingConfig(BaseModel):
"""Configuration for LangSmith tracing."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(...)
api_key: str | None = Field(...)
project: str = Field(...)
@@ -26,6 +28,8 @@ class LangSmithTracingConfig(BaseModel):
class LangfuseTracingConfig(BaseModel):
"""Configuration for Langfuse tracing."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(...)
public_key: str | None = Field(...)
secret_key: str | None = Field(...)
@@ -50,6 +54,8 @@ class LangfuseTracingConfig(BaseModel):
class TracingConfig(BaseModel):
"""Tracing configuration for supported providers."""
model_config = ConfigDict(frozen=True)
langsmith: LangSmithTracingConfig = Field(...)
langfuse: LangfuseTracingConfig = Field(...)
@@ -118,13 +118,9 @@ def get_cached_mcp_tools() -> list[BaseTool]:
loop.run_until_complete(initialize_mcp_tools())
except RuntimeError:
# No event loop exists, create one
try:
asyncio.run(initialize_mcp_tools())
except Exception:
logger.exception("Failed to lazy-initialize MCP tools")
return []
except Exception:
logger.exception("Failed to lazy-initialize MCP tools")
asyncio.run(initialize_mcp_tools())
except Exception as e:
logger.error(f"Failed to lazy-initialize MCP tools: {e}")
return []
return _mcp_tools_cache or []
@@ -12,7 +12,6 @@ from langchain_core.tools import BaseTool
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.mcp.client import build_servers_config
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
from deerflow.reflection import resolve_variable
logger = logging.getLogger(__name__)
@@ -96,27 +95,6 @@ async def get_mcp_tools() -> list[BaseTool]:
if oauth_interceptor is not None:
tool_interceptors.append(oauth_interceptor)
# Load custom interceptors declared in extensions_config.json
# Format: "mcpInterceptors": ["pkg.module:builder_func", ...]
raw_interceptor_paths = (extensions_config.model_extra or {}).get("mcpInterceptors")
if isinstance(raw_interceptor_paths, str):
raw_interceptor_paths = [raw_interceptor_paths]
elif not isinstance(raw_interceptor_paths, list):
if raw_interceptor_paths is not None:
logger.warning(f"mcpInterceptors must be a list of strings, got {type(raw_interceptor_paths).__name__}; skipping")
raw_interceptor_paths = []
for interceptor_path in raw_interceptor_paths:
try:
builder = resolve_variable(interceptor_path)
interceptor = builder()
if callable(interceptor):
tool_interceptors.append(interceptor)
logger.info(f"Loaded MCP interceptor: {interceptor_path}")
elif interceptor is not None:
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
except Exception as e:
logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True)
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
# Get all tools from all servers
@@ -190,33 +190,23 @@ class ClaudeChatModel(ChatAnthropic):
)
def _apply_prompt_caching(self, payload: dict) -> None:
"""Apply ephemeral cache_control to system, recent messages, and last tool definition.
Uses a budget of MAX_CACHE_BREAKPOINTS (4) breakpoints the hard limit
enforced by both the Anthropic API and AWS Bedrock. Breakpoints are
placed on the *last* eligible blocks because later breakpoints cover a
larger prefix and yield better cache hit rates.
"""
MAX_CACHE_BREAKPOINTS = 4
# Collect candidate blocks in document order:
# 1. system text blocks
# 2. content blocks of the last prompt_cache_size messages
# 3. the last tool definition
candidates: list[dict] = []
# 1. System blocks
"""Apply ephemeral cache_control to system and recent messages."""
# Cache system messages
system = payload.get("system")
if system and isinstance(system, list):
for block in system:
if isinstance(block, dict) and block.get("type") == "text":
candidates.append(block)
block["cache_control"] = {"type": "ephemeral"}
elif system and isinstance(system, str):
new_block: dict = {"type": "text", "text": system}
payload["system"] = [new_block]
candidates.append(new_block)
payload["system"] = [
{
"type": "text",
"text": system,
"cache_control": {"type": "ephemeral"},
}
]
# 2. Recent message blocks
# Cache recent messages
messages = payload.get("messages", [])
cache_start = max(0, len(messages) - self.prompt_cache_size)
for i in range(cache_start, len(messages)):
@@ -227,21 +217,20 @@ class ClaudeChatModel(ChatAnthropic):
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
candidates.append(block)
block["cache_control"] = {"type": "ephemeral"}
elif isinstance(content, str) and content:
new_block = {"type": "text", "text": content}
msg["content"] = [new_block]
candidates.append(new_block)
msg["content"] = [
{
"type": "text",
"text": content,
"cache_control": {"type": "ephemeral"},
}
]
# 3. Last tool definition
# Cache the last tool definition
tools = payload.get("tools", [])
if tools and isinstance(tools[-1], dict):
candidates.append(tools[-1])
# Apply cache_control only to the last MAX_CACHE_BREAKPOINTS candidates
# to stay within the API limit.
for block in candidates[-MAX_CACHE_BREAKPOINTS:]:
block["cache_control"] = {"type": "ephemeral"}
tools[-1]["cache_control"] = {"type": "ephemeral"}
def _apply_thinking_budget(self, payload: dict) -> None:
"""Auto-allocate thinking budget (80% of max_tokens)."""
@@ -2,7 +2,7 @@ import logging
from langchain.chat_models import BaseChatModel
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks
@@ -30,22 +30,6 @@ def _vllm_disable_chat_template_kwargs(chat_template_kwargs: dict) -> dict:
return disable_kwargs
def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_config: dict) -> None:
"""Enable stream usage for OpenAI-compatible models unless explicitly configured.
LangChain only auto-enables ``stream_usage`` for OpenAI models when no custom
base URL or client is configured. DeerFlow frequently uses OpenAI-compatible
gateways, so token usage tracking would otherwise stay empty and the
TokenUsageMiddleware would have nothing to log.
"""
if model_use_path != "langchain_openai:ChatOpenAI":
return
if "stream_usage" in model_settings_from_config:
return
if "base_url" in model_settings_from_config or "openai_api_base" in model_settings_from_config:
model_settings_from_config["stream_usage"] = True
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
"""Create a chat model instance from the config.
@@ -55,7 +39,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
Returns:
A chat model instance.
"""
config = get_app_config()
config = AppConfig.current()
if name is None:
name = config.models[0].name
model_config = config.get_model_config(name)
@@ -113,8 +97,6 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
kwargs.pop("reasoning_effort", None)
model_settings_from_config.pop("reasoning_effort", None)
_enable_stream_usage_by_default(model_config.use, model_settings_from_config)
# For Codex Responses API models: map thinking mode to reasoning_effort
from deerflow.models.openai_codex_provider import CodexChatModel
@@ -131,12 +113,6 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
elif "reasoning_effort" not in model_settings_from_config:
model_settings_from_config["reasoning_effort"] = "medium"
# For MindIE models: enforce conservative retry defaults.
# Timeout normalization is handled inside MindIEChatModel itself.
if getattr(model_class, "__name__", "") == "MindIEChatModel":
# Enforce max_retries constraint to prevent cascading timeouts.
model_settings_from_config["max_retries"] = model_settings_from_config.get("max_retries", 1)
model_instance = model_class(**{**model_settings_from_config, **kwargs})
callbacks = build_tracing_callbacks()
@@ -1,237 +0,0 @@
import ast
import json
import re
import uuid
from collections.abc import Iterator
import httpx
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
def _fix_messages(messages: list) -> list:
"""Sanitize incoming messages for MindIE compatibility.
MindIE's chat template may fail to parse LangChain's native tool_calls
or ToolMessage roles, resulting in 0-token generation errors. This function
flattens multi-modal list contents into strings and converts tool-related
messages into raw text with XML tags expected by the underlying model.
"""
fixed = []
for msg in messages:
# Flatten content if it's a list of blocks
if isinstance(msg.content, list):
parts = []
for block in msg.content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict) and block.get("type") == "text":
parts.append(block.get("text", ""))
text = "".join(parts)
else:
text = msg.content or ""
# Convert AIMessage with tool_calls to raw XML text format
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", []):
xml_parts = []
for tool in msg.tool_calls:
args_xml = " ".join(f"<parameter={k}>{json.dumps(v, ensure_ascii=False)}</parameter>" for k, v in tool.get("args", {}).items())
xml_parts.append(f"<tool_call> <function={tool['name']}> {args_xml} </function> </tool_call>")
full_text = f"{text}\n" + "\n".join(xml_parts) if text else "\n".join(xml_parts)
fixed.append(AIMessage(content=full_text.strip() or " "))
continue
# Wrap tool execution results in XML tags and convert to HumanMessage
if isinstance(msg, ToolMessage):
tool_result_text = f"<tool_response>\n{text}\n</tool_response>"
fixed.append(HumanMessage(content=tool_result_text))
continue
# Fallback to prevent completely empty message content
if not text.strip():
text = " "
fixed.append(msg.model_copy(update={"content": text}))
return fixed
def _parse_xml_tool_call_to_dict(content: str) -> tuple[str, list[dict]]:
"""Parse XML-style tool calls from model output into LangChain dicts.
Args:
content: The raw text output from the model.
Returns:
A tuple containing the cleaned text (with XML blocks removed) and
a list of tool call dictionaries formatted for LangChain.
"""
if not isinstance(content, str) or "<tool_call>" not in content:
return content, []
tool_calls = []
clean_parts: list[str] = []
cursor = 0
for start, end, inner_content in _iter_tool_call_blocks(content):
clean_parts.append(content[cursor:start])
cursor = end
func_match = re.search(r"<function=([^>]+)>", inner_content)
if not func_match:
continue
function_name = func_match.group(1).strip()
args = {}
param_pattern = re.compile(r"<parameter=([^>]+)>(.*?)</parameter>", re.DOTALL)
for param_match in param_pattern.finditer(inner_content):
key = param_match.group(1).strip()
raw_value = param_match.group(2).strip()
# Attempt to deserialize string values into native Python types
# to satisfy downstream Pydantic validation.
parsed_value = raw_value
if raw_value.startswith(("[", "{")) or raw_value in ("true", "false", "null") or raw_value.isdigit():
try:
parsed_value = json.loads(raw_value)
except json.JSONDecodeError:
try:
parsed_value = ast.literal_eval(raw_value)
except (ValueError, SyntaxError):
pass
args[key] = parsed_value
tool_calls.append({"name": function_name, "args": args, "id": f"call_{uuid.uuid4().hex[:10]}"})
clean_parts.append(content[cursor:])
return "".join(clean_parts).strip(), tool_calls
def _iter_tool_call_blocks(content: str) -> Iterator[tuple[int, int, str]]:
"""Iterate `<tool_call>...</tool_call>` blocks and tolerate nesting."""
token_pattern = re.compile(r"</?tool_call>")
depth = 0
block_start = -1
for match in token_pattern.finditer(content):
token = match.group(0)
if token == "<tool_call>":
if depth == 0:
block_start = match.start()
depth += 1
continue
if depth == 0:
continue
depth -= 1
if depth == 0 and block_start != -1:
block_end = match.end()
inner_start = block_start + len("<tool_call>")
inner_end = match.start()
yield block_start, block_end, content[inner_start:inner_end]
block_start = -1
def _decode_escaped_newlines_outside_fences(content: str) -> str:
"""Decode literal `\\n` outside fenced code blocks."""
if "\\n" not in content:
return content
parts = re.split(r"(```[\s\S]*?```)", content)
for idx, part in enumerate(parts):
if part.startswith("```"):
continue
parts[idx] = part.replace("\\n", "\n")
return "".join(parts)
class MindIEChatModel(ChatOpenAI):
"""Chat model adapter for MindIE engine.
Addresses compatibility issues including:
- Flattening multimodal list contents to strings.
- Intercepting and parsing hardcoded XML tool calls into LangChain standard.
- Handling stream=True dropping choices when tools are present by falling back
to non-streaming generation and yielding simulated chunks.
- Fixing over-escaped newline characters from gateway responses.
"""
def __init__(self, **kwargs):
"""Normalize timeout kwargs without creating long-lived clients."""
connect_timeout = kwargs.pop("connect_timeout", 30.0)
read_timeout = kwargs.pop("read_timeout", 900.0)
write_timeout = kwargs.pop("write_timeout", 60.0)
pool_timeout = kwargs.pop("pool_timeout", 30.0)
kwargs.setdefault(
"timeout",
httpx.Timeout(
connect=connect_timeout,
read=read_timeout,
write=write_timeout,
pool=pool_timeout,
),
)
super().__init__(**kwargs)
def _patch_result_with_tools(self, result: ChatResult) -> ChatResult:
"""Apply post-generation fixes to the model result."""
for gen in result.generations:
msg = gen.message
if isinstance(msg.content, str):
# Keep escaped newlines inside fenced code blocks untouched.
msg.content = _decode_escaped_newlines_outside_fences(msg.content)
if "<tool_call>" in msg.content:
clean_content, extracted_tools = _parse_xml_tool_call_to_dict(msg.content)
if extracted_tools:
msg.content = clean_content
if getattr(msg, "tool_calls", None) is None:
msg.tool_calls = []
msg.tool_calls.extend(extracted_tools)
return result
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
result = super()._generate(_fix_messages(messages), stop=stop, run_manager=run_manager, **kwargs)
return self._patch_result_with_tools(result)
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
result = await super()._agenerate(_fix_messages(messages), stop=stop, run_manager=run_manager, **kwargs)
return self._patch_result_with_tools(result)
async def _astream(self, messages, stop=None, run_manager=None, **kwargs):
# Route standard queries to native streaming for lower TTFB
if not kwargs.get("tools"):
async for chunk in super()._astream(_fix_messages(messages), stop=stop, run_manager=run_manager, **kwargs):
if isinstance(chunk.message.content, str):
chunk.message.content = _decode_escaped_newlines_outside_fences(chunk.message.content)
yield chunk
return
# Fallback for tool-enabled requests:
# MindIE currently drops choices when stream=True and tools are present.
# We await the full generation and yield chunks to simulate streaming.
result = await self._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs)
for gen in result.generations:
msg = gen.message
content = msg.content
standard_tool_calls = getattr(msg, "tool_calls", [])
# Yield text in chunks to allow downstream UI/Markdown parsers to render smoothly
if isinstance(content, str) and content:
chunk_size = 15
for i in range(0, len(content), chunk_size):
chunk_text = content[i : i + chunk_size]
chunk_msg = AIMessageChunk(content=chunk_text, id=msg.id, response_metadata=msg.response_metadata if i == 0 else {})
yield ChatGenerationChunk(message=chunk_msg, generation_info=gen.generation_info if i == 0 else None)
if standard_tool_calls:
yield ChatGenerationChunk(message=AIMessageChunk(content="", id=msg.id, tool_calls=standard_tool_calls, invalid_tool_calls=getattr(msg, "invalid_tool_calls", [])))
else:
chunk_msg = AIMessageChunk(content=content, id=msg.id, tool_calls=standard_tool_calls, invalid_tool_calls=getattr(msg, "invalid_tool_calls", []))
yield ChatGenerationChunk(message=chunk_msg, generation_info=gen.generation_info)
@@ -21,6 +21,8 @@ import inspect
import logging
from typing import Any, Literal
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge
@@ -98,17 +100,14 @@ async def run_agent(
# 3. Build the agent
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
# Inject runtime context so middlewares can access thread_id
# (langgraph-cli does this automatically; we must do it manually)
runtime = Runtime(context={"thread_id": thread_id}, store=store)
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
# prefers it over ``configurable`` for thread-level data), make
# sure ``thread_id`` is available there too.
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Construct typed context for the agent run.
# LangGraph's astream(context=...) injects this into Runtime.context
# so middleware/tools can access it via resolve_context().
deer_flow_context = DeerFlowContext(
app_config=AppConfig.current(),
thread_id=thread_id,
)
runnable_config = RunnableConfig(**config)
agent = agent_factory(config=runnable_config)
@@ -155,7 +154,7 @@ async def run_agent(
if len(lg_modes) == 1 and not stream_subgraphs:
# Single mode, no subgraphs: astream yields raw chunks
single_mode = lg_modes[0]
async for chunk in agent.astream(graph_input, config=runnable_config, stream_mode=single_mode):
async for chunk in agent.astream(graph_input, config=runnable_config, context=deer_flow_context, stream_mode=single_mode):
if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id)
break
@@ -166,6 +165,7 @@ async def run_agent(
async for item in agent.astream(
graph_input,
config=runnable_config,
context=deer_flow_context,
stream_mode=lg_modes,
subgraphs=stream_subgraphs,
):
@@ -23,7 +23,7 @@ from collections.abc import AsyncIterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -100,7 +100,7 @@ async def make_store() -> AsyncIterator[BaseStore]:
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
``checkpointer`` section is configured (emits a WARNING in that case).
"""
config = get_app_config()
config = AppConfig.current()
if config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
@@ -26,7 +26,7 @@ from collections.abc import Iterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -115,19 +115,10 @@ def get_store() -> BaseStore:
if _store is not None:
return _store
# Lazily load app config, mirroring the checkpointer singleton pattern so
# that tests that set the global checkpointer config explicitly remain isolated.
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
try:
get_app_config()
except FileNotFoundError:
pass
config = get_checkpointer_config()
try:
config = AppConfig.current().checkpointer
except (LookupError, FileNotFoundError):
config = None
if config is None:
from langgraph.store.memory import InMemoryStore
@@ -176,7 +167,7 @@ def store_context() -> Iterator[BaseStore]:
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
config = AppConfig.current()
if config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
@@ -17,7 +17,7 @@ import contextlib
import logging
from collections.abc import AsyncIterator
from deerflow.config.stream_bridge_config import get_stream_bridge_config
from deerflow.config.app_config import AppConfig
from .base import StreamBridge
@@ -32,7 +32,7 @@ async def make_stream_bridge(config=None) -> AsyncIterator[StreamBridge]:
provided and nothing is set globally.
"""
if config is None:
config = get_stream_bridge_config()
config = AppConfig.current().stream_bridge
if config is None or config.type == "memory":
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
@@ -288,10 +288,10 @@ class LocalSandbox(Sandbox):
timeout=600,
)
else:
args = [shell, "-c", resolved_command]
result = subprocess.run(
args,
shell=False,
resolved_command,
executable=shell,
shell=True,
capture_output=True,
text=True,
timeout=600,
@@ -11,8 +11,6 @@ _singleton: LocalSandbox | None = None
class LocalSandboxProvider(SandboxProvider):
uses_thread_data_mounts = True
def __init__(self):
"""Initialize the local sandbox provider with path mappings."""
self._path_mappings = self._setup_path_mappings()
@@ -31,9 +29,9 @@ class LocalSandboxProvider(SandboxProvider):
# Map skills container path to local skills directory
try:
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
config = get_app_config()
config = AppConfig.current()
skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path
@@ -6,8 +6,8 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import SandboxState, ThreadDataState
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.sandbox import get_sandbox_provider
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@@ -50,15 +50,15 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
return sandbox_id
@override
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
# Skip acquisition if lazy_init is enabled
if self._lazy_init:
return super().before_agent(state, runtime)
# Eager initialization (original behavior)
if "sandbox" not in state or state["sandbox"] is None:
thread_id = get_thread_id(runtime)
if thread_id is None:
thread_id = runtime.context.thread_id
if not thread_id:
return super().before_agent(state, runtime)
sandbox_id = self._acquire_sandbox(thread_id)
logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}")
@@ -66,7 +66,7 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
return super().before_agent(state, runtime)
@override
def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
sandbox = state.get("sandbox")
if sandbox is not None:
sandbox_id = sandbox["sandbox_id"]
@@ -74,11 +74,5 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
get_sandbox_provider().release(sandbox_id)
return None
if (runtime.context or {}).get("sandbox_id") is not None:
sandbox_id = runtime.context.get("sandbox_id")
logger.info(f"Releasing sandbox {sandbox_id} from context")
get_sandbox_provider().release(sandbox_id)
return None
# No sandbox to release
return super().after_agent(state, runtime)
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class
from deerflow.sandbox.sandbox import Sandbox
@@ -8,8 +8,6 @@ from deerflow.sandbox.sandbox import Sandbox
class SandboxProvider(ABC):
"""Abstract base class for sandbox providers"""
uses_thread_data_mounts: bool = False
@abstractmethod
def acquire(self, thread_id: str | None = None) -> str:
"""Acquire a sandbox environment and return its ID.
@@ -52,7 +50,7 @@ def get_sandbox_provider(**kwargs) -> SandboxProvider:
"""
global _default_sandbox_provider
if _default_sandbox_provider is None:
config = get_app_config()
config = AppConfig.current()
cls = resolve_class(config.sandbox.use, SandboxProvider)
_default_sandbox_provider = cls(**kwargs)
return _default_sandbox_provider
@@ -1,6 +1,6 @@
"""Security helpers for sandbox capability gating."""
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
_LOCAL_SANDBOX_PROVIDER_MARKERS = (
"deerflow.sandbox.local:LocalSandboxProvider",
@@ -23,7 +23,7 @@ LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE = (
def uses_local_sandbox_provider(config=None) -> bool:
"""Return True when the active sandbox provider is the host-local provider."""
if config is None:
config = get_app_config()
config = AppConfig.current()
sandbox_cfg = getattr(config, "sandbox", None)
sandbox_use = getattr(sandbox_cfg, "use", "")
@@ -35,7 +35,7 @@ def uses_local_sandbox_provider(config=None) -> bool:
def is_host_bash_allowed(config=None) -> bool:
"""Return whether host bash execution is explicitly allowed."""
if config is None:
config = get_app_config()
config = AppConfig.current()
sandbox_cfg = getattr(config, "sandbox", None)
if sandbox_cfg is None:
@@ -7,7 +7,7 @@ from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.sandbox.exceptions import (
SandboxError,
@@ -19,7 +19,6 @@ from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.sandbox.search import GrepMatch
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.utils.runtime import get_thread_id
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
@@ -51,9 +50,7 @@ def _get_skills_container_path() -> str:
if cached is not None:
return cached
try:
from deerflow.config import get_app_config
value = get_app_config().skills.container_path
value = AppConfig.current().skills.container_path
_get_skills_container_path._cached = value # type: ignore[attr-defined]
return value
except Exception:
@@ -72,9 +69,7 @@ def _get_skills_host_path() -> str | None:
if cached is not None:
return cached
try:
from deerflow.config import get_app_config
config = get_app_config()
config = AppConfig.current()
skills_path = config.skills.get_skills_path()
if skills_path.exists():
value = str(skills_path)
@@ -133,9 +128,7 @@ def _get_custom_mounts():
try:
from pathlib import Path
from deerflow.config import get_app_config
config = get_app_config()
config = AppConfig.current()
mounts = []
if config.sandbox and config.sandbox.mounts:
# Only include mounts whose host_path exists, consistent with
@@ -275,9 +268,7 @@ def _get_mcp_allowed_paths() -> list[str]:
"""Get the list of allowed paths from MCP config for file system server."""
allowed_paths = []
try:
from deerflow.config.extensions_config import get_extensions_config
extensions_config = get_extensions_config()
extensions_config = AppConfig.current().extensions
for _, server in extensions_config.mcp_servers.items():
if not server.enabled:
@@ -302,7 +293,7 @@ def _get_mcp_allowed_paths() -> list[str]:
def _get_tool_config_int(name: str, key: str, default: int) -> int:
try:
tool_config = get_app_config().get_tool_config(name)
tool_config = AppConfig.current().get_tool_config(name)
if tool_config is not None and key in tool_config.model_extra:
value = tool_config.model_extra.get(key)
if isinstance(value, int):
@@ -810,8 +801,6 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
if sandbox is None:
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
return sandbox
@@ -846,15 +835,13 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if sandbox_id is not None:
sandbox = get_sandbox_provider().get(sandbox_id)
if sandbox is not None:
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox
# Sandbox was released, fall through to acquire new one
# Lazy acquisition: get thread_id and acquire sandbox
thread_id = get_thread_id(runtime)
if thread_id is None:
raise SandboxRuntimeError("Thread ID not available in runtime context, runtime config, or LangGraph config")
thread_id = runtime.context.thread_id
if not thread_id:
raise SandboxRuntimeError("Thread ID not available in runtime context")
provider = get_sandbox_provider()
sandbox_id = provider.acquire(thread_id)
@@ -867,8 +854,6 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if sandbox is None:
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox
@@ -1010,18 +995,14 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
command = _apply_cwd_prefix(command, thread_data)
output = sandbox.execute_command(command)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
sandbox_cfg = AppConfig.current().sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
ensure_thread_directories_exist(runtime)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
sandbox_cfg = AppConfig.current().sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
@@ -1046,7 +1027,6 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
thread_data = None
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data, read_only=True)
@@ -1061,12 +1041,8 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
if not children:
return "(empty)"
output = "\n".join(children)
if thread_data is not None:
output = mask_local_paths_in_output(output, thread_data)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
sandbox_cfg = AppConfig.current().sandbox
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
@@ -1237,9 +1213,7 @@ def read_file_tool(
if start_line is not None and end_line is not None:
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
sandbox_cfg = AppConfig.current().sandbox
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
except Exception:
max_chars = 50000
@@ -42,9 +42,9 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
if skills_path is None:
if use_config:
try:
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
config = get_app_config()
config = AppConfig.current()
skills_path = config.skills.get_skills_path()
except Exception:
# Fallback to default if config fails
@@ -9,7 +9,7 @@ from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.skills.loader import load_skills
from deerflow.skills.validation import _validate_skill_frontmatter
@@ -21,7 +21,7 @@ _SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
def get_skills_root_dir() -> Path:
return get_app_config().skills.get_skills_path()
return AppConfig.current().skills.get_skills_path()
def get_public_skills_dir() -> Path:
@@ -2,24 +2,21 @@ import logging
import re
from pathlib import Path
import yaml
from .types import Skill
logger = logging.getLogger(__name__)
def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None = None) -> Skill | None:
"""Parse a SKILL.md file and extract metadata.
"""
Parse a SKILL.md file and extract metadata.
Args:
skill_file: Path to the SKILL.md file.
category: Category of the skill ('public' or 'custom').
relative_path: Relative path from the category root to the skill
directory. Defaults to the skill directory name when omitted.
skill_file: Path to the SKILL.md file
category: Category of the skill ('public' or 'custom')
Returns:
Skill object if parsing succeeds, None otherwise.
Skill object if parsing succeeds, None otherwise
"""
if not skill_file.exists() or skill_file.name != "SKILL.md":
return None
@@ -27,42 +24,90 @@ def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None
try:
content = skill_file.read_text(encoding="utf-8")
# Extract YAML front-matter block between leading ``---`` fences.
# Extract YAML front matter
# Pattern: ---\nkey: value\n---
front_matter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
if not front_matter_match:
return None
front_matter_text = front_matter_match.group(1)
front_matter = front_matter_match.group(1)
try:
metadata = yaml.safe_load(front_matter_text)
except yaml.YAMLError as exc:
logger.error("Invalid YAML front-matter in %s: %s", skill_file, exc)
return None
# Parse YAML front matter with basic multiline string support
metadata = {}
lines = front_matter.split("\n")
current_key = None
current_value = []
is_multiline = False
multiline_style = None
indent_level = None
if not isinstance(metadata, dict):
logger.error("Front-matter in %s is not a YAML mapping", skill_file)
return None
for line in lines:
if is_multiline:
if not line.strip():
current_value.append("")
continue
# Extract required fields. Both must be non-empty strings.
current_indent = len(line) - len(line.lstrip())
if indent_level is None:
if current_indent > 0:
indent_level = current_indent
current_value.append(line[indent_level:])
continue
elif current_indent >= indent_level:
current_value.append(line[indent_level:])
continue
# If we reach here, it's either a new key or the end of multiline
if current_key and is_multiline:
if multiline_style == "|":
metadata[current_key] = "\n".join(current_value).rstrip()
else:
text = "\n".join(current_value).rstrip()
# Replace single newlines with spaces for folded blocks
metadata[current_key] = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
current_key = None
current_value = []
is_multiline = False
multiline_style = None
indent_level = None
if not line.strip():
continue
if ":" in line:
# Handle nested dicts simply by ignoring indentation for now,
# or just extracting top-level keys
key, value = line.split(":", 1)
key = key.strip()
value = value.strip()
if value in (">", "|"):
current_key = key
is_multiline = True
multiline_style = value
current_value = []
indent_level = None
else:
metadata[key] = value
if current_key and is_multiline:
if multiline_style == "|":
metadata[current_key] = "\n".join(current_value).rstrip()
else:
text = "\n".join(current_value).rstrip()
metadata[current_key] = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
# Extract required fields
name = metadata.get("name")
description = metadata.get("description")
if not name or not isinstance(name, str):
return None
if not description or not isinstance(description, str):
return None
# Normalise: strip surrounding whitespace that YAML may preserve.
name = name.strip()
description = description.strip()
if not name or not description:
return None
license_text = metadata.get("license")
if license_text is not None:
license_text = str(license_text).strip() or None
return Skill(
name=name,
@@ -72,9 +117,9 @@ def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None
skill_file=skill_file,
relative_path=relative_path or Path(skill_file.parent.name),
category=category,
enabled=True, # Actual state comes from the extensions config file.
enabled=True, # Default to enabled, actual state comes from config file
)
except Exception:
logger.exception("Unexpected error parsing skill file %s", skill_file)
except Exception as e:
logger.error("Error parsing skill file %s: %s", skill_file, e)
return None
@@ -7,7 +7,7 @@ import logging
import re
from dataclasses import dataclass
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -47,15 +47,14 @@ async def scan_skill_content(content: str, *, executable: bool = False, location
prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----"
try:
config = get_app_config()
config = AppConfig.current()
model_name = config.skill_evolution.moderation_model_name
model = create_chat_model(name=model_name, thinking_enabled=False) if model_name else create_chat_model(thinking_enabled=False)
response = await model.ainvoke(
[
{"role": "system", "content": rubric},
{"role": "user", "content": prompt},
],
config={"run_name": "security_agent"},
]
)
parsed = _extract_json_object(str(getattr(response, "content", "") or ""))
if parsed and parsed.get("decision") in {"allow", "warn", "block"}:
@@ -13,8 +13,6 @@ class SubagentConfig:
system_prompt: The system prompt that guides the subagent's behavior.
tools: Optional list of tool names to allow. If None, inherits all tools.
disallowed_tools: Optional list of tool names to deny.
skills: Optional list of skill names to load. If None, inherits all enabled skills.
If an empty list, no skills are loaded.
model: Model to use - 'inherit' uses parent's model.
max_turns: Maximum number of agent turns before stopping.
timeout_seconds: Maximum execution time in seconds (default: 900 = 15 minutes).
@@ -25,7 +23,6 @@ class SubagentConfig:
system_prompt: str
tools: list[str] | None = None
disallowed_tools: list[str] | None = field(default_factory=lambda: ["task"])
skills: list[str] | None = None
model: str = "inherit"
max_turns: int = 50
timeout_seconds: int = 900

Some files were not shown because too many files have changed in this diff Show More