Compare commits

..

5 Commits

Author SHA1 Message Date
jiangfeng.11 e60621d519 refactor(prompt): remove unused skill functions and clean up code 2026-04-11 11:03:47 +08:00
jiangfeng.11 f7a6ca8364 refactor(workspace): simplify sidebar state management using cookies 2026-04-11 10:38:05 +08:00
jiangfeng.11 2540acd5f7 Merge branch 'main' into release/2.0-rc 2026-04-11 10:34:31 +08:00
greatmengqi b2704525a0 feat(auth): release-validation pass for 2.0-rc — 12 blockers + simplify follow-ups (#2008)
* feat(auth): introduce backend auth module

Port RFC-001 authentication core from PR #1728:
- JWT token handling (create_access_token, decode_token, TokenPayload)
- Password hashing (bcrypt) with verify_password
- SQLite UserRepository with base interface
- Provider Factory pattern (LocalAuthProvider)
- CLI reset_admin tool
- Auth-specific errors (AuthErrorCode, TokenError, AuthErrorResponse)

Deps:
- bcrypt>=4.0.0
- pyjwt>=2.9.0
- email-validator>=2.0.0
- backend/uv.toml pins public PyPI index

Tests: 12 pure unit tests (test_auth_config.py, test_auth_errors.py).

Scope note: authz.py, test_auth.py, and test_auth_type_system.py are
deferred to commit 2 because they depend on middleware and deps wiring
that is not yet in place. Commit 1 stays "pure new files only" as the
spec mandates.

* feat(auth): wire auth end-to-end (middleware + frontend replacement)

Backend:
- Port auth_middleware, csrf_middleware, langgraph_auth, routers/auth
- Port authz decorator (owner_filter_key defaults to 'owner_id')
- Merge app.py: register AuthMiddleware + CSRFMiddleware + CORS, add
  _ensure_admin_user lifespan hook, _migrate_orphaned_threads helper,
  register auth router
- Merge deps.py: add get_local_provider, get_current_user_from_request,
  get_optional_user_from_request; keep get_current_user as thin str|None
  adapter for feedback router
- langgraph.json: add auth path pointing to langgraph_auth.py:auth
- Rename metadata['user_id'] -> metadata['owner_id'] in langgraph_auth
  (both metadata write and LangGraph filter dict) + test fixtures

Frontend:
- Delete better-auth library and api catch-all route
- Remove better-auth npm dependency and env vars (BETTER_AUTH_SECRET,
  BETTER_AUTH_GITHUB_*) from env.js
- Port frontend/src/core/auth/* (AuthProvider, gateway-config,
  proxy-policy, server-side getServerSideUser, types)
- Port frontend/src/core/api/fetcher.ts
- Port (auth)/layout, (auth)/login, (auth)/setup pages
- Rewrite workspace/layout.tsx as server component that calls
  getServerSideUser and wraps in AuthProvider
- Port workspace/workspace-content.tsx for the client-side sidebar logic

Tests:
- Port 5 auth test files (test_auth, test_auth_middleware,
  test_auth_type_system, test_ensure_admin, test_langgraph_auth)
- 176 auth tests PASS

After this commit: login/logout/registration flow works, but persistence
layer does not yet filter by owner_id. Commit 4 closes that gap.

* feat(auth): account settings page + i18n

- Port account-settings-page.tsx (change password, change email, logout)
- Wire into settings-dialog.tsx as new "account" section with UserIcon,
  rendered first in the section list
- Add i18n keys:
  - en-US/zh-CN: settings.sections.account ("Account" / "账号")
  - en-US/zh-CN: button.logout ("Log out" / "退出登录")
  - types.ts: matching type declarations

* feat(auth): enforce owner_id across 2.0-rc persistence layer

Add request-scoped contextvar-based owner filtering to threads_meta,
runs, run_events, and feedback repositories. Router code is unchanged
— isolation is enforced at the storage layer so that any caller that
forgets to pass owner_id still gets filtered results, and new routes
cannot accidentally leak data.

Core infrastructure
-------------------
- deerflow/runtime/user_context.py (new):
  - ContextVar[CurrentUser | None] with default None
  - runtime_checkable CurrentUser Protocol (structural subtype with .id)
  - set/reset/get/require helpers
  - AUTO sentinel + resolve_owner_id(value, method_name) for sentinel
    three-state resolution: AUTO reads contextvar, explicit str
    overrides, explicit None bypasses the filter (for migration/CLI)

Repository changes
------------------
- ThreadMetaRepository: create/get/search/update_*/delete gain
  owner_id=AUTO kwarg; read paths filter by owner, writes stamp it,
  mutations check ownership before applying
- RunRepository: put/get/list_by_thread/delete gain owner_id=AUTO kwarg
- FeedbackRepository: create/get/list_by_run/list_by_thread/delete
  gain owner_id=AUTO kwarg
- DbRunEventStore: list_messages/list_events/list_messages_by_run/
  count_messages/delete_by_thread/delete_by_run gain owner_id=AUTO
  kwarg. Write paths (put/put_batch) read contextvar softly: when a
  request-scoped user is available, owner_id is stamped; background
  worker writes without a user context pass None which is valid
  (orphan row to be bound by migration)

Schema
------
- persistence/models/run_event.py: RunEventRow.owner_id = Mapped[
  str | None] = mapped_column(String(64), nullable=True, index=True)
- No alembic migration needed: 2.0 ships fresh, Base.metadata.create_all
  picks up the new column automatically

Middleware
----------
- auth_middleware.py: after cookie check, call get_optional_user_from_
  request to load the real User, stamp it into request.state.user AND
  the contextvar via set_current_user, reset in a try/finally. Public
  paths and unauthenticated requests continue without contextvar, and
  @require_auth handles the strict 401 path

Test infrastructure
-------------------
- tests/conftest.py: @pytest.fixture(autouse=True) _auto_user_context
  sets a default SimpleNamespace(id="test-user-autouse") on every test
  unless marked @pytest.mark.no_auto_user. Keeps existing 20+
  persistence tests passing without modification
- pyproject.toml [tool.pytest.ini_options]: register no_auto_user
  marker so pytest does not emit warnings for opt-out tests
- tests/test_user_context.py: 6 tests covering three-state semantics,
  Protocol duck typing, and require/optional APIs
- tests/test_thread_meta_repo.py: one test updated to pass owner_id=
  None explicitly where it was previously relying on the old default

Test results
------------
- test_user_context.py: 6 passed
- test_auth*.py + test_langgraph_auth.py + test_ensure_admin.py: 127
- test_run_event_store / test_run_repository / test_thread_meta_repo
  / test_feedback: 92 passed
- Full backend suite: 1905 passed, 2 failed (both @requires_llm flaky
  integration tests unrelated to auth), 1 skipped

* feat(auth): extend orphan migration to 2.0-rc persistence tables

_ensure_admin_user now runs a three-step pipeline on every boot:

  Step 1 (fatal):     admin user exists / is created / password is reset
  Step 2 (non-fatal): LangGraph store orphan threads → admin
  Step 3 (non-fatal): SQL persistence tables → admin
    - threads_meta
    - runs
    - run_events
    - feedback

Each step is idempotent. The fatal/non-fatal split mirrors PR #1728's
original philosophy: admin creation failure blocks startup (the system
is unusable without an admin), whereas migration failures log a warning
and let the service proceed (a partial migration is recoverable; a
missing admin is not).

Key helpers
-----------
- _iter_store_items(store, namespace, *, page_size=500):
  async generator that cursor-paginates across LangGraph store pages.
  Fixes PR #1728's hardcoded limit=1000 bug that would silently lose
  orphans beyond the first page.

- _migrate_orphaned_threads(store, admin_user_id):
  Rewritten to use _iter_store_items. Returns the migrated count so the
  caller can log it; raises only on unhandled exceptions.

- _migrate_orphan_sql_tables(admin_user_id):
  Imports the 4 ORM models lazily, grabs the shared session factory,
  runs one UPDATE per table in a single transaction, commits once.
  No-op when no persistence backend is configured (in-memory dev).

Tests: test_ensure_admin.py (8 passed)

* test(auth): port AUTH test plan docs + lint/format pass

- Port backend/docs/AUTH_TEST_PLAN.md and AUTH_UPGRADE.md from PR #1728
- Rename metadata.user_id → metadata.owner_id in AUTH_TEST_PLAN.md
  (4 occurrences from the original PR doc)
- ruff auto-fix UP037 in sentinel type annotations: drop quotes around
  "str | None | _AutoSentinel" now that from __future__ import
  annotations makes them implicit string forms
- ruff format: 2 files (app/gateway/app.py, runtime/user_context.py)

Note on test coverage additions:
- conftest.py autouse fixture was already added in commit 4 (had to
  be co-located with the repository changes to keep pre-existing
  persistence tests passing)
- cross-user isolation E2E tests (test_owner_isolation.py) deferred
  — enforcement is already proven by the 98-test repository suite
  via the autouse fixture + explicit _AUTO sentinel exercises
- New test cases (TC-API-17..20, TC-ATK-13, TC-MIG-01..07) listed
  in AUTH_TEST_PLAN.md are deferred to a follow-up PR — they are
  manual-QA test cases rather than pytest code, and the spec-level
  coverage is already met by test_user_context.py + the 98-test
  repository suite.

Final test results:
- Auth suite (test_auth*, test_langgraph_auth, test_ensure_admin,
  test_user_context): 186 passed
- Persistence suite (test_run_event_store, test_run_repository,
  test_thread_meta_repo, test_feedback): 98 passed
- Lint: ruff check + ruff format both clean

* test(auth): add cross-user isolation test suite

10 tests exercising the storage-layer owner filter by manually
switching the user_context contextvar between two users. Verifies
the safety invariant:

  After a repository write with owner_id=A, a subsequent read with
  owner_id=B must not return the row, and vice versa.

Covers all 4 tables that own user-scoped data:

TC-API-17  threads_meta  — read, search, update, delete cross-user
TC-API-18  runs          — get, list_by_thread, delete cross-user
TC-API-19  run_events    — list_messages, list_events, count_messages,
                           delete_by_thread (CRITICAL: raw conversation
                           content leak vector)
TC-API-20  feedback      — get, list_by_run, delete cross-user

Plus two meta-tests verifying the sentinel pattern itself:
- AUTO + unset contextvar raises RuntimeError
- explicit owner_id=None bypasses the filter (migration escape hatch)

Architecture note
-----------------
These tests bypass the HTTP layer by design. The full chain
(cookie → middleware → contextvar → repository) is covered piecewise:

- test_auth_middleware.py: middleware sets contextvar from cookies
- test_owner_isolation.py: repositories enforce isolation when
  contextvar is set to different users

Together they prove the end-to-end safety property without the
ceremony of spinning up a full TestClient + in-memory DB for every
router endpoint.

Tests pass: 231 (full auth + persistence + isolation suite)
Lint: clean

* refactor(auth): migrate user repository to SQLAlchemy ORM

Move the users table into the shared persistence engine so auth
matches the pattern of threads_meta, runs, run_events, and feedback —
one engine, one session factory, one schema init codepath.

New files
---------
- persistence/user/__init__.py, persistence/user/model.py: UserRow
  ORM class with partial unique index on (oauth_provider, oauth_id)
- Registered in persistence/models/__init__.py so
  Base.metadata.create_all() picks it up

Modified
--------
- auth/repositories/sqlite.py: rewritten as async SQLAlchemy,
  identical constructor pattern to the other four repositories
  (def __init__(self, session_factory) + self._sf = session_factory)
- auth/config.py: drop users_db_path field — storage is configured
  through config.database like every other table
- deps.py/get_local_provider: construct SQLiteUserRepository with
  the shared session factory, fail fast if engine is not initialised
- tests/test_auth.py: rewrite test_sqlite_round_trip_new_fields to
  use the shared engine (init_engine + close_engine in a tempdir)
- tests/test_auth_type_system.py: add per-test autouse fixture that
  spins up a scratch engine and resets deps._cached_* singletons

* refactor(auth): remove SQL orphan migration (unused in supported scenarios)

The _migrate_orphan_sql_tables helper existed to bind NULL owner_id
rows in threads_meta, runs, run_events, and feedback to the admin on
first boot. But in every supported upgrade path, it's a no-op:

  1. Fresh install: create_all builds fresh tables, no legacy rows
  2. No-auth → with-auth (no existing persistence DB): persistence
     tables are created fresh by create_all, no legacy rows
  3. No-auth → with-auth (has existing persistence DB from #1930):
     NOT a supported upgrade path — "有 DB 到有 DB" schema evolution
     is out of scope; users wipe DB or run manual ALTER

So the SQL orphan migration never has anything to do in the
supported matrix. Delete the function, simplify _ensure_admin_user
from a 3-step pipeline to a 2-step one (admin creation + LangGraph
store orphan migration only).

LangGraph store orphan migration stays: it serves the real
"no-auth → with-auth" upgrade path where a user's existing LangGraph
thread metadata has no owner_id field and needs to be stamped with
the newly-created admin's id.

Tests: 284 passed (auth + persistence + isolation)
Lint: clean

* security(auth): write initial admin password to 0600 file instead of logs

CodeQL py/clear-text-logging-sensitive-data flagged 3 call sites that
logged the auto-generated admin password to stdout via logger.info().
Production log aggregators (ELK/Splunk/etc) would have captured those
cleartext secrets. Replace with a shared helper that writes to
.deer-flow/admin_initial_credentials.txt with mode 0600, and log only
the path.

New file
--------
- app/gateway/auth/credential_file.py: write_initial_credentials()
  helper. Takes email, password, and a "initial"/"reset" label.
  Creates .deer-flow/ if missing, writes a header comment plus the
  email+password, chmods 0o600, returns the absolute Path.

Modified
--------
- app/gateway/app.py: both _ensure_admin_user paths (fresh creation
  + needs_setup password reset) now write to file and log the path
- app/gateway/auth/reset_admin.py: rewritten to use the shared ORM
  repo (SQLiteUserRepository with session_factory) and the
  credential_file helper. The previous implementation was broken
  after the earlier ORM refactor — it still imported _get_users_conn
  and constructed SQLiteUserRepository() without a session factory.

No tests changed — the three password-log sites are all exercised
via existing test_ensure_admin.py which checks that startup
succeeds, not that a specific string appears in logs.

CodeQL alerts 272, 283, 284: all resolved.

* security(auth): strict JWT validation in middleware (fix junk cookie bypass)

AUTH_TEST_PLAN test 7.5.8 expects junk cookies to be rejected with
401. The previous middleware behaviour was "presence-only": check
that some access_token cookie exists, then pass through. In
combination with my Task-12 decision to skip @require_auth
decorators on routes, this created a gap where a request with any
cookie-shaped string (e.g. access_token=not-a-jwt) would bypass
authentication on routes that do not touch the repository
(/api/models, /api/mcp/config, /api/memory, /api/skills, …).

Fix: middleware now calls get_current_user_from_request() strictly
and catches the resulting HTTPException to render a 401 with the
proper fine-grained error code (token_invalid, token_expired,
user_not_found, …). On success it stamps request.state.user and
the contextvar so repository-layer owner filters work downstream.

The 4 old "_with_cookie_passes" tests in test_auth_middleware.py
were written for the presence-only behaviour; they asserted that
a junk cookie would make the handler return 200. They are renamed
to "_with_junk_cookie_rejected" and their assertions flipped to
401. The negative path (no cookie → 401 not_authenticated)
is unchanged.

Verified:
  no cookie       → 401 not_authenticated
  junk cookie     → 401 token_invalid     (the fixed bug)
  expired cookie  → 401 token_expired

Tests: 284 passed (auth + persistence + isolation)
Lint: clean

* security(auth): wire @require_permission(owner_check=True) on isolation routes

Apply the require_permission decorator to all 28 routes that take a
{thread_id} path parameter. Combined with the strict middleware
(previous commit), this gives the double-layer protection that
AUTH_TEST_PLAN test 7.5.9 documents:

  Layer 1 (AuthMiddleware): cookie + JWT validation, rejects junk
                            cookies and stamps request.state.user
  Layer 2 (@require_permission with owner_check=True): per-resource
                            ownership verification via
                            ThreadMetaStore.check_access — returns
                            404 if a different user owns the thread

The decorator's owner_check branch is rewritten to use the SQL
thread_meta_repo (the 2.0-rc persistence layer) instead of the
LangGraph store path that PR #1728 used (_store_get / get_store
in routers/threads.py). The inject_record convenience is dropped
— no caller in 2.0 needs the LangGraph blob, and the SQL repo has
a different shape.

Routes decorated (28 total):
- threads.py: delete, patch, get, get-state, post-state, post-history
- thread_runs.py: post-runs, post-runs-stream, post-runs-wait,
  list_runs, get_run, cancel_run, join_run, stream_existing_run,
  list_thread_messages, list_run_messages, list_run_events,
  thread_token_usage
- feedback.py: create, list, stats, delete
- uploads.py: upload (added Request param), list, delete
- artifacts.py: get_artifact
- suggestions.py: generate (renamed body parameter to avoid
  conflict with FastAPI Request)

Test fixes:
- test_suggestions_router.py: bypass the decorator via __wrapped__
  (the unit tests cover parsing logic, not auth — no point spinning
  up a thread_meta_repo just to test JSON unwrapping)
- test_auth_middleware.py 4 fake-cookie tests: already updated in
  the previous commit (745bf432)

Tests: 293 passed (auth + persistence + isolation + suggestions)
Lint: clean

* security(auth): defense-in-depth fixes from release validation pass

Eight findings caught while running the AUTH_TEST_PLAN end-to-end against
the deployed sg_dev stack. Each is a pre-condition for shipping
release/2.0-rc that the previous PRs missed.

Backend hardening
- routers/auth.py: rate limiter X-Real-IP now requires AUTH_TRUSTED_PROXIES
  whitelist (CIDR/IP allowlist). Without nginx in front, the previous code
  honored arbitrary X-Real-IP, letting an attacker rotate the header to
  fully bypass the per-IP login lockout.
- routers/auth.py: 36-entry common-password blocklist via Pydantic
  field_validator on RegisterRequest + ChangePasswordRequest. The shared
  _validate_strong_password helper keeps the constraint in one place.
- routers/threads.py: ThreadCreateRequest + ThreadPatchRequest strip
  server-reserved metadata keys (owner_id, user_id) via Pydantic
  field_validator so a forged value can never round-trip back to other
  clients reading the same thread. The actual ownership invariant stays
  on the threads_meta row; this closes the metadata-blob echo gap.
- authz.py + thread_meta/sql.py: require_permission gains a require_existing
  flag plumbed through check_access(require_existing=True). Destructive
  routes (DELETE/PATCH/state-update/runs/feedback) now treat a missing
  thread_meta row as 404 instead of "untracked legacy thread, allow",
  closing the cross-user delete-idempotence gap where any user could
  successfully DELETE another user's deleted thread.
- repositories/sqlite.py + base.py: update_user raises UserNotFoundError
  on a vanished row instead of silently returning the input. Concurrent
  delete during password reset can no longer look like a successful update.
- runtime/user_context.py: resolve_owner_id() coerces User.id (UUID) to
  str at the contextvar boundary so SQLAlchemy String(64) columns can
  bind it. The whole 2.0-rc isolation pipeline was previously broken
  end-to-end (POST /api/threads → 500 "type 'UUID' is not supported").
- persistence/engine.py: SQLAlchemy listener enables PRAGMA journal_mode=WAL,
  synchronous=NORMAL, foreign_keys=ON on every new SQLite connection.
  TC-UPG-06 in the test plan expects WAL; previous code shipped with the
  default 'delete' journal.
- auth_middleware.py: stamp request.state.auth = AuthContext(...) so
  @require_permission's short-circuit fires; previously every isolation
  request did a duplicate JWT decode + users SELECT. Also unifies the
  401 payload through AuthErrorResponse(...).model_dump().
- app.py: _ensure_admin_user restructure removes the noqa F821 scoping
  bug where 'password' was referenced outside the branch that defined it.
  New _announce_credentials helper absorbs the duplicate log block in
  the fresh-admin and reset-admin branches.

* fix(frontend+nginx): rollout CSRF on every state-changing client path

The frontend was 100% broken in gateway-pro mode for any user trying to
open a specific chat thread. Three cumulative bugs each silently
masked the next.

LangGraph SDK CSRF gap (api-client.ts)
- The Client constructor took only apiUrl, no defaultHeaders, no fetch
  interceptor. The SDK's internal fetch never sent X-CSRF-Token, so
  every state-changing /api/langgraph-compat/* call (runs/stream,
  threads/search, threads/{tid}/history, ...) hit CSRFMiddleware and
  got 403 before reaching the auth check. UI symptom: empty thread page
  with no error message; the SPA's hooks swallowed the rejection.
- Fix: pass an onRequest hook that injects X-CSRF-Token from the
  csrf_token cookie per request. Reading the cookie per call (not at
  construction time) handles login / logout / password-change cookie
  rotation transparently. The SDK's prepareFetchOptions calls
  onRequest for both regular requests AND streaming/SSE/reconnect, so
  the same hook covers runs.stream and runs.joinStream.

Raw fetch CSRF gap (7 files)
- Audit: 11 frontend fetch sites, only 2 included CSRF (login/setup +
  account-settings change-password). The other 7 routed through raw
  fetch() with no header — suggestions, memory, agents, mcp, skills,
  uploads, and the local thread cleanup hook all 403'd silently.
- Fix: enhance fetcher.ts:fetchWithAuth to auto-inject X-CSRF-Token on
  POST/PUT/DELETE/PATCH from a single shared readCsrfCookie() helper.
  Convert all 7 raw fetch() callers to fetchWithAuth so the contract
  is centrally enforced. api-client.ts and fetcher.ts share
  readCsrfCookie + STATE_CHANGING_METHODS to avoid drift.

nginx routing + buffering (nginx.local.conf)
- The auth feature shipped without updating the nginx config: per-API
  explicit location blocks but no /api/v1/auth/, /api/feedback, /api/runs.
  The frontend's client-side fetches to /api/v1/auth/login/local 404'd
  from the Next.js side because nginx routed /api/* to the frontend.
- Fix: add catch-all `location /api/` that proxies to the gateway.
  nginx longest-prefix matching keeps the explicit blocks (/api/models,
  /api/threads regex, /api/langgraph/, ...) winning for their paths.
- Fix: disable proxy_buffering + proxy_request_buffering for the
  frontend `location /` block. Without it, nginx tries to spool large
  Next.js chunks into /var/lib/nginx/proxy (root-owned) and fails with
  Permission denied → ERR_INCOMPLETE_CHUNKED_ENCODING → ChunkLoadError.

* test(auth): release-validation test infra and new coverage

Test fixtures and unit tests added during the validation pass.

Router test helpers (NEW: tests/_router_auth_helpers.py)
- make_authed_test_app(): builds a FastAPI test app with a stub
  middleware that stamps request.state.user + request.state.auth and a
  permissive thread_meta_repo mock. TestClient-based router tests
  (test_artifacts_router, test_threads_router) use it instead of bare
  FastAPI() so the new @require_permission(owner_check=True) decorators
  short-circuit cleanly.
- call_unwrapped(): walks the __wrapped__ chain to invoke the underlying
  handler without going through the authz wrappers. Direct-call tests
  (test_uploads_router) use it. Typed with ParamSpec so the wrapped
  signature flows through.

Backend test additions
- test_auth.py: 7 tests for the new _get_client_ip trust model (no
  proxy / trusted proxy / untrusted peer / XFF rejection / invalid
  CIDR / no client). 5 tests for the password blocklist (literal,
  case-insensitive, strong password accepted, change-password binding,
  short-password length-check still fires before blocklist).
  test_update_user_raises_when_row_concurrently_deleted: closes a
  shipped-without-coverage gap on the new UserNotFoundError contract.
- test_thread_meta_repo.py: 4 tests for check_access(require_existing=True)
  — strict missing-row denial, strict owner match, strict owner mismatch,
  strict null-owner still allowed (shared rows survive the tightening).
- test_ensure_admin.py: 3 tests for _migrate_orphaned_threads /
  _iter_store_items pagination, covering the TC-UPG-02 upgrade story
  end-to-end via mock store. Closes the gap where the cursor pagination
  was untested even though the previous PR rewrote it.
- test_threads_router.py: 5 tests for _strip_reserved_metadata
  (owner_id removal, user_id removal, safe-keys passthrough, empty
  input, both-stripped).
- test_auth_type_system.py: replace "password123" fixtures with
  Tr0ub4dor3a / AnotherStr0ngPwd! so the new password blocklist
  doesn't reject the test data.

* docs(auth): refresh TC-DOCKER-05 + document Docker validation gap

- AUTH_TEST_PLAN.md TC-DOCKER-05: the previous expectation
  ("admin password visible in docker logs") was stale after the simplify
  pass that moved credentials to a 0600 file. The grep "Password:" check
  would have silently failed and given a false sense of coverage. New
  expectation matches the actual file-based path: 0600 file in
  DEER_FLOW_HOME, log shows the path (not the secret), reverse-grep
  asserts no leaked password in container logs.
- NEW: docs/AUTH_TEST_DOCKER_GAP.md documents the only un-executed
  block in the test plan (TC-DOCKER-01..06). Reason: sg_dev validation
  host has no Docker daemon installed. The doc maps each Docker case
  to an already-validated bare-metal equivalent (TC-1.1, TC-REENT-01,
  TC-API-02 etc.) so the gap is auditable, and includes pre-flight
  reproduction steps for whoever has Docker available.

---------

Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
2026-04-09 11:29:32 +08:00
rayhpeng 00e0e9a49a feat(persistence): add unified persistence layer with event store, token tracking, and feedback (#1930)
* feat(persistence): add SQLAlchemy 2.0 async ORM scaffold

Introduce a unified database configuration (DatabaseConfig) that
controls both the LangGraph checkpointer and the DeerFlow application
persistence layer from a single `database:` config section.

New modules:
- deerflow.config.database_config — Pydantic config with memory/sqlite/postgres backends
- deerflow.persistence — async engine lifecycle, DeclarativeBase with to_dict mixin, Alembic skeleton
- deerflow.runtime.runs.store — RunStore ABC + MemoryRunStore implementation

Gateway integration initializes/tears down the persistence engine in
the existing langgraph_runtime() context manager. Legacy checkpointer
config is preserved for backward compatibility.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(persistence): add RunEventStore ABC + MemoryRunEventStore

Phase 2-A prerequisite for event storage: adds the unified run event
stream interface (RunEventStore) with an in-memory implementation,
RunEventsConfig, gateway integration, and comprehensive tests (27 cases).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(persistence): add ORM models, repositories, DB/JSONL event stores, RunJournal, and API endpoints

Phase 2-B: run persistence + event storage + token tracking.

- ORM models: RunRow (with token fields), ThreadMetaRow, RunEventRow
- RunRepository implements RunStore ABC via SQLAlchemy ORM
- ThreadMetaRepository with owner access control
- DbRunEventStore with trace content truncation and cursor pagination
- JsonlRunEventStore with per-run files and seq recovery from disk
- RunJournal (BaseCallbackHandler) captures LLM/tool/lifecycle events,
  accumulates token usage by caller type, buffers and flushes to store
- RunManager now accepts optional RunStore for persistent backing
- Worker creates RunJournal, writes human_message, injects callbacks
- Gateway deps use factory functions (RunRepository when DB available)
- New endpoints: messages, run messages, run events, token-usage
- ThreadCreateRequest gains assistant_id field
- 92 tests pass (33 new), zero regressions

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(persistence): add user feedback + follow-up run association

Phase 2-C: feedback and follow-up tracking.

- FeedbackRow ORM model (rating +1/-1, optional message_id, comment)
- FeedbackRepository with CRUD, list_by_run/thread, aggregate stats
- Feedback API endpoints: create, list, stats, delete
- follow_up_to_run_id in RunCreateRequest (explicit or auto-detected
  from latest successful run on the thread)
- Worker writes follow_up_to_run_id into human_message event metadata
- Gateway deps: feedback_repo factory + getter
- 17 new tests (14 FeedbackRepository + 3 follow-up association)
- 109 total tests pass, zero regressions

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* test+config: comprehensive Phase 2 test coverage + deprecate checkpointer config

- config.example.yaml: deprecate standalone checkpointer section, activate
  unified database:sqlite as default (drives both checkpointer + app data)
- New: test_thread_meta_repo.py (14 tests) — full ThreadMetaRepository coverage
  including check_access owner logic, list_by_owner pagination
- Extended test_run_repository.py (+4 tests) — completion preserves fields,
  list ordering desc, limit, owner_none returns all
- Extended test_run_journal.py (+8 tests) — on_chain_error, track_tokens=false,
  middleware no ai_message, unknown caller tokens, convenience fields,
  tool_error, non-summarization custom event
- Extended test_run_event_store.py (+7 tests) — DB batch seq continuity,
  make_run_event_store factory (memory/db/jsonl/fallback/unknown)
- Extended test_phase2b_integration.py (+4 tests) — create_or_reject persists,
  follow-up metadata, summarization in history, full DB-backed lifecycle
- Fixed DB integration test to use proper fake objects (not MagicMock)
  for JSON-serializable metadata
- 157 total Phase 2 tests pass, zero regressions

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* config: move default sqlite_dir to .deer-flow/data

Keep SQLite databases alongside other DeerFlow-managed data
(threads, memory) under the .deer-flow/ directory instead of a
top-level ./data folder.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(persistence): remove UTFJSON, use engine-level json_serializer + datetime.now()

- Replace custom UTFJSON type with standard sqlalchemy.JSON in all ORM
  models. Add json_serializer=json.dumps(ensure_ascii=False) to all
  create_async_engine calls so non-ASCII text (Chinese etc.) is stored
  as-is in both SQLite and Postgres.
- Change ORM datetime defaults from datetime.now(UTC) to datetime.now(),
  remove UTC imports.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(gateway): simplify deps.py with getter factory + inline repos

- Replace 6 identical getter functions with _require() factory.
- Inline 3 _make_*_repo() factories into langgraph_runtime(), call
  get_session_factory() once instead of 3 times.
- Add thread_meta upsert in start_run (services.py).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(docker): add UV_EXTRAS build arg for optional dependencies

Support installing optional dependency groups (e.g. postgres) at
Docker build time via UV_EXTRAS build arg:
  UV_EXTRAS=postgres docker compose build

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(journal): fix flush, token tracking, and consolidate tests

RunJournal fixes:
- _flush_sync: retain events in buffer when no event loop instead of
  dropping them; worker's finally block flushes via async flush().
- on_llm_end: add tool_calls filter and caller=="lead_agent" guard for
  ai_message events; mark message IDs for dedup with record_llm_usage.
- worker.py: persist completion data (tokens, message count) to RunStore
  in finally block.

Model factory:
- Auto-inject stream_usage=True for BaseChatOpenAI subclasses with
  custom api_base, so usage_metadata is populated in streaming responses.

Test consolidation:
- Delete test_phase2b_integration.py (redundant with existing tests).
- Move DB-backed lifecycle test into test_run_journal.py.
- Add tests for stream_usage injection in test_model_factory.py.
- Clean up executor/task_tool dead journal references.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(events): widen content type to str|dict in all store backends

Allow event content to be a dict (for structured OpenAI-format messages)
in addition to plain strings. Dict values are JSON-serialized for the DB
backend and deserialized on read; memory and JSONL backends handle dicts
natively. Trace truncation now serializes dicts to JSON before measuring.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(events): use metadata flag instead of heuristic for dict content detection

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(converters): add LangChain-to-OpenAI message format converters

Pure functions langchain_to_openai_message, langchain_to_openai_completion,
langchain_messages_to_openai, and _infer_finish_reason for converting
LangChain BaseMessage objects to OpenAI Chat Completions format, used by
RunJournal for event storage. 15 unit tests added.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(converters): handle empty list content as null, clean up test

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(events): human_message content uses OpenAI user message format

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat(events): ai_message uses OpenAI format, add ai_tool_call message event

- ai_message content now uses {"role": "assistant", "content": "..."} format
- New ai_tool_call message event emitted when lead_agent LLM responds with tool_calls
- ai_tool_call uses langchain_to_openai_message converter for consistent format
- Both events include finish_reason in metadata ("stop" or "tool_calls")

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(events): add tool_result message event with OpenAI tool message format

Cache tool_call_id from on_tool_start keyed by run_id as fallback for on_tool_end,
then emit a tool_result message event (role=tool, tool_call_id, content) after each
successful tool completion.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat(events): summary content uses OpenAI system message format

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(events): replace llm_start/llm_end with llm_request/llm_response in OpenAI format

Add on_chat_model_start to capture structured prompt messages as llm_request events.
Replace llm_end trace events with llm_response using OpenAI Chat Completions format.
Track llm_call_index to pair request/response events.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(events): add record_middleware method for middleware trace events

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* test(events): add full run sequence integration test for OpenAI content format

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat(events): align message events with checkpoint format and add middleware tag injection

- Message events (ai_message, ai_tool_call, tool_result, human_message) now use
  BaseMessage.model_dump() format, matching LangGraph checkpoint values.messages
- on_tool_end extracts tool_call_id/name/status from ToolMessage objects
- on_tool_error now emits tool_result message events with error status
- record_middleware uses middleware:{tag} event_type and middleware category
- Summarization custom events use middleware:summarize category
- TitleMiddleware injects middleware:title tag via get_config() inheritance
- SummarizationMiddleware model bound with middleware:summarize tag
- Worker writes human_message using HumanMessage.model_dump()

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(threads): switch search endpoint to threads_meta table and sync title

- POST /api/threads/search now queries threads_meta table directly,
  removing the two-phase Store + Checkpointer scan approach
- Add ThreadMetaRepository.search() with metadata/status filters
- Add ThreadMetaRepository.update_display_name() for title sync
- Worker syncs checkpoint title to threads_meta.display_name on run completion
- Map display_name to values.title in search response for API compatibility

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(threads): history endpoint reads messages from event store

- POST /api/threads/{thread_id}/history now combines two data sources:
  checkpointer for checkpoint_id, metadata, title, thread_data;
  event store for messages (complete history, not truncated by summarization)
- Strip internal LangGraph metadata keys from response
- Remove full channel_values serialization in favor of selective fields

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: remove duplicate optional-dependencies header in pyproject.toml

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(middleware): pass tagged config to TitleMiddleware ainvoke call

Without the config, the middleware:title tag was not injected,
causing the LLM response to be recorded as a lead_agent ai_message
in run_events.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: resolve merge conflict in .env.example

Keep both DATABASE_URL (from persistence-scaffold) and WECOM
credentials (from main) after the merge.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(persistence): address review feedback on PR #1851

- Fix naive datetime.now() → datetime.now(UTC) in all ORM models
- Fix seq race condition in DbRunEventStore.put() with FOR UPDATE
  and UNIQUE(thread_id, seq) constraint
- Encapsulate _store access in RunManager.update_run_completion()
- Deduplicate _store.put() logic in RunManager via _persist_to_store()
- Add update_run_completion to RunStore ABC + MemoryRunStore
- Wire follow_up_to_run_id through the full create path
- Add error recovery to RunJournal._flush_sync() lost-event scenario
- Add migration note for search_threads breaking change
- Fix test_checkpointer_none_fix mock to set database=None

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* chore: update uv.lock

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(persistence): address 22 review comments from CodeQL, Copilot, and Code Quality

Bug fixes:
- Sanitize log params to prevent log injection (CodeQL)
- Reset threads_meta.status to idle/error when run completes
- Attach messages only to latest checkpoint in /history response
- Write threads_meta on POST /threads so new threads appear in search

Lint fixes:
- Remove unused imports (journal.py, migrations/env.py, test_converters.py)
- Convert lambda to named function (engine.py, Ruff E731)
- Remove unused logger definitions in repos (Ruff F841)
- Add logging to JSONL decode errors and empty except blocks
- Separate assert side-effects in tests (CodeQL)
- Remove unused local variables in tests (Ruff F841)
- Fix max_trace_content truncation to use byte length, not char length

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* style: apply ruff format to persistence and runtime files

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Potential fix for pull request finding 'Statement has no effect'

Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>

* refactor(runtime): introduce RunContext to reduce run_agent parameter bloat

Extract checkpointer, store, event_store, run_events_config, thread_meta_repo,
and follow_up_to_run_id into a frozen RunContext dataclass. Add get_run_context()
in deps.py to build the base context from app.state singletons. start_run() uses
dataclasses.replace() to enrich per-run fields before passing ctx to run_agent.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(gateway): move sanitize_log_param to app/gateway/utils.py

Extract the log-injection sanitizer from routers/threads.py into a shared
utils module and rename to sanitize_log_param (public API). Eliminates the
reverse service → router import in services.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* perf: use SQL aggregation for feedback stats and thread token usage

Replace Python-side counting in FeedbackRepository.aggregate_by_run with
a single SELECT COUNT/SUM query. Add RunStore.aggregate_tokens_by_thread
abstract method with SQL GROUP BY implementation in RunRepository and
Python fallback in MemoryRunStore. Simplify the thread_token_usage
endpoint to delegate to the new method, eliminating the limit=10000
truncation risk.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* docs: annotate DbRunEventStore.put() as low-frequency path

Add docstring clarifying that put() opens a per-call transaction with
FOR UPDATE and should only be used for infrequent writes (currently
just the initial human_message event). High-throughput callers should
use put_batch() instead.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(threads): fall back to Store search when ThreadMetaRepository is unavailable

When database.backend=memory (default) or no SQL session factory is
configured, search_threads now queries the LangGraph Store instead of
returning 503. Returns empty list if neither Store nor repo is available.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(persistence): introduce ThreadMetaStore ABC for backend-agnostic thread metadata

Add ThreadMetaStore abstract base class with create/get/search/update/delete
interface. ThreadMetaRepository (SQL) now inherits from it. New
MemoryThreadMetaStore wraps LangGraph BaseStore for memory-mode deployments.

deps.py now always provides a non-None thread_meta_repo, eliminating all
`if thread_meta_repo is not None` guards in services.py, worker.py, and
routers/threads.py. search_threads no longer needs a Store fallback branch.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(history): read messages from checkpointer instead of RunEventStore

The /history endpoint now reads messages directly from the
checkpointer's channel_values (the authoritative source) instead of
querying RunEventStore.list_messages(). The RunEventStore API is
preserved for other consumers.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(persistence): address new Copilot review comments

- feedback.py: validate thread_id/run_id before deleting feedback
- jsonl.py: add path traversal protection with ID validation
- run_repo.py: parse `before` to datetime for PostgreSQL compat
- thread_meta_repo.py: fix pagination when metadata filter is active
- database_config.py: use resolve_path for sqlite_dir consistency

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Implement skill self-evolution and skill_manage flow (#1874)

* chore: ignore .worktrees directory

* Add skill_manage self-evolution flow

* Fix CI regressions for skill_manage

* Address PR review feedback for skill evolution

* fix(skill-evolution): preserve history on delete

* fix(skill-evolution): tighten scanner fallbacks

* docs: add skill_manage e2e evidence screenshot

* fix(skill-manage): avoid blocking fs ops in session runtime

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>

* fix(config): resolve sqlite_dir relative to CWD, not Paths.base_dir

resolve_path() resolves relative to Paths.base_dir (.deer-flow),
which double-nested the path to .deer-flow/.deer-flow/data/app.db.
Use Path.resolve() (CWD-relative) instead.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Feature/feishu receive file (#1608)

* feat(feishu): add channel file materialization hook for inbound messages

- Introduce Channel.receive_file(msg, thread_id) as a base method for file materialization; default is no-op.
- Implement FeishuChannel.receive_file to download files/images from Feishu messages, save to sandbox, and inject virtual paths into msg.text.
- Update ChannelManager to call receive_file for any channel if msg.files is present, enabling downstream model access to user-uploaded files.
- No impact on Slack/Telegram or other channels (they inherit the default no-op).

* style(backend): format code with ruff for lint compliance

- Auto-formatted packages/harness/deerflow/agents/factory.py and tests/test_create_deerflow_agent.py using `ruff format`
- Ensured both files conform to project linting standards
- Fixes CI lint check failures caused by code style issues

* fix(feishu): handle file write operation asynchronously to prevent blocking

* fix(feishu): rename GetMessageResourceRequest to _GetMessageResourceRequest and remove redundant code

* test(feishu): add tests for receive_file method and placeholder replacement

* fix(manager): remove unnecessary type casting for channel retrieval

* fix(feishu): update logging messages to reflect resource handling instead of image

* fix(feishu): sanitize filename by replacing invalid characters in file uploads

* fix(feishu): improve filename sanitization and reorder image key handling in message processing

* fix(feishu): add thread lock to prevent filename conflicts during file downloads

* fix(test): correct bad merge in test_feishu_parser.py

* chore: run ruff and apply formatting cleanup
fix(feishu): preserve rich-text attachment order and improve fallback filename handling

* fix(docker): restore gateway env vars and fix langgraph empty arg issue (#1915)

Two production docker-compose.yaml bugs prevent `make up` from working:

1. Gateway missing DEER_FLOW_CONFIG_PATH and DEER_FLOW_EXTENSIONS_CONFIG_PATH
   environment overrides. Added in fb2d99f (#1836) but accidentally reverted
   by ca2fb95 (#1847). Without them, gateway reads host paths from .env via
   env_file, causing FileNotFoundError inside the container.

2. Langgraph command fails when LANGGRAPH_ALLOW_BLOCKING is unset (default).
   Empty $${allow_blocking} inserts a bare space between flags, causing
   ' --no-reload' to be parsed as unexpected extra argument. Fix by building
   args string first and conditionally appending --allow-blocking.

Co-authored-by: cooper <cooperfu@tencent.com>

* fix(frontend): resolve invalid HTML nesting and tabnabbing vulnerabilities (#1904)

* fix(frontend): resolve invalid HTML nesting and tabnabbing vulnerabilities

Fix `<button>` inside `<a>` invalid HTML in artifact components and add
missing `noopener,noreferrer` to `window.open` calls to prevent reverse
tabnabbing.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(frontend): address Copilot review on tabnabbing and double-tab-open

Remove redundant parent onClick on web_fetch ChainOfThoughtStep to
prevent opening two tabs on link click, and explicitly null out
window.opener after window.open() for defensive tabnabbing hardening.

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* refactor(persistence): organize entities into per-entity directories

Restructure the persistence layer from horizontal "models/ + repositories/"
split into vertical entity-aligned directories. Each entity (thread_meta,
run, feedback) now owns its ORM model, abstract interface (where applicable),
and concrete implementations under a single directory with an aggregating
__init__.py for one-line imports.

Layout:
  persistence/thread_meta/{base,model,sql,memory}.py
  persistence/run/{model,sql}.py
  persistence/feedback/{model,sql}.py

models/__init__.py is kept as a facade so Alembic autogenerate continues to
discover all ORM tables via Base.metadata. RunEventRow remains under
models/run_event.py because its storage implementation lives in
runtime/events/store/db.py and has no matching repository directory.

The repositories/ directory is removed entirely. All call sites in
gateway/deps.py and tests are updated to import from the new entity
packages, e.g.:

    from deerflow.persistence.thread_meta import ThreadMetaRepository
    from deerflow.persistence.run import RunRepository
    from deerflow.persistence.feedback import FeedbackRepository

Full test suite passes (1690 passed, 14 skipped).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(gateway): sync thread rename and delete through ThreadMetaStore

The POST /threads/{id}/state endpoint previously synced title changes
only to the LangGraph Store via _store_upsert. In sqlite mode the search
endpoint reads from the ThreadMetaRepository SQL table, so renames never
appeared in /threads/search until the next agent run completed (worker.py
syncs title from checkpoint to thread_meta in its finally block).

Likewise the DELETE /threads/{id} endpoint cleaned up the filesystem,
Store, and checkpointer but left the threads_meta row orphaned in sqlite,
so deleted threads kept appearing in /threads/search.

Fix both endpoints by routing through the ThreadMetaStore abstraction
which already has the correct sqlite/memory implementations wired up by
deps.py. The rename path now calls update_display_name() and the delete
path calls delete() — both work uniformly across backends.

Verified end-to-end with curl in gateway mode against sqlite backend.
Existing test suite (1690 passed) and focused router/repo tests pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(gateway): route all thread metadata access through ThreadMetaStore

Following the rename/delete bug fix in PR1, migrate the remaining direct
LangGraph Store reads/writes in the threads router and services to the
ThreadMetaStore abstraction so that the sqlite and memory backends behave
identically and the legacy dual-write paths can be removed.

Migrated endpoints (threads.py):
- create_thread: idempotency check + write now use thread_meta_repo.get/create
  instead of dual-writing the LangGraph Store and the SQL row.
- get_thread: reads from thread_meta_repo.get; the checkpoint-only fallback
  for legacy threads is preserved.
- patch_thread: replaced _store_get/_store_put with thread_meta_repo.update_metadata.
- delete_thread_data: dropped the legacy store.adelete; thread_meta_repo.delete
  already covers it.

Removed dead code (services.py):
- _upsert_thread_in_store — redundant with the immediately following
  thread_meta_repo.create() call.
- _sync_thread_title_after_run — worker.py's finally block already syncs
  the title via thread_meta_repo.update_display_name() after each run.

Removed dead code (threads.py):
- _store_get / _store_put / _store_upsert helpers (no remaining callers).
- THREADS_NS constant.
- get_store import (router no longer touches the LangGraph Store directly).

New abstract method:
- ThreadMetaStore.update_metadata(thread_id, metadata) merges metadata into
  the thread's metadata field. Implemented in both ThreadMetaRepository (SQL,
  read-modify-write inside one session) and MemoryThreadMetaStore. Three new
  unit tests cover merge / empty / nonexistent behaviour.

Net change: -134 lines. Full test suite: 1693 passed, 14 skipped.
Verified end-to-end with curl in gateway mode against sqlite backend
(create / patch / get / rename / search / delete).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
Co-authored-by: DanielWalnut <45447813+hetaoBackend@users.noreply.github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: JilongSun <965640067@qq.com>
Co-authored-by: jie <49781832+stan-fu@users.noreply.github.com>
Co-authored-by: cooper <cooperfu@tencent.com>
Co-authored-by: yangzheli <43645580+yangzheli@users.noreply.github.com>
2026-04-07 11:53:52 +08:00
249 changed files with 16588 additions and 6845 deletions
+4 -1
View File
@@ -24,7 +24,6 @@ INFOQUEST_API_KEY=your-infoquest-api-key
# SLACK_BOT_TOKEN=your-slack-bot-token
# SLACK_APP_TOKEN=your-slack-app-token
# TELEGRAM_BOT_TOKEN=your-telegram-bot-token
# DISCORD_BOT_TOKEN=your-discord-bot-token
# Enable LangSmith to monitor and debug your LLM calls, agent runs, and tool executions.
# LANGSMITH_TRACING=true
@@ -34,5 +33,9 @@ INFOQUEST_API_KEY=your-infoquest-api-key
# GitHub API Token
# GITHUB_TOKEN=your-github-token
# Database (only needed when config.yaml has database.backend: postgres)
# DATABASE_URL=postgresql://deerflow:password@localhost:5432/deerflow
#
# WECOM_BOT_ID=your-wecom-bot-id
# WECOM_BOT_SECRET=your-wecom-bot-secret
-63
View File
@@ -1,63 +0,0 @@
name: E2E Tests
on:
push:
branches: [ 'main' ]
paths:
- 'frontend/**'
- '.github/workflows/e2e-tests.yml'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'frontend/**'
- '.github/workflows/e2e-tests.yml'
concurrency:
group: e2e-tests-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
e2e-tests:
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }}
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Enable Corepack
run: corepack enable
- name: Use pinned pnpm version
run: corepack prepare pnpm@10.26.2 --activate
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Install Playwright Chromium
working-directory: frontend
run: npx playwright install chromium --with-deps
- name: Run E2E tests
working-directory: frontend
run: pnpm exec playwright test
env:
SKIP_ENV_VALIDATION: '1'
- name: Upload Playwright report
uses: actions/upload-artifact@v4
if: ${{ !cancelled() }}
with:
name: playwright-report
path: frontend/playwright-report/
retention-days: 7
-43
View File
@@ -1,43 +0,0 @@
name: Frontend Unit Tests
on:
push:
branches: [ 'main' ]
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
concurrency:
group: frontend-unit-tests-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
frontend-unit-tests:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Enable Corepack
run: corepack enable
- name: Use pinned pnpm version
run: corepack prepare pnpm@10.26.2 --activate
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Run unit tests of frontend
working-directory: frontend
run: make test
-2
View File
@@ -55,7 +55,5 @@ web/
backend/Dockerfile.langgraph
config.yaml.bak
.playwright-mcp
/frontend/test-results/
/frontend/playwright-report/
.gstack/
.worktrees
+6 -11
View File
@@ -298,24 +298,19 @@ Nginx (port 2026) ← Unified entry point
```bash
# Backend tests
cd backend
make test
uv run pytest
# Frontend unit tests
# Frontend checks
cd frontend
make test
# Frontend E2E tests (requires Chromium; builds and auto-starts the Next.js production server)
cd frontend
make test-e2e
pnpm check
```
### PR Regression Checks
Every pull request triggers the following CI workflows:
Every pull request runs the backend regression workflow at [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml), including:
- **Backend unit tests** — [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml)
- **Frontend unit tests** — [.github/workflows/frontend-unit-tests.yml](.github/workflows/frontend-unit-tests.yml)
- **Frontend E2E tests** — [.github/workflows/e2e-tests.yml](.github/workflows/e2e-tests.yml) (triggered only when `frontend/` files change)
- `tests/test_provisioner_kubeconfig.py`
- `tests/test_docker_sandbox_mode_detection.py`
## Code Style
-2
View File
@@ -658,8 +658,6 @@ This is the difference between a chatbot with tool access and an agent with an a
**Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window.
**Strict Tool-Call Recovery**: When a provider or middleware interrupts a tool-call loop, DeerFlow now strips provider-level raw tool-call metadata on forced-stop assistant messages and injects placeholder tool results for dangling calls before the next model invocation. This keeps OpenAI-compatible reasoning models that strictly validate `tool_call_id` sequences from failing with malformed history errors.
### Long-Term Memory
Most agents forget everything the moment a conversation ends. DeerFlow remembers.
+10 -16
View File
@@ -156,26 +156,20 @@ from deerflow.config import get_app_config
### Middleware Chain
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`_build_middlewares`):
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption), including raw provider tool-call payloads preserved only in `additional_kwargs["tool_calls"]`
5. **LLMErrorHandlingMiddleware** - Normalizes provider/model invocation failures into recoverable assistant-facing errors before later middleware/tool stages run
6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
15. **DeferredToolFilterMiddleware** - Hides deferred tool schemas from the bound model until tool search is enabled (optional)
16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
5. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
6. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
7. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
8. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
9. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
10. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
11. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if subagent_enabled)
12. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
### Configuration System
+5 -1
View File
@@ -13,6 +13,9 @@ FROM python:3.12-slim-bookworm AS builder
ARG NODE_MAJOR=22
ARG APT_MIRROR
ARG UV_INDEX_URL
# Optional extras to install (e.g. "postgres" for PostgreSQL support)
# Usage: docker build --build-arg UV_EXTRAS=postgres ...
ARG UV_EXTRAS
# Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.aliyun.com)
RUN if [ -n "${APT_MIRROR}" ]; then \
@@ -43,8 +46,9 @@ WORKDIR /app
COPY backend ./backend
# Install dependencies with cache mount
# When UV_EXTRAS is set (e.g. "postgres"), installs optional dependencies.
RUN --mount=type=cache,target=/root/.cache/uv \
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync"
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
# Retains compiler toolchain from builder so startup-time `uv sync` can build
-273
View File
@@ -1,273 +0,0 @@
"""Discord channel integration using discord.py."""
from __future__ import annotations
import asyncio
import logging
import threading
from typing import Any
from app.channels.base import Channel
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
_DISCORD_MAX_MESSAGE_LEN = 2000
class DiscordChannel(Channel):
"""Discord bot channel.
Configuration keys (in ``config.yaml`` under ``channels.discord``):
- ``bot_token``: Discord Bot token.
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
"""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
super().__init__(name="discord", bus=bus, config=config)
self._bot_token = str(config.get("bot_token", "")).strip()
self._allowed_guilds: set[int] = set()
for guild_id in config.get("allowed_guilds", []):
try:
self._allowed_guilds.add(int(guild_id))
except (TypeError, ValueError):
continue
self._client = None
self._thread: threading.Thread | None = None
self._discord_loop: asyncio.AbstractEventLoop | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None
self._discord_module = None
async def start(self) -> None:
if self._running:
return
try:
import discord
except ImportError:
logger.error("discord.py is not installed. Install it with: uv add discord.py")
return
if not self._bot_token:
logger.error("Discord channel requires bot_token")
return
intents = discord.Intents.default()
intents.messages = True
intents.guilds = True
intents.message_content = True
client = discord.Client(
intents=intents,
allowed_mentions=discord.AllowedMentions.none(),
)
self._client = client
self._discord_module = discord
self._main_loop = asyncio.get_event_loop()
@client.event
async def on_message(message) -> None:
await self._on_message(message)
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
self._thread = threading.Thread(target=self._run_client, daemon=True)
self._thread.start()
logger.info("Discord channel started")
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
if self._client and self._discord_loop and self._discord_loop.is_running():
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
try:
await asyncio.wait_for(asyncio.wrap_future(close_future), timeout=10)
except TimeoutError:
logger.warning("[Discord] client close timed out after 10s")
except Exception:
logger.exception("[Discord] error while closing client")
if self._thread:
self._thread.join(timeout=10)
self._thread = None
self._client = None
self._discord_loop = None
self._discord_module = None
logger.info("Discord channel stopped")
async def send(self, msg: OutboundMessage) -> None:
target = await self._resolve_target(msg)
if target is None:
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
return
text = msg.text or ""
for chunk in self._split_text(text):
send_future = asyncio.run_coroutine_threadsafe(target.send(chunk), self._discord_loop)
await asyncio.wrap_future(send_future)
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
target = await self._resolve_target(msg)
if target is None:
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
return False
if self._discord_module is None:
return False
try:
fp = open(str(attachment.actual_path), "rb") # noqa: SIM115
file = self._discord_module.File(fp, filename=attachment.filename)
send_future = asyncio.run_coroutine_threadsafe(target.send(file=file), self._discord_loop)
await asyncio.wrap_future(send_future)
logger.info("[Discord] file uploaded: %s", attachment.filename)
return True
except Exception:
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
return False
async def _on_message(self, message) -> None:
if not self._running or not self._client:
return
if message.author.bot:
return
if self._client.user and message.author.id == self._client.user.id:
return
guild = message.guild
if self._allowed_guilds:
if guild is None or guild.id not in self._allowed_guilds:
return
text = (message.content or "").strip()
if not text:
return
if self._discord_module is None:
return
if isinstance(message.channel, self._discord_module.Thread):
chat_id = str(message.channel.parent_id or message.channel.id)
thread_id = str(message.channel.id)
else:
thread = await self._create_thread(message)
if thread is None:
return
chat_id = str(message.channel.id)
thread_id = str(thread.id)
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=chat_id,
user_id=str(message.author.id),
text=text,
msg_type=msg_type,
thread_ts=thread_id,
metadata={
"guild_id": str(guild.id) if guild else None,
"channel_id": str(message.channel.id),
"message_id": str(message.id),
},
)
inbound.topic_id = thread_id
if self._main_loop and self._main_loop.is_running():
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
def _run_client(self) -> None:
self._discord_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._discord_loop)
try:
self._discord_loop.run_until_complete(self._client.start(self._bot_token))
except Exception:
if self._running:
logger.exception("Discord client error")
finally:
try:
if self._client and not self._client.is_closed():
self._discord_loop.run_until_complete(self._client.close())
except Exception:
logger.exception("Error during Discord shutdown")
async def _create_thread(self, message):
try:
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
return await message.create_thread(name=thread_name)
except Exception:
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
try:
await message.channel.send("Could not create a thread for your message. Please check that threads are enabled in this channel.")
except Exception:
pass
return None
async def _resolve_target(self, msg: OutboundMessage):
if not self._client or not self._discord_loop:
return None
target_ids: list[str] = []
if msg.thread_ts:
target_ids.append(msg.thread_ts)
if msg.chat_id and msg.chat_id not in target_ids:
target_ids.append(msg.chat_id)
for raw_id in target_ids:
target = await self._get_channel_or_thread(raw_id)
if target is not None:
return target
return None
async def _get_channel_or_thread(self, raw_id: str):
if not self._client or not self._discord_loop:
return None
try:
target_id = int(raw_id)
except (TypeError, ValueError):
return None
get_future = asyncio.run_coroutine_threadsafe(self._fetch_channel(target_id), self._discord_loop)
try:
return await asyncio.wrap_future(get_future)
except Exception:
logger.exception("[Discord] failed to resolve target id=%s", raw_id)
return None
async def _fetch_channel(self, target_id: int):
if not self._client:
return None
channel = self._client.get_channel(target_id)
if channel is not None:
return channel
try:
return await self._client.fetch_channel(target_id)
except Exception:
return None
@staticmethod
def _split_text(text: str) -> list[str]:
if not text:
return [""]
chunks: list[str] = []
remaining = text
while len(remaining) > _DISCORD_MAX_MESSAGE_LEN:
split_at = remaining.rfind("\n", 0, _DISCORD_MAX_MESSAGE_LEN)
if split_at <= 0:
split_at = _DISCORD_MAX_MESSAGE_LEN
chunks.append(remaining[:split_at])
remaining = remaining[split_at:].lstrip("\n")
if remaining:
chunks.append(remaining)
return chunks
-1
View File
@@ -35,7 +35,6 @@ STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
CHANNEL_CAPABILITIES = {
"discord": {"supports_streaming": False},
"feishu": {"supports_streaming": True},
"slack": {"supports_streaming": False},
"telegram": {"supports_streaming": False},
-1
View File
@@ -15,7 +15,6 @@ logger = logging.getLogger(__name__)
# Channel name → import path for lazy loading
_CHANNEL_REGISTRY: dict[str, str] = {
"discord": "app.channels.discord:DiscordChannel",
"feishu": "app.channels.feishu:FeishuChannel",
"slack": "app.channels.slack:SlackChannel",
"telegram": "app.channels.telegram:TelegramChannel",
+161 -1
View File
@@ -1,16 +1,23 @@
import logging
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import UTC
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.gateway.auth_middleware import AuthMiddleware
from app.gateway.config import get_gateway_config
from app.gateway.csrf_middleware import CSRFMiddleware
from app.gateway.deps import langgraph_runtime
from app.gateway.routers import (
agents,
artifacts,
assistants_compat,
auth,
channels,
feedback,
mcp,
memory,
models,
@@ -33,6 +40,125 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
async def _ensure_admin_user(app: FastAPI) -> None:
"""Auto-create the admin user on first boot if no users exist.
After admin creation, migrate orphan threads from the LangGraph
store (metadata.owner_id unset) to the admin account. This is the
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
authentication have existing LangGraph thread data that needs an
owner assigned.
No SQL persistence migration is needed: the four owner_id columns
(threads_meta, runs, run_events, feedback) only come into existence
alongside the auth module via create_all, so freshly created tables
never contain NULL-owner rows. "Existing persistence DB + new auth"
is not a supported upgrade path — fresh install or wipe-and-retry.
Multi-worker safe: relies on SQLite UNIQUE constraint to resolve
races during admin creation. Only the worker that successfully
creates/updates the admin prints the password; losers silently skip.
"""
import secrets
from app.gateway.auth.credential_file import write_initial_credentials
from app.gateway.deps import get_local_provider
def _announce_credentials(email: str, password: str, *, label: str, headline: str) -> None:
"""Write the password to a 0600 file and log the path (never the secret)."""
cred_path = write_initial_credentials(email, password, label=label)
logger.info("=" * 60)
logger.info(" %s", headline)
logger.info(" Credentials written to: %s (mode 0600)", cred_path)
logger.info(" Change it after login: Settings -> Account")
logger.info("=" * 60)
provider = get_local_provider()
user_count = await provider.count_users()
admin = None
if user_count == 0:
password = secrets.token_urlsafe(16)
try:
admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True)
except ValueError:
return # Another worker already created the admin.
_announce_credentials(admin.email, password, label="initial", headline="Admin account created on first boot")
else:
# Admin exists but setup never completed — reset password so operator
# can always find it in the console without needing the CLI.
# Multi-worker guard: if admin was created less than 30s ago, another
# worker just created it and will print the password — skip reset.
admin = await provider.get_user_by_email("admin@deerflow.dev")
if admin and admin.needs_setup:
import time
age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp()
if age >= 30:
from app.gateway.auth.password import hash_password_async
password = secrets.token_urlsafe(16)
admin.password_hash = await hash_password_async(password)
admin.token_version += 1
await provider.update_user(admin)
_announce_credentials(admin.email, password, label="reset", headline="Admin account setup incomplete — password reset")
if admin is None:
return # Nothing to bind orphans to.
admin_id = str(admin.id)
# LangGraph store orphan migration — non-fatal.
# This covers the "no-auth → with-auth" upgrade path for users
# whose existing LangGraph thread metadata has no owner_id set.
store = getattr(app.state, "store", None)
if store is not None:
try:
migrated = await _migrate_orphaned_threads(store, admin_id)
if migrated:
logger.info("Migrated %d orphan LangGraph thread(s) to admin", migrated)
except Exception:
logger.exception("LangGraph thread migration failed (non-fatal)")
async def _iter_store_items(store, namespace, *, page_size: int = 500):
"""Paginated async iterator over a LangGraph store namespace.
Replaces the old hardcoded ``limit=1000`` call with a cursor-style
loop so that environments with more than one page of orphans do
not silently lose data. Terminates when a page is empty OR when a
short page arrives (indicating the last page).
"""
offset = 0
while True:
batch = await store.asearch(namespace, limit=page_size, offset=offset)
if not batch:
return
for item in batch:
yield item
if len(batch) < page_size:
return
offset += page_size
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
"""Migrate LangGraph store threads with no owner_id to the given admin.
Uses cursor pagination so all orphans are migrated regardless of
count. Returns the number of rows migrated.
"""
migrated = 0
async for item in _iter_store_items(store, ("threads",)):
metadata = item.value.get("metadata", {})
if not metadata.get("owner_id"):
metadata["owner_id"] = admin_user_id
item.value["metadata"] = metadata
await store.aput(("threads",), item.key, item.value)
migrated += 1
return migrated
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler."""
@@ -52,6 +178,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
async with langgraph_runtime(app):
logger.info("LangGraph runtime initialised")
# Ensure admin user exists (auto-create on first boot)
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
await _ensure_admin_user(app)
# Start IM channel service if any channels are configured
try:
from app.channels.service import start_channel_service
@@ -163,7 +293,31 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
],
)
# CORS is handled by nginx - no need for FastAPI middleware
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
app.add_middleware(AuthMiddleware)
# CSRF: Double Submit Cookie pattern for state-changing requests
app.add_middleware(CSRFMiddleware)
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware.
# In production, nginx handles CORS and no middleware is needed.
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
if cors_origins_env:
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
# Validate: wildcard origin with credentials is a security misconfiguration
for origin in cors_origins:
if origin == "*":
logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.")
cors_origins = [o for o in cors_origins if o != "*"]
break
if cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
# Models API is mounted at /api/models
@@ -199,6 +353,12 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
# Assistants compatibility API (LangGraph Platform stub)
app.include_router(assistants_compat.router)
# Auth API is mounted at /api/v1/auth
app.include_router(auth.router)
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
app.include_router(feedback.router)
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
app.include_router(thread_runs.router)
+42
View File
@@ -0,0 +1,42 @@
"""Authentication module for DeerFlow.
This module provides:
- JWT-based authentication
- Provider Factory pattern for extensible auth methods
- UserRepository interface for storage backends (SQLite)
"""
from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.models import User, UserResponse
from app.gateway.auth.password import hash_password, verify_password
from app.gateway.auth.providers import AuthProvider
from app.gateway.auth.repositories.base import UserRepository
__all__ = [
# Config
"AuthConfig",
"get_auth_config",
"set_auth_config",
# Errors
"AuthErrorCode",
"AuthErrorResponse",
"TokenError",
# JWT
"TokenPayload",
"create_access_token",
"decode_token",
# Password
"hash_password",
"verify_password",
# Models
"User",
"UserResponse",
# Providers
"AuthProvider",
"LocalAuthProvider",
# Repository
"UserRepository",
]
+57
View File
@@ -0,0 +1,57 @@
"""Authentication configuration for DeerFlow."""
import logging
import os
import secrets
from dotenv import load_dotenv
from pydantic import BaseModel, Field
load_dotenv()
logger = logging.getLogger(__name__)
class AuthConfig(BaseModel):
"""JWT and auth-related configuration. Parsed once at startup.
Note: the ``users`` table now lives in the shared persistence
database managed by ``deerflow.persistence.engine``. The old
``users_db_path`` config key has been removed — user storage is
configured through ``config.database`` like every other table.
"""
jwt_secret: str = Field(
...,
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
)
token_expiry_days: int = Field(default=7, ge=1, le=30)
oauth_github_client_id: str | None = Field(default=None)
oauth_github_client_secret: str | None = Field(default=None)
_auth_config: AuthConfig | None = None
def get_auth_config() -> AuthConfig:
"""Get the global AuthConfig instance. Parses from env on first call."""
global _auth_config
if _auth_config is None:
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
if not jwt_secret:
jwt_secret = secrets.token_urlsafe(32)
os.environ["AUTH_JWT_SECRET"] = jwt_secret
logger.warning(
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
"Sessions will be invalidated on restart. "
"For production, add AUTH_JWT_SECRET to your .env file: "
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
)
_auth_config = AuthConfig(jwt_secret=jwt_secret)
return _auth_config
def set_auth_config(config: AuthConfig) -> None:
"""Set the global AuthConfig instance (for testing)."""
global _auth_config
_auth_config = config
@@ -0,0 +1,48 @@
"""Write initial admin credentials to a restricted file instead of logs.
Logging secrets to stdout/stderr is a well-known CodeQL finding
(py/clear-text-logging-sensitive-data) — in production those logs
get collected into ELK/Splunk/etc and become a secret sprawl
source. This helper writes the credential to a 0600 file that only
the process user can read, and returns the path so the caller can
log **the path** (not the password) for the operator to pick up.
"""
from __future__ import annotations
import os
from pathlib import Path
from deerflow.config.paths import get_paths
_CREDENTIAL_FILENAME = "admin_initial_credentials.txt"
def write_initial_credentials(email: str, password: str, *, label: str = "initial") -> Path:
"""Write the admin email + password to ``{base_dir}/admin_initial_credentials.txt``.
The file is created **atomically** with mode 0600 via ``os.open``
so the password is never world-readable, even for the single syscall
window between ``write_text`` and ``chmod``.
``label`` distinguishes "initial" (fresh creation) from "reset"
(password reset) in the file header so an operator picking up the
file after a restart can tell which event produced it.
Returns the absolute :class:`Path` to the file.
"""
target = get_paths().base_dir / _CREDENTIAL_FILENAME
target.parent.mkdir(parents=True, exist_ok=True)
content = (
f"# DeerFlow admin {label} credentials\n# This file is generated on first boot or password reset.\n# Change the password after login via Settings -> Account,\n# then delete this file.\n#\nemail: {email}\npassword: {password}\n"
)
# Atomic 0600 create-or-truncate. O_TRUNC (not O_EXCL) so the
# reset-password path can rewrite an existing file without a
# separate unlink-then-create dance.
fd = os.open(target, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as fh:
fh.write(content)
return target.resolve()
+44
View File
@@ -0,0 +1,44 @@
"""Typed error definitions for auth module.
AuthErrorCode: exhaustive enum of all auth failure conditions.
TokenError: exhaustive enum of JWT decode failures.
AuthErrorResponse: structured error payload for HTTP responses.
"""
from enum import StrEnum
from pydantic import BaseModel
class AuthErrorCode(StrEnum):
"""Exhaustive list of auth error conditions."""
INVALID_CREDENTIALS = "invalid_credentials"
TOKEN_EXPIRED = "token_expired"
TOKEN_INVALID = "token_invalid"
USER_NOT_FOUND = "user_not_found"
EMAIL_ALREADY_EXISTS = "email_already_exists"
PROVIDER_NOT_FOUND = "provider_not_found"
NOT_AUTHENTICATED = "not_authenticated"
class TokenError(StrEnum):
"""Exhaustive list of JWT decode failure reasons."""
EXPIRED = "expired"
INVALID_SIGNATURE = "invalid_signature"
MALFORMED = "malformed"
class AuthErrorResponse(BaseModel):
"""Structured error response — replaces bare `detail` strings."""
code: AuthErrorCode
message: str
def token_error_to_code(err: TokenError) -> AuthErrorCode:
"""Map TokenError to AuthErrorCode — single source of truth."""
if err == TokenError.EXPIRED:
return AuthErrorCode.TOKEN_EXPIRED
return AuthErrorCode.TOKEN_INVALID
+55
View File
@@ -0,0 +1,55 @@
"""JWT token creation and verification."""
from datetime import UTC, datetime, timedelta
import jwt
from pydantic import BaseModel
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.errors import TokenError
class TokenPayload(BaseModel):
"""JWT token payload."""
sub: str # user_id
exp: datetime
iat: datetime | None = None
ver: int = 0 # token_version — must match User.token_version
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
"""Create a JWT access token.
Args:
user_id: The user's UUID as string
expires_delta: Optional custom expiry, defaults to 7 days
token_version: User's current token_version for invalidation
Returns:
Encoded JWT string
"""
config = get_auth_config()
expiry = expires_delta or timedelta(days=config.token_expiry_days)
now = datetime.now(UTC)
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
def decode_token(token: str) -> TokenPayload | TokenError:
"""Decode and validate a JWT token.
Returns:
TokenPayload if valid, or a specific TokenError variant.
"""
config = get_auth_config()
try:
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
return TokenPayload(**payload)
except jwt.ExpiredSignatureError:
return TokenError.EXPIRED
except jwt.InvalidSignatureError:
return TokenError.INVALID_SIGNATURE
except jwt.PyJWTError:
return TokenError.MALFORMED
@@ -0,0 +1,87 @@
"""Local email/password authentication provider."""
from app.gateway.auth.models import User
from app.gateway.auth.password import hash_password_async, verify_password_async
from app.gateway.auth.providers import AuthProvider
from app.gateway.auth.repositories.base import UserRepository
class LocalAuthProvider(AuthProvider):
"""Email/password authentication provider using local database."""
def __init__(self, repository: UserRepository):
"""Initialize with a UserRepository.
Args:
repository: UserRepository implementation (SQLite)
"""
self._repo = repository
async def authenticate(self, credentials: dict) -> User | None:
"""Authenticate with email and password.
Args:
credentials: dict with 'email' and 'password' keys
Returns:
User if authentication succeeds, None otherwise
"""
email = credentials.get("email")
password = credentials.get("password")
if not email or not password:
return None
user = await self._repo.get_user_by_email(email)
if user is None:
return None
if user.password_hash is None:
# OAuth user without local password
return None
if not await verify_password_async(password, user.password_hash):
return None
return user
async def get_user(self, user_id: str) -> User | None:
"""Get user by ID."""
return await self._repo.get_user_by_id(user_id)
async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User:
"""Create a new local user.
Args:
email: User email address
password: Plain text password (will be hashed)
system_role: Role to assign ("admin" or "user")
needs_setup: If True, user must complete setup on first login
Returns:
Created User instance
"""
password_hash = await hash_password_async(password) if password else None
user = User(
email=email,
password_hash=password_hash,
system_role=system_role,
needs_setup=needs_setup,
)
return await self._repo.create_user(user)
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID."""
return await self._repo.get_user_by_oauth(provider, oauth_id)
async def count_users(self) -> int:
"""Return total number of registered users."""
return await self._repo.count_users()
async def update_user(self, user: User) -> User:
"""Update an existing user."""
return await self._repo.update_user(user)
async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email."""
return await self._repo.get_user_by_email(email)
+41
View File
@@ -0,0 +1,41 @@
"""User Pydantic models for authentication."""
from datetime import UTC, datetime
from typing import Literal
from uuid import UUID, uuid4
from pydantic import BaseModel, ConfigDict, EmailStr, Field
def _utc_now() -> datetime:
"""Return current UTC time (timezone-aware)."""
return datetime.now(UTC)
class User(BaseModel):
"""Internal user representation."""
model_config = ConfigDict(from_attributes=True)
id: UUID = Field(default_factory=uuid4, description="Primary key")
email: EmailStr = Field(..., description="Unique email address")
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
system_role: Literal["admin", "user"] = Field(default="user")
created_at: datetime = Field(default_factory=_utc_now)
# OAuth linkage (optional)
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
# Auth lifecycle
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
class UserResponse(BaseModel):
"""Response model for user info endpoint."""
id: str
email: str
system_role: Literal["admin", "user"]
needs_setup: bool = False
+33
View File
@@ -0,0 +1,33 @@
"""Password hashing utilities using bcrypt directly."""
import asyncio
import bcrypt
def hash_password(password: str) -> str:
"""Hash a password using bcrypt."""
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
async def hash_password_async(password: str) -> str:
"""Hash a password using bcrypt (non-blocking).
Wraps the blocking bcrypt operation in a thread pool to avoid
blocking the event loop during password hashing.
"""
return await asyncio.to_thread(hash_password, password)
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash (non-blocking).
Wraps the blocking bcrypt operation in a thread pool to avoid
blocking the event loop during password verification.
"""
return await asyncio.to_thread(verify_password, plain_password, hashed_password)
+24
View File
@@ -0,0 +1,24 @@
"""Auth provider abstraction."""
from abc import ABC, abstractmethod
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def authenticate(self, credentials: dict) -> "User | None":
"""Authenticate user with given credentials.
Returns User if authentication succeeds, None otherwise.
"""
...
@abstractmethod
async def get_user(self, user_id: str) -> "User | None":
"""Retrieve user by ID."""
...
# Import User at runtime to avoid circular imports
from app.gateway.auth.models import User # noqa: E402
@@ -0,0 +1,97 @@
"""User repository interface for abstracting database operations."""
from abc import ABC, abstractmethod
from app.gateway.auth.models import User
class UserNotFoundError(LookupError):
"""Raised when a user repository operation targets a non-existent row.
Subclass of :class:`LookupError` so callers that already catch
``LookupError`` for "missing entity" can keep working unchanged,
while specific call sites can pin to this class to distinguish
"concurrent delete during update" from other lookups.
"""
class UserRepository(ABC):
"""Abstract interface for user data storage.
Implement this interface to support different storage backends
(SQLite)
"""
@abstractmethod
async def create_user(self, user: User) -> User:
"""Create a new user.
Args:
user: User object to create
Returns:
Created User with ID assigned
Raises:
ValueError: If email already exists
"""
...
@abstractmethod
async def get_user_by_id(self, user_id: str) -> User | None:
"""Get user by ID.
Args:
user_id: User UUID as string
Returns:
User if found, None otherwise
"""
...
@abstractmethod
async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email.
Args:
email: User email address
Returns:
User if found, None otherwise
"""
...
@abstractmethod
async def update_user(self, user: User) -> User:
"""Update an existing user.
Args:
user: User object with updated fields
Returns:
Updated User
Raises:
UserNotFoundError: If no row exists for ``user.id``. This is
a hard failure (not a no-op) so callers cannot mistake a
concurrent-delete race for a successful update.
"""
...
@abstractmethod
async def count_users(self) -> int:
"""Return total number of registered users."""
...
@abstractmethod
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID.
Args:
provider: OAuth provider name (e.g. 'github', 'google')
oauth_id: User ID from the OAuth provider
Returns:
User if found, None otherwise
"""
...
@@ -0,0 +1,122 @@
"""SQLAlchemy-backed UserRepository implementation.
Uses the shared async session factory from
``deerflow.persistence.engine`` — the ``users`` table lives in the
same database as ``threads_meta``, ``runs``, ``run_events``, and
``feedback``.
Constructor takes the session factory directly (same pattern as the
other four repositories in ``deerflow.persistence.*``). Callers
construct this after ``init_engine_from_config()`` has run.
"""
from __future__ import annotations
from datetime import UTC
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.gateway.auth.models import User
from app.gateway.auth.repositories.base import UserNotFoundError, UserRepository
from deerflow.persistence.user.model import UserRow
class SQLiteUserRepository(UserRepository):
"""Async user repository backed by the shared SQLAlchemy engine."""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
# ── Converters ────────────────────────────────────────────────────
@staticmethod
def _row_to_user(row: UserRow) -> User:
return User(
id=UUID(row.id),
email=row.email,
password_hash=row.password_hash,
system_role=row.system_role, # type: ignore[arg-type]
# SQLite loses tzinfo on read; reattach UTC so downstream
# code can compare timestamps reliably.
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
oauth_provider=row.oauth_provider,
oauth_id=row.oauth_id,
needs_setup=row.needs_setup,
token_version=row.token_version,
)
@staticmethod
def _user_to_row(user: User) -> UserRow:
return UserRow(
id=str(user.id),
email=user.email,
password_hash=user.password_hash,
system_role=user.system_role,
created_at=user.created_at,
oauth_provider=user.oauth_provider,
oauth_id=user.oauth_id,
needs_setup=user.needs_setup,
token_version=user.token_version,
)
# ── CRUD ──────────────────────────────────────────────────────────
async def create_user(self, user: User) -> User:
"""Insert a new user. Raises ``ValueError`` on duplicate email."""
row = self._user_to_row(user)
async with self._sf() as session:
session.add(row)
try:
await session.commit()
except IntegrityError as exc:
await session.rollback()
raise ValueError(f"Email already registered: {user.email}") from exc
return user
async def get_user_by_id(self, user_id: str) -> User | None:
async with self._sf() as session:
row = await session.get(UserRow, user_id)
return self._row_to_user(row) if row is not None else None
async def get_user_by_email(self, email: str) -> User | None:
stmt = select(UserRow).where(UserRow.email == email)
async with self._sf() as session:
result = await session.execute(stmt)
row = result.scalar_one_or_none()
return self._row_to_user(row) if row is not None else None
async def update_user(self, user: User) -> User:
async with self._sf() as session:
row = await session.get(UserRow, str(user.id))
if row is None:
# Hard fail on concurrent delete: callers (reset_admin,
# password change handlers, _ensure_admin_user) all
# fetched the user just before this call, so a missing
# row here means the row vanished underneath us. Silent
# success would let the caller log "password reset" for
# a row that no longer exists.
raise UserNotFoundError(f"User {user.id} no longer exists")
row.email = user.email
row.password_hash = user.password_hash
row.system_role = user.system_role
row.oauth_provider = user.oauth_provider
row.oauth_id = user.oauth_id
row.needs_setup = user.needs_setup
row.token_version = user.token_version
await session.commit()
return user
async def count_users(self) -> int:
stmt = select(func.count()).select_from(UserRow)
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
async with self._sf() as session:
result = await session.execute(stmt)
row = result.scalar_one_or_none()
return self._row_to_user(row) if row is not None else None
+91
View File
@@ -0,0 +1,91 @@
"""CLI tool to reset an admin password.
Usage:
python -m app.gateway.auth.reset_admin
python -m app.gateway.auth.reset_admin --email admin@example.com
Writes the new password to ``.deer-flow/admin_initial_credentials.txt``
(mode 0600) instead of printing it, so CI / log aggregators never see
the cleartext secret.
"""
from __future__ import annotations
import argparse
import asyncio
import secrets
import sys
from sqlalchemy import select
from app.gateway.auth.credential_file import write_initial_credentials
from app.gateway.auth.password import hash_password
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.user.model import UserRow
async def _run(email: str | None) -> int:
from deerflow.config import get_app_config
from deerflow.persistence.engine import (
close_engine,
get_session_factory,
init_engine_from_config,
)
config = get_app_config()
await init_engine_from_config(config.database)
try:
sf = get_session_factory()
if sf is None:
print("Error: persistence engine not available (check config.database).", file=sys.stderr)
return 1
repo = SQLiteUserRepository(sf)
if email:
user = await repo.get_user_by_email(email)
else:
# Find first admin via direct SELECT — repository does not
# expose a "first admin" helper and we do not want to add
# one just for this CLI.
async with sf() as session:
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
user = None
else:
user = await repo.get_user_by_id(row.id)
if user is None:
if email:
print(f"Error: user '{email}' not found.", file=sys.stderr)
else:
print("Error: no admin user found.", file=sys.stderr)
return 1
new_password = secrets.token_urlsafe(16)
user.password_hash = hash_password(new_password)
user.token_version += 1
user.needs_setup = True
await repo.update_user(user)
cred_path = write_initial_credentials(user.email, new_password, label="reset")
print(f"Password reset for: {user.email}")
print(f"Credentials written to: {cred_path} (mode 0600)")
print("Next login will require setup (new email + password).")
return 0
finally:
await close_engine()
def main() -> None:
parser = argparse.ArgumentParser(description="Reset admin password")
parser.add_argument("--email", help="Admin email (default: first admin found)")
args = parser.parse_args()
exit_code = asyncio.run(_run(args.email))
sys.exit(exit_code)
if __name__ == "__main__":
main()
+117
View File
@@ -0,0 +1,117 @@
"""Global authentication middleware — fail-closed safety net.
Rejects unauthenticated requests to non-public paths with 401. When a
request passes the cookie check, resolves the JWT payload to a real
``User`` object and stamps it into both ``request.state.user`` and the
``deerflow.runtime.user_context`` contextvar so that repository-layer
owner filtering works automatically via the sentinel pattern.
Fine-grained permission checks remain in authz.py decorators.
"""
from collections.abc import Callable
from fastapi import HTTPException, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
from deerflow.runtime.user_context import reset_current_user, set_current_user
# Paths that never require authentication.
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
"/health",
"/docs",
"/redoc",
"/openapi.json",
)
# Exact auth paths that are public (login/register/status check).
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/login/local",
"/api/v1/auth/register",
"/api/v1/auth/logout",
"/api/v1/auth/setup-status",
}
)
def _is_public(path: str) -> bool:
stripped = path.rstrip("/")
if stripped in _PUBLIC_EXACT_PATHS:
return True
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
class AuthMiddleware(BaseHTTPMiddleware):
"""Strict auth gate: reject requests without a valid session.
Two-stage check for non-public paths:
1. Cookie presence — return 401 NOT_AUTHENTICATED if missing
2. JWT validation via ``get_optional_user_from_request`` — return 401
TOKEN_INVALID if the token is absent, malformed, expired, or the
signed user does not exist / is stale
On success, stamps ``request.state.user`` and the
``deerflow.runtime.user_context`` contextvar so that repository-layer
owner filters work downstream without every route needing a
``@require_auth`` decorator. Routes that need per-resource
authorization (e.g. "user A cannot read user B's thread by guessing
the URL") should additionally use ``@require_permission(...,
owner_check=True)`` for explicit enforcement — but authentication
itself is fully handled here.
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if _is_public(request.url.path):
return await call_next(request)
# Non-public path: require session cookie
if not request.cookies.get("access_token"):
return JSONResponse(
status_code=401,
content={
"detail": AuthErrorResponse(
code=AuthErrorCode.NOT_AUTHENTICATED,
message="Authentication required",
).model_dump()
},
)
# Strict JWT validation: reject junk/expired tokens with 401
# right here instead of silently passing through. This closes
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
# without this, non-isolation routes like /api/models would
# accept any cookie-shaped string as authentication.
#
# We call the *strict* resolver so that fine-grained error
# codes (token_expired, token_invalid, user_not_found, …)
# propagate from AuthErrorCode, not get flattened into one
# generic code. BaseHTTPMiddleware doesn't let HTTPException
# bubble up, so we catch and render it as JSONResponse here.
from app.gateway.deps import get_current_user_from_request
try:
user = await get_current_user_from_request(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
# Stamp both request.state.user (for the contextvar pattern)
# and request.state.auth (so @require_permission's "auth is
# None" branch short-circuits instead of running the entire
# JWT-decode + DB-lookup pipeline a second time per request).
request.state.user = user
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
token = set_current_user(user)
try:
return await call_next(request)
finally:
reset_current_user(token)
+262
View File
@@ -0,0 +1,262 @@
"""Authorization decorators and context for DeerFlow.
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
**Usage:**
1. Use ``@require_auth`` on routes that need authentication
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
3. The decorator chain processes from bottom to top
**Example:**
@router.get("/{thread_id}")
@require_auth
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request):
# User is authenticated and has threads:read permission
...
**Permission Model:**
- threads:read - View thread
- threads:write - Create/update thread
- threads:delete - Delete thread
- runs:create - Run agent
- runs:read - View run
- runs:cancel - Cancel run
"""
from __future__ import annotations
import functools
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
from fastapi import HTTPException, Request
if TYPE_CHECKING:
from app.gateway.auth.models import User
P = ParamSpec("P")
T = TypeVar("T")
# Permission constants
class Permissions:
"""Permission constants for resource:action format."""
# Threads
THREADS_READ = "threads:read"
THREADS_WRITE = "threads:write"
THREADS_DELETE = "threads:delete"
# Runs
RUNS_CREATE = "runs:create"
RUNS_READ = "runs:read"
RUNS_CANCEL = "runs:cancel"
class AuthContext:
"""Authentication context for the current request.
Stored in request.state.auth after require_auth decoration.
Attributes:
user: The authenticated user, or None if anonymous
permissions: List of permission strings (e.g., "threads:read")
"""
__slots__ = ("user", "permissions")
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
self.user = user
self.permissions = permissions or []
@property
def is_authenticated(self) -> bool:
"""Check if user is authenticated."""
return self.user is not None
def has_permission(self, resource: str, action: str) -> bool:
"""Check if context has permission for resource:action.
Args:
resource: Resource name (e.g., "threads")
action: Action name (e.g., "read")
Returns:
True if user has permission
"""
permission = f"{resource}:{action}"
return permission in self.permissions
def require_user(self) -> User:
"""Get user or raise 401.
Raises:
HTTPException 401 if not authenticated
"""
if not self.user:
raise HTTPException(status_code=401, detail="Authentication required")
return self.user
def get_auth_context(request: Request) -> AuthContext | None:
"""Get AuthContext from request state."""
return getattr(request.state, "auth", None)
_ALL_PERMISSIONS: list[str] = [
Permissions.THREADS_READ,
Permissions.THREADS_WRITE,
Permissions.THREADS_DELETE,
Permissions.RUNS_CREATE,
Permissions.RUNS_READ,
Permissions.RUNS_CANCEL,
]
async def _authenticate(request: Request) -> AuthContext:
"""Authenticate request and return AuthContext.
Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline.
Returns AuthContext with user=None for anonymous requests.
"""
from app.gateway.deps import get_optional_user_from_request
user = await get_optional_user_from_request(request)
if user is None:
return AuthContext(user=None, permissions=[])
# In future, permissions could be stored in user record
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
"""Decorator that authenticates the request and sets AuthContext.
Must be placed ABOVE other decorators (executes after them).
Usage:
@router.get("/{thread_id}")
@require_auth # Bottom decorator (executes first after permission check)
@require_permission("threads", "read")
async def get_thread(thread_id: str, request: Request):
auth: AuthContext = request.state.auth
...
Raises:
ValueError: If 'request' parameter is missing
"""
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
raise ValueError("require_auth decorator requires 'request' parameter")
# Authenticate and set context
auth_context = await _authenticate(request)
request.state.auth = auth_context
return await func(*args, **kwargs)
return wrapper
def require_permission(
resource: str,
action: str,
owner_check: bool = False,
require_existing: bool = False,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator that checks permission for resource:action.
Must be used AFTER @require_auth.
Args:
resource: Resource name (e.g., "threads", "runs")
action: Action name (e.g., "read", "write", "delete")
owner_check: If True, validates that the current user owns the resource.
Requires 'thread_id' path parameter and performs ownership check.
require_existing: Only meaningful with ``owner_check=True``. If True, a
missing ``threads_meta`` row counts as a denial (404)
instead of "untracked legacy thread, allow". Use on
**destructive / mutating** routes (DELETE, PATCH,
state-update) so a deleted thread can't be re-targeted
by another user via the missing-row code path.
Usage:
# Read-style: legacy untracked threads are allowed
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request):
...
# Destructive: thread row MUST exist and be owned by caller
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_thread(thread_id: str, request: Request):
...
Raises:
HTTPException 401: If authentication required but user is anonymous
HTTPException 403: If user lacks permission
HTTPException 404: If owner_check=True but user doesn't own the thread
ValueError: If owner_check=True but 'thread_id' parameter is missing
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
raise ValueError("require_permission decorator requires 'request' parameter")
auth: AuthContext = getattr(request.state, "auth", None)
if auth is None:
auth = await _authenticate(request)
request.state.auth = auth
if not auth.is_authenticated:
raise HTTPException(status_code=401, detail="Authentication required")
# Check permission
if not auth.has_permission(resource, action):
raise HTTPException(
status_code=403,
detail=f"Permission denied: {resource}:{action}",
)
# Owner check for thread-specific resources.
#
# 2.0-rc moved thread metadata into the SQL persistence layer
# (``threads_meta`` table). We verify ownership via
# ``ThreadMetaStore.check_access``: it returns True for
# missing rows (untracked legacy thread) and for rows whose
# ``owner_id`` is NULL (shared / pre-auth data), so this is
# strict-deny rather than strict-allow — only an *existing*
# row with a *different* owner_id triggers 404.
if owner_check:
thread_id = kwargs.get("thread_id")
if thread_id is None:
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
from app.gateway.deps import get_thread_meta_repo
thread_meta_repo = get_thread_meta_repo(request)
allowed = await thread_meta_repo.check_access(
thread_id,
str(auth.user.id),
require_existing=require_existing,
)
if not allowed:
raise HTTPException(
status_code=404,
detail=f"Thread {thread_id} not found",
)
return await func(*args, **kwargs)
return wrapper
return decorator
+112
View File
@@ -0,0 +1,112 @@
"""CSRF protection middleware for FastAPI.
Per RFC-001:
State-changing operations require CSRF protection.
"""
import secrets
from collections.abc import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
CSRF_COOKIE_NAME = "csrf_token"
CSRF_HEADER_NAME = "X-CSRF-Token"
CSRF_TOKEN_LENGTH = 64 # bytes
def is_secure_request(request: Request) -> bool:
"""Detect whether the original client request was made over HTTPS."""
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
def generate_csrf_token() -> str:
"""Generate a secure random CSRF token."""
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
def should_check_csrf(request: Request) -> bool:
"""Determine if a request needs CSRF validation.
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
"""
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
return False
path = request.url.path.rstrip("/")
# Exempt /api/v1/auth/me endpoint
if path == "/api/v1/auth/me":
return False
return True
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/login/local",
"/api/v1/auth/logout",
"/api/v1/auth/register",
}
)
def is_auth_endpoint(request: Request) -> bool:
"""Check if the request is to an auth endpoint.
Auth endpoints don't need CSRF validation on first call (no token).
"""
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
_is_auth = is_auth_endpoint(request)
if should_check_csrf(request) and not _is_auth:
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
header_token = request.headers.get(CSRF_HEADER_NAME)
if not cookie_token or not header_token:
return JSONResponse(
status_code=403,
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
)
if not secrets.compare_digest(cookie_token, header_token):
return JSONResponse(
status_code=403,
content={"detail": "CSRF token mismatch."},
)
response = await call_next(request)
# For auth endpoints that set up session, also set CSRF cookie
if _is_auth and request.method == "POST":
# Generate a new CSRF token for the session
csrf_token = generate_csrf_token()
is_https = is_secure_request(request)
response.set_cookie(
key=CSRF_COOKIE_NAME,
value=csrf_token,
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
secure=is_https,
samesite="strict",
)
return response
def get_csrf_token(request: Request) -> str | None:
"""Get the CSRF token from the current request's cookies.
This is useful for server-side rendering where you need to embed
token in forms or headers.
"""
return request.cookies.get(CSRF_COOKIE_NAME)
+180 -25
View File
@@ -1,7 +1,8 @@
"""Centralized accessors for singleton objects stored on ``app.state``.
**Getters** (used by routers): raise 503 when a required dependency is
missing, except ``get_store`` which returns ``None``.
missing, except ``get_store`` and ``get_thread_meta_repo`` which return
``None``.
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
"""
@@ -10,10 +11,15 @@ from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING
from fastapi import FastAPI, HTTPException, Request
from deerflow.runtime import RunManager, StreamBridge
from deerflow.runtime import RunContext, RunManager
if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
@asynccontextmanager
@@ -26,45 +32,194 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
yield
"""
from deerflow.agents.checkpointer.async_provider import make_checkpointer
from deerflow.config import get_app_config
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
from deerflow.runtime import make_store, make_stream_bridge
from deerflow.runtime.events.store import make_run_event_store
async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
# Initialize persistence engine BEFORE checkpointer so that
# auto-create-database logic runs first (postgres backend).
config = get_app_config()
await init_engine_from_config(config.database)
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
app.state.store = await stack.enter_async_context(make_store())
app.state.run_manager = RunManager()
yield
# Initialize repositories — one get_session_factory() call for all.
sf = get_session_factory()
if sf is not None:
from deerflow.persistence.feedback import FeedbackRepository
from deerflow.persistence.run import RunRepository
from deerflow.persistence.thread_meta import ThreadMetaRepository
app.state.run_store = RunRepository(sf)
app.state.feedback_repo = FeedbackRepository(sf)
app.state.thread_meta_repo = ThreadMetaRepository(sf)
else:
from deerflow.persistence.thread_meta import MemoryThreadMetaStore
from deerflow.runtime.runs.store.memory import MemoryRunStore
app.state.run_store = MemoryRunStore()
app.state.feedback_repo = None
app.state.thread_meta_repo = MemoryThreadMetaStore(app.state.store)
# Run event store (has its own factory with config-driven backend selection)
run_events_config = getattr(config, "run_events", None)
app.state.run_event_store = make_run_event_store(run_events_config)
# RunManager with store backing for persistence
app.state.run_manager = RunManager(store=app.state.run_store)
try:
yield
finally:
await close_engine()
# ---------------------------------------------------------------------------
# Getters called by routers per-request
# Getters -- called by routers per-request
# ---------------------------------------------------------------------------
def get_stream_bridge(request: Request) -> StreamBridge:
"""Return the global :class:`StreamBridge`, or 503."""
bridge = getattr(request.app.state, "stream_bridge", None)
if bridge is None:
raise HTTPException(status_code=503, detail="Stream bridge not available")
return bridge
def _require(attr: str, label: str):
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
def dep(request: Request):
val = getattr(request.app.state, attr, None)
if val is None:
raise HTTPException(status_code=503, detail=f"{label} not available")
return val
dep.__name__ = dep.__qualname__ = f"get_{attr}"
return dep
def get_run_manager(request: Request) -> RunManager:
"""Return the global :class:`RunManager`, or 503."""
mgr = getattr(request.app.state, "run_manager", None)
if mgr is None:
raise HTTPException(status_code=503, detail="Run manager not available")
return mgr
def get_checkpointer(request: Request):
"""Return the global checkpointer, or 503."""
cp = getattr(request.app.state, "checkpointer", None)
if cp is None:
raise HTTPException(status_code=503, detail="Checkpointer not available")
return cp
get_stream_bridge = _require("stream_bridge", "Stream bridge")
get_run_manager = _require("run_manager", "Run manager")
get_checkpointer = _require("checkpointer", "Checkpointer")
get_run_event_store = _require("run_event_store", "Run event store")
get_feedback_repo = _require("feedback_repo", "Feedback")
get_run_store = _require("run_store", "Run store")
def get_store(request: Request):
"""Return the global store (may be ``None`` if not configured)."""
return getattr(request.app.state, "store", None)
get_thread_meta_repo = _require("thread_meta_repo", "Thread metadata store")
def get_run_context(request: Request) -> RunContext:
"""Build a :class:`RunContext` from ``app.state`` singletons.
Returns a *base* context with infrastructure dependencies. Callers that
need per-run fields (e.g. ``follow_up_to_run_id``) should use
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
to :func:`run_agent`.
"""
from deerflow.config import get_app_config
return RunContext(
checkpointer=get_checkpointer(request),
store=get_store(request),
event_store=get_run_event_store(request),
run_events_config=getattr(get_app_config(), "run_events", None),
thread_meta_repo=get_thread_meta_repo(request),
)
# ---------------------------------------------------------------------------
# Auth helpers (used by authz.py and auth middleware)
# ---------------------------------------------------------------------------
# Cached singletons to avoid repeated instantiation per request
_cached_local_provider: LocalAuthProvider | None = None
_cached_repo: SQLiteUserRepository | None = None
def get_local_provider() -> LocalAuthProvider:
"""Get or create the cached LocalAuthProvider singleton.
Must be called after ``init_engine_from_config()`` — the shared
session factory is required to construct the user repository.
"""
global _cached_local_provider, _cached_repo
if _cached_repo is None:
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.engine import get_session_factory
sf = get_session_factory()
if sf is None:
raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table")
_cached_repo = SQLiteUserRepository(sf)
if _cached_local_provider is None:
from app.gateway.auth.local_provider import LocalAuthProvider
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
return _cached_local_provider
async def get_current_user_from_request(request: Request):
"""Get the current authenticated user from the request cookie.
Raises HTTPException 401 if not authenticated.
"""
from app.gateway.auth import decode_token
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
access_token = request.cookies.get("access_token")
if not access_token:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
)
payload = decode_token(access_token)
if isinstance(payload, TokenError):
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
)
provider = get_local_provider()
user = await provider.get_user(payload.sub)
if user is None:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
)
# Token version mismatch → password was changed, token is stale
if user.token_version != payload.ver:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
)
return user
async def get_optional_user_from_request(request: Request):
"""Get optional authenticated user from request.
Returns None if not authenticated.
"""
try:
return await get_current_user_from_request(request)
except HTTPException:
return None
async def get_current_user(request: Request) -> str | None:
"""Extract user_id from request cookie, or None if not authenticated.
Thin adapter that returns the string id for callers that only need
identification (e.g., ``feedback.py``). Full-user callers should use
``get_current_user_from_request`` or ``get_optional_user_from_request``.
"""
user = await get_optional_user_from_request(request)
return str(user.id) if user else None
+106
View File
@@ -0,0 +1,106 @@
"""LangGraph Server auth handler — shares JWT logic with Gateway.
Loaded by LangGraph Server via langgraph.json ``auth.path``.
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
so both modes validate tokens with the same secret and rules.
Two layers:
1. @auth.authenticate — validates JWT cookie, extracts user_id,
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
2. @auth.on — returns metadata filter so each user only sees own threads
"""
import secrets
from langgraph_sdk import Auth
from app.gateway.auth.errors import TokenError
from app.gateway.auth.jwt import decode_token
from app.gateway.deps import get_local_provider
auth = Auth()
# Methods that require CSRF validation (state-changing per RFC 7231).
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
def _check_csrf(request) -> None:
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
proxied directly by nginx have the same CSRF protection.
"""
method = getattr(request, "method", "") or ""
if method.upper() not in _CSRF_METHODS:
return
cookie_token = request.cookies.get("csrf_token")
header_token = request.headers.get("x-csrf-token")
if not cookie_token or not header_token:
raise Auth.exceptions.HTTPException(
status_code=403,
detail="CSRF token missing. Include X-CSRF-Token header.",
)
if not secrets.compare_digest(cookie_token, header_token):
raise Auth.exceptions.HTTPException(
status_code=403,
detail="CSRF token mismatch.",
)
@auth.authenticate
async def authenticate(request):
"""Validate the session cookie, decode JWT, and check token_version.
Same validation chain as Gateway's get_current_user_from_request:
cookie → decode JWT → DB lookup → token_version match
Also enforces CSRF on state-changing methods.
"""
# CSRF check before authentication so forged cross-site requests
# are rejected early, even if the cookie carries a valid JWT.
_check_csrf(request)
token = request.cookies.get("access_token")
if not token:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Not authenticated",
)
payload = decode_token(token)
if isinstance(payload, TokenError):
raise Auth.exceptions.HTTPException(
status_code=401,
detail=f"Token error: {payload.value}",
)
user = await get_local_provider().get_user(payload.sub)
if user is None:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="User not found",
)
if user.token_version != payload.ver:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Token revoked (password changed)",
)
return payload.sub
@auth.on
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
"""Inject owner_id metadata on writes; filter by owner_id on reads.
Gateway stores thread ownership as ``metadata.owner_id``.
This handler ensures LangGraph Server enforces the same isolation.
"""
# On create/update: stamp owner_id into metadata
metadata = value.setdefault("metadata", {})
metadata["owner_id"] = ctx.user.identity
# Return filter dict — LangGraph applies it to search/read/delete
return {"owner_id": ctx.user.identity}
-21
View File
@@ -8,7 +8,6 @@ import yaml
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config.agents_api_config import get_agents_api_config
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
from deerflow.config.paths import get_paths
@@ -74,15 +73,6 @@ def _normalize_agent_name(name: str) -> str:
return name.lower()
def _require_agents_api_enabled() -> None:
"""Reject access unless the custom-agent management API is explicitly enabled."""
if not get_agents_api_config().enabled:
raise HTTPException(
status_code=403,
detail=("Custom-agent management API is disabled. Set agents_api.enabled=true to expose agent and user-profile routes over HTTP."),
)
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
"""Convert AgentConfig to AgentResponse."""
soul: str | None = None
@@ -110,8 +100,6 @@ async def list_agents() -> AgentsListResponse:
Returns:
List of all custom agents with their metadata and soul content.
"""
_require_agents_api_enabled()
try:
agents = list_custom_agents()
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
@@ -137,7 +125,6 @@ async def check_agent_name(name: str) -> dict:
Raises:
HTTPException: 422 if the name is invalid.
"""
_require_agents_api_enabled()
_validate_agent_name(name)
normalized = _normalize_agent_name(name)
available = not get_paths().agent_dir(normalized).exists()
@@ -162,7 +149,6 @@ async def get_agent(name: str) -> AgentResponse:
Raises:
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
@@ -195,7 +181,6 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
Raises:
HTTPException: 409 if agent already exists, 422 if name is invalid.
"""
_require_agents_api_enabled()
_validate_agent_name(request.name)
normalized_name = _normalize_agent_name(request.name)
@@ -258,7 +243,6 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
Raises:
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
@@ -331,8 +315,6 @@ async def get_user_profile() -> UserProfileResponse:
Returns:
UserProfileResponse with content=None if USER.md does not exist yet.
"""
_require_agents_api_enabled()
try:
user_md_path = get_paths().user_md_file
if not user_md_path.exists():
@@ -359,8 +341,6 @@ async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileR
Returns:
UserProfileResponse with the saved content.
"""
_require_agents_api_enabled()
try:
paths = get_paths()
paths.base_dir.mkdir(parents=True, exist_ok=True)
@@ -387,7 +367,6 @@ async def delete_agent(name: str) -> None:
Raises:
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
+2
View File
@@ -7,6 +7,7 @@ from urllib.parse import quote
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import FileResponse, PlainTextResponse, Response
from app.gateway.authz import require_permission
from app.gateway.path_utils import resolve_thread_virtual_path
logger = logging.getLogger(__name__)
@@ -81,6 +82,7 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
summary="Get Artifact File",
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
)
@require_permission("threads", "read", owner_check=True)
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
"""Get an artifact file by its path.
+418
View File
@@ -0,0 +1,418 @@
"""Authentication endpoints."""
import logging
import os
import time
from ipaddress import ip_address, ip_network
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr, Field, field_validator
from app.gateway.auth import (
UserResponse,
create_access_token,
)
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.csrf_middleware import is_secure_request
from app.gateway.deps import get_current_user_from_request, get_local_provider
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
# ── Request/Response Models ──────────────────────────────────────────────
class LoginResponse(BaseModel):
"""Response model for login — token only lives in HttpOnly cookie."""
expires_in: int # seconds
needs_setup: bool = False
# Top common-password blocklist. Drawn from the public SecLists "10k worst
# passwords" set, lowercased + length>=8 only (shorter ones already fail
# the min_length check). Kept tight on purpose: this is the **lower bound**
# defense, not a full HIBP / passlib check, and runs in-process per request.
_COMMON_PASSWORDS: frozenset[str] = frozenset(
{
"password",
"password1",
"password12",
"password123",
"password1234",
"12345678",
"123456789",
"1234567890",
"qwerty12",
"qwertyui",
"qwerty123",
"abc12345",
"abcd1234",
"iloveyou",
"letmein1",
"welcome1",
"welcome123",
"admin123",
"administrator",
"passw0rd",
"p@ssw0rd",
"monkey12",
"trustno1",
"sunshine",
"princess",
"football",
"baseball",
"superman",
"batman123",
"starwars",
"dragon123",
"master123",
"shadow12",
"michael1",
"jennifer",
"computer",
}
)
def _password_is_common(password: str) -> bool:
"""Case-insensitive blocklist check.
Lowercases the input so trivial mutations like ``Password`` /
``PASSWORD`` are also rejected. Does not normalize digit substitutions
(``p@ssw0rd`` is included as a literal entry instead) — keeping the
rule cheap and predictable.
"""
return password.lower() in _COMMON_PASSWORDS
def _validate_strong_password(value: str) -> str:
"""Pydantic field-validator body shared by Register + ChangePassword.
Constraint = function, not type-level mixin. The two request models
have no "is-a" relationship; they only share the password-strength
rule. Lifting it into a free function lets each model bind it via
``@field_validator(field_name)`` without inheritance gymnastics.
"""
if _password_is_common(value):
raise ValueError("Password is too common; choose a stronger password.")
return value
class RegisterRequest(BaseModel):
"""Request model for user registration."""
email: EmailStr
password: str = Field(..., min_length=8)
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
class ChangePasswordRequest(BaseModel):
"""Request model for password change (also handles setup flow)."""
current_password: str
new_password: str = Field(..., min_length=8)
new_email: EmailStr | None = None
_strong_password = field_validator("new_password")(classmethod(lambda cls, v: _validate_strong_password(v)))
class MessageResponse(BaseModel):
"""Generic message response."""
message: str
# ── Helpers ───────────────────────────────────────────────────────────────
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
"""Set the access_token HttpOnly cookie on the response."""
config = get_auth_config()
is_https = is_secure_request(request)
response.set_cookie(
key="access_token",
value=token,
httponly=True,
secure=is_https,
samesite="lax",
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
)
# ── Rate Limiting ────────────────────────────────────────────────────────
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
_MAX_LOGIN_ATTEMPTS = 5
_LOCKOUT_SECONDS = 300 # 5 minutes
# ip → (fail_count, lock_until_timestamp)
_login_attempts: dict[str, tuple[int, float]] = {}
def _trusted_proxies() -> list:
"""Parse ``AUTH_TRUSTED_PROXIES`` env var into a list of ip_network objects.
Comma-separated CIDR or single-IP entries. Empty / unset = no proxy is
trusted (direct mode). Invalid entries are skipped with a logger warning.
Read live so env-var overrides take effect immediately and tests can
``monkeypatch.setenv`` without poking a module-level cache.
"""
raw = os.getenv("AUTH_TRUSTED_PROXIES", "").strip()
if not raw:
return []
nets = []
for entry in raw.split(","):
entry = entry.strip()
if not entry:
continue
try:
nets.append(ip_network(entry, strict=False))
except ValueError:
logger.warning("AUTH_TRUSTED_PROXIES: ignoring invalid entry %r", entry)
return nets
def _get_client_ip(request: Request) -> str:
"""Extract the real client IP for rate limiting.
Trust model:
- The TCP peer (``request.client.host``) is always the baseline. It is
whatever the kernel reports as the connecting socket — unforgeable
by the client itself.
- ``X-Real-IP`` is **only** honored if the TCP peer is in the
``AUTH_TRUSTED_PROXIES`` allowlist (set via env var, comma-separated
CIDR or single IPs). When set, the gateway is assumed to be behind a
reverse proxy (nginx, Cloudflare, ALB, …) that overwrites
``X-Real-IP`` with the original client address.
- With no ``AUTH_TRUSTED_PROXIES`` set, ``X-Real-IP`` is silently
ignored — closing the bypass where any client could rotate the
header to dodge per-IP rate limits in dev / direct-gateway mode.
``X-Forwarded-For`` is intentionally NOT used because it is naturally
client-controlled at the *first* hop and the trust chain is harder to
audit per-request.
"""
peer_host = request.client.host if request.client else None
trusted = _trusted_proxies()
if trusted and peer_host:
try:
peer_ip = ip_address(peer_host)
if any(peer_ip in net for net in trusted):
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
except ValueError:
# peer_host wasn't a parseable IP (e.g. "unknown") — fall through
pass
return peer_host or "unknown"
def _check_rate_limit(ip: str) -> None:
"""Raise 429 if the IP is currently locked out."""
record = _login_attempts.get(ip)
if record is None:
return
fail_count, lock_until = record
if fail_count >= _MAX_LOGIN_ATTEMPTS:
if time.time() < lock_until:
raise HTTPException(
status_code=429,
detail="Too many login attempts. Try again later.",
)
del _login_attempts[ip]
_MAX_TRACKED_IPS = 10000
def _record_login_failure(ip: str) -> None:
"""Record a failed login attempt for the given IP."""
# Evict expired lockouts when dict grows too large
if len(_login_attempts) >= _MAX_TRACKED_IPS:
now = time.time()
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
for k in expired:
del _login_attempts[k]
# If still too large, evict cheapest-to-lose half: below-threshold
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
if len(_login_attempts) >= _MAX_TRACKED_IPS:
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
for k, _ in by_time[: len(by_time) // 2]:
del _login_attempts[k]
record = _login_attempts.get(ip)
if record is None:
_login_attempts[ip] = (1, 0.0)
else:
new_count = record[0] + 1
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
_login_attempts[ip] = (new_count, lock_until)
def _record_login_success(ip: str) -> None:
"""Clear failure counter for the given IP on successful login."""
_login_attempts.pop(ip, None)
# ── Endpoints ─────────────────────────────────────────────────────────────
@router.post("/login/local", response_model=LoginResponse)
async def login_local(
request: Request,
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
):
"""Local email/password login."""
client_ip = _get_client_ip(request)
_check_rate_limit(client_ip)
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
if user is None:
_record_login_failure(client_ip)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
)
_record_login_success(client_ip)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return LoginResponse(
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
needs_setup=user.needs_setup,
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(request: Request, response: Response, body: RegisterRequest):
"""Register a new user account (always 'user' role).
Admin is auto-created on first boot. This endpoint creates regular users.
Auto-login by setting the session cookie.
"""
try:
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
@router.post("/logout", response_model=MessageResponse)
async def logout(request: Request, response: Response):
"""Logout current user by clearing the cookie."""
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
return MessageResponse(message="Successfully logged out")
@router.post("/change-password", response_model=MessageResponse)
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
"""Change password for the currently authenticated user.
Also handles the first-boot setup flow:
- If new_email is provided, updates email (checks uniqueness)
- If user.needs_setup is True and new_email is given, clears needs_setup
- Always increments token_version to invalidate old sessions
- Re-issues session cookie with new token_version
"""
from app.gateway.auth.password import hash_password_async, verify_password_async
user = await get_current_user_from_request(request)
if user.password_hash is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
if not await verify_password_async(body.current_password, user.password_hash):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
provider = get_local_provider()
# Update email if provided
if body.new_email is not None:
existing = await provider.get_user_by_email(body.new_email)
if existing and str(existing.id) != str(user.id):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
user.email = body.new_email
# Update password + bump version
user.password_hash = await hash_password_async(body.new_password)
user.token_version += 1
# Clear setup flag if this is the setup flow
if user.needs_setup and body.new_email is not None:
user.needs_setup = False
await provider.update_user(user)
# Re-issue cookie with new token_version
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return MessageResponse(message="Password changed successfully")
@router.get("/me", response_model=UserResponse)
async def get_me(request: Request):
"""Get current authenticated user info."""
user = await get_current_user_from_request(request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
@router.get("/setup-status")
async def setup_status():
"""Check if admin account exists. Always False after first boot."""
user_count = await get_local_provider().count_users()
return {"needs_setup": user_count == 0}
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
@router.get("/oauth/{provider}")
async def oauth_login(provider: str):
"""Initiate OAuth login flow.
Redirects to the OAuth provider's authorization URL.
Currently a placeholder - requires OAuth provider implementation.
"""
if provider not in ["github", "google"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported OAuth provider: {provider}",
)
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth login not yet implemented",
)
@router.get("/callback/{provider}")
async def oauth_callback(provider: str, code: str, state: str):
"""OAuth callback endpoint.
Handles the OAuth provider's callback after user authorization.
Currently a placeholder.
"""
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth callback not yet implemented",
)
+132
View File
@@ -0,0 +1,132 @@
"""Feedback endpoints — create, list, stats, delete.
Allows users to submit thumbs-up/down feedback on runs,
optionally scoped to a specific message.
"""
from __future__ import annotations
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["feedback"])
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class FeedbackCreateRequest(BaseModel):
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
comment: str | None = Field(default=None, description="Optional text feedback")
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
class FeedbackResponse(BaseModel):
feedback_id: str
run_id: str
thread_id: str
owner_id: str | None = None
message_id: str | None = None
rating: int
comment: str | None = None
created_at: str = ""
class FeedbackStatsResponse(BaseModel):
run_id: str
total: int = 0
positive: int = 0
negative: int = 0
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def create_feedback(
thread_id: str,
run_id: str,
body: FeedbackCreateRequest,
request: Request,
) -> dict[str, Any]:
"""Submit feedback (thumbs-up/down) for a run."""
if body.rating not in (1, -1):
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
user_id = await get_current_user(request)
# Validate run exists and belongs to thread
run_store = get_run_store(request)
run = await run_store.get(run_id)
if run is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
if run.get("thread_id") != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
feedback_repo = get_feedback_repo(request)
return await feedback_repo.create(
run_id=run_id,
thread_id=thread_id,
rating=body.rating,
owner_id=user_id,
message_id=body.message_id,
comment=body.comment,
)
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
@require_permission("threads", "read", owner_check=True)
async def list_feedback(
thread_id: str,
run_id: str,
request: Request,
) -> list[dict[str, Any]]:
"""List all feedback for a run."""
feedback_repo = get_feedback_repo(request)
return await feedback_repo.list_by_run(thread_id, run_id)
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
@require_permission("threads", "read", owner_check=True)
async def feedback_stats(
thread_id: str,
run_id: str,
request: Request,
) -> dict[str, Any]:
"""Get aggregated feedback stats (positive/negative counts) for a run."""
feedback_repo = get_feedback_repo(request)
return await feedback_repo.aggregate_by_run(thread_id, run_id)
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_feedback(
thread_id: str,
run_id: str,
feedback_id: str,
request: Request,
) -> dict[str, bool]:
"""Delete a feedback record."""
feedback_repo = get_feedback_repo(request)
# Verify feedback belongs to the specified thread/run before deleting
existing = await feedback_repo.get(feedback_id)
if existing is None:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
if existing.get("thread_id") != thread_id or existing.get("run_id") != run_id:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found in run {run_id}")
deleted = await feedback_repo.delete(feedback_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
return {"success": True}
+5 -22
View File
@@ -17,17 +17,10 @@ class ModelResponse(BaseModel):
supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort")
class TokenUsageResponse(BaseModel):
"""Token usage display configuration."""
enabled: bool = Field(default=False, description="Whether token usage display is enabled")
class ModelsListResponse(BaseModel):
"""Response model for listing all models."""
models: list[ModelResponse]
token_usage: TokenUsageResponse
@router.get(
@@ -43,7 +36,7 @@ async def list_models() -> ModelsListResponse:
excluding sensitive fields like API keys and internal configuration.
Returns:
A list of all configured models with their metadata and token usage display settings.
A list of all configured models with their metadata.
Example Response:
```json
@@ -51,24 +44,17 @@ async def list_models() -> ModelsListResponse:
"models": [
{
"name": "gpt-4",
"model": "gpt-4",
"display_name": "GPT-4",
"description": "OpenAI GPT-4 model",
"supports_thinking": false,
"supports_reasoning_effort": false
"supports_thinking": false
},
{
"name": "claude-3-opus",
"model": "claude-3-opus",
"display_name": "Claude 3 Opus",
"description": "Anthropic Claude 3 Opus model",
"supports_thinking": true,
"supports_reasoning_effort": false
"supports_thinking": true
}
],
"token_usage": {
"enabled": true
}
]
}
```
"""
@@ -84,10 +70,7 @@ async def list_models() -> ModelsListResponse:
)
for model in config.models
]
return ModelsListResponse(
models=models,
token_usage=TokenUsageResponse(enabled=config.token_usage.enabled),
)
return ModelsListResponse(models=models)
@router.get(
+12 -18
View File
@@ -1,4 +1,3 @@
import errno
import json
import logging
import shutil
@@ -202,23 +201,18 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
ensure_custom_skill_is_editable(skill_name)
skill_dir = get_custom_skill_dir(skill_name)
prev_content = read_custom_skill_content(skill_name)
try:
append_history(
skill_name,
{
"action": "human_delete",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": prev_content,
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
)
except OSError as e:
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
raise
logger.warning("Skipping delete history write for custom skill %s due to readonly/permission failure; continuing with skill directory removal: %s", skill_name, e)
append_history(
skill_name,
{
"action": "human_delete",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": prev_content,
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
)
shutil.rmtree(skill_dir)
await refresh_skills_system_prompt_cache_async()
return {"success": True}
+8 -6
View File
@@ -1,10 +1,11 @@
import json
import logging
from fastapi import APIRouter
from fastapi import APIRouter, Request
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -98,12 +99,13 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
summary="Generate Follow-up Questions",
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
if not request.messages:
@require_permission("threads", "read", owner_check=True)
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
if not body.messages:
return SuggestionsResponse(suggestions=[])
n = request.n
conversation = _format_conversation(request.messages)
n = body.n
conversation = _format_conversation(body.messages)
if not conversation:
return SuggestionsResponse(suggestions=[])
@@ -120,7 +122,7 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try:
model = create_chat_model(name=request.model_name, thinking_enabled=False)
model = create_chat_model(name=body.model_name, thinking_enabled=False)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
+62 -1
View File
@@ -19,7 +19,8 @@ from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import RunRecord, serialize_channel_values
@@ -53,6 +54,7 @@ class RunCreateRequest(BaseModel):
after_seconds: float | None = Field(default=None, description="Delayed execution")
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.")
class RunResponse(BaseModel):
@@ -92,6 +94,7 @@ def _record_to_response(record: RunRecord) -> RunResponse:
@router.post("/{thread_id}/runs", response_model=RunResponse)
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
"""Create a background run (returns immediately)."""
record = await start_run(body, thread_id, request)
@@ -99,6 +102,7 @@ async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -
@router.post("/{thread_id}/runs/stream")
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE.
@@ -126,6 +130,7 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
@router.post("/{thread_id}/runs/wait", response_model=dict)
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state."""
record = await start_run(body, thread_id, request)
@@ -151,6 +156,7 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
@require_permission("runs", "read", owner_check=True)
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
"""List all runs for a thread."""
run_mgr = get_run_manager(request)
@@ -159,6 +165,7 @@ async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
@require_permission("runs", "read", owner_check=True)
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
"""Get details of a specific run."""
run_mgr = get_run_manager(request)
@@ -169,6 +176,7 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
@router.post("/{thread_id}/runs/{run_id}/cancel")
@require_permission("runs", "cancel", owner_check=True, require_existing=True)
async def cancel_run(
thread_id: str,
run_id: str,
@@ -206,6 +214,7 @@ async def cancel_run(
@router.get("/{thread_id}/runs/{run_id}/join")
@require_permission("runs", "read", owner_check=True)
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
"""Join an existing run's SSE stream."""
bridge = get_stream_bridge(request)
@@ -226,6 +235,7 @@ async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingRe
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
@require_permission("runs", "read", owner_check=True)
async def stream_existing_run(
thread_id: str,
run_id: str,
@@ -265,3 +275,54 @@ async def stream_existing_run(
"X-Accel-Buffering": "no",
},
)
# ---------------------------------------------------------------------------
# Messages / Events / Token usage endpoints
# ---------------------------------------------------------------------------
@router.get("/{thread_id}/messages")
@require_permission("runs", "read", owner_check=True)
async def list_thread_messages(
thread_id: str,
request: Request,
limit: int = Query(default=50, le=200),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> list[dict]:
"""Return displayable messages for a thread (across all runs)."""
event_store = get_run_event_store(request)
return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
@router.get("/{thread_id}/runs/{run_id}/messages")
@require_permission("runs", "read", owner_check=True)
async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]:
"""Return displayable messages for a specific run."""
event_store = get_run_event_store(request)
return await event_store.list_messages_by_run(thread_id, run_id)
@router.get("/{thread_id}/runs/{run_id}/events")
@require_permission("runs", "read", owner_check=True)
async def list_run_events(
thread_id: str,
run_id: str,
request: Request,
event_types: str | None = Query(default=None),
limit: int = Query(default=500, le=2000),
) -> list[dict]:
"""Return the full event stream for a run (debug/audit)."""
event_store = get_run_event_store(request)
types = event_types.split(",") if event_types else None
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
@router.get("/{thread_id}/token-usage")
@require_permission("threads", "read", owner_check=True)
async def thread_token_usage(thread_id: str, request: Request) -> dict:
"""Thread-level token usage aggregation."""
run_store = get_run_store(request)
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return {"thread_id": thread_id, **agg}
+171 -237
View File
@@ -18,23 +18,34 @@ import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from app.gateway.deps import get_checkpointer, get_store
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer
from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
# ---------------------------------------------------------------------------
# Store namespace
# ---------------------------------------------------------------------------
THREADS_NS: tuple[str, ...] = ("threads",)
"""Namespace used by the Store for thread metadata records."""
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["threads"])
# Metadata keys that the server controls; clients are not allowed to set
# them. Pydantic ``@field_validator("metadata")`` strips them on every
# inbound model below so a malicious client cannot reflect a forged
# owner identity through the API surface. Defense-in-depth — the
# row-level invariant is still ``threads_meta.owner_id`` populated from
# the auth contextvar; this list closes the metadata-blob echo gap.
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
def _strip_reserved_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]:
"""Return ``metadata`` with server-controlled keys removed."""
if not metadata:
return metadata or {}
return {k: v for k, v in metadata.items() if k not in _SERVER_RESERVED_METADATA_KEYS}
# ---------------------------------------------------------------------------
# Response / request models
# ---------------------------------------------------------------------------
@@ -63,8 +74,11 @@ class ThreadCreateRequest(BaseModel):
"""Request body for creating a thread."""
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
class ThreadSearchRequest(BaseModel):
"""Request body for searching threads."""
@@ -93,6 +107,8 @@ class ThreadPatchRequest(BaseModel):
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
class ThreadStateUpdateRequest(BaseModel):
"""Request body for updating thread state (human-in-the-loop resume)."""
@@ -135,61 +151,16 @@ def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDel
raise HTTPException(status_code=422, detail=str(exc)) from exc
except FileNotFoundError:
# Not critical — thread data may not exist on disk
logger.debug("No local thread data to delete for %s", thread_id)
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
except Exception as exc:
logger.exception("Failed to delete thread data for %s", thread_id)
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
logger.info("Deleted local thread data for %s", thread_id)
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
async def _store_get(store, thread_id: str) -> dict | None:
"""Fetch a thread record from the Store; returns ``None`` if absent."""
item = await store.aget(THREADS_NS, thread_id)
return item.value if item is not None else None
async def _store_put(store, record: dict) -> None:
"""Write a thread record to the Store."""
await store.aput(THREADS_NS, record["thread_id"], record)
async def _store_upsert(store, thread_id: str, *, metadata: dict | None = None, values: dict | None = None) -> None:
"""Create or refresh a thread record in the Store.
On creation the record is written with ``status="idle"``. On update only
``updated_at`` (and optionally ``metadata`` / ``values``) are changed so
that existing fields are preserved.
``values`` carries the agent-state snapshot exposed to the frontend
(currently just ``{"title": "..."}``).
"""
now = time.time()
existing = await _store_get(store, thread_id)
if existing is None:
await _store_put(
store,
{
"thread_id": thread_id,
"status": "idle",
"created_at": now,
"updated_at": now,
"metadata": metadata or {},
"values": values or {},
},
)
else:
val = dict(existing)
val["updated_at"] = now
if metadata:
val.setdefault("metadata", {}).update(metadata)
if values:
val.setdefault("values", {}).update(values)
await _store_put(store, val)
def _derive_thread_status(checkpoint_tuple) -> str:
"""Derive thread status from checkpoint metadata."""
if checkpoint_tuple is None:
@@ -215,23 +186,19 @@ def _derive_thread_status(checkpoint_tuple) -> str:
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread.
Cleans DeerFlow-managed thread directories, removes checkpoint data,
and removes the thread record from the Store.
and removes the thread_meta row from the configured ThreadMetaStore
(sqlite or memory).
"""
from app.gateway.deps import get_thread_meta_repo
# Clean local filesystem
response = _delete_thread_data(thread_id)
# Remove from Store (best-effort)
store = get_store(request)
if store is not None:
try:
await store.adelete(THREADS_NS, thread_id)
except Exception:
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
# Remove checkpoints (best-effort)
checkpointer = getattr(request.app.state, "checkpointer", None)
if checkpointer is not None:
@@ -239,7 +206,15 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
if hasattr(checkpointer, "adelete_thread"):
await checkpointer.adelete_thread(thread_id)
except Exception:
logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id)
logger.debug("Could not delete checkpoints for thread %s (not critical)", sanitize_log_param(thread_id))
# Remove thread_meta row (best-effort) — required for sqlite backend
# so the deleted thread no longer appears in /threads/search.
try:
thread_meta_repo = get_thread_meta_repo(request)
await thread_meta_repo.delete(thread_id)
except Exception:
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
return response
@@ -248,43 +223,40 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
"""Create a new thread.
The thread record is written to the Store (for fast listing) and an
empty checkpoint is written to the checkpointer (for state reads).
Writes a thread_meta record (so the thread appears in /threads/search)
and an empty checkpoint (so state endpoints work immediately).
Idempotent: returns the existing record when ``thread_id`` already exists.
"""
store = get_store(request)
from app.gateway.deps import get_thread_meta_repo
checkpointer = get_checkpointer(request)
thread_meta_repo = get_thread_meta_repo(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = time.time()
# ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
# Idempotency: return existing record from Store when already present
if store is not None:
existing_record = await _store_get(store, thread_id)
if existing_record is not None:
return ThreadResponse(
thread_id=thread_id,
status=existing_record.get("status", "idle"),
created_at=str(existing_record.get("created_at", "")),
updated_at=str(existing_record.get("updated_at", "")),
metadata=existing_record.get("metadata", {}),
)
# Idempotency: return existing record when already present
existing_record = await thread_meta_repo.get(thread_id)
if existing_record is not None:
return ThreadResponse(
thread_id=thread_id,
status=existing_record.get("status", "idle"),
created_at=str(existing_record.get("created_at", "")),
updated_at=str(existing_record.get("updated_at", "")),
metadata=existing_record.get("metadata", {}),
)
# Write thread record to Store
if store is not None:
try:
await _store_put(
store,
{
"thread_id": thread_id,
"status": "idle",
"created_at": now,
"updated_at": now,
"metadata": body.metadata,
},
)
except Exception:
logger.exception("Failed to write thread %s to store", thread_id)
raise HTTPException(status_code=500, detail="Failed to create thread")
# Write thread_meta so the thread appears in /threads/search immediately
try:
await thread_meta_repo.create(
thread_id,
assistant_id=getattr(body, "assistant_id", None),
metadata=body.metadata,
)
except Exception:
logger.exception("Failed to write thread_meta for %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread")
# Write an empty checkpoint so state endpoints work immediately
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
@@ -301,10 +273,10 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
}
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
except Exception:
logger.exception("Failed to create checkpoint for thread %s", thread_id)
logger.exception("Failed to create checkpoint for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread")
logger.info("Thread created: %s", thread_id)
logger.info("Thread created: %s", sanitize_log_param(thread_id))
return ThreadResponse(
thread_id=thread_id,
status="idle",
@@ -318,166 +290,91 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
"""Search and list threads.
Two-phase approach:
**Phase 1 — Store (fast path, O(threads))**: returns threads that were
created or run through this Gateway. Store records are tiny metadata
dicts so fetching all of them at once is cheap.
**Phase 2 — Checkpointer supplement (lazy migration)**: threads that
were created directly by LangGraph Server (and therefore absent from the
Store) are discovered here by iterating the shared checkpointer. Any
newly found thread is immediately written to the Store so that the next
search skips Phase 2 for that thread — the Store converges to a full
index over time without a one-shot migration job.
Delegates to the configured ThreadMetaStore implementation
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
from app.gateway.deps import get_thread_meta_repo
# -----------------------------------------------------------------------
# Phase 1: Store
# -----------------------------------------------------------------------
merged: dict[str, ThreadResponse] = {}
if store is not None:
try:
items = await store.asearch(THREADS_NS, limit=10_000)
except Exception:
logger.warning("Store search failed — falling back to checkpointer only", exc_info=True)
items = []
for item in items:
val = item.value
merged[val["thread_id"]] = ThreadResponse(
thread_id=val["thread_id"],
status=val.get("status", "idle"),
created_at=str(val.get("created_at", "")),
updated_at=str(val.get("updated_at", "")),
metadata=val.get("metadata", {}),
values=val.get("values", {}),
)
# -----------------------------------------------------------------------
# Phase 2: Checkpointer supplement
# Discovers threads not yet in the Store (e.g. created by LangGraph
# Server) and lazily migrates them so future searches skip this phase.
# -----------------------------------------------------------------------
try:
async for checkpoint_tuple in checkpointer.alist(None):
cfg = getattr(checkpoint_tuple, "config", {})
thread_id = cfg.get("configurable", {}).get("thread_id")
if not thread_id or thread_id in merged:
continue
# Skip sub-graph checkpoints (checkpoint_ns is non-empty for those)
if cfg.get("configurable", {}).get("checkpoint_ns", ""):
continue
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
# Strip LangGraph internal keys from the user-visible metadata dict
user_meta = {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
# Extract state values (title) from the checkpoint's channel_values
checkpoint_data = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint_data.get("channel_values", {})
ckpt_values = {}
if title := channel_values.get("title"):
ckpt_values["title"] = title
thread_resp = ThreadResponse(
thread_id=thread_id,
status=_derive_thread_status(checkpoint_tuple),
created_at=str(ckpt_meta.get("created_at", "")),
updated_at=str(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
metadata=user_meta,
values=ckpt_values,
)
merged[thread_id] = thread_resp
# Lazy migration — write to Store so the next search finds it there
if store is not None:
try:
await _store_upsert(store, thread_id, metadata=user_meta, values=ckpt_values or None)
except Exception:
logger.debug("Failed to migrate thread %s to store (non-fatal)", thread_id)
except Exception:
logger.exception("Checkpointer scan failed during thread search")
# Don't raise — return whatever was collected from Store + partial scan
# -----------------------------------------------------------------------
# Phase 3: Filter → sort → paginate
# -----------------------------------------------------------------------
results = list(merged.values())
if body.metadata:
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
if body.status:
results = [r for r in results if r.status == body.status]
results.sort(key=lambda r: r.updated_at, reverse=True)
return results[body.offset : body.offset + body.limit]
repo = get_thread_meta_repo(request)
rows = await repo.search(
metadata=body.metadata or None,
status=body.status,
limit=body.limit,
offset=body.offset,
)
return [
ThreadResponse(
thread_id=r["thread_id"],
status=r.get("status", "idle"),
created_at=r.get("created_at", ""),
updated_at=r.get("updated_at", ""),
metadata=r.get("metadata", {}),
values={"title": r["display_name"]} if r.get("display_name") else {},
interrupts={},
)
for r in rows
]
@router.patch("/{thread_id}", response_model=ThreadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
"""Merge metadata into a thread record."""
store = get_store(request)
if store is None:
raise HTTPException(status_code=503, detail="Store not available")
from app.gateway.deps import get_thread_meta_repo
record = await _store_get(store, thread_id)
thread_meta_repo = get_thread_meta_repo(request)
record = await thread_meta_repo.get(thread_id)
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
now = time.time()
updated = dict(record)
updated.setdefault("metadata", {}).update(body.metadata)
updated["updated_at"] = now
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
try:
await _store_put(store, updated)
await thread_meta_repo.update_metadata(thread_id, body.metadata)
except Exception:
logger.exception("Failed to patch thread %s", thread_id)
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread")
# Re-read to get the merged metadata + refreshed updated_at
record = await thread_meta_repo.get(thread_id) or record
return ThreadResponse(
thread_id=thread_id,
status=updated.get("status", "idle"),
created_at=str(updated.get("created_at", "")),
updated_at=str(now),
metadata=updated.get("metadata", {}),
status=record.get("status", "idle"),
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
)
@router.get("/{thread_id}", response_model=ThreadResponse)
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
"""Get thread info.
Reads metadata from the Store and derives the accurate execution
status from the checkpointer. Falls back to the checkpointer alone
for threads that pre-date Store adoption (backward compat).
Reads metadata from the ThreadMetaStore and derives the accurate
execution status from the checkpointer. Falls back to the checkpointer
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
"""
store = get_store(request)
from app.gateway.deps import get_thread_meta_repo
thread_meta_repo = get_thread_meta_repo(request)
checkpointer = get_checkpointer(request)
record: dict | None = None
if store is not None:
record = await _store_get(store, thread_id)
record: dict | None = await thread_meta_repo.get(thread_id)
# Derive accurate status from the checkpointer
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get checkpoint for thread %s", thread_id)
logger.exception("Failed to get checkpoint for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread")
if record is None and checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# If the thread exists in the checkpointer but not the store (e.g. legacy
# data), synthesize a minimal store record from the checkpoint metadata.
# If the thread exists in the checkpointer but not in thread_meta (e.g.
# legacy data created before thread_meta adoption), synthesize a minimal
# record from the checkpoint metadata.
if record is None and checkpoint_tuple is not None:
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
record = {
@@ -506,6 +403,7 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
@require_permission("threads", "read", owner_check=True)
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
"""Get the latest state snapshot for a thread.
@@ -518,7 +416,7 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get state for thread %s", thread_id)
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
@@ -555,15 +453,19 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
Writes a new checkpoint that merges *body.values* into the latest
channel values, then syncs any updated ``title`` field back to the Store
so that ``/threads/search`` reflects the change immediately.
channel values, then syncs any updated ``title`` field through the
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
change immediately in both sqlite and memory backends.
"""
from app.gateway.deps import get_thread_meta_repo
checkpointer = get_checkpointer(request)
store = get_store(request)
thread_meta_repo = get_thread_meta_repo(request)
# checkpoint_ns must be present in the config for aput — default to ""
# (the root graph namespace). checkpoint_id is optional; omitting it
@@ -580,7 +482,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
try:
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
except Exception:
logger.exception("Failed to get state for thread %s", thread_id)
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
@@ -614,19 +516,22 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
try:
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
except Exception:
logger.exception("Failed to update state for thread %s", thread_id)
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread state")
new_checkpoint_id: str | None = None
if isinstance(new_config, dict):
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
# Sync title changes to the Store so /threads/search reflects them immediately.
if store is not None and body.values and "title" in body.values:
try:
await _store_upsert(store, thread_id, values={"title": body.values["title"]})
except Exception:
logger.debug("Failed to sync title to store for thread %s (non-fatal)", thread_id)
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
# reflects them immediately in both sqlite and memory backends.
if body.values and "title" in body.values:
new_title = body.values["title"]
if new_title: # Skip empty strings and None
try:
await thread_meta_repo.update_display_name(thread_id, new_title)
except Exception:
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
@@ -638,8 +543,16 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
@require_permission("threads", "read", owner_check=True)
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
"""Get checkpoint history for a thread."""
"""Get checkpoint history for a thread.
Messages are read from the checkpointer's channel values (the
authoritative source) and serialized via
:func:`~deerflow.runtime.serialization.serialize_channel_values`.
Only the latest (first) checkpoint carries the ``messages`` key to
avoid duplicating them across every entry.
"""
checkpointer = get_checkpointer(request)
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
@@ -647,6 +560,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
config["configurable"]["checkpoint_id"] = body.before
entries: list[HistoryEntry] = []
is_latest_checkpoint = True
try:
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
ckpt_config = getattr(checkpoint_tuple, "config", {})
@@ -661,22 +575,42 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
channel_values = checkpoint.get("channel_values", {})
# Build values from checkpoint channel_values
values: dict[str, Any] = {}
if title := channel_values.get("title"):
values["title"] = title
if thread_data := channel_values.get("thread_data"):
values["thread_data"] = thread_data
# Attach messages from checkpointer only for the latest checkpoint
if is_latest_checkpoint:
messages = channel_values.get("messages")
if messages:
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
is_latest_checkpoint = False
# Derive next tasks
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
# Strip LangGraph internal keys from metadata
user_meta = {k: v for k, v in metadata.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
# Keep step for ordering context
if "step" in metadata:
user_meta["step"] = metadata["step"]
entries.append(
HistoryEntry(
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_id,
metadata=metadata,
values=serialize_channel_values(channel_values),
metadata=user_meta,
values=values,
created_at=str(metadata.get("created_at", "")),
next=next_tasks,
)
)
except Exception:
logger.exception("Failed to get history for thread %s", thread_id)
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread history")
return entries
+14 -42
View File
@@ -4,12 +4,12 @@ import logging
import os
import stat
from fastapi import APIRouter, File, HTTPException, UploadFile
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
from pydantic import BaseModel
from deerflow.config.app_config import get_app_config
from app.gateway.authz import require_permission
from deerflow.config.paths import get_paths
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.uploads.manager import (
PathTraversalError,
delete_file_safe,
@@ -54,37 +54,11 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
os.chmod(file_path, writable_mode, **chmod_kwargs)
def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
def _get_uploads_config_value(key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default)
def _auto_convert_documents_enabled() -> bool:
"""Return whether automatic host-side document conversion is enabled.
The secure default is disabled unless an operator explicitly opts in via
uploads.auto_convert_documents in config.yaml.
"""
try:
raw = _get_uploads_config_value("auto_convert_documents", False)
if isinstance(raw, str):
return raw.strip().lower() in {"1", "true", "yes", "on"}
return bool(raw)
except Exception:
return False
@router.post("", response_model=UploadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def upload_files(
thread_id: str,
request: Request,
files: list[UploadFile] = File(...),
) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory."""
@@ -99,12 +73,8 @@ async def upload_files(
uploaded_files = []
sandbox_provider = get_sandbox_provider()
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
sandbox = None
if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
auto_convert_documents = _auto_convert_documents_enabled()
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
for file in files:
if not file.filename:
@@ -123,7 +93,7 @@ async def upload_files(
virtual_path = upload_virtual_path(safe_filename)
if sync_to_sandbox and sandbox is not None:
if sandbox_id != "local":
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, content)
@@ -138,12 +108,12 @@ async def upload_files(
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
file_ext = file_path.suffix.lower()
if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS:
if file_ext in CONVERTIBLE_EXTENSIONS:
md_path = await convert_file_to_markdown(file_path)
if md_path:
md_virtual_path = upload_virtual_path(md_path.name)
if sync_to_sandbox and sandbox is not None:
if sandbox_id != "local":
_make_file_sandbox_writable(md_path)
sandbox.update_file(md_virtual_path, md_path.read_bytes())
@@ -166,7 +136,8 @@ async def upload_files(
@router.get("/list", response_model=dict)
async def list_uploaded_files(thread_id: str) -> dict:
@require_permission("threads", "read", owner_check=True)
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
"""List all files in a thread's uploads directory."""
try:
uploads_dir = get_uploads_dir(thread_id)
@@ -184,7 +155,8 @@ async def list_uploaded_files(thread_id: str) -> dict:
@router.delete("/{filename}")
async def delete_uploaded_file(thread_id: str, filename: str) -> dict:
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
"""Delete a file from a thread's uploads directory."""
try:
uploads_dir = get_uploads_dir(thread_id)
+39 -83
View File
@@ -8,16 +8,17 @@ frames, and consuming stream bridge events. Router modules
from __future__ import annotations
import asyncio
import dataclasses
import json
import logging
import re
import time
from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.utils import sanitize_log_param
from deerflow.runtime import (
END_SENTINEL,
HEARTBEAT_SENTINEL,
@@ -171,71 +172,6 @@ def build_run_config(
# ---------------------------------------------------------------------------
async def _upsert_thread_in_store(store, thread_id: str, metadata: dict | None) -> None:
"""Create or refresh the thread record in the Store.
Called from :func:`start_run` so that threads created via the stateless
``/runs/stream`` endpoint (which never calls ``POST /threads``) still
appear in ``/threads/search`` results.
"""
# Deferred import to avoid circular import with the threads router module.
from app.gateway.routers.threads import _store_upsert
try:
await _store_upsert(store, thread_id, metadata=metadata)
except Exception:
logger.warning("Failed to upsert thread %s in store (non-fatal)", thread_id)
async def _sync_thread_title_after_run(
run_task: asyncio.Task,
thread_id: str,
checkpointer: Any,
store: Any,
) -> None:
"""Wait for *run_task* to finish, then persist the generated title to the Store.
TitleMiddleware writes the generated title to the LangGraph agent state
(checkpointer) but the Gateway's Store record is not updated automatically.
This coroutine closes that gap by reading the final checkpoint after the
run completes and syncing ``values.title`` into the Store record so that
subsequent ``/threads/search`` responses include the correct title.
Runs as a fire-and-forget :func:`asyncio.create_task`; failures are
logged at DEBUG level and never propagate.
"""
# Wait for the background run task to complete (any outcome).
# asyncio.wait does not propagate task exceptions — it just returns
# when the task is done, cancelled, or failed.
await asyncio.wait({run_task})
# Deferred import to avoid circular import with the threads router module.
from app.gateway.routers.threads import _store_get, _store_put
try:
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
if ckpt_tuple is None:
return
channel_values = ckpt_tuple.checkpoint.get("channel_values", {})
title = channel_values.get("title")
if not title:
return
existing = await _store_get(store, thread_id)
if existing is None:
return
updated = dict(existing)
updated.setdefault("values", {})["title"] = title
updated["updated_at"] = time.time()
await _store_put(store, updated)
logger.debug("Synced title %r for thread %s", title, thread_id)
except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id, exc_info=True)
async def start_run(
body: Any,
thread_id: str,
@@ -255,11 +191,25 @@ async def start_run(
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
checkpointer = get_checkpointer(request)
store = get_store(request)
run_ctx = get_run_context(request)
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
if follow_up_to_run_id is None:
run_store = get_run_store(request)
try:
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
if recent_runs and recent_runs[0].get("status") == "success":
follow_up_to_run_id = recent_runs[0]["run_id"]
except Exception:
pass # Don't block run creation
# Enrich base context with per-run field
if follow_up_to_run_id:
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
try:
record = await run_mgr.create_or_reject(
thread_id,
@@ -268,17 +218,28 @@ async def start_run(
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
follow_up_to_run_id=follow_up_to_run_id,
)
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except UnsupportedStrategyError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc
# Ensure the thread is visible in /threads/search, even for threads that
# were never explicitly created via POST /threads (e.g. stateless runs).
store = get_store(request)
if store is not None:
await _upsert_thread_in_store(store, thread_id, body.metadata)
# Upsert thread metadata so the thread appears in /threads/search,
# even for threads that were never explicitly created via POST /threads
# (e.g. stateless runs).
try:
existing = await run_ctx.thread_meta_repo.get(thread_id)
if existing is None:
await run_ctx.thread_meta_repo.create(
thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
)
else:
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input)
@@ -298,8 +259,6 @@ async def start_run(
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
"agent_name",
"is_bootstrap",
}
configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
@@ -313,8 +272,7 @@ async def start_run(
bridge,
run_mgr,
record,
checkpointer=checkpointer,
store=store,
ctx=run_ctx,
agent_factory=agent_factory,
graph_input=graph_input,
config=config,
@@ -326,11 +284,9 @@ async def start_run(
)
record.task = task
# After the run completes, sync the title generated by TitleMiddleware from
# the checkpointer into the Store record so that /threads/search returns the
# correct title instead of an empty values dict.
if store is not None:
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
# Title sync is handled by worker.py's finally block which reads the
# title from the checkpoint and calls thread_meta_repo.update_display_name
# after the run completes.
return record
+6
View File
@@ -0,0 +1,6 @@
"""Shared utility helpers for the Gateway layer."""
def sanitize_log_param(value: str) -> str:
"""Strip control characters to prevent log injection."""
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
+77
View File
@@ -0,0 +1,77 @@
# Docker Test Gap (Section 七 7.4)
This file documents the only **un-executed** test cases from
`backend/docs/AUTH_TEST_PLAN.md` after the full release validation pass.
## Why this gap exists
The release validation environment (sg_dev: `10.251.229.92`) **does not have
a Docker daemon installed**. The TC-DOCKER cases are container-runtime
behavior tests that need an actual Docker engine to spin up
`docker/docker-compose.yaml` services.
```bash
$ ssh sg_dev "which docker; docker --version"
# (empty)
# bash: docker: command not found
```
All other test plan sections were executed against either:
- The local dev box (Mac, all services running locally), or
- The deployed sg_dev instance (gateway + frontend + nginx via SSH tunnel)
## Cases not executed
| Case | Title | What it covers | Why not run |
|---|---|---|---|
| TC-DOCKER-01 | `users.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` |
| TC-DOCKER-02 | Session persistence across container restart | `AUTH_JWT_SECRET` env var keeps cookies valid after `docker compose down && up` | needs `docker compose down/up` |
| TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container |
| TC-DOCKER-04 | IM channels skip AuthMiddleware | Verify Feishu/Slack/Telegram dispatchers run in-container against `http://langgraph:2024` without going through nginx | needs `docker logs` |
| TC-DOCKER-05 | Admin credentials surfacing | **Updated post-simplify** — was "log scrape", now "0600 credential file in `DEER_FLOW_HOME`". The file-based behavior is already validated by TC-1.1 + TC-UPG-13 on sg_dev (non-Docker), so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume |
| TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` |
## Coverage already provided by non-Docker tests
The **auth-relevant** behavior in each Docker case is already exercised by
the test cases that ran on sg_dev or local:
| Docker case | Auth behavior covered by |
|---|---|
| TC-DOCKER-01 (volume persistence) | TC-REENT-01 on sg_dev (admin row survives gateway restart) — same SQLite file, just no container layer between |
| TC-DOCKER-02 (session persistence) | TC-API-02/03/06 (cookie roundtrip), plus TC-REENT-04 (multi-cookie) — JWT verification is process-state-free, container restart is equivalent to `pkill uvicorn && uv run uvicorn` |
| TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs |
| TC-DOCKER-04 (IM channels skip auth) | Code-level only: `app/channels/manager.py` uses `langgraph_sdk` directly with no cookie handling. The langgraph_auth handler is bypassed by going through SDK, not HTTP |
| TC-DOCKER-05 (credential surfacing) | TC-1.1 on sg_dev (file at `~/deer-flow/backend/.deer-flow/admin_initial_credentials.txt`, mode 0600, password 22 chars) — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change |
| TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change |
## Reproduction steps when Docker becomes available
Anyone with `docker` + `docker compose` installed can reproduce the gap by
running the test plan section verbatim. Pre-flight:
```bash
# Required on the host
docker --version # >=24.x
docker compose version # plugin >=2.x
# Required env var (otherwise sessions reset on every container restart)
echo "AUTH_JWT_SECRET=$(python3 -c 'import secrets; print(secrets.token_urlsafe(32))')" \
>> .env
# Optional: pin DEER_FLOW_HOME to a stable host path
echo "DEER_FLOW_HOME=$HOME/deer-flow-data" >> .env
```
Then run TC-DOCKER-01..06 from the test plan as written.
## Decision log
- **Not blocking the release.** The auth-relevant behavior in every Docker
case has an already-validated equivalent on bare metal. The gap is purely
about *container packaging* details (bind mounts, multi-worker, log
collection), not about whether the auth code paths work.
- **TC-DOCKER-05 was updated in place** in `AUTH_TEST_PLAN.md` to reflect
the post-simplify reality (credentials file → 0600 file, no log leak).
The old "grep 'Password:' in docker logs" expectation would have failed
silently and given a false sense of coverage.
File diff suppressed because it is too large Load Diff
+129
View File
@@ -0,0 +1,129 @@
# Authentication Upgrade Guide
DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。
## 核心概念
认证模块采用**始终强制**策略:
- 首次启动时自动创建 admin 账号,随机密码打印到控制台日志
- 认证从一开始就是强制的,无竞争窗口
- 历史对话(升级前创建的 thread)自动迁移到 admin 名下
## 升级步骤
### 1. 更新代码
```bash
git pull origin main
cd backend && make install
```
### 2. 首次启动
```bash
make dev
```
控制台会输出:
```
============================================================
Admin account created on first boot
Email: admin@deerflow.dev
Password: aB3xK9mN_pQ7rT2w
Change it after login: Settings → Account
============================================================
```
如果未登录就重启了服务,不用担心——只要 setup 未完成,每次启动都会重置密码并重新打印到控制台。
### 3. 登录
访问 `http://localhost:2026/login`,使用控制台输出的邮箱和密码登录。
### 4. 修改密码
登录后进入 Settings → Account → Change Password。
### 5. 添加用户(可选)
其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话。
## 安全机制
| 机制 | 说明 |
|------|------|
| JWT HttpOnly Cookie | Token 不暴露给 JavaScript,防止 XSS 窃取 |
| CSRF Double Submit Cookie | 所有 POST/PUT/DELETE 请求需携带 `X-CSRF-Token` |
| bcrypt 密码哈希 | 密码不以明文存储 |
| 多租户隔离 | 用户只能访问自己的 thread |
| HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 |
## 常见操作
### 忘记密码
```bash
cd backend
# 重置 admin 密码
python -m app.gateway.auth.reset_admin
# 重置指定用户密码
python -m app.gateway.auth.reset_admin --email user@example.com
```
会输出新的随机密码。
### 完全重置
删除用户数据库,重启后自动创建新 admin:
```bash
rm -f backend/.deer-flow/users.db
# 重启服务,控制台输出新密码
```
## 数据存储
| 文件 | 内容 |
|------|------|
| `.deer-flow/users.db` | SQLite 用户数据库(密码哈希、角色) |
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) |
### 生产环境建议
```bash
# 生成持久化 JWT 密钥,避免重启后所有用户需重新登录
python -c "import secrets; print(secrets.token_urlsafe(32))"
# 将输出添加到 .env
# AUTH_JWT_SECRET=<生成的密钥>
```
## API 端点
| 端点 | 方法 | 说明 |
|------|------|------|
| `/api/v1/auth/login/local` | POST | 邮箱密码登录(OAuth2 form |
| `/api/v1/auth/register` | POST | 注册新用户(user 角色) |
| `/api/v1/auth/logout` | POST | 登出(清除 cookie |
| `/api/v1/auth/me` | GET | 获取当前用户信息 |
| `/api/v1/auth/change-password` | POST | 修改密码 |
| `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 |
## 兼容性
- **标准模式**`make dev`):完全兼容,admin 自动创建
- **Gateway 模式**`make dev-pro`):完全兼容
- **Docker 部署**:完全兼容,`.deer-flow/users.db` 需持久化卷挂载
- **IM 渠道**Feishu/Slack/Telegram):通过 LangGraph SDK 通信,不经过认证层
- **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响
## 故障排查
| 症状 | 原因 | 解决 |
|------|------|------|
| 启动后没看到密码 | admin 已存在(非首次启动) | 用 `reset_admin` 重置,或删 `users.db` |
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |
+3 -6
View File
@@ -2,12 +2,12 @@
## 概述
DeerFlow 后端提供了完整的文件上传功能,支持多文件上传,并可选地将 Office 文档和 PDF 转换为 Markdown 格式。
DeerFlow 后端提供了完整的文件上传功能,支持多文件上传,并自动将 Office 文档和 PDF 转换为 Markdown 格式。
## 功能特性
- ✅ 支持多文件同时上传
-可选地转换文档为 MarkdownPDF、PPT、Excel、Word
-自动转换文档为 MarkdownPDF、PPT、Excel、Word
- ✅ 文件存储在线程隔离的目录中
- ✅ Agent 自动感知已上传的文件
- ✅ 支持文件列表查询和删除
@@ -86,7 +86,7 @@ DELETE /api/threads/{thread_id}/uploads/{filename}
## 支持的文档格式
以下格式在显式启用 `uploads.auto_convert_documents: true`会自动转换为 Markdown
以下格式会自动转换为 Markdown:
- PDF (`.pdf`)
- PowerPoint (`.ppt`, `.pptx`)
- Excel (`.xls`, `.xlsx`)
@@ -94,8 +94,6 @@ DELETE /api/threads/{thread_id}/uploads/{filename}
转换后的 Markdown 文件会保存在同一目录下,文件名为原文件名 + `.md` 扩展名。
默认情况下,自动转换是关闭的,以避免在网关主机上对不受信任的 Office/PDF 上传执行解析。只有在受信任部署中明确接受此风险时,才应将 `uploads.auto_convert_documents` 设置为 `true`
## Agent 集成
### 自动文件列举
@@ -209,7 +207,6 @@ backend/.deer-flow/threads/
- 最大文件大小:100MB(可在 nginx.conf 中配置 `client_max_body_size`
- 文件名安全性:系统会自动验证文件路径,防止目录遍历攻击
- 线程隔离:每个线程的上传文件相互隔离,无法跨线程访问
- 自动文档转换默认关闭;如需启用,需在 `config.yaml` 中显式设置 `uploads.auto_convert_documents: true`
## 技术实现
+3 -3
View File
@@ -11,7 +11,6 @@
- [x] Add Plan Mode with TodoList middleware
- [x] Add vision model support with ViewImageMiddleware
- [x] Skills system with SKILL.md format
- [x] Replace `time.sleep(5)` with `asyncio.sleep()` in `packages/harness/deerflow/tools/builtins/task_tool.py` (subagent polling)
## Planned Features
@@ -22,9 +21,10 @@
- [ ] Support for more document formats in upload
- [ ] Skill marketplace / remote skill installation
- [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario)
- [ ] Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
- Replace `time.sleep(5)` with `asyncio.sleep()` in `packages/harness/deerflow/tools/builtins/task_tool.py` (subagent polling)
- Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
- [x] Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
- Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
- For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker)
+3
View File
@@ -8,6 +8,9 @@
"graphs": {
"lead_agent": "deerflow.agents:make_lead_agent"
},
"auth": {
"path": "./app/gateway/langgraph_auth.py:auth"
},
"checkpointer": {
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
}
@@ -84,23 +84,76 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
@contextlib.asynccontextmanager
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit — no global state::
async with make_checkpointer() as checkpointer:
app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpointer]:
"""Async context manager that constructs a checkpointer from unified DatabaseConfig."""
if db_config.backend == "memory":
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
async with _async_checkpointer(config.checkpointer) as saver:
yield saver
if db_config.backend == "sqlite":
try:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = db_config.checkpointer_sqlite_path
ensure_sqlite_parent_dir(conn_str)
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
await saver.setup()
yield saver
return
if db_config.backend == "postgres":
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not db_config.postgres_url:
raise ValueError("database.postgres_url is required for the postgres backend")
async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver:
await saver.setup()
yield saver
return
raise ValueError(f"Unknown database backend: {db_config.backend!r}")
@contextlib.asynccontextmanager
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit -- no global state::
async with make_checkpointer() as checkpointer:
app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
Priority:
1. Legacy ``checkpointer:`` config section (backward compatible)
2. Unified ``database:`` config section
3. Default InMemorySaver
"""
config = get_app_config()
# Legacy: standalone checkpointer config takes precedence
if config.checkpointer is not None:
async with _async_checkpointer(config.checkpointer) as saver:
yield saver
return
# Unified database config
db_config = getattr(config, "database", None)
if db_config is not None and db_config.backend != "memory":
async with _async_checkpointer_from_database(db_config) as saver:
yield saver
return
# Default: in-memory
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
@@ -27,7 +27,7 @@ from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
from deerflow.runtime.store._sqlite_utils import resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -67,7 +67,6 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
ensure_sqlite_parent_dir(conn_str)
with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
@@ -1,25 +1,22 @@
import logging
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
from langchain_core.runnables import RunnableConfig
from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.agents_config import load_agent_config
from deerflow.config.app_config import get_app_config
from deerflow.config.memory_config import get_memory_config
from deerflow.config.summarization_config import get_summarization_config
from deerflow.models import create_chat_model
@@ -41,7 +38,7 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
return default_model_name
def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
def _create_summarization_middleware() -> SummarizationMiddleware | None:
"""Create and configure the summarization middleware from config."""
config = get_summarization_config()
@@ -59,13 +56,15 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
# Prepare keep parameter
keep = config.keep.to_tuple()
# Prepare model parameter
# Prepare model parameter.
# Bind "middleware:summarize" tag so RunJournal identifies these LLM calls
# as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time).
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
else:
# Use a lightweight model for summarization to save costs
# Falls back to default model if not explicitly specified
model = create_chat_model(thinking_enabled=False)
model = model.with_config(tags=["middleware:summarize"])
# Prepare kwargs
kwargs = {
@@ -80,11 +79,7 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
if config.summary_prompt is not None:
kwargs["summary_prompt"] = config.summary_prompt
hooks: list[BeforeSummarizationHook] = []
if get_memory_config().enabled:
hooks.append(memory_flush_hook)
return DeerFlowSummarizationMiddleware(**kwargs, before_summarization=hooks)
return SummarizationMiddleware(**kwargs)
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
@@ -291,7 +286,7 @@ def make_lead_agent(config: RunnableConfig):
subagent_enabled = cfg.get("subagent_enabled", False)
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
is_bootstrap = cfg.get("is_bootstrap", False)
agent_name = validate_agent_name(cfg.get("agent_name"))
agent_name = cfg.get("agent_name")
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
@@ -332,7 +327,6 @@ def make_lead_agent(config: RunnableConfig):
"reasoning_effort": reasoning_effort,
"is_plan_mode": is_plan_mode,
"subagent_enabled": subagent_enabled,
"tool_groups": agent_config.tool_groups if agent_config else None,
}
)
@@ -1,109 +0,0 @@
"""Shared helpers for turning conversations into memory update inputs."""
from __future__ import annotations
import re
from copy import copy
from typing import Any
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
def extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Keep only user inputs and final assistant responses for memory updates."""
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content_str = extract_message_text(msg)
if "<uploaded_files>" in content_str:
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
if not stripped:
skip_next_ai = True
continue
clean_msg = copy(msg)
clean_msg.content = stripped
filtered.append(clean_msg)
skip_next_ai = False
else:
filtered.append(msg)
skip_next_ai = False
elif msg_type == "ai":
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
if skip_next_ai:
skip_next_ai = False
continue
filtered.append(msg)
return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns."""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = extract_message_text(msg).strip()
if content and any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns."""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = extract_message_text(msg).strip()
if content and any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
@@ -61,88 +61,48 @@ class MemoryUpdateQueue:
return
with self._lock:
self._enqueue_locked(
existing_context = next(
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
# Check if this thread already has a pending update
# If so, replace it with the newer one
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue.append(context)
# Reset or start the debounce timer
self._reset_timer()
logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue))
def add_nowait(
self,
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
"""Add a conversation and start processing immediately in the background."""
config = get_memory_config()
if not config.enabled:
return
with self._lock:
self._enqueue_locked(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
self._schedule_timer(0)
logger.info("Memory update queued for immediate processing on thread %s, queue size: %d", thread_id, len(self._queue))
def _enqueue_locked(
self,
*,
thread_id: str,
messages: list[Any],
agent_name: str | None,
correction_detected: bool,
reinforcement_detected: bool,
) -> None:
existing_context = next(
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue.append(context)
def _reset_timer(self) -> None:
"""Reset the debounce timer."""
config = get_memory_config()
self._schedule_timer(config.debounce_seconds)
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
def _schedule_timer(self, delay_seconds: float) -> None:
"""Schedule queue processing after the provided delay."""
# Cancel existing timer if any
if self._timer is not None:
self._timer.cancel()
# Start new timer
self._timer = threading.Timer(
delay_seconds,
config.debounce_seconds,
self._process_queue,
)
self._timer.daemon = True
self._timer.start()
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
def _process_queue(self) -> None:
"""Process all queued conversation contexts."""
# Import here to avoid circular dependency
@@ -150,8 +110,8 @@ class MemoryUpdateQueue:
with self._lock:
if self._processing:
# Preserve immediate flush semantics even if another worker is active.
self._schedule_timer(0)
# Already processing, reschedule
self._reset_timer()
return
if not self._queue:
@@ -204,13 +164,6 @@ class MemoryUpdateQueue:
self._process_queue()
def flush_nowait(self) -> None:
"""Start queue processing immediately in a background thread."""
with self._lock:
# Daemon thread: queued messages may be lost if the process exits
# before _process_queue completes. Acceptable for best-effort memory updates.
self._schedule_timer(0)
def clear(self) -> None:
"""Clear the queue without processing.
@@ -4,7 +4,6 @@ import abc
import json
import logging
import threading
import uuid
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
@@ -67,8 +66,6 @@ class FileMemoryStorage(MemoryStorage):
# Per-agent memory cache: keyed by agent_name (None = global)
# Value: (memory_data, file_mtime)
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
# Guards all reads and writes to _memory_cache across concurrent callers.
self._cache_lock = threading.Lock()
def _validate_agent_name(self, agent_name: str) -> None:
"""Validate that the agent name is safe to use in filesystem paths.
@@ -117,17 +114,14 @@ class FileMemoryStorage(MemoryStorage):
except OSError:
current_mtime = None
with self._cache_lock:
cached = self._memory_cache.get(agent_name)
if cached is not None and cached[1] == current_mtime:
return cached[0]
cached = self._memory_cache.get(agent_name)
memory_data = self._load_memory_from_file(agent_name)
with self._cache_lock:
if cached is None or cached[1] != current_mtime:
memory_data = self._load_memory_from_file(agent_name)
self._memory_cache[agent_name] = (memory_data, current_mtime)
return memory_data
return memory_data
return cached[0]
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
"""Reload memory data from file, forcing cache invalidation."""
@@ -139,8 +133,7 @@ class FileMemoryStorage(MemoryStorage):
except OSError:
mtime = None
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
self._memory_cache[agent_name] = (memory_data, mtime)
return memory_data
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
@@ -149,12 +142,9 @@ class FileMemoryStorage(MemoryStorage):
try:
file_path.parent.mkdir(parents=True, exist_ok=True)
# Shallow-copy before adding lastUpdated so the caller's dict is not
# mutated as a side-effect, and the cache reference is not silently
# updated before the file write succeeds.
memory_data = {**memory_data, "lastUpdated": utc_now_iso_z()}
memory_data["lastUpdated"] = utc_now_iso_z()
temp_path = file_path.with_suffix(f".{uuid.uuid4().hex}.tmp")
temp_path = file_path.with_suffix(".tmp")
with open(temp_path, "w", encoding="utf-8") as f:
json.dump(memory_data, f, indent=2, ensure_ascii=False)
@@ -165,8 +155,7 @@ class FileMemoryStorage(MemoryStorage):
except OSError:
mtime = None
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
self._memory_cache[agent_name] = (memory_data, mtime)
logger.info("Memory saved to %s", file_path)
return True
except OSError as e:
@@ -1,31 +0,0 @@
"""Hooks fired before summarization removes messages from state."""
from __future__ import annotations
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
from deerflow.config.memory_config import get_memory_config
def memory_flush_hook(event: SummarizationEvent) -> None:
"""Flush messages about to be summarized into the memory queue."""
if not get_memory_config().enabled or not event.thread_id:
return
filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize))
user_messages = [message for message in filtered_messages if getattr(message, "type", None) == "human"]
assistant_messages = [message for message in filtered_messages if getattr(message, "type", None) == "ai"]
if not user_messages or not assistant_messages:
return
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue()
queue.add_nowait(
thread_id=event.thread_id,
messages=filtered_messages,
agent_name=event.agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@@ -1,15 +1,10 @@
"""Memory updater for reading, writing, and updating memory data."""
import asyncio
import atexit
import concurrent.futures
import copy
import json
import logging
import math
import re
import uuid
from collections.abc import Awaitable
from typing import Any
from deerflow.agents.memory.prompt import (
@@ -26,12 +21,6 @@ from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
max_workers=4,
thread_name_prefix="memory-updater-sync",
)
atexit.register(lambda: _SYNC_MEMORY_UPDATER_EXECUTOR.shutdown(wait=False))
def _create_empty_memory() -> dict[str, Any]:
"""Backward-compatible wrapper around the storage-layer empty-memory factory."""
@@ -217,39 +206,6 @@ def _extract_text(content: Any) -> str:
return str(content)
def _run_async_update_sync(coro: Awaitable[bool]) -> bool:
"""Run an async memory update from sync code, including nested-loop contexts."""
handed_off = False
try:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro)
handed_off = True
return future.result()
handed_off = True
return asyncio.run(coro)
except Exception:
if not handed_off:
close = getattr(coro, "close", None)
if callable(close):
try:
close()
except Exception:
logger.debug(
"Failed to close un-awaited memory update coroutine",
exc_info=True,
)
logger.exception("Failed to run async memory update from sync context")
return False
# Matches sentences that describe a file-upload *event* rather than general
# file-related work. Deliberately narrow to avoid removing legitimate facts
# such as "User works with CSV files" or "prefers PDF export".
@@ -313,117 +269,6 @@ class MemoryUpdater:
model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False)
def _build_correction_hint(
self,
correction_detected: bool,
reinforcement_detected: bool,
) -> str:
"""Build optional prompt hints for correction and reinforcement signals."""
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
return correction_hint
def _prepare_update_prompt(
self,
messages: list[Any],
agent_name: str | None,
correction_detected: bool,
reinforcement_detected: bool,
) -> tuple[dict[str, Any], str] | None:
"""Load memory and build the update prompt for a conversation."""
config = get_memory_config()
if not config.enabled or not messages:
return None
current_memory = get_memory_data(agent_name)
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return None
correction_hint = self._build_correction_hint(
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
)
return current_memory, prompt
def _finalize_update(
self,
current_memory: dict[str, Any],
response_content: Any,
thread_id: str | None,
agent_name: str | None,
) -> bool:
"""Parse the model response, apply updates, and persist memory."""
response_text = _extract_text(response_content).strip()
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Deep-copy before in-place mutation so a subsequent save() failure
# cannot corrupt the still-cached original object reference.
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
return get_memory_storage().save(updated_memory, agent_name)
async def aupdate_memory(
self,
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Update memory asynchronously based on conversation messages."""
try:
prepared = await asyncio.to_thread(
self._prepare_update_prompt,
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
if prepared is None:
return False
current_memory, prompt = prepared
model = self._get_model()
response = await model.ainvoke(prompt)
return await asyncio.to_thread(
self._finalize_update,
current_memory=current_memory,
response_content=response.content,
thread_id=thread_id,
agent_name=agent_name,
)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
def update_memory(
self,
messages: list[Any],
@@ -432,7 +277,7 @@ class MemoryUpdater:
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Synchronously update memory via the async updater path.
"""Update memory based on conversation messages.
Args:
messages: List of conversation messages.
@@ -444,15 +289,78 @@ class MemoryUpdater:
Returns:
True if update was successful, False otherwise.
"""
return _run_async_update_sync(
self.aupdate_memory(
messages=messages,
thread_id=thread_id,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
config = get_memory_config()
if not config.enabled:
return False
if not messages:
return False
try:
# Get current memory
current_memory = get_memory_data(agent_name)
# Format conversation for prompt
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return False
# Build prompt
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
)
)
# Call LLM
model = self._get_model()
response = model.invoke(prompt)
response_text = _extract_text(response.content).strip()
# Parse response
# Remove markdown code blocks if present
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Apply updates
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
# Strip file-upload mentions from all summaries before saving.
# Uploaded files are session-scoped and won't exist in future sessions,
# so recording upload events in long-term memory causes the agent to
# try (and fail) to locate those files in subsequent conversations.
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save
return get_memory_storage().save(updated_memory, agent_name)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
def _apply_updates(
self,
@@ -13,7 +13,6 @@ at the correct positions (immediately after each dangling AIMessage), not append
to the end of the message list as before_model + add_messages reducer would do.
"""
import json
import logging
from collections.abc import Awaitable, Callable
from typing import override
@@ -34,44 +33,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
offending AIMessage so the LLM receives a well-formed conversation.
"""
@staticmethod
def _message_tool_calls(msg) -> list[dict]:
"""Return normalized tool calls from structured fields or raw provider payloads."""
tool_calls = getattr(msg, "tool_calls", None) or []
if tool_calls:
return list(tool_calls)
raw_tool_calls = (getattr(msg, "additional_kwargs", None) or {}).get("tool_calls") or []
normalized: list[dict] = []
for raw_tc in raw_tool_calls:
if not isinstance(raw_tc, dict):
continue
function = raw_tc.get("function")
name = raw_tc.get("name")
if not name and isinstance(function, dict):
name = function.get("name")
args = raw_tc.get("args", {})
if not args and isinstance(function, dict):
raw_args = function.get("arguments")
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except (TypeError, ValueError, json.JSONDecodeError):
parsed_args = {}
args = parsed_args if isinstance(parsed_args, dict) else {}
normalized.append(
{
"id": raw_tc.get("id"),
"name": name or "unknown",
"args": args if isinstance(args, dict) else {},
}
)
return normalized
def _build_patched_messages(self, messages: list) -> list | None:
"""Return a new message list with patches inserted at the correct positions.
@@ -90,7 +51,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
for msg in messages:
if getattr(msg, "type", None) != "ai":
continue
for tc in self._message_tool_calls(msg):
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids:
needs_patch = True
@@ -109,7 +70,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
patched.append(msg)
if getattr(msg, "type", None) != "ai":
continue
for tc in self._message_tool_calls(msg):
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
patched.append(
@@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio
import logging
import threading
import time
from collections.abc import Awaitable, Callable
from email.utils import parsedate_to_datetime
@@ -20,8 +19,6 @@ from langchain.agents.middleware.types import (
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
from deerflow.config import get_app_config
logger = logging.getLogger(__name__)
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
@@ -70,80 +67,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000
circuit_failure_threshold: int = 5
circuit_recovery_timeout_sec: int = 60
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
# Load Circuit Breaker configs from app config if available, fall back to defaults
try:
app_config = get_app_config()
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
except (FileNotFoundError, RuntimeError):
# Gracefully fall back to class defaults in test environments
pass
# Circuit Breaker state
self._circuit_lock = threading.Lock()
self._circuit_failure_count = 0
self._circuit_open_until = 0.0
self._circuit_state = "closed"
self._circuit_probe_in_flight = False
def _check_circuit(self) -> bool:
"""Returns True if circuit is OPEN (fast fail), False otherwise."""
with self._circuit_lock:
now = time.time()
if self._circuit_state == "open":
if now < self._circuit_open_until:
return True
self._circuit_state = "half_open"
self._circuit_probe_in_flight = False
if self._circuit_state == "half_open":
if self._circuit_probe_in_flight:
return True
self._circuit_probe_in_flight = True
return False
return False
def _record_success(self) -> None:
with self._circuit_lock:
if self._circuit_state != "closed" or self._circuit_failure_count > 0:
logger.info("Circuit breaker reset (Closed). LLM service recovered.")
self._circuit_failure_count = 0
self._circuit_open_until = 0.0
self._circuit_state = "closed"
self._circuit_probe_in_flight = False
def _record_failure(self) -> None:
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
self._circuit_state = "open"
self._circuit_probe_in_flight = False
logger.error(
"Circuit breaker probe failed (Open). Will probe again after %ds.",
self.circuit_recovery_timeout_sec,
)
return
self._circuit_failure_count += 1
if self._circuit_failure_count >= self.circuit_failure_threshold:
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
if self._circuit_state != "open":
self._circuit_state = "open"
self._circuit_probe_in_flight = False
logger.error(
"Circuit breaker tripped (Open). Threshold reached (%d). Will probe after %ds.",
self.circuit_failure_threshold,
self.circuit_recovery_timeout_sec,
)
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
detail = _extract_error_detail(exc)
lowered = detail.lower()
@@ -181,9 +104,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
def _build_circuit_breaker_message(self) -> str:
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc)
if reason == "quota":
@@ -218,20 +138,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
if self._check_circuit():
return AIMessage(content=self._build_circuit_breaker_message())
attempt = 1
while True:
try:
response = handler(request)
self._record_success()
return response
return handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_probe_in_flight = False
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
@@ -254,8 +166,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
_extract_error_detail(exc),
exc_info=exc,
)
if retriable:
self._record_failure()
return AIMessage(content=self._build_user_message(exc, reason))
@override
@@ -264,20 +174,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
if self._check_circuit():
return AIMessage(content=self._build_circuit_breaker_message())
attempt = 1
while True:
try:
response = await handler(request)
self._record_success()
return response
return await handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_probe_in_flight = False
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
@@ -300,8 +202,6 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
_extract_error_detail(exc),
exc_info=exc,
)
if retriable:
self._record_failure()
return AIMessage(content=self._build_user_message(exc, reason))
@@ -17,7 +17,6 @@ import json
import logging
import threading
from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import override
from langchain.agents import AgentState
@@ -32,8 +31,6 @@ _DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
@@ -128,14 +125,8 @@ def _hash_tool_calls(tool_calls: list[dict]) -> str:
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
_TOOL_FREQ_WARNING_MSG = (
"[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
)
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
_TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far."
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops.
@@ -149,12 +140,6 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
Default: 20.
max_tracked_threads: Maximum number of threads to track before
evicting the least recently used. Default: 100.
tool_freq_warn: Number of calls to the same tool *type* (regardless
of arguments) before injecting a frequency warning. Catches
cross-file read loops that hash-based detection misses.
Default: 30.
tool_freq_hard_limit: Number of calls to the same tool type before
forcing a stop. Default: 50.
"""
def __init__(
@@ -163,23 +148,16 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
hard_limit: int = _DEFAULT_HARD_LIMIT,
window_size: int = _DEFAULT_WINDOW_SIZE,
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
):
super().__init__()
self.warn_threshold = warn_threshold
self.hard_limit = hard_limit
self.window_size = window_size
self.max_tracked_threads = max_tracked_threads
self.tool_freq_warn = tool_freq_warn
self.tool_freq_hard_limit = tool_freq_hard_limit
self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set)
# Per-thread, per-tool-type cumulative call counts
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
@@ -196,19 +174,11 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
while len(self._history) > self.max_tracked_threads:
evicted_id, _ = self._history.popitem(last=False)
self._warned.pop(evicted_id, None)
self._tool_freq.pop(evicted_id, None)
self._tool_freq_warned.pop(evicted_id, None)
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
"""Track tool calls and check for loops.
Two detection layers:
1. **Hash-based** (existing): catches identical tool call sets.
2. **Frequency-based** (new): catches the same *tool type* being
called many times with varying arguments (e.g. ``read_file``
on 40 different files).
Returns:
(warning_message_or_none, should_hard_stop)
"""
@@ -243,7 +213,6 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
count = history.count(call_hash)
tool_names = [tc.get("name", "?") for tc in tool_calls]
# --- Layer 1: hash-based (identical call sets) ---
if count >= self.hard_limit:
logger.error(
"Loop hard limit reached — forcing stop",
@@ -270,40 +239,8 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
},
)
return _WARNING_MSG, False
# --- Layer 2: per-tool-type frequency ---
freq = self._tool_freq[thread_id]
for tc in tool_calls:
name = tc.get("name", "")
if not name:
continue
freq[name] += 1
tc_count = freq[name]
if tc_count >= self.tool_freq_hard_limit:
logger.error(
"Tool frequency hard limit reached — forcing stop",
extra={
"thread_id": thread_id,
"tool_name": name,
"count": tc_count,
},
)
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
if tc_count >= self.tool_freq_warn:
warned = self._tool_freq_warned[thread_id]
if name not in warned:
warned.add(name)
logger.warning(
"Tool frequency warning — too many calls to same tool type",
extra={
"thread_id": thread_id,
"tool_name": name,
"count": tc_count,
},
)
return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False
# Warning already injected for this hash — suppress
return None, False
return None, False
@@ -324,26 +261,6 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
# Fallback: coerce unexpected types to str to avoid TypeError
return str(content) + f"\n\n{text}"
@staticmethod
def _build_hard_stop_update(last_msg, content: str | list) -> dict:
"""Clear tool-call metadata so forced-stop messages serialize as plain assistant text."""
update = {
"tool_calls": [],
"content": content,
}
additional_kwargs = dict(getattr(last_msg, "additional_kwargs", {}) or {})
for key in ("tool_calls", "function_call"):
additional_kwargs.pop(key, None)
update["additional_kwargs"] = additional_kwargs
response_metadata = deepcopy(getattr(last_msg, "response_metadata", {}) or {})
if response_metadata.get("finish_reason") == "tool_calls":
response_metadata["finish_reason"] = "stop"
update["response_metadata"] = response_metadata
return update
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime)
@@ -351,8 +268,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
# Strip tool_calls from the last AIMessage to force text output
messages = state.get("messages", [])
last_msg = messages[-1]
content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG)
stripped_msg = last_msg.model_copy(update=self._build_hard_stop_update(last_msg, content))
stripped_msg = last_msg.model_copy(
update={
"tool_calls": [],
"content": self._append_text(last_msg.content, _HARD_STOP_MSG),
}
)
return {"messages": [stripped_msg]}
if warning:
@@ -380,10 +301,6 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
if thread_id:
self._history.pop(thread_id, None)
self._warned.pop(thread_id, None)
self._tool_freq.pop(thread_id, None)
self._tool_freq_warned.pop(thread_id, None)
else:
self._history.clear()
self._warned.clear()
self._tool_freq.clear()
self._tool_freq_warned.clear()
@@ -1,19 +1,50 @@
"""Middleware for memory mechanism."""
import logging
from typing import override
import re
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config
logger = logging.getLogger(__name__)
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
@@ -21,6 +52,125 @@ class MemoryMiddlewareState(AgentState):
pass
def _extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Filter messages to keep only user inputs and final assistant responses.
This filters out:
- Tool messages (intermediate tool call results)
- AI messages with tool_calls (intermediate steps, not final responses)
- The <uploaded_files> block injected by UploadsMiddleware into human messages
(file paths are session-scoped and must not persist in long-term memory).
The user's actual question is preserved; only turns whose content is entirely
the upload block (nothing remains after stripping) are dropped along with
their paired assistant response.
Only keeps:
- Human messages (with the ephemeral upload block removed)
- AI messages without tool_calls (final assistant responses), unless the
paired human turn was upload-only and had no real user text.
Args:
messages: List of all conversation messages.
Returns:
Filtered list containing only user inputs and final assistant responses.
"""
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content_str = _extract_message_text(msg)
if "<uploaded_files>" in content_str:
# Strip the ephemeral upload block; keep the user's real question.
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
if not stripped:
# Nothing left — the entire turn was upload bookkeeping;
# skip it and the paired assistant response.
skip_next_ai = True
continue
# Rebuild the message with cleaned content so the user's question
# is still available for memory summarisation.
from copy import copy
clean_msg = copy(msg)
clean_msg.content = stripped
filtered.append(clean_msg)
skip_next_ai = False
else:
filtered.append(msg)
skip_next_ai = False
elif msg_type == "ai":
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
if skip_next_ai:
skip_next_ai = False
continue
filtered.append(msg)
# Skip tool messages and AI messages with tool_calls
return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale corrections from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns.
Complements detect_correction() by identifying when the user confirms the
agent's approach was correct. This allows the memory system to record what
worked well, not just what went wrong.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale signals from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
@@ -73,7 +223,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
return None
# Filter to only keep user inputs and final assistant responses
filtered_messages = filter_messages_for_memory(messages)
filtered_messages = _filter_messages_for_memory(messages)
# Only queue if there's meaningful conversation
# At minimum need one user message and one assistant response
@@ -1,151 +0,0 @@
"""Summarization middleware extensions for DeerFlow."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Protocol, runtime_checkable
from langchain.agents import AgentState
from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.messages import AnyMessage, RemoveMessage
from langgraph.config import get_config
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class SummarizationEvent:
"""Context emitted before conversation history is summarized away."""
messages_to_summarize: tuple[AnyMessage, ...]
preserved_messages: tuple[AnyMessage, ...]
thread_id: str | None
agent_name: str | None
runtime: Runtime
@runtime_checkable
class BeforeSummarizationHook(Protocol):
"""Hook invoked before summarization removes messages from state."""
def __call__(self, event: SummarizationEvent) -> None: ...
def _resolve_thread_id(runtime: Runtime) -> str | None:
"""Resolve the current thread ID from runtime context or LangGraph config."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
try:
config_data = get_config()
except RuntimeError:
return None
thread_id = config_data.get("configurable", {}).get("thread_id")
return thread_id
def _resolve_agent_name(runtime: Runtime) -> str | None:
"""Resolve the current agent name from runtime context or LangGraph config."""
agent_name = runtime.context.get("agent_name") if runtime.context else None
if agent_name is None:
try:
config_data = get_config()
except RuntimeError:
return None
agent_name = config_data.get("configurable", {}).get("agent_name")
return agent_name
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
"""Summarization middleware with pre-compression hook dispatch."""
def __init__(
self,
*args,
before_summarization: list[BeforeSummarizationHook] | None = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._before_summarization_hooks = before_summarization or []
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._maybe_summarize(state, runtime)
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return await self._amaybe_summarize(state, runtime)
def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = self._create_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = await self._acreate_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
def _fire_hooks(
self,
messages_to_summarize: list[AnyMessage],
preserved_messages: list[AnyMessage],
runtime: Runtime,
) -> None:
if not self._before_summarization_hooks:
return
event = SummarizationEvent(
messages_to_summarize=tuple(messages_to_summarize),
preserved_messages=tuple(preserved_messages),
thread_id=_resolve_thread_id(runtime),
agent_name=_resolve_agent_name(runtime),
runtime=runtime,
)
for hook in self._before_summarization_hooks:
try:
hook(event)
except Exception:
hook_name = getattr(hook, "__name__", None) or type(hook).__name__
logger.exception("before_summarization hook %s failed", hook_name)
@@ -1,11 +1,11 @@
"""Middleware for automatic thread title generation."""
import logging
import re
from typing import NotRequired, override
from typing import Any, NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config
@@ -78,7 +78,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
user_msg = self._normalize_content(user_msg_content)
assistant_msg = self._strip_think_tags(self._normalize_content(assistant_msg_content))
assistant_msg = self._normalize_content(assistant_msg_content)
prompt = config.prompt_template.format(
max_words=config.max_words,
@@ -87,15 +87,10 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
)
return prompt, user_msg
def _strip_think_tags(self, text: str) -> str:
"""Remove <think>...</think> blocks emitted by reasoning models (e.g. minimax, DeepSeek-R1)."""
return re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip()
def _parse_title(self, content: object) -> str:
"""Normalize model output into a clean title string."""
config = get_title_config()
title_content = self._normalize_content(content)
title_content = self._strip_think_tags(title_content)
title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title
@@ -106,6 +101,20 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation"
def _get_runnable_config(self) -> dict[str, Any]:
"""Inherit the parent RunnableConfig and add middleware tag.
This ensures RunJournal identifies LLM calls from this middleware
as ``middleware:title`` instead of ``lead_agent``.
"""
try:
parent = get_config()
except Exception:
parent = {}
config = {**parent}
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
return config
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Generate a local fallback title without blocking on an LLM call."""
if not self._should_generate_title(state):
@@ -127,7 +136,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
model = create_chat_model(name=config.model_name, thinking_enabled=False)
else:
model = create_chat_model(thinking_enabled=False)
response = await model.ainvoke(prompt)
response = await model.ainvoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content)
if title:
return {"title": title}
@@ -1,14 +1,9 @@
"""Middleware that extends TodoListMiddleware with context-loss detection and premature-exit prevention.
"""Middleware that extends TodoListMiddleware with context-loss detection.
When the message history is truncated (e.g., by SummarizationMiddleware), the
original `write_todos` tool call and its ToolMessage can be scrolled out of the
active context window. This middleware detects that situation and injects a
reminder message so the model still knows about the outstanding todo list.
Additionally, this middleware prevents the agent from exiting the loop while
there are still incomplete todo items. When the model produces a final response
(no tool calls) but todos are not yet complete, the middleware injects a reminder
and jumps back to the model node to force continued engagement.
"""
from __future__ import annotations
@@ -17,7 +12,6 @@ from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo
from langchain.agents.middleware.types import hook_config
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime
@@ -40,11 +34,6 @@ def _reminder_in_messages(messages: list[Any]) -> bool:
return False
def _completion_reminder_count(messages: list[Any]) -> int:
"""Return the number of todo_completion_reminder HumanMessages in *messages*."""
return sum(1 for msg in messages if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_completion_reminder")
def _format_todos(todos: list[Todo]) -> str:
"""Format a list of Todo items into a human-readable string."""
lines: list[str] = []
@@ -68,7 +57,7 @@ class TodoMiddleware(TodoListMiddleware):
def before_model(
self,
state: PlanningState,
runtime: Runtime,
runtime: Runtime, # noqa: ARG002
) -> dict[str, Any] | None:
"""Inject a todo-list reminder when write_todos has left the context window."""
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
@@ -109,71 +98,3 @@ class TodoMiddleware(TodoListMiddleware):
) -> dict[str, Any] | None:
"""Async version of before_model."""
return self.before_model(state, runtime)
# Maximum number of completion reminders before allowing the agent to exit.
# This prevents infinite loops when the agent cannot make further progress.
_MAX_COMPLETION_REMINDERS = 2
@hook_config(can_jump_to=["model"])
@override
def after_model(
self,
state: PlanningState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Prevent premature agent exit when todo items are still incomplete.
In addition to the base class check for parallel ``write_todos`` calls,
this override intercepts model responses that have no tool calls while
there are still incomplete todo items. It injects a reminder
``HumanMessage`` and jumps back to the model node so the agent
continues working through the todo list.
A retry cap of ``_MAX_COMPLETION_REMINDERS`` (default 2) prevents
infinite loops when the agent cannot make further progress.
"""
# 1. Preserve base class logic (parallel write_todos detection).
base_result = super().after_model(state, runtime)
if base_result is not None:
return base_result
# 2. Only intervene when the agent wants to exit (no tool calls).
messages = state.get("messages") or []
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
if not last_ai or last_ai.tool_calls:
return None
# 3. Allow exit when all todos are completed or there are no todos.
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
if not todos or all(t.get("status") == "completed" for t in todos):
return None
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS:
return None
# 5. Inject a reminder and force the agent back to the model.
incomplete = [t for t in todos if t.get("status") != "completed"]
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
reminder = HumanMessage(
name="todo_completion_reminder",
content=(
"<system_reminder>\n"
"You have incomplete todo items that must be finished before giving your final response:\n\n"
f"{incomplete_text}\n\n"
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
"as you finish them, and only respond when all items are done.\n"
"</system_reminder>"
),
)
return {"jump_to": "model", "messages": [reminder]}
@override
@hook_config(can_jump_to=["model"])
async def aafter_model(
self,
state: PlanningState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async version of after_model."""
return self.after_model(state, runtime)
@@ -262,25 +262,21 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
files_message = self._create_files_message(new_files, historical_files)
# Extract original content - handle both string and list formats
original_content = last_message.content
if isinstance(original_content, str):
# Simple case: string content, just prepend files message
updated_content = f"{files_message}\n\n{original_content}"
elif isinstance(original_content, list):
# Complex case: list content (multimodal), preserve all blocks
# Prepend files message as the first text block
files_block = {"type": "text", "text": f"{files_message}\n\n"}
# Keep all original blocks (including images)
updated_content = [files_block, *original_content]
else:
# Other types, preserve as-is
updated_content = original_content
original_content = ""
if isinstance(last_message.content, str):
original_content = last_message.content
elif isinstance(last_message.content, list):
text_parts = []
for block in last_message.content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
original_content = "\n".join(text_parts)
# Create new message with combined content.
# Preserve additional_kwargs (including files metadata) so the frontend
# can read structured file info from the streamed message.
updated_message = HumanMessage(
content=updated_content,
content=f"{files_message}\n\n{original_content}",
id=last_message.id,
additional_kwargs=last_message.additional_kwargs,
)
+1 -6
View File
@@ -722,10 +722,6 @@ class DeerFlowClient:
Dict with "models" key containing list of model info dicts,
matching the Gateway API ``ModelsListResponse`` schema.
"""
token_usage_enabled = getattr(getattr(self._app_config, "token_usage", None), "enabled", False)
if not isinstance(token_usage_enabled, bool):
token_usage_enabled = False
return {
"models": [
{
@@ -737,8 +733,7 @@ class DeerFlowClient:
"supports_reasoning_effort": getattr(model, "supports_reasoning_effort", False),
}
for model in self._app_config.models
],
"token_usage": {"enabled": token_usage_enabled},
]
}
def list_skills(self, enabled_only: bool = False) -> dict:
@@ -119,16 +119,6 @@ class AioSandboxProvider(SandboxProvider):
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
self._start_idle_checker()
@property
def uses_thread_data_mounts(self) -> bool:
"""Whether thread workspace/uploads/outputs are visible via mounts.
Local container backends bind-mount the thread data directories, so files
written by the gateway are already visible when the sandbox starts.
Remote backends may require explicit file sync.
"""
return isinstance(self._backend, LocalContainerBackend)
# ── Factory methods ──────────────────────────────────────────────────
def _create_backend(self) -> SandboxBackend:
@@ -1,5 +1,3 @@
import asyncio
from langchain.tools import tool
from deerflow.community.jina_ai.jina_client import JinaClient
@@ -28,5 +26,5 @@ async def web_fetch_tool(url: str) -> str:
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
if isinstance(html_content, str) and html_content.startswith("Error:"):
return html_content
article = await asyncio.to_thread(readability_extractor.extract_article, html_content)
article = readability_extractor.extract_article(html_content)
return article.to_markdown()[:4096]
@@ -1,32 +0,0 @@
"""Configuration for the custom agents management API."""
from pydantic import BaseModel, Field
class AgentsApiConfig(BaseModel):
"""Configuration for custom-agent and user-profile management routes."""
enabled: bool = Field(
default=False,
description=("Whether to expose the custom-agent management API over HTTP. When disabled, the gateway rejects read/write access to custom agent SOUL.md, config, and USER.md prompt-management routes."),
)
_agents_api_config: AgentsApiConfig = AgentsApiConfig()
def get_agents_api_config() -> AgentsApiConfig:
"""Get the current agents API configuration."""
return _agents_api_config
def set_agents_api_config(config: AgentsApiConfig) -> None:
"""Set the agents API configuration."""
global _agents_api_config
_agents_api_config = config
def load_agents_api_config_from_dict(config_dict: dict) -> None:
"""Load agents API configuration from a dictionary."""
global _agents_api_config
_agents_api_config = AgentsApiConfig(**config_dict)
@@ -15,17 +15,6 @@ SOUL_FILENAME = "SOUL.md"
AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
def validate_agent_name(name: str | None) -> str | None:
"""Validate a custom agent name before using it in filesystem paths."""
if name is None:
return None
if not isinstance(name, str):
raise ValueError("Invalid agent name. Expected a string or None.")
if not AGENT_NAME_PATTERN.fullmatch(name):
raise ValueError(f"Invalid agent name '{name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
return name
class AgentConfig(BaseModel):
"""Configuration for a custom agent."""
@@ -57,7 +46,8 @@ def load_agent_config(name: str | None) -> AgentConfig | None:
if name is None:
return None
name = validate_agent_name(name)
if not AGENT_NAME_PATTERN.match(name):
raise ValueError(f"Invalid agent name '{name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
agent_dir = get_paths().agent_dir(name)
config_file = agent_dir / "config.yaml"
@@ -9,12 +9,13 @@ from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig
@@ -31,13 +32,6 @@ load_dotenv()
logger = logging.getLogger(__name__)
class CircuitBreakerConfig(BaseModel):
"""Configuration for the LLM Circuit Breaker."""
failure_threshold: int = Field(default=5, description="Number of consecutive failures before tripping the circuit")
recovery_timeout_sec: int = Field(default=60, description="Time in seconds before attempting to recover the circuit")
def _default_config_candidates() -> tuple[Path, ...]:
"""Return deterministic config.yaml locations without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
@@ -61,11 +55,11 @@ class AppConfig(BaseModel):
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
agents_api: AgentsApiConfig = Field(default_factory=AgentsApiConfig, description="Custom-agent management API configuration")
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
model_config = ConfigDict(extra="allow", frozen=False)
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
@@ -127,10 +121,6 @@ class AppConfig(BaseModel):
if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"])
# Always refresh agents API config so removed config sections reset
# singleton-backed state to its default/disabled values on reload.
load_agents_api_config_from_dict(config_data.get("agents_api") or {})
# Load subagents config if present
if "subagents" in config_data:
load_subagents_config_from_dict(config_data["subagents"])
@@ -143,10 +133,6 @@ class AppConfig(BaseModel):
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load circuit_breaker config if present
if "circuit_breaker" in config_data:
config_data["circuit_breaker"] = config_data["circuit_breaker"]
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])
@@ -0,0 +1,92 @@
"""Unified database backend configuration.
Controls BOTH the LangGraph checkpointer and the DeerFlow application
persistence layer (runs, threads metadata, users, etc.). The user
configures one backend; the system handles physical separation details.
SQLite mode: checkpointer and app use different .db files in the same
directory to avoid write-lock contention. This is automatic.
Postgres mode: both use the same database URL but maintain independent
connection pools with different lifecycles.
Memory mode: checkpointer uses MemorySaver, app uses in-memory stores.
No database is initialized.
Sensitive values (postgres_url) should use $VAR syntax in config.yaml
to reference environment variables from .env:
database:
backend: postgres
postgres_url: $DATABASE_URL
The $VAR resolution is handled by AppConfig.resolve_env_variables()
before this config is instantiated -- DatabaseConfig itself does not
need to do any environment variable processing.
"""
from __future__ import annotations
import os
from typing import Literal
from pydantic import BaseModel, Field
class DatabaseConfig(BaseModel):
backend: Literal["memory", "sqlite", "postgres"] = Field(
default="memory",
description=("Storage backend for both checkpointer and application data. 'memory' for development (no persistence across restarts), 'sqlite' for single-node deployment, 'postgres' for production multi-node deployment."),
)
sqlite_dir: str = Field(
default=".deer-flow/data",
description=("Directory for SQLite database files. Checkpointer uses {sqlite_dir}/checkpoints.db, application data uses {sqlite_dir}/app.db."),
)
postgres_url: str = Field(
default="",
description=(
"PostgreSQL connection URL, shared by checkpointer and app. "
"Use $DATABASE_URL in config.yaml to reference .env. "
"Example: postgresql://user:pass@host:5432/deerflow "
"(the +asyncpg driver suffix is added automatically where needed)."
),
)
echo_sql: bool = Field(
default=False,
description="Echo all SQL statements to log (debug only).",
)
pool_size: int = Field(
default=5,
description="Connection pool size for the app ORM engine (postgres only).",
)
# -- Derived helpers (not user-configured) --
@property
def _resolved_sqlite_dir(self) -> str:
"""Resolve sqlite_dir to an absolute path (relative to CWD)."""
from pathlib import Path
return str(Path(self.sqlite_dir).resolve())
@property
def checkpointer_sqlite_path(self) -> str:
"""SQLite file path for the LangGraph checkpointer."""
return os.path.join(self._resolved_sqlite_dir, "checkpoints.db")
@property
def app_sqlite_path(self) -> str:
"""SQLite file path for application ORM data."""
return os.path.join(self._resolved_sqlite_dir, "app.db")
@property
def app_sqlalchemy_url(self) -> str:
"""SQLAlchemy async URL for the application ORM engine."""
if self.backend == "sqlite":
return f"sqlite+aiosqlite:///{self.app_sqlite_path}"
if self.backend == "postgres":
url = self.postgres_url
if url.startswith("postgresql://"):
url = url.replace("postgresql://", "postgresql+asyncpg://", 1)
return url
raise ValueError(f"No SQLAlchemy URL for backend={self.backend!r}")
@@ -0,0 +1,33 @@
"""Run event storage configuration.
Controls where run events (messages + execution traces) are persisted.
Backends:
- memory: In-memory storage, data lost on restart. Suitable for
development and testing.
- db: SQL database via SQLAlchemy ORM. Provides full query capability.
Suitable for production deployments.
- jsonl: Append-only JSONL files. Lightweight alternative for
single-node deployments that need persistence without a database.
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
class RunEventsConfig(BaseModel):
backend: Literal["memory", "db", "jsonl"] = Field(
default="memory",
description="Storage backend for run events. 'memory' for development (no persistence), 'db' for production (SQL queries), 'jsonl' for lightweight single-node persistence.",
)
max_trace_content: int = Field(
default=10240,
description="Maximum trace content size in bytes before truncation (db backend only).",
)
track_token_usage: bool = Field(
default=True,
description="Whether RunJournal should accumulate token counts to RunRow.",
)
@@ -20,11 +20,6 @@ class SubagentOverrideConfig(BaseModel):
ge=1,
description="Maximum turns for this subagent (None = use global or builtin default)",
)
model: str | None = Field(
default=None,
min_length=1,
description="Model name for this subagent (None = inherit from parent agent)",
)
class SubagentsAppConfig(BaseModel):
@@ -59,20 +54,6 @@ class SubagentsAppConfig(BaseModel):
return override.timeout_seconds
return self.timeout_seconds
def get_model_for(self, agent_name: str) -> str | None:
"""Get the model override for a specific agent.
Args:
agent_name: The name of the subagent.
Returns:
Model name if overridden, None otherwise (subagent will inherit parent model).
"""
override = self.agents.get(agent_name)
if override is not None and override.model is not None:
return override.model
return None
def get_max_turns_for(self, agent_name: str, builtin_default: int) -> int:
"""Get the effective max_turns for a specific agent."""
override = self.agents.get(agent_name)
@@ -103,8 +84,6 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
parts.append(f"timeout={override.timeout_seconds}s")
if override.max_turns is not None:
parts.append(f"max_turns={override.max_turns}")
if override.model is not None:
parts.append(f"model={override.model}")
if parts:
overrides_summary[name] = ", ".join(parts)
@@ -118,13 +118,9 @@ def get_cached_mcp_tools() -> list[BaseTool]:
loop.run_until_complete(initialize_mcp_tools())
except RuntimeError:
# No event loop exists, create one
try:
asyncio.run(initialize_mcp_tools())
except Exception:
logger.exception("Failed to lazy-initialize MCP tools")
return []
except Exception:
logger.exception("Failed to lazy-initialize MCP tools")
asyncio.run(initialize_mcp_tools())
except Exception as e:
logger.error(f"Failed to lazy-initialize MCP tools: {e}")
return []
return _mcp_tools_cache or []
@@ -113,7 +113,16 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
elif "reasoning_effort" not in model_settings_from_config:
model_settings_from_config["reasoning_effort"] = "medium"
model_instance = model_class(**{**model_settings_from_config, **kwargs})
# Ensure stream_usage is enabled so that token usage metadata is available
# in streaming responses. LangChain's BaseChatOpenAI only defaults
# stream_usage=True when no custom base_url/api_base is set, so models
# hitting third-party endpoints (e.g. doubao, deepseek) silently lose
# usage data. We default it to True unless explicitly configured.
if "stream_usage" not in model_settings_from_config and "stream_usage" not in kwargs:
if "stream_usage" in getattr(model_class, "model_fields", {}):
model_settings_from_config["stream_usage"] = True
model_instance = model_class(**kwargs, **model_settings_from_config)
callbacks = build_tracing_callbacks()
if callbacks:
@@ -0,0 +1,13 @@
"""DeerFlow application persistence layer (SQLAlchemy 2.0 async ORM).
This module manages DeerFlow's own application data -- runs metadata,
thread ownership, cron jobs, users. It is completely separate from
LangGraph's checkpointer, which manages graph execution state.
Usage:
from deerflow.persistence import init_engine, close_engine, get_session_factory
"""
from deerflow.persistence.engine import close_engine, get_engine, get_session_factory, init_engine
__all__ = ["close_engine", "get_engine", "get_session_factory", "init_engine"]
@@ -0,0 +1,40 @@
"""SQLAlchemy declarative base with automatic to_dict support.
All DeerFlow ORM models inherit from this Base. It provides a generic
to_dict() method via SQLAlchemy's inspect() so individual models don't
need to write their own serialization logic.
LangGraph's checkpointer tables are NOT managed by this Base.
"""
from __future__ import annotations
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
"""Base class for all DeerFlow ORM models.
Provides:
- Automatic to_dict() via SQLAlchemy column inspection.
- Standard __repr__() showing all column values.
"""
def to_dict(self, *, exclude: set[str] | None = None) -> dict:
"""Convert ORM instance to plain dict.
Uses SQLAlchemy's inspect() to iterate mapped column attributes.
Args:
exclude: Optional set of column keys to omit.
Returns:
Dict of {column_key: value} for all mapped columns.
"""
exclude = exclude or set()
return {c.key: getattr(self, c.key) for c in sa_inspect(type(self)).mapper.column_attrs if c.key not in exclude}
def __repr__(self) -> str:
cols = ", ".join(f"{c.key}={getattr(self, c.key)!r}" for c in sa_inspect(type(self)).mapper.column_attrs)
return f"{type(self).__name__}({cols})"
@@ -0,0 +1,185 @@
"""Async SQLAlchemy engine lifecycle management.
Initializes at Gateway startup, provides session factory for
repositories, disposes at shutdown.
When database.backend="memory", init_engine is a no-op and
get_session_factory() returns None. Repositories must check for
None and fall back to in-memory implementations.
"""
from __future__ import annotations
import json
import logging
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
def _json_serializer(obj: object) -> str:
"""JSON serializer with ensure_ascii=False for Chinese character support."""
return json.dumps(obj, ensure_ascii=False)
logger = logging.getLogger(__name__)
_engine: AsyncEngine | None = None
_session_factory: async_sessionmaker[AsyncSession] | None = None
async def _auto_create_postgres_db(url: str) -> None:
"""Connect to the ``postgres`` maintenance DB and CREATE DATABASE.
The target database name is extracted from *url*. The connection is
made to the default ``postgres`` database on the same server using
``AUTOCOMMIT`` isolation (CREATE DATABASE cannot run inside a
transaction).
"""
from sqlalchemy import text
from sqlalchemy.engine.url import make_url
parsed = make_url(url)
db_name = parsed.database
if not db_name:
raise ValueError("Cannot auto-create database: no database name in URL")
# Connect to the default 'postgres' database to issue CREATE DATABASE
maint_url = parsed.set(database="postgres")
maint_engine = create_async_engine(maint_url, isolation_level="AUTOCOMMIT")
try:
async with maint_engine.connect() as conn:
await conn.execute(text(f'CREATE DATABASE "{db_name}"'))
logger.info("Auto-created PostgreSQL database: %s", db_name)
finally:
await maint_engine.dispose()
async def init_engine(
backend: str,
*,
url: str = "",
echo: bool = False,
pool_size: int = 5,
sqlite_dir: str = "",
) -> None:
"""Create the async engine and session factory, then auto-create tables.
Args:
backend: "memory", "sqlite", or "postgres".
url: SQLAlchemy async URL (for sqlite/postgres).
echo: Echo SQL to log.
pool_size: Postgres connection pool size.
sqlite_dir: Directory to create for SQLite (ensured to exist).
"""
global _engine, _session_factory
if backend == "memory":
logger.info("Persistence backend=memory -- ORM engine not initialized")
return
if backend == "postgres":
try:
import asyncpg # noqa: F401
except ImportError:
raise ImportError("database.backend is set to 'postgres' but asyncpg is not installed.\nInstall it with:\n uv sync --extra postgres\nOr switch to backend: sqlite in config.yaml for single-node deployment.") from None
if backend == "sqlite":
import os
from sqlalchemy import event
os.makedirs(sqlite_dir or ".", exist_ok=True)
_engine = create_async_engine(url, echo=echo, json_serializer=_json_serializer)
# Enable WAL on every new connection. SQLite PRAGMA settings are
# per-connection, so we wire the listener instead of running PRAGMA
# once at startup. WAL gives concurrent reads + writers without
# blocking and is the standard recommendation for any production
# SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion
# ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only
# at WAL checkpoint boundaries instead of every commit.
@event.listens_for(_engine.sync_engine, "connect")
def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract
cursor = dbapi_conn.cursor()
try:
cursor.execute("PRAGMA journal_mode=WAL;")
cursor.execute("PRAGMA synchronous=NORMAL;")
cursor.execute("PRAGMA foreign_keys=ON;")
finally:
cursor.close()
elif backend == "postgres":
_engine = create_async_engine(
url,
echo=echo,
pool_size=pool_size,
pool_pre_ping=True,
json_serializer=_json_serializer,
)
else:
raise ValueError(f"Unknown persistence backend: {backend!r}")
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
# Auto-create tables (dev convenience). Production should use Alembic.
from deerflow.persistence.base import Base
# Import all models so Base.metadata discovers them.
# When no models exist yet (scaffolding phase), this is a no-op.
try:
import deerflow.persistence.models # noqa: F401
except ImportError:
# Models package not yet available — tables won't be auto-created.
# This is expected during initial scaffolding or minimal installs.
logger.debug("deerflow.persistence.models not found; skipping auto-create tables")
try:
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
except Exception as exc:
if backend == "postgres" and "does not exist" in str(exc):
# Database not yet created — attempt to auto-create it, then retry.
await _auto_create_postgres_db(url)
# Rebuild engine against the now-existing database
await _engine.dispose()
_engine = create_async_engine(url, echo=echo, pool_size=pool_size, pool_pre_ping=True, json_serializer=_json_serializer)
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
else:
raise
logger.info("Persistence engine initialized: backend=%s", backend)
async def init_engine_from_config(config) -> None:
"""Convenience: init engine from a DatabaseConfig object."""
if config.backend == "memory":
await init_engine("memory")
return
await init_engine(
backend=config.backend,
url=config.app_sqlalchemy_url,
echo=config.echo_sql,
pool_size=config.pool_size,
sqlite_dir=config.sqlite_dir if config.backend == "sqlite" else "",
)
def get_session_factory() -> async_sessionmaker[AsyncSession] | None:
"""Return the async session factory, or None if backend=memory."""
return _session_factory
def get_engine() -> AsyncEngine | None:
"""Return the async engine, or None if not initialized."""
return _engine
async def close_engine() -> None:
"""Dispose the engine, release all connections."""
global _engine, _session_factory
if _engine is not None:
await _engine.dispose()
logger.info("Persistence engine closed")
_engine = None
_session_factory = None
@@ -0,0 +1,6 @@
"""Feedback persistence — ORM and SQL repository."""
from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.persistence.feedback.sql import FeedbackRepository
__all__ = ["FeedbackRepository", "FeedbackRow"]
@@ -0,0 +1,30 @@
"""ORM model for user feedback on runs."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import DateTime, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class FeedbackRow(Base):
__tablename__ = "feedback"
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
message_id: Mapped[str | None] = mapped_column(String(64))
# message_id is an optional RunEventStore event identifier —
# allows feedback to target a specific message or the entire run
rating: Mapped[int] = mapped_column(nullable=False)
# +1 (thumbs-up) or -1 (thumbs-down)
comment: Mapped[str | None] = mapped_column(Text)
# Optional text feedback from the user
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
@@ -0,0 +1,139 @@
"""SQLAlchemy-backed feedback storage.
Each method acquires its own short-lived session.
"""
from __future__ import annotations
import uuid
from datetime import UTC, datetime
from sqlalchemy import case, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
class FeedbackRepository:
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
@staticmethod
def _row_to_dict(row: FeedbackRow) -> dict:
d = row.to_dict()
val = d.get("created_at")
if isinstance(val, datetime):
d["created_at"] = val.isoformat()
return d
async def create(
self,
*,
run_id: str,
thread_id: str,
rating: int,
owner_id: str | None | _AutoSentinel = AUTO,
message_id: str | None = None,
comment: str | None = None,
) -> dict:
"""Create a feedback record. rating must be +1 or -1."""
if rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {rating}")
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
row = FeedbackRow(
feedback_id=str(uuid.uuid4()),
run_id=run_id,
thread_id=thread_id,
owner_id=resolved_owner_id,
message_id=message_id,
rating=rating,
comment=comment,
created_at=datetime.now(UTC),
)
async with self._sf() as session:
session.add(row)
await session.commit()
await session.refresh(row)
return self._row_to_dict(row)
async def get(
self,
feedback_id: str,
*,
owner_id: str | None | _AutoSentinel = AUTO,
) -> dict | None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
if row is None:
return None
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return None
return self._row_to_dict(row)
async def list_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 100,
owner_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
if resolved_owner_id is not None:
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_by_thread(
self,
thread_id: str,
*,
limit: int = 100,
owner_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
if resolved_owner_id is not None:
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def delete(
self,
feedback_id: str,
*,
owner_id: str | None | _AutoSentinel = AUTO,
) -> bool:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
if row is None:
return False
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return False
await session.delete(row)
await session.commit()
return True
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
"""Aggregate feedback stats for a run using database-side counting."""
stmt = select(
func.count().label("total"),
func.coalesce(func.sum(case((FeedbackRow.rating == 1, 1), else_=0)), 0).label("positive"),
func.coalesce(func.sum(case((FeedbackRow.rating == -1, 1), else_=0)), 0).label("negative"),
).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
async with self._sf() as session:
row = (await session.execute(stmt)).one()
return {
"run_id": run_id,
"total": row.total,
"positive": row.positive,
"negative": row.negative,
}
@@ -0,0 +1,38 @@
[alembic]
script_location = %(here)s
# Default URL for offline mode / autogenerate.
# Runtime uses engine from DeerFlow config.
sqlalchemy.url = sqlite+aiosqlite:///./data/app.db
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
@@ -0,0 +1,65 @@
"""Alembic environment for DeerFlow application tables.
ONLY manages DeerFlow's tables (runs, threads_meta, cron_jobs, users).
LangGraph's checkpointer tables are managed by LangGraph itself -- they
have their own schema lifecycle and must not be touched by Alembic.
"""
from __future__ import annotations
import asyncio
import logging
from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio import create_async_engine
from deerflow.persistence.base import Base
# Import all models so metadata is populated.
try:
import deerflow.persistence.models # noqa: F401 — register ORM models with Base.metadata
except ImportError:
# Models not available — migration will work with existing metadata only.
logging.getLogger(__name__).warning("Could not import deerflow.persistence.models; Alembic may not detect all tables")
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
def run_migrations_offline() -> None:
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
render_as_batch=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection):
context.configure(
connection=connection,
target_metadata=target_metadata,
render_as_batch=True, # Required for SQLite ALTER TABLE support
)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
connectable = create_async_engine(config.get_main_option("sqlalchemy.url"))
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())
@@ -0,0 +1,23 @@
"""ORM model registration entry point.
Importing this module ensures all ORM models are registered with
``Base.metadata`` so Alembic autogenerate detects every table.
The actual ORM classes have moved to entity-specific subpackages:
- ``deerflow.persistence.thread_meta``
- ``deerflow.persistence.run``
- ``deerflow.persistence.feedback``
- ``deerflow.persistence.user``
``RunEventRow`` remains in ``deerflow.persistence.models.run_event`` because
its storage implementation lives in ``deerflow.runtime.events.store.db`` and
there is no matching entity directory.
"""
from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.persistence.models.run_event import RunEventRow
from deerflow.persistence.run.model import RunRow
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.persistence.user.model import UserRow
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"]
@@ -0,0 +1,35 @@
"""ORM model for run events."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import JSON, DateTime, Index, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class RunEventRow(Base):
__tablename__ = "run_events"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
# Owner of the conversation this event belongs to. Nullable for data
# created before auth was introduced; populated by auth middleware on
# new writes and by the boot-time orphan migration on existing rows.
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
category: Mapped[str] = mapped_column(String(16), nullable=False)
# "message" | "trace" | "lifecycle"
content: Mapped[str] = mapped_column(Text, default="")
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
seq: Mapped[int] = mapped_column(nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
__table_args__ = (
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
Index("ix_events_run", "thread_id", "run_id", "seq"),
)
@@ -0,0 +1,6 @@
"""Run metadata persistence — ORM and SQL repository."""
from deerflow.persistence.run.model import RunRow
from deerflow.persistence.run.sql import RunRepository
__all__ = ["RunRepository", "RunRow"]
@@ -0,0 +1,49 @@
"""ORM model for run metadata."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import JSON, DateTime, Index, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class RunRow(Base):
__tablename__ = "runs"
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
assistant_id: Mapped[str | None] = mapped_column(String(128))
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
status: Mapped[str] = mapped_column(String(20), default="pending")
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
model_name: Mapped[str | None] = mapped_column(String(128))
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
kwargs_json: Mapped[dict] = mapped_column(JSON, default=dict)
error: Mapped[str | None] = mapped_column(Text)
# Convenience fields (for listing pages without querying RunEventStore)
message_count: Mapped[int] = mapped_column(default=0)
first_human_message: Mapped[str | None] = mapped_column(Text)
last_ai_message: Mapped[str | None] = mapped_column(Text)
# Token usage (accumulated in-memory by RunJournal, written on run completion)
total_input_tokens: Mapped[int] = mapped_column(default=0)
total_output_tokens: Mapped[int] = mapped_column(default=0)
total_tokens: Mapped[int] = mapped_column(default=0)
llm_call_count: Mapped[int] = mapped_column(default=0)
lead_agent_tokens: Mapped[int] = mapped_column(default=0)
subagent_tokens: Mapped[int] = mapped_column(default=0)
middleware_tokens: Mapped[int] = mapped_column(default=0)
# Follow-up association
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)
@@ -0,0 +1,255 @@
"""SQLAlchemy-backed RunStore implementation.
Each method acquires and releases its own short-lived session.
Run status updates happen from background workers that may live
minutes -- we don't hold connections across long execution.
"""
from __future__ import annotations
import json
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.run.model import RunRow
from deerflow.runtime.runs.store.base import RunStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
class RunRepository(RunStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
@staticmethod
def _safe_json(obj: Any) -> Any:
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
if obj is None:
return None
if isinstance(obj, (str, int, float, bool)):
return obj
if isinstance(obj, dict):
return {k: RunRepository._safe_json(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [RunRepository._safe_json(v) for v in obj]
if hasattr(obj, "model_dump"):
try:
return obj.model_dump()
except Exception:
pass
if hasattr(obj, "dict"):
try:
return obj.dict()
except Exception:
pass
try:
json.dumps(obj)
return obj
except (TypeError, ValueError):
return str(obj)
@staticmethod
def _row_to_dict(row: RunRow) -> dict[str, Any]:
d = row.to_dict()
# Remap JSON columns to match RunStore interface
d["metadata"] = d.pop("metadata_json", {})
d["kwargs"] = d.pop("kwargs_json", {})
# Convert datetime to ISO string for consistency with MemoryRunStore
for key in ("created_at", "updated_at"):
val = d.get(key)
if isinstance(val, datetime):
d[key] = val.isoformat()
return d
async def put(
self,
run_id,
*,
thread_id,
assistant_id=None,
owner_id: str | None | _AutoSentinel = AUTO,
status="pending",
multitask_strategy="reject",
metadata=None,
kwargs=None,
error=None,
created_at=None,
follow_up_to_run_id=None,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
now = datetime.now(UTC)
row = RunRow(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
owner_id=resolved_owner_id,
status=status,
multitask_strategy=multitask_strategy,
metadata_json=self._safe_json(metadata) or {},
kwargs_json=self._safe_json(kwargs) or {},
error=error,
follow_up_to_run_id=follow_up_to_run_id,
created_at=datetime.fromisoformat(created_at) if created_at else now,
updated_at=now,
)
async with self._sf() as session:
session.add(row)
await session.commit()
async def get(
self,
run_id,
*,
owner_id: str | None | _AutoSentinel = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
async with self._sf() as session:
row = await session.get(RunRow, run_id)
if row is None:
return None
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return None
return self._row_to_dict(row)
async def list_by_thread(
self,
thread_id,
*,
owner_id: str | None | _AutoSentinel = AUTO,
limit=100,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
if resolved_owner_id is not None:
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_status(self, run_id, status, *, error=None):
values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)}
if error is not None:
values["error"] = error
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
async def delete(
self,
run_id,
*,
owner_id: str | None | _AutoSentinel = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
async with self._sf() as session:
row = await session.get(RunRow, run_id)
if row is None:
return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return
await session.delete(row)
await session.commit()
async def list_pending(self, *, before=None):
if before is None:
before_dt = datetime.now(UTC)
elif isinstance(before, datetime):
before_dt = before
else:
before_dt = datetime.fromisoformat(before)
stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= before_dt).order_by(RunRow.created_at.asc())
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
last_ai_message: str | None = None,
first_human_message: str | None = None,
error: str | None = None,
) -> None:
"""Update status + token usage + convenience fields on run completion."""
values: dict[str, Any] = {
"status": status,
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
"updated_at": datetime.now(UTC),
}
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
if error is not None:
values["error"] = error
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
"""Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error"))
_thread = RunRow.thread_id == thread_id
stmt = (
select(
func.coalesce(RunRow.model_name, "unknown").label("model"),
func.count().label("runs"),
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
func.coalesce(func.sum(RunRow.total_output_tokens), 0).label("total_output_tokens"),
func.coalesce(func.sum(RunRow.lead_agent_tokens), 0).label("lead_agent"),
func.coalesce(func.sum(RunRow.subagent_tokens), 0).label("subagent"),
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
)
.where(_thread, _completed)
.group_by(func.coalesce(RunRow.model_name, "unknown"))
)
async with self._sf() as session:
rows = (await session.execute(stmt)).all()
total_tokens = total_input = total_output = total_runs = 0
lead_agent = subagent = middleware = 0
by_model: dict[str, dict] = {}
for r in rows:
by_model[r.model] = {"tokens": r.total_tokens, "runs": r.runs}
total_tokens += r.total_tokens
total_input += r.total_input_tokens
total_output += r.total_output_tokens
total_runs += r.runs
lead_agent += r.lead_agent
subagent += r.subagent
middleware += r.middleware
return {
"total_tokens": total_tokens,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_runs": total_runs,
"by_model": by_model,
"by_caller": {
"lead_agent": lead_agent,
"subagent": subagent,
"middleware": middleware,
},
}
@@ -0,0 +1,13 @@
"""Thread metadata persistence — ORM, abstract store, and concrete implementations."""
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
__all__ = [
"MemoryThreadMetaStore",
"ThreadMetaRepository",
"ThreadMetaRow",
"ThreadMetaStore",
]
@@ -0,0 +1,60 @@
"""Abstract interface for thread metadata storage.
Implementations:
- ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy)
- MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode)
"""
from __future__ import annotations
import abc
class ThreadMetaStore(abc.ABC):
@abc.abstractmethod
async def create(
self,
thread_id: str,
*,
assistant_id: str | None = None,
owner_id: str | None = None,
display_name: str | None = None,
metadata: dict | None = None,
) -> dict:
pass
@abc.abstractmethod
async def get(self, thread_id: str) -> dict | None:
pass
@abc.abstractmethod
async def search(
self,
*,
metadata: dict | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[dict]:
pass
@abc.abstractmethod
async def update_display_name(self, thread_id: str, display_name: str) -> None:
pass
@abc.abstractmethod
async def update_status(self, thread_id: str, status: str) -> None:
pass
@abc.abstractmethod
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
"""Merge ``metadata`` into the thread's metadata field.
Existing keys are overwritten by the new values; keys absent from
``metadata`` are preserved. No-op if the thread does not exist.
"""
pass
@abc.abstractmethod
async def delete(self, thread_id: str) -> None:
pass
@@ -0,0 +1,120 @@
"""In-memory ThreadMetaStore backed by LangGraph BaseStore.
Used when database.backend=memory. Delegates to the LangGraph Store's
``("threads",)`` namespace — the same namespace used by the Gateway
router for thread records.
"""
from __future__ import annotations
import time
from typing import Any
from langgraph.store.base import BaseStore
from deerflow.persistence.thread_meta.base import ThreadMetaStore
THREADS_NS: tuple[str, ...] = ("threads",)
class MemoryThreadMetaStore(ThreadMetaStore):
def __init__(self, store: BaseStore) -> None:
self._store = store
async def create(
self,
thread_id: str,
*,
assistant_id: str | None = None,
owner_id: str | None = None,
display_name: str | None = None,
metadata: dict | None = None,
) -> dict:
now = time.time()
record: dict[str, Any] = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"owner_id": owner_id,
"display_name": display_name,
"status": "idle",
"metadata": metadata or {},
"values": {},
"created_at": now,
"updated_at": now,
}
await self._store.aput(THREADS_NS, thread_id, record)
return record
async def get(self, thread_id: str) -> dict | None:
item = await self._store.aget(THREADS_NS, thread_id)
return item.value if item is not None else None
async def search(
self,
*,
metadata: dict | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[dict]:
filter_dict: dict[str, Any] = {}
if metadata:
filter_dict.update(metadata)
if status:
filter_dict["status"] = status
items = await self._store.asearch(
THREADS_NS,
filter=filter_dict or None,
limit=limit,
offset=offset,
)
return [self._item_to_dict(item) for item in items]
async def update_display_name(self, thread_id: str, display_name: str) -> None:
item = await self._store.aget(THREADS_NS, thread_id)
if item is None:
return
record = dict(item.value)
record["display_name"] = display_name
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_status(self, thread_id: str, status: str) -> None:
item = await self._store.aget(THREADS_NS, thread_id)
if item is None:
return
record = dict(item.value)
record["status"] = status
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
"""Merge ``metadata`` into the in-memory record. No-op if absent."""
item = await self._store.aget(THREADS_NS, thread_id)
if item is None:
return
record = dict(item.value)
merged = dict(record.get("metadata") or {})
merged.update(metadata)
record["metadata"] = merged
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def delete(self, thread_id: str) -> None:
await self._store.adelete(THREADS_NS, thread_id)
@staticmethod
def _item_to_dict(item) -> dict[str, Any]:
"""Convert a Store SearchItem to the dict format expected by callers."""
val = item.value
return {
"thread_id": item.key,
"assistant_id": val.get("assistant_id"),
"owner_id": val.get("owner_id"),
"display_name": val.get("display_name"),
"status": val.get("status", "idle"),
"metadata": val.get("metadata", {}),
"created_at": str(val.get("created_at", "")),
"updated_at": str(val.get("updated_at", "")),
}
@@ -0,0 +1,23 @@
"""ORM model for thread metadata."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import JSON, DateTime, String
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class ThreadMetaRow(Base):
__tablename__ = "threads_meta"
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
display_name: Mapped[str | None] = mapped_column(String(256))
status: Mapped[str] = mapped_column(String(20), default="idle")
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
@@ -0,0 +1,223 @@
"""SQLAlchemy-backed thread metadata repository."""
from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
class ThreadMetaRepository(ThreadMetaStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
@staticmethod
def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]:
d = row.to_dict()
d["metadata"] = d.pop("metadata_json", {})
for key in ("created_at", "updated_at"):
val = d.get(key)
if isinstance(val, datetime):
d[key] = val.isoformat()
return d
async def create(
self,
thread_id: str,
*,
assistant_id: str | None = None,
owner_id: str | None | _AutoSentinel = AUTO,
display_name: str | None = None,
metadata: dict | None = None,
) -> dict:
# Auto-resolve owner_id from contextvar when AUTO; explicit None
# creates an orphan row (used by migration scripts).
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
now = datetime.now(UTC)
row = ThreadMetaRow(
thread_id=thread_id,
assistant_id=assistant_id,
owner_id=resolved_owner_id,
display_name=display_name,
metadata_json=metadata or {},
created_at=now,
updated_at=now,
)
async with self._sf() as session:
session.add(row)
await session.commit()
await session.refresh(row)
return self._row_to_dict(row)
async def get(
self,
thread_id: str,
*,
owner_id: str | None | _AutoSentinel = AUTO,
) -> dict | None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return None
# Enforce owner filter unless explicitly bypassed (owner_id=None).
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return None
return self._row_to_dict(row)
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def check_access(self, thread_id: str, owner_id: str, *, require_existing: bool = False) -> bool:
"""Check if ``owner_id`` has access to ``thread_id``.
Two modes — one row, two distinct semantics depending on what
the caller is about to do:
- ``require_existing=False`` (default, permissive):
Returns True for: row missing (untracked legacy thread),
``row.owner_id`` is None (shared / pre-auth data),
or ``row.owner_id == owner_id``. Use for **read-style**
decorators where treating an untracked thread as accessible
preserves backward-compat.
- ``require_existing=True`` (strict):
Returns True **only** when the row exists AND
(``row.owner_id == owner_id`` OR ``row.owner_id is None``).
Use for **destructive / mutating** decorators (DELETE, PATCH,
state-update) so a thread that has *already been deleted*
cannot be re-targeted by any caller — closing the
delete-idempotence cross-user gap where the row vanishing
made every other user appear to "own" it.
"""
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return not require_existing
if row.owner_id is None:
return True
return row.owner_id == owner_id
async def search(
self,
*,
metadata: dict | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
owner_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
"""Search threads with optional metadata and status filters.
Owner filter is enforced by default: caller must be in a user
context. Pass ``owner_id=None`` to bypass (migration/CLI).
"""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
if resolved_owner_id is not None:
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
if status:
stmt = stmt.where(ThreadMetaRow.status == status)
if metadata:
# When metadata filter is active, fetch a larger window and filter
# in Python. TODO(Phase 2): use JSON DB operators (Postgres @>,
# SQLite json_extract) for server-side filtering.
stmt = stmt.limit(limit * 5 + offset)
async with self._sf() as session:
result = await session.execute(stmt)
rows = [self._row_to_dict(r) for r in result.scalars()]
rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())]
return rows[offset : offset + limit]
else:
stmt = stmt.limit(limit).offset(offset)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
"""Return True if the row exists and is owned (or filter bypassed)."""
if resolved_owner_id is None:
return True # explicit bypass
row = await session.get(ThreadMetaRow, thread_id)
return row is not None and row.owner_id == resolved_owner_id
async def update_display_name(
self,
thread_id: str,
display_name: str,
*,
owner_id: str | None | _AutoSentinel = AUTO,
) -> None:
"""Update the display_name (title) for a thread."""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_owner_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
await session.commit()
async def update_status(
self,
thread_id: str,
status: str,
*,
owner_id: str | None | _AutoSentinel = AUTO,
) -> None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_owner_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
await session.commit()
async def update_metadata(
self,
thread_id: str,
metadata: dict,
*,
owner_id: str | None | _AutoSentinel = AUTO,
) -> None:
"""Merge ``metadata`` into ``metadata_json``.
Read-modify-write inside a single session/transaction so concurrent
callers see consistent state. No-op if the row does not exist or
the owner_id check fails.
"""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return
merged = dict(row.metadata_json or {})
merged.update(metadata)
row.metadata_json = merged
row.updated_at = datetime.now(UTC)
await session.commit()
async def delete(
self,
thread_id: str,
*,
owner_id: str | None | _AutoSentinel = AUTO,
) -> None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return
await session.delete(row)
await session.commit()
@@ -0,0 +1,12 @@
"""User storage subpackage.
Holds the ORM model for the ``users`` table. The concrete repository
implementation (``SQLiteUserRepository``) lives in the app layer
(``app.gateway.auth.repositories.sqlite``) because it converts
between the ORM row and the auth module's pydantic ``User`` class.
This keeps the harness package free of any dependency on app code.
"""
from deerflow.persistence.user.model import UserRow
__all__ = ["UserRow"]
@@ -0,0 +1,59 @@
"""ORM model for the users table.
Lives in the harness persistence package so it is picked up by
``Base.metadata.create_all()`` alongside ``threads_meta``, ``runs``,
``run_events``, and ``feedback``. Using the shared engine means:
- One SQLite/Postgres database, one connection pool
- One schema initialisation codepath
- Consistent async sessions across auth and persistence reads
"""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import Boolean, DateTime, Index, String, text
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class UserRow(Base):
__tablename__ = "users"
# UUIDs are stored as 36-char strings for cross-backend portability.
id: Mapped[str] = mapped_column(String(36), primary_key=True)
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
password_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
# "admin" | "user" — kept as plain string to avoid ALTER TABLE pain
# when new roles are introduced.
system_role: Mapped[str] = mapped_column(String(16), nullable=False, default="user")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
)
# OAuth linkage (optional). A partial unique index enforces one
# account per (provider, oauth_id) pair, leaving NULL/NULL rows
# unconstrained so plain password accounts can coexist.
oauth_provider: Mapped[str | None] = mapped_column(String(32), nullable=True)
oauth_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
# Auth lifecycle flags
needs_setup: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
token_version: Mapped[int] = mapped_column(nullable=False, default=0)
__table_args__ = (
Index(
"idx_users_oauth_identity",
"oauth_provider",
"oauth_id",
unique=True,
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
),
)
@@ -5,7 +5,7 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and
directly from ``deerflow.runtime``.
"""
from .runs import ConflictError, DisconnectMode, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
from .store import get_store, make_store, reset_store, store_context
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
@@ -14,6 +14,7 @@ __all__ = [
# runs
"ConflictError",
"DisconnectMode",
"RunContext",
"RunManager",
"RunRecord",
"RunStatus",
@@ -0,0 +1,134 @@
"""Pure functions to convert LangChain message objects to OpenAI Chat Completions format.
Used by RunJournal to build content dicts for event storage.
"""
from __future__ import annotations
import json
from typing import Any
_ROLE_MAP = {
"human": "user",
"ai": "assistant",
"system": "system",
"tool": "tool",
}
def langchain_to_openai_message(message: Any) -> dict:
"""Convert a single LangChain BaseMessage to an OpenAI message dict.
Handles:
- HumanMessage → {"role": "user", "content": "..."}
- AIMessage (text only) → {"role": "assistant", "content": "..."}
- AIMessage (with tool_calls) → {"role": "assistant", "content": null, "tool_calls": [...]}
- AIMessage (text + tool_calls) → both content and tool_calls present
- AIMessage (list content / multimodal) → content preserved as list
- SystemMessage → {"role": "system", "content": "..."}
- ToolMessage → {"role": "tool", "tool_call_id": "...", "content": "..."}
"""
msg_type = getattr(message, "type", "")
role = _ROLE_MAP.get(msg_type, msg_type)
content = getattr(message, "content", "")
if role == "tool":
return {
"role": "tool",
"tool_call_id": getattr(message, "tool_call_id", ""),
"content": content,
}
if role == "assistant":
tool_calls = getattr(message, "tool_calls", None) or []
result: dict = {"role": "assistant"}
if tool_calls:
openai_tool_calls = []
for tc in tool_calls:
args = tc.get("args", {})
openai_tool_calls.append(
{
"id": tc.get("id", ""),
"type": "function",
"function": {
"name": tc.get("name", ""),
"arguments": json.dumps(args) if not isinstance(args, str) else args,
},
}
)
# If no text content, set content to null per OpenAI spec
result["content"] = content if (isinstance(content, list) and content) or (isinstance(content, str) and content) else None
result["tool_calls"] = openai_tool_calls
else:
result["content"] = content
return result
# user / system / unknown
return {"role": role, "content": content}
def _infer_finish_reason(message: Any) -> str:
"""Infer OpenAI finish_reason from an AIMessage.
Returns "tool_calls" if tool_calls present, else looks in
response_metadata.finish_reason, else returns "stop".
"""
tool_calls = getattr(message, "tool_calls", None) or []
if tool_calls:
return "tool_calls"
resp_meta = getattr(message, "response_metadata", None) or {}
if isinstance(resp_meta, dict):
finish = resp_meta.get("finish_reason")
if finish:
return finish
return "stop"
def langchain_to_openai_completion(message: Any) -> dict:
"""Convert an AIMessage and its metadata to an OpenAI completion response dict.
Returns:
{
"id": message.id,
"model": message.response_metadata.get("model_name"),
"choices": [{"index": 0, "message": <openai_message>, "finish_reason": <inferred>}],
"usage": {"prompt_tokens": ..., "completion_tokens": ..., "total_tokens": ...} or None,
}
"""
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
openai_msg = langchain_to_openai_message(message)
finish_reason = _infer_finish_reason(message)
usage_metadata = getattr(message, "usage_metadata", None)
if usage_metadata is not None:
input_tokens = usage_metadata.get("input_tokens", 0) or 0
output_tokens = usage_metadata.get("output_tokens", 0) or 0
usage: dict | None = {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
else:
usage = None
return {
"id": getattr(message, "id", None),
"model": model_name,
"choices": [
{
"index": 0,
"message": openai_msg,
"finish_reason": finish_reason,
}
],
"usage": usage,
}
def langchain_messages_to_openai(messages: list) -> list[dict]:
"""Convert a list of LangChain BaseMessages to OpenAI message dicts."""
return [langchain_to_openai_message(m) for m in messages]
@@ -0,0 +1,4 @@
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.events.store.memory import MemoryRunEventStore
__all__ = ["MemoryRunEventStore", "RunEventStore"]
@@ -0,0 +1,26 @@
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.events.store.memory import MemoryRunEventStore
def make_run_event_store(config=None) -> RunEventStore:
"""Create a RunEventStore based on run_events.backend configuration."""
if config is None or config.backend == "memory":
return MemoryRunEventStore()
if config.backend == "db":
from deerflow.persistence.engine import get_session_factory
sf = get_session_factory()
if sf is None:
# database.backend=memory but run_events.backend=db -> fallback
return MemoryRunEventStore()
from deerflow.runtime.events.store.db import DbRunEventStore
return DbRunEventStore(sf, max_trace_content=config.max_trace_content)
if config.backend == "jsonl":
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
return JsonlRunEventStore()
raise ValueError(f"Unknown run_events backend: {config.backend!r}")
__all__ = ["MemoryRunEventStore", "RunEventStore", "make_run_event_store"]

Some files were not shown because too many files have changed in this diff Show More