Compare commits

..

7 Commits

Author SHA1 Message Date
greatmengqi edf345cd72 refactor(config): eliminate global mutable state, wire DeerFlowContext into runtime
- Freeze all config models (AppConfig + 15 sub-configs) with frozen=True
- Purify from_file() — remove 9 load_*_from_dict() side-effect calls
- Replace mtime/reload/push/pop machinery with single ContextVar + init_app_config()
- Delete 10 sub-module globals and their getters/setters/loaders
- Migrate 50+ consumers from get_*_config() to get_app_config().xxx

- Expand DeerFlowContext: app_config + thread_id + agent_name (frozen dataclass)
- Wire into Gateway runtime (worker.py) and DeerFlowClient via context= parameter
- Remove sandbox_id from runtime.context — flows through ThreadState.sandbox only
- Middleware/tools access runtime.context directly via Runtime[DeerFlowContext] generic
- resolve_context() retained at server entry points for LangGraph Server fallback
2026-04-14 01:18:19 +08:00
Matt Van Horn c4d273a68a feat(channels): add Discord channel integration (#1806)
* feat(channels): add Discord channel integration

Add a Discord bot channel following the existing Telegram/Slack pattern.
The bot listens for messages, creates conversation threads, and relays
responses back to Discord with 2000-char message splitting.

- DiscordChannel extends Channel base class
- Lazy imports discord.py with install hint
- Thread-based conversations (each Discord thread maps to a DeerFlow thread)
- Allowed guilds filter for access control
- File attachment support via discord.File
- Registered in service.py and manager.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(channels): address Copilot review suggestions for Discord integration

- Disable @everyone/@here mentions via AllowedMentions.none()
- Add 10s timeout to client close to prevent shutdown hangs
- Log publish_inbound errors via future callback instead of silently dropping
- Open file handle on caller thread to avoid cross-thread ownership issues
- Notify user in channel when thread creation fails

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(discord): resolve lint errors in Discord channel

- Replace asyncio.TimeoutError with builtin TimeoutError (UP041)
- Remove extraneous f-string prefix (F541)
- Apply ruff format

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(tests): remove fake langgraph_sdk shim from test_discord_channel

The module-level sys.modules.setdefault shim installed a fake
langgraph_sdk.errors.ConflictError during pytest collection. Because
pytest imports all test modules before running them, test_channels.py
then imported the fake ConflictError instead of the real one.

In test_handle_feishu_stream_conflict_sends_busy_message, the test
constructs ConflictError(message, response=..., body=...). The fake
only subclasses Exception (which takes no kwargs), so the construction
raised TypeError. The manager's _is_thread_busy_error check then saw a
TypeError instead of a ConflictError and fell through to the generic
'An error occurred' message.

langgraph_sdk is a real dependency, so the shim is unnecessary.
Removing it makes both test files import the same real ConflictError
and the full suite pass (1773 passed, 15 skipped).

---------

Co-authored-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-11 17:48:04 +08:00
Jason dc50a7fdfb fix(sandbox): resolve paths in read_file/write_file content for LocalSandbox (#1935)
* fix(sandbox): resolve paths in read_file/write_file content for LocalSandbox

In LocalSandbox mode, read_file and write_file now transform
container paths in file content, matching the path handling
behavior of bash tool.

- write_file: resolves virtual paths in content to system paths
  before writing, so scripts with /mnt/user-data paths work
  when executed
- read_file: reverse-resolves system paths back to virtual
  paths in returned content for consistency

This fixes scenarios where agents write Python scripts with
virtual paths, then execute them via bash tool expecting the
paths to work.

Fixes #1778

* fix(sandbox): address Copilot review — dedicated content resolver + forward-slash safety + tests

- Extract _resolve_paths_in_content() separate from _resolve_paths_in_command()
  to decouple file-content path resolution from shell-command parsing
- Normalize resolved paths to forward slashes to avoid Windows backslash
  escape issues in source files (e.g. \U in Python string literals)
- Add 4 focused tests: write resolves content, forward-slash guarantee,
  read reverse-resolves content, and write→read roundtrip

* style: fix ruff lint — remove extraneous f-string prefix

* fix(sandbox): only reverse-resolve paths in agent-written files

read_file previously applied _reverse_resolve_paths_in_output to ALL
file content, which could silently rewrite paths in user uploads and
external tool output (Willem Jiang review on #1935).

Now tracks files written through write_file in _agent_written_paths.
Only those files get reverse-resolved on read. Non-agent files are
returned as-is.

---------

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
2026-04-11 17:41:36 +08:00
ZHANG Ning 5b633449f8 fix(middleware): add per-tool-type frequency detection to LoopDetectionMiddleware (#1988)
* fix(middleware): add per-tool-type frequency detection to LoopDetectionMiddleware

The existing hash-based loop detection only catches identical tool call
sets. When the agent calls the same tool type (e.g. read_file) on many
different files, each call produces a unique hash and bypasses detection.
This causes the agent to exhaust recursion_limit, consuming 150K-225K
tokens per failed run.

Add a second detection layer that tracks cumulative call counts per tool
type per thread. Warns at 30 calls (configurable) and forces stop at 50.
The hard stop message now uses the actual returned message instead of a
hardcoded constant, so both hash-based and frequency-based stops produce
accurate diagnostics.

Also fix _apply() to use the warning message returned by
_track_and_check() for hard stops, instead of always using _HARD_STOP_MSG.

Closes #1987

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix(lint): remove unused imports and fix line length

- Remove unused _TOOL_FREQ_HARD_STOP_MSG and _TOOL_FREQ_WARNING_MSG
  imports from test file (F401)
- Break long _TOOL_FREQ_WARNING_MSG string to fit within 240 char limit (E501)

* style: apply ruff format

* test: add LRU eviction and per-thread reset coverage for frequency state

Address review feedback from @WillemJiang:
- Verify _tool_freq and _tool_freq_warned are cleaned on LRU eviction
- Add test for reset(thread_id=...) clearing only the target thread's
  frequency state while leaving others intact

* fix(makefile): route Windows shell-script targets through Git Bash (#2060)

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Asish Kumar <87874775+officialasishkumar@users.noreply.github.com>
2026-04-11 17:33:27 +08:00
yorick 02569136df fix(sandbox): improve sandbox security and preserve multimodal content (#2114)
* fix: improve sandbox security and preserve multimodal content

* Add unit test modifications for test_injects_uploaded_files_tag_into_list_content

* format updated_content

* Add regression tests for multimodal upload content and host bash default safety
2026-04-11 16:52:10 +08:00
dependabot[bot] 024ac0e464 chore(deps): bump langsmith from 0.5.2 to 0.5.18 in /frontend (#2110)
Bumps [langsmith](https://github.com/langchain-ai/langsmith-sdk) from 0.5.2 to 0.5.18.
- [Release notes](https://github.com/langchain-ai/langsmith-sdk/releases)
- [Commits](https://github.com/langchain-ai/langsmith-sdk/commits)

---
updated-dependencies:
- dependency-name: langsmith
  dependency-version: 0.5.18
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-11 14:51:21 +08:00
dependabot[bot] 19030928e0 chore(deps): bump langchain-core from 1.2.17 to 1.2.28 in /backend (#2109)
Bumps [langchain-core](https://github.com/langchain-ai/langchain) from 1.2.17 to 1.2.28.
- [Release notes](https://github.com/langchain-ai/langchain/releases)
- [Commits](https://github.com/langchain-ai/langchain/compare/langchain-core==1.2.17...langchain-core==1.2.28)

---
updated-dependencies:
- dependency-name: langchain-core
  dependency-version: 1.2.28
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-11 14:49:54 +08:00
364 changed files with 6886 additions and 30257 deletions
+1 -4
View File
@@ -24,6 +24,7 @@ INFOQUEST_API_KEY=your-infoquest-api-key
# SLACK_BOT_TOKEN=your-slack-bot-token # SLACK_BOT_TOKEN=your-slack-bot-token
# SLACK_APP_TOKEN=your-slack-app-token # SLACK_APP_TOKEN=your-slack-app-token
# TELEGRAM_BOT_TOKEN=your-telegram-bot-token # TELEGRAM_BOT_TOKEN=your-telegram-bot-token
# DISCORD_BOT_TOKEN=your-discord-bot-token
# Enable LangSmith to monitor and debug your LLM calls, agent runs, and tool executions. # Enable LangSmith to monitor and debug your LLM calls, agent runs, and tool executions.
# LANGSMITH_TRACING=true # LANGSMITH_TRACING=true
@@ -33,9 +34,5 @@ INFOQUEST_API_KEY=your-infoquest-api-key
# GitHub API Token # GitHub API Token
# GITHUB_TOKEN=your-github-token # GITHUB_TOKEN=your-github-token
# Database (only needed when config.yaml has database.backend: postgres)
# DATABASE_URL=postgresql://deerflow:password@localhost:5432/deerflow
#
# WECOM_BOT_ID=your-wecom-bot-id # WECOM_BOT_ID=your-wecom-bot-id
# WECOM_BOT_SECRET=your-wecom-bot-secret # WECOM_BOT_SECRET=your-wecom-bot-secret
+11 -21
View File
@@ -158,7 +158,7 @@ from deerflow.config import get_app_config
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`: Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory 1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation 2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state 3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption) 4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
@@ -179,7 +179,9 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
**Config Versioning**: `config.example.yaml` has a `config_version` field. On startup, `AppConfig.from_file()` compares user version vs example version and emits a warning if outdated. Missing `config_version` = version 0. Run `make config-upgrade` to auto-merge missing fields. When changing the config schema, bump `config_version` in `config.example.yaml`. **Config Versioning**: `config.example.yaml` has a `config_version` field. On startup, `AppConfig.from_file()` compares user version vs example version and emits a warning if outdated. Missing `config_version` = version 0. Run `make config-upgrade` to auto-merge missing fields. When changing the config schema, bump `config_version` in `config.example.yaml`.
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart. **Config Lifecycle**: All config models are `frozen=True` (immutable after construction). `AppConfig.from_file()` is a pure function — no side effects on sub-module globals. `get_app_config()` is backed by a single `ContextVar`, set once via `init_app_config()` at process startup. To update config at runtime (e.g., Gateway API updates MCP/Skills), construct a new `AppConfig.from_file()` and call `init_app_config()` again. No mtime detection, no auto-reload.
**DeerFlowContext**: Per-invocation typed context for the agent execution path, injected via LangGraph `Runtime[DeerFlowContext]`. Holds `app_config: AppConfig`, `thread_id: str`, `agent_name: str | None`. Gateway runtime and `DeerFlowClient` construct full `DeerFlowContext` at invoke time; LangGraph Server path uses a fallback via `resolve_context()`. Middleware and tools access context through `resolve_context(runtime)` which returns a typed `DeerFlowContext` regardless of entry point. Mutable runtime state (`sandbox_id`) flows through `ThreadState.sandbox`, not context.
Configuration priority: Configuration priority:
1. Explicit `config_path` argument 1. Explicit `config_path` argument
@@ -216,9 +218,6 @@ FastAPI application on port 8001 with health check at `GET /health`.
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail | | **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types | | **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing | | **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway. Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
@@ -232,7 +231,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
**Virtual Path System**: **Virtual Path System**:
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills` - Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/` - Physical: `backend/.deer-flow/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()` - Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"` - Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
@@ -272,7 +271,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
- `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml` - `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml`
- ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary - ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary
- Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]` - Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]`
- Each ACP agent uses a per-thread workspace at `{base_dir}/users/{user_id}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py` - Each ACP agent uses a per-thread workspace at `{base_dir}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
- `image_search/` - Image search via DuckDuckGo - `image_search/` - Image search via DuckDuckGo
### MCP System (`packages/harness/deerflow/mcp/`) ### MCP System (`packages/harness/deerflow/mcp/`)
@@ -341,27 +340,18 @@ Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow a
**Components**: **Components**:
- `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O - `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time); captures `user_id` at enqueue time so it survives the `threading.Timer` boundary - `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time)
- `prompt.py` - Prompt templates for memory updates - `prompt.py` - Prompt templates for memory updates
- `storage.py` - File-based storage with per-user isolation; cache keyed by `(user_id, agent_name)` tuple
**Per-User Isolation**: **Data Structure** (stored in `backend/.deer-flow/memory.json`):
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
- Absolute `storage_path` in config opts out of per-user isolation
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries) - **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
- **History**: `recentMonths`, `earlierContext`, `longTermBackground` - **History**: `recentMonths`, `earlierContext`, `longTermBackground`
- **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source` - **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source`
**Workflow**: **Workflow**:
1. `MemoryMiddleware` filters messages (user inputs + final AI responses), captures `user_id` via `get_effective_user_id()`, and queues conversation with the captured `user_id` 1. `MemoryMiddleware` filters messages (user inputs + final AI responses) and queues conversation
2. Queue debounces (30s default), batches updates, deduplicates per-thread 2. Queue debounces (30s default), batches updates, deduplicates per-thread
3. Background thread invokes LLM to extract context updates and facts, using the stored `user_id` (not the contextvar, which is unavailable on timer threads) 3. Background thread invokes LLM to extract context updates and facts
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append 4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt 5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
@@ -369,7 +359,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
**Configuration** (`config.yaml``memory`): **Configuration** (`config.yaml``memory`):
- `enabled` / `injection_enabled` - Master switches - `enabled` / `injection_enabled` - Master switches
- `storage_path` - Path to memory.json (absolute path opts out of per-user isolation) - `storage_path` - Path to memory.json
- `debounce_seconds` - Wait time before processing (default: 30) - `debounce_seconds` - Wait time before processing (default: 30)
- `model_name` - LLM for updates (null = default model) - `model_name` - LLM for updates (null = default model)
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7) - `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
+1 -5
View File
@@ -13,9 +13,6 @@ FROM python:3.12-slim-bookworm AS builder
ARG NODE_MAJOR=22 ARG NODE_MAJOR=22
ARG APT_MIRROR ARG APT_MIRROR
ARG UV_INDEX_URL ARG UV_INDEX_URL
# Optional extras to install (e.g. "postgres" for PostgreSQL support)
# Usage: docker build --build-arg UV_EXTRAS=postgres ...
ARG UV_EXTRAS
# Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.aliyun.com) # Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.aliyun.com)
RUN if [ -n "${APT_MIRROR}" ]; then \ RUN if [ -n "${APT_MIRROR}" ]; then \
@@ -46,9 +43,8 @@ WORKDIR /app
COPY backend ./backend COPY backend ./backend
# Install dependencies with cache mount # Install dependencies with cache mount
# When UV_EXTRAS is set (e.g. "postgres"), installs optional dependencies.
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}" sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync"
# ── Stage 2: Dev ────────────────────────────────────────────────────────────── # ── Stage 2: Dev ──────────────────────────────────────────────────────────────
# Retains compiler toolchain from builder so startup-time `uv sync` can build # Retains compiler toolchain from builder so startup-time `uv sync` can build
+273
View File
@@ -0,0 +1,273 @@
"""Discord channel integration using discord.py."""
from __future__ import annotations
import asyncio
import logging
import threading
from typing import Any
from app.channels.base import Channel
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
_DISCORD_MAX_MESSAGE_LEN = 2000
class DiscordChannel(Channel):
"""Discord bot channel.
Configuration keys (in ``config.yaml`` under ``channels.discord``):
- ``bot_token``: Discord Bot token.
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
"""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
super().__init__(name="discord", bus=bus, config=config)
self._bot_token = str(config.get("bot_token", "")).strip()
self._allowed_guilds: set[int] = set()
for guild_id in config.get("allowed_guilds", []):
try:
self._allowed_guilds.add(int(guild_id))
except (TypeError, ValueError):
continue
self._client = None
self._thread: threading.Thread | None = None
self._discord_loop: asyncio.AbstractEventLoop | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None
self._discord_module = None
async def start(self) -> None:
if self._running:
return
try:
import discord
except ImportError:
logger.error("discord.py is not installed. Install it with: uv add discord.py")
return
if not self._bot_token:
logger.error("Discord channel requires bot_token")
return
intents = discord.Intents.default()
intents.messages = True
intents.guilds = True
intents.message_content = True
client = discord.Client(
intents=intents,
allowed_mentions=discord.AllowedMentions.none(),
)
self._client = client
self._discord_module = discord
self._main_loop = asyncio.get_event_loop()
@client.event
async def on_message(message) -> None:
await self._on_message(message)
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
self._thread = threading.Thread(target=self._run_client, daemon=True)
self._thread.start()
logger.info("Discord channel started")
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
if self._client and self._discord_loop and self._discord_loop.is_running():
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
try:
await asyncio.wait_for(asyncio.wrap_future(close_future), timeout=10)
except TimeoutError:
logger.warning("[Discord] client close timed out after 10s")
except Exception:
logger.exception("[Discord] error while closing client")
if self._thread:
self._thread.join(timeout=10)
self._thread = None
self._client = None
self._discord_loop = None
self._discord_module = None
logger.info("Discord channel stopped")
async def send(self, msg: OutboundMessage) -> None:
target = await self._resolve_target(msg)
if target is None:
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
return
text = msg.text or ""
for chunk in self._split_text(text):
send_future = asyncio.run_coroutine_threadsafe(target.send(chunk), self._discord_loop)
await asyncio.wrap_future(send_future)
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
target = await self._resolve_target(msg)
if target is None:
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
return False
if self._discord_module is None:
return False
try:
fp = open(str(attachment.actual_path), "rb") # noqa: SIM115
file = self._discord_module.File(fp, filename=attachment.filename)
send_future = asyncio.run_coroutine_threadsafe(target.send(file=file), self._discord_loop)
await asyncio.wrap_future(send_future)
logger.info("[Discord] file uploaded: %s", attachment.filename)
return True
except Exception:
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
return False
async def _on_message(self, message) -> None:
if not self._running or not self._client:
return
if message.author.bot:
return
if self._client.user and message.author.id == self._client.user.id:
return
guild = message.guild
if self._allowed_guilds:
if guild is None or guild.id not in self._allowed_guilds:
return
text = (message.content or "").strip()
if not text:
return
if self._discord_module is None:
return
if isinstance(message.channel, self._discord_module.Thread):
chat_id = str(message.channel.parent_id or message.channel.id)
thread_id = str(message.channel.id)
else:
thread = await self._create_thread(message)
if thread is None:
return
chat_id = str(message.channel.id)
thread_id = str(thread.id)
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=chat_id,
user_id=str(message.author.id),
text=text,
msg_type=msg_type,
thread_ts=thread_id,
metadata={
"guild_id": str(guild.id) if guild else None,
"channel_id": str(message.channel.id),
"message_id": str(message.id),
},
)
inbound.topic_id = thread_id
if self._main_loop and self._main_loop.is_running():
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
def _run_client(self) -> None:
self._discord_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._discord_loop)
try:
self._discord_loop.run_until_complete(self._client.start(self._bot_token))
except Exception:
if self._running:
logger.exception("Discord client error")
finally:
try:
if self._client and not self._client.is_closed():
self._discord_loop.run_until_complete(self._client.close())
except Exception:
logger.exception("Error during Discord shutdown")
async def _create_thread(self, message):
try:
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
return await message.create_thread(name=thread_name)
except Exception:
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
try:
await message.channel.send("Could not create a thread for your message. Please check that threads are enabled in this channel.")
except Exception:
pass
return None
async def _resolve_target(self, msg: OutboundMessage):
if not self._client or not self._discord_loop:
return None
target_ids: list[str] = []
if msg.thread_ts:
target_ids.append(msg.thread_ts)
if msg.chat_id and msg.chat_id not in target_ids:
target_ids.append(msg.chat_id)
for raw_id in target_ids:
target = await self._get_channel_or_thread(raw_id)
if target is not None:
return target
return None
async def _get_channel_or_thread(self, raw_id: str):
if not self._client or not self._discord_loop:
return None
try:
target_id = int(raw_id)
except (TypeError, ValueError):
return None
get_future = asyncio.run_coroutine_threadsafe(self._fetch_channel(target_id), self._discord_loop)
try:
return await asyncio.wrap_future(get_future)
except Exception:
logger.exception("[Discord] failed to resolve target id=%s", raw_id)
return None
async def _fetch_channel(self, target_id: int):
if not self._client:
return None
channel = self._client.get_channel(target_id)
if channel is not None:
return channel
try:
return await self._client.fetch_channel(target_id)
except Exception:
return None
@staticmethod
def _split_text(text: str) -> list[str]:
if not text:
return [""]
chunks: list[str] = []
remaining = text
while len(remaining) > _DISCORD_MAX_MESSAGE_LEN:
split_at = remaining.rfind("\n", 0, _DISCORD_MAX_MESSAGE_LEN)
if split_at <= 0:
split_at = _DISCORD_MAX_MESSAGE_LEN
chunks.append(remaining[:split_at])
remaining = remaining[split_at:].lstrip("\n")
if remaining:
chunks.append(remaining)
return chunks
+2 -4
View File
@@ -13,7 +13,6 @@ from app.channels.base import Channel
from app.channels.commands import KNOWN_CHANNEL_COMMANDS from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import get_sandbox_provider from deerflow.sandbox.sandbox_provider import get_sandbox_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -345,9 +344,8 @@ class FeishuChannel(Channel):
return f"Failed to obtain the [{type}]" return f"Failed to obtain the [{type}]"
paths = get_paths() paths = get_paths()
user_id = get_effective_user_id() paths.ensure_thread_dirs(thread_id)
paths.ensure_thread_dirs(thread_id, user_id=user_id) uploads_dir = paths.sandbox_uploads_dir(thread_id).resolve()
uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=user_id).resolve()
ext = "png" if type == "image" else "bin" ext = "png" if type == "image" else "bin"
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}" raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
+3 -4
View File
@@ -17,7 +17,6 @@ from langgraph_sdk.errors import ConflictError
from app.channels.commands import KNOWN_CHANNEL_COMMANDS from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.store import ChannelStore from app.channels.store import ChannelStore
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -36,6 +35,7 @@ STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again." THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
CHANNEL_CAPABILITIES = { CHANNEL_CAPABILITIES = {
"discord": {"supports_streaming": False},
"feishu": {"supports_streaming": True}, "feishu": {"supports_streaming": True},
"slack": {"supports_streaming": False}, "slack": {"supports_streaming": False},
"telegram": {"supports_streaming": False}, "telegram": {"supports_streaming": False},
@@ -342,15 +342,14 @@ def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedA
attachments: list[ResolvedAttachment] = [] attachments: list[ResolvedAttachment] = []
paths = get_paths() paths = get_paths()
user_id = get_effective_user_id() outputs_dir = paths.sandbox_outputs_dir(thread_id).resolve()
outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=user_id).resolve()
for virtual_path in artifacts: for virtual_path in artifacts:
# Security: only allow files from the agent outputs directory # Security: only allow files from the agent outputs directory
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX): if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path) logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
continue continue
try: try:
actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=user_id) actual = paths.resolve_virtual_path(thread_id, virtual_path)
# Verify the resolved path is actually under the outputs directory # Verify the resolved path is actually under the outputs directory
# (guards against path-traversal even after prefix check) # (guards against path-traversal even after prefix check)
try: try:
+3 -2
View File
@@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
# Channel name → import path for lazy loading # Channel name → import path for lazy loading
_CHANNEL_REGISTRY: dict[str, str] = { _CHANNEL_REGISTRY: dict[str, str] = {
"discord": "app.channels.discord:DiscordChannel",
"feishu": "app.channels.feishu:FeishuChannel", "feishu": "app.channels.feishu:FeishuChannel",
"slack": "app.channels.slack:SlackChannel", "slack": "app.channels.slack:SlackChannel",
"telegram": "app.channels.telegram:TelegramChannel", "telegram": "app.channels.telegram:TelegramChannel",
@@ -66,9 +67,9 @@ class ChannelService:
@classmethod @classmethod
def from_app_config(cls) -> ChannelService: def from_app_config(cls) -> ChannelService:
"""Create a ChannelService from the application config.""" """Create a ChannelService from the application config."""
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig
config = get_app_config() config = AppConfig.current()
channels_config = {} channels_config = {}
# extra fields are allowed by AppConfig (extra="allow") # extra fields are allowed by AppConfig (extra="allow")
extra = config.model_extra or {} extra = config.model_extra or {}
+3 -145
View File
@@ -1,22 +1,16 @@
import logging import logging
import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.gateway.auth_middleware import AuthMiddleware
from app.gateway.config import get_gateway_config from app.gateway.config import get_gateway_config
from app.gateway.csrf_middleware import CSRFMiddleware
from app.gateway.deps import langgraph_runtime from app.gateway.deps import langgraph_runtime
from app.gateway.routers import ( from app.gateway.routers import (
agents, agents,
artifacts, artifacts,
assistants_compat, assistants_compat,
auth,
channels, channels,
feedback,
mcp, mcp,
memory, memory,
models, models,
@@ -27,7 +21,7 @@ from app.gateway.routers import (
threads, threads,
uploads, uploads,
) )
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
@@ -39,115 +33,13 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def _ensure_admin_user(app: FastAPI) -> None:
"""Startup hook: handle first boot and migrate orphan threads otherwise.
After admin creation, migrate orphan threads from the LangGraph
store (metadata.user_id unset) to the admin account. This is the
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
authentication have existing LangGraph thread data that needs an
owner assigned.
First boot (no admin exists):
- Does NOT create any user accounts automatically.
- The operator must visit ``/setup`` to create the first admin.
Subsequent boots (admin already exists):
- Runs the one-time "no-auth → with-auth" orphan thread migration for
existing LangGraph thread metadata that has no owner_id.
No SQL persistence migration is needed: the four user_id columns
(threads_meta, runs, run_events, feedback) only come into existence
alongside the auth module via create_all, so freshly created tables
never contain NULL-owner rows.
"""
from sqlalchemy import select
from app.gateway.deps import get_local_provider
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.user.model import UserRow
provider = get_local_provider()
admin_count = await provider.count_admin_users()
if admin_count == 0:
logger.info("=" * 60)
logger.info(" First boot detected — no admin account exists.")
logger.info(" Visit /setup to complete admin account creation.")
logger.info("=" * 60)
return
# Admin already exists — run orphan thread migration for any
# LangGraph thread metadata that pre-dates the auth module.
sf = get_session_factory()
if sf is None:
return
async with sf() as session:
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
return # Should not happen (admin_count > 0 above), but be safe.
admin_id = str(row.id)
# LangGraph store orphan migration — non-fatal.
# This covers the "no-auth → with-auth" upgrade path for users
# whose existing LangGraph thread metadata has no user_id set.
store = getattr(app.state, "store", None)
if store is not None:
try:
migrated = await _migrate_orphaned_threads(store, admin_id)
if migrated:
logger.info("Migrated %d orphan LangGraph thread(s) to admin", migrated)
except Exception:
logger.exception("LangGraph thread migration failed (non-fatal)")
async def _iter_store_items(store, namespace, *, page_size: int = 500):
"""Paginated async iterator over a LangGraph store namespace.
Replaces the old hardcoded ``limit=1000`` call with a cursor-style
loop so that environments with more than one page of orphans do
not silently lose data. Terminates when a page is empty OR when a
short page arrives (indicating the last page).
"""
offset = 0
while True:
batch = await store.asearch(namespace, limit=page_size, offset=offset)
if not batch:
return
for item in batch:
yield item
if len(batch) < page_size:
return
offset += page_size
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
"""Migrate LangGraph store threads with no user_id to the given admin.
Uses cursor pagination so all orphans are migrated regardless of
count. Returns the number of rows migrated.
"""
migrated = 0
async for item in _iter_store_items(store, ("threads",)):
metadata = item.value.get("metadata", {})
if not metadata.get("user_id"):
metadata["user_id"] = admin_user_id
item.value["metadata"] = metadata
await store.aput(("threads",), item.key, item.value)
migrated += 1
return migrated
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler.""" """Application lifespan handler."""
# Load config and check necessary environment variables at startup # Load config and check necessary environment variables at startup
try: try:
get_app_config() AppConfig.current()
logger.info("Configuration loaded successfully") logger.info("Configuration loaded successfully")
except Exception as e: except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}" error_msg = f"Failed to load configuration during gateway startup: {e}"
@@ -160,10 +52,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
async with langgraph_runtime(app): async with langgraph_runtime(app):
logger.info("LangGraph runtime initialised") logger.info("LangGraph runtime initialised")
# Ensure admin user exists (auto-create on first boot)
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
await _ensure_admin_user(app)
# Start IM channel service if any channels are configured # Start IM channel service if any channels are configured
try: try:
from app.channels.service import start_channel_service from app.channels.service import start_channel_service
@@ -275,31 +163,7 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
], ],
) )
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net) # CORS is handled by nginx - no need for FastAPI middleware
app.add_middleware(AuthMiddleware)
# CSRF: Double Submit Cookie pattern for state-changing requests
app.add_middleware(CSRFMiddleware)
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware.
# In production, nginx handles CORS and no middleware is needed.
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
if cors_origins_env:
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
# Validate: wildcard origin with credentials is a security misconfiguration
for origin in cors_origins:
if origin == "*":
logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.")
cors_origins = [o for o in cors_origins if o != "*"]
break
if cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers # Include routers
# Models API is mounted at /api/models # Models API is mounted at /api/models
@@ -335,12 +199,6 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
# Assistants compatibility API (LangGraph Platform stub) # Assistants compatibility API (LangGraph Platform stub)
app.include_router(assistants_compat.router) app.include_router(assistants_compat.router)
# Auth API is mounted at /api/v1/auth
app.include_router(auth.router)
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
app.include_router(feedback.router)
# Thread Runs API (LangGraph Platform-compatible runs lifecycle) # Thread Runs API (LangGraph Platform-compatible runs lifecycle)
app.include_router(thread_runs.router) app.include_router(thread_runs.router)
-42
View File
@@ -1,42 +0,0 @@
"""Authentication module for DeerFlow.
This module provides:
- JWT-based authentication
- Provider Factory pattern for extensible auth methods
- UserRepository interface for storage backends (SQLite)
"""
from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.models import User, UserResponse
from app.gateway.auth.password import hash_password, verify_password
from app.gateway.auth.providers import AuthProvider
from app.gateway.auth.repositories.base import UserRepository
__all__ = [
# Config
"AuthConfig",
"get_auth_config",
"set_auth_config",
# Errors
"AuthErrorCode",
"AuthErrorResponse",
"TokenError",
# JWT
"TokenPayload",
"create_access_token",
"decode_token",
# Password
"hash_password",
"verify_password",
# Models
"User",
"UserResponse",
# Providers
"AuthProvider",
"LocalAuthProvider",
# Repository
"UserRepository",
]
-57
View File
@@ -1,57 +0,0 @@
"""Authentication configuration for DeerFlow."""
import logging
import os
import secrets
from dotenv import load_dotenv
from pydantic import BaseModel, Field
load_dotenv()
logger = logging.getLogger(__name__)
class AuthConfig(BaseModel):
"""JWT and auth-related configuration. Parsed once at startup.
Note: the ``users`` table now lives in the shared persistence
database managed by ``deerflow.persistence.engine``. The old
``users_db_path`` config key has been removed — user storage is
configured through ``config.database`` like every other table.
"""
jwt_secret: str = Field(
...,
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
)
token_expiry_days: int = Field(default=7, ge=1, le=30)
oauth_github_client_id: str | None = Field(default=None)
oauth_github_client_secret: str | None = Field(default=None)
_auth_config: AuthConfig | None = None
def get_auth_config() -> AuthConfig:
"""Get the global AuthConfig instance. Parses from env on first call."""
global _auth_config
if _auth_config is None:
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
if not jwt_secret:
jwt_secret = secrets.token_urlsafe(32)
os.environ["AUTH_JWT_SECRET"] = jwt_secret
logger.warning(
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
"Sessions will be invalidated on restart. "
"For production, add AUTH_JWT_SECRET to your .env file: "
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
)
_auth_config = AuthConfig(jwt_secret=jwt_secret)
return _auth_config
def set_auth_config(config: AuthConfig) -> None:
"""Set the global AuthConfig instance (for testing)."""
global _auth_config
_auth_config = config
@@ -1,48 +0,0 @@
"""Write initial admin credentials to a restricted file instead of logs.
Logging secrets to stdout/stderr is a well-known CodeQL finding
(py/clear-text-logging-sensitive-data) — in production those logs
get collected into ELK/Splunk/etc and become a secret sprawl
source. This helper writes the credential to a 0600 file that only
the process user can read, and returns the path so the caller can
log **the path** (not the password) for the operator to pick up.
"""
from __future__ import annotations
import os
from pathlib import Path
from deerflow.config.paths import get_paths
_CREDENTIAL_FILENAME = "admin_initial_credentials.txt"
def write_initial_credentials(email: str, password: str, *, label: str = "initial") -> Path:
"""Write the admin email + password to ``{base_dir}/admin_initial_credentials.txt``.
The file is created **atomically** with mode 0600 via ``os.open``
so the password is never world-readable, even for the single syscall
window between ``write_text`` and ``chmod``.
``label`` distinguishes "initial" (fresh creation) from "reset"
(password reset) in the file header so an operator picking up the
file after a restart can tell which event produced it.
Returns the absolute :class:`Path` to the file.
"""
target = get_paths().base_dir / _CREDENTIAL_FILENAME
target.parent.mkdir(parents=True, exist_ok=True)
content = (
f"# DeerFlow admin {label} credentials\n# This file is generated on first boot or password reset.\n# Change the password after login via Settings -> Account,\n# then delete this file.\n#\nemail: {email}\npassword: {password}\n"
)
# Atomic 0600 create-or-truncate. O_TRUNC (not O_EXCL) so the
# reset-password path can rewrite an existing file without a
# separate unlink-then-create dance.
fd = os.open(target, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as fh:
fh.write(content)
return target.resolve()
-45
View File
@@ -1,45 +0,0 @@
"""Typed error definitions for auth module.
AuthErrorCode: exhaustive enum of all auth failure conditions.
TokenError: exhaustive enum of JWT decode failures.
AuthErrorResponse: structured error payload for HTTP responses.
"""
from enum import StrEnum
from pydantic import BaseModel
class AuthErrorCode(StrEnum):
"""Exhaustive list of auth error conditions."""
INVALID_CREDENTIALS = "invalid_credentials"
TOKEN_EXPIRED = "token_expired"
TOKEN_INVALID = "token_invalid"
USER_NOT_FOUND = "user_not_found"
EMAIL_ALREADY_EXISTS = "email_already_exists"
PROVIDER_NOT_FOUND = "provider_not_found"
NOT_AUTHENTICATED = "not_authenticated"
SYSTEM_ALREADY_INITIALIZED = "system_already_initialized"
class TokenError(StrEnum):
"""Exhaustive list of JWT decode failure reasons."""
EXPIRED = "expired"
INVALID_SIGNATURE = "invalid_signature"
MALFORMED = "malformed"
class AuthErrorResponse(BaseModel):
"""Structured error response — replaces bare `detail` strings."""
code: AuthErrorCode
message: str
def token_error_to_code(err: TokenError) -> AuthErrorCode:
"""Map TokenError to AuthErrorCode — single source of truth."""
if err == TokenError.EXPIRED:
return AuthErrorCode.TOKEN_EXPIRED
return AuthErrorCode.TOKEN_INVALID
-55
View File
@@ -1,55 +0,0 @@
"""JWT token creation and verification."""
from datetime import UTC, datetime, timedelta
import jwt
from pydantic import BaseModel
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.errors import TokenError
class TokenPayload(BaseModel):
"""JWT token payload."""
sub: str # user_id
exp: datetime
iat: datetime | None = None
ver: int = 0 # token_version — must match User.token_version
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
"""Create a JWT access token.
Args:
user_id: The user's UUID as string
expires_delta: Optional custom expiry, defaults to 7 days
token_version: User's current token_version for invalidation
Returns:
Encoded JWT string
"""
config = get_auth_config()
expiry = expires_delta or timedelta(days=config.token_expiry_days)
now = datetime.now(UTC)
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
def decode_token(token: str) -> TokenPayload | TokenError:
"""Decode and validate a JWT token.
Returns:
TokenPayload if valid, or a specific TokenError variant.
"""
config = get_auth_config()
try:
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
return TokenPayload(**payload)
except jwt.ExpiredSignatureError:
return TokenError.EXPIRED
except jwt.InvalidSignatureError:
return TokenError.INVALID_SIGNATURE
except jwt.PyJWTError:
return TokenError.MALFORMED
@@ -1,91 +0,0 @@
"""Local email/password authentication provider."""
from app.gateway.auth.models import User
from app.gateway.auth.password import hash_password_async, verify_password_async
from app.gateway.auth.providers import AuthProvider
from app.gateway.auth.repositories.base import UserRepository
class LocalAuthProvider(AuthProvider):
"""Email/password authentication provider using local database."""
def __init__(self, repository: UserRepository):
"""Initialize with a UserRepository.
Args:
repository: UserRepository implementation (SQLite)
"""
self._repo = repository
async def authenticate(self, credentials: dict) -> User | None:
"""Authenticate with email and password.
Args:
credentials: dict with 'email' and 'password' keys
Returns:
User if authentication succeeds, None otherwise
"""
email = credentials.get("email")
password = credentials.get("password")
if not email or not password:
return None
user = await self._repo.get_user_by_email(email)
if user is None:
return None
if user.password_hash is None:
# OAuth user without local password
return None
if not await verify_password_async(password, user.password_hash):
return None
return user
async def get_user(self, user_id: str) -> User | None:
"""Get user by ID."""
return await self._repo.get_user_by_id(user_id)
async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User:
"""Create a new local user.
Args:
email: User email address
password: Plain text password (will be hashed)
system_role: Role to assign ("admin" or "user")
needs_setup: If True, user must complete setup on first login
Returns:
Created User instance
"""
password_hash = await hash_password_async(password) if password else None
user = User(
email=email,
password_hash=password_hash,
system_role=system_role,
needs_setup=needs_setup,
)
return await self._repo.create_user(user)
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID."""
return await self._repo.get_user_by_oauth(provider, oauth_id)
async def count_users(self) -> int:
"""Return total number of registered users."""
return await self._repo.count_users()
async def count_admin_users(self) -> int:
"""Return number of admin users."""
return await self._repo.count_admin_users()
async def update_user(self, user: User) -> User:
"""Update an existing user."""
return await self._repo.update_user(user)
async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email."""
return await self._repo.get_user_by_email(email)
-41
View File
@@ -1,41 +0,0 @@
"""User Pydantic models for authentication."""
from datetime import UTC, datetime
from typing import Literal
from uuid import UUID, uuid4
from pydantic import BaseModel, ConfigDict, EmailStr, Field
def _utc_now() -> datetime:
"""Return current UTC time (timezone-aware)."""
return datetime.now(UTC)
class User(BaseModel):
"""Internal user representation."""
model_config = ConfigDict(from_attributes=True)
id: UUID = Field(default_factory=uuid4, description="Primary key")
email: EmailStr = Field(..., description="Unique email address")
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
system_role: Literal["admin", "user"] = Field(default="user")
created_at: datetime = Field(default_factory=_utc_now)
# OAuth linkage (optional)
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
# Auth lifecycle
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
class UserResponse(BaseModel):
"""Response model for user info endpoint."""
id: str
email: str
system_role: Literal["admin", "user"]
needs_setup: bool = False
-33
View File
@@ -1,33 +0,0 @@
"""Password hashing utilities using bcrypt directly."""
import asyncio
import bcrypt
def hash_password(password: str) -> str:
"""Hash a password using bcrypt."""
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
async def hash_password_async(password: str) -> str:
"""Hash a password using bcrypt (non-blocking).
Wraps the blocking bcrypt operation in a thread pool to avoid
blocking the event loop during password hashing.
"""
return await asyncio.to_thread(hash_password, password)
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash (non-blocking).
Wraps the blocking bcrypt operation in a thread pool to avoid
blocking the event loop during password verification.
"""
return await asyncio.to_thread(verify_password, plain_password, hashed_password)
-24
View File
@@ -1,24 +0,0 @@
"""Auth provider abstraction."""
from abc import ABC, abstractmethod
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def authenticate(self, credentials: dict) -> "User | None":
"""Authenticate user with given credentials.
Returns User if authentication succeeds, None otherwise.
"""
...
@abstractmethod
async def get_user(self, user_id: str) -> "User | None":
"""Retrieve user by ID."""
...
# Import User at runtime to avoid circular imports
from app.gateway.auth.models import User # noqa: E402
@@ -1,102 +0,0 @@
"""User repository interface for abstracting database operations."""
from abc import ABC, abstractmethod
from app.gateway.auth.models import User
class UserNotFoundError(LookupError):
"""Raised when a user repository operation targets a non-existent row.
Subclass of :class:`LookupError` so callers that already catch
``LookupError`` for "missing entity" can keep working unchanged,
while specific call sites can pin to this class to distinguish
"concurrent delete during update" from other lookups.
"""
class UserRepository(ABC):
"""Abstract interface for user data storage.
Implement this interface to support different storage backends
(SQLite)
"""
@abstractmethod
async def create_user(self, user: User) -> User:
"""Create a new user.
Args:
user: User object to create
Returns:
Created User with ID assigned
Raises:
ValueError: If email already exists
"""
...
@abstractmethod
async def get_user_by_id(self, user_id: str) -> User | None:
"""Get user by ID.
Args:
user_id: User UUID as string
Returns:
User if found, None otherwise
"""
...
@abstractmethod
async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email.
Args:
email: User email address
Returns:
User if found, None otherwise
"""
...
@abstractmethod
async def update_user(self, user: User) -> User:
"""Update an existing user.
Args:
user: User object with updated fields
Returns:
Updated User
Raises:
UserNotFoundError: If no row exists for ``user.id``. This is
a hard failure (not a no-op) so callers cannot mistake a
concurrent-delete race for a successful update.
"""
...
@abstractmethod
async def count_users(self) -> int:
"""Return total number of registered users."""
...
@abstractmethod
async def count_admin_users(self) -> int:
"""Return number of users with system_role == 'admin'."""
...
@abstractmethod
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID.
Args:
provider: OAuth provider name (e.g. 'github', 'google')
oauth_id: User ID from the OAuth provider
Returns:
User if found, None otherwise
"""
...
@@ -1,127 +0,0 @@
"""SQLAlchemy-backed UserRepository implementation.
Uses the shared async session factory from
``deerflow.persistence.engine`` — the ``users`` table lives in the
same database as ``threads_meta``, ``runs``, ``run_events``, and
``feedback``.
Constructor takes the session factory directly (same pattern as the
other four repositories in ``deerflow.persistence.*``). Callers
construct this after ``init_engine_from_config()`` has run.
"""
from __future__ import annotations
from datetime import UTC
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.gateway.auth.models import User
from app.gateway.auth.repositories.base import UserNotFoundError, UserRepository
from deerflow.persistence.user.model import UserRow
class SQLiteUserRepository(UserRepository):
"""Async user repository backed by the shared SQLAlchemy engine."""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
# ── Converters ────────────────────────────────────────────────────
@staticmethod
def _row_to_user(row: UserRow) -> User:
return User(
id=UUID(row.id),
email=row.email,
password_hash=row.password_hash,
system_role=row.system_role, # type: ignore[arg-type]
# SQLite loses tzinfo on read; reattach UTC so downstream
# code can compare timestamps reliably.
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
oauth_provider=row.oauth_provider,
oauth_id=row.oauth_id,
needs_setup=row.needs_setup,
token_version=row.token_version,
)
@staticmethod
def _user_to_row(user: User) -> UserRow:
return UserRow(
id=str(user.id),
email=user.email,
password_hash=user.password_hash,
system_role=user.system_role,
created_at=user.created_at,
oauth_provider=user.oauth_provider,
oauth_id=user.oauth_id,
needs_setup=user.needs_setup,
token_version=user.token_version,
)
# ── CRUD ──────────────────────────────────────────────────────────
async def create_user(self, user: User) -> User:
"""Insert a new user. Raises ``ValueError`` on duplicate email."""
row = self._user_to_row(user)
async with self._sf() as session:
session.add(row)
try:
await session.commit()
except IntegrityError as exc:
await session.rollback()
raise ValueError(f"Email already registered: {user.email}") from exc
return user
async def get_user_by_id(self, user_id: str) -> User | None:
async with self._sf() as session:
row = await session.get(UserRow, user_id)
return self._row_to_user(row) if row is not None else None
async def get_user_by_email(self, email: str) -> User | None:
stmt = select(UserRow).where(UserRow.email == email)
async with self._sf() as session:
result = await session.execute(stmt)
row = result.scalar_one_or_none()
return self._row_to_user(row) if row is not None else None
async def update_user(self, user: User) -> User:
async with self._sf() as session:
row = await session.get(UserRow, str(user.id))
if row is None:
# Hard fail on concurrent delete: callers (reset_admin,
# password change handlers, _ensure_admin_user) all
# fetched the user just before this call, so a missing
# row here means the row vanished underneath us. Silent
# success would let the caller log "password reset" for
# a row that no longer exists.
raise UserNotFoundError(f"User {user.id} no longer exists")
row.email = user.email
row.password_hash = user.password_hash
row.system_role = user.system_role
row.oauth_provider = user.oauth_provider
row.oauth_id = user.oauth_id
row.needs_setup = user.needs_setup
row.token_version = user.token_version
await session.commit()
return user
async def count_users(self) -> int:
stmt = select(func.count()).select_from(UserRow)
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def count_admin_users(self) -> int:
stmt = select(func.count()).select_from(UserRow).where(UserRow.system_role == "admin")
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
async with self._sf() as session:
result = await session.execute(stmt)
row = result.scalar_one_or_none()
return self._row_to_user(row) if row is not None else None
-91
View File
@@ -1,91 +0,0 @@
"""CLI tool to reset an admin password.
Usage:
python -m app.gateway.auth.reset_admin
python -m app.gateway.auth.reset_admin --email admin@example.com
Writes the new password to ``.deer-flow/admin_initial_credentials.txt``
(mode 0600) instead of printing it, so CI / log aggregators never see
the cleartext secret.
"""
from __future__ import annotations
import argparse
import asyncio
import secrets
import sys
from sqlalchemy import select
from app.gateway.auth.credential_file import write_initial_credentials
from app.gateway.auth.password import hash_password
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.user.model import UserRow
async def _run(email: str | None) -> int:
from deerflow.config import get_app_config
from deerflow.persistence.engine import (
close_engine,
get_session_factory,
init_engine_from_config,
)
config = get_app_config()
await init_engine_from_config(config.database)
try:
sf = get_session_factory()
if sf is None:
print("Error: persistence engine not available (check config.database).", file=sys.stderr)
return 1
repo = SQLiteUserRepository(sf)
if email:
user = await repo.get_user_by_email(email)
else:
# Find first admin via direct SELECT — repository does not
# expose a "first admin" helper and we do not want to add
# one just for this CLI.
async with sf() as session:
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
user = None
else:
user = await repo.get_user_by_id(row.id)
if user is None:
if email:
print(f"Error: user '{email}' not found.", file=sys.stderr)
else:
print("Error: no admin user found.", file=sys.stderr)
return 1
new_password = secrets.token_urlsafe(16)
user.password_hash = hash_password(new_password)
user.token_version += 1
user.needs_setup = True
await repo.update_user(user)
cred_path = write_initial_credentials(user.email, new_password, label="reset")
print(f"Password reset for: {user.email}")
print(f"Credentials written to: {cred_path} (mode 0600)")
print("Next login will require setup (new email + password).")
return 0
finally:
await close_engine()
def main() -> None:
parser = argparse.ArgumentParser(description="Reset admin password")
parser.add_argument("--email", help="Admin email (default: first admin found)")
args = parser.parse_args()
exit_code = asyncio.run(_run(args.email))
sys.exit(exit_code)
if __name__ == "__main__":
main()
-118
View File
@@ -1,118 +0,0 @@
"""Global authentication middleware — fail-closed safety net.
Rejects unauthenticated requests to non-public paths with 401. When a
request passes the cookie check, resolves the JWT payload to a real
``User`` object and stamps it into both ``request.state.user`` and the
``deerflow.runtime.user_context`` contextvar so that repository-layer
owner filtering works automatically via the sentinel pattern.
Fine-grained permission checks remain in authz.py decorators.
"""
from collections.abc import Callable
from fastapi import HTTPException, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
from deerflow.runtime.user_context import reset_current_user, set_current_user
# Paths that never require authentication.
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
"/health",
"/docs",
"/redoc",
"/openapi.json",
)
# Exact auth paths that are public (login/register/status check).
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/login/local",
"/api/v1/auth/register",
"/api/v1/auth/logout",
"/api/v1/auth/setup-status",
"/api/v1/auth/initialize",
}
)
def _is_public(path: str) -> bool:
stripped = path.rstrip("/")
if stripped in _PUBLIC_EXACT_PATHS:
return True
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
class AuthMiddleware(BaseHTTPMiddleware):
"""Strict auth gate: reject requests without a valid session.
Two-stage check for non-public paths:
1. Cookie presence — return 401 NOT_AUTHENTICATED if missing
2. JWT validation via ``get_optional_user_from_request`` — return 401
TOKEN_INVALID if the token is absent, malformed, expired, or the
signed user does not exist / is stale
On success, stamps ``request.state.user`` and the
``deerflow.runtime.user_context`` contextvar so that repository-layer
owner filters work downstream without every route needing a
``@require_auth`` decorator. Routes that need per-resource
authorization (e.g. "user A cannot read user B's thread by guessing
the URL") should additionally use ``@require_permission(...,
owner_check=True)`` for explicit enforcement — but authentication
itself is fully handled here.
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if _is_public(request.url.path):
return await call_next(request)
# Non-public path: require session cookie
if not request.cookies.get("access_token"):
return JSONResponse(
status_code=401,
content={
"detail": AuthErrorResponse(
code=AuthErrorCode.NOT_AUTHENTICATED,
message="Authentication required",
).model_dump()
},
)
# Strict JWT validation: reject junk/expired tokens with 401
# right here instead of silently passing through. This closes
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
# without this, non-isolation routes like /api/models would
# accept any cookie-shaped string as authentication.
#
# We call the *strict* resolver so that fine-grained error
# codes (token_expired, token_invalid, user_not_found, …)
# propagate from AuthErrorCode, not get flattened into one
# generic code. BaseHTTPMiddleware doesn't let HTTPException
# bubble up, so we catch and render it as JSONResponse here.
from app.gateway.deps import get_current_user_from_request
try:
user = await get_current_user_from_request(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
# Stamp both request.state.user (for the contextvar pattern)
# and request.state.auth (so @require_permission's "auth is
# None" branch short-circuits instead of running the entire
# JWT-decode + DB-lookup pipeline a second time per request).
request.state.user = user
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
token = set_current_user(user)
try:
return await call_next(request)
finally:
reset_current_user(token)
-262
View File
@@ -1,262 +0,0 @@
"""Authorization decorators and context for DeerFlow.
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
**Usage:**
1. Use ``@require_auth`` on routes that need authentication
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
3. The decorator chain processes from bottom to top
**Example:**
@router.get("/{thread_id}")
@require_auth
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request):
# User is authenticated and has threads:read permission
...
**Permission Model:**
- threads:read - View thread
- threads:write - Create/update thread
- threads:delete - Delete thread
- runs:create - Run agent
- runs:read - View run
- runs:cancel - Cancel run
"""
from __future__ import annotations
import functools
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
from fastapi import HTTPException, Request
if TYPE_CHECKING:
from app.gateway.auth.models import User
P = ParamSpec("P")
T = TypeVar("T")
# Permission constants
class Permissions:
"""Permission constants for resource:action format."""
# Threads
THREADS_READ = "threads:read"
THREADS_WRITE = "threads:write"
THREADS_DELETE = "threads:delete"
# Runs
RUNS_CREATE = "runs:create"
RUNS_READ = "runs:read"
RUNS_CANCEL = "runs:cancel"
class AuthContext:
"""Authentication context for the current request.
Stored in request.state.auth after require_auth decoration.
Attributes:
user: The authenticated user, or None if anonymous
permissions: List of permission strings (e.g., "threads:read")
"""
__slots__ = ("user", "permissions")
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
self.user = user
self.permissions = permissions or []
@property
def is_authenticated(self) -> bool:
"""Check if user is authenticated."""
return self.user is not None
def has_permission(self, resource: str, action: str) -> bool:
"""Check if context has permission for resource:action.
Args:
resource: Resource name (e.g., "threads")
action: Action name (e.g., "read")
Returns:
True if user has permission
"""
permission = f"{resource}:{action}"
return permission in self.permissions
def require_user(self) -> User:
"""Get user or raise 401.
Raises:
HTTPException 401 if not authenticated
"""
if not self.user:
raise HTTPException(status_code=401, detail="Authentication required")
return self.user
def get_auth_context(request: Request) -> AuthContext | None:
"""Get AuthContext from request state."""
return getattr(request.state, "auth", None)
_ALL_PERMISSIONS: list[str] = [
Permissions.THREADS_READ,
Permissions.THREADS_WRITE,
Permissions.THREADS_DELETE,
Permissions.RUNS_CREATE,
Permissions.RUNS_READ,
Permissions.RUNS_CANCEL,
]
async def _authenticate(request: Request) -> AuthContext:
"""Authenticate request and return AuthContext.
Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline.
Returns AuthContext with user=None for anonymous requests.
"""
from app.gateway.deps import get_optional_user_from_request
user = await get_optional_user_from_request(request)
if user is None:
return AuthContext(user=None, permissions=[])
# In future, permissions could be stored in user record
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
"""Decorator that authenticates the request and sets AuthContext.
Must be placed ABOVE other decorators (executes after them).
Usage:
@router.get("/{thread_id}")
@require_auth # Bottom decorator (executes first after permission check)
@require_permission("threads", "read")
async def get_thread(thread_id: str, request: Request):
auth: AuthContext = request.state.auth
...
Raises:
ValueError: If 'request' parameter is missing
"""
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
raise ValueError("require_auth decorator requires 'request' parameter")
# Authenticate and set context
auth_context = await _authenticate(request)
request.state.auth = auth_context
return await func(*args, **kwargs)
return wrapper
def require_permission(
resource: str,
action: str,
owner_check: bool = False,
require_existing: bool = False,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator that checks permission for resource:action.
Must be used AFTER @require_auth.
Args:
resource: Resource name (e.g., "threads", "runs")
action: Action name (e.g., "read", "write", "delete")
owner_check: If True, validates that the current user owns the resource.
Requires 'thread_id' path parameter and performs ownership check.
require_existing: Only meaningful with ``owner_check=True``. If True, a
missing ``threads_meta`` row counts as a denial (404)
instead of "untracked legacy thread, allow". Use on
**destructive / mutating** routes (DELETE, PATCH,
state-update) so a deleted thread can't be re-targeted
by another user via the missing-row code path.
Usage:
# Read-style: legacy untracked threads are allowed
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request):
...
# Destructive: thread row MUST exist and be owned by caller
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_thread(thread_id: str, request: Request):
...
Raises:
HTTPException 401: If authentication required but user is anonymous
HTTPException 403: If user lacks permission
HTTPException 404: If owner_check=True but user doesn't own the thread
ValueError: If owner_check=True but 'thread_id' parameter is missing
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
raise ValueError("require_permission decorator requires 'request' parameter")
auth: AuthContext = getattr(request.state, "auth", None)
if auth is None:
auth = await _authenticate(request)
request.state.auth = auth
if not auth.is_authenticated:
raise HTTPException(status_code=401, detail="Authentication required")
# Check permission
if not auth.has_permission(resource, action):
raise HTTPException(
status_code=403,
detail=f"Permission denied: {resource}:{action}",
)
# Owner check for thread-specific resources.
#
# 2.0-rc moved thread metadata into the SQL persistence layer
# (``threads_meta`` table). We verify ownership via
# ``ThreadMetaStore.check_access``: it returns True for
# missing rows (untracked legacy thread) and for rows whose
# ``user_id`` is NULL (shared / pre-auth data), so this is
# strict-deny rather than strict-allow — only an *existing*
# row with a *different* user_id triggers 404.
if owner_check:
thread_id = kwargs.get("thread_id")
if thread_id is None:
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
from app.gateway.deps import get_thread_store
thread_store = get_thread_store(request)
allowed = await thread_store.check_access(
thread_id,
str(auth.user.id),
require_existing=require_existing,
)
if not allowed:
raise HTTPException(
status_code=404,
detail=f"Thread {thread_id} not found",
)
return await func(*args, **kwargs)
return wrapper
return decorator
-113
View File
@@ -1,113 +0,0 @@
"""CSRF protection middleware for FastAPI.
Per RFC-001:
State-changing operations require CSRF protection.
"""
import secrets
from collections.abc import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
CSRF_COOKIE_NAME = "csrf_token"
CSRF_HEADER_NAME = "X-CSRF-Token"
CSRF_TOKEN_LENGTH = 64 # bytes
def is_secure_request(request: Request) -> bool:
"""Detect whether the original client request was made over HTTPS."""
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
def generate_csrf_token() -> str:
"""Generate a secure random CSRF token."""
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
def should_check_csrf(request: Request) -> bool:
"""Determine if a request needs CSRF validation.
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
"""
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
return False
path = request.url.path.rstrip("/")
# Exempt /api/v1/auth/me endpoint
if path == "/api/v1/auth/me":
return False
return True
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/login/local",
"/api/v1/auth/logout",
"/api/v1/auth/register",
"/api/v1/auth/initialize",
}
)
def is_auth_endpoint(request: Request) -> bool:
"""Check if the request is to an auth endpoint.
Auth endpoints don't need CSRF validation on first call (no token).
"""
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
_is_auth = is_auth_endpoint(request)
if should_check_csrf(request) and not _is_auth:
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
header_token = request.headers.get(CSRF_HEADER_NAME)
if not cookie_token or not header_token:
return JSONResponse(
status_code=403,
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
)
if not secrets.compare_digest(cookie_token, header_token):
return JSONResponse(
status_code=403,
content={"detail": "CSRF token mismatch."},
)
response = await call_next(request)
# For auth endpoints that set up session, also set CSRF cookie
if _is_auth and request.method == "POST":
# Generate a new CSRF token for the session
csrf_token = generate_csrf_token()
is_https = is_secure_request(request)
response.set_cookie(
key=CSRF_COOKIE_NAME,
value=csrf_token,
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
secure=is_https,
samesite="strict",
)
return response
def get_csrf_token(request: Request) -> str | None:
"""Get the CSRF token from the current request's cookies.
This is useful for server-side rendering where you need to embed
token in forms or headers.
"""
return request.cookies.get(CSRF_COOKIE_NAME)
+25 -189
View File
@@ -8,25 +8,12 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator, Callable from collections.abc import AsyncGenerator
from contextlib import AsyncExitStack, asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, TypeVar, cast
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from langgraph.types import Checkpointer
from deerflow.persistence.feedback import FeedbackRepository from deerflow.runtime import RunManager, StreamBridge
from deerflow.runtime import RunContext, RunManager, StreamBridge
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.runs.store.base import RunStore
if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.thread_meta.base import ThreadMetaStore
T = TypeVar("T")
@asynccontextmanager @asynccontextmanager
@@ -38,52 +25,15 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
async with langgraph_runtime(app): async with langgraph_runtime(app):
yield yield
""" """
from deerflow.config import get_app_config from deerflow.agents.checkpointer.async_provider import make_checkpointer
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
from deerflow.runtime import make_store, make_stream_bridge from deerflow.runtime import make_store, make_stream_bridge
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
from deerflow.runtime.events.store import make_run_event_store
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge()) app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
# Initialize persistence engine BEFORE checkpointer so that
# auto-create-database logic runs first (postgres backend).
config = get_app_config()
await init_engine_from_config(config.database)
app.state.checkpointer = await stack.enter_async_context(make_checkpointer()) app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
app.state.store = await stack.enter_async_context(make_store()) app.state.store = await stack.enter_async_context(make_store())
app.state.run_manager = RunManager()
# Initialize repositories — one get_session_factory() call for all. yield
sf = get_session_factory()
if sf is not None:
from deerflow.persistence.feedback import FeedbackRepository
from deerflow.persistence.run import RunRepository
app.state.run_store = RunRepository(sf)
app.state.feedback_repo = FeedbackRepository(sf)
else:
from deerflow.runtime.runs.store.memory import MemoryRunStore
app.state.run_store = MemoryRunStore()
app.state.feedback_repo = None
from deerflow.persistence.thread_meta import make_thread_store
app.state.thread_store = make_thread_store(sf, app.state.store)
# Run event store (has its own factory with config-driven backend selection)
run_events_config = getattr(config, "run_events", None)
app.state.run_event_store = make_run_event_store(run_events_config)
# RunManager with store backing for persistence
app.state.run_manager = RunManager(store=app.state.run_store)
try:
yield
finally:
await close_engine()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -91,144 +41,30 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _require(attr: str, label: str) -> Callable[[Request], T]: def get_stream_bridge(request: Request) -> StreamBridge:
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503.""" """Return the global :class:`StreamBridge`, or 503."""
bridge = getattr(request.app.state, "stream_bridge", None)
def dep(request: Request) -> T: if bridge is None:
val = getattr(request.app.state, attr, None) raise HTTPException(status_code=503, detail="Stream bridge not available")
if val is None: return bridge
raise HTTPException(status_code=503, detail=f"{label} not available")
return cast(T, val)
dep.__name__ = dep.__qualname__ = f"get_{attr}"
return dep
get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge") def get_run_manager(request: Request) -> RunManager:
get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager") """Return the global :class:`RunManager`, or 503."""
get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer") mgr = getattr(request.app.state, "run_manager", None)
get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store") if mgr is None:
get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback") raise HTTPException(status_code=503, detail="Run manager not available")
get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store") return mgr
def get_checkpointer(request: Request):
"""Return the global checkpointer, or 503."""
cp = getattr(request.app.state, "checkpointer", None)
if cp is None:
raise HTTPException(status_code=503, detail="Checkpointer not available")
return cp
def get_store(request: Request): def get_store(request: Request):
"""Return the global store (may be ``None`` if not configured).""" """Return the global store (may be ``None`` if not configured)."""
return getattr(request.app.state, "store", None) return getattr(request.app.state, "store", None)
def get_thread_store(request: Request) -> ThreadMetaStore:
"""Return the thread metadata store (SQL or memory-backed)."""
val = getattr(request.app.state, "thread_store", None)
if val is None:
raise HTTPException(status_code=503, detail="Thread metadata store not available")
return val
def get_run_context(request: Request) -> RunContext:
"""Build a :class:`RunContext` from ``app.state`` singletons.
Returns a *base* context with infrastructure dependencies.
"""
from deerflow.config import get_app_config
return RunContext(
checkpointer=get_checkpointer(request),
store=get_store(request),
event_store=get_run_event_store(request),
run_events_config=getattr(get_app_config(), "run_events", None),
thread_store=get_thread_store(request),
)
# ---------------------------------------------------------------------------
# Auth helpers (used by authz.py and auth middleware)
# ---------------------------------------------------------------------------
# Cached singletons to avoid repeated instantiation per request
_cached_local_provider: LocalAuthProvider | None = None
_cached_repo: SQLiteUserRepository | None = None
def get_local_provider() -> LocalAuthProvider:
"""Get or create the cached LocalAuthProvider singleton.
Must be called after ``init_engine_from_config()`` — the shared
session factory is required to construct the user repository.
"""
global _cached_local_provider, _cached_repo
if _cached_repo is None:
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.engine import get_session_factory
sf = get_session_factory()
if sf is None:
raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table")
_cached_repo = SQLiteUserRepository(sf)
if _cached_local_provider is None:
from app.gateway.auth.local_provider import LocalAuthProvider
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
return _cached_local_provider
async def get_current_user_from_request(request: Request):
"""Get the current authenticated user from the request cookie.
Raises HTTPException 401 if not authenticated.
"""
from app.gateway.auth import decode_token
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
access_token = request.cookies.get("access_token")
if not access_token:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
)
payload = decode_token(access_token)
if isinstance(payload, TokenError):
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
)
provider = get_local_provider()
user = await provider.get_user(payload.sub)
if user is None:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
)
# Token version mismatch → password was changed, token is stale
if user.token_version != payload.ver:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
)
return user
async def get_optional_user_from_request(request: Request):
"""Get optional authenticated user from request.
Returns None if not authenticated.
"""
try:
return await get_current_user_from_request(request)
except HTTPException:
return None
async def get_current_user(request: Request) -> str | None:
"""Extract user_id from request cookie, or None if not authenticated.
Thin adapter that returns the string id for callers that only need
identification (e.g., ``feedback.py``). Full-user callers should use
``get_current_user_from_request`` or ``get_optional_user_from_request``.
"""
user = await get_optional_user_from_request(request)
return str(user.id) if user else None
-106
View File
@@ -1,106 +0,0 @@
"""LangGraph Server auth handler — shares JWT logic with Gateway.
Loaded by LangGraph Server via langgraph.json ``auth.path``.
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
so both modes validate tokens with the same secret and rules.
Two layers:
1. @auth.authenticate — validates JWT cookie, extracts user_id,
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
2. @auth.on — returns metadata filter so each user only sees own threads
"""
import secrets
from langgraph_sdk import Auth
from app.gateway.auth.errors import TokenError
from app.gateway.auth.jwt import decode_token
from app.gateway.deps import get_local_provider
auth = Auth()
# Methods that require CSRF validation (state-changing per RFC 7231).
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
def _check_csrf(request) -> None:
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
proxied directly by nginx have the same CSRF protection.
"""
method = getattr(request, "method", "") or ""
if method.upper() not in _CSRF_METHODS:
return
cookie_token = request.cookies.get("csrf_token")
header_token = request.headers.get("x-csrf-token")
if not cookie_token or not header_token:
raise Auth.exceptions.HTTPException(
status_code=403,
detail="CSRF token missing. Include X-CSRF-Token header.",
)
if not secrets.compare_digest(cookie_token, header_token):
raise Auth.exceptions.HTTPException(
status_code=403,
detail="CSRF token mismatch.",
)
@auth.authenticate
async def authenticate(request):
"""Validate the session cookie, decode JWT, and check token_version.
Same validation chain as Gateway's get_current_user_from_request:
cookie → decode JWT → DB lookup → token_version match
Also enforces CSRF on state-changing methods.
"""
# CSRF check before authentication so forged cross-site requests
# are rejected early, even if the cookie carries a valid JWT.
_check_csrf(request)
token = request.cookies.get("access_token")
if not token:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Not authenticated",
)
payload = decode_token(token)
if isinstance(payload, TokenError):
raise Auth.exceptions.HTTPException(
status_code=401,
detail=f"Token error: {payload.value}",
)
user = await get_local_provider().get_user(payload.sub)
if user is None:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="User not found",
)
if user.token_version != payload.ver:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Token revoked (password changed)",
)
return payload.sub
@auth.on
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
"""Inject user_id metadata on writes; filter by user_id on reads.
Gateway stores thread ownership as ``metadata.user_id``.
This handler ensures LangGraph Server enforces the same isolation.
"""
# On create/update: stamp user_id into metadata
metadata = value.setdefault("metadata", {})
metadata["user_id"] = ctx.user.identity
# Return filter dict — LangGraph applies it to search/read/delete
return {"user_id": ctx.user.identity}
+1 -2
View File
@@ -5,7 +5,6 @@ from pathlib import Path
from fastapi import HTTPException from fastapi import HTTPException
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path: def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
@@ -23,7 +22,7 @@ def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
HTTPException: If the path is invalid or outside allowed directories. HTTPException: If the path is invalid or outside allowed directories.
""" """
try: try:
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=get_effective_user_id()) return get_paths().resolve_virtual_path(thread_id, virtual_path)
except ValueError as e: except ValueError as e:
status = 403 if "traversal" in str(e) else 400 status = 403 if "traversal" in str(e) else 400
raise HTTPException(status_code=status, detail=str(e)) raise HTTPException(status_code=status, detail=str(e))
-2
View File
@@ -7,7 +7,6 @@ from urllib.parse import quote
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import FileResponse, PlainTextResponse, Response from fastapi.responses import FileResponse, PlainTextResponse, Response
from app.gateway.authz import require_permission
from app.gateway.path_utils import resolve_thread_virtual_path from app.gateway.path_utils import resolve_thread_virtual_path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -82,7 +81,6 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
summary="Get Artifact File", summary="Get Artifact File",
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.", description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
) )
@require_permission("threads", "read", owner_check=True)
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response: async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
"""Get an artifact file by its path. """Get an artifact file by its path.
-459
View File
@@ -1,459 +0,0 @@
"""Authentication endpoints."""
import logging
import os
import time
from ipaddress import ip_address, ip_network
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr, Field, field_validator
from app.gateway.auth import (
UserResponse,
create_access_token,
)
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.csrf_middleware import is_secure_request
from app.gateway.deps import get_current_user_from_request, get_local_provider
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
# ── Request/Response Models ──────────────────────────────────────────────
class LoginResponse(BaseModel):
"""Response model for login — token only lives in HttpOnly cookie."""
expires_in: int # seconds
needs_setup: bool = False
# Top common-password blocklist. Drawn from the public SecLists "10k worst
# passwords" set, lowercased + length>=8 only (shorter ones already fail
# the min_length check). Kept tight on purpose: this is the **lower bound**
# defense, not a full HIBP / passlib check, and runs in-process per request.
_COMMON_PASSWORDS: frozenset[str] = frozenset(
{
"password",
"password1",
"password12",
"password123",
"password1234",
"12345678",
"123456789",
"1234567890",
"qwerty12",
"qwertyui",
"qwerty123",
"abc12345",
"abcd1234",
"iloveyou",
"letmein1",
"welcome1",
"welcome123",
"admin123",
"administrator",
"passw0rd",
"p@ssw0rd",
"monkey12",
"trustno1",
"sunshine",
"princess",
"football",
"baseball",
"superman",
"batman123",
"starwars",
"dragon123",
"master123",
"shadow12",
"michael1",
"jennifer",
"computer",
}
)
def _password_is_common(password: str) -> bool:
"""Case-insensitive blocklist check.
Lowercases the input so trivial mutations like ``Password`` /
``PASSWORD`` are also rejected. Does not normalize digit substitutions
(``p@ssw0rd`` is included as a literal entry instead) — keeping the
rule cheap and predictable.
"""
return password.lower() in _COMMON_PASSWORDS
def _validate_strong_password(value: str) -> str:
"""Pydantic field-validator body shared by Register + ChangePassword.
Constraint = function, not type-level mixin. The two request models
have no "is-a" relationship; they only share the password-strength
rule. Lifting it into a free function lets each model bind it via
``@field_validator(field_name)`` without inheritance gymnastics.
"""
if _password_is_common(value):
raise ValueError("Password is too common; choose a stronger password.")
return value
class RegisterRequest(BaseModel):
"""Request model for user registration."""
email: EmailStr
password: str = Field(..., min_length=8)
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
class ChangePasswordRequest(BaseModel):
"""Request model for password change (also handles setup flow)."""
current_password: str
new_password: str = Field(..., min_length=8)
new_email: EmailStr | None = None
_strong_password = field_validator("new_password")(classmethod(lambda cls, v: _validate_strong_password(v)))
class MessageResponse(BaseModel):
"""Generic message response."""
message: str
# ── Helpers ───────────────────────────────────────────────────────────────
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
"""Set the access_token HttpOnly cookie on the response."""
config = get_auth_config()
is_https = is_secure_request(request)
response.set_cookie(
key="access_token",
value=token,
httponly=True,
secure=is_https,
samesite="lax",
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
)
# ── Rate Limiting ────────────────────────────────────────────────────────
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
_MAX_LOGIN_ATTEMPTS = 5
_LOCKOUT_SECONDS = 300 # 5 minutes
# ip → (fail_count, lock_until_timestamp)
_login_attempts: dict[str, tuple[int, float]] = {}
def _trusted_proxies() -> list:
"""Parse ``AUTH_TRUSTED_PROXIES`` env var into a list of ip_network objects.
Comma-separated CIDR or single-IP entries. Empty / unset = no proxy is
trusted (direct mode). Invalid entries are skipped with a logger warning.
Read live so env-var overrides take effect immediately and tests can
``monkeypatch.setenv`` without poking a module-level cache.
"""
raw = os.getenv("AUTH_TRUSTED_PROXIES", "").strip()
if not raw:
return []
nets = []
for entry in raw.split(","):
entry = entry.strip()
if not entry:
continue
try:
nets.append(ip_network(entry, strict=False))
except ValueError:
logger.warning("AUTH_TRUSTED_PROXIES: ignoring invalid entry %r", entry)
return nets
def _get_client_ip(request: Request) -> str:
"""Extract the real client IP for rate limiting.
Trust model:
- The TCP peer (``request.client.host``) is always the baseline. It is
whatever the kernel reports as the connecting socket — unforgeable
by the client itself.
- ``X-Real-IP`` is **only** honored if the TCP peer is in the
``AUTH_TRUSTED_PROXIES`` allowlist (set via env var, comma-separated
CIDR or single IPs). When set, the gateway is assumed to be behind a
reverse proxy (nginx, Cloudflare, ALB, …) that overwrites
``X-Real-IP`` with the original client address.
- With no ``AUTH_TRUSTED_PROXIES`` set, ``X-Real-IP`` is silently
ignored — closing the bypass where any client could rotate the
header to dodge per-IP rate limits in dev / direct-gateway mode.
``X-Forwarded-For`` is intentionally NOT used because it is naturally
client-controlled at the *first* hop and the trust chain is harder to
audit per-request.
"""
peer_host = request.client.host if request.client else None
trusted = _trusted_proxies()
if trusted and peer_host:
try:
peer_ip = ip_address(peer_host)
if any(peer_ip in net for net in trusted):
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
except ValueError:
# peer_host wasn't a parseable IP (e.g. "unknown") — fall through
pass
return peer_host or "unknown"
def _check_rate_limit(ip: str) -> None:
"""Raise 429 if the IP is currently locked out."""
record = _login_attempts.get(ip)
if record is None:
return
fail_count, lock_until = record
if fail_count >= _MAX_LOGIN_ATTEMPTS:
if time.time() < lock_until:
raise HTTPException(
status_code=429,
detail="Too many login attempts. Try again later.",
)
del _login_attempts[ip]
_MAX_TRACKED_IPS = 10000
def _record_login_failure(ip: str) -> None:
"""Record a failed login attempt for the given IP."""
# Evict expired lockouts when dict grows too large
if len(_login_attempts) >= _MAX_TRACKED_IPS:
now = time.time()
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
for k in expired:
del _login_attempts[k]
# If still too large, evict cheapest-to-lose half: below-threshold
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
if len(_login_attempts) >= _MAX_TRACKED_IPS:
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
for k, _ in by_time[: len(by_time) // 2]:
del _login_attempts[k]
record = _login_attempts.get(ip)
if record is None:
_login_attempts[ip] = (1, 0.0)
else:
new_count = record[0] + 1
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
_login_attempts[ip] = (new_count, lock_until)
def _record_login_success(ip: str) -> None:
"""Clear failure counter for the given IP on successful login."""
_login_attempts.pop(ip, None)
# ── Endpoints ─────────────────────────────────────────────────────────────
@router.post("/login/local", response_model=LoginResponse)
async def login_local(
request: Request,
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
):
"""Local email/password login."""
client_ip = _get_client_ip(request)
_check_rate_limit(client_ip)
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
if user is None:
_record_login_failure(client_ip)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
)
_record_login_success(client_ip)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return LoginResponse(
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
needs_setup=user.needs_setup,
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(request: Request, response: Response, body: RegisterRequest):
"""Register a new user account (always 'user' role).
Admin is auto-created on first boot. This endpoint creates regular users.
Auto-login by setting the session cookie.
"""
try:
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
@router.post("/logout", response_model=MessageResponse)
async def logout(request: Request, response: Response):
"""Logout current user by clearing the cookie."""
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
return MessageResponse(message="Successfully logged out")
@router.post("/change-password", response_model=MessageResponse)
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
"""Change password for the currently authenticated user.
Also handles the first-boot setup flow:
- If new_email is provided, updates email (checks uniqueness)
- If user.needs_setup is True and new_email is given, clears needs_setup
- Always increments token_version to invalidate old sessions
- Re-issues session cookie with new token_version
"""
from app.gateway.auth.password import hash_password_async, verify_password_async
user = await get_current_user_from_request(request)
if user.password_hash is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
if not await verify_password_async(body.current_password, user.password_hash):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
provider = get_local_provider()
# Update email if provided
if body.new_email is not None:
existing = await provider.get_user_by_email(body.new_email)
if existing and str(existing.id) != str(user.id):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
user.email = body.new_email
# Update password + bump version
user.password_hash = await hash_password_async(body.new_password)
user.token_version += 1
# Clear setup flag if this is the setup flow
if user.needs_setup and body.new_email is not None:
user.needs_setup = False
await provider.update_user(user)
# Re-issue cookie with new token_version
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return MessageResponse(message="Password changed successfully")
@router.get("/me", response_model=UserResponse)
async def get_me(request: Request):
"""Get current authenticated user info."""
user = await get_current_user_from_request(request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
@router.get("/setup-status")
async def setup_status():
"""Check if an admin account exists. Returns needs_setup=True when no admin exists."""
admin_count = await get_local_provider().count_admin_users()
return {"needs_setup": admin_count == 0}
class InitializeAdminRequest(BaseModel):
"""Request model for first-boot admin account creation."""
email: EmailStr
password: str = Field(..., min_length=8)
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
@router.post("/initialize", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def initialize_admin(request: Request, response: Response, body: InitializeAdminRequest):
"""Create the first admin account on initial system setup.
Only callable when no admin exists. Returns 409 Conflict if an admin
already exists.
On success, the admin account is created with ``needs_setup=False`` and
the session cookie is set.
"""
admin_count = await get_local_provider().count_admin_users()
if admin_count > 0:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
)
try:
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="admin", needs_setup=False)
except ValueError:
# DB unique-constraint race: another concurrent request beat us.
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
@router.get("/oauth/{provider}")
async def oauth_login(provider: str):
"""Initiate OAuth login flow.
Redirects to the OAuth provider's authorization URL.
Currently a placeholder - requires OAuth provider implementation.
"""
if provider not in ["github", "google"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported OAuth provider: {provider}",
)
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth login not yet implemented",
)
@router.get("/callback/{provider}")
async def oauth_callback(provider: str, code: str, state: str):
"""OAuth callback endpoint.
Handles the OAuth provider's callback after user authorization.
Currently a placeholder.
"""
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth callback not yet implemented",
)
-188
View File
@@ -1,188 +0,0 @@
"""Feedback endpoints — create, list, stats, delete.
Allows users to submit thumbs-up/down feedback on runs,
optionally scoped to a specific message.
"""
from __future__ import annotations
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["feedback"])
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class FeedbackCreateRequest(BaseModel):
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
comment: str | None = Field(default=None, description="Optional text feedback")
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
class FeedbackUpsertRequest(BaseModel):
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
comment: str | None = Field(default=None, description="Optional text feedback")
class FeedbackResponse(BaseModel):
feedback_id: str
run_id: str
thread_id: str
user_id: str | None = None
message_id: str | None = None
rating: int
comment: str | None = None
created_at: str = ""
class FeedbackStatsResponse(BaseModel):
run_id: str
total: int = 0
positive: int = 0
negative: int = 0
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def upsert_feedback(
thread_id: str,
run_id: str,
body: FeedbackUpsertRequest,
request: Request,
) -> dict[str, Any]:
"""Create or update feedback for a run (idempotent)."""
if body.rating not in (1, -1):
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
user_id = await get_current_user(request)
run_store = get_run_store(request)
run = await run_store.get(run_id)
if run is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
if run.get("thread_id") != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
feedback_repo = get_feedback_repo(request)
return await feedback_repo.upsert(
run_id=run_id,
thread_id=thread_id,
rating=body.rating,
user_id=user_id,
comment=body.comment,
)
@router.delete("/{thread_id}/runs/{run_id}/feedback")
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_run_feedback(
thread_id: str,
run_id: str,
request: Request,
) -> dict[str, bool]:
"""Delete the current user's feedback for a run."""
user_id = await get_current_user(request)
feedback_repo = get_feedback_repo(request)
deleted = await feedback_repo.delete_by_run(
thread_id=thread_id,
run_id=run_id,
user_id=user_id,
)
if not deleted:
raise HTTPException(status_code=404, detail="No feedback found for this run")
return {"success": True}
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def create_feedback(
thread_id: str,
run_id: str,
body: FeedbackCreateRequest,
request: Request,
) -> dict[str, Any]:
"""Submit feedback (thumbs-up/down) for a run."""
if body.rating not in (1, -1):
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
user_id = await get_current_user(request)
# Validate run exists and belongs to thread
run_store = get_run_store(request)
run = await run_store.get(run_id)
if run is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
if run.get("thread_id") != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
feedback_repo = get_feedback_repo(request)
return await feedback_repo.create(
run_id=run_id,
thread_id=thread_id,
rating=body.rating,
user_id=user_id,
message_id=body.message_id,
comment=body.comment,
)
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
@require_permission("threads", "read", owner_check=True)
async def list_feedback(
thread_id: str,
run_id: str,
request: Request,
) -> list[dict[str, Any]]:
"""List all feedback for a run."""
feedback_repo = get_feedback_repo(request)
return await feedback_repo.list_by_run(thread_id, run_id)
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
@require_permission("threads", "read", owner_check=True)
async def feedback_stats(
thread_id: str,
run_id: str,
request: Request,
) -> dict[str, Any]:
"""Get aggregated feedback stats (positive/negative counts) for a run."""
feedback_repo = get_feedback_repo(request)
return await feedback_repo.aggregate_by_run(thread_id, run_id)
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_feedback(
thread_id: str,
run_id: str,
feedback_id: str,
request: Request,
) -> dict[str, bool]:
"""Delete a feedback record."""
feedback_repo = get_feedback_repo(request)
# Verify feedback belongs to the specified thread/run before deleting
existing = await feedback_repo.get(feedback_id)
if existing is None:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
if existing.get("thread_id") != thread_id or existing.get("run_id") != run_id:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found in run {run_id}")
deleted = await feedback_repo.delete(feedback_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
return {"success": True}
+9 -7
View File
@@ -6,7 +6,8 @@ from typing import Literal
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["mcp"]) router = APIRouter(prefix="/api", tags=["mcp"])
@@ -90,9 +91,9 @@ async def get_mcp_configuration() -> McpConfigResponse:
} }
``` ```
""" """
config = get_extensions_config() ext = AppConfig.current().extensions
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in config.mcp_servers.items()}) return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in ext.mcp_servers.items()})
@router.put( @router.put(
@@ -143,12 +144,12 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
logger.info(f"No existing extensions config found. Creating new config at: {config_path}") logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
# Load current config to preserve skills configuration # Load current config to preserve skills configuration
current_config = get_extensions_config() current_ext = AppConfig.current().extensions
# Convert request to dict format for JSON serialization # Convert request to dict format for JSON serialization
config_data = { config_data = {
"mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()}, "mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()}, "skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()},
} }
# Write the configuration to file # Write the configuration to file
@@ -161,8 +162,9 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
# will detect config file changes via mtime and reinitialize MCP tools automatically # will detect config file changes via mtime and reinitialize MCP tools automatically
# Reload the configuration and update the global cache # Reload the configuration and update the global cache
reloaded_config = reload_extensions_config() AppConfig.init(AppConfig.from_file())
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_config.mcp_servers.items()}) reloaded_ext = AppConfig.current().extensions
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_ext.mcp_servers.items()})
except Exception as e: except Exception as e:
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True) logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
+10 -13
View File
@@ -12,8 +12,7 @@ from deerflow.agents.memory.updater import (
reload_memory_data, reload_memory_data,
update_memory_fact, update_memory_fact,
) )
from deerflow.config.memory_config import get_memory_config from deerflow.config.app_config import AppConfig
from deerflow.runtime.user_context import get_effective_user_id
router = APIRouter(prefix="/api", tags=["memory"]) router = APIRouter(prefix="/api", tags=["memory"])
@@ -148,7 +147,7 @@ async def get_memory() -> MemoryResponse:
} }
``` ```
""" """
memory_data = get_memory_data(user_id=get_effective_user_id()) memory_data = get_memory_data()
return MemoryResponse(**memory_data) return MemoryResponse(**memory_data)
@@ -168,7 +167,7 @@ async def reload_memory() -> MemoryResponse:
Returns: Returns:
The reloaded memory data. The reloaded memory data.
""" """
memory_data = reload_memory_data(user_id=get_effective_user_id()) memory_data = reload_memory_data()
return MemoryResponse(**memory_data) return MemoryResponse(**memory_data)
@@ -182,7 +181,7 @@ async def reload_memory() -> MemoryResponse:
async def clear_memory() -> MemoryResponse: async def clear_memory() -> MemoryResponse:
"""Clear all persisted memory data.""" """Clear all persisted memory data."""
try: try:
memory_data = clear_memory_data(user_id=get_effective_user_id()) memory_data = clear_memory_data()
except OSError as exc: except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
@@ -203,7 +202,6 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
content=request.content, content=request.content,
category=request.category, category=request.category,
confidence=request.confidence, confidence=request.confidence,
user_id=get_effective_user_id(),
) )
except ValueError as exc: except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc raise _map_memory_fact_value_error(exc) from exc
@@ -223,7 +221,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse: async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
"""Delete a single fact from memory by fact id.""" """Delete a single fact from memory by fact id."""
try: try:
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id()) memory_data = delete_memory_fact(fact_id)
except KeyError as exc: except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc: except OSError as exc:
@@ -247,7 +245,6 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
content=request.content, content=request.content,
category=request.category, category=request.category,
confidence=request.confidence, confidence=request.confidence,
user_id=get_effective_user_id(),
) )
except ValueError as exc: except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc raise _map_memory_fact_value_error(exc) from exc
@@ -268,7 +265,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
) )
async def export_memory() -> MemoryResponse: async def export_memory() -> MemoryResponse:
"""Export the current memory data.""" """Export the current memory data."""
memory_data = get_memory_data(user_id=get_effective_user_id()) memory_data = get_memory_data()
return MemoryResponse(**memory_data) return MemoryResponse(**memory_data)
@@ -282,7 +279,7 @@ async def export_memory() -> MemoryResponse:
async def import_memory(request: MemoryResponse) -> MemoryResponse: async def import_memory(request: MemoryResponse) -> MemoryResponse:
"""Import and persist memory data.""" """Import and persist memory data."""
try: try:
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id()) memory_data = import_memory_data(request.model_dump())
except OSError as exc: except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
@@ -314,7 +311,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
} }
``` ```
""" """
config = get_memory_config() config = AppConfig.current().memory
return MemoryConfigResponse( return MemoryConfigResponse(
enabled=config.enabled, enabled=config.enabled,
storage_path=config.storage_path, storage_path=config.storage_path,
@@ -339,8 +336,8 @@ async def get_memory_status() -> MemoryStatusResponse:
Returns: Returns:
Combined memory configuration and current data. Combined memory configuration and current data.
""" """
config = get_memory_config() config = AppConfig.current().memory
memory_data = get_memory_data(user_id=get_effective_user_id()) memory_data = get_memory_data()
return MemoryStatusResponse( return MemoryStatusResponse(
config=MemoryConfigResponse( config=MemoryConfigResponse(
+3 -3
View File
@@ -1,7 +1,7 @@
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
router = APIRouter(prefix="/api", tags=["models"]) router = APIRouter(prefix="/api", tags=["models"])
@@ -58,7 +58,7 @@ async def list_models() -> ModelsListResponse:
} }
``` ```
""" """
config = get_app_config() config = AppConfig.current()
models = [ models = [
ModelResponse( ModelResponse(
name=model.name, name=model.name,
@@ -101,7 +101,7 @@ async def get_model(model_name: str) -> ModelResponse:
} }
``` ```
""" """
config = get_app_config() config = AppConfig.current()
model = config.get_model_config(model_name) model = config.get_model_config(model_name)
if model is None: if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
+2 -58
View File
@@ -11,11 +11,10 @@ import asyncio
import logging import logging
import uuid import uuid
from fastapi import APIRouter, HTTPException, Query, Request from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from app.gateway.authz import require_permission from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.routers.thread_runs import RunCreateRequest from app.gateway.routers.thread_runs import RunCreateRequest
from app.gateway.services import sse_consumer, start_run from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import serialize_channel_values from deerflow.runtime import serialize_channel_values
@@ -86,58 +85,3 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
logger.exception("Failed to fetch final state for run %s", record.run_id) logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error} return {"status": record.status.value, "error": record.error}
# ---------------------------------------------------------------------------
# Run-scoped read endpoints
# ---------------------------------------------------------------------------
async def _resolve_run(run_id: str, request: Request) -> dict:
"""Fetch run by run_id with user ownership check. Raises 404 if not found."""
run_store = get_run_store(request)
record = await run_store.get(run_id) # user_id=AUTO filters by contextvar
if record is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return record
@router.get("/{run_id}/messages")
@require_permission("runs", "read")
async def run_messages(
run_id: str,
request: Request,
limit: int = Query(default=50, le=200, ge=1),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> dict:
"""Return paginated messages for a run (cursor-based).
Pagination:
- after_seq: messages with seq > after_seq (forward)
- before_seq: messages with seq < before_seq (backward)
- neither: latest messages
Response: { data: [...], has_more: bool }
"""
run = await _resolve_run(run_id, request)
event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run(
run["thread_id"],
run_id,
limit=limit + 1,
before_seq=before_seq,
after_seq=after_seq,
)
has_more = len(rows) > limit
data = rows[:limit] if has_more else rows
return {"data": data, "has_more": has_more}
@router.get("/{run_id}/feedback")
@require_permission("runs", "read")
async def run_feedback(run_id: str, request: Request) -> list[dict]:
"""Return all feedback for a run."""
run = await _resolve_run(run_id, request)
feedback_repo = get_feedback_repo(request)
return await feedback_repo.list_by_run(run["thread_id"], run_id)
+7 -6
View File
@@ -8,7 +8,8 @@ from pydantic import BaseModel, Field
from app.gateway.path_utils import resolve_thread_virtual_path from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
from deerflow.skills import Skill, load_skills from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
from deerflow.skills.manager import ( from deerflow.skills.manager import (
@@ -325,19 +326,19 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
config_path = Path.cwd().parent / "extensions_config.json" config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}") logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
extensions_config = get_extensions_config() ext = AppConfig.current().extensions
extensions_config.skills[skill_name] = SkillStateConfig(enabled=request.enabled) ext.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
config_data = { config_data = {
"mcpServers": {name: server.model_dump() for name, server in extensions_config.mcp_servers.items()}, "mcpServers": {name: server.model_dump() for name, server in ext.mcp_servers.items()},
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in extensions_config.skills.items()}, "skills": {name: {"enabled": skill_config.enabled} for name, skill_config in ext.skills.items()},
} }
with open(config_path, "w", encoding="utf-8") as f: with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2) json.dump(config_data, f, indent=2)
logger.info(f"Skills configuration updated and saved to: {config_path}") logger.info(f"Skills configuration updated and saved to: {config_path}")
reload_extensions_config() AppConfig.init(AppConfig.from_file())
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async()
skills = load_skills(enabled_only=False) skills = load_skills(enabled_only=False)
+6 -8
View File
@@ -1,11 +1,10 @@
import json import json
import logging import logging
from fastapi import APIRouter, Request from fastapi import APIRouter
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -99,13 +98,12 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
summary="Generate Follow-up Questions", summary="Generate Follow-up Questions",
description="Generate short follow-up questions a user might ask next, based on recent conversation context.", description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
) )
@require_permission("threads", "read", owner_check=True) async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse: if not request.messages:
if not body.messages:
return SuggestionsResponse(suggestions=[]) return SuggestionsResponse(suggestions=[])
n = body.n n = request.n
conversation = _format_conversation(body.messages) conversation = _format_conversation(request.messages)
if not conversation: if not conversation:
return SuggestionsResponse(suggestions=[]) return SuggestionsResponse(suggestions=[])
@@ -122,7 +120,7 @@ async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions" user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try: try:
model = create_chat_model(name=body.model_name, thinking_enabled=False) model = create_chat_model(name=request.model_name, thinking_enabled=False)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)]) response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
raw = _extract_response_text(response.content) raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or [] suggestions = _parse_json_string_list(raw) or []
+1 -111
View File
@@ -19,8 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.authz import require_permission from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.services import sse_consumer, start_run from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import RunRecord, serialize_channel_values from deerflow.runtime import RunRecord, serialize_channel_values
@@ -93,7 +92,6 @@ def _record_to_response(record: RunRecord) -> RunResponse:
@router.post("/{thread_id}/runs", response_model=RunResponse) @router.post("/{thread_id}/runs", response_model=RunResponse)
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse: async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
"""Create a background run (returns immediately).""" """Create a background run (returns immediately)."""
record = await start_run(body, thread_id, request) record = await start_run(body, thread_id, request)
@@ -101,7 +99,6 @@ async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -
@router.post("/{thread_id}/runs/stream") @router.post("/{thread_id}/runs/stream")
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse: async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE. """Create a run and stream events via SSE.
@@ -129,7 +126,6 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
@router.post("/{thread_id}/runs/wait", response_model=dict) @router.post("/{thread_id}/runs/wait", response_model=dict)
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict: async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state.""" """Create a run and block until it completes, returning the final state."""
record = await start_run(body, thread_id, request) record = await start_run(body, thread_id, request)
@@ -155,7 +151,6 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
@router.get("/{thread_id}/runs", response_model=list[RunResponse]) @router.get("/{thread_id}/runs", response_model=list[RunResponse])
@require_permission("runs", "read", owner_check=True)
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]: async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
"""List all runs for a thread.""" """List all runs for a thread."""
run_mgr = get_run_manager(request) run_mgr = get_run_manager(request)
@@ -164,7 +159,6 @@ async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse) @router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
@require_permission("runs", "read", owner_check=True)
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse: async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
"""Get details of a specific run.""" """Get details of a specific run."""
run_mgr = get_run_manager(request) run_mgr = get_run_manager(request)
@@ -175,7 +169,6 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
@router.post("/{thread_id}/runs/{run_id}/cancel") @router.post("/{thread_id}/runs/{run_id}/cancel")
@require_permission("runs", "cancel", owner_check=True, require_existing=True)
async def cancel_run( async def cancel_run(
thread_id: str, thread_id: str,
run_id: str, run_id: str,
@@ -213,7 +206,6 @@ async def cancel_run(
@router.get("/{thread_id}/runs/{run_id}/join") @router.get("/{thread_id}/runs/{run_id}/join")
@require_permission("runs", "read", owner_check=True)
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse: async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
"""Join an existing run's SSE stream.""" """Join an existing run's SSE stream."""
bridge = get_stream_bridge(request) bridge = get_stream_bridge(request)
@@ -234,7 +226,6 @@ async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingRe
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None) @router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
@require_permission("runs", "read", owner_check=True)
async def stream_existing_run( async def stream_existing_run(
thread_id: str, thread_id: str,
run_id: str, run_id: str,
@@ -274,104 +265,3 @@ async def stream_existing_run(
"X-Accel-Buffering": "no", "X-Accel-Buffering": "no",
}, },
) )
# ---------------------------------------------------------------------------
# Messages / Events / Token usage endpoints
# ---------------------------------------------------------------------------
@router.get("/{thread_id}/messages")
@require_permission("runs", "read", owner_check=True)
async def list_thread_messages(
thread_id: str,
request: Request,
limit: int = Query(default=50, le=200),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> list[dict]:
"""Return displayable messages for a thread (across all runs), with feedback attached."""
event_store = get_run_event_store(request)
messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
# Attach feedback to the last AI message of each run
feedback_repo = get_feedback_repo(request)
user_id = await get_current_user(request)
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
# Find the last ai_message per run_id
last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list
for i, msg in enumerate(messages):
if msg.get("event_type") == "ai_message":
last_ai_per_run[msg["run_id"]] = i
# Attach feedback field
last_ai_indices = set(last_ai_per_run.values())
for i, msg in enumerate(messages):
if i in last_ai_indices:
run_id = msg["run_id"]
fb = feedback_map.get(run_id)
msg["feedback"] = (
{
"feedback_id": fb["feedback_id"],
"rating": fb["rating"],
"comment": fb.get("comment"),
}
if fb
else None
)
else:
msg["feedback"] = None
return messages
@router.get("/{thread_id}/runs/{run_id}/messages")
@require_permission("runs", "read", owner_check=True)
async def list_run_messages(
thread_id: str,
run_id: str,
request: Request,
limit: int = Query(default=50, le=200, ge=1),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> dict:
"""Return paginated messages for a specific run.
Response: { data: [...], has_more: bool }
"""
event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run(
thread_id,
run_id,
limit=limit + 1,
before_seq=before_seq,
after_seq=after_seq,
)
has_more = len(rows) > limit
data = rows[:limit] if has_more else rows
return {"data": data, "has_more": has_more}
@router.get("/{thread_id}/runs/{run_id}/events")
@require_permission("runs", "read", owner_check=True)
async def list_run_events(
thread_id: str,
run_id: str,
request: Request,
event_types: str | None = Query(default=None),
limit: int = Query(default=500, le=2000),
) -> list[dict]:
"""Return the full event stream for a run (debug/audit)."""
event_store = get_run_event_store(request)
types = event_types.split(",") if event_types else None
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
@router.get("/{thread_id}/token-usage")
@require_permission("threads", "read", owner_check=True)
async def thread_token_usage(thread_id: str, request: Request) -> dict:
"""Thread-level token usage aggregation."""
run_store = get_run_store(request)
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return {"thread_id": thread_id, **agg}
+242 -181
View File
@@ -13,41 +13,28 @@ matching the LangGraph Platform wire format expected by the
from __future__ import annotations from __future__ import annotations
import logging import logging
import re
import time import time
import uuid import uuid
from typing import Any from typing import Any
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field
from app.gateway.authz import require_permission from app.gateway.deps import get_checkpointer, get_store
from app.gateway.deps import get_checkpointer
from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values from deerflow.runtime import serialize_channel_values
from deerflow.runtime.user_context import get_effective_user_id
# ---------------------------------------------------------------------------
# Store namespace
# ---------------------------------------------------------------------------
THREADS_NS: tuple[str, ...] = ("threads",)
"""Namespace used by the Store for thread metadata records."""
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["threads"]) router = APIRouter(prefix="/api/threads", tags=["threads"])
# Metadata keys that the server controls; clients are not allowed to set
# them. Pydantic ``@field_validator("metadata")`` strips them on every
# inbound model below so a malicious client cannot reflect a forged
# owner identity through the API surface. Defense-in-depth — the
# row-level invariant is still ``threads_meta.user_id`` populated from
# the auth contextvar; this list closes the metadata-blob echo gap.
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
def _strip_reserved_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]:
"""Return ``metadata`` with server-controlled keys removed."""
if not metadata:
return metadata or {}
return {k: v for k, v in metadata.items() if k not in _SERVER_RESERVED_METADATA_KEYS}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Response / request models # Response / request models
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -76,11 +63,8 @@ class ThreadCreateRequest(BaseModel):
"""Request body for creating a thread.""" """Request body for creating a thread."""
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)") thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata") metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
class ThreadSearchRequest(BaseModel): class ThreadSearchRequest(BaseModel):
"""Request body for searching threads.""" """Request body for searching threads."""
@@ -109,8 +93,6 @@ class ThreadPatchRequest(BaseModel):
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge") metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
class ThreadStateUpdateRequest(BaseModel): class ThreadStateUpdateRequest(BaseModel):
"""Request body for updating thread state (human-in-the-loop resume).""" """Request body for updating thread state (human-in-the-loop resume)."""
@@ -144,25 +126,70 @@ class ThreadHistoryRequest(BaseModel):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _delete_thread_data(thread_id: str, paths: Paths | None = None, *, user_id: str | None = None) -> ThreadDeleteResponse: def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread.""" """Delete local persisted filesystem data for a thread."""
path_manager = paths or get_paths() path_manager = paths or get_paths()
try: try:
path_manager.delete_thread_dir(thread_id, user_id=user_id) path_manager.delete_thread_dir(thread_id)
except ValueError as exc: except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc raise HTTPException(status_code=422, detail=str(exc)) from exc
except FileNotFoundError: except FileNotFoundError:
# Not critical — thread data may not exist on disk # Not critical — thread data may not exist on disk
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id)) logger.debug("No local thread data to delete for %s", thread_id)
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}") return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
except Exception as exc: except Exception as exc:
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id)) logger.exception("Failed to delete thread data for %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id)) logger.info("Deleted local thread data for %s", thread_id)
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}") return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
async def _store_get(store, thread_id: str) -> dict | None:
"""Fetch a thread record from the Store; returns ``None`` if absent."""
item = await store.aget(THREADS_NS, thread_id)
return item.value if item is not None else None
async def _store_put(store, record: dict) -> None:
"""Write a thread record to the Store."""
await store.aput(THREADS_NS, record["thread_id"], record)
async def _store_upsert(store, thread_id: str, *, metadata: dict | None = None, values: dict | None = None) -> None:
"""Create or refresh a thread record in the Store.
On creation the record is written with ``status="idle"``. On update only
``updated_at`` (and optionally ``metadata`` / ``values``) are changed so
that existing fields are preserved.
``values`` carries the agent-state snapshot exposed to the frontend
(currently just ``{"title": "..."}``).
"""
now = time.time()
existing = await _store_get(store, thread_id)
if existing is None:
await _store_put(
store,
{
"thread_id": thread_id,
"status": "idle",
"created_at": now,
"updated_at": now,
"metadata": metadata or {},
"values": values or {},
},
)
else:
val = dict(existing)
val["updated_at"] = now
if metadata:
val.setdefault("metadata", {}).update(metadata)
if values:
val.setdefault("values", {}).update(values)
await _store_put(store, val)
def _derive_thread_status(checkpoint_tuple) -> str: def _derive_thread_status(checkpoint_tuple) -> str:
"""Derive thread status from checkpoint metadata.""" """Derive thread status from checkpoint metadata."""
if checkpoint_tuple is None: if checkpoint_tuple is None:
@@ -188,18 +215,22 @@ def _derive_thread_status(checkpoint_tuple) -> str:
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse) @router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse: async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread. """Delete local persisted filesystem data for a thread.
Cleans DeerFlow-managed thread directories, removes checkpoint data, Cleans DeerFlow-managed thread directories, removes checkpoint data,
and removes the thread_meta row from the configured ThreadMetaStore and removes the thread record from the Store.
(sqlite or memory).
""" """
from app.gateway.deps import get_thread_store
# Clean local filesystem # Clean local filesystem
response = _delete_thread_data(thread_id, user_id=get_effective_user_id()) response = _delete_thread_data(thread_id)
# Remove from Store (best-effort)
store = get_store(request)
if store is not None:
try:
await store.adelete(THREADS_NS, thread_id)
except Exception:
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
# Remove checkpoints (best-effort) # Remove checkpoints (best-effort)
checkpointer = getattr(request.app.state, "checkpointer", None) checkpointer = getattr(request.app.state, "checkpointer", None)
@@ -208,15 +239,7 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
if hasattr(checkpointer, "adelete_thread"): if hasattr(checkpointer, "adelete_thread"):
await checkpointer.adelete_thread(thread_id) await checkpointer.adelete_thread(thread_id)
except Exception: except Exception:
logger.debug("Could not delete checkpoints for thread %s (not critical)", sanitize_log_param(thread_id)) logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id)
# Remove thread_meta row (best-effort) — required for sqlite backend
# so the deleted thread no longer appears in /threads/search.
try:
thread_store = get_thread_store(request)
await thread_store.delete(thread_id)
except Exception:
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
return response return response
@@ -225,40 +248,43 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse: async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
"""Create a new thread. """Create a new thread.
Writes a thread_meta record (so the thread appears in /threads/search) The thread record is written to the Store (for fast listing) and an
and an empty checkpoint (so state endpoints work immediately). empty checkpoint is written to the checkpointer (for state reads).
Idempotent: returns the existing record when ``thread_id`` already exists. Idempotent: returns the existing record when ``thread_id`` already exists.
""" """
from app.gateway.deps import get_thread_store store = get_store(request)
checkpointer = get_checkpointer(request) checkpointer = get_checkpointer(request)
thread_store = get_thread_store(request)
thread_id = body.thread_id or str(uuid.uuid4()) thread_id = body.thread_id or str(uuid.uuid4())
now = time.time() now = time.time()
# ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
# Idempotency: return existing record when already present # Idempotency: return existing record from Store when already present
existing_record = await thread_store.get(thread_id) if store is not None:
if existing_record is not None: existing_record = await _store_get(store, thread_id)
return ThreadResponse( if existing_record is not None:
thread_id=thread_id, return ThreadResponse(
status=existing_record.get("status", "idle"), thread_id=thread_id,
created_at=str(existing_record.get("created_at", "")), status=existing_record.get("status", "idle"),
updated_at=str(existing_record.get("updated_at", "")), created_at=str(existing_record.get("created_at", "")),
metadata=existing_record.get("metadata", {}), updated_at=str(existing_record.get("updated_at", "")),
) metadata=existing_record.get("metadata", {}),
)
# Write thread_meta so the thread appears in /threads/search immediately # Write thread record to Store
try: if store is not None:
await thread_store.create( try:
thread_id, await _store_put(
assistant_id=getattr(body, "assistant_id", None), store,
metadata=body.metadata, {
) "thread_id": thread_id,
except Exception: "status": "idle",
logger.exception("Failed to write thread_meta for %s", sanitize_log_param(thread_id)) "created_at": now,
raise HTTPException(status_code=500, detail="Failed to create thread") "updated_at": now,
"metadata": body.metadata,
},
)
except Exception:
logger.exception("Failed to write thread %s to store", thread_id)
raise HTTPException(status_code=500, detail="Failed to create thread")
# Write an empty checkpoint so state endpoints work immediately # Write an empty checkpoint so state endpoints work immediately
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
@@ -275,10 +301,10 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
} }
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {}) await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
except Exception: except Exception:
logger.exception("Failed to create checkpoint for thread %s", sanitize_log_param(thread_id)) logger.exception("Failed to create checkpoint for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to create thread") raise HTTPException(status_code=500, detail="Failed to create thread")
logger.info("Thread created: %s", sanitize_log_param(thread_id)) logger.info("Thread created: %s", thread_id)
return ThreadResponse( return ThreadResponse(
thread_id=thread_id, thread_id=thread_id,
status="idle", status="idle",
@@ -292,91 +318,166 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]: async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
"""Search and list threads. """Search and list threads.
Delegates to the configured ThreadMetaStore implementation Two-phase approach:
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
"""
from app.gateway.deps import get_thread_store
repo = get_thread_store(request) **Phase 1 — Store (fast path, O(threads))**: returns threads that were
rows = await repo.search( created or run through this Gateway. Store records are tiny metadata
metadata=body.metadata or None, dicts so fetching all of them at once is cheap.
status=body.status,
limit=body.limit, **Phase 2 — Checkpointer supplement (lazy migration)**: threads that
offset=body.offset, were created directly by LangGraph Server (and therefore absent from the
) Store) are discovered here by iterating the shared checkpointer. Any
return [ newly found thread is immediately written to the Store so that the next
ThreadResponse( search skips Phase 2 for that thread — the Store converges to a full
thread_id=r["thread_id"], index over time without a one-shot migration job.
status=r.get("status", "idle"), """
created_at=r.get("created_at", ""), store = get_store(request)
updated_at=r.get("updated_at", ""), checkpointer = get_checkpointer(request)
metadata=r.get("metadata", {}),
values={"title": r["display_name"]} if r.get("display_name") else {}, # -----------------------------------------------------------------------
interrupts={}, # Phase 1: Store
) # -----------------------------------------------------------------------
for r in rows merged: dict[str, ThreadResponse] = {}
]
if store is not None:
try:
items = await store.asearch(THREADS_NS, limit=10_000)
except Exception:
logger.warning("Store search failed — falling back to checkpointer only", exc_info=True)
items = []
for item in items:
val = item.value
merged[val["thread_id"]] = ThreadResponse(
thread_id=val["thread_id"],
status=val.get("status", "idle"),
created_at=str(val.get("created_at", "")),
updated_at=str(val.get("updated_at", "")),
metadata=val.get("metadata", {}),
values=val.get("values", {}),
)
# -----------------------------------------------------------------------
# Phase 2: Checkpointer supplement
# Discovers threads not yet in the Store (e.g. created by LangGraph
# Server) and lazily migrates them so future searches skip this phase.
# -----------------------------------------------------------------------
try:
async for checkpoint_tuple in checkpointer.alist(None):
cfg = getattr(checkpoint_tuple, "config", {})
thread_id = cfg.get("configurable", {}).get("thread_id")
if not thread_id or thread_id in merged:
continue
# Skip sub-graph checkpoints (checkpoint_ns is non-empty for those)
if cfg.get("configurable", {}).get("checkpoint_ns", ""):
continue
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
# Strip LangGraph internal keys from the user-visible metadata dict
user_meta = {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
# Extract state values (title) from the checkpoint's channel_values
checkpoint_data = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint_data.get("channel_values", {})
ckpt_values = {}
if title := channel_values.get("title"):
ckpt_values["title"] = title
thread_resp = ThreadResponse(
thread_id=thread_id,
status=_derive_thread_status(checkpoint_tuple),
created_at=str(ckpt_meta.get("created_at", "")),
updated_at=str(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
metadata=user_meta,
values=ckpt_values,
)
merged[thread_id] = thread_resp
# Lazy migration — write to Store so the next search finds it there
if store is not None:
try:
await _store_upsert(store, thread_id, metadata=user_meta, values=ckpt_values or None)
except Exception:
logger.debug("Failed to migrate thread %s to store (non-fatal)", thread_id)
except Exception:
logger.exception("Checkpointer scan failed during thread search")
# Don't raise — return whatever was collected from Store + partial scan
# -----------------------------------------------------------------------
# Phase 3: Filter → sort → paginate
# -----------------------------------------------------------------------
results = list(merged.values())
if body.metadata:
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
if body.status:
results = [r for r in results if r.status == body.status]
results.sort(key=lambda r: r.updated_at, reverse=True)
return results[body.offset : body.offset + body.limit]
@router.patch("/{thread_id}", response_model=ThreadResponse) @router.patch("/{thread_id}", response_model=ThreadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse: async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
"""Merge metadata into a thread record.""" """Merge metadata into a thread record."""
from app.gateway.deps import get_thread_store store = get_store(request)
if store is None:
raise HTTPException(status_code=503, detail="Store not available")
thread_store = get_thread_store(request) record = await _store_get(store, thread_id)
record = await thread_store.get(thread_id)
if record is None: if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``. now = time.time()
updated = dict(record)
updated.setdefault("metadata", {}).update(body.metadata)
updated["updated_at"] = now
try: try:
await thread_store.update_metadata(thread_id, body.metadata) await _store_put(store, updated)
except Exception: except Exception:
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id)) logger.exception("Failed to patch thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to update thread") raise HTTPException(status_code=500, detail="Failed to update thread")
# Re-read to get the merged metadata + refreshed updated_at
record = await thread_store.get(thread_id) or record
return ThreadResponse( return ThreadResponse(
thread_id=thread_id, thread_id=thread_id,
status=record.get("status", "idle"), status=updated.get("status", "idle"),
created_at=str(record.get("created_at", "")), created_at=str(updated.get("created_at", "")),
updated_at=str(record.get("updated_at", "")), updated_at=str(now),
metadata=record.get("metadata", {}), metadata=updated.get("metadata", {}),
) )
@router.get("/{thread_id}", response_model=ThreadResponse) @router.get("/{thread_id}", response_model=ThreadResponse)
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request) -> ThreadResponse: async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
"""Get thread info. """Get thread info.
Reads metadata from the ThreadMetaStore and derives the accurate Reads metadata from the Store and derives the accurate execution
execution status from the checkpointer. Falls back to the checkpointer status from the checkpointer. Falls back to the checkpointer alone
alone for threads that pre-date ThreadMetaStore adoption (backward compat). for threads that pre-date Store adoption (backward compat).
""" """
from app.gateway.deps import get_thread_store store = get_store(request)
thread_store = get_thread_store(request)
checkpointer = get_checkpointer(request) checkpointer = get_checkpointer(request)
record: dict | None = await thread_store.get(thread_id) record: dict | None = None
if store is not None:
record = await _store_get(store, thread_id)
# Derive accurate status from the checkpointer # Derive accurate status from the checkpointer
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try: try:
checkpoint_tuple = await checkpointer.aget_tuple(config) checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception: except Exception:
logger.exception("Failed to get checkpoint for thread %s", sanitize_log_param(thread_id)) logger.exception("Failed to get checkpoint for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to get thread") raise HTTPException(status_code=500, detail="Failed to get thread")
if record is None and checkpoint_tuple is None: if record is None and checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# If the thread exists in the checkpointer but not in thread_meta (e.g. # If the thread exists in the checkpointer but not the store (e.g. legacy
# legacy data created before thread_meta adoption), synthesize a minimal # data), synthesize a minimal store record from the checkpoint metadata.
# record from the checkpoint metadata.
if record is None and checkpoint_tuple is not None: if record is None and checkpoint_tuple is not None:
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {} ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
record = { record = {
@@ -404,9 +505,7 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
) )
# ---------------------------------------------------------------------------
@router.get("/{thread_id}/state", response_model=ThreadStateResponse) @router.get("/{thread_id}/state", response_model=ThreadStateResponse)
@require_permission("threads", "read", owner_check=True)
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse: async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
"""Get the latest state snapshot for a thread. """Get the latest state snapshot for a thread.
@@ -419,7 +518,7 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
try: try:
checkpoint_tuple = await checkpointer.aget_tuple(config) checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception: except Exception:
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id)) logger.exception("Failed to get state for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to get thread state") raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None: if checkpoint_tuple is None:
@@ -443,10 +542,8 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")] next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw] tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
values = serialize_channel_values(channel_values)
return ThreadStateResponse( return ThreadStateResponse(
values=values, values=serialize_channel_values(channel_values),
next=next_tasks, next=next_tasks,
metadata=metadata, metadata=metadata,
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))}, checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
@@ -458,19 +555,15 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
@router.post("/{thread_id}/state", response_model=ThreadStateResponse) @router.post("/{thread_id}/state", response_model=ThreadStateResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse: async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
"""Update thread state (e.g. for human-in-the-loop resume or title rename). """Update thread state (e.g. for human-in-the-loop resume or title rename).
Writes a new checkpoint that merges *body.values* into the latest Writes a new checkpoint that merges *body.values* into the latest
channel values, then syncs any updated ``title`` field through the channel values, then syncs any updated ``title`` field back to the Store
ThreadMetaStore abstraction so that ``/threads/search`` reflects the so that ``/threads/search`` reflects the change immediately.
change immediately in both sqlite and memory backends.
""" """
from app.gateway.deps import get_thread_store
checkpointer = get_checkpointer(request) checkpointer = get_checkpointer(request)
thread_store = get_thread_store(request) store = get_store(request)
# checkpoint_ns must be present in the config for aput — default to "" # checkpoint_ns must be present in the config for aput — default to ""
# (the root graph namespace). checkpoint_id is optional; omitting it # (the root graph namespace). checkpoint_id is optional; omitting it
@@ -487,7 +580,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
try: try:
checkpoint_tuple = await checkpointer.aget_tuple(read_config) checkpoint_tuple = await checkpointer.aget_tuple(read_config)
except Exception: except Exception:
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id)) logger.exception("Failed to get state for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to get thread state") raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None: if checkpoint_tuple is None:
@@ -521,22 +614,19 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
try: try:
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {}) new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
except Exception: except Exception:
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id)) logger.exception("Failed to update state for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to update thread state") raise HTTPException(status_code=500, detail="Failed to update thread state")
new_checkpoint_id: str | None = None new_checkpoint_id: str | None = None
if isinstance(new_config, dict): if isinstance(new_config, dict):
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id") new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
# Sync title changes through the ThreadMetaStore abstraction so /threads/search # Sync title changes to the Store so /threads/search reflects them immediately.
# reflects them immediately in both sqlite and memory backends. if store is not None and body.values and "title" in body.values:
if body.values and "title" in body.values: try:
new_title = body.values["title"] await _store_upsert(store, thread_id, values={"title": body.values["title"]})
if new_title: # Skip empty strings and None except Exception:
try: logger.debug("Failed to sync title to store for thread %s (non-fatal)", thread_id)
await thread_store.update_display_name(thread_id, new_title)
except Exception:
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
return ThreadStateResponse( return ThreadStateResponse(
values=serialize_channel_values(channel_values), values=serialize_channel_values(channel_values),
@@ -548,16 +638,8 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
@router.post("/{thread_id}/history", response_model=list[HistoryEntry]) @router.post("/{thread_id}/history", response_model=list[HistoryEntry])
@require_permission("threads", "read", owner_check=True)
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]: async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
"""Get checkpoint history for a thread. """Get checkpoint history for a thread."""
Messages are read from the checkpointer's channel values (the
authoritative source) and serialized via
:func:`~deerflow.runtime.serialization.serialize_channel_values`.
Only the latest (first) checkpoint carries the ``messages`` key to
avoid duplicating them across every entry.
"""
checkpointer = get_checkpointer(request) checkpointer = get_checkpointer(request)
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}} config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
@@ -565,7 +647,6 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
config["configurable"]["checkpoint_id"] = body.before config["configurable"]["checkpoint_id"] = body.before
entries: list[HistoryEntry] = [] entries: list[HistoryEntry] = []
is_latest_checkpoint = True
try: try:
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit): async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
ckpt_config = getattr(checkpoint_tuple, "config", {}) ckpt_config = getattr(checkpoint_tuple, "config", {})
@@ -580,42 +661,22 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
channel_values = checkpoint.get("channel_values", {}) channel_values = checkpoint.get("channel_values", {})
# Build values from checkpoint channel_values
values: dict[str, Any] = {}
if title := channel_values.get("title"):
values["title"] = title
if thread_data := channel_values.get("thread_data"):
values["thread_data"] = thread_data
# Attach messages only to the latest checkpoint entry.
if is_latest_checkpoint:
messages = channel_values.get("messages")
if messages:
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
is_latest_checkpoint = False
# Derive next tasks # Derive next tasks
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or [] tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")] next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
# Strip LangGraph internal keys from metadata
user_meta = {k: v for k, v in metadata.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
# Keep step for ordering context
if "step" in metadata:
user_meta["step"] = metadata["step"]
entries.append( entries.append(
HistoryEntry( HistoryEntry(
checkpoint_id=checkpoint_id, checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_id, parent_checkpoint_id=parent_id,
metadata=user_meta, metadata=metadata,
values=values, values=serialize_channel_values(channel_values),
created_at=str(metadata.get("created_at", "")), created_at=str(metadata.get("created_at", "")),
next=next_tasks, next=next_tasks,
) )
) )
except Exception: except Exception:
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id)) logger.exception("Failed to get history for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to get thread history") raise HTTPException(status_code=500, detail="Failed to get thread history")
return entries return entries
+5 -11
View File
@@ -4,12 +4,10 @@ import logging
import os import os
import stat import stat
from fastapi import APIRouter, File, HTTPException, Request, UploadFile from fastapi import APIRouter, File, HTTPException, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from app.gateway.authz import require_permission
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import get_sandbox_provider from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.uploads.manager import ( from deerflow.uploads.manager import (
PathTraversalError, PathTraversalError,
@@ -56,10 +54,8 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
@router.post("", response_model=UploadResponse) @router.post("", response_model=UploadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=False)
async def upload_files( async def upload_files(
thread_id: str, thread_id: str,
request: Request,
files: list[UploadFile] = File(...), files: list[UploadFile] = File(...),
) -> UploadResponse: ) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory.""" """Upload multiple files to a thread's uploads directory."""
@@ -70,7 +66,7 @@ async def upload_files(
uploads_dir = ensure_uploads_dir(thread_id) uploads_dir = ensure_uploads_dir(thread_id)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
uploaded_files = [] uploaded_files = []
sandbox_provider = get_sandbox_provider() sandbox_provider = get_sandbox_provider()
@@ -137,8 +133,7 @@ async def upload_files(
@router.get("/list", response_model=dict) @router.get("/list", response_model=dict)
@require_permission("threads", "read", owner_check=True) async def list_uploaded_files(thread_id: str) -> dict:
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
"""List all files in a thread's uploads directory.""" """List all files in a thread's uploads directory."""
try: try:
uploads_dir = get_uploads_dir(thread_id) uploads_dir = get_uploads_dir(thread_id)
@@ -148,7 +143,7 @@ async def list_uploaded_files(thread_id: str, request: Request) -> dict:
enrich_file_listing(result, thread_id) enrich_file_listing(result, thread_id)
# Gateway additionally includes the sandbox-relative path. # Gateway additionally includes the sandbox-relative path.
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
for f in result["files"]: for f in result["files"]:
f["path"] = str(sandbox_uploads / f["filename"]) f["path"] = str(sandbox_uploads / f["filename"])
@@ -156,8 +151,7 @@ async def list_uploaded_files(thread_id: str, request: Request) -> dict:
@router.delete("/{filename}") @router.delete("/{filename}")
@require_permission("threads", "delete", owner_check=True, require_existing=True) async def delete_uploaded_file(thread_id: str, filename: str) -> dict:
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
"""Delete a file from a thread's uploads directory.""" """Delete a file from a thread's uploads directory."""
try: try:
uploads_dir = get_uploads_dir(thread_id) uploads_dir = get_uploads_dir(thread_id)
+81 -23
View File
@@ -8,17 +8,16 @@ frames, and consuming stream bridge events. Router modules
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import dataclasses
import json import json
import logging import logging
import re import re
import time
from typing import Any from typing import Any
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
from app.gateway.utils import sanitize_log_param
from deerflow.runtime import ( from deerflow.runtime import (
END_SENTINEL, END_SENTINEL,
HEARTBEAT_SENTINEL, HEARTBEAT_SENTINEL,
@@ -172,6 +171,71 @@ def build_run_config(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def _upsert_thread_in_store(store, thread_id: str, metadata: dict | None) -> None:
"""Create or refresh the thread record in the Store.
Called from :func:`start_run` so that threads created via the stateless
``/runs/stream`` endpoint (which never calls ``POST /threads``) still
appear in ``/threads/search`` results.
"""
# Deferred import to avoid circular import with the threads router module.
from app.gateway.routers.threads import _store_upsert
try:
await _store_upsert(store, thread_id, metadata=metadata)
except Exception:
logger.warning("Failed to upsert thread %s in store (non-fatal)", thread_id)
async def _sync_thread_title_after_run(
run_task: asyncio.Task,
thread_id: str,
checkpointer: Any,
store: Any,
) -> None:
"""Wait for *run_task* to finish, then persist the generated title to the Store.
TitleMiddleware writes the generated title to the LangGraph agent state
(checkpointer) but the Gateway's Store record is not updated automatically.
This coroutine closes that gap by reading the final checkpoint after the
run completes and syncing ``values.title`` into the Store record so that
subsequent ``/threads/search`` responses include the correct title.
Runs as a fire-and-forget :func:`asyncio.create_task`; failures are
logged at DEBUG level and never propagate.
"""
# Wait for the background run task to complete (any outcome).
# asyncio.wait does not propagate task exceptions — it just returns
# when the task is done, cancelled, or failed.
await asyncio.wait({run_task})
# Deferred import to avoid circular import with the threads router module.
from app.gateway.routers.threads import _store_get, _store_put
try:
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
if ckpt_tuple is None:
return
channel_values = ckpt_tuple.checkpoint.get("channel_values", {})
title = channel_values.get("title")
if not title:
return
existing = await _store_get(store, thread_id)
if existing is None:
return
updated = dict(existing)
updated.setdefault("values", {})["title"] = title
updated["updated_at"] = time.time()
await _store_put(store, updated)
logger.debug("Synced title %r for thread %s", title, thread_id)
except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id, exc_info=True)
async def start_run( async def start_run(
body: Any, body: Any,
thread_id: str, thread_id: str,
@@ -191,7 +255,8 @@ async def start_run(
""" """
bridge = get_stream_bridge(request) bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request) run_mgr = get_run_manager(request)
run_ctx = get_run_context(request) checkpointer = get_checkpointer(request)
store = get_store(request)
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_ disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
@@ -209,21 +274,11 @@ async def start_run(
except UnsupportedStrategyError as exc: except UnsupportedStrategyError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc raise HTTPException(status_code=501, detail=str(exc)) from exc
# Upsert thread metadata so the thread appears in /threads/search, # Ensure the thread is visible in /threads/search, even for threads that
# even for threads that were never explicitly created via POST /threads # were never explicitly created via POST /threads (e.g. stateless runs).
# (e.g. stateless runs). store = get_store(request)
try: if store is not None:
existing = await run_ctx.thread_store.get(thread_id) await _upsert_thread_in_store(store, thread_id, body.metadata)
if existing is None:
await run_ctx.thread_store.create(
thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
)
else:
await run_ctx.thread_store.update_status(thread_id, "running")
except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
agent_factory = resolve_agent_factory(body.assistant_id) agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input) graph_input = normalize_input(body.input)
@@ -256,7 +311,8 @@ async def start_run(
bridge, bridge,
run_mgr, run_mgr,
record, record,
ctx=run_ctx, checkpointer=checkpointer,
store=store,
agent_factory=agent_factory, agent_factory=agent_factory,
graph_input=graph_input, graph_input=graph_input,
config=config, config=config,
@@ -268,9 +324,11 @@ async def start_run(
) )
record.task = task record.task = task
# Title sync is handled by worker.py's finally block which reads the # After the run completes, sync the title generated by TitleMiddleware from
# title from the checkpoint and calls thread_store.update_display_name # the checkpointer into the Store record so that /threads/search returns the
# after the run completes. # correct title instead of an empty values dict.
if store is not None:
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
return record return record
-6
View File
@@ -1,6 +0,0 @@
"""Shared utility helpers for the Gateway layer."""
def sanitize_log_param(value: str) -> str:
"""Strip control characters to prevent log injection."""
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
-77
View File
@@ -1,77 +0,0 @@
# Docker Test Gap (Section 七 7.4)
This file documents the only **un-executed** test cases from
`backend/docs/AUTH_TEST_PLAN.md` after the full release validation pass.
## Why this gap exists
The release validation environment (sg_dev: `10.251.229.92`) **does not have
a Docker daemon installed**. The TC-DOCKER cases are container-runtime
behavior tests that need an actual Docker engine to spin up
`docker/docker-compose.yaml` services.
```bash
$ ssh sg_dev "which docker; docker --version"
# (empty)
# bash: docker: command not found
```
All other test plan sections were executed against either:
- The local dev box (Mac, all services running locally), or
- The deployed sg_dev instance (gateway + frontend + nginx via SSH tunnel)
## Cases not executed
| Case | Title | What it covers | Why not run |
|---|---|---|---|
| TC-DOCKER-01 | `users.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` |
| TC-DOCKER-02 | Session persistence across container restart | `AUTH_JWT_SECRET` env var keeps cookies valid after `docker compose down && up` | needs `docker compose down/up` |
| TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container |
| TC-DOCKER-04 | IM channels skip AuthMiddleware | Verify Feishu/Slack/Telegram dispatchers run in-container against `http://langgraph:2024` without going through nginx | needs `docker logs` |
| TC-DOCKER-05 | Admin credentials surfacing | **Updated post-simplify** — was "log scrape", now "0600 credential file in `DEER_FLOW_HOME`". The file-based behavior is already validated by TC-1.1 + TC-UPG-13 on sg_dev (non-Docker), so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume |
| TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` |
## Coverage already provided by non-Docker tests
The **auth-relevant** behavior in each Docker case is already exercised by
the test cases that ran on sg_dev or local:
| Docker case | Auth behavior covered by |
|---|---|
| TC-DOCKER-01 (volume persistence) | TC-REENT-01 on sg_dev (admin row survives gateway restart) — same SQLite file, just no container layer between |
| TC-DOCKER-02 (session persistence) | TC-API-02/03/06 (cookie roundtrip), plus TC-REENT-04 (multi-cookie) — JWT verification is process-state-free, container restart is equivalent to `pkill uvicorn && uv run uvicorn` |
| TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs |
| TC-DOCKER-04 (IM channels skip auth) | Code-level only: `app/channels/manager.py` uses `langgraph_sdk` directly with no cookie handling. The langgraph_auth handler is bypassed by going through SDK, not HTTP |
| TC-DOCKER-05 (credential surfacing) | TC-1.1 on sg_dev (file at `~/deer-flow/backend/.deer-flow/admin_initial_credentials.txt`, mode 0600, password 22 chars) — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change |
| TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change |
## Reproduction steps when Docker becomes available
Anyone with `docker` + `docker compose` installed can reproduce the gap by
running the test plan section verbatim. Pre-flight:
```bash
# Required on the host
docker --version # >=24.x
docker compose version # plugin >=2.x
# Required env var (otherwise sessions reset on every container restart)
echo "AUTH_JWT_SECRET=$(python3 -c 'import secrets; print(secrets.token_urlsafe(32))')" \
>> .env
# Optional: pin DEER_FLOW_HOME to a stable host path
echo "DEER_FLOW_HOME=$HOME/deer-flow-data" >> .env
```
Then run TC-DOCKER-01..06 from the test plan as written.
## Decision log
- **Not blocking the release.** The auth-relevant behavior in every Docker
case has an already-validated equivalent on bare metal. The gap is purely
about *container packaging* details (bind mounts, multi-worker, log
collection), not about whether the auth code paths work.
- **TC-DOCKER-05 was updated in place** in `AUTH_TEST_PLAN.md` to reflect
the post-simplify reality (credentials file → 0600 file, no log leak).
The old "grep 'Password:' in docker logs" expectation would have failed
silently and given a false sense of coverage.
File diff suppressed because it is too large Load Diff
-129
View File
@@ -1,129 +0,0 @@
# Authentication Upgrade Guide
DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。
## 核心概念
认证模块采用**始终强制**策略:
- 首次启动时自动创建 admin 账号,随机密码打印到控制台日志
- 认证从一开始就是强制的,无竞争窗口
- 历史对话(升级前创建的 thread)自动迁移到 admin 名下
## 升级步骤
### 1. 更新代码
```bash
git pull origin main
cd backend && make install
```
### 2. 首次启动
```bash
make dev
```
控制台会输出:
```
============================================================
Admin account created on first boot
Email: admin@deerflow.dev
Password: aB3xK9mN_pQ7rT2w
Change it after login: Settings → Account
============================================================
```
如果未登录就重启了服务,不用担心——只要 setup 未完成,每次启动都会重置密码并重新打印到控制台。
### 3. 登录
访问 `http://localhost:2026/login`,使用控制台输出的邮箱和密码登录。
### 4. 修改密码
登录后进入 Settings → Account → Change Password。
### 5. 添加用户(可选)
其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话。
## 安全机制
| 机制 | 说明 |
|------|------|
| JWT HttpOnly Cookie | Token 不暴露给 JavaScript,防止 XSS 窃取 |
| CSRF Double Submit Cookie | 所有 POST/PUT/DELETE 请求需携带 `X-CSRF-Token` |
| bcrypt 密码哈希 | 密码不以明文存储 |
| 多租户隔离 | 用户只能访问自己的 thread |
| HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 |
## 常见操作
### 忘记密码
```bash
cd backend
# 重置 admin 密码
python -m app.gateway.auth.reset_admin
# 重置指定用户密码
python -m app.gateway.auth.reset_admin --email user@example.com
```
会输出新的随机密码。
### 完全重置
删除用户数据库,重启后自动创建新 admin:
```bash
rm -f backend/.deer-flow/users.db
# 重启服务,控制台输出新密码
```
## 数据存储
| 文件 | 内容 |
|------|------|
| `.deer-flow/users.db` | SQLite 用户数据库(密码哈希、角色) |
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) |
### 生产环境建议
```bash
# 生成持久化 JWT 密钥,避免重启后所有用户需重新登录
python -c "import secrets; print(secrets.token_urlsafe(32))"
# 将输出添加到 .env
# AUTH_JWT_SECRET=<生成的密钥>
```
## API 端点
| 端点 | 方法 | 说明 |
|------|------|------|
| `/api/v1/auth/login/local` | POST | 邮箱密码登录(OAuth2 form |
| `/api/v1/auth/register` | POST | 注册新用户(user 角色) |
| `/api/v1/auth/logout` | POST | 登出(清除 cookie |
| `/api/v1/auth/me` | GET | 获取当前用户信息 |
| `/api/v1/auth/change-password` | POST | 修改密码 |
| `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 |
## 兼容性
- **标准模式**`make dev`):完全兼容,admin 自动创建
- **Gateway 模式**`make dev-pro`):完全兼容
- **Docker 部署**:完全兼容,`.deer-flow/users.db` 需持久化卷挂载
- **IM 渠道**Feishu/Slack/Telegram):通过 LangGraph SDK 通信,不经过认证层
- **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响
## 故障排查
| 症状 | 原因 | 解决 |
|------|------|------|
| 启动后没看到密码 | admin 已存在(非首次启动) | 用 `reset_admin` 重置,或删 `users.db` |
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |
@@ -124,7 +124,7 @@ title:
# checkpointer.py # checkpointer.py
from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.checkpoint.sqlite import SqliteSaver
checkpointer = SqliteSaver.from_conn_string("deerflow.db") checkpointer = SqliteSaver.from_conn_string("checkpoints.db")
``` ```
```json ```json
-3
View File
@@ -8,9 +8,6 @@
"graphs": { "graphs": {
"lead_agent": "deerflow.agents:make_lead_agent" "lead_agent": "deerflow.agents:make_lead_agent"
}, },
"auth": {
"path": "./app/gateway/langgraph_auth.py:auth"
},
"checkpointer": { "checkpointer": {
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer" "path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
} }
@@ -1,3 +1,4 @@
from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer
from .factory import create_deerflow_agent from .factory import create_deerflow_agent
from .features import Next, Prev, RuntimeFeatures from .features import Next, Prev, RuntimeFeatures
from .lead_agent import make_lead_agent from .lead_agent import make_lead_agent
@@ -17,4 +18,7 @@ __all__ = [
"make_lead_agent", "make_lead_agent",
"SandboxState", "SandboxState",
"ThreadState", "ThreadState",
"get_checkpointer",
"reset_checkpointer",
"make_checkpointer",
] ]
@@ -7,12 +7,12 @@ Supported backends: memory, sqlite, postgres.
Usage (e.g. FastAPI lifespan):: Usage (e.g. FastAPI lifespan)::
from deerflow.runtime.checkpointer.async_provider import make_checkpointer from deerflow.agents.checkpointer.async_provider import make_checkpointer
async with make_checkpointer() as checkpointer: async with make_checkpointer() as checkpointer:
app.state.checkpointer = checkpointer # InMemorySaver if not configured app.state.checkpointer = checkpointer # InMemorySaver if not configured
For sync usage see :mod:`deerflow.runtime.checkpointer.provider`. For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
""" """
from __future__ import annotations from __future__ import annotations
@@ -24,12 +24,12 @@ from collections.abc import AsyncIterator
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config from deerflow.agents.checkpointer.provider import (
from deerflow.runtime.checkpointer.provider import (
POSTGRES_CONN_REQUIRED, POSTGRES_CONN_REQUIRED,
POSTGRES_INSTALL, POSTGRES_INSTALL,
SQLITE_INSTALL, SQLITE_INSTALL,
) )
from deerflow.config.app_config import AppConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -83,77 +83,24 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@contextlib.asynccontextmanager
async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpointer]:
"""Async context manager that constructs a checkpointer from unified DatabaseConfig."""
if db_config.backend == "memory":
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
if db_config.backend == "sqlite":
try:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = db_config.checkpointer_sqlite_path
ensure_sqlite_parent_dir(conn_str)
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
await saver.setup()
yield saver
return
if db_config.backend == "postgres":
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not db_config.postgres_url:
raise ValueError("database.postgres_url is required for the postgres backend")
async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver:
await saver.setup()
yield saver
return
raise ValueError(f"Unknown database backend: {db_config.backend!r}")
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def make_checkpointer() -> AsyncIterator[Checkpointer]: async def make_checkpointer() -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime. """Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit -- no global state:: Resources are opened on enter and closed on exit no global state::
async with make_checkpointer() as checkpointer: async with make_checkpointer() as checkpointer:
app.state.checkpointer = checkpointer app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
Priority:
1. Legacy ``checkpointer:`` config section (backward compatible)
2. Unified ``database:`` config section
3. Default InMemorySaver
""" """
config = get_app_config() config = AppConfig.current()
# Legacy: standalone checkpointer config takes precedence if config.checkpointer is None:
if config.checkpointer is not None: from langgraph.checkpoint.memory import InMemorySaver
async with _async_checkpointer(config.checkpointer) as saver:
yield saver
return
# Unified database config yield InMemorySaver()
db_config = getattr(config, "database", None) return
if db_config is not None and db_config.backend != "memory":
async with _async_checkpointer_from_database(db_config) as saver:
yield saver
return
# Default: in-memory async with _async_checkpointer(config.checkpointer) as saver:
from langgraph.checkpoint.memory import InMemorySaver yield saver
yield InMemorySaver()
@@ -7,7 +7,7 @@ Supported backends: memory, sqlite, postgres.
Usage:: Usage::
from deerflow.runtime.checkpointer.provider import get_checkpointer, checkpointer_context from deerflow.agents.checkpointer.provider import get_checkpointer, checkpointer_context
# Singleton — reused across calls, closed on process exit # Singleton — reused across calls, closed on process exit
cp = get_checkpointer() cp = get_checkpointer()
@@ -25,7 +25,7 @@ from collections.abc import Iterator
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.checkpointer_config import CheckpointerConfig from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import resolve_sqlite_conn_str
@@ -113,25 +113,10 @@ def get_checkpointer() -> Checkpointer:
if _checkpointer is not None: if _checkpointer is not None:
return _checkpointer return _checkpointer
# Ensure app config is loaded before checking checkpointer config try:
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section config = AppConfig.current().checkpointer
# but hasn't been loaded yet except (LookupError, FileNotFoundError):
from deerflow.config.app_config import _app_config config = None
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
if config is None: if config is None:
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
@@ -180,7 +165,7 @@ def checkpointer_context() -> Iterator[Checkpointer]:
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
""" """
config = get_app_config() config = AppConfig.current()
if config.checkpointer is None: if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
@@ -1,15 +1,15 @@
import logging import logging
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from deerflow.agents.lead_agent.prompt import apply_prompt_template from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from deerflow.agents.middlewares.summarization_middleware import SummarizationMiddleware
from deerflow.agents.middlewares.title_middleware import TitleMiddleware from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
@@ -17,8 +17,8 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import build_lea
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config from deerflow.config.agents_config import load_agent_config
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.summarization_config import get_summarization_config from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
def _resolve_model_name(requested_model_name: str | None = None) -> str: def _resolve_model_name(requested_model_name: str | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured.""" """Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config() app_config = AppConfig.current()
default_model_name = app_config.models[0].name if app_config.models else None default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None: if default_model_name is None:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.") raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
@@ -41,7 +41,7 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
def _create_summarization_middleware() -> SummarizationMiddleware | None: def _create_summarization_middleware() -> SummarizationMiddleware | None:
"""Create and configure the summarization middleware from config.""" """Create and configure the summarization middleware from config."""
config = get_summarization_config() config = AppConfig.current().summarization
if not config.enabled: if not config.enabled:
return None return None
@@ -57,15 +57,13 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
# Prepare keep parameter # Prepare keep parameter
keep = config.keep.to_tuple() keep = config.keep.to_tuple()
# Prepare model parameter. # Prepare model parameter
# Bind "middleware:summarize" tag so RunJournal identifies these LLM calls
# as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time).
if config.model_name: if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False) model = create_chat_model(name=config.model_name, thinking_enabled=False)
else: else:
# Use a lightweight model for summarization to save costs
# Falls back to default model if not explicitly specified
model = create_chat_model(thinking_enabled=False) model = create_chat_model(thinking_enabled=False)
model = model.with_config(tags=["middleware:summarize"])
# Prepare kwargs # Prepare kwargs
kwargs = { kwargs = {
@@ -233,7 +231,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
middlewares.append(todo_list_middleware) middlewares.append(todo_list_middleware)
# Add TokenUsageMiddleware when token_usage tracking is enabled # Add TokenUsageMiddleware when token_usage tracking is enabled
if get_app_config().token_usage.enabled: if AppConfig.current().token_usage.enabled:
middlewares.append(TokenUsageMiddleware()) middlewares.append(TokenUsageMiddleware())
# Add TitleMiddleware # Add TitleMiddleware
@@ -244,7 +242,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
# Add ViewImageMiddleware only if the current model supports vision. # Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values. # Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = get_app_config() app_config = AppConfig.current()
model_config = app_config.get_model_config(model_name) if model_name else None model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision: if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware()) middlewares.append(ViewImageMiddleware())
@@ -273,7 +271,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
return middlewares return middlewares
def make_lead_agent(config: RunnableConfig): def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph:
# Lazy import to avoid circular dependency # Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent from deerflow.tools.builtins import setup_agent
@@ -296,7 +294,7 @@ def make_lead_agent(config: RunnableConfig):
# Final model name resolution: request → agent config → global default, with fallback for unknown names # Final model name resolution: request → agent config → global default, with fallback for unknown names
model_name = _resolve_model_name(requested_model_name or agent_model_name) model_name = _resolve_model_name(requested_model_name or agent_model_name)
app_config = get_app_config() app_config = AppConfig.current()
model_config = app_config.get_model_config(model_name) model_config = app_config.get_model_config(model_name)
if model_config is None: if model_config is None:
@@ -339,6 +337,7 @@ def make_lead_agent(config: RunnableConfig):
middleware=_build_middlewares(config, model_name=model_name), middleware=_build_middlewares(config, model_name=model_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])), system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
state_schema=ThreadState, state_schema=ThreadState,
context_schema=DeerFlowContext,
) )
# Default lead agent (unchanged behavior) # Default lead agent (unchanged behavior)
@@ -350,4 +349,5 @@ def make_lead_agent(config: RunnableConfig):
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
), ),
state_schema=ThreadState, state_schema=ThreadState,
context_schema=DeerFlowContext,
) )
@@ -5,6 +5,7 @@ from datetime import datetime
from functools import lru_cache from functools import lru_cache
from deerflow.config.agents_config import load_agent_soul from deerflow.config.agents_config import load_agent_soul
from deerflow.config.app_config import AppConfig
from deerflow.skills import load_skills from deerflow.skills import load_skills
from deerflow.skills.types import Skill from deerflow.skills.types import Skill
from deerflow.subagents import get_available_subagent_names from deerflow.subagents import get_available_subagent_names
@@ -518,14 +519,12 @@ def _get_memory_context(agent_name: str | None = None) -> str:
""" """
try: try:
from deerflow.agents.memory import format_memory_for_injection, get_memory_data from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.config.memory_config import get_memory_config
from deerflow.runtime.user_context import get_effective_user_id
config = get_memory_config() config = AppConfig.current().memory
if not config.enabled or not config.injection_enabled: if not config.enabled or not config.injection_enabled:
return "" return ""
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id()) memory_data = get_memory_data(agent_name)
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens) memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
if not memory_content.strip(): if not memory_content.strip():
@@ -577,9 +576,7 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
skills = _get_enabled_skills() skills = _get_enabled_skills()
try: try:
from deerflow.config import get_app_config config = AppConfig.current()
config = get_app_config()
container_base_path = config.skills.container_path container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled skill_evolution_enabled = config.skill_evolution.enabled
except Exception: except Exception:
@@ -618,9 +615,7 @@ def get_deferred_tools_prompt_section() -> str:
from deerflow.tools.builtins.tool_search import get_deferred_registry from deerflow.tools.builtins.tool_search import get_deferred_registry
try: try:
from deerflow.config import get_app_config if not AppConfig.current().tool_search.enabled:
if not get_app_config().tool_search.enabled:
return "" return ""
except Exception: except Exception:
return "" return ""
@@ -636,9 +631,7 @@ def get_deferred_tools_prompt_section() -> str:
def _build_acp_section() -> str: def _build_acp_section() -> str:
"""Build the ACP agent prompt section, only if ACP agents are configured.""" """Build the ACP agent prompt section, only if ACP agents are configured."""
try: try:
from deerflow.config.acp_config import get_acp_agents agents = AppConfig.current().acp_agents
agents = get_acp_agents()
if not agents: if not agents:
return "" return ""
except Exception: except Exception:
@@ -656,9 +649,7 @@ def _build_acp_section() -> str:
def _build_custom_mounts_section() -> str: def _build_custom_mounts_section() -> str:
"""Build a prompt section for explicitly configured sandbox mounts.""" """Build a prompt section for explicitly configured sandbox mounts."""
try: try:
from deerflow.config import get_app_config mounts = AppConfig.current().sandbox.mounts or []
mounts = get_app_config().sandbox.mounts or []
except Exception: except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt") logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return "" return ""
@@ -7,7 +7,7 @@ from dataclasses import dataclass, field
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
from deerflow.config.memory_config import get_memory_config from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,7 +20,6 @@ class ConversationContext:
messages: list[Any] messages: list[Any]
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC)) timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
agent_name: str | None = None agent_name: str | None = None
user_id: str | None = None
correction_detected: bool = False correction_detected: bool = False
reinforcement_detected: bool = False reinforcement_detected: bool = False
@@ -45,7 +44,6 @@ class MemoryUpdateQueue:
thread_id: str, thread_id: str,
messages: list[Any], messages: list[Any],
agent_name: str | None = None, agent_name: str | None = None,
user_id: str | None = None,
correction_detected: bool = False, correction_detected: bool = False,
reinforcement_detected: bool = False, reinforcement_detected: bool = False,
) -> None: ) -> None:
@@ -55,13 +53,10 @@ class MemoryUpdateQueue:
thread_id: The thread ID. thread_id: The thread ID.
messages: The conversation messages. messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory. agent_name: If provided, memory is stored per-agent. If None, uses global memory.
user_id: The user ID captured at enqueue time. Stored in ConversationContext so it
survives the threading.Timer boundary (ContextVar does not propagate across
raw threads).
correction_detected: Whether recent turns include an explicit correction signal. correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal. reinforcement_detected: Whether recent turns include a positive reinforcement signal.
""" """
config = get_memory_config() config = AppConfig.current().memory
if not config.enabled: if not config.enabled:
return return
@@ -76,7 +71,6 @@ class MemoryUpdateQueue:
thread_id=thread_id, thread_id=thread_id,
messages=messages, messages=messages,
agent_name=agent_name, agent_name=agent_name,
user_id=user_id,
correction_detected=merged_correction_detected, correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected, reinforcement_detected=merged_reinforcement_detected,
) )
@@ -93,7 +87,7 @@ class MemoryUpdateQueue:
def _reset_timer(self) -> None: def _reset_timer(self) -> None:
"""Reset the debounce timer.""" """Reset the debounce timer."""
config = get_memory_config() config = AppConfig.current().memory
# Cancel existing timer if any # Cancel existing timer if any
if self._timer is not None: if self._timer is not None:
@@ -142,7 +136,6 @@ class MemoryUpdateQueue:
agent_name=context.agent_name, agent_name=context.agent_name,
correction_detected=context.correction_detected, correction_detected=context.correction_detected,
reinforcement_detected=context.reinforcement_detected, reinforcement_detected=context.reinforcement_detected,
user_id=context.user_id,
) )
if success: if success:
logger.info("Memory updated successfully for thread %s", context.thread_id) logger.info("Memory updated successfully for thread %s", context.thread_id)
@@ -9,7 +9,7 @@ from pathlib import Path
from typing import Any from typing import Any
from deerflow.config.agents_config import AGENT_NAME_PATTERN from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.memory_config import get_memory_config from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -43,17 +43,17 @@ class MemoryStorage(abc.ABC):
"""Abstract base class for memory storage providers.""" """Abstract base class for memory storage providers."""
@abc.abstractmethod @abc.abstractmethod
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def load(self, agent_name: str | None = None) -> dict[str, Any]:
"""Load memory data for the given agent.""" """Load memory data for the given agent."""
pass pass
@abc.abstractmethod @abc.abstractmethod
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def reload(self, agent_name: str | None = None) -> dict[str, Any]:
"""Force reload memory data for the given agent.""" """Force reload memory data for the given agent."""
pass pass
@abc.abstractmethod @abc.abstractmethod
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool: def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
"""Save memory data for the given agent.""" """Save memory data for the given agent."""
pass pass
@@ -63,9 +63,9 @@ class FileMemoryStorage(MemoryStorage):
def __init__(self): def __init__(self):
"""Initialize the file memory storage.""" """Initialize the file memory storage."""
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global) # Per-agent memory cache: keyed by agent_name (None = global)
# Value: (memory_data, file_mtime) # Value: (memory_data, file_mtime)
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {} self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
def _validate_agent_name(self, agent_name: str) -> None: def _validate_agent_name(self, agent_name: str) -> None:
"""Validate that the agent name is safe to use in filesystem paths. """Validate that the agent name is safe to use in filesystem paths.
@@ -78,29 +78,21 @@ class FileMemoryStorage(MemoryStorage):
if not AGENT_NAME_PATTERN.match(agent_name): if not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}") raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path: def _get_memory_file_path(self, agent_name: str | None = None) -> Path:
"""Get the path to the memory file.""" """Get the path to the memory file."""
if user_id is not None:
if agent_name is not None:
self._validate_agent_name(agent_name)
return get_paths().user_agent_memory_file(user_id, agent_name)
config = get_memory_config()
if config.storage_path and Path(config.storage_path).is_absolute():
return Path(config.storage_path)
return get_paths().user_memory_file(user_id)
# Legacy: no user_id
if agent_name is not None: if agent_name is not None:
self._validate_agent_name(agent_name) self._validate_agent_name(agent_name)
return get_paths().agent_memory_file(agent_name) return get_paths().agent_memory_file(agent_name)
config = get_memory_config()
config = AppConfig.current().memory
if config.storage_path: if config.storage_path:
p = Path(config.storage_path) p = Path(config.storage_path)
return p if p.is_absolute() else get_paths().base_dir / p return p if p.is_absolute() else get_paths().base_dir / p
return get_paths().memory_file return get_paths().memory_file
def _load_memory_from_file(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]:
"""Load memory data from file.""" """Load memory data from file."""
file_path = self._get_memory_file_path(agent_name, user_id=user_id) file_path = self._get_memory_file_path(agent_name)
if not file_path.exists(): if not file_path.exists():
return create_empty_memory() return create_empty_memory()
@@ -113,42 +105,40 @@ class FileMemoryStorage(MemoryStorage):
logger.warning("Failed to load memory file: %s", e) logger.warning("Failed to load memory file: %s", e)
return create_empty_memory() return create_empty_memory()
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def load(self, agent_name: str | None = None) -> dict[str, Any]:
"""Load memory data (cached with file modification time check).""" """Load memory data (cached with file modification time check)."""
file_path = self._get_memory_file_path(agent_name, user_id=user_id) file_path = self._get_memory_file_path(agent_name)
try: try:
current_mtime = file_path.stat().st_mtime if file_path.exists() else None current_mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError: except OSError:
current_mtime = None current_mtime = None
cache_key = (user_id, agent_name) cached = self._memory_cache.get(agent_name)
cached = self._memory_cache.get(cache_key)
if cached is None or cached[1] != current_mtime: if cached is None or cached[1] != current_mtime:
memory_data = self._load_memory_from_file(agent_name, user_id=user_id) memory_data = self._load_memory_from_file(agent_name)
self._memory_cache[cache_key] = (memory_data, current_mtime) self._memory_cache[agent_name] = (memory_data, current_mtime)
return memory_data return memory_data
return cached[0] return cached[0]
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def reload(self, agent_name: str | None = None) -> dict[str, Any]:
"""Reload memory data from file, forcing cache invalidation.""" """Reload memory data from file, forcing cache invalidation."""
file_path = self._get_memory_file_path(agent_name, user_id=user_id) file_path = self._get_memory_file_path(agent_name)
memory_data = self._load_memory_from_file(agent_name, user_id=user_id) memory_data = self._load_memory_from_file(agent_name)
try: try:
mtime = file_path.stat().st_mtime if file_path.exists() else None mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError: except OSError:
mtime = None mtime = None
cache_key = (user_id, agent_name) self._memory_cache[agent_name] = (memory_data, mtime)
self._memory_cache[cache_key] = (memory_data, mtime)
return memory_data return memory_data
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool: def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
"""Save memory data to file and update cache.""" """Save memory data to file and update cache."""
file_path = self._get_memory_file_path(agent_name, user_id=user_id) file_path = self._get_memory_file_path(agent_name)
try: try:
file_path.parent.mkdir(parents=True, exist_ok=True) file_path.parent.mkdir(parents=True, exist_ok=True)
@@ -165,8 +155,7 @@ class FileMemoryStorage(MemoryStorage):
except OSError: except OSError:
mtime = None mtime = None
cache_key = (user_id, agent_name) self._memory_cache[agent_name] = (memory_data, mtime)
self._memory_cache[cache_key] = (memory_data, mtime)
logger.info("Memory saved to %s", file_path) logger.info("Memory saved to %s", file_path)
return True return True
except OSError as e: except OSError as e:
@@ -188,7 +177,7 @@ def get_memory_storage() -> MemoryStorage:
if _storage_instance is not None: if _storage_instance is not None:
return _storage_instance return _storage_instance
config = get_memory_config() config = AppConfig.current().memory
storage_class_path = config.storage_class storage_class_path = config.storage_class
try: try:
@@ -16,7 +16,7 @@ from deerflow.agents.memory.storage import (
get_memory_storage, get_memory_storage,
utc_now_iso_z, utc_now_iso_z,
) )
from deerflow.config.memory_config import get_memory_config from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -27,28 +27,27 @@ def _create_empty_memory() -> dict[str, Any]:
return create_empty_memory() return create_empty_memory()
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool: def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
"""Backward-compatible wrapper around the configured memory storage save path.""" """Backward-compatible wrapper around the configured memory storage save path."""
return get_memory_storage().save(memory_data, agent_name, user_id=user_id) return get_memory_storage().save(memory_data, agent_name)
def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
"""Get the current memory data via storage provider.""" """Get the current memory data via storage provider."""
return get_memory_storage().load(agent_name, user_id=user_id) return get_memory_storage().load(agent_name)
def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
"""Reload memory data via storage provider.""" """Reload memory data via storage provider."""
return get_memory_storage().reload(agent_name, user_id=user_id) return get_memory_storage().reload(agent_name)
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
"""Persist imported memory data via storage provider. """Persist imported memory data via storage provider.
Args: Args:
memory_data: Full memory payload to persist. memory_data: Full memory payload to persist.
agent_name: If provided, imports into per-agent memory. agent_name: If provided, imports into per-agent memory.
user_id: If provided, scopes memory to a specific user.
Returns: Returns:
The saved memory data after storage normalization. The saved memory data after storage normalization.
@@ -57,15 +56,15 @@ def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = Non
OSError: If persisting the imported memory fails. OSError: If persisting the imported memory fails.
""" """
storage = get_memory_storage() storage = get_memory_storage()
if not storage.save(memory_data, agent_name, user_id=user_id): if not storage.save(memory_data, agent_name):
raise OSError("Failed to save imported memory data") raise OSError("Failed to save imported memory data")
return storage.load(agent_name, user_id=user_id) return storage.load(agent_name)
def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
"""Clear all stored memory data and persist an empty structure.""" """Clear all stored memory data and persist an empty structure."""
cleared_memory = create_empty_memory() cleared_memory = create_empty_memory()
if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id): if not _save_memory_to_file(cleared_memory, agent_name):
raise OSError("Failed to save cleared memory data") raise OSError("Failed to save cleared memory data")
return cleared_memory return cleared_memory
@@ -82,8 +81,6 @@ def create_memory_fact(
category: str = "context", category: str = "context",
confidence: float = 0.5, confidence: float = 0.5,
agent_name: str | None = None, agent_name: str | None = None,
*,
user_id: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Create a new fact and persist the updated memory data.""" """Create a new fact and persist the updated memory data."""
normalized_content = content.strip() normalized_content = content.strip()
@@ -93,7 +90,7 @@ def create_memory_fact(
normalized_category = category.strip() or "context" normalized_category = category.strip() or "context"
validated_confidence = _validate_confidence(confidence) validated_confidence = _validate_confidence(confidence)
now = utc_now_iso_z() now = utc_now_iso_z()
memory_data = get_memory_data(agent_name, user_id=user_id) memory_data = get_memory_data(agent_name)
updated_memory = dict(memory_data) updated_memory = dict(memory_data)
facts = list(memory_data.get("facts", [])) facts = list(memory_data.get("facts", []))
facts.append( facts.append(
@@ -108,15 +105,15 @@ def create_memory_fact(
) )
updated_memory["facts"] = facts updated_memory["facts"] = facts
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id): if not _save_memory_to_file(updated_memory, agent_name):
raise OSError("Failed to save memory data after creating fact") raise OSError("Failed to save memory data after creating fact")
return updated_memory return updated_memory
def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]:
"""Delete a fact by its id and persist the updated memory data.""" """Delete a fact by its id and persist the updated memory data."""
memory_data = get_memory_data(agent_name, user_id=user_id) memory_data = get_memory_data(agent_name)
facts = memory_data.get("facts", []) facts = memory_data.get("facts", [])
updated_facts = [fact for fact in facts if fact.get("id") != fact_id] updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
if len(updated_facts) == len(facts): if len(updated_facts) == len(facts):
@@ -125,7 +122,7 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id:
updated_memory = dict(memory_data) updated_memory = dict(memory_data)
updated_memory["facts"] = updated_facts updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id): if not _save_memory_to_file(updated_memory, agent_name):
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'") raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
return updated_memory return updated_memory
@@ -137,11 +134,9 @@ def update_memory_fact(
category: str | None = None, category: str | None = None,
confidence: float | None = None, confidence: float | None = None,
agent_name: str | None = None, agent_name: str | None = None,
*,
user_id: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Update an existing fact and persist the updated memory data.""" """Update an existing fact and persist the updated memory data."""
memory_data = get_memory_data(agent_name, user_id=user_id) memory_data = get_memory_data(agent_name)
updated_memory = dict(memory_data) updated_memory = dict(memory_data)
updated_facts: list[dict[str, Any]] = [] updated_facts: list[dict[str, Any]] = []
found = False found = False
@@ -168,7 +163,7 @@ def update_memory_fact(
updated_memory["facts"] = updated_facts updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id): if not _save_memory_to_file(updated_memory, agent_name):
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'") raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
return updated_memory return updated_memory
@@ -270,7 +265,7 @@ class MemoryUpdater:
def _get_model(self): def _get_model(self):
"""Get the model for memory updates.""" """Get the model for memory updates."""
config = get_memory_config() config = AppConfig.current().memory
model_name = self._model_name or config.model_name model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False) return create_chat_model(name=model_name, thinking_enabled=False)
@@ -281,7 +276,6 @@ class MemoryUpdater:
agent_name: str | None = None, agent_name: str | None = None,
correction_detected: bool = False, correction_detected: bool = False,
reinforcement_detected: bool = False, reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool: ) -> bool:
"""Update memory based on conversation messages. """Update memory based on conversation messages.
@@ -291,12 +285,11 @@ class MemoryUpdater:
agent_name: If provided, updates per-agent memory. If None, updates global memory. agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal. correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal. reinforcement_detected: Whether recent turns include a positive reinforcement signal.
user_id: If provided, scopes memory to a specific user.
Returns: Returns:
True if update was successful, False otherwise. True if update was successful, False otherwise.
""" """
config = get_memory_config() config = AppConfig.current().memory
if not config.enabled: if not config.enabled:
return False return False
@@ -305,7 +298,7 @@ class MemoryUpdater:
try: try:
# Get current memory # Get current memory
current_memory = get_memory_data(agent_name, user_id=user_id) current_memory = get_memory_data(agent_name)
# Format conversation for prompt # Format conversation for prompt
conversation_text = format_conversation_for_update(messages) conversation_text = format_conversation_for_update(messages)
@@ -360,7 +353,7 @@ class MemoryUpdater:
updated_memory = _strip_upload_mentions_from_memory(updated_memory) updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save # Save
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id) return get_memory_storage().save(updated_memory, agent_name)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e) logger.warning("Failed to parse LLM response for memory update: %s", e)
@@ -385,7 +378,7 @@ class MemoryUpdater:
Returns: Returns:
Updated memory data. Updated memory data.
""" """
config = get_memory_config() config = AppConfig.current().memory
now = utc_now_iso_z() now = utc_now_iso_z()
# Update user sections # Update user sections
@@ -462,7 +455,6 @@ def update_memory_from_conversation(
agent_name: str | None = None, agent_name: str | None = None,
correction_detected: bool = False, correction_detected: bool = False,
reinforcement_detected: bool = False, reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool: ) -> bool:
"""Convenience function to update memory from a conversation. """Convenience function to update memory from a conversation.
@@ -472,10 +464,9 @@ def update_memory_from_conversation(
agent_name: If provided, updates per-agent memory. If None, updates global memory. agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal. correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal. reinforcement_detected: Whether recent turns include a positive reinforcement signal.
user_id: If provided, scopes memory to a specific user.
Returns: Returns:
True if successful, False otherwise. True if successful, False otherwise.
""" """
updater = MemoryUpdater() updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id) return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
@@ -24,6 +24,8 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.config.deer_flow_context import DeerFlowContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor # Defaults — can be overridden via constructor
@@ -31,6 +33,8 @@ _DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls _DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls _DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit _DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]: def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
@@ -125,8 +129,14 @@ def _hash_tool_calls(tool_calls: list[dict]) -> str:
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far." _WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
_TOOL_FREQ_WARNING_MSG = (
"[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
)
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far." _HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
_TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far."
class LoopDetectionMiddleware(AgentMiddleware[AgentState]): class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops. """Detects and breaks repetitive tool call loops.
@@ -140,6 +150,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
Default: 20. Default: 20.
max_tracked_threads: Maximum number of threads to track before max_tracked_threads: Maximum number of threads to track before
evicting the least recently used. Default: 100. evicting the least recently used. Default: 100.
tool_freq_warn: Number of calls to the same tool *type* (regardless
of arguments) before injecting a frequency warning. Catches
cross-file read loops that hash-based detection misses.
Default: 30.
tool_freq_hard_limit: Number of calls to the same tool type before
forcing a stop. Default: 50.
""" """
def __init__( def __init__(
@@ -148,23 +164,27 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
hard_limit: int = _DEFAULT_HARD_LIMIT, hard_limit: int = _DEFAULT_HARD_LIMIT,
window_size: int = _DEFAULT_WINDOW_SIZE, window_size: int = _DEFAULT_WINDOW_SIZE,
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS, max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
): ):
super().__init__() super().__init__()
self.warn_threshold = warn_threshold self.warn_threshold = warn_threshold
self.hard_limit = hard_limit self.hard_limit = hard_limit
self.window_size = window_size self.window_size = window_size
self.max_tracked_threads = max_tracked_threads self.max_tracked_threads = max_tracked_threads
self.tool_freq_warn = tool_freq_warn
self.tool_freq_hard_limit = tool_freq_hard_limit
self._lock = threading.Lock() self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction # Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict() self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set) self._warned: dict[str, set[str]] = defaultdict(set)
# Per-thread, per-tool-type cumulative call counts
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime) -> str: def _get_thread_id(self, runtime: Runtime[DeerFlowContext]) -> str:
"""Extract thread_id from runtime context for per-thread tracking.""" """Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None return runtime.context.thread_id or "default"
if thread_id:
return thread_id
return "default"
def _evict_if_needed(self) -> None: def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit. """Evict least recently used threads if over the limit.
@@ -174,11 +194,19 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
while len(self._history) > self.max_tracked_threads: while len(self._history) > self.max_tracked_threads:
evicted_id, _ = self._history.popitem(last=False) evicted_id, _ = self._history.popitem(last=False)
self._warned.pop(evicted_id, None) self._warned.pop(evicted_id, None)
self._tool_freq.pop(evicted_id, None)
self._tool_freq_warned.pop(evicted_id, None)
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id) logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]: def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
"""Track tool calls and check for loops. """Track tool calls and check for loops.
Two detection layers:
1. **Hash-based** (existing): catches identical tool call sets.
2. **Frequency-based** (new): catches the same *tool type* being
called many times with varying arguments (e.g. ``read_file``
on 40 different files).
Returns: Returns:
(warning_message_or_none, should_hard_stop) (warning_message_or_none, should_hard_stop)
""" """
@@ -213,6 +241,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
count = history.count(call_hash) count = history.count(call_hash)
tool_names = [tc.get("name", "?") for tc in tool_calls] tool_names = [tc.get("name", "?") for tc in tool_calls]
# --- Layer 1: hash-based (identical call sets) ---
if count >= self.hard_limit: if count >= self.hard_limit:
logger.error( logger.error(
"Loop hard limit reached — forcing stop", "Loop hard limit reached — forcing stop",
@@ -239,8 +268,40 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
}, },
) )
return _WARNING_MSG, False return _WARNING_MSG, False
# Warning already injected for this hash — suppress
return None, False # --- Layer 2: per-tool-type frequency ---
freq = self._tool_freq[thread_id]
for tc in tool_calls:
name = tc.get("name", "")
if not name:
continue
freq[name] += 1
tc_count = freq[name]
if tc_count >= self.tool_freq_hard_limit:
logger.error(
"Tool frequency hard limit reached — forcing stop",
extra={
"thread_id": thread_id,
"tool_name": name,
"count": tc_count,
},
)
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
if tc_count >= self.tool_freq_warn:
warned = self._tool_freq_warned[thread_id]
if name not in warned:
warned.add(name)
logger.warning(
"Tool frequency warning — too many calls to same tool type",
extra={
"thread_id": thread_id,
"tool_name": name,
"count": tc_count,
},
)
return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False
return None, False return None, False
@@ -271,7 +332,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
stripped_msg = last_msg.model_copy( stripped_msg = last_msg.model_copy(
update={ update={
"tool_calls": [], "tool_calls": [],
"content": self._append_text(last_msg.content, _HARD_STOP_MSG), "content": self._append_text(last_msg.content, warning),
} }
) )
return {"messages": [stripped_msg]} return {"messages": [stripped_msg]}
@@ -283,16 +344,16 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
# the conversation; injecting one mid-conversation crashes # the conversation; injecting one mid-conversation crashes
# langchain_anthropic's _format_messages(). HumanMessage works # langchain_anthropic's _format_messages(). HumanMessage works
# with all providers. See #1299. # with all providers. See #1299.
return {"messages": [HumanMessage(content=warning, name="loop_warning")]} return {"messages": [HumanMessage(content=warning)]}
return None return None
@override @override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: def after_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._apply(state, runtime) return self._apply(state, runtime)
@override @override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: async def aafter_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._apply(state, runtime) return self._apply(state, runtime)
def reset(self, thread_id: str | None = None) -> None: def reset(self, thread_id: str | None = None) -> None:
@@ -301,6 +362,10 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
if thread_id: if thread_id:
self._history.pop(thread_id, None) self._history.pop(thread_id, None)
self._warned.pop(thread_id, None) self._warned.pop(thread_id, None)
self._tool_freq.pop(thread_id, None)
self._tool_freq_warned.pop(thread_id, None)
else: else:
self._history.clear() self._history.clear()
self._warned.clear() self._warned.clear()
self._tool_freq.clear()
self._tool_freq_warned.clear()
@@ -6,12 +6,10 @@ from typing import Any, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.memory.queue import get_memory_queue from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -194,7 +192,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
self._agent_name = agent_name self._agent_name = agent_name
@override @override
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None: def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
"""Queue conversation for memory update after agent completes. """Queue conversation for memory update after agent completes.
Args: Args:
@@ -204,15 +202,11 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
Returns: Returns:
None (no state changes needed from this middleware). None (no state changes needed from this middleware).
""" """
config = get_memory_config() memory_config = runtime.context.app_config.memory
if not config.enabled: if not memory_config.enabled:
return None return None
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata thread_id = runtime.context.thread_id
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
config_data = get_config()
thread_id = config_data.get("configurable", {}).get("thread_id")
if not thread_id: if not thread_id:
logger.debug("No thread_id in context, skipping memory update") logger.debug("No thread_id in context, skipping memory update")
return None return None
@@ -237,16 +231,11 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
# Queue the filtered conversation for memory update # Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages) correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages) reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
# Capture user_id at enqueue time while the request context is still alive.
# threading.Timer fires on a different thread where ContextVar values are not
# propagated, so we must store user_id explicitly in ConversationContext.
user_id = get_effective_user_id()
queue = get_memory_queue() queue = get_memory_queue()
queue.add( queue.add(
thread_id=thread_id, thread_id=thread_id,
messages=filtered_messages, messages=filtered_messages,
agent_name=self._agent_name, agent_name=self._agent_name,
user_id=user_id,
correction_detected=correction_detected, correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected, reinforcement_detected=reinforcement_detected,
) )
@@ -1,13 +0,0 @@
from typing import override
from langchain.agents.middleware import SummarizationMiddleware as BaseSummarizationMiddleware
from langchain_core.messages.human import HumanMessage
class SummarizationMiddleware(BaseSummarizationMiddleware):
@override
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
"""Override the base implementation to let the human message with the special name 'summary'.
And this message will be ignored to display in the frontend, but still can be used as context for the model.
"""
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
@@ -1,16 +1,13 @@
import logging import logging
from datetime import UTC, datetime
from typing import NotRequired, override from typing import NotRequired, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadDataState from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -49,70 +46,50 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
self._paths = Paths(base_dir) if base_dir else get_paths() self._paths = Paths(base_dir) if base_dir else get_paths()
self._lazy_init = lazy_init self._lazy_init = lazy_init
def _get_thread_paths(self, thread_id: str, user_id: str | None = None) -> dict[str, str]: def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
"""Get the paths for a thread's data directories. """Get the paths for a thread's data directories.
Args: Args:
thread_id: The thread ID. thread_id: The thread ID.
user_id: Optional user ID for per-user path isolation.
Returns: Returns:
Dictionary with workspace_path, uploads_path, and outputs_path. Dictionary with workspace_path, uploads_path, and outputs_path.
""" """
return { return {
"workspace_path": str(self._paths.sandbox_work_dir(thread_id, user_id=user_id)), "workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id, user_id=user_id)), "uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id, user_id=user_id)), "outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
} }
def _create_thread_directories(self, thread_id: str, user_id: str | None = None) -> dict[str, str]: def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
"""Create the thread data directories. """Create the thread data directories.
Args: Args:
thread_id: The thread ID. thread_id: The thread ID.
user_id: Optional user ID for per-user path isolation.
Returns: Returns:
Dictionary with the created directory paths. Dictionary with the created directory paths.
""" """
self._paths.ensure_thread_dirs(thread_id, user_id=user_id) self._paths.ensure_thread_dirs(thread_id)
return self._get_thread_paths(thread_id, user_id=user_id) return self._get_thread_paths(thread_id)
@override @override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None: def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
context = runtime.context or {} thread_id = runtime.context.thread_id
thread_id = context.get("thread_id")
if thread_id is None:
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
if thread_id is None: if not thread_id:
raise ValueError("Thread ID is required in runtime context or config.configurable") raise ValueError("Thread ID is required in runtime context or config.configurable")
user_id = get_effective_user_id()
if self._lazy_init: if self._lazy_init:
# Lazy initialization: only compute paths, don't create directories # Lazy initialization: only compute paths, don't create directories
paths = self._get_thread_paths(thread_id, user_id=user_id) paths = self._get_thread_paths(thread_id)
else: else:
# Eager initialization: create directories immediately # Eager initialization: create directories immediately
paths = self._create_thread_directories(thread_id, user_id=user_id) paths = self._create_thread_directories(thread_id)
logger.debug("Created thread data directories for thread %s", thread_id) logger.debug("Created thread data directories for thread %s", thread_id)
messages = list(state.get("messages", []))
last_message = messages[-1] if messages else None
if last_message and isinstance(last_message, HumanMessage):
messages[-1] = HumanMessage(
content=last_message.content,
id=last_message.id,
name=last_message.name or "user-input",
additional_kwargs={**last_message.additional_kwargs, "run_id": runtime.context.get("run_id"), "timestamp": datetime.now(UTC).isoformat()},
)
return { return {
"thread_data": { "thread_data": {
**paths, **paths,
}, }
"messages": messages,
} }
@@ -1,14 +1,13 @@
"""Middleware for automatic thread title generation.""" """Middleware for automatic thread title generation."""
import logging import logging
from typing import Any, NotRequired, override from typing import NotRequired, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -46,7 +45,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
def _should_generate_title(self, state: TitleMiddlewareState) -> bool: def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
"""Check if we should generate a title for this thread.""" """Check if we should generate a title for this thread."""
config = get_title_config() config = AppConfig.current().title
if not config.enabled: if not config.enabled:
return False return False
@@ -71,7 +70,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
Returns (prompt_string, user_msg) so callers can use user_msg as fallback. Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
""" """
config = get_title_config() config = AppConfig.current().title
messages = state.get("messages", []) messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "") user_msg_content = next((m.content for m in messages if m.type == "human"), "")
@@ -89,32 +88,18 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
def _parse_title(self, content: object) -> str: def _parse_title(self, content: object) -> str:
"""Normalize model output into a clean title string.""" """Normalize model output into a clean title string."""
config = get_title_config() config = AppConfig.current().title
title_content = self._normalize_content(content) title_content = self._normalize_content(content)
title = title_content.strip().strip('"').strip("'") title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title return title[: config.max_chars] if len(title) > config.max_chars else title
def _fallback_title(self, user_msg: str) -> str: def _fallback_title(self, user_msg: str) -> str:
config = get_title_config() config = AppConfig.current().title
fallback_chars = min(config.max_chars, 50) fallback_chars = min(config.max_chars, 50)
if len(user_msg) > fallback_chars: if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..." return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation" return user_msg if user_msg else "New Conversation"
def _get_runnable_config(self) -> dict[str, Any]:
"""Inherit the parent RunnableConfig and add middleware tag.
This ensures RunJournal identifies LLM calls from this middleware
as ``middleware:title`` instead of ``lead_agent``.
"""
try:
parent = get_config()
except Exception:
parent = {}
config = {**parent}
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
return config
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None: def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Generate a local fallback title without blocking on an LLM call.""" """Generate a local fallback title without blocking on an LLM call."""
if not self._should_generate_title(state): if not self._should_generate_title(state):
@@ -128,7 +113,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
if not self._should_generate_title(state): if not self._should_generate_title(state):
return None return None
config = get_title_config() config = AppConfig.current().title
prompt, user_msg = self._build_title_prompt(state) prompt, user_msg = self._build_title_prompt(state)
try: try:
@@ -136,7 +121,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
model = create_chat_model(name=config.model_name, thinking_enabled=False) model = create_chat_model(name=config.model_name, thinking_enabled=False)
else: else:
model = create_chat_model(thinking_enabled=False) model = create_chat_model(thinking_enabled=False)
response = await model.ainvoke(prompt, config=self._get_runnable_config()) response = await model.ainvoke(prompt)
title = self._parse_title(response.content) title = self._parse_title(response.content)
if title: if title:
return {"title": title} return {"title": title}
@@ -94,9 +94,9 @@ def _build_runtime_middlewares(
middlewares.append(LLMErrorHandlingMiddleware()) middlewares.append(LLMErrorHandlingMiddleware())
# Guardrail middleware (if configured) # Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config from deerflow.config.app_config import AppConfig
guardrails_config = get_guardrails_config() guardrails_config = AppConfig.current().guardrails
if guardrails_config.enabled and guardrails_config.provider: if guardrails_config.enabled and guardrails_config.provider:
import inspect import inspect
@@ -9,8 +9,8 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.utils.file_conversion import extract_outline from deerflow.utils.file_conversion import extract_outline
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -185,7 +185,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return files if files else None return files if files else None
@override @override
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None: def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
"""Inject uploaded files information before agent execution. """Inject uploaded files information before agent execution.
New files come from the current message's additional_kwargs.files. New files come from the current message's additional_kwargs.files.
@@ -214,15 +214,8 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return None return None
# Resolve uploads directory for existence checks # Resolve uploads directory for existence checks
thread_id = (runtime.context or {}).get("thread_id") thread_id = runtime.context.thread_id
if thread_id is None: uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
try:
from langgraph.config import get_config
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
uploads_dir = self._paths.sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files # Get newly uploaded files from the current message's additional_kwargs.files
new_files = self._files_from_kwargs(last_message, uploads_dir) or [] new_files = self._files_from_kwargs(last_message, uploads_dir) or []
@@ -263,23 +256,26 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
files_message = self._create_files_message(new_files, historical_files) files_message = self._create_files_message(new_files, historical_files)
# Extract original content - handle both string and list formats # Extract original content - handle both string and list formats
original_content = "" original_content = last_message.content
if isinstance(last_message.content, str): if isinstance(original_content, str):
original_content = last_message.content # Simple case: string content, just prepend files message
elif isinstance(last_message.content, list): updated_content = f"{files_message}\n\n{original_content}"
text_parts = [] elif isinstance(original_content, list):
for block in last_message.content: # Complex case: list content (multimodal), preserve all blocks
if isinstance(block, dict) and block.get("type") == "text": # Prepend files message as the first text block
text_parts.append(block.get("text", "")) files_block = {"type": "text", "text": f"{files_message}\n\n"}
original_content = "\n".join(text_parts) # Keep all original blocks (including images)
updated_content = [files_block, *original_content]
else:
# Other types, preserve as-is
updated_content = original_content
# Create new message with combined content. # Create new message with combined content.
# Preserve additional_kwargs (including files metadata) so the frontend # Preserve additional_kwargs (including files metadata) so the frontend
# can read structured file info from the streamed message. # can read structured file info from the streamed message.
updated_message = HumanMessage( updated_message = HumanMessage(
content=f"{files_message}\n\n{original_content}", content=updated_content,
id=last_message.id, id=last_message.id,
name=last_message.name,
additional_kwargs=last_message.additional_kwargs, additional_kwargs=last_message.additional_kwargs,
) )
+27 -30
View File
@@ -36,11 +36,11 @@ from deerflow.agents.lead_agent.agent import _build_middlewares
from deerflow.agents.lead_agent.prompt import apply_prompt_template from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import AGENT_NAME_PATTERN from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.app_config import get_app_config, reload_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.skills.installer import install_skill_from_archive from deerflow.skills.installer import install_skill_from_archive
from deerflow.uploads.manager import ( from deerflow.uploads.manager import (
claim_unique_filename, claim_unique_filename,
@@ -142,8 +142,8 @@ class DeerFlowClient:
middlewares: Optional list of custom middlewares to inject into the agent. middlewares: Optional list of custom middlewares to inject into the agent.
""" """
if config_path is not None: if config_path is not None:
reload_app_config(config_path) AppConfig.init(AppConfig.from_file(config_path))
self._app_config = get_app_config() self._app_config = AppConfig.current()
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name): if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}") raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
@@ -241,7 +241,7 @@ class DeerFlowClient:
} }
checkpointer = self._checkpointer checkpointer = self._checkpointer
if checkpointer is None: if checkpointer is None:
from deerflow.runtime.checkpointer import get_checkpointer from deerflow.agents.checkpointer import get_checkpointer
checkpointer = get_checkpointer() checkpointer = get_checkpointer()
if checkpointer is not None: if checkpointer is not None:
@@ -375,7 +375,7 @@ class DeerFlowClient:
""" """
checkpointer = self._checkpointer checkpointer = self._checkpointer
if checkpointer is None: if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer from deerflow.agents.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer() checkpointer = get_checkpointer()
@@ -430,7 +430,7 @@ class DeerFlowClient:
""" """
checkpointer = self._checkpointer checkpointer = self._checkpointer
if checkpointer is None: if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer from deerflow.agents.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer() checkpointer = get_checkpointer()
@@ -552,9 +552,7 @@ class DeerFlowClient:
self._ensure_agent(config) self._ensure_agent(config)
state: dict[str, Any] = {"messages": [HumanMessage(content=message)]} state: dict[str, Any] = {"messages": [HumanMessage(content=message)]}
context = {"thread_id": thread_id} context = DeerFlowContext(app_config=self._app_config, thread_id=thread_id, agent_name=self._agent_name)
if self._agent_name:
context["agent_name"] = self._agent_name
seen_ids: set[str] = set() seen_ids: set[str] = set()
# Cross-mode handoff: ids already streamed via LangGraph ``messages`` # Cross-mode handoff: ids already streamed via LangGraph ``messages``
@@ -770,19 +768,19 @@ class DeerFlowClient:
""" """
from deerflow.agents.memory.updater import get_memory_data from deerflow.agents.memory.updater import get_memory_data
return get_memory_data(user_id=get_effective_user_id()) return get_memory_data()
def export_memory(self) -> dict: def export_memory(self) -> dict:
"""Export current memory data for backup or transfer.""" """Export current memory data for backup or transfer."""
from deerflow.agents.memory.updater import get_memory_data from deerflow.agents.memory.updater import get_memory_data
return get_memory_data(user_id=get_effective_user_id()) return get_memory_data()
def import_memory(self, memory_data: dict) -> dict: def import_memory(self, memory_data: dict) -> dict:
"""Import and persist full memory data.""" """Import and persist full memory data."""
from deerflow.agents.memory.updater import import_memory_data from deerflow.agents.memory.updater import import_memory_data
return import_memory_data(memory_data, user_id=get_effective_user_id()) return import_memory_data(memory_data)
def get_model(self, name: str) -> dict | None: def get_model(self, name: str) -> dict | None:
"""Get a specific model's configuration by name. """Get a specific model's configuration by name.
@@ -817,8 +815,8 @@ class DeerFlowClient:
Dict with "mcp_servers" key mapping server name to config, Dict with "mcp_servers" key mapping server name to config,
matching the Gateway API ``McpConfigResponse`` schema. matching the Gateway API ``McpConfigResponse`` schema.
""" """
config = get_extensions_config() ext = AppConfig.current().extensions
return {"mcp_servers": {name: server.model_dump() for name, server in config.mcp_servers.items()}} return {"mcp_servers": {name: server.model_dump() for name, server in ext.mcp_servers.items()}}
def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict: def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict:
"""Update MCP server configurations. """Update MCP server configurations.
@@ -840,18 +838,19 @@ class DeerFlowClient:
if config_path is None: if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.") raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
current_config = get_extensions_config() current_ext = AppConfig.current().extensions
config_data = { config_data = {
"mcpServers": mcp_servers, "mcpServers": mcp_servers,
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()}, "skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()},
} }
self._atomic_write_json(config_path, config_data) self._atomic_write_json(config_path, config_data)
self._agent = None self._agent = None
self._agent_config_key = None self._agent_config_key = None
reloaded = reload_extensions_config() AppConfig.init(AppConfig.from_file())
reloaded = AppConfig.current().extensions
return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}} return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}}
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -905,19 +904,19 @@ class DeerFlowClient:
if config_path is None: if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.") raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
extensions_config = get_extensions_config() ext = AppConfig.current().extensions
extensions_config.skills[name] = SkillStateConfig(enabled=enabled) ext.skills[name] = SkillStateConfig(enabled=enabled)
config_data = { config_data = {
"mcpServers": {n: s.model_dump() for n, s in extensions_config.mcp_servers.items()}, "mcpServers": {n: s.model_dump() for n, s in ext.mcp_servers.items()},
"skills": {n: {"enabled": sc.enabled} for n, sc in extensions_config.skills.items()}, "skills": {n: {"enabled": sc.enabled} for n, sc in ext.skills.items()},
} }
self._atomic_write_json(config_path, config_data) self._atomic_write_json(config_path, config_data)
self._agent = None self._agent = None
self._agent_config_key = None self._agent_config_key = None
reload_extensions_config() AppConfig.init(AppConfig.from_file())
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None) updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
if updated is None: if updated is None:
@@ -957,13 +956,13 @@ class DeerFlowClient:
""" """
from deerflow.agents.memory.updater import reload_memory_data from deerflow.agents.memory.updater import reload_memory_data
return reload_memory_data(user_id=get_effective_user_id()) return reload_memory_data()
def clear_memory(self) -> dict: def clear_memory(self) -> dict:
"""Clear all persisted memory data.""" """Clear all persisted memory data."""
from deerflow.agents.memory.updater import clear_memory_data from deerflow.agents.memory.updater import clear_memory_data
return clear_memory_data(user_id=get_effective_user_id()) return clear_memory_data()
def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict: def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict:
"""Create a single fact manually.""" """Create a single fact manually."""
@@ -1000,9 +999,7 @@ class DeerFlowClient:
Returns: Returns:
Memory config dict. Memory config dict.
""" """
from deerflow.config.memory_config import get_memory_config config = AppConfig.current().memory
config = get_memory_config()
return { return {
"enabled": config.enabled, "enabled": config.enabled,
"storage_path": config.storage_path, "storage_path": config.storage_path,
@@ -1180,7 +1177,7 @@ class DeerFlowClient:
ValueError: If the path is invalid. ValueError: If the path is invalid.
""" """
try: try:
actual = get_paths().resolve_virtual_path(thread_id, path, user_id=get_effective_user_id()) actual = get_paths().resolve_virtual_path(thread_id, path)
except ValueError as exc: except ValueError as exc:
if "traversal" in str(exc): if "traversal" in str(exc):
from deerflow.uploads.manager import PathTraversalError from deerflow.uploads.manager import PathTraversalError
@@ -25,9 +25,8 @@ except ImportError: # pragma: no cover - Windows fallback
fcntl = None # type: ignore[assignment] fcntl = None # type: ignore[assignment]
import msvcrt import msvcrt
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider from deerflow.sandbox.sandbox_provider import SandboxProvider
@@ -149,7 +148,7 @@ class AioSandboxProvider(SandboxProvider):
def _load_config(self) -> dict: def _load_config(self) -> dict:
"""Load sandbox configuration from app config.""" """Load sandbox configuration from app config."""
config = get_app_config() config = AppConfig.current()
sandbox_config = config.sandbox sandbox_config = config.sandbox
idle_timeout = getattr(sandbox_config, "idle_timeout", None) idle_timeout = getattr(sandbox_config, "idle_timeout", None)
@@ -261,16 +260,15 @@ class AioSandboxProvider(SandboxProvider):
mounted Docker socket (DooD), the host Docker daemon can resolve the paths. mounted Docker socket (DooD), the host Docker daemon can resolve the paths.
""" """
paths = get_paths() paths = get_paths()
user_id = get_effective_user_id() paths.ensure_thread_dirs(thread_id)
paths.ensure_thread_dirs(thread_id, user_id=user_id)
return [ return [
(paths.host_sandbox_work_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False), (paths.host_sandbox_work_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
(paths.host_sandbox_uploads_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False), (paths.host_sandbox_uploads_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
(paths.host_sandbox_outputs_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False), (paths.host_sandbox_outputs_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
# ACP workspace: read-only inside the sandbox (lead agent reads results; # ACP workspace: read-only inside the sandbox (lead agent reads results;
# the ACP subprocess writes from the host side, not from within the container). # the ACP subprocess writes from the host side, not from within the container).
(paths.host_acp_workspace_dir(thread_id, user_id=user_id), "/mnt/acp-workspace", True), (paths.host_acp_workspace_dir(thread_id), "/mnt/acp-workspace", True),
] ]
@staticmethod @staticmethod
@@ -281,7 +279,7 @@ class AioSandboxProvider(SandboxProvider):
so the host Docker daemon can resolve the path. so the host Docker daemon can resolve the path.
""" """
try: try:
config = get_app_config() config = AppConfig.current()
skills_path = config.skills.get_skills_path() skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path container_path = config.skills.container_path
@@ -482,9 +480,8 @@ class AioSandboxProvider(SandboxProvider):
across multiple processes, preventing container-name conflicts. across multiple processes, preventing container-name conflicts.
""" """
paths = get_paths() paths = get_paths()
user_id = get_effective_user_id() paths.ensure_thread_dirs(thread_id)
paths.ensure_thread_dirs(thread_id, user_id=user_id) lock_path = paths.thread_dir(thread_id) / f"{sandbox_id}.lock"
lock_path = paths.thread_dir(thread_id, user_id=user_id) / f"{sandbox_id}.lock"
with open(lock_path, "a", encoding="utf-8") as lock_file: with open(lock_path, "a", encoding="utf-8") as lock_file:
locked = False locked = False
@@ -7,7 +7,7 @@ import logging
from langchain.tools import tool from langchain.tools import tool
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -63,7 +63,7 @@ def web_search_tool(
query: Search keywords describing what you want to find. Be specific for better results. query: Search keywords describing what you want to find. Be specific for better results.
max_results: Maximum number of results to return. Default is 5. max_results: Maximum number of results to return. Default is 5.
""" """
config = get_app_config().get_tool_config("web_search") config = AppConfig.current().get_tool_config("web_search")
# Override max_results from config if set # Override max_results from config if set
if config is not None and "max_results" in config.model_extra: if config is not None and "max_results" in config.model_extra:
@@ -3,11 +3,11 @@ import json
from exa_py import Exa from exa_py import Exa
from langchain.tools import tool from langchain.tools import tool
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
def _get_exa_client(tool_name: str = "web_search") -> Exa: def _get_exa_client(tool_name: str = "web_search") -> Exa:
config = get_app_config().get_tool_config(tool_name) config = AppConfig.current().get_tool_config(tool_name)
api_key = None api_key = None
if config is not None and "api_key" in config.model_extra: if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key") api_key = config.model_extra.get("api_key")
@@ -22,7 +22,7 @@ def web_search_tool(query: str) -> str:
query: The query to search for. query: The query to search for.
""" """
try: try:
config = get_app_config().get_tool_config("web_search") config = AppConfig.current().get_tool_config("web_search")
max_results = 5 max_results = 5
search_type = "auto" search_type = "auto"
contents_max_characters = 1000 contents_max_characters = 1000
@@ -3,11 +3,11 @@ import json
from firecrawl import FirecrawlApp from firecrawl import FirecrawlApp
from langchain.tools import tool from langchain.tools import tool
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
def _get_firecrawl_client(tool_name: str = "web_search") -> FirecrawlApp: def _get_firecrawl_client(tool_name: str = "web_search") -> FirecrawlApp:
config = get_app_config().get_tool_config(tool_name) config = AppConfig.current().get_tool_config(tool_name)
api_key = None api_key = None
if config is not None and "api_key" in config.model_extra: if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key") api_key = config.model_extra.get("api_key")
@@ -22,7 +22,7 @@ def web_search_tool(query: str) -> str:
query: The query to search for. query: The query to search for.
""" """
try: try:
config = get_app_config().get_tool_config("web_search") config = AppConfig.current().get_tool_config("web_search")
max_results = 5 max_results = 5
if config is not None: if config is not None:
max_results = config.model_extra.get("max_results", max_results) max_results = config.model_extra.get("max_results", max_results)
@@ -7,7 +7,7 @@ import logging
from langchain.tools import tool from langchain.tools import tool
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -99,7 +99,7 @@ def image_search_tool(
type_image: Image type filter. Options: "photo", "clipart", "gif", "transparent", "line". Use "photo" for realistic references. type_image: Image type filter. Options: "photo", "clipart", "gif", "transparent", "line". Use "photo" for realistic references.
layout: Layout filter. Options: "Square", "Tall", "Wide". Choose based on your generation needs. layout: Layout filter. Options: "Square", "Tall", "Wide". Choose based on your generation needs.
""" """
config = get_app_config().get_tool_config("image_search") config = AppConfig.current().get_tool_config("image_search")
# Override max_results from config if set # Override max_results from config if set
if config is not None and "max_results" in config.model_extra: if config is not None and "max_results" in config.model_extra:
@@ -1,6 +1,6 @@
from langchain.tools import tool from langchain.tools import tool
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.utils.readability import ReadabilityExtractor from deerflow.utils.readability import ReadabilityExtractor
from .infoquest_client import InfoQuestClient from .infoquest_client import InfoQuestClient
@@ -9,12 +9,12 @@ readability_extractor = ReadabilityExtractor()
def _get_infoquest_client() -> InfoQuestClient: def _get_infoquest_client() -> InfoQuestClient:
search_config = get_app_config().get_tool_config("web_search") search_config = AppConfig.current().get_tool_config("web_search")
search_time_range = -1 search_time_range = -1
if search_config is not None and "search_time_range" in search_config.model_extra: if search_config is not None and "search_time_range" in search_config.model_extra:
search_time_range = search_config.model_extra.get("search_time_range") search_time_range = search_config.model_extra.get("search_time_range")
fetch_config = get_app_config().get_tool_config("web_fetch") fetch_config = AppConfig.current().get_tool_config("web_fetch")
fetch_time = -1 fetch_time = -1
if fetch_config is not None and "fetch_time" in fetch_config.model_extra: if fetch_config is not None and "fetch_time" in fetch_config.model_extra:
fetch_time = fetch_config.model_extra.get("fetch_time") fetch_time = fetch_config.model_extra.get("fetch_time")
@@ -25,7 +25,7 @@ def _get_infoquest_client() -> InfoQuestClient:
if fetch_config is not None and "navigation_timeout" in fetch_config.model_extra: if fetch_config is not None and "navigation_timeout" in fetch_config.model_extra:
navigation_timeout = fetch_config.model_extra.get("navigation_timeout") navigation_timeout = fetch_config.model_extra.get("navigation_timeout")
image_search_config = get_app_config().get_tool_config("image_search") image_search_config = AppConfig.current().get_tool_config("image_search")
image_search_time_range = -1 image_search_time_range = -1
if image_search_config is not None and "image_search_time_range" in image_search_config.model_extra: if image_search_config is not None and "image_search_time_range" in image_search_config.model_extra:
image_search_time_range = image_search_config.model_extra.get("image_search_time_range") image_search_time_range = image_search_config.model_extra.get("image_search_time_range")
@@ -1,7 +1,7 @@
from langchain.tools import tool from langchain.tools import tool
from deerflow.community.jina_ai.jina_client import JinaClient from deerflow.community.jina_ai.jina_client import JinaClient
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.utils.readability import ReadabilityExtractor from deerflow.utils.readability import ReadabilityExtractor
readability_extractor = ReadabilityExtractor() readability_extractor = ReadabilityExtractor()
@@ -20,7 +20,7 @@ async def web_fetch_tool(url: str) -> str:
""" """
jina_client = JinaClient() jina_client = JinaClient()
timeout = 10 timeout = 10
config = get_app_config().get_tool_config("web_fetch") config = AppConfig.current().get_tool_config("web_fetch")
if config is not None and "timeout" in config.model_extra: if config is not None and "timeout" in config.model_extra:
timeout = config.model_extra.get("timeout") timeout = config.model_extra.get("timeout")
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout) html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
@@ -3,11 +3,11 @@ import json
from langchain.tools import tool from langchain.tools import tool
from tavily import TavilyClient from tavily import TavilyClient
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
def _get_tavily_client() -> TavilyClient: def _get_tavily_client() -> TavilyClient:
config = get_app_config().get_tool_config("web_search") config = AppConfig.current().get_tool_config("web_search")
api_key = None api_key = None
if config is not None and "api_key" in config.model_extra: if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key") api_key = config.model_extra.get("api_key")
@@ -21,7 +21,7 @@ def web_search_tool(query: str) -> str:
Args: Args:
query: The query to search for. query: The query to search for.
""" """
config = get_app_config().get_tool_config("web_search") config = AppConfig.current().get_tool_config("web_search")
max_results = 5 max_results = 5
if config is not None and "max_results" in config.model_extra: if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results") max_results = config.model_extra.get("max_results")
@@ -1,6 +1,6 @@
from .app_config import get_app_config from .app_config import AppConfig
from .extensions_config import ExtensionsConfig, get_extensions_config from .extensions_config import ExtensionsConfig
from .memory_config import MemoryConfig, get_memory_config from .memory_config import MemoryConfig
from .paths import Paths, get_paths from .paths import Paths, get_paths
from .skill_evolution_config import SkillEvolutionConfig from .skill_evolution_config import SkillEvolutionConfig
from .skills_config import SkillsConfig from .skills_config import SkillsConfig
@@ -13,18 +13,16 @@ from .tracing_config import (
) )
__all__ = [ __all__ = [
"get_app_config", "AppConfig",
"SkillEvolutionConfig",
"Paths",
"get_paths",
"SkillsConfig",
"ExtensionsConfig", "ExtensionsConfig",
"get_extensions_config",
"MemoryConfig", "MemoryConfig",
"get_memory_config", "Paths",
"get_tracing_config", "SkillEvolutionConfig",
"get_explicitly_enabled_tracing_providers", "SkillsConfig",
"get_enabled_tracing_providers", "get_enabled_tracing_providers",
"get_explicitly_enabled_tracing_providers",
"get_paths",
"get_tracing_config",
"is_tracing_enabled", "is_tracing_enabled",
"validate_enabled_tracing_providers", "validate_enabled_tracing_providers",
] ]
@@ -1,16 +1,13 @@
"""ACP (Agent Client Protocol) agent configuration loaded from config.yaml.""" """ACP (Agent Client Protocol) agent configuration loaded from config.yaml."""
import logging from pydantic import BaseModel, ConfigDict, Field
from collections.abc import Mapping
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class ACPAgentConfig(BaseModel): class ACPAgentConfig(BaseModel):
"""Configuration for a single ACP-compatible agent.""" """Configuration for a single ACP-compatible agent."""
model_config = ConfigDict(frozen=True)
command: str = Field(description="Command to launch the ACP agent subprocess") command: str = Field(description="Command to launch the ACP agent subprocess")
args: list[str] = Field(default_factory=list, description="Additional command arguments") args: list[str] = Field(default_factory=list, description="Additional command arguments")
env: dict[str, str] = Field(default_factory=dict, description="Environment variables to inject into the agent subprocess. Values starting with $ are resolved from host environment variables.") env: dict[str, str] = Field(default_factory=dict, description="Environment variables to inject into the agent subprocess. Values starting with $ are resolved from host environment variables.")
@@ -24,28 +21,3 @@ class ACPAgentConfig(BaseModel):
"are denied — the agent must be configured to operate without requesting permissions." "are denied — the agent must be configured to operate without requesting permissions."
), ),
) )
_acp_agents: dict[str, ACPAgentConfig] = {}
def get_acp_agents() -> dict[str, ACPAgentConfig]:
"""Get the currently configured ACP agents.
Returns:
Mapping of agent name -> ACPAgentConfig. Empty dict if no ACP agents are configured.
"""
return _acp_agents
def load_acp_config_from_dict(config_dict: Mapping[str, Mapping[str, object]] | None) -> None:
"""Load ACP agent configuration from a dictionary (typically from config.yaml).
Args:
config_dict: Mapping of agent name -> config fields.
"""
global _acp_agents
if config_dict is None:
config_dict = {}
_acp_agents = {name: ACPAgentConfig(**cfg) for name, cfg in config_dict.items()}
logger.info("ACP config loaded: %d agent(s): %s", len(_acp_agents), list(_acp_agents.keys()))
@@ -5,7 +5,7 @@ import re
from typing import Any from typing import Any
import yaml import yaml
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
@@ -18,6 +18,8 @@ AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
class AgentConfig(BaseModel): class AgentConfig(BaseModel):
"""Configuration for a custom agent.""" """Configuration for a custom agent."""
model_config = ConfigDict(frozen=True)
name: str name: str
description: str = "" description: str = ""
model: str | None = None model: str | None = None
@@ -1,31 +1,31 @@
from __future__ import annotations
import logging import logging
import os import os
from contextvars import ContextVar from contextvars import ContextVar
from pathlib import Path from pathlib import Path
from typing import Any, Self from typing import Any, ClassVar, Self
import yaml import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import load_acp_config_from_dict from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict from deerflow.config.guardrails_config import GuardrailsConfig
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.sandbox_config import SandboxConfig from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig from deerflow.config.skills_config import SkillsConfig
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict from deerflow.config.stream_bridge_config import StreamBridgeConfig
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict from deerflow.config.subagents_config import SubagentsAppConfig
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict from deerflow.config.summarization_config import SummarizationConfig
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict from deerflow.config.title_config import TitleConfig
from deerflow.config.token_usage_config import TokenUsageConfig from deerflow.config.token_usage_config import TokenUsageConfig
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict from deerflow.config.tool_search_config import ToolSearchConfig
load_dotenv() load_dotenv()
@@ -57,11 +57,10 @@ class AppConfig(BaseModel):
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration") memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration") subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
model_config = ConfigDict(extra="allow", frozen=False) model_config = ConfigDict(extra="allow", frozen=True)
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration") checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration") stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
acp_agents: dict[str, ACPAgentConfig] = Field(default_factory=dict, description="ACP agent configurations keyed by agent name")
@classmethod @classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path: def resolve_config_path(cls, config_path: str | None = None) -> Path:
@@ -109,41 +108,6 @@ class AppConfig(BaseModel):
config_data = cls.resolve_env_variables(config_data) config_data = cls.resolve_env_variables(config_data)
# Load title config if present
if "title" in config_data:
load_title_config_from_dict(config_data["title"])
# Load summarization config if present
if "summarization" in config_data:
load_summarization_config_from_dict(config_data["summarization"])
# Load memory config if present
if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"])
# Load subagents config if present
if "subagents" in config_data:
load_subagents_config_from_dict(config_data["subagents"])
# Load tool_search config if present
if "tool_search" in config_data:
load_tool_search_config_from_dict(config_data["tool_search"])
# Load guardrails config if present
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])
# Load stream bridge config if present
if "stream_bridge" in config_data:
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
# Always refresh ACP agent config so removed entries do not linger across reloads.
load_acp_config_from_dict(config_data.get("acp_agents", {}))
# Load extensions config separately (it's in a different file) # Load extensions config separately (it's in a different file)
extensions_config = ExtensionsConfig.from_file() extensions_config = ExtensionsConfig.from_file()
config_data["extensions"] = extensions_config.model_dump() config_data["extensions"] = extensions_config.model_dump()
@@ -254,130 +218,26 @@ class AppConfig(BaseModel):
""" """
return next((group for group in self.tool_groups if group.name == name), None) return next((group for group in self.tool_groups if group.name == name), None)
# -- Lifecycle (class-level singleton via ContextVar) --
_app_config: AppConfig | None = None _current: ClassVar[ContextVar[AppConfig]] = ContextVar("deerflow_app_config")
_app_config_path: Path | None = None
_app_config_mtime: float | None = None
_app_config_is_custom = False
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
@classmethod
def init(cls, config: AppConfig) -> None:
"""Set the AppConfig for the current context. Call once at process startup."""
cls._current.set(config)
def _get_config_mtime(config_path: Path) -> float | None: @classmethod
"""Get the modification time of a config file if it exists.""" def current(cls) -> AppConfig:
try: """Get the current AppConfig.
return config_path.stat().st_mtime
except OSError:
return None
Auto-initializes from config file on first access for backward compatibility.
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig: Prefer calling AppConfig.init() explicitly at process startup.
"""Load config from disk and refresh cache metadata.""" """
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom try:
return cls._current.get()
resolved_path = AppConfig.resolve_config_path(config_path) except LookupError:
_app_config = AppConfig.from_file(str(resolved_path)) logger.debug("AppConfig not initialized, auto-loading from file")
_app_config_path = resolved_path config = cls.from_file()
_app_config_mtime = _get_config_mtime(resolved_path) cls._current.set(config)
_app_config_is_custom = False return config
return _app_config
def get_app_config() -> AppConfig:
"""Get the DeerFlow config instance.
Returns a cached singleton instance and automatically reloads it when the
underlying config file path or modification time changes. Use
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
the cache.
"""
global _app_config, _app_config_path, _app_config_mtime
runtime_override = _current_app_config.get()
if runtime_override is not None:
return runtime_override
if _app_config is not None and _app_config_is_custom:
return _app_config
resolved_path = AppConfig.resolve_config_path()
current_mtime = _get_config_mtime(resolved_path)
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
if should_reload:
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
logger.info(
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
_app_config_mtime,
current_mtime,
)
_load_and_cache_app_config(str(resolved_path))
return _app_config
def reload_app_config(config_path: str | None = None) -> AppConfig:
"""Reload the config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded AppConfig instance.
"""
return _load_and_cache_app_config(config_path)
def reset_app_config() -> None:
"""Reset the cached config instance.
This clears the singleton cache, causing the next call to
`get_app_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = None
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = False
def set_app_config(config: AppConfig) -> None:
"""Set a custom config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The AppConfig instance to use.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = config
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = True
def peek_current_app_config() -> AppConfig | None:
"""Return the runtime-scoped AppConfig override, if one is active."""
return _current_app_config.get()
def push_current_app_config(config: AppConfig) -> None:
"""Push a runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
_current_app_config_stack.set(stack + (_current_app_config.get(),))
_current_app_config.set(config)
def pop_current_app_config() -> None:
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
if not stack:
_current_app_config.set(None)
return
previous = stack[-1]
_current_app_config_stack.set(stack[:-1])
_current_app_config.set(previous)
@@ -2,7 +2,7 @@
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
CheckpointerType = Literal["memory", "sqlite", "postgres"] CheckpointerType = Literal["memory", "sqlite", "postgres"]
@@ -10,6 +10,8 @@ CheckpointerType = Literal["memory", "sqlite", "postgres"]
class CheckpointerConfig(BaseModel): class CheckpointerConfig(BaseModel):
"""Configuration for LangGraph state persistence checkpointer.""" """Configuration for LangGraph state persistence checkpointer."""
model_config = ConfigDict(frozen=True)
type: CheckpointerType = Field( type: CheckpointerType = Field(
description="Checkpointer backend type. " description="Checkpointer backend type. "
"'memory' is in-process only (lost on restart). " "'memory' is in-process only (lost on restart). "
@@ -23,24 +25,3 @@ class CheckpointerConfig(BaseModel):
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. " "For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.", "For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
) )
# Global configuration instance — None means no checkpointer is configured.
_checkpointer_config: CheckpointerConfig | None = None
def get_checkpointer_config() -> CheckpointerConfig | None:
"""Get the current checkpointer configuration, or None if not configured."""
return _checkpointer_config
def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
"""Set the checkpointer configuration."""
global _checkpointer_config
_checkpointer_config = config
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
"""Load checkpointer configuration from a dictionary."""
global _checkpointer_config
_checkpointer_config = CheckpointerConfig(**config_dict)
@@ -1,102 +0,0 @@
"""Unified database backend configuration.
Controls BOTH the LangGraph checkpointer and the DeerFlow application
persistence layer (runs, threads metadata, users, etc.). The user
configures one backend; the system handles physical separation details.
SQLite mode: checkpointer and app share a single .db file
({sqlite_dir}/deerflow.db) with WAL journal mode enabled on every
connection. WAL allows concurrent readers and a single writer without
blocking, making a unified file safe for both workloads. Writers
that contend for the lock wait via the default 5-second sqlite3
busy timeout rather than failing immediately.
Postgres mode: both use the same database URL but maintain independent
connection pools with different lifecycles.
Memory mode: checkpointer uses MemorySaver, app uses in-memory stores.
No database is initialized.
Sensitive values (postgres_url) should use $VAR syntax in config.yaml
to reference environment variables from .env:
database:
backend: postgres
postgres_url: $DATABASE_URL
The $VAR resolution is handled by AppConfig.resolve_env_variables()
before this config is instantiated -- DatabaseConfig itself does not
need to do any environment variable processing.
"""
from __future__ import annotations
import os
from typing import Literal
from pydantic import BaseModel, Field
class DatabaseConfig(BaseModel):
backend: Literal["memory", "sqlite", "postgres"] = Field(
default="memory",
description=("Storage backend for both checkpointer and application data. 'memory' for development (no persistence across restarts), 'sqlite' for single-node deployment, 'postgres' for production multi-node deployment."),
)
sqlite_dir: str = Field(
default=".deer-flow/data",
description=("Directory for the SQLite database file. Both checkpointer and application data share {sqlite_dir}/deerflow.db."),
)
postgres_url: str = Field(
default="",
description=(
"PostgreSQL connection URL, shared by checkpointer and app. "
"Use $DATABASE_URL in config.yaml to reference .env. "
"Example: postgresql://user:pass@host:5432/deerflow "
"(the +asyncpg driver suffix is added automatically where needed)."
),
)
echo_sql: bool = Field(
default=False,
description="Echo all SQL statements to log (debug only).",
)
pool_size: int = Field(
default=5,
description="Connection pool size for the app ORM engine (postgres only).",
)
# -- Derived helpers (not user-configured) --
@property
def _resolved_sqlite_dir(self) -> str:
"""Resolve sqlite_dir to an absolute path (relative to CWD)."""
from pathlib import Path
return str(Path(self.sqlite_dir).resolve())
@property
def sqlite_path(self) -> str:
"""Unified SQLite file path shared by checkpointer and app."""
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
# Backward-compatible aliases
@property
def checkpointer_sqlite_path(self) -> str:
"""SQLite file path for the LangGraph checkpointer (alias for sqlite_path)."""
return self.sqlite_path
@property
def app_sqlite_path(self) -> str:
"""SQLite file path for application ORM data (alias for sqlite_path)."""
return self.sqlite_path
@property
def app_sqlalchemy_url(self) -> str:
"""SQLAlchemy async URL for the application ORM engine."""
if self.backend == "sqlite":
return f"sqlite+aiosqlite:///{self.sqlite_path}"
if self.backend == "postgres":
url = self.postgres_url
if url.startswith("postgresql://"):
url = url.replace("postgresql://", "postgresql+asyncpg://", 1)
return url
raise ValueError(f"No SQLAlchemy URL for backend={self.backend!r}")
@@ -0,0 +1,59 @@
"""Per-invocation context for DeerFlow agent execution.
Injected via LangGraph Runtime. Middleware and tools access this
via Runtime[DeerFlowContext] parameters, through resolve_context().
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True)
class DeerFlowContext:
"""Typed, immutable, per-invocation context injected via LangGraph Runtime.
Fields are all known at run start and never change during execution.
Mutable runtime state (e.g. sandbox_id) flows through ThreadState, not here.
"""
app_config: Any # AppConfig — typed as Any to avoid circular import at module level
thread_id: str
agent_name: str | None = None
def resolve_context(runtime: Any) -> DeerFlowContext:
"""Extract or construct DeerFlowContext from runtime.
Gateway/Client paths: runtime.context is already DeerFlowContext → return directly.
LangGraph Server / legacy dict path: construct from dict context or configurable fallback.
"""
ctx = getattr(runtime, "context", None)
if isinstance(ctx, DeerFlowContext):
return ctx
from deerflow.config.app_config import AppConfig
# Try dict context first (legacy path, tests), then configurable
if isinstance(ctx, dict):
return DeerFlowContext(
app_config=AppConfig.current(),
thread_id=ctx.get("thread_id", ""),
agent_name=ctx.get("agent_name"),
)
# No context at all — fall back to LangGraph configurable
try:
from langgraph.config import get_config
cfg = get_config().get("configurable", {})
except RuntimeError:
# Outside runnable context (e.g. unit tests)
cfg = {}
return DeerFlowContext(
app_config=AppConfig.current(),
thread_id=cfg.get("thread_id", ""),
agent_name=cfg.get("agent_name"),
)
@@ -11,6 +11,8 @@ from pydantic import BaseModel, ConfigDict, Field
class McpOAuthConfig(BaseModel): class McpOAuthConfig(BaseModel):
"""OAuth configuration for an MCP server (HTTP/SSE transports).""" """OAuth configuration for an MCP server (HTTP/SSE transports)."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled") enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
token_url: str = Field(description="OAuth token endpoint URL") token_url: str = Field(description="OAuth token endpoint URL")
grant_type: Literal["client_credentials", "refresh_token"] = Field( grant_type: Literal["client_credentials", "refresh_token"] = Field(
@@ -28,12 +30,13 @@ class McpOAuthConfig(BaseModel):
default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response") default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response")
refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry") refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry")
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint") extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
model_config = ConfigDict(extra="allow")
class McpServerConfig(BaseModel): class McpServerConfig(BaseModel):
"""Configuration for a single MCP server.""" """Configuration for a single MCP server."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether this MCP server is enabled") enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'") type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)") command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
@@ -43,12 +46,13 @@ class McpServerConfig(BaseModel):
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)") headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)") oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)")
description: str = Field(default="", description="Human-readable description of what this MCP server provides") description: str = Field(default="", description="Human-readable description of what this MCP server provides")
model_config = ConfigDict(extra="allow")
class SkillStateConfig(BaseModel): class SkillStateConfig(BaseModel):
"""Configuration for a single skill's state.""" """Configuration for a single skill's state."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=True, description="Whether this skill is enabled") enabled: bool = Field(default=True, description="Whether this skill is enabled")
@@ -64,7 +68,7 @@ class ExtensionsConfig(BaseModel):
default_factory=dict, default_factory=dict,
description="Map of skill name to state configuration", description="Map of skill name to state configuration",
) )
model_config = ConfigDict(extra="allow", populate_by_name=True) model_config = ConfigDict(extra="allow", frozen=True, populate_by_name=True)
@classmethod @classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path | None: def resolve_config_path(cls, config_path: str | None = None) -> Path | None:
@@ -195,62 +199,3 @@ class ExtensionsConfig(BaseModel):
# Default to enable for public & custom skill # Default to enable for public & custom skill
return skill_category in ("public", "custom") return skill_category in ("public", "custom")
return skill_config.enabled return skill_config.enabled
_extensions_config: ExtensionsConfig | None = None
def get_extensions_config() -> ExtensionsConfig:
"""Get the extensions config instance.
Returns a cached singleton instance. Use `reload_extensions_config()` to reload
from file, or `reset_extensions_config()` to clear the cache.
Returns:
The cached ExtensionsConfig instance.
"""
global _extensions_config
if _extensions_config is None:
_extensions_config = ExtensionsConfig.from_file()
return _extensions_config
def reload_extensions_config(config_path: str | None = None) -> ExtensionsConfig:
"""Reload the extensions config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to extensions config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded ExtensionsConfig instance.
"""
global _extensions_config
_extensions_config = ExtensionsConfig.from_file(config_path)
return _extensions_config
def reset_extensions_config() -> None:
"""Reset the cached extensions config instance.
This clears the singleton cache, causing the next call to
`get_extensions_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _extensions_config
_extensions_config = None
def set_extensions_config(config: ExtensionsConfig) -> None:
"""Set a custom extensions config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The ExtensionsConfig instance to use.
"""
global _extensions_config
_extensions_config = config
@@ -1,11 +1,13 @@
"""Configuration for pre-tool-call authorization.""" """Configuration for pre-tool-call authorization."""
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class GuardrailProviderConfig(BaseModel): class GuardrailProviderConfig(BaseModel):
"""Configuration for a guardrail provider.""" """Configuration for a guardrail provider."""
model_config = ConfigDict(frozen=True)
use: str = Field(description="Class path (e.g. 'deerflow.guardrails.builtin:AllowlistProvider')") use: str = Field(description="Class path (e.g. 'deerflow.guardrails.builtin:AllowlistProvider')")
config: dict = Field(default_factory=dict, description="Provider-specific settings passed as kwargs") config: dict = Field(default_factory=dict, description="Provider-specific settings passed as kwargs")
@@ -18,31 +20,9 @@ class GuardrailsConfig(BaseModel):
agent's passport reference, and returns an allow/deny decision. agent's passport reference, and returns an allow/deny decision.
""" """
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=False, description="Enable guardrail middleware") enabled: bool = Field(default=False, description="Enable guardrail middleware")
fail_closed: bool = Field(default=True, description="Block tool calls if provider errors") fail_closed: bool = Field(default=True, description="Block tool calls if provider errors")
passport: str | None = Field(default=None, description="OAP passport path or hosted agent ID") passport: str | None = Field(default=None, description="OAP passport path or hosted agent ID")
provider: GuardrailProviderConfig | None = Field(default=None, description="Guardrail provider configuration") provider: GuardrailProviderConfig | None = Field(default=None, description="Guardrail provider configuration")
_guardrails_config: GuardrailsConfig | None = None
def get_guardrails_config() -> GuardrailsConfig:
"""Get the guardrails config, returning defaults if not loaded."""
global _guardrails_config
if _guardrails_config is None:
_guardrails_config = GuardrailsConfig()
return _guardrails_config
def load_guardrails_config_from_dict(data: dict) -> GuardrailsConfig:
"""Load guardrails config from a dict (called during AppConfig loading)."""
global _guardrails_config
_guardrails_config = GuardrailsConfig.model_validate(data)
return _guardrails_config
def reset_guardrails_config() -> None:
"""Reset the cached config instance. Used in tests to prevent singleton leaks."""
global _guardrails_config
_guardrails_config = None
@@ -1,11 +1,13 @@
"""Configuration for memory mechanism.""" """Configuration for memory mechanism."""
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class MemoryConfig(BaseModel): class MemoryConfig(BaseModel):
"""Configuration for global memory mechanism.""" """Configuration for global memory mechanism."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=True, default=True,
description="Whether to enable memory mechanism", description="Whether to enable memory mechanism",
@@ -14,9 +16,8 @@ class MemoryConfig(BaseModel):
default="", default="",
description=( description=(
"Path to store memory data. " "Path to store memory data. "
"If empty, defaults to per-user memory at `{base_dir}/users/{user_id}/memory.json`. " "If empty, defaults to `{base_dir}/memory.json` (see Paths.memory_file). "
"Absolute paths are used as-is and opt out of per-user isolation " "Absolute paths are used as-is. "
"(all users share the same file). "
"Relative paths are resolved against `Paths.base_dir` " "Relative paths are resolved against `Paths.base_dir` "
"(not the backend working directory). " "(not the backend working directory). "
"Note: if you previously set this to `.deer-flow/memory.json`, " "Note: if you previously set this to `.deer-flow/memory.json`, "
@@ -60,24 +61,3 @@ class MemoryConfig(BaseModel):
le=8000, le=8000,
description="Maximum tokens to use for memory injection", description="Maximum tokens to use for memory injection",
) )
# Global configuration instance
_memory_config: MemoryConfig = MemoryConfig()
def get_memory_config() -> MemoryConfig:
"""Get the current memory configuration."""
return _memory_config
def set_memory_config(config: MemoryConfig) -> None:
"""Set the memory configuration."""
global _memory_config
_memory_config = config
def load_memory_config_from_dict(config_dict: dict) -> None:
"""Load memory configuration from a dictionary."""
global _memory_config
_memory_config = MemoryConfig(**config_dict)
@@ -12,7 +12,7 @@ class ModelConfig(BaseModel):
description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)", description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)",
) )
model: str = Field(..., description="Model name") model: str = Field(..., description="Model name")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow", frozen=True)
use_responses_api: bool | None = Field( use_responses_api: bool | None = Field(
default=None, default=None,
description="Whether to route OpenAI ChatOpenAI calls through the /v1/responses API", description="Whether to route OpenAI ChatOpenAI calls through the /v1/responses API",
@@ -7,7 +7,6 @@ from pathlib import Path, PureWindowsPath
VIRTUAL_PATH_PREFIX = "/mnt/user-data" VIRTUAL_PATH_PREFIX = "/mnt/user-data"
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") _SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
def _default_local_base_dir() -> Path: def _default_local_base_dir() -> Path:
@@ -23,13 +22,6 @@ def _validate_thread_id(thread_id: str) -> str:
return thread_id return thread_id
def _validate_user_id(user_id: str) -> str:
"""Validate a user ID before using it in filesystem paths."""
if not _SAFE_USER_ID_RE.match(user_id):
raise ValueError(f"Invalid user_id {user_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.")
return user_id
def _join_host_path(base: str, *parts: str) -> str: def _join_host_path(base: str, *parts: str) -> str:
"""Join host filesystem path segments while preserving native style. """Join host filesystem path segments while preserving native style.
@@ -142,63 +134,44 @@ class Paths:
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`.""" """Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
return self.agent_dir(name) / "memory.json" return self.agent_dir(name) / "memory.json"
def user_dir(self, user_id: str) -> Path: def thread_dir(self, thread_id: str) -> Path:
"""Directory for a specific user: `{base_dir}/users/{user_id}/`."""
return self.base_dir / "users" / _validate_user_id(user_id)
def user_memory_file(self, user_id: str) -> Path:
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
return self.user_dir(user_id) / "memory.json"
def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path:
"""Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`."""
return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json"
def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
""" """
Host path for a thread's data. Host path for a thread's data: `{base_dir}/threads/{thread_id}/`
When *user_id* is provided:
`{base_dir}/users/{user_id}/threads/{thread_id}/`
Otherwise (legacy layout):
`{base_dir}/threads/{thread_id}/`
This directory contains a `user-data/` subdirectory that is mounted This directory contains a `user-data/` subdirectory that is mounted
as `/mnt/user-data/` inside the sandbox. as `/mnt/user-data/` inside the sandbox.
Raises: Raises:
ValueError: If `thread_id` or `user_id` contains unsafe characters (path ValueError: If `thread_id` contains unsafe characters (path separators
separators or `..`) that could cause directory traversal. or `..`) that could cause directory traversal.
""" """
if user_id is not None:
return self.user_dir(user_id) / "threads" / _validate_thread_id(thread_id)
return self.base_dir / "threads" / _validate_thread_id(thread_id) return self.base_dir / "threads" / _validate_thread_id(thread_id)
def sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> Path: def sandbox_work_dir(self, thread_id: str) -> Path:
""" """
Host path for the agent's workspace directory. Host path for the agent's workspace directory.
Host: `{base_dir}/threads/{thread_id}/user-data/workspace/` Host: `{base_dir}/threads/{thread_id}/user-data/workspace/`
Sandbox: `/mnt/user-data/workspace/` Sandbox: `/mnt/user-data/workspace/`
""" """
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "workspace" return self.thread_dir(thread_id) / "user-data" / "workspace"
def sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> Path: def sandbox_uploads_dir(self, thread_id: str) -> Path:
""" """
Host path for user-uploaded files. Host path for user-uploaded files.
Host: `{base_dir}/threads/{thread_id}/user-data/uploads/` Host: `{base_dir}/threads/{thread_id}/user-data/uploads/`
Sandbox: `/mnt/user-data/uploads/` Sandbox: `/mnt/user-data/uploads/`
""" """
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "uploads" return self.thread_dir(thread_id) / "user-data" / "uploads"
def sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> Path: def sandbox_outputs_dir(self, thread_id: str) -> Path:
""" """
Host path for agent-generated artifacts. Host path for agent-generated artifacts.
Host: `{base_dir}/threads/{thread_id}/user-data/outputs/` Host: `{base_dir}/threads/{thread_id}/user-data/outputs/`
Sandbox: `/mnt/user-data/outputs/` Sandbox: `/mnt/user-data/outputs/`
""" """
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "outputs" return self.thread_dir(thread_id) / "user-data" / "outputs"
def acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> Path: def acp_workspace_dir(self, thread_id: str) -> Path:
""" """
Host path for the ACP workspace of a specific thread. Host path for the ACP workspace of a specific thread.
Host: `{base_dir}/threads/{thread_id}/acp-workspace/` Host: `{base_dir}/threads/{thread_id}/acp-workspace/`
@@ -207,43 +180,41 @@ class Paths:
Each thread gets its own isolated ACP workspace so that concurrent Each thread gets its own isolated ACP workspace so that concurrent
sessions cannot read each other's ACP agent outputs. sessions cannot read each other's ACP agent outputs.
""" """
return self.thread_dir(thread_id, user_id=user_id) / "acp-workspace" return self.thread_dir(thread_id) / "acp-workspace"
def sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> Path: def sandbox_user_data_dir(self, thread_id: str) -> Path:
""" """
Host path for the user-data root. Host path for the user-data root.
Host: `{base_dir}/threads/{thread_id}/user-data/` Host: `{base_dir}/threads/{thread_id}/user-data/`
Sandbox: `/mnt/user-data/` Sandbox: `/mnt/user-data/`
""" """
return self.thread_dir(thread_id, user_id=user_id) / "user-data" return self.thread_dir(thread_id) / "user-data"
def host_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> str: def host_thread_dir(self, thread_id: str) -> str:
"""Host path for a thread directory, preserving Windows path syntax.""" """Host path for a thread directory, preserving Windows path syntax."""
if user_id is not None:
return _join_host_path(self._host_base_dir_str(), "users", _validate_user_id(user_id), "threads", _validate_thread_id(thread_id))
return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id)) return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id))
def host_sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> str: def host_sandbox_user_data_dir(self, thread_id: str) -> str:
"""Host path for a thread's user-data root.""" """Host path for a thread's user-data root."""
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "user-data") return _join_host_path(self.host_thread_dir(thread_id), "user-data")
def host_sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> str: def host_sandbox_work_dir(self, thread_id: str) -> str:
"""Host path for the workspace mount source.""" """Host path for the workspace mount source."""
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "workspace") return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "workspace")
def host_sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> str: def host_sandbox_uploads_dir(self, thread_id: str) -> str:
"""Host path for the uploads mount source.""" """Host path for the uploads mount source."""
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "uploads") return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "uploads")
def host_sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> str: def host_sandbox_outputs_dir(self, thread_id: str) -> str:
"""Host path for the outputs mount source.""" """Host path for the outputs mount source."""
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "outputs") return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "outputs")
def host_acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> str: def host_acp_workspace_dir(self, thread_id: str) -> str:
"""Host path for the ACP workspace mount source.""" """Host path for the ACP workspace mount source."""
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "acp-workspace") return _join_host_path(self.host_thread_dir(thread_id), "acp-workspace")
def ensure_thread_dirs(self, thread_id: str, *, user_id: str | None = None) -> None: def ensure_thread_dirs(self, thread_id: str) -> None:
"""Create all standard sandbox directories for a thread. """Create all standard sandbox directories for a thread.
Directories are created with mode 0o777 so that sandbox containers Directories are created with mode 0o777 so that sandbox containers
@@ -257,24 +228,24 @@ class Paths:
ACP agent invocation. ACP agent invocation.
""" """
for d in [ for d in [
self.sandbox_work_dir(thread_id, user_id=user_id), self.sandbox_work_dir(thread_id),
self.sandbox_uploads_dir(thread_id, user_id=user_id), self.sandbox_uploads_dir(thread_id),
self.sandbox_outputs_dir(thread_id, user_id=user_id), self.sandbox_outputs_dir(thread_id),
self.acp_workspace_dir(thread_id, user_id=user_id), self.acp_workspace_dir(thread_id),
]: ]:
d.mkdir(parents=True, exist_ok=True) d.mkdir(parents=True, exist_ok=True)
d.chmod(0o777) d.chmod(0o777)
def delete_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> None: def delete_thread_dir(self, thread_id: str) -> None:
"""Delete all persisted data for a thread. """Delete all persisted data for a thread.
The operation is idempotent: missing thread directories are ignored. The operation is idempotent: missing thread directories are ignored.
""" """
thread_dir = self.thread_dir(thread_id, user_id=user_id) thread_dir = self.thread_dir(thread_id)
if thread_dir.exists(): if thread_dir.exists():
shutil.rmtree(thread_dir) shutil.rmtree(thread_dir)
def resolve_virtual_path(self, thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path: def resolve_virtual_path(self, thread_id: str, virtual_path: str) -> Path:
"""Resolve a sandbox virtual path to the actual host filesystem path. """Resolve a sandbox virtual path to the actual host filesystem path.
Args: Args:
@@ -282,7 +253,6 @@ class Paths:
virtual_path: Virtual path as seen inside the sandbox, e.g. virtual_path: Virtual path as seen inside the sandbox, e.g.
``/mnt/user-data/outputs/report.pdf``. ``/mnt/user-data/outputs/report.pdf``.
Leading slashes are stripped before matching. Leading slashes are stripped before matching.
user_id: Optional user ID for user-scoped path resolution.
Returns: Returns:
The resolved absolute host filesystem path. The resolved absolute host filesystem path.
@@ -300,7 +270,7 @@ class Paths:
raise ValueError(f"Path must start with /{prefix}") raise ValueError(f"Path must start with /{prefix}")
relative = stripped[len(prefix) :].lstrip("/") relative = stripped[len(prefix) :].lstrip("/")
base = self.sandbox_user_data_dir(thread_id, user_id=user_id).resolve() base = self.sandbox_user_data_dir(thread_id).resolve()
actual = (base / relative).resolve() actual = (base / relative).resolve()
try: try:
@@ -1,33 +0,0 @@
"""Run event storage configuration.
Controls where run events (messages + execution traces) are persisted.
Backends:
- memory: In-memory storage, data lost on restart. Suitable for
development and testing.
- db: SQL database via SQLAlchemy ORM. Provides full query capability.
Suitable for production deployments.
- jsonl: Append-only JSONL files. Lightweight alternative for
single-node deployments that need persistence without a database.
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
class RunEventsConfig(BaseModel):
backend: Literal["memory", "db", "jsonl"] = Field(
default="memory",
description="Storage backend for run events. 'memory' for development (no persistence), 'db' for production (SQL queries), 'jsonl' for lightweight single-node persistence.",
)
max_trace_content: int = Field(
default=10240,
description="Maximum trace content size in bytes before truncation (db backend only).",
)
track_token_usage: bool = Field(
default=True,
description="Whether RunJournal should accumulate token counts to RunRow.",
)
@@ -4,6 +4,8 @@ from pydantic import BaseModel, ConfigDict, Field
class VolumeMountConfig(BaseModel): class VolumeMountConfig(BaseModel):
"""Configuration for a volume mount.""" """Configuration for a volume mount."""
model_config = ConfigDict(frozen=True)
host_path: str = Field(..., description="Path on the host machine") host_path: str = Field(..., description="Path on the host machine")
container_path: str = Field(..., description="Path inside the container") container_path: str = Field(..., description="Path inside the container")
read_only: bool = Field(default=False, description="Whether the mount is read-only") read_only: bool = Field(default=False, description="Whether the mount is read-only")
@@ -80,4 +82,4 @@ class SandboxConfig(BaseModel):
description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.", description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
) )
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow", frozen=True)
@@ -1,9 +1,11 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class SkillEvolutionConfig(BaseModel): class SkillEvolutionConfig(BaseModel):
"""Configuration for agent-managed skill evolution.""" """Configuration for agent-managed skill evolution."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=False, default=False,
description="Whether the agent can create and modify skills under skills/custom.", description="Whether the agent can create and modify skills under skills/custom.",
@@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
def _default_repo_root() -> Path: def _default_repo_root() -> Path:
@@ -11,6 +11,8 @@ def _default_repo_root() -> Path:
class SkillsConfig(BaseModel): class SkillsConfig(BaseModel):
"""Configuration for skills system""" """Configuration for skills system"""
model_config = ConfigDict(frozen=True)
path: str | None = Field( path: str | None = Field(
default=None, default=None,
description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory", description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory",
@@ -2,7 +2,7 @@
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
StreamBridgeType = Literal["memory", "redis"] StreamBridgeType = Literal["memory", "redis"]
@@ -10,6 +10,8 @@ StreamBridgeType = Literal["memory", "redis"]
class StreamBridgeConfig(BaseModel): class StreamBridgeConfig(BaseModel):
"""Configuration for the stream bridge that connects agent workers to SSE endpoints.""" """Configuration for the stream bridge that connects agent workers to SSE endpoints."""
model_config = ConfigDict(frozen=True)
type: StreamBridgeType = Field( type: StreamBridgeType = Field(
default="memory", default="memory",
description="Stream bridge backend type. 'memory' uses in-process asyncio.Queue (single-process only). 'redis' uses Redis Streams (planned for Phase 2, not yet implemented).", description="Stream bridge backend type. 'memory' uses in-process asyncio.Queue (single-process only). 'redis' uses Redis Streams (planned for Phase 2, not yet implemented).",
@@ -22,25 +24,3 @@ class StreamBridgeConfig(BaseModel):
default=256, default=256,
description="Maximum number of events buffered per run in the memory bridge.", description="Maximum number of events buffered per run in the memory bridge.",
) )
# Global configuration instance — None means no stream bridge is configured
# (falls back to memory with defaults).
_stream_bridge_config: StreamBridgeConfig | None = None
def get_stream_bridge_config() -> StreamBridgeConfig | None:
"""Get the current stream bridge configuration, or None if not configured."""
return _stream_bridge_config
def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
"""Set the stream bridge configuration."""
global _stream_bridge_config
_stream_bridge_config = config
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
"""Load stream bridge configuration from a dictionary."""
global _stream_bridge_config
_stream_bridge_config = StreamBridgeConfig(**config_dict)
@@ -1,15 +1,13 @@
"""Configuration for the subagent system loaded from config.yaml.""" """Configuration for the subagent system loaded from config.yaml."""
import logging from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class SubagentOverrideConfig(BaseModel): class SubagentOverrideConfig(BaseModel):
"""Per-agent configuration overrides.""" """Per-agent configuration overrides."""
model_config = ConfigDict(frozen=True)
timeout_seconds: int | None = Field( timeout_seconds: int | None = Field(
default=None, default=None,
ge=1, ge=1,
@@ -25,6 +23,8 @@ class SubagentOverrideConfig(BaseModel):
class SubagentsAppConfig(BaseModel): class SubagentsAppConfig(BaseModel):
"""Configuration for the subagent system.""" """Configuration for the subagent system."""
model_config = ConfigDict(frozen=True)
timeout_seconds: int = Field( timeout_seconds: int = Field(
default=900, default=900,
ge=1, ge=1,
@@ -62,41 +62,3 @@ class SubagentsAppConfig(BaseModel):
if self.max_turns is not None: if self.max_turns is not None:
return self.max_turns return self.max_turns
return builtin_default return builtin_default
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
def get_subagents_app_config() -> SubagentsAppConfig:
"""Get the current subagents configuration."""
return _subagents_config
def load_subagents_config_from_dict(config_dict: dict) -> None:
"""Load subagents configuration from a dictionary."""
global _subagents_config
_subagents_config = SubagentsAppConfig(**config_dict)
overrides_summary = {}
for name, override in _subagents_config.agents.items():
parts = []
if override.timeout_seconds is not None:
parts.append(f"timeout={override.timeout_seconds}s")
if override.max_turns is not None:
parts.append(f"max_turns={override.max_turns}")
if parts:
overrides_summary[name] = ", ".join(parts)
if overrides_summary:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
overrides_summary,
)
else:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
)
@@ -2,7 +2,7 @@
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
ContextSizeType = Literal["fraction", "tokens", "messages"] ContextSizeType = Literal["fraction", "tokens", "messages"]
@@ -10,6 +10,8 @@ ContextSizeType = Literal["fraction", "tokens", "messages"]
class ContextSize(BaseModel): class ContextSize(BaseModel):
"""Context size specification for trigger or keep parameters.""" """Context size specification for trigger or keep parameters."""
model_config = ConfigDict(frozen=True)
type: ContextSizeType = Field(description="Type of context size specification") type: ContextSizeType = Field(description="Type of context size specification")
value: int | float = Field(description="Value for the context size specification") value: int | float = Field(description="Value for the context size specification")
@@ -21,6 +23,8 @@ class ContextSize(BaseModel):
class SummarizationConfig(BaseModel): class SummarizationConfig(BaseModel):
"""Configuration for automatic conversation summarization.""" """Configuration for automatic conversation summarization."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=False, default=False,
description="Whether to enable automatic conversation summarization", description="Whether to enable automatic conversation summarization",
@@ -51,24 +55,3 @@ class SummarizationConfig(BaseModel):
default=None, default=None,
description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.", description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.",
) )
# Global configuration instance
_summarization_config: SummarizationConfig = SummarizationConfig()
def get_summarization_config() -> SummarizationConfig:
"""Get the current summarization configuration."""
return _summarization_config
def set_summarization_config(config: SummarizationConfig) -> None:
"""Set the summarization configuration."""
global _summarization_config
_summarization_config = config
def load_summarization_config_from_dict(config_dict: dict) -> None:
"""Load summarization configuration from a dictionary."""
global _summarization_config
_summarization_config = SummarizationConfig(**config_dict)
@@ -1,11 +1,13 @@
"""Configuration for automatic thread title generation.""" """Configuration for automatic thread title generation."""
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class TitleConfig(BaseModel): class TitleConfig(BaseModel):
"""Configuration for automatic thread title generation.""" """Configuration for automatic thread title generation."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=True, default=True,
description="Whether to enable automatic title generation", description="Whether to enable automatic title generation",
@@ -30,24 +32,3 @@ class TitleConfig(BaseModel):
default=("Generate a concise title (max {max_words} words) for this conversation.\nUser: {user_msg}\nAssistant: {assistant_msg}\n\nReturn ONLY the title, no quotes, no explanation."), default=("Generate a concise title (max {max_words} words) for this conversation.\nUser: {user_msg}\nAssistant: {assistant_msg}\n\nReturn ONLY the title, no quotes, no explanation."),
description="Prompt template for title generation", description="Prompt template for title generation",
) )
# Global configuration instance
_title_config: TitleConfig = TitleConfig()
def get_title_config() -> TitleConfig:
"""Get the current title configuration."""
return _title_config
def set_title_config(config: TitleConfig) -> None:
"""Set the title configuration."""
global _title_config
_title_config = config
def load_title_config_from_dict(config_dict: dict) -> None:
"""Load title configuration from a dictionary."""
global _title_config
_title_config = TitleConfig(**config_dict)
@@ -1,7 +1,9 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class TokenUsageConfig(BaseModel): class TokenUsageConfig(BaseModel):
"""Configuration for token usage tracking.""" """Configuration for token usage tracking."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=False, description="Enable token usage tracking middleware") enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
@@ -5,7 +5,7 @@ class ToolGroupConfig(BaseModel):
"""Config section for a tool group""" """Config section for a tool group"""
name: str = Field(..., description="Unique name for the tool group") name: str = Field(..., description="Unique name for the tool group")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow", frozen=True)
class ToolConfig(BaseModel): class ToolConfig(BaseModel):
@@ -17,4 +17,4 @@ class ToolConfig(BaseModel):
..., ...,
description="Variable name of the tool provider(e.g. deerflow.sandbox.tools:bash_tool)", description="Variable name of the tool provider(e.g. deerflow.sandbox.tools:bash_tool)",
) )
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow", frozen=True)
@@ -1,6 +1,6 @@
"""Configuration for deferred tool loading via tool_search.""" """Configuration for deferred tool loading via tool_search."""
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class ToolSearchConfig(BaseModel): class ToolSearchConfig(BaseModel):
@@ -11,25 +11,9 @@ class ToolSearchConfig(BaseModel):
via the tool_search tool at runtime. via the tool_search tool at runtime.
""" """
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=False, default=False,
description="Defer tools and enable tool_search", description="Defer tools and enable tool_search",
) )
_tool_search_config: ToolSearchConfig | None = None
def get_tool_search_config() -> ToolSearchConfig:
"""Get the tool search config, loading from AppConfig if needed."""
global _tool_search_config
if _tool_search_config is None:
_tool_search_config = ToolSearchConfig()
return _tool_search_config
def load_tool_search_config_from_dict(data: dict) -> ToolSearchConfig:
"""Load tool search config from a dict (called during AppConfig loading)."""
global _tool_search_config
_tool_search_config = ToolSearchConfig.model_validate(data)
return _tool_search_config
@@ -1,7 +1,7 @@
import os import os
import threading import threading
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
_config_lock = threading.Lock() _config_lock = threading.Lock()
@@ -9,6 +9,8 @@ _config_lock = threading.Lock()
class LangSmithTracingConfig(BaseModel): class LangSmithTracingConfig(BaseModel):
"""Configuration for LangSmith tracing.""" """Configuration for LangSmith tracing."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(...) enabled: bool = Field(...)
api_key: str | None = Field(...) api_key: str | None = Field(...)
project: str = Field(...) project: str = Field(...)
@@ -26,6 +28,8 @@ class LangSmithTracingConfig(BaseModel):
class LangfuseTracingConfig(BaseModel): class LangfuseTracingConfig(BaseModel):
"""Configuration for Langfuse tracing.""" """Configuration for Langfuse tracing."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(...) enabled: bool = Field(...)
public_key: str | None = Field(...) public_key: str | None = Field(...)
secret_key: str | None = Field(...) secret_key: str | None = Field(...)
@@ -50,6 +54,8 @@ class LangfuseTracingConfig(BaseModel):
class TracingConfig(BaseModel): class TracingConfig(BaseModel):
"""Tracing configuration for supported providers.""" """Tracing configuration for supported providers."""
model_config = ConfigDict(frozen=True)
langsmith: LangSmithTracingConfig = Field(...) langsmith: LangSmithTracingConfig = Field(...)
langfuse: LangfuseTracingConfig = Field(...) langfuse: LangfuseTracingConfig = Field(...)
@@ -2,7 +2,7 @@ import logging
from langchain.chat_models import BaseChatModel from langchain.chat_models import BaseChatModel
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks from deerflow.tracing import build_tracing_callbacks
@@ -39,7 +39,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
Returns: Returns:
A chat model instance. A chat model instance.
""" """
config = get_app_config() config = AppConfig.current()
if name is None: if name is None:
name = config.models[0].name name = config.models[0].name
model_config = config.get_model_config(name) model_config = config.get_model_config(name)
@@ -113,16 +113,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
elif "reasoning_effort" not in model_settings_from_config: elif "reasoning_effort" not in model_settings_from_config:
model_settings_from_config["reasoning_effort"] = "medium" model_settings_from_config["reasoning_effort"] = "medium"
# Ensure stream_usage is enabled so that token usage metadata is available model_instance = model_class(**{**model_settings_from_config, **kwargs})
# in streaming responses. LangChain's BaseChatOpenAI only defaults
# stream_usage=True when no custom base_url/api_base is set, so models
# hitting third-party endpoints (e.g. doubao, deepseek) silently lose
# usage data. We default it to True unless explicitly configured.
if "stream_usage" not in model_settings_from_config and "stream_usage" not in kwargs:
if "stream_usage" in getattr(model_class, "model_fields", {}):
model_settings_from_config["stream_usage"] = True
model_instance = model_class(**kwargs, **model_settings_from_config)
callbacks = build_tracing_callbacks() callbacks = build_tracing_callbacks()
if callbacks: if callbacks:
@@ -1,13 +0,0 @@
"""DeerFlow application persistence layer (SQLAlchemy 2.0 async ORM).
This module manages DeerFlow's own application data -- runs metadata,
thread ownership, cron jobs, users. It is completely separate from
LangGraph's checkpointer, which manages graph execution state.
Usage:
from deerflow.persistence import init_engine, close_engine, get_session_factory
"""
from deerflow.persistence.engine import close_engine, get_engine, get_session_factory, init_engine
__all__ = ["close_engine", "get_engine", "get_session_factory", "init_engine"]
@@ -1,40 +0,0 @@
"""SQLAlchemy declarative base with automatic to_dict support.
All DeerFlow ORM models inherit from this Base. It provides a generic
to_dict() method via SQLAlchemy's inspect() so individual models don't
need to write their own serialization logic.
LangGraph's checkpointer tables are NOT managed by this Base.
"""
from __future__ import annotations
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
"""Base class for all DeerFlow ORM models.
Provides:
- Automatic to_dict() via SQLAlchemy column inspection.
- Standard __repr__() showing all column values.
"""
def to_dict(self, *, exclude: set[str] | None = None) -> dict:
"""Convert ORM instance to plain dict.
Uses SQLAlchemy's inspect() to iterate mapped column attributes.
Args:
exclude: Optional set of column keys to omit.
Returns:
Dict of {column_key: value} for all mapped columns.
"""
exclude = exclude or set()
return {c.key: getattr(self, c.key) for c in sa_inspect(type(self)).mapper.column_attrs if c.key not in exclude}
def __repr__(self) -> str:
cols = ", ".join(f"{c.key}={getattr(self, c.key)!r}" for c in sa_inspect(type(self)).mapper.column_attrs)
return f"{type(self).__name__}({cols})"
@@ -1,190 +0,0 @@
"""Async SQLAlchemy engine lifecycle management.
Initializes at Gateway startup, provides session factory for
repositories, disposes at shutdown.
When database.backend="memory", init_engine is a no-op and
get_session_factory() returns None. Repositories must check for
None and fall back to in-memory implementations.
"""
from __future__ import annotations
import json
import logging
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
def _json_serializer(obj: object) -> str:
"""JSON serializer with ensure_ascii=False for Chinese character support."""
return json.dumps(obj, ensure_ascii=False)
logger = logging.getLogger(__name__)
_engine: AsyncEngine | None = None
_session_factory: async_sessionmaker[AsyncSession] | None = None
async def _auto_create_postgres_db(url: str) -> None:
"""Connect to the ``postgres`` maintenance DB and CREATE DATABASE.
The target database name is extracted from *url*. The connection is
made to the default ``postgres`` database on the same server using
``AUTOCOMMIT`` isolation (CREATE DATABASE cannot run inside a
transaction).
"""
from sqlalchemy import text
from sqlalchemy.engine.url import make_url
parsed = make_url(url)
db_name = parsed.database
if not db_name:
raise ValueError("Cannot auto-create database: no database name in URL")
# Connect to the default 'postgres' database to issue CREATE DATABASE
maint_url = parsed.set(database="postgres")
maint_engine = create_async_engine(maint_url, isolation_level="AUTOCOMMIT")
try:
async with maint_engine.connect() as conn:
await conn.execute(text(f'CREATE DATABASE "{db_name}"'))
logger.info("Auto-created PostgreSQL database: %s", db_name)
finally:
await maint_engine.dispose()
async def init_engine(
backend: str,
*,
url: str = "",
echo: bool = False,
pool_size: int = 5,
sqlite_dir: str = "",
) -> None:
"""Create the async engine and session factory, then auto-create tables.
Args:
backend: "memory", "sqlite", or "postgres".
url: SQLAlchemy async URL (for sqlite/postgres).
echo: Echo SQL to log.
pool_size: Postgres connection pool size.
sqlite_dir: Directory to create for SQLite (ensured to exist).
"""
global _engine, _session_factory
if backend == "memory":
logger.info("Persistence backend=memory -- ORM engine not initialized")
return
if backend == "postgres":
try:
import asyncpg # noqa: F401
except ImportError:
raise ImportError("database.backend is set to 'postgres' but asyncpg is not installed.\nInstall it with:\n uv sync --extra postgres\nOr switch to backend: sqlite in config.yaml for single-node deployment.") from None
if backend == "sqlite":
import os
from sqlalchemy import event
os.makedirs(sqlite_dir or ".", exist_ok=True)
_engine = create_async_engine(url, echo=echo, json_serializer=_json_serializer)
# Enable WAL on every new connection. SQLite PRAGMA settings are
# per-connection, so we wire the listener instead of running PRAGMA
# once at startup. WAL gives concurrent reads + writers without
# blocking and is the standard recommendation for any production
# SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion
# ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only
# at WAL checkpoint boundaries instead of every commit.
# Note: we do not set PRAGMA busy_timeout here — Python's sqlite3
# driver already defaults to a 5-second busy timeout (see the
# ``timeout`` kwarg of ``sqlite3.connect``), and aiosqlite /
# SQLAlchemy's aiosqlite dialect inherit that default. Setting
# it again would be a no-op.
@event.listens_for(_engine.sync_engine, "connect")
def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract
cursor = dbapi_conn.cursor()
try:
cursor.execute("PRAGMA journal_mode=WAL;")
cursor.execute("PRAGMA synchronous=NORMAL;")
cursor.execute("PRAGMA foreign_keys=ON;")
finally:
cursor.close()
elif backend == "postgres":
_engine = create_async_engine(
url,
echo=echo,
pool_size=pool_size,
pool_pre_ping=True,
json_serializer=_json_serializer,
)
else:
raise ValueError(f"Unknown persistence backend: {backend!r}")
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
# Auto-create tables (dev convenience). Production should use Alembic.
from deerflow.persistence.base import Base
# Import all models so Base.metadata discovers them.
# When no models exist yet (scaffolding phase), this is a no-op.
try:
import deerflow.persistence.models # noqa: F401
except ImportError:
# Models package not yet available — tables won't be auto-created.
# This is expected during initial scaffolding or minimal installs.
logger.debug("deerflow.persistence.models not found; skipping auto-create tables")
try:
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
except Exception as exc:
if backend == "postgres" and "does not exist" in str(exc):
# Database not yet created — attempt to auto-create it, then retry.
await _auto_create_postgres_db(url)
# Rebuild engine against the now-existing database
await _engine.dispose()
_engine = create_async_engine(url, echo=echo, pool_size=pool_size, pool_pre_ping=True, json_serializer=_json_serializer)
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
else:
raise
logger.info("Persistence engine initialized: backend=%s", backend)
async def init_engine_from_config(config) -> None:
"""Convenience: init engine from a DatabaseConfig object."""
if config.backend == "memory":
await init_engine("memory")
return
await init_engine(
backend=config.backend,
url=config.app_sqlalchemy_url,
echo=config.echo_sql,
pool_size=config.pool_size,
sqlite_dir=config.sqlite_dir if config.backend == "sqlite" else "",
)
def get_session_factory() -> async_sessionmaker[AsyncSession] | None:
"""Return the async session factory, or None if backend=memory."""
return _session_factory
def get_engine() -> AsyncEngine | None:
"""Return the async engine, or None if not initialized."""
return _engine
async def close_engine() -> None:
"""Dispose the engine, release all connections."""
global _engine, _session_factory
if _engine is not None:
await _engine.dispose()
logger.info("Persistence engine closed")
_engine = None
_session_factory = None
@@ -1,6 +0,0 @@
"""Feedback persistence — ORM and SQL repository."""
from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.persistence.feedback.sql import FeedbackRepository
__all__ = ["FeedbackRepository", "FeedbackRow"]

Some files were not shown because too many files have changed in this diff Show More