mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 23:46:50 +00:00
feat(auth): release-validation pass for 2.0-rc — 12 blockers + simplify follow-ups (#2008)
* feat(auth): introduce backend auth module
Port RFC-001 authentication core from PR #1728:
- JWT token handling (create_access_token, decode_token, TokenPayload)
- Password hashing (bcrypt) with verify_password
- SQLite UserRepository with base interface
- Provider Factory pattern (LocalAuthProvider)
- CLI reset_admin tool
- Auth-specific errors (AuthErrorCode, TokenError, AuthErrorResponse)
Deps:
- bcrypt>=4.0.0
- pyjwt>=2.9.0
- email-validator>=2.0.0
- backend/uv.toml pins public PyPI index
Tests: 12 pure unit tests (test_auth_config.py, test_auth_errors.py).
Scope note: authz.py, test_auth.py, and test_auth_type_system.py are
deferred to commit 2 because they depend on middleware and deps wiring
that is not yet in place. Commit 1 stays "pure new files only" as the
spec mandates.
* feat(auth): wire auth end-to-end (middleware + frontend replacement)
Backend:
- Port auth_middleware, csrf_middleware, langgraph_auth, routers/auth
- Port authz decorator (owner_filter_key defaults to 'owner_id')
- Merge app.py: register AuthMiddleware + CSRFMiddleware + CORS, add
_ensure_admin_user lifespan hook, _migrate_orphaned_threads helper,
register auth router
- Merge deps.py: add get_local_provider, get_current_user_from_request,
get_optional_user_from_request; keep get_current_user as thin str|None
adapter for feedback router
- langgraph.json: add auth path pointing to langgraph_auth.py:auth
- Rename metadata['user_id'] -> metadata['owner_id'] in langgraph_auth
(both metadata write and LangGraph filter dict) + test fixtures
Frontend:
- Delete better-auth library and api catch-all route
- Remove better-auth npm dependency and env vars (BETTER_AUTH_SECRET,
BETTER_AUTH_GITHUB_*) from env.js
- Port frontend/src/core/auth/* (AuthProvider, gateway-config,
proxy-policy, server-side getServerSideUser, types)
- Port frontend/src/core/api/fetcher.ts
- Port (auth)/layout, (auth)/login, (auth)/setup pages
- Rewrite workspace/layout.tsx as server component that calls
getServerSideUser and wraps in AuthProvider
- Port workspace/workspace-content.tsx for the client-side sidebar logic
Tests:
- Port 5 auth test files (test_auth, test_auth_middleware,
test_auth_type_system, test_ensure_admin, test_langgraph_auth)
- 176 auth tests PASS
After this commit: login/logout/registration flow works, but persistence
layer does not yet filter by owner_id. Commit 4 closes that gap.
* feat(auth): account settings page + i18n
- Port account-settings-page.tsx (change password, change email, logout)
- Wire into settings-dialog.tsx as new "account" section with UserIcon,
rendered first in the section list
- Add i18n keys:
- en-US/zh-CN: settings.sections.account ("Account" / "账号")
- en-US/zh-CN: button.logout ("Log out" / "退出登录")
- types.ts: matching type declarations
* feat(auth): enforce owner_id across 2.0-rc persistence layer
Add request-scoped contextvar-based owner filtering to threads_meta,
runs, run_events, and feedback repositories. Router code is unchanged
— isolation is enforced at the storage layer so that any caller that
forgets to pass owner_id still gets filtered results, and new routes
cannot accidentally leak data.
Core infrastructure
-------------------
- deerflow/runtime/user_context.py (new):
- ContextVar[CurrentUser | None] with default None
- runtime_checkable CurrentUser Protocol (structural subtype with .id)
- set/reset/get/require helpers
- AUTO sentinel + resolve_owner_id(value, method_name) for sentinel
three-state resolution: AUTO reads contextvar, explicit str
overrides, explicit None bypasses the filter (for migration/CLI)
Repository changes
------------------
- ThreadMetaRepository: create/get/search/update_*/delete gain
owner_id=AUTO kwarg; read paths filter by owner, writes stamp it,
mutations check ownership before applying
- RunRepository: put/get/list_by_thread/delete gain owner_id=AUTO kwarg
- FeedbackRepository: create/get/list_by_run/list_by_thread/delete
gain owner_id=AUTO kwarg
- DbRunEventStore: list_messages/list_events/list_messages_by_run/
count_messages/delete_by_thread/delete_by_run gain owner_id=AUTO
kwarg. Write paths (put/put_batch) read contextvar softly: when a
request-scoped user is available, owner_id is stamped; background
worker writes without a user context pass None which is valid
(orphan row to be bound by migration)
Schema
------
- persistence/models/run_event.py: RunEventRow.owner_id = Mapped[
str | None] = mapped_column(String(64), nullable=True, index=True)
- No alembic migration needed: 2.0 ships fresh, Base.metadata.create_all
picks up the new column automatically
Middleware
----------
- auth_middleware.py: after cookie check, call get_optional_user_from_
request to load the real User, stamp it into request.state.user AND
the contextvar via set_current_user, reset in a try/finally. Public
paths and unauthenticated requests continue without contextvar, and
@require_auth handles the strict 401 path
Test infrastructure
-------------------
- tests/conftest.py: @pytest.fixture(autouse=True) _auto_user_context
sets a default SimpleNamespace(id="test-user-autouse") on every test
unless marked @pytest.mark.no_auto_user. Keeps existing 20+
persistence tests passing without modification
- pyproject.toml [tool.pytest.ini_options]: register no_auto_user
marker so pytest does not emit warnings for opt-out tests
- tests/test_user_context.py: 6 tests covering three-state semantics,
Protocol duck typing, and require/optional APIs
- tests/test_thread_meta_repo.py: one test updated to pass owner_id=
None explicitly where it was previously relying on the old default
Test results
------------
- test_user_context.py: 6 passed
- test_auth*.py + test_langgraph_auth.py + test_ensure_admin.py: 127
- test_run_event_store / test_run_repository / test_thread_meta_repo
/ test_feedback: 92 passed
- Full backend suite: 1905 passed, 2 failed (both @requires_llm flaky
integration tests unrelated to auth), 1 skipped
* feat(auth): extend orphan migration to 2.0-rc persistence tables
_ensure_admin_user now runs a three-step pipeline on every boot:
Step 1 (fatal): admin user exists / is created / password is reset
Step 2 (non-fatal): LangGraph store orphan threads → admin
Step 3 (non-fatal): SQL persistence tables → admin
- threads_meta
- runs
- run_events
- feedback
Each step is idempotent. The fatal/non-fatal split mirrors PR #1728's
original philosophy: admin creation failure blocks startup (the system
is unusable without an admin), whereas migration failures log a warning
and let the service proceed (a partial migration is recoverable; a
missing admin is not).
Key helpers
-----------
- _iter_store_items(store, namespace, *, page_size=500):
async generator that cursor-paginates across LangGraph store pages.
Fixes PR #1728's hardcoded limit=1000 bug that would silently lose
orphans beyond the first page.
- _migrate_orphaned_threads(store, admin_user_id):
Rewritten to use _iter_store_items. Returns the migrated count so the
caller can log it; raises only on unhandled exceptions.
- _migrate_orphan_sql_tables(admin_user_id):
Imports the 4 ORM models lazily, grabs the shared session factory,
runs one UPDATE per table in a single transaction, commits once.
No-op when no persistence backend is configured (in-memory dev).
Tests: test_ensure_admin.py (8 passed)
* test(auth): port AUTH test plan docs + lint/format pass
- Port backend/docs/AUTH_TEST_PLAN.md and AUTH_UPGRADE.md from PR #1728
- Rename metadata.user_id → metadata.owner_id in AUTH_TEST_PLAN.md
(4 occurrences from the original PR doc)
- ruff auto-fix UP037 in sentinel type annotations: drop quotes around
"str | None | _AutoSentinel" now that from __future__ import
annotations makes them implicit string forms
- ruff format: 2 files (app/gateway/app.py, runtime/user_context.py)
Note on test coverage additions:
- conftest.py autouse fixture was already added in commit 4 (had to
be co-located with the repository changes to keep pre-existing
persistence tests passing)
- cross-user isolation E2E tests (test_owner_isolation.py) deferred
— enforcement is already proven by the 98-test repository suite
via the autouse fixture + explicit _AUTO sentinel exercises
- New test cases (TC-API-17..20, TC-ATK-13, TC-MIG-01..07) listed
in AUTH_TEST_PLAN.md are deferred to a follow-up PR — they are
manual-QA test cases rather than pytest code, and the spec-level
coverage is already met by test_user_context.py + the 98-test
repository suite.
Final test results:
- Auth suite (test_auth*, test_langgraph_auth, test_ensure_admin,
test_user_context): 186 passed
- Persistence suite (test_run_event_store, test_run_repository,
test_thread_meta_repo, test_feedback): 98 passed
- Lint: ruff check + ruff format both clean
* test(auth): add cross-user isolation test suite
10 tests exercising the storage-layer owner filter by manually
switching the user_context contextvar between two users. Verifies
the safety invariant:
After a repository write with owner_id=A, a subsequent read with
owner_id=B must not return the row, and vice versa.
Covers all 4 tables that own user-scoped data:
TC-API-17 threads_meta — read, search, update, delete cross-user
TC-API-18 runs — get, list_by_thread, delete cross-user
TC-API-19 run_events — list_messages, list_events, count_messages,
delete_by_thread (CRITICAL: raw conversation
content leak vector)
TC-API-20 feedback — get, list_by_run, delete cross-user
Plus two meta-tests verifying the sentinel pattern itself:
- AUTO + unset contextvar raises RuntimeError
- explicit owner_id=None bypasses the filter (migration escape hatch)
Architecture note
-----------------
These tests bypass the HTTP layer by design. The full chain
(cookie → middleware → contextvar → repository) is covered piecewise:
- test_auth_middleware.py: middleware sets contextvar from cookies
- test_owner_isolation.py: repositories enforce isolation when
contextvar is set to different users
Together they prove the end-to-end safety property without the
ceremony of spinning up a full TestClient + in-memory DB for every
router endpoint.
Tests pass: 231 (full auth + persistence + isolation suite)
Lint: clean
* refactor(auth): migrate user repository to SQLAlchemy ORM
Move the users table into the shared persistence engine so auth
matches the pattern of threads_meta, runs, run_events, and feedback —
one engine, one session factory, one schema init codepath.
New files
---------
- persistence/user/__init__.py, persistence/user/model.py: UserRow
ORM class with partial unique index on (oauth_provider, oauth_id)
- Registered in persistence/models/__init__.py so
Base.metadata.create_all() picks it up
Modified
--------
- auth/repositories/sqlite.py: rewritten as async SQLAlchemy,
identical constructor pattern to the other four repositories
(def __init__(self, session_factory) + self._sf = session_factory)
- auth/config.py: drop users_db_path field — storage is configured
through config.database like every other table
- deps.py/get_local_provider: construct SQLiteUserRepository with
the shared session factory, fail fast if engine is not initialised
- tests/test_auth.py: rewrite test_sqlite_round_trip_new_fields to
use the shared engine (init_engine + close_engine in a tempdir)
- tests/test_auth_type_system.py: add per-test autouse fixture that
spins up a scratch engine and resets deps._cached_* singletons
* refactor(auth): remove SQL orphan migration (unused in supported scenarios)
The _migrate_orphan_sql_tables helper existed to bind NULL owner_id
rows in threads_meta, runs, run_events, and feedback to the admin on
first boot. But in every supported upgrade path, it's a no-op:
1. Fresh install: create_all builds fresh tables, no legacy rows
2. No-auth → with-auth (no existing persistence DB): persistence
tables are created fresh by create_all, no legacy rows
3. No-auth → with-auth (has existing persistence DB from #1930):
NOT a supported upgrade path — "有 DB 到有 DB" schema evolution
is out of scope; users wipe DB or run manual ALTER
So the SQL orphan migration never has anything to do in the
supported matrix. Delete the function, simplify _ensure_admin_user
from a 3-step pipeline to a 2-step one (admin creation + LangGraph
store orphan migration only).
LangGraph store orphan migration stays: it serves the real
"no-auth → with-auth" upgrade path where a user's existing LangGraph
thread metadata has no owner_id field and needs to be stamped with
the newly-created admin's id.
Tests: 284 passed (auth + persistence + isolation)
Lint: clean
* security(auth): write initial admin password to 0600 file instead of logs
CodeQL py/clear-text-logging-sensitive-data flagged 3 call sites that
logged the auto-generated admin password to stdout via logger.info().
Production log aggregators (ELK/Splunk/etc) would have captured those
cleartext secrets. Replace with a shared helper that writes to
.deer-flow/admin_initial_credentials.txt with mode 0600, and log only
the path.
New file
--------
- app/gateway/auth/credential_file.py: write_initial_credentials()
helper. Takes email, password, and a "initial"/"reset" label.
Creates .deer-flow/ if missing, writes a header comment plus the
email+password, chmods 0o600, returns the absolute Path.
Modified
--------
- app/gateway/app.py: both _ensure_admin_user paths (fresh creation
+ needs_setup password reset) now write to file and log the path
- app/gateway/auth/reset_admin.py: rewritten to use the shared ORM
repo (SQLiteUserRepository with session_factory) and the
credential_file helper. The previous implementation was broken
after the earlier ORM refactor — it still imported _get_users_conn
and constructed SQLiteUserRepository() without a session factory.
No tests changed — the three password-log sites are all exercised
via existing test_ensure_admin.py which checks that startup
succeeds, not that a specific string appears in logs.
CodeQL alerts 272, 283, 284: all resolved.
* security(auth): strict JWT validation in middleware (fix junk cookie bypass)
AUTH_TEST_PLAN test 7.5.8 expects junk cookies to be rejected with
401. The previous middleware behaviour was "presence-only": check
that some access_token cookie exists, then pass through. In
combination with my Task-12 decision to skip @require_auth
decorators on routes, this created a gap where a request with any
cookie-shaped string (e.g. access_token=not-a-jwt) would bypass
authentication on routes that do not touch the repository
(/api/models, /api/mcp/config, /api/memory, /api/skills, …).
Fix: middleware now calls get_current_user_from_request() strictly
and catches the resulting HTTPException to render a 401 with the
proper fine-grained error code (token_invalid, token_expired,
user_not_found, …). On success it stamps request.state.user and
the contextvar so repository-layer owner filters work downstream.
The 4 old "_with_cookie_passes" tests in test_auth_middleware.py
were written for the presence-only behaviour; they asserted that
a junk cookie would make the handler return 200. They are renamed
to "_with_junk_cookie_rejected" and their assertions flipped to
401. The negative path (no cookie → 401 not_authenticated)
is unchanged.
Verified:
no cookie → 401 not_authenticated
junk cookie → 401 token_invalid (the fixed bug)
expired cookie → 401 token_expired
Tests: 284 passed (auth + persistence + isolation)
Lint: clean
* security(auth): wire @require_permission(owner_check=True) on isolation routes
Apply the require_permission decorator to all 28 routes that take a
{thread_id} path parameter. Combined with the strict middleware
(previous commit), this gives the double-layer protection that
AUTH_TEST_PLAN test 7.5.9 documents:
Layer 1 (AuthMiddleware): cookie + JWT validation, rejects junk
cookies and stamps request.state.user
Layer 2 (@require_permission with owner_check=True): per-resource
ownership verification via
ThreadMetaStore.check_access — returns
404 if a different user owns the thread
The decorator's owner_check branch is rewritten to use the SQL
thread_meta_repo (the 2.0-rc persistence layer) instead of the
LangGraph store path that PR #1728 used (_store_get / get_store
in routers/threads.py). The inject_record convenience is dropped
— no caller in 2.0 needs the LangGraph blob, and the SQL repo has
a different shape.
Routes decorated (28 total):
- threads.py: delete, patch, get, get-state, post-state, post-history
- thread_runs.py: post-runs, post-runs-stream, post-runs-wait,
list_runs, get_run, cancel_run, join_run, stream_existing_run,
list_thread_messages, list_run_messages, list_run_events,
thread_token_usage
- feedback.py: create, list, stats, delete
- uploads.py: upload (added Request param), list, delete
- artifacts.py: get_artifact
- suggestions.py: generate (renamed body parameter to avoid
conflict with FastAPI Request)
Test fixes:
- test_suggestions_router.py: bypass the decorator via __wrapped__
(the unit tests cover parsing logic, not auth — no point spinning
up a thread_meta_repo just to test JSON unwrapping)
- test_auth_middleware.py 4 fake-cookie tests: already updated in
the previous commit (745bf432)
Tests: 293 passed (auth + persistence + isolation + suggestions)
Lint: clean
* security(auth): defense-in-depth fixes from release validation pass
Eight findings caught while running the AUTH_TEST_PLAN end-to-end against
the deployed sg_dev stack. Each is a pre-condition for shipping
release/2.0-rc that the previous PRs missed.
Backend hardening
- routers/auth.py: rate limiter X-Real-IP now requires AUTH_TRUSTED_PROXIES
whitelist (CIDR/IP allowlist). Without nginx in front, the previous code
honored arbitrary X-Real-IP, letting an attacker rotate the header to
fully bypass the per-IP login lockout.
- routers/auth.py: 36-entry common-password blocklist via Pydantic
field_validator on RegisterRequest + ChangePasswordRequest. The shared
_validate_strong_password helper keeps the constraint in one place.
- routers/threads.py: ThreadCreateRequest + ThreadPatchRequest strip
server-reserved metadata keys (owner_id, user_id) via Pydantic
field_validator so a forged value can never round-trip back to other
clients reading the same thread. The actual ownership invariant stays
on the threads_meta row; this closes the metadata-blob echo gap.
- authz.py + thread_meta/sql.py: require_permission gains a require_existing
flag plumbed through check_access(require_existing=True). Destructive
routes (DELETE/PATCH/state-update/runs/feedback) now treat a missing
thread_meta row as 404 instead of "untracked legacy thread, allow",
closing the cross-user delete-idempotence gap where any user could
successfully DELETE another user's deleted thread.
- repositories/sqlite.py + base.py: update_user raises UserNotFoundError
on a vanished row instead of silently returning the input. Concurrent
delete during password reset can no longer look like a successful update.
- runtime/user_context.py: resolve_owner_id() coerces User.id (UUID) to
str at the contextvar boundary so SQLAlchemy String(64) columns can
bind it. The whole 2.0-rc isolation pipeline was previously broken
end-to-end (POST /api/threads → 500 "type 'UUID' is not supported").
- persistence/engine.py: SQLAlchemy listener enables PRAGMA journal_mode=WAL,
synchronous=NORMAL, foreign_keys=ON on every new SQLite connection.
TC-UPG-06 in the test plan expects WAL; previous code shipped with the
default 'delete' journal.
- auth_middleware.py: stamp request.state.auth = AuthContext(...) so
@require_permission's short-circuit fires; previously every isolation
request did a duplicate JWT decode + users SELECT. Also unifies the
401 payload through AuthErrorResponse(...).model_dump().
- app.py: _ensure_admin_user restructure removes the noqa F821 scoping
bug where 'password' was referenced outside the branch that defined it.
New _announce_credentials helper absorbs the duplicate log block in
the fresh-admin and reset-admin branches.
* fix(frontend+nginx): rollout CSRF on every state-changing client path
The frontend was 100% broken in gateway-pro mode for any user trying to
open a specific chat thread. Three cumulative bugs each silently
masked the next.
LangGraph SDK CSRF gap (api-client.ts)
- The Client constructor took only apiUrl, no defaultHeaders, no fetch
interceptor. The SDK's internal fetch never sent X-CSRF-Token, so
every state-changing /api/langgraph-compat/* call (runs/stream,
threads/search, threads/{tid}/history, ...) hit CSRFMiddleware and
got 403 before reaching the auth check. UI symptom: empty thread page
with no error message; the SPA's hooks swallowed the rejection.
- Fix: pass an onRequest hook that injects X-CSRF-Token from the
csrf_token cookie per request. Reading the cookie per call (not at
construction time) handles login / logout / password-change cookie
rotation transparently. The SDK's prepareFetchOptions calls
onRequest for both regular requests AND streaming/SSE/reconnect, so
the same hook covers runs.stream and runs.joinStream.
Raw fetch CSRF gap (7 files)
- Audit: 11 frontend fetch sites, only 2 included CSRF (login/setup +
account-settings change-password). The other 7 routed through raw
fetch() with no header — suggestions, memory, agents, mcp, skills,
uploads, and the local thread cleanup hook all 403'd silently.
- Fix: enhance fetcher.ts:fetchWithAuth to auto-inject X-CSRF-Token on
POST/PUT/DELETE/PATCH from a single shared readCsrfCookie() helper.
Convert all 7 raw fetch() callers to fetchWithAuth so the contract
is centrally enforced. api-client.ts and fetcher.ts share
readCsrfCookie + STATE_CHANGING_METHODS to avoid drift.
nginx routing + buffering (nginx.local.conf)
- The auth feature shipped without updating the nginx config: per-API
explicit location blocks but no /api/v1/auth/, /api/feedback, /api/runs.
The frontend's client-side fetches to /api/v1/auth/login/local 404'd
from the Next.js side because nginx routed /api/* to the frontend.
- Fix: add catch-all `location /api/` that proxies to the gateway.
nginx longest-prefix matching keeps the explicit blocks (/api/models,
/api/threads regex, /api/langgraph/, ...) winning for their paths.
- Fix: disable proxy_buffering + proxy_request_buffering for the
frontend `location /` block. Without it, nginx tries to spool large
Next.js chunks into /var/lib/nginx/proxy (root-owned) and fails with
Permission denied → ERR_INCOMPLETE_CHUNKED_ENCODING → ChunkLoadError.
* test(auth): release-validation test infra and new coverage
Test fixtures and unit tests added during the validation pass.
Router test helpers (NEW: tests/_router_auth_helpers.py)
- make_authed_test_app(): builds a FastAPI test app with a stub
middleware that stamps request.state.user + request.state.auth and a
permissive thread_meta_repo mock. TestClient-based router tests
(test_artifacts_router, test_threads_router) use it instead of bare
FastAPI() so the new @require_permission(owner_check=True) decorators
short-circuit cleanly.
- call_unwrapped(): walks the __wrapped__ chain to invoke the underlying
handler without going through the authz wrappers. Direct-call tests
(test_uploads_router) use it. Typed with ParamSpec so the wrapped
signature flows through.
Backend test additions
- test_auth.py: 7 tests for the new _get_client_ip trust model (no
proxy / trusted proxy / untrusted peer / XFF rejection / invalid
CIDR / no client). 5 tests for the password blocklist (literal,
case-insensitive, strong password accepted, change-password binding,
short-password length-check still fires before blocklist).
test_update_user_raises_when_row_concurrently_deleted: closes a
shipped-without-coverage gap on the new UserNotFoundError contract.
- test_thread_meta_repo.py: 4 tests for check_access(require_existing=True)
— strict missing-row denial, strict owner match, strict owner mismatch,
strict null-owner still allowed (shared rows survive the tightening).
- test_ensure_admin.py: 3 tests for _migrate_orphaned_threads /
_iter_store_items pagination, covering the TC-UPG-02 upgrade story
end-to-end via mock store. Closes the gap where the cursor pagination
was untested even though the previous PR rewrote it.
- test_threads_router.py: 5 tests for _strip_reserved_metadata
(owner_id removal, user_id removal, safe-keys passthrough, empty
input, both-stripped).
- test_auth_type_system.py: replace "password123" fixtures with
Tr0ub4dor3a / AnotherStr0ngPwd! so the new password blocklist
doesn't reject the test data.
* docs(auth): refresh TC-DOCKER-05 + document Docker validation gap
- AUTH_TEST_PLAN.md TC-DOCKER-05: the previous expectation
("admin password visible in docker logs") was stale after the simplify
pass that moved credentials to a 0600 file. The grep "Password:" check
would have silently failed and given a false sense of coverage. New
expectation matches the actual file-based path: 0600 file in
DEER_FLOW_HOME, log shows the path (not the secret), reverse-grep
asserts no leaked password in container logs.
- NEW: docs/AUTH_TEST_DOCKER_GAP.md documents the only un-executed
block in the test plan (TC-DOCKER-01..06). Reason: sg_dev validation
host has no Docker daemon installed. The doc maps each Docker case
to an already-validated bare-metal equivalent (TC-1.1, TC-REENT-01,
TC-API-02 etc.) so the gap is auditable, and includes pre-flight
reproduction steps for whoever has Docker available.
---------
Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
This commit is contained in:
@@ -86,8 +86,27 @@ async def init_engine(
|
||||
if backend == "sqlite":
|
||||
import os
|
||||
|
||||
from sqlalchemy import event
|
||||
|
||||
os.makedirs(sqlite_dir or ".", exist_ok=True)
|
||||
_engine = create_async_engine(url, echo=echo, json_serializer=_json_serializer)
|
||||
|
||||
# Enable WAL on every new connection. SQLite PRAGMA settings are
|
||||
# per-connection, so we wire the listener instead of running PRAGMA
|
||||
# once at startup. WAL gives concurrent reads + writers without
|
||||
# blocking and is the standard recommendation for any production
|
||||
# SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion
|
||||
# ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only
|
||||
# at WAL checkpoint boundaries instead of every commit.
|
||||
@event.listens_for(_engine.sync_engine, "connect")
|
||||
def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract
|
||||
cursor = dbapi_conn.cursor()
|
||||
try:
|
||||
cursor.execute("PRAGMA journal_mode=WAL;")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL;")
|
||||
cursor.execute("PRAGMA foreign_keys=ON;")
|
||||
finally:
|
||||
cursor.close()
|
||||
elif backend == "postgres":
|
||||
_engine = create_async_engine(
|
||||
url,
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy import case, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
|
||||
|
||||
class FeedbackRepository:
|
||||
@@ -32,18 +33,19 @@ class FeedbackRepository:
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
owner_id: str | None = None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
message_id: str | None = None,
|
||||
comment: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a feedback record. rating must be +1 or -1."""
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
|
||||
row = FeedbackRow(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
owner_id=owner_id,
|
||||
owner_id=resolved_owner_id,
|
||||
message_id=message_id,
|
||||
rating=rating,
|
||||
comment=comment,
|
||||
@@ -55,27 +57,66 @@ class FeedbackRepository:
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def get(self, feedback_id: str) -> dict | None:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
return self._row_to_dict(row) if row else None
|
||||
|
||||
async def list_by_run(self, thread_id: str, run_id: str, *, limit: int = 100) -> list[dict]:
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict]:
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def delete(self, feedback_id: str) -> bool:
|
||||
async def get(
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> bool:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return False
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return False
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
@@ -7,6 +7,7 @@ The actual ORM classes have moved to entity-specific subpackages:
|
||||
- ``deerflow.persistence.thread_meta``
|
||||
- ``deerflow.persistence.run``
|
||||
- ``deerflow.persistence.feedback``
|
||||
- ``deerflow.persistence.user``
|
||||
|
||||
``RunEventRow`` remains in ``deerflow.persistence.models.run_event`` because
|
||||
its storage implementation lives in ``deerflow.runtime.events.store.db`` and
|
||||
@@ -17,5 +18,6 @@ from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow"]
|
||||
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"]
|
||||
|
||||
@@ -16,6 +16,10 @@ class RunEventRow(Base):
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# Owner of the conversation this event belongs to. Nullable for data
|
||||
# created before auth was introduced; populated by auth middleware on
|
||||
# new writes and by the boot-time orphan migration on existing rows.
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
# "message" | "trace" | "lifecycle"
|
||||
|
||||
@@ -16,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
|
||||
|
||||
class RunRepository(RunStore):
|
||||
@@ -68,7 +69,7 @@ class RunRepository(RunStore):
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
owner_id=None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
@@ -77,12 +78,13 @@ class RunRepository(RunStore):
|
||||
created_at=None,
|
||||
follow_up_to_run_id=None,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
|
||||
now = datetime.now(UTC)
|
||||
row = RunRow(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=owner_id,
|
||||
owner_id=resolved_owner_id,
|
||||
status=status,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata_json=self._safe_json(metadata) or {},
|
||||
@@ -96,15 +98,32 @@ class RunRepository(RunStore):
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
|
||||
async def get(self, run_id):
|
||||
async def get(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
return self._row_to_dict(row) if row else None
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
limit=100,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
|
||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||
if owner_id is not None:
|
||||
stmt = stmt.where(RunRow.owner_id == owner_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
@@ -118,12 +137,21 @@ class RunRepository(RunStore):
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, run_id):
|
||||
async def delete(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is not None:
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
async def list_pending(self, *, before=None):
|
||||
if before is None:
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
|
||||
|
||||
class ThreadMetaRepository(ThreadMetaStore):
|
||||
@@ -31,15 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
# Auto-resolve owner_id from contextvar when AUTO; explicit None
|
||||
# creates an orphan row (used by migration scripts).
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
|
||||
now = datetime.now(UTC)
|
||||
row = ThreadMetaRow(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=owner_id,
|
||||
owner_id=resolved_owner_id,
|
||||
display_name=display_name,
|
||||
metadata_json=metadata or {},
|
||||
created_at=now,
|
||||
@@ -51,10 +55,21 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def get(self, thread_id: str) -> dict | None:
|
||||
async def get(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
return self._row_to_dict(row) if row else None
|
||||
if row is None:
|
||||
return None
|
||||
# Enforce owner filter unless explicitly bypassed (owner_id=None).
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
||||
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
||||
@@ -62,16 +77,32 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def check_access(self, thread_id: str, owner_id: str) -> bool:
|
||||
"""Check if owner_id has access to thread_id.
|
||||
async def check_access(self, thread_id: str, owner_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``owner_id`` has access to ``thread_id``.
|
||||
|
||||
Returns True if: row doesn't exist (untracked thread), owner_id
|
||||
is None on the row (shared thread), or owner_id matches.
|
||||
Two modes — one row, two distinct semantics depending on what
|
||||
the caller is about to do:
|
||||
|
||||
- ``require_existing=False`` (default, permissive):
|
||||
Returns True for: row missing (untracked legacy thread),
|
||||
``row.owner_id`` is None (shared / pre-auth data),
|
||||
or ``row.owner_id == owner_id``. Use for **read-style**
|
||||
decorators where treating an untracked thread as accessible
|
||||
preserves backward-compat.
|
||||
|
||||
- ``require_existing=True`` (strict):
|
||||
Returns True **only** when the row exists AND
|
||||
(``row.owner_id == owner_id`` OR ``row.owner_id is None``).
|
||||
Use for **destructive / mutating** decorators (DELETE, PATCH,
|
||||
state-update) so a thread that has *already been deleted*
|
||||
cannot be re-targeted by any caller — closing the
|
||||
delete-idempotence cross-user gap where the row vanishing
|
||||
made every other user appear to "own" it.
|
||||
"""
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return True
|
||||
return not require_existing
|
||||
if row.owner_id is None:
|
||||
return True
|
||||
return row.owner_id == owner_id
|
||||
@@ -83,9 +114,17 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
"""Search threads with optional metadata and status filters."""
|
||||
"""Search threads with optional metadata and status filters.
|
||||
|
||||
Owner filter is enforced by default: caller must be in a user
|
||||
context. Pass ``owner_id=None`` to bypass (migration/CLI).
|
||||
"""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
|
||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
|
||||
if status:
|
||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||
|
||||
@@ -105,36 +144,80 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
|
||||
"""Return True if the row exists and is owned (or filter bypassed)."""
|
||||
if resolved_owner_id is None:
|
||||
return True # explicit bypass
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
return row is not None and row.owner_id == resolved_owner_id
|
||||
|
||||
async def update_display_name(
|
||||
self,
|
||||
thread_id: str,
|
||||
display_name: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Update the display_name (title) for a thread."""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def update_status(self, thread_id: str, status: str) -> None:
|
||||
async def update_status(
|
||||
self,
|
||||
thread_id: str,
|
||||
status: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
||||
async def update_metadata(
|
||||
self,
|
||||
thread_id: str,
|
||||
metadata: dict,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Merge ``metadata`` into ``metadata_json``.
|
||||
|
||||
Read-modify-write inside a single session/transaction so concurrent
|
||||
callers see consistent state. No-op if the row does not exist.
|
||||
callers see consistent state. No-op if the row does not exist or
|
||||
the owner_id check fails.
|
||||
"""
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return
|
||||
merged = dict(row.metadata_json or {})
|
||||
merged.update(metadata)
|
||||
row.metadata_json = merged
|
||||
row.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, thread_id: str) -> None:
|
||||
async def delete(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is not None:
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
if row is None:
|
||||
return
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
"""User storage subpackage.
|
||||
|
||||
Holds the ORM model for the ``users`` table. The concrete repository
|
||||
implementation (``SQLiteUserRepository``) lives in the app layer
|
||||
(``app.gateway.auth.repositories.sqlite``) because it converts
|
||||
between the ORM row and the auth module's pydantic ``User`` class.
|
||||
This keeps the harness package free of any dependency on app code.
|
||||
"""
|
||||
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
__all__ = ["UserRow"]
|
||||
@@ -0,0 +1,59 @@
|
||||
"""ORM model for the users table.
|
||||
|
||||
Lives in the harness persistence package so it is picked up by
|
||||
``Base.metadata.create_all()`` alongside ``threads_meta``, ``runs``,
|
||||
``run_events``, and ``feedback``. Using the shared engine means:
|
||||
|
||||
- One SQLite/Postgres database, one connection pool
|
||||
- One schema initialisation codepath
|
||||
- Consistent async sessions across auth and persistence reads
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Index, String, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class UserRow(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
# UUIDs are stored as 36-char strings for cross-backend portability.
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
|
||||
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
# "admin" | "user" — kept as plain string to avoid ALTER TABLE pain
|
||||
# when new roles are introduced.
|
||||
system_role: Mapped[str] = mapped_column(String(16), nullable=False, default="user")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# OAuth linkage (optional). A partial unique index enforces one
|
||||
# account per (provider, oauth_id) pair, leaving NULL/NULL rows
|
||||
# unconstrained so plain password accounts can coexist.
|
||||
oauth_provider: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
||||
oauth_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
# Auth lifecycle flags
|
||||
needs_setup: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
token_version: Mapped[int] = mapped_column(nullable=False, default=0)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"idx_users_oauth_identity",
|
||||
"oauth_provider",
|
||||
"oauth_id",
|
||||
unique=True,
|
||||
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
|
||||
),
|
||||
)
|
||||
@@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -53,6 +54,18 @@ class DbRunEventStore(RunEventStore):
|
||||
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
|
||||
return content, metadata or {}
|
||||
|
||||
@staticmethod
|
||||
def _owner_from_context() -> str | None:
|
||||
"""Soft read of owner_id from contextvar for write paths.
|
||||
|
||||
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
||||
which is the expected case for background worker writes. HTTP
|
||||
request writes will have the contextvar set by auth middleware
|
||||
and get their user_id stamped automatically.
|
||||
"""
|
||||
user = get_current_user()
|
||||
return user.id if user is not None else None
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||
"""Write a single event — low-frequency path only.
|
||||
|
||||
@@ -68,6 +81,7 @@ class DbRunEventStore(RunEventStore):
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
owner_id = self._owner_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
@@ -78,6 +92,7 @@ class DbRunEventStore(RunEventStore):
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
owner_id=owner_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=db_content,
|
||||
@@ -91,6 +106,7 @@ class DbRunEventStore(RunEventStore):
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
owner_id = self._owner_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
@@ -114,6 +130,7 @@ class DbRunEventStore(RunEventStore):
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
owner_id=e.get("owner_id", owner_id),
|
||||
event_type=e["event_type"],
|
||||
category=category,
|
||||
content=db_content,
|
||||
@@ -125,8 +142,19 @@ class DbRunEventStore(RunEventStore):
|
||||
rows.append(row)
|
||||
return [self._row_to_dict(r) for r in rows]
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
limit=50,
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
@@ -146,8 +174,19 @@ class DbRunEventStore(RunEventStore):
|
||||
rows = list(result.scalars())
|
||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
event_types=None,
|
||||
limit=500,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if event_types:
|
||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
@@ -155,31 +194,68 @@ class DbRunEventStore(RunEventStore):
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id):
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message").order_by(RunEventRow.seq.asc())
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc())
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
async def count_messages(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
|
||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
async def delete_by_thread(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
|
||||
async with self._sf() as session:
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id)
|
||||
count_conditions = [RunEventRow.thread_id == thread_id]
|
||||
if resolved_owner_id is not None:
|
||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id))
|
||||
await session.execute(delete(RunEventRow).where(*count_conditions))
|
||||
await session.commit()
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
async def delete_by_run(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
|
||||
async with self._sf() as session:
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
||||
if resolved_owner_id is not None:
|
||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id))
|
||||
await session.execute(delete(RunEventRow).where(*count_conditions))
|
||||
await session.commit()
|
||||
return count
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
"""Request-scoped user context for owner-based authorization.
|
||||
|
||||
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
||||
auth middleware sets after a successful authentication. Repository
|
||||
methods read the contextvar via a sentinel default parameter, letting
|
||||
routers stay free of ``owner_id`` boilerplate.
|
||||
|
||||
Three-state semantics for the repository ``owner_id`` parameter (the
|
||||
consumer side of this module lives in ``deerflow.persistence.*``):
|
||||
|
||||
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
||||
raise :class:`RuntimeError` if unset.
|
||||
- Explicit ``str``: use the provided value, overriding contextvar.
|
||||
- Explicit ``None``: no WHERE clause — used only by migration scripts
|
||||
and admin CLIs that intentionally bypass isolation.
|
||||
|
||||
Dependency direction
|
||||
--------------------
|
||||
``persistence`` (lower layer) reads from this module; ``gateway.auth``
|
||||
(higher layer) writes to it. ``CurrentUser`` is defined here as a
|
||||
:class:`typing.Protocol` so that ``persistence`` never needs to import
|
||||
the concrete ``User`` class from ``gateway.auth.models``. Any object
|
||||
with an ``.id: str`` attribute structurally satisfies the protocol.
|
||||
|
||||
Asyncio semantics
|
||||
-----------------
|
||||
``ContextVar`` is task-local under asyncio, not thread-local. Each
|
||||
FastAPI request runs in its own task, so the context is naturally
|
||||
isolated. ``asyncio.create_task`` and ``asyncio.to_thread`` inherit the
|
||||
parent task's context, which is typically the intended behaviour; if
|
||||
a background task must *not* see the foreground user, wrap it with
|
||||
``contextvars.copy_context()`` to get a clean copy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Final, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CurrentUser(Protocol):
|
||||
"""Structural type for the current authenticated user.
|
||||
|
||||
Any object with an ``.id: str`` attribute satisfies this protocol.
|
||||
Concrete implementations live in ``app.gateway.auth.models.User``.
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
|
||||
_current_user: Final[ContextVar[CurrentUser | None]] = ContextVar("deerflow_current_user", default=None)
|
||||
|
||||
|
||||
def set_current_user(user: CurrentUser) -> Token[CurrentUser | None]:
|
||||
"""Set the current user for this async task.
|
||||
|
||||
Returns a reset token that should be passed to
|
||||
:func:`reset_current_user` in a ``finally`` block to restore the
|
||||
previous context.
|
||||
"""
|
||||
return _current_user.set(user)
|
||||
|
||||
|
||||
def reset_current_user(token: Token[CurrentUser | None]) -> None:
|
||||
"""Restore the context to the state captured by ``token``."""
|
||||
_current_user.reset(token)
|
||||
|
||||
|
||||
def get_current_user() -> CurrentUser | None:
|
||||
"""Return the current user, or ``None`` if unset.
|
||||
|
||||
Safe to call in any context. Used by code paths that can proceed
|
||||
without a user (e.g. migration scripts, public endpoints).
|
||||
"""
|
||||
return _current_user.get()
|
||||
|
||||
|
||||
def require_current_user() -> CurrentUser:
|
||||
"""Return the current user, or raise :class:`RuntimeError`.
|
||||
|
||||
Used by repository code that must not be called outside a
|
||||
request-authenticated context. The error message is phrased so
|
||||
that a caller debugging a stack trace can locate the offending
|
||||
code path.
|
||||
"""
|
||||
user = _current_user.get()
|
||||
if user is None:
|
||||
raise RuntimeError("repository accessed without user context")
|
||||
return user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sentinel-based owner_id resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods accept an ``owner_id`` keyword-only argument that
|
||||
# defaults to ``AUTO``. The three possible values drive distinct
|
||||
# behaviours; see the docstring on :func:`resolve_owner_id`.
|
||||
|
||||
|
||||
class _AutoSentinel:
|
||||
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
|
||||
|
||||
_instance: _AutoSentinel | None = None
|
||||
|
||||
def __new__(cls) -> _AutoSentinel:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<AUTO>"
|
||||
|
||||
|
||||
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
||||
|
||||
|
||||
def resolve_owner_id(
|
||||
value: str | None | _AutoSentinel,
|
||||
*,
|
||||
method_name: str = "repository method",
|
||||
) -> str | None:
|
||||
"""Resolve the owner_id parameter passed to a repository method.
|
||||
|
||||
Three-state semantics:
|
||||
|
||||
- :data:`AUTO` (default): read from contextvar; raise
|
||||
:class:`RuntimeError` if no user is in context. This is the
|
||||
common case for request-scoped calls.
|
||||
- Explicit ``str``: use the provided id verbatim, overriding any
|
||||
contextvar value. Useful for tests and admin-override flows.
|
||||
- Explicit ``None``: no filter — the repository should skip the
|
||||
owner_id WHERE clause entirely. Reserved for migration scripts
|
||||
and CLI tools that intentionally bypass isolation.
|
||||
"""
|
||||
if isinstance(value, _AutoSentinel):
|
||||
user = _current_user.get()
|
||||
if user is None:
|
||||
raise RuntimeError(f"{method_name} called with owner_id=AUTO but no user context is set; pass an explicit owner_id, set the contextvar via auth middleware, or opt out with owner_id=None for migration/CLI paths.")
|
||||
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
|
||||
# ``UUID`` for the API surface, but the persistence layer
|
||||
# stores ``owner_id`` as ``String(64)`` and aiosqlite cannot
|
||||
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
|
||||
# not supported"). Honour the documented return type here
|
||||
# rather than ripple a type change through every caller.
|
||||
return str(user.id)
|
||||
return value
|
||||
Reference in New Issue
Block a user