Compare commits

..

9 Commits

Author SHA1 Message Date
Willem Jiang 2d5f0787de Update lint-check.yml with the job setting 2026-06-11 00:07:36 +08:00
Huixin615 5819bd8a59 fix(frontend): paginate workspace chat list beyond 50 threads (#3482) (#3485)
* fix(frontend): paginate workspace chat list beyond 50 threads (#3482)

The sidebar 'Recent chats' and /workspace/chats list were hard-capped
at the first 50 threads returned by threads.search. Replace the
single-shot useThreads() consumers with useInfiniteThreads() and add
an IntersectionObserver sentinel to each list so further pages are
fetched on demand.

In search mode on the chats page, the sentinel is replaced by an
explicit 'Load more' button to prevent the observer from draining the
entire backend list while the filtered view stays empty.

- Add useInfiniteThreads + page-size constant and pure cache helpers
  (map/filterInfiniteThreadsCache, getInfiniteThreadsNextPageParam)
- Mirror rename / delete / stream-finish updates into the new
  infinite cache so optimistic UI stays consistent
- Extend the e2e mock to honour limit/offset slicing
- Unit tests for the cache helpers and pagination boundary
- Playwright e2e covering chats page + sidebar load-more, and the
  search-mode guard against runaway auto-pagination
- Add en/zh i18n entries for the search-mode load-more button

Fixes #3482

* docs(frontend): clarify infinite-threads offset semantics and test post-delete invariant

- Add docstring to getInfiniteThreadsNextPageParam explaining that TanStack
  Query freezes the returned offset into pageParams once, so optimistic cache
  mutations that shrink page lengths (filterInfiniteThreadsCache on delete)
  cannot retroactively move the offset backwards. Delete/rename paths
  reconcile against the backend via invalidateQueries in onSettled.
- Add unit test covering the post-delete invariant.
- Fix misleading comment in thread-list-infinite-scroll.spec.ts: the
  thread-search mock does not sort by updated_at; it returns the array in
  the order provided.

Addresses Copilot CR comments on #3485.

* fix(frontend): mirror onCreated upsert into infinite cache; add sidebar Load-older button

Address review feedback on #3485:

- New upsertThreadInInfiniteCache helper; useThreadStream onCreated now
  upserts into both the legacy ['threads','search'] cache and the new
  infinite cache, so a freshly created thread appears in the sidebar
  immediately during streaming instead of only after the run finishes
  and onSettled invalidates the query. Restores parity with main.
- Sidebar Recent Chats now exposes a visible 'Load older chats' button
  alongside the IntersectionObserver sentinel, so keyboard-only users
  and environments where IO is unavailable can still reach older
  conversations.
- Add zh-CN / en-US / types entry for chats.loadOlderChats.
- Cover the new helper with 3 unit tests (no-op on uninitialised cache,
  prepend new thread to first page, merge with existing entry without
  duplication).
2026-06-10 23:59:38 +08:00
hataa b3c2cc42cf fix(agents): require config.yaml in resolve_agent_dir to skip memory-only directories (#3390) (#3481)
When memory is enabled, the first conversation with a legacy shared agent
creates a per-user agent directory containing only memory.json (no
config.yaml). On the second turn, resolve_agent_dir() returned this
incomplete directory, causing load_agent_config() to fail with
"Agent config not found".

Require config.yaml to exist alongside the directory for both the
per-user and legacy paths, so that memory-only directories fall
through correctly. This aligns resolve_agent_dir with the existing
config.yaml check in list_custom_agents.

Refs: https://github.com/bytedance/deer-flow/issues/3390
2026-06-10 23:57:17 +08:00
Ryker_Feng 167ef4512f feat(memory): add memory.token_counting config to avoid tiktoken network dependency (#3429) (#3465)
* feat(memory): add memory.token_counting config to avoid tiktoken network dependency (#3429)

Add a `memory.token_counting` option (`tiktoken` | `char`) so deployments in
network-restricted environments can opt out of tiktoken entirely. In `char`
mode the memory-injection budget uses a network-free character-based estimate
and never triggers the BPE download from openaipublic.blob.core.windows.net,
which could otherwise block for tens of minutes (see #3402).

Also harden the default `tiktoken` path:
- cache an in-flight LOADING sentinel so concurrent callers fall back
  immediately instead of spawning more blocking get_encoding threads when the
  first load is still running (e.g. under the 5s startup warm-up timeout);
- cache failures with a timestamp and retry after a cooldown so a transient
  network outage self-heals back to accurate counting without a restart;
- skip startup warm-up entirely in char mode.

The new config is surfaced via the memory config API and config.example.yaml
(config_version bumped). Default remains `tiktoken`, so existing deployments
are unaffected.

* fix(memory): use CJK-aware char token estimate and address review feedback

- Replace the flat len(text)//4 fallback with a CJK-aware estimate so
  Chinese/Japanese/Korean memory content does not over-fill the injection budget
- Document the internal tiktoken retry cooldown and char-mode escape hatch
- Sync CLAUDE.md / config.example.yaml / MEMORY_IMPROVEMENTS.md wording
- Fix MemoryConfigResponse mocks/assertions and add CJK estimate tests
2026-06-10 23:26:15 +08:00
Xinmin Zeng ba9cc5e972 fix(gateway): enforce thread ownership on stateless run endpoints (#3473)
POST /api/runs/stream and /api/runs/wait accept thread_id in the request
body but performed no owner authorization, letting any authenticated user
start runs on -- and read /wait checkpoint channel_values from -- another
user's thread (cross-user IDOR, #3472).

The @require_permission(owner_check=True) decorator resolves ownership from
the thread_id *path* param, so it cannot cover these body-param endpoints.
Enforce ownership inside start_run() before create_or_reject via
ThreadMetaStore.check_access: missing rows (auto-created temp threads) and
NULL-owner rows stay accessible, while a thread owned by another user
returns 404 (matching thread_runs.py). The internal system role (IM
channels acting for platform users) is exempt.

Closes #3472
2026-06-10 23:03:39 +08:00
Xinmin Zeng 05ae4467ae fix(docker): default Gateway to a single worker to prevent multi-worker breakage (#3475)
The default `make up` started the Gateway with `--workers 4`, but run state
(RunManager and the stream bridge) is held in-process and nginx uses no sticky
sessions. With the default config, same-run requests scatter across workers that
each keep their own run state, breaking run cancellation (409), SSE reconnect
(hangs on heartbeats), multitask de-duplication, and IM channels (duplicate
replies). The shared cross-worker stream bridge does not exist yet.

Default GATEWAY_WORKERS to 1 so the out-of-the-box deployment is correct,
document the single-worker boundary in the README, and add a regression test
pinning the default while keeping it overridable. This is a stop-gap, not a
multi-worker implementation; the full fix (shared run state + stream bridge) is
tracked in #3191.

Refs #3239, #3260
2026-06-10 21:36:25 +08:00
DanielWalnut 2b795265e7 fix: align auth-disabled mode and mock history loading (#3471)
* fix: align auth-disabled mode and mock history loading

* fix: address auth-disabled review feedback

* test: cover auth-disabled backend contract

* style: format frontend tests

* fix: address follow-up review comments
2026-06-10 16:11:00 +08:00
Nan Gao a57d05fe0a fix runtime journal run lifecycle events (#3470) 2026-06-10 08:33:29 +08:00
Lucy Shen ae9e8bc0bf fix(sandbox): make missing sandbox.mounts host_path a loud ERROR (#3244) (#3250)
In Docker production deployments, LocalSandboxProvider runs inside the
deer-flow-gateway container, so any `sandbox.mounts[].host_path` from
config.yaml is resolved against the gateway container's filesystem — not
the host machine. When the path isn't also bind-mounted into the gateway
service, the mount was silently dropped with only a WARNING log, leaving
agents reading an empty directory in production while the same config
worked under `make dev`.

Escalate the missing-host_path branch to logger.error with explicit
guidance about Docker bind mounts and docker-compose, so the failure is
hard to miss in default log configurations. Skip behaviour is preserved
to avoid breaking existing deployments.

Also clarify the misleading `VolumeMountConfig.host_path` field
description so it documents reality for both providers:

  - LocalSandboxProvider checks host_path from inside the gateway process
    (host in `make dev`, container in `make up`).
  - AioSandboxProvider (DooD) passes host_path straight to `docker -v`
    for the sandbox container, where the host Docker daemon resolves it
    from the host machine's perspective.

config.example.yaml's `sandbox.mounts` comment gets a Note: block
pointing operators at the docker-compose bind-mount requirement so the
Docker-mode gotcha is discoverable from the canonical template.

Adds a regression test that:
  - confirms missing host_path is still skipped (no behaviour break);
  - asserts an ERROR record is emitted referencing the offending paths;
  - asserts the message contains actionable Docker/gateway/docker-compose
    keywords so future refactors can't quietly downgrade it.

Refs: https://github.com/bytedance/deer-flow/issues/3244
2026-06-09 23:16:14 +08:00
51 changed files with 2016 additions and 117 deletions
+1 -1
View File
@@ -10,7 +10,7 @@ permissions:
contents: read contents: read
jobs: jobs:
lint: lint-backend:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v6
+3
View File
@@ -247,6 +247,9 @@ Access: http://localhost:2026
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks. The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
> [!IMPORTANT]
> The Gateway holds run state (RunManager and the stream bridge) in process, so production defaults to a single Gateway worker (`GATEWAY_WORKERS=1`). Raising the worker count without a shared cross-worker stream bridge — which is not yet available — breaks run cancellation, SSE reconnects, request de-duplication, and IM channels, because nginx uses no sticky sessions and each worker keeps its own run state. Scale a single worker up with more CPU/RAM (or move the database and sandbox onto dedicated tiers) instead of raising `GATEWAY_WORKERS`.
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide. See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
#### Option 2: Local Development #### Option 2: Local Development
+7
View File
@@ -429,6 +429,12 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append 4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt 5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
**Token counting** (`packages/harness/deerflow/agents/memory/prompt.py`):
- `_count_tokens` budgets the injection. In default `tiktoken` mode, the encoding is loaded lazily and cached.
- Failed tiktoken loads are cached with a timestamp. During the fixed cooldown (`_TIKTOKEN_RETRY_COOLDOWN_S`, 600s), callers fall back to char estimation immediately instead of re-triggering the blocking BPE download; after the cooldown, transient outages can self-heal without a restart.
- In-flight loads are cached as a LOADING sentinel so concurrent callers fall back instead of spawning more blocking threads.
- Set `memory.token_counting: char` to skip tiktoken entirely and use the network-free CJK-aware char estimate.
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`. Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
**Configuration** (`config.yaml``memory`): **Configuration** (`config.yaml``memory`):
@@ -438,6 +444,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
- `model_name` - LLM for updates (null = default model) - `model_name` - LLM for updates (null = default model)
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7) - `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
- `max_injection_tokens` - Token limit for prompt injection (2000) - `max_injection_tokens` - Token limit for prompt injection (2000)
- `token_counting` - Token counting strategy for the injection budget: `tiktoken` (default, accurate but may download BPE data from a public endpoint on first use — can block for a long time in network-restricted environments, see issues #3402/#3429) or `char` (network-free CJK-aware char estimate, never touches tiktoken)
### Reflection System (`packages/harness/deerflow/reflection/`) ### Reflection System (`packages/harness/deerflow/reflection/`)
+10 -2
View File
@@ -6,6 +6,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
from app.gateway.auth_middleware import AuthMiddleware from app.gateway.auth_middleware import AuthMiddleware
from app.gateway.config import get_gateway_config from app.gateway.config import get_gateway_config
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
@@ -172,6 +173,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
startup_config = get_app_config() startup_config = get_app_config()
apply_logging_level(startup_config.log_level) apply_logging_level(startup_config.log_level)
logger.info("Configuration loaded successfully") logger.info("Configuration loaded successfully")
warn_if_auth_disabled_enabled()
except Exception as e: except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}" error_msg = f"Failed to load configuration during gateway startup: {e}"
logger.exception(error_msg) logger.exception(error_msg)
@@ -182,6 +184,12 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Pre-warm tiktoken encoding cache so the first memory-injection request # Pre-warm tiktoken encoding cache so the first memory-injection request
# never blocks on the BPE data download (which hits an OpenAI/Azure URL # never blocks on the BPE data download (which hits an OpenAI/Azure URL
# that may be unreachable in restricted networks — see issue #3402). # that may be unreachable in restricted networks — see issue #3402).
# When memory.token_counting is "char", token counting never touches
# tiktoken, so skip the warm-up entirely (avoids even the 5s probe in
# network-restricted deployments — see issue #3429).
if startup_config.memory.token_counting == "char":
logger.info("memory.token_counting='char'; skipping tiktoken warm-up (network-free token estimation)")
else:
try: try:
from deerflow.agents.memory.prompt import warm_tiktoken_cache from deerflow.agents.memory.prompt import warm_tiktoken_cache
@@ -192,9 +200,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
if warmed: if warmed:
logger.info("tiktoken encoding cache warmed successfully") logger.info("tiktoken encoding cache warmed successfully")
else: else:
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback") logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback until tiktoken loads successfully")
except TimeoutError: except TimeoutError:
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback") logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback until tiktoken loads successfully")
except Exception: except Exception:
logger.warning("tiktoken warm-up skipped", exc_info=True) logger.warning("tiktoken warm-up skipped", exc_info=True)
+54
View File
@@ -0,0 +1,54 @@
"""Shared helpers for local/E2E auth-disabled mode."""
from __future__ import annotations
import logging
import os
from types import SimpleNamespace
AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED"
AUTH_DISABLED_USER_ID = "e2e-user"
AUTH_DISABLED_USER_EMAIL = "e2e@test.local"
AUTH_SOURCE_SESSION = "session"
AUTH_SOURCE_INTERNAL = "internal"
AUTH_SOURCE_AUTH_DISABLED = "auth_disabled"
_PRODUCTION_ENV_VARS: tuple[str, ...] = ("DEER_FLOW_ENV", "ENVIRONMENT")
_PRODUCTION_ENV_VALUES: frozenset[str] = frozenset({"prod", "production"})
logger = logging.getLogger(__name__)
def is_explicit_production_environment() -> bool:
return any(os.environ.get(name, "").strip().lower() in _PRODUCTION_ENV_VALUES for name in _PRODUCTION_ENV_VARS)
def is_auth_disabled_requested() -> bool:
return os.environ.get(AUTH_DISABLED_ENV_VAR) == "1"
def is_auth_disabled() -> bool:
return is_auth_disabled_requested() and not is_explicit_production_environment()
def warn_if_auth_disabled_enabled() -> None:
if not is_auth_disabled():
return
logger.warning(
"%s=1 is active: authentication is bypassed and anonymous requests run as synthetic admin user %r. Do not enable this in shared or production deployments.",
AUTH_DISABLED_ENV_VAR,
AUTH_DISABLED_USER_ID,
)
def get_auth_disabled_user():
return SimpleNamespace(
id=AUTH_DISABLED_USER_ID,
email=AUTH_DISABLED_USER_EMAIL,
password_hash=None,
system_role="admin",
needs_setup=False,
token_version=0,
)
+31 -14
View File
@@ -17,6 +17,13 @@ from starlette.responses import JSONResponse
from starlette.types import ASGIApp from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.auth_disabled import (
AUTH_SOURCE_AUTH_DISABLED,
AUTH_SOURCE_INTERNAL,
AUTH_SOURCE_SESSION,
get_auth_disabled_user,
is_auth_disabled,
)
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
from deerflow.runtime.user_context import reset_current_user, set_current_user from deerflow.runtime.user_context import reset_current_user, set_current_user
@@ -80,18 +87,14 @@ class AuthMiddleware(BaseHTTPMiddleware):
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)): if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
internal_user = get_internal_user() internal_user = get_internal_user()
# Non-public path: require session cookie auth_source = AUTH_SOURCE_SESSION
if internal_user is None and not request.cookies.get("access_token"): access_token = request.cookies.get("access_token")
return JSONResponse(
status_code=401,
content={
"detail": AuthErrorResponse(
code=AuthErrorCode.NOT_AUTHENTICATED,
message="Authentication required",
).model_dump()
},
)
# Non-public path: require session cookie
if internal_user is not None:
user = internal_user
auth_source = AUTH_SOURCE_INTERNAL
elif access_token:
# Strict JWT validation: reject junk/expired tokens with 401 # Strict JWT validation: reject junk/expired tokens with 401
# right here instead of silently passing through. This closes # right here instead of silently passing through. This closes
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8): # the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
@@ -105,19 +108,33 @@ class AuthMiddleware(BaseHTTPMiddleware):
# bubble up, so we catch and render it as JSONResponse here. # bubble up, so we catch and render it as JSONResponse here.
from app.gateway.deps import get_current_user_from_request from app.gateway.deps import get_current_user_from_request
if internal_user is not None:
user = internal_user
else:
try: try:
user = await get_current_user_from_request(request) user = await get_current_user_from_request(request)
except HTTPException as exc: except HTTPException as exc:
if not is_auth_disabled():
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
user = get_auth_disabled_user()
auth_source = AUTH_SOURCE_AUTH_DISABLED
elif is_auth_disabled():
user = get_auth_disabled_user()
auth_source = AUTH_SOURCE_AUTH_DISABLED
else:
return JSONResponse(
status_code=401,
content={
"detail": AuthErrorResponse(
code=AuthErrorCode.NOT_AUTHENTICATED,
message="Authentication required",
).model_dump()
},
)
# Stamp both request.state.user (for the contextvar pattern) # Stamp both request.state.user (for the contextvar pattern)
# and request.state.auth (so @require_permission's "auth is # and request.state.auth (so @require_permission's "auth is
# None" branch short-circuits instead of running the entire # None" branch short-circuits instead of running the entire
# JWT-decode + DB-lookup pipeline a second time per request). # JWT-decode + DB-lookup pipeline a second time per request).
request.state.user = user request.state.user = user
request.state.auth_source = auth_source
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS) request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
token = set_current_user(user) token = set_current_user(user)
try: try:
+5
View File
@@ -14,6 +14,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from starlette.types import ASGIApp from starlette.types import ASGIApp
from app.gateway.auth_disabled import is_auth_disabled
CSRF_COOKIE_NAME = "csrf_token" CSRF_COOKIE_NAME = "csrf_token"
CSRF_HEADER_NAME = "X-CSRF-Token" CSRF_HEADER_NAME = "X-CSRF-Token"
CSRF_TOKEN_LENGTH = 64 # bytes CSRF_TOKEN_LENGTH = 64 # bytes
@@ -38,6 +40,9 @@ def should_check_csrf(request: Request) -> bool:
if request.method not in ("POST", "PUT", "DELETE", "PATCH"): if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
return False return False
if is_auth_disabled():
return False
path = request.url.path.rstrip("/") path = request.url.path.rstrip("/")
# Exempt /api/v1/auth/me endpoint # Exempt /api/v1/auth/me endpoint
if path == "/api/v1/auth/me": if path == "/api/v1/auth/me":
+11
View File
@@ -331,6 +331,17 @@ async def get_current_user_from_request(request: Request):
Raises HTTPException 401 if not authenticated. Raises HTTPException 401 if not authenticated.
""" """
state = getattr(request, "state", None)
state_user = getattr(state, "user", None)
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED, AUTH_SOURCE_INTERNAL, AUTH_SOURCE_SESSION
if state_user is not None and getattr(state, "auth_source", None) in {
AUTH_SOURCE_SESSION,
AUTH_SOURCE_AUTH_DISABLED,
AUTH_SOURCE_INTERNAL,
}:
return state_user
from app.gateway.auth import decode_token from app.gateway.auth import decode_token
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
+7
View File
@@ -20,6 +20,7 @@ from langgraph_sdk import Auth
from app.gateway.auth.errors import TokenError from app.gateway.auth.errors import TokenError
from app.gateway.auth.jwt import decode_token from app.gateway.auth.jwt import decode_token
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
from app.gateway.deps import get_local_provider from app.gateway.deps import get_local_provider
auth = Auth() auth = Auth()
@@ -38,6 +39,9 @@ def _check_csrf(request) -> None:
if method.upper() not in _CSRF_METHODS: if method.upper() not in _CSRF_METHODS:
return return
if is_auth_disabled():
return
cookie_token = request.cookies.get("csrf_token") cookie_token = request.cookies.get("csrf_token")
header_token = request.headers.get("x-csrf-token") header_token = request.headers.get("x-csrf-token")
@@ -66,6 +70,9 @@ async def authenticate(request):
# are rejected early, even if the cookie carries a valid JWT. # are rejected early, even if the cookie carries a valid JWT.
_check_csrf(request) _check_csrf(request)
if is_auth_disabled():
return AUTH_DISABLED_USER_ID
token = request.cookies.get("access_token") token = request.cookies.get("access_token")
if not token: if not token:
raise Auth.exceptions.HTTPException( raise Auth.exceptions.HTTPException(
+10
View File
@@ -341,9 +341,19 @@ async def change_password(request: Request, response: Response, body: ChangePass
- Re-issues session cookie with new token_version - Re-issues session cookie with new token_version
""" """
from app.gateway.auth.password import hash_password_async, verify_password_async from app.gateway.auth.password import hash_password_async, verify_password_async
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED
user = await get_current_user_from_request(request) user = await get_current_user_from_request(request)
if getattr(request.state, "auth_source", None) == AUTH_SOURCE_AUTH_DISABLED:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=AuthErrorResponse(
code=AuthErrorCode.INVALID_CREDENTIALS,
message="Password changes are not available when DEER_FLOW_AUTH_DISABLED=1.",
).model_dump(),
)
if user.password_hash is None: if user.password_hash is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump()) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
+5 -1
View File
@@ -98,6 +98,7 @@ class MemoryConfigResponse(BaseModel):
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts") fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
injection_enabled: bool = Field(..., description="Whether memory injection is enabled") injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection") max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
token_counting: str = Field(..., description="Token counting strategy for memory injection ('tiktoken' or 'char')")
class MemoryStatusResponse(BaseModel): class MemoryStatusResponse(BaseModel):
@@ -310,7 +311,8 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
"max_facts": 100, "max_facts": 100,
"fact_confidence_threshold": 0.7, "fact_confidence_threshold": 0.7,
"injection_enabled": true, "injection_enabled": true,
"max_injection_tokens": 2000 "max_injection_tokens": 2000,
"token_counting": "tiktoken"
} }
``` ```
""" """
@@ -323,6 +325,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
fact_confidence_threshold=config.fact_confidence_threshold, fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled, injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens, max_injection_tokens=config.max_injection_tokens,
token_counting=config.token_counting,
) )
@@ -351,6 +354,7 @@ async def get_memory_status() -> MemoryStatusResponse:
fact_confidence_threshold=config.fact_confidence_threshold, fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled, injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens, max_injection_tokens=config.max_injection_tokens,
token_counting=config.token_counting,
), ),
data=MemoryResponse(**memory_data), data=MemoryResponse(**memory_data),
) )
+15
View File
@@ -315,6 +315,21 @@ async def start_run(
detail=f"Model {model_name!r} is not in the configured model allowlist", detail=f"Model {model_name!r} is not in the configured model allowlist",
) )
# Stateless run endpoints carry thread_id in the request *body*, so the
# @require_permission(owner_check=True) decorator -- which resolves ownership
# from the path param -- cannot protect them. Enforce thread ownership here,
# before any run is created, so one user cannot start runs on (or read /wait
# checkpoint state from) another user's thread. Missing rows (auto-created
# temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible
# via check_access; only a thread already owned by another user is rejected
# with 404, matching thread_runs.py's anti-enumeration behaviour. Internal
# channel runs act on behalf of IM users they do not own (see
# inject_authenticated_user_context), so the internal system role is exempt.
user = getattr(request.state, "user", None)
if user is not None and getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
if not await run_ctx.thread_store.check_access(thread_id, str(user.id)):
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
try: try:
record = await run_mgr.create_or_reject( record = await run_mgr.create_or_reject(
thread_id, thread_id,
+2 -1
View File
@@ -31,7 +31,8 @@ Current injection format:
Token counting: Token counting:
- Uses `tiktoken` (`cl100k_base`) when available - Uses `tiktoken` (`cl100k_base`) when available
- Falls back to `len(text) // 4` if tokenizer import fails - Falls back to a network-free CJK-aware character estimate if tokenizer import or encoding load fails
(CJK characters count as ~2 chars/token, other characters as ~4 chars/token)
## Known Gap ## Known Gap
@@ -586,7 +586,11 @@ def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig
return "" return ""
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id()) memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens) memory_content = format_memory_for_injection(
memory_data,
max_tokens=config.max_injection_tokens,
use_tiktoken=(config.token_counting == "tiktoken"),
)
if not memory_content.strip(): if not memory_content.strip():
return "" return ""
@@ -5,7 +5,9 @@ from __future__ import annotations
import logging import logging
import math import math
import re import re
from typing import Any import threading
import time
from typing import Any, cast
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -169,7 +171,26 @@ Return ONLY valid JSON."""
# subsequent calls are a dict lookup (no network I/O). Pre-warming at # subsequent calls are a dict lookup (no network I/O). Pre-warming at
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the # startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
# (potentially slow) first ``get_encoding`` call. # (potentially slow) first ``get_encoding`` call.
_tiktoken_encoding_cache: dict[str, tiktoken.Encoding] = {} #
# A *failed* load is cached as a ``(None, monotonic_timestamp)`` tuple so that
# a network-restricted environment does not re-attempt the blocking BPE
# download on every subsequent call. After ``_TIKTOKEN_RETRY_COOLDOWN_S`` the
# failure is allowed to expire so a transient network outage can self-heal back
# to accurate tiktoken counting without a process restart. A load already in
# progress is cached as ``_TIKTOKEN_ENCODING_LOADING`` so concurrent callers
# fall back immediately instead of spawning more blocking
# ``tiktoken.get_encoding`` threads. Use the ``memory.token_counting: char``
# config to skip tiktoken entirely.
_TIKTOKEN_ENCODING_MISSING = object()
_TIKTOKEN_ENCODING_LOADING = object()
# Cooldown before a *failed* tiktoken load is re-attempted. This is an internal
# tuning constant rather than a user-facing config: it only affects how quickly
# the default ``tiktoken`` mode self-heals after a transient network outage.
# Deployments that want to avoid tiktoken's network dependency entirely should
# set ``memory.token_counting: char`` instead of tuning this value.
_TIKTOKEN_RETRY_COOLDOWN_S = 600.0
_tiktoken_encoding_cache: dict[str, Any] = {}
_tiktoken_encoding_cache_lock = threading.Lock()
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None: def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
@@ -181,44 +202,91 @@ def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encod
download can block for tens of minutes before the OS TCP timeout kicks in. download can block for tens of minutes before the OS TCP timeout kicks in.
The caller must therefore be prepared for this to block and should run it The caller must therefore be prepared for this to block and should run it
off the event loop (e.g. via ``asyncio.to_thread``). off the event loop (e.g. via ``asyncio.to_thread``).
A failed load is remembered (with a timestamp) so subsequent calls fall
back immediately to character-based estimation instead of re-triggering the
blocking download. The failure expires after ``_TIKTOKEN_RETRY_COOLDOWN_S``
so a transient outage can self-heal without a restart. A load already in
progress is also remembered so that a timed-out caller does not leave a
window where later requests start more blocking ``get_encoding`` calls.
""" """
if not TIKTOKEN_AVAILABLE: if not TIKTOKEN_AVAILABLE:
return None return None
cached = _tiktoken_encoding_cache.get(encoding_name) with _tiktoken_encoding_cache_lock:
if cached is not None: cached = _tiktoken_encoding_cache.get(encoding_name, _TIKTOKEN_ENCODING_MISSING)
return cached if cached is _TIKTOKEN_ENCODING_LOADING:
return None
if isinstance(cached, tuple):
# Cached failure: (None, failed_at). Retry only after cooldown.
_, failed_at = cached
if time.monotonic() - failed_at < _TIKTOKEN_RETRY_COOLDOWN_S:
return None
cached = _TIKTOKEN_ENCODING_MISSING
if cached is not _TIKTOKEN_ENCODING_MISSING:
return cast("tiktoken.Encoding", cached)
_tiktoken_encoding_cache[encoding_name] = _TIKTOKEN_ENCODING_LOADING
try: try:
encoding = tiktoken.get_encoding(encoding_name) encoding = tiktoken.get_encoding(encoding_name)
_tiktoken_encoding_cache[encoding_name] = encoding
return encoding
except Exception: except Exception:
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True) logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
with _tiktoken_encoding_cache_lock:
_tiktoken_encoding_cache[encoding_name] = (None, time.monotonic())
return None return None
with _tiktoken_encoding_cache_lock:
_tiktoken_encoding_cache[encoding_name] = encoding
return encoding
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
def _char_based_token_estimate(text: str) -> int:
"""Network-free token estimate that accounts for CJK density.
The plain ``len(text) // 4`` heuristic is reasonable for English/code
(~4 chars per token) but significantly under-estimates token counts for
Chinese, Japanese, and Korean text, where the ratio is closer to 1.5-2
characters per token. Counting CJK characters separately (~2 chars per
token) avoids over-filling the injection budget for CJK-heavy memory
content.
"""
cjk = sum(
1
for ch in text
if "\u4e00" <= ch <= "\u9fff" # CJK Unified Ideographs
or "\u3040" <= ch <= "\u30ff" # Hiragana + Katakana
or "\uac00" <= ch <= "\ud7a3" # Hangul syllables
)
return (len(text) - cjk) // 4 + cjk // 2
def _count_tokens(text: str, encoding_name: str = "cl100k_base", *, use_tiktoken: bool = True) -> int:
"""Count tokens in text using tiktoken. """Count tokens in text using tiktoken.
Args: Args:
text: The text to count tokens for. text: The text to count tokens for.
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5). encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
use_tiktoken: When ``False``, skip tiktoken entirely and use the
network-free character-based estimate. This guarantees no BPE
download is attempted (see ``memory.token_counting`` config).
Returns: Returns:
The number of tokens in the text. The number of tokens in the text.
""" """
if not use_tiktoken:
return _char_based_token_estimate(text)
encoding = _get_tiktoken_encoding(encoding_name) encoding = _get_tiktoken_encoding(encoding_name)
if encoding is None: if encoding is None:
# Fallback to character-based estimation if tiktoken is not available # Fallback to CJK-aware character estimation if tiktoken is not
# or the encoding failed to load. # available or the encoding failed to load.
return len(text) // 4 return _char_based_token_estimate(text)
try: try:
return len(encoding.encode(text)) return len(encoding.encode(text))
except Exception: except Exception:
# Fallback to character-based estimation on error # Fallback to CJK-aware character estimation on error.
return len(text) // 4 return _char_based_token_estimate(text)
def warm_tiktoken_cache() -> bool: def warm_tiktoken_cache() -> bool:
@@ -248,12 +316,15 @@ def _coerce_confidence(value: Any, default: float = 0.0) -> float:
return max(0.0, min(1.0, confidence)) return max(0.0, min(1.0, confidence))
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str: def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000, *, use_tiktoken: bool = True) -> str:
"""Format memory data for injection into system prompt. """Format memory data for injection into system prompt.
Args: Args:
memory_data: The memory data dictionary. memory_data: The memory data dictionary.
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy). max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
use_tiktoken: When ``False``, all token counting uses the network-free
character-based estimate instead of tiktoken (see
``memory.token_counting`` config). Defaults to ``True``.
Returns: Returns:
Formatted memory string for system prompt injection. Formatted memory string for system prompt injection.
@@ -315,10 +386,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
# Compute token count for existing sections once, then account # Compute token count for existing sections once, then account
# incrementally for each fact line to avoid full-string re-tokenization. # incrementally for each fact line to avoid full-string re-tokenization.
base_text = "\n\n".join(sections) base_text = "\n\n".join(sections)
base_tokens = _count_tokens(base_text) if base_text else 0 base_tokens = _count_tokens(base_text, use_tiktoken=use_tiktoken) if base_text else 0
# Account for the separator between existing sections and the facts section. # Account for the separator between existing sections and the facts section.
facts_header = "Facts:\n" facts_header = "Facts:\n"
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header) separator_tokens = _count_tokens("\n\n" + facts_header, use_tiktoken=use_tiktoken) if base_text else _count_tokens(facts_header, use_tiktoken=use_tiktoken)
running_tokens = base_tokens + separator_tokens running_tokens = base_tokens + separator_tokens
fact_lines: list[str] = [] fact_lines: list[str] = []
@@ -339,7 +410,7 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
# Each additional line is preceded by a newline (except the first). # Each additional line is preceded by a newline (except the first).
line_text = ("\n" + line) if fact_lines else line line_text = ("\n" + line) if fact_lines else line
line_tokens = _count_tokens(line_text) line_tokens = _count_tokens(line_text, use_tiktoken=use_tiktoken)
if running_tokens + line_tokens <= max_tokens: if running_tokens + line_tokens <= max_tokens:
fact_lines.append(line) fact_lines.append(line)
@@ -355,8 +426,9 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
result = "\n\n".join(sections) result = "\n\n".join(sections)
# Use accurate token counting with tiktoken # Use accurate token counting with tiktoken (or the char-based estimate
token_count = _count_tokens(result) # when use_tiktoken is False).
token_count = _count_tokens(result, use_tiktoken=use_tiktoken)
if token_count > max_tokens: if token_count > max_tokens:
# Truncate to fit within token limit # Truncate to fit within token limit
# Estimate characters to remove based on token ratio # Estimate characters to remove based on token ratio
@@ -1141,6 +1141,7 @@ class DeerFlowClient:
"fact_confidence_threshold": config.fact_confidence_threshold, "fact_confidence_threshold": config.fact_confidence_threshold,
"injection_enabled": config.injection_enabled, "injection_enabled": config.injection_enabled,
"max_injection_tokens": config.max_injection_tokens, "max_injection_tokens": config.max_injection_tokens,
"token_counting": config.token_counting,
} }
def get_memory_status(self) -> dict: def get_memory_status(self) -> dict:
@@ -67,11 +67,13 @@ def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
paths = get_paths() paths = get_paths()
effective_user = user_id or get_effective_user_id() effective_user = user_id or get_effective_user_id()
user_path = paths.user_agent_dir(effective_user, name) user_path = paths.user_agent_dir(effective_user, name)
if user_path.exists(): # Require config.yaml to confirm this is a genuine agent directory,
# not a leftover from memory/storage writes (see #3390).
if user_path.exists() and (user_path / "config.yaml").exists():
return user_path return user_path
legacy_path = paths.agent_dir(name) legacy_path = paths.agent_dir(name)
if legacy_path.exists(): if legacy_path.exists() and (legacy_path / "config.yaml").exists():
return legacy_path return legacy_path
return user_path return user_path
@@ -1,5 +1,7 @@
"""Configuration for memory mechanism.""" """Configuration for memory mechanism."""
from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -60,6 +62,17 @@ class MemoryConfig(BaseModel):
le=8000, le=8000,
description="Maximum tokens to use for memory injection", description="Maximum tokens to use for memory injection",
) )
token_counting: Literal["tiktoken", "char"] = Field(
default="tiktoken",
description=(
"Token counting strategy for memory-injection budgeting. "
"'tiktoken' is accurate but the encoding's BPE data may be "
"downloaded from a public network endpoint on first use, which "
"can block for a long time in network-restricted environments "
"(see issue #3402/#3429). 'char' uses a network-free "
"CJK-aware character-based estimate and never touches tiktoken."
),
)
# Global configuration instance # Global configuration instance
@@ -4,7 +4,20 @@ from pydantic import BaseModel, ConfigDict, Field
class VolumeMountConfig(BaseModel): class VolumeMountConfig(BaseModel):
"""Configuration for a volume mount.""" """Configuration for a volume mount."""
host_path: str = Field(..., description="Path on the host machine") host_path: str = Field(
...,
description=(
"Source path for the mount. Resolution depends on the active provider: "
"``LocalSandboxProvider`` checks this path from the gateway process — in "
"``make dev`` that is the host machine, but in Docker deployments "
"(``make up`` / docker-compose) it is the path *inside* the "
"``deer-flow-gateway`` container, so the host directory must also be "
"bind-mounted into the gateway service for the mount to take effect. "
"``AioSandboxProvider`` (DooD) passes this value straight to ``docker -v`` "
"for the sandbox container, where it is resolved by the host Docker daemon "
"from the host machine's perspective."
),
)
container_path: str = Field(..., description="Path inside the container") container_path: str = Field(..., description="Path inside the container")
read_only: bool = Field(default=False, description="Whether the mount is read-only") read_only: bool = Field(default=False, description="Whether the mount is read-only")
@@ -164,7 +164,18 @@ class RunJournal(BaseCallbackHandler):
metadata={"caller": caller, **(metadata or {})}, metadata={"caller": caller, **(metadata or {})},
) )
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None: def on_chain_end(
self,
outputs: Any,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
# Nested chain ends fire for internal graph nodes; only the root chain
# represents the user-visible run lifecycle.
if parent_run_id is not None:
return
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"}) self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
self._flush_sync() self._flush_sync()
@@ -147,7 +147,17 @@ class LocalSandboxProvider(SandboxProvider):
mount.container_path, mount.container_path,
) )
continue continue
# Ensure the host path exists before adding mapping # Ensure the host path exists before adding mapping.
#
# ``host_path`` is resolved against the filesystem of the
# process running this provider — for ``make dev`` that is
# the host machine, but for ``make up`` it is the
# ``deer-flow-gateway`` container, so any host path that
# isn't bind-mounted into the gateway image will be missing
# here. Skipping silently makes this a high-cost-to-debug
# silent failure (sandbox skill / tool reads an empty dir
# instead of the configured mount), so escalate to ERROR
# and include actionable guidance. See #3244.
if host_path.exists(): if host_path.exists():
mappings.append( mappings.append(
PathMapping( PathMapping(
@@ -157,10 +167,16 @@ class LocalSandboxProvider(SandboxProvider):
) )
) )
else: else:
logger.warning( logger.error(
"Mount host_path does not exist, skipping: %s -> %s", "sandbox.mounts entry %s -> %s ignored: host_path %s does not exist from the "
"perspective of the gateway process. In Docker deployments (make up / docker-compose), "
"this path must also be bind-mounted into the gateway container — add a matching "
"volume entry under services.gateway.volumes in docker/docker-compose.yaml (and use "
"the in-container path here), or run in local mode (make dev) where the gateway sees "
"the host filesystem directly.",
mount.host_path, mount.host_path,
mount.container_path, mount.container_path,
mount.host_path,
) )
except Exception as e: except Exception as e:
# Log but don't fail if config loading fails # Log but don't fail if config loading fails
+187 -4
View File
@@ -4,6 +4,7 @@ import pytest
from starlette.testclient import TestClient from starlette.testclient import TestClient
from app.gateway.auth_middleware import AuthMiddleware, _is_public from app.gateway.auth_middleware import AuthMiddleware, _is_public
from app.gateway.csrf_middleware import CSRFMiddleware
# ── _is_public unit tests ───────────────────────────────────────────────── # ── _is_public unit tests ─────────────────────────────────────────────────
@@ -88,7 +89,9 @@ def test_unknown_api_path_is_protected():
def _make_app(): def _make_app():
"""Create a minimal FastAPI app with AuthMiddleware for testing.""" """Create a minimal FastAPI app with AuthMiddleware for testing."""
from fastapi import FastAPI from fastapi import FastAPI, Request
from deerflow.runtime.user_context import get_effective_user_id
app = FastAPI() app = FastAPI()
app.add_middleware(AuthMiddleware) app.add_middleware(AuthMiddleware)
@@ -98,8 +101,16 @@ def _make_app():
return {"status": "ok"} return {"status": "ok"}
@app.get("/api/v1/auth/me") @app.get("/api/v1/auth/me")
async def auth_me(): async def auth_me(request: Request):
return {"id": "1", "email": "test@test.com"} from app.gateway.deps import get_current_user_from_request
user = await get_current_user_from_request(request)
return {
"id": str(user.id),
"email": user.email,
"system_role": user.system_role,
"needs_setup": user.needs_setup,
}
@app.get("/api/v1/auth/setup-status") @app.get("/api/v1/auth/setup-status")
async def setup_status(): async def setup_status():
@@ -109,6 +120,29 @@ def _make_app():
async def models_get(): async def models_get():
return {"models": []} return {"models": []}
@app.get("/api/whoami")
async def whoami(request: Request):
user = request.state.user
return {
"id": str(user.id),
"email": getattr(user, "email", None),
"system_role": getattr(user, "system_role", None),
"context_user_id": get_effective_user_id(),
}
@app.get("/api/current-user-from-dep")
async def current_user_from_dep(request: Request):
from app.gateway.deps import get_current_user_from_request
user = await get_current_user_from_request(request)
state_user = request.state.user
return {
"id": str(user.id),
"state_id": str(state_user.id),
"auth_source": request.state.auth_source,
"context_user_id": get_effective_user_id(),
}
@app.put("/api/mcp/config") @app.put("/api/mcp/config")
async def mcp_put(): async def mcp_put():
return {"ok": True} return {"ok": True}
@@ -132,8 +166,24 @@ def _make_app():
return app return app
def _make_auth_csrf_app():
"""Create a minimal app with production middleware ordering."""
from fastapi import FastAPI
app = FastAPI()
app.add_middleware(AuthMiddleware)
app.add_middleware(CSRFMiddleware)
@app.post("/api/threads/abc/runs/stream")
async def protected_mutation():
return {"ok": True}
return app
@pytest.fixture @pytest.fixture
def client(): def client(monkeypatch):
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
return TestClient(_make_app()) return TestClient(_make_app())
@@ -161,6 +211,139 @@ def test_protected_path_no_cookie_returns_401(client):
assert body["detail"]["code"] == "not_authenticated" assert body["detail"]["code"] == "not_authenticated"
def test_auth_disabled_allows_protected_path_without_cookie(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
client = TestClient(_make_app())
res = client.get("/api/models")
assert res.status_code == 200
assert res.json() == {"models": []}
def test_auth_disabled_stamps_e2e_admin_user_without_cookie(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
client = TestClient(_make_app())
res = client.get("/api/whoami")
assert res.status_code == 200
assert res.json() == {
"id": "e2e-user",
"email": "e2e@test.local",
"system_role": "admin",
"context_user_id": "e2e-user",
}
def test_auth_disabled_auth_me_reuses_middleware_user_without_cookie(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
client = TestClient(_make_app())
res = client.get("/api/v1/auth/me")
assert res.status_code == 200
assert res.json() == {
"id": "e2e-user",
"email": "e2e@test.local",
"system_role": "admin",
"needs_setup": False,
}
def test_auth_disabled_does_not_clobber_valid_session_cookie(monkeypatch):
from types import SimpleNamespace
async def fake_current_user(request):
return SimpleNamespace(
id="session-user",
email="session@test.local",
system_role="user",
needs_setup=False,
)
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
monkeypatch.setattr("app.gateway.deps.get_current_user_from_request", fake_current_user)
client = TestClient(_make_app())
res = client.get("/api/whoami", cookies={"access_token": "valid-session"})
assert res.status_code == 200
assert res.json() == {
"id": "session-user",
"email": "session@test.local",
"system_role": "user",
"context_user_id": "session-user",
}
def test_auth_disabled_does_not_clobber_internal_auth_identity(monkeypatch):
from app.gateway.internal_auth import create_internal_auth_headers
from deerflow.runtime.user_context import DEFAULT_USER_ID
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
client = TestClient(_make_app())
res = client.get(
"/api/current-user-from-dep",
headers=create_internal_auth_headers(),
)
assert res.status_code == 200
assert res.json() == {
"id": DEFAULT_USER_ID,
"state_id": DEFAULT_USER_ID,
"auth_source": "internal",
"context_user_id": DEFAULT_USER_ID,
}
def test_auth_disabled_skips_csrf_for_state_changing_requests(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
client = TestClient(_make_auth_csrf_app())
res = client.post("/api/threads/abc/runs/stream")
assert res.status_code == 200
assert res.json() == {"ok": True}
def test_auth_disabled_is_ignored_in_explicit_production_env(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
monkeypatch.setenv("DEER_FLOW_ENV", "production")
client = TestClient(_make_app())
res = client.get("/api/models")
assert res.status_code == 401
def test_auth_disabled_startup_warning_when_effective(monkeypatch, caplog):
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
monkeypatch.delenv("DEER_FLOW_ENV", raising=False)
monkeypatch.delenv("ENVIRONMENT", raising=False)
with caplog.at_level("WARNING", logger="app.gateway.auth_disabled"):
warn_if_auth_disabled_enabled()
assert "authentication is bypassed" in caplog.text
assert "e2e-user" in caplog.text
def test_auth_disabled_startup_warning_suppressed_in_explicit_production_env(monkeypatch, caplog):
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
monkeypatch.setenv("ENVIRONMENT", "production")
with caplog.at_level("WARNING", logger="app.gateway.auth_disabled"):
warn_if_auth_disabled_enabled()
assert "authentication is bypassed" not in caplog.text
def test_protected_path_with_junk_cookie_rejected(client): def test_protected_path_with_junk_cookie_rejected(client):
"""Junk cookie → 401. Middleware strictly validates the JWT now """Junk cookie → 401. Middleware strictly validates the JWT now
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad (AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
+4
View File
@@ -2472,6 +2472,7 @@ class TestGatewayConformance:
mem_cfg.fact_confidence_threshold = 0.7 mem_cfg.fact_confidence_threshold = 0.7
mem_cfg.injection_enabled = True mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000 mem_cfg.max_injection_tokens = 2000
mem_cfg.token_counting = "tiktoken"
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg): with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
result = client.get_memory_config() result = client.get_memory_config()
@@ -2479,6 +2480,7 @@ class TestGatewayConformance:
parsed = MemoryConfigResponse(**result) parsed = MemoryConfigResponse(**result)
assert parsed.enabled is True assert parsed.enabled is True
assert parsed.max_facts == 100 assert parsed.max_facts == 100
assert parsed.token_counting == "tiktoken"
def test_get_memory_status(self, client): def test_get_memory_status(self, client):
mem_cfg = MagicMock() mem_cfg = MagicMock()
@@ -2489,6 +2491,7 @@ class TestGatewayConformance:
mem_cfg.fact_confidence_threshold = 0.7 mem_cfg.fact_confidence_threshold = 0.7
mem_cfg.injection_enabled = True mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000 mem_cfg.max_injection_tokens = 2000
mem_cfg.token_counting = "tiktoken"
memory_data = { memory_data = {
"version": "1.0", "version": "1.0",
@@ -2514,6 +2517,7 @@ class TestGatewayConformance:
parsed = MemoryStatusResponse(**result) parsed = MemoryStatusResponse(**result)
assert parsed.config.enabled is True assert parsed.config.enabled is True
assert parsed.config.token_counting == "tiktoken"
assert parsed.data.version == "1.0" assert parsed.data.version == "1.0"
@@ -0,0 +1,45 @@
"""Regression test for the Docker Compose default Gateway worker count.
The Gateway holds run state (RunManager and the stream bridge) in process, so
the default deployment must run a single Uvicorn worker. Running more than one
worker without a shared cross-worker stream bridge breaks run cancellation, SSE
reconnects, request de-duplication, and IM channels (nginx has no sticky
sessions, so requests scatter across workers that each keep their own run
state). This test pins the safe default so it cannot silently regress to a
multi-worker default, while still allowing operators to override it once a
shared stream bridge exists.
"""
from __future__ import annotations
import re
from pathlib import Path
import yaml
REPO_ROOT = Path(__file__).resolve().parents[2]
COMPOSE_PATH = REPO_ROOT / "docker" / "docker-compose.yaml"
def _gateway_command() -> str:
"""Return the gateway service command as a single string."""
compose = yaml.safe_load(COMPOSE_PATH.read_text(encoding="utf-8"))
command = compose["services"]["gateway"]["command"]
# ``command`` may load as a scalar string or a list depending on YAML style.
if isinstance(command, list):
command = " ".join(str(part) for part in command)
return command
def test_gateway_defaults_to_single_worker():
"""With GATEWAY_WORKERS unset, the worker count must default to 1."""
command = _gateway_command()
match = re.search(r"GATEWAY_WORKERS:-(\d+)", command)
assert match is not None, f"gateway command must set a GATEWAY_WORKERS default; got: {command}"
assert match.group(1) == "1", f"default Gateway worker count must be 1, got {match.group(1)}"
def test_gateway_worker_count_remains_overridable():
"""The worker count must stay configurable, not hard-coded to 1."""
command = _gateway_command()
assert "${GATEWAY_WORKERS:-1}" in command, f"worker count must use ${{GATEWAY_WORKERS:-1}} so operators can override it; got: {command}"
+73
View File
@@ -203,6 +203,79 @@ class TestLoadAgentConfig:
assert cfg.name == "legacy-agent" assert cfg.name == "legacy-agent"
# ===========================================================================
# 3b. resolve_agent_dir — memory-only directory fallback (#3390)
# ===========================================================================
class TestResolveAgentDirMemoryOnlyFallback:
"""Regression tests for #3390.
When memory is enabled, the first conversation creates a user-isolated
agent directory containing only ``memory.json`` (no ``config.yaml``).
On the next turn ``resolve_agent_dir`` must fall through to the legacy
shared layout instead of returning the incomplete user directory.
"""
def test_user_dir_with_only_memory_falls_back_to_legacy(self, tmp_path):
"""User dir has memory.json but no config.yaml → use legacy dir."""
from deerflow.config.agents_config import resolve_agent_dir
# Legacy agent with full config
legacy_dir = tmp_path / "agents" / "my-agent"
legacy_dir.mkdir(parents=True)
(legacy_dir / "config.yaml").write_text("name: my-agent\n", encoding="utf-8")
(legacy_dir / "SOUL.md").write_text("legacy soul", encoding="utf-8")
# User dir created by memory write — no config.yaml
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
user_dir.mkdir(parents=True)
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
result = resolve_agent_dir("my-agent", user_id="u1")
assert result == legacy_dir
def test_user_dir_with_config_takes_priority(self, tmp_path):
"""User dir with config.yaml should still win over legacy."""
from deerflow.config.agents_config import resolve_agent_dir
# Legacy
legacy_dir = tmp_path / "agents" / "my-agent"
legacy_dir.mkdir(parents=True)
(legacy_dir / "config.yaml").write_text("name: my-agent\n", encoding="utf-8")
# User dir with full config (migrated)
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
user_dir.mkdir(parents=True)
(user_dir / "config.yaml").write_text("name: my-agent\nmodel: gpt-4\n", encoding="utf-8")
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
result = resolve_agent_dir("my-agent", user_id="u1")
assert result == user_dir
def test_load_config_falls_back_when_user_dir_is_memory_only(self, tmp_path):
"""End-to-end: load_agent_config works when user dir only has memory.json."""
config_dict = {"name": "my-agent", "description": "Legacy agent", "model": "deepseek-v3"}
_write_agent(tmp_path, "my-agent", config_dict)
# Simulate memory write creating user dir without config
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
user_dir.mkdir(parents=True)
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
from deerflow.config.agents_config import load_agent_config
cfg = load_agent_config("my-agent", user_id="u1")
assert cfg.name == "my-agent"
assert cfg.model == "deepseek-v3"
# =========================================================================== # ===========================================================================
# 4. load_agent_soul # 4. load_agent_soul
# =========================================================================== # ===========================================================================
+9
View File
@@ -21,6 +21,7 @@ from langgraph_sdk import Auth
from app.gateway.auth.config import AuthConfig, set_auth_config from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.jwt import create_access_token, decode_token from app.gateway.auth.jwt import create_access_token, decode_token
from app.gateway.auth.models import User from app.gateway.auth.models import User
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
from app.gateway.langgraph_auth import add_owner_filter, authenticate from app.gateway.langgraph_auth import add_owner_filter, authenticate
# ── Helpers ─────────────────────────────────────────────────────────────── # ── Helpers ───────────────────────────────────────────────────────────────
@@ -59,6 +60,14 @@ def test_no_cookie_raises_401():
assert "Not authenticated" in str(exc.value.detail) assert "Not authenticated" in str(exc.value.detail)
def test_auth_disabled_skips_csrf_and_authenticates_e2e_user(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
identity = asyncio.run(authenticate(_req(method="POST")))
assert identity == AUTH_DISABLED_USER_ID
def test_invalid_jwt_raises_401(): def test_invalid_jwt_raises_401():
with pytest.raises(Auth.exceptions.HTTPException) as exc: with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": "garbage"}))) asyncio.run(authenticate(_req({"access_token": "garbage"})))
+4 -2
View File
@@ -192,7 +192,7 @@ def test_build_acp_section_uses_explicit_app_config_without_global_config(monkey
def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch): def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch):
explicit_config = SimpleNamespace( explicit_config = SimpleNamespace(
memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234), memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234, token_counting="tiktoken"),
) )
captured: dict[str, object] = {} captured: dict[str, object] = {}
@@ -204,9 +204,10 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke
captured["user_id"] = user_id captured["user_id"] = user_id
return {"facts": []} return {"facts": []}
def fake_format_memory_for_injection(memory_data, *, max_tokens): def fake_format_memory_for_injection(memory_data, *, max_tokens, use_tiktoken=True):
captured["memory_data"] = memory_data captured["memory_data"] = memory_data
captured["max_tokens"] = max_tokens captured["max_tokens"] = max_tokens
captured["use_tiktoken"] = use_tiktoken
return "remember this" return "remember this"
monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config) monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config)
@@ -223,6 +224,7 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke
"user_id": "user-1", "user_id": "user-1",
"memory_data": {"facts": []}, "memory_data": {"facts": []},
"max_tokens": 1234, "max_tokens": 1234,
"use_tiktoken": True,
} }
@@ -612,6 +612,54 @@ class TestLocalSandboxProviderMounts:
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"] assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
def test_setup_path_mappings_logs_actionable_error_for_missing_host_path(self, tmp_path, caplog):
"""Regression for #3244.
When ``sandbox.mounts[].host_path`` is absent from the gateway process's
filesystem (the typical symptom in Docker production mode: host_path is a
host machine path that is not bind-mounted into the gateway container),
the mount is still skipped but the failure must be a hard-to-miss ERROR
log with explicit, actionable guidance about Docker bind mounts, not the
old DEBUG/WARNING that buried the silent failure.
"""
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
missing_host_path = tmp_path / "does-not-exist"
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path=str(missing_host_path), container_path="/mnt/knowledge", read_only=True),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir, use="deerflow.skills.storage.local_skill_storage:LocalSkillStorage"),
sandbox=sandbox_config,
)
with caplog.at_level("ERROR", logger="deerflow.sandbox.local.local_sandbox_provider"):
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
# Silent-skip behaviour is preserved (no breaking change for existing deployments).
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
# The failure must be observable at ERROR level and reference the offending paths.
error_records = [r for r in caplog.records if r.levelname == "ERROR"]
assert error_records, "expected an ERROR log when host_path is missing"
message = "\n".join(r.getMessage() for r in error_records)
assert str(missing_host_path) in message
assert "/mnt/knowledge" in message
# And it must include actionable Docker guidance so users don't lose hours
# to a silent empty-mount failure in production.
lowered = message.lower()
assert "docker" in lowered
assert "gateway" in lowered
assert "docker-compose" in lowered
def test_write_file_resolves_container_paths_in_content(self, tmp_path): def test_write_file_resolves_container_paths_in_content(self, tmp_path):
"""write_file should replace container paths in file content with local paths.""" """write_file should replace container paths in file content with local paths."""
data_dir = tmp_path / "data" data_dir = tmp_path / "data"
@@ -39,7 +39,7 @@ def test_format_memory_sorts_facts_by_confidence_desc() -> None:
def test_format_memory_respects_budget_when_adding_facts(monkeypatch) -> None: def test_format_memory_respects_budget_when_adding_facts(monkeypatch) -> None:
# Make token counting deterministic for this test by counting characters. # Make token counting deterministic for this test by counting characters.
monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base": len(text)) monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base", *, use_tiktoken=True: len(text))
memory_data = { memory_data = {
"user": {}, "user": {},
+4 -3
View File
@@ -179,15 +179,16 @@ class TestLifecycleCallbacks:
assert "run.end" in types assert "run.end" in types
@pytest.mark.anyio @pytest.mark.anyio
async def test_nested_chain_no_run_start(self, journal_setup): async def test_nested_chain_no_run_lifecycle_events(self, journal_setup):
"""Nested chains (parent_run_id set) should NOT produce run.start.""" """Nested chains (parent_run_id set) should NOT produce root run lifecycle events."""
j, store = journal_setup j, store = journal_setup
parent_id = uuid4() parent_id = uuid4()
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id) j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
j.on_chain_end({}, run_id=uuid4()) j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id)
await j.flush() await j.flush()
events = await store.list_events("t1", "r1") events = await store.list_events("t1", "r1")
assert not any(e["event_type"] == "run.start" for e in events) assert not any(e["event_type"] == "run.start" for e in events)
assert not any(e["event_type"] == "run.end" for e in events)
class TestToolCallbacks: class TestToolCallbacks:
@@ -0,0 +1,173 @@
"""Cross-user isolation for the stateless ``POST /api/runs/stream`` and ``/wait`` endpoints.
These endpoints receive ``thread_id`` in the request body, so the
``@require_permission(owner_check=True)`` decorator which reads the
``thread_id`` *path* parameter cannot protect them. The owner check
lives inside ``services.start_run()`` instead; this suite pins it at the
HTTP layer so the gap cannot silently reopen.
Strategy
--------
``app.state.run_manager.create_or_reject`` raises ``ConflictError``, so a
request that *passes* the owner check deterministically short-circuits
with 409 before any agent code runs. The two outcomes:
- 404 + ``create_or_reject`` never awaited -> blocked by the owner check
- 409 + ``create_or_reject`` awaited -> passed the owner check
The thread store is a real ``MemoryThreadMetaStore`` (not a mock) so the
``check_access`` semantics under test missing row allows, ``user_id``
NULL allows, foreign owner denies are exercised through real code.
"""
from __future__ import annotations
import asyncio
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from _router_auth_helpers import make_authed_test_app
from fastapi.testclient import TestClient
from langgraph.store.memory import InMemoryStore
from app.gateway.auth.models import User
from app.gateway.routers import runs
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
from deerflow.runtime import ConflictError
USER_A = User(email="owner-a@example.com", password_hash="x", system_role="user", id=uuid4())
USER_B = User(email="intruder-b@example.com", password_hash="x", system_role="user", id=uuid4())
INTERNAL_USER = SimpleNamespace(id="default", system_role="internal")
THREAD_A = "thread-owned-by-a"
THREAD_SHARED = "thread-shared-null-owner"
@pytest.fixture(autouse=True)
def _stub_app_config():
"""Inject a minimal AppConfig so the allowed path (which builds a
RunContext via ``get_config()``) never reads config.yaml from disk."""
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
yield
reset_app_config()
def _make_thread_store() -> MemoryThreadMetaStore:
store = MemoryThreadMetaStore(InMemoryStore())
async def _seed():
await store.create(THREAD_A, user_id=str(USER_A.id))
await store.create(THREAD_SHARED, user_id=None)
asyncio.run(_seed())
return store
@contextmanager
def _client(user):
"""Yield a ``TestClient`` authenticated as ``user`` plus the stubbed
``create_or_reject`` mock, closing the client (and its anyio portal /
background threads) on exit.
``create_or_reject`` raises ``ConflictError`` so a request that passes the
owner check short-circuits to 409 before any agent code runs.
"""
app = make_authed_test_app(user_factory=lambda: user)
app.include_router(runs.router)
app.state.thread_store = _make_thread_store()
app.state.stream_bridge = MagicMock()
app.state.checkpointer = MagicMock()
app.state.store = MagicMock()
app.state.run_events_config = None
app.state.run_event_store = MagicMock()
run_manager = MagicMock()
run_manager.create_or_reject = AsyncMock(side_effect=ConflictError("sentinel: owner check passed"))
app.state.run_manager = run_manager
with TestClient(app) as client:
yield client, run_manager.create_or_reject
def _body(thread_id: str | None = None) -> dict:
if thread_id is None:
return {}
return {"config": {"configurable": {"thread_id": thread_id}}}
# ---------------------------------------------------------------------------
# Denied: another user's thread
# ---------------------------------------------------------------------------
def test_stream_cross_user_returns_404():
"""User B cannot start a run on user A's thread via /api/runs/stream."""
with _client(USER_B) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body(THREAD_A))
assert response.status_code == 404
assert response.json()["detail"] == f"Thread {THREAD_A} not found"
create_or_reject.assert_not_awaited()
def test_wait_cross_user_returns_404_without_channel_values():
"""User B cannot read user A's checkpoint state via /api/runs/wait."""
with _client(USER_B) as (client, create_or_reject):
response = client.post("/api/runs/wait", json=_body(THREAD_A))
assert response.status_code == 404
assert response.json() == {"detail": f"Thread {THREAD_A} not found"}
create_or_reject.assert_not_awaited()
# ---------------------------------------------------------------------------
# Allowed: owner, fresh/untracked/shared threads, internal role
# ---------------------------------------------------------------------------
def test_stream_owner_passes_owner_check():
"""User A reaches run creation on their own thread (409 sentinel)."""
with _client(USER_A) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body(THREAD_A))
assert response.status_code == 409
create_or_reject.assert_awaited()
def test_wait_owner_passes_owner_check():
with _client(USER_A) as (client, create_or_reject):
response = client.post("/api/runs/wait", json=_body(THREAD_A))
assert response.status_code == 409
create_or_reject.assert_awaited()
def test_stream_without_thread_id_passes_owner_check():
"""Stateless run with no thread_id auto-creates a thread — never blocked."""
with _client(USER_B) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body())
assert response.status_code == 409
create_or_reject.assert_awaited()
def test_stream_untracked_thread_passes_owner_check():
"""A thread_id with no thread_meta row (untracked legacy) stays accessible."""
with _client(USER_B) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body("never-created-thread"))
assert response.status_code == 409
create_or_reject.assert_awaited()
def test_stream_shared_thread_passes_owner_check():
"""A thread_meta row with user_id NULL (shared / pre-auth data) stays accessible."""
with _client(USER_B) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body(THREAD_SHARED))
assert response.status_code == 409
create_or_reject.assert_awaited()
def test_stream_internal_role_bypasses_owner_check():
"""IM channels run with the internal system role on behalf of platform
users whose threads they do not own the owner check must not break them."""
with _client(INTERNAL_USER) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body(THREAD_A))
assert response.status_code == 409
create_or_reject.assert_awaited()
@@ -5,18 +5,22 @@ Verifies:
- ``_count_tokens`` falls back to character estimation when tiktoken is - ``_count_tokens`` falls back to character estimation when tiktoken is
unavailable or the encoding fails to load. unavailable or the encoding fails to load.
- ``warm_tiktoken_cache`` populates the cache on success. - ``warm_tiktoken_cache`` populates the cache on success.
- An in-flight tiktoken load prevents duplicate blocking downloads.
""" """
from __future__ import annotations from __future__ import annotations
import threading
from unittest import mock from unittest import mock
from deerflow.agents.memory.prompt import ( from deerflow.agents.memory.prompt import (
_count_tokens, _count_tokens,
_get_tiktoken_encoding, _get_tiktoken_encoding,
_tiktoken_encoding_cache, _tiktoken_encoding_cache,
format_memory_for_injection,
warm_tiktoken_cache, warm_tiktoken_cache,
) )
from deerflow.config.memory_config import MemoryConfig
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# _get_tiktoken_encoding # _get_tiktoken_encoding
@@ -62,14 +66,103 @@ class TestGetTiktokenEncoding:
assert enc is fake_enc assert enc is fake_enc
tiktoken.get_encoding.assert_not_called() tiktoken.get_encoding.assert_not_called()
def test_returns_none_and_warns_on_get_encoding_failure(self, monkeypatch): def test_returns_none_and_caches_failure_sentinel(self, monkeypatch):
"""A failed load is cached (with a timestamp) so it is not re-attempted (no repeated network download)."""
_tiktoken_encoding_cache.pop("bogus_encoding", None) _tiktoken_encoding_cache.pop("bogus_encoding", None)
import tiktoken import tiktoken
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=OSError("download failed"))) get_encoding = mock.Mock(side_effect=OSError("download failed"))
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
result = _get_tiktoken_encoding("bogus_encoding") result = _get_tiktoken_encoding("bogus_encoding")
assert result is None assert result is None
assert "bogus_encoding" not in _tiktoken_encoding_cache # The failure is remembered as a (None, timestamp) tuple.
assert "bogus_encoding" in _tiktoken_encoding_cache
cached = _tiktoken_encoding_cache["bogus_encoding"]
assert isinstance(cached, tuple)
assert cached[0] is None
# A second call must NOT re-attempt get_encoding (avoids re-blocking on
# the network download in restricted environments — see #3429).
result2 = _get_tiktoken_encoding("bogus_encoding")
assert result2 is None
assert get_encoding.call_count == 1
# Cleanup module-level cache to avoid cross-test leakage.
_tiktoken_encoding_cache.pop("bogus_encoding", None)
def test_failure_self_heals_after_cooldown(self, monkeypatch):
"""After the retry cooldown expires, a transient failure is re-attempted and can recover."""
_tiktoken_encoding_cache.pop("flaky_encoding", None)
import tiktoken
fake_enc = mock.Mock()
# First call fails, second call (after cooldown) succeeds.
get_encoding = mock.Mock(side_effect=[OSError("transient outage"), fake_enc])
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
# Initial failure is cached.
assert _get_tiktoken_encoding("flaky_encoding") is None
assert get_encoding.call_count == 1
# Within the cooldown window: no retry, immediate fallback.
assert _get_tiktoken_encoding("flaky_encoding") is None
assert get_encoding.call_count == 1
# Simulate the cooldown having elapsed by ageing the cached timestamp.
from deerflow.agents.memory import prompt as prompt_module
_, _failed_at = _tiktoken_encoding_cache["flaky_encoding"]
_tiktoken_encoding_cache["flaky_encoding"] = (
None,
_failed_at - prompt_module._TIKTOKEN_RETRY_COOLDOWN_S - 1,
)
# Now the load is retried and recovers to accurate counting.
assert _get_tiktoken_encoding("flaky_encoding") is fake_enc
assert get_encoding.call_count == 2
_tiktoken_encoding_cache.pop("flaky_encoding", None)
def test_in_flight_load_returns_none_without_duplicate_get_encoding(self, monkeypatch):
"""Concurrent callers must not start duplicate blocking BPE downloads."""
_tiktoken_encoding_cache.pop("slow_encoding", None)
import tiktoken
started = threading.Event()
release = threading.Event()
fake_enc = mock.Mock()
def slow_get_encoding(_name):
started.set()
assert release.wait(timeout=2), "test timed out waiting to release slow get_encoding"
return fake_enc
get_encoding = mock.Mock(side_effect=slow_get_encoding)
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
result: dict[str, object | None] = {}
def load_encoding():
result["encoding"] = _get_tiktoken_encoding("slow_encoding")
thread = threading.Thread(target=load_encoding)
thread.start()
try:
assert started.wait(timeout=1), "slow get_encoding did not start"
# While the first call is still blocked, a second call should see
# the in-flight sentinel and fall back immediately instead of
# starting another potentially long network download.
assert _get_tiktoken_encoding("slow_encoding") is None
assert get_encoding.call_count == 1
finally:
release.set()
thread.join(timeout=2)
_tiktoken_encoding_cache.pop("slow_encoding", None)
assert result["encoding"] is fake_enc
assert get_encoding.call_count == 1
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -115,6 +208,45 @@ class TestCountTokens:
result = _count_tokens(text, encoding_name="test_enc") result = _count_tokens(text, encoding_name="test_enc")
assert result == len(text) // 4 assert result == len(text) // 4
def test_use_tiktoken_false_returns_char_estimate_without_touching_tiktoken(self, monkeypatch):
"""use_tiktoken=False must never call tiktoken (guarantees no BPE download)."""
# Spy on both the encoding loader and tiktoken.get_encoding directly.
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
loader_spy = mock.Mock(side_effect=AssertionError("_get_tiktoken_encoding must not be called"))
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", loader_spy)
text = "Hello, world! This is a network-free count."
result = _count_tokens(text, use_tiktoken=False)
assert result == len(text) // 4
get_encoding_spy.assert_not_called()
loader_spy.assert_not_called()
def test_cjk_estimate_is_denser_than_plain_quarter(self, monkeypatch):
"""CJK text should estimate more tokens than the plain len // 4 heuristic.
CJK characters are ~2 chars/token, so the char-based estimate must not
under-fill the budget the way ``len(text) // 4`` would.
"""
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
# "User prefers concise answers" rendered in CJK (Chinese) characters.
text = "\u7528\u6237\u504f\u597d\u7b80\u6d01\u7684\u4e2d\u6587\u56de\u7b54\u5e76\u5173\u6ce8\u91d1\u878d\u9886\u57df"
result = _count_tokens(text)
# Each CJK char counts as ~1/2 token (vs 1/4 for the plain heuristic).
assert result == len(text) // 2
assert result > len(text) // 4
def test_cjk_estimate_combines_cjk_and_non_cjk_characters(self, monkeypatch):
"""Mixed-language text should apply the CJK density only to CJK chars."""
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
# ASCII words mixed with CJK (Chinese) characters: "User" + "likes" + "Python and data analysis".
text = "User\u559c\u6b22Python\u548c\u6570\u636e\u5206\u6790"
cjk = sum(1 for ch in text if "\u4e00" <= ch <= "\u9fff")
result = _count_tokens(text)
assert result == (len(text) - cjk) // 4 + cjk // 2
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# warm_tiktoken_cache # warm_tiktoken_cache
@@ -146,3 +278,69 @@ class TestWarmTiktokenCache:
def test_returns_false_when_tiktoken_unavailable(self, monkeypatch): def test_returns_false_when_tiktoken_unavailable(self, monkeypatch):
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False) monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
assert warm_tiktoken_cache() is False assert warm_tiktoken_cache() is False
# ---------------------------------------------------------------------------
# format_memory_for_injection token_counting strategy
# ---------------------------------------------------------------------------
class TestFormatMemoryForInjectionTokenCounting:
"""Verify the use_tiktoken flag is honoured end-to-end."""
@staticmethod
def _sample_memory() -> dict:
return {
"facts": [
{"content": "User prefers concise answers.", "category": "preference", "confidence": 0.9},
{"content": "User works in the finance domain.", "category": "context", "confidence": 0.8},
],
}
def test_use_tiktoken_false_never_touches_tiktoken(self, monkeypatch):
"""With use_tiktoken=False, formatting must not call tiktoken at all."""
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=False)
assert "User prefers concise answers." in result
get_encoding_spy.assert_not_called()
def test_use_tiktoken_true_uses_encoding(self, monkeypatch):
"""With use_tiktoken=True (default), the cached encoding is used for counting."""
fake_enc = mock.Mock()
fake_enc.encode.side_effect = lambda text: list(range(len(text)))
monkeypatch.setattr(
"deerflow.agents.memory.prompt._get_tiktoken_encoding",
mock.Mock(return_value=fake_enc),
)
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=True)
assert "User prefers concise answers." in result
assert fake_enc.encode.called
def test_empty_memory_returns_empty(self):
assert format_memory_for_injection({}, max_tokens=2000, use_tiktoken=False) == ""
# ---------------------------------------------------------------------------
# MemoryConfig.token_counting
# ---------------------------------------------------------------------------
class TestMemoryConfigTokenCounting:
"""Verify the new config field defaults and validation."""
def test_default_is_tiktoken(self):
"""Default must remain tiktoken so existing deployments are unaffected."""
assert MemoryConfig().token_counting == "tiktoken"
def test_accepts_char(self):
assert MemoryConfig(token_counting="char").token_counting == "char"
def test_rejects_invalid_value(self):
import pytest
from pydantic import ValidationError
with pytest.raises(ValidationError):
MemoryConfig(token_counting="invalid")
+15 -2
View File
@@ -15,7 +15,7 @@
# ============================================================================ # ============================================================================
# Bump this number when the config schema changes. # Bump this number when the config schema changes.
# Run `make config-upgrade` to merge new fields into your local config.yaml. # Run `make config-upgrade` to merge new fields into your local config.yaml.
config_version: 11 config_version: 12
# ============================================================================ # ============================================================================
# Logging # Logging
@@ -768,8 +768,12 @@ sandbox:
allow_host_bash: false allow_host_bash: false
# Optional: Mount additional host directories into the sandbox. # Optional: Mount additional host directories into the sandbox.
# Each mount maps a host path to a virtual container path accessible by the agent. # Each mount maps a host path to a virtual container path accessible by the agent.
# Note: with LocalSandboxProvider under `make up` (docker-compose), host_path is
# checked from inside the deer-flow-gateway container — you must also bind-mount
# the same directory into services.gateway.volumes in docker/docker-compose.yaml
# for this mount to take effect (see issue #3244).
# mounts: # mounts:
# - host_path: /home/user/my-project # Absolute path on the host machine # - host_path: /home/user/my-project # Absolute path; see note above for Docker mode
# container_path: /mnt/my-project # Virtual path inside the sandbox # container_path: /mnt/my-project # Virtual path inside the sandbox
# read_only: true # Whether the mount is read-only (default: false) # read_only: true # Whether the mount is read-only (default: false)
@@ -1020,6 +1024,15 @@ memory:
fact_confidence_threshold: 0.7 # Minimum confidence for storing facts fact_confidence_threshold: 0.7 # Minimum confidence for storing facts
injection_enabled: true # Whether to inject memory into system prompt injection_enabled: true # Whether to inject memory into system prompt
max_injection_tokens: 2000 # Maximum tokens for memory injection max_injection_tokens: 2000 # Maximum tokens for memory injection
# Token counting strategy for memory-injection budgeting:
# tiktoken (default) - accurate, but the encoding's BPE data may be
# downloaded from a public network endpoint on first use. In
# network-restricted environments this download can block for a long
# time (see issues #3402 / #3429). Pre-cache the encoding or set this
# to "char" to avoid it.
# char - network-free CJK-aware character-based estimate; never touches
# tiktoken. Slightly less precise budgeting, zero network I/O.
token_counting: tiktoken
# ============================================================================ # ============================================================================
# Custom Agent Management API # Custom Agent Management API
+7 -1
View File
@@ -72,7 +72,13 @@ services:
UV_INDEX_URL: ${UV_INDEX_URL:-https://pypi.org/simple} UV_INDEX_URL: ${UV_INDEX_URL:-https://pypi.org/simple}
UV_EXTRAS: ${UV_EXTRAS:-} UV_EXTRAS: ${UV_EXTRAS:-}
container_name: deer-flow-gateway container_name: deer-flow-gateway
command: sh -c "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --workers ${GATEWAY_WORKERS:-4}" # Gateway hosts the agent runtime with in-process RunManager + StreamBridge
# singletons -- run state lives in this worker's memory. Default to a single
# worker: with >1 worker and no nginx sticky sessions, run cancel, SSE
# reconnect, request dedup, and per-worker IM channel services all break
# across workers until a shared (e.g. redis) stream bridge lands, which is
# not yet implemented. Override GATEWAY_WORKERS only once that is in place.
command: sh -c "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --workers ${GATEWAY_WORKERS:-1}"
volumes: volumes:
- ${DEER_FLOW_CONFIG_PATH}:/app/backend/config.yaml:ro - ${DEER_FLOW_CONFIG_PATH}:/app/backend/config.yaml:ro
- ${DEER_FLOW_EXTENSIONS_CONFIG_PATH}:/app/backend/extensions_config.json:ro - ${DEER_FLOW_EXTENSIONS_CONFIG_PATH}:/app/backend/extensions_config.json:ro
+7 -3
View File
@@ -7,8 +7,9 @@ import { defineConfig, devices } from "@playwright/test";
* so the mock-based suite is untouched. * so the mock-based suite is untouched.
* *
* Two webServers are started: the replay gateway (:8011) and the frontend * Two webServers are started: the replay gateway (:8011) and the frontend
* (:3000, pointed at the gateway). Auth uses a throwaway test account the spec * (:3000, pointed at the gateway). Auth-disabled mode is enabled on both
* registers at runtime no secrets. * servers so the no-cookie e2e contract is covered; specs that need session
* cookies still register a throwaway test account at runtime.
*/ */
export default defineConfig({ export default defineConfig({
testDir: "./tests/e2e-real-backend", testDir: "./tests/e2e-real-backend",
@@ -38,7 +39,10 @@ export default defineConfig({
// Mount the test-only run/message seeder used by multi-run-order.spec.ts // Mount the test-only run/message seeder used by multi-run-order.spec.ts
// (#3352). The endpoint exists only on this replay gateway, never in the // (#3352). The endpoint exists only on this replay gateway, never in the
// production app. // production app.
env: { DEERFLOW_ENABLE_TEST_SEED: "1" }, env: {
DEERFLOW_ENABLE_TEST_SEED: "1",
DEER_FLOW_AUTH_DISABLED: "1",
},
}, },
{ {
command: "pnpm build && pnpm start", command: "pnpm build && pnpm start",
+61 -4
View File
@@ -1,8 +1,9 @@
"use client"; "use client";
import Link from "next/link"; import Link from "next/link";
import { useEffect, useMemo, useState } from "react"; import { useEffect, useMemo, useRef, useState } from "react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { ScrollArea } from "@/components/ui/scroll-area"; import { ScrollArea } from "@/components/ui/scroll-area";
import { import {
@@ -11,24 +12,58 @@ import {
WorkspaceHeader, WorkspaceHeader,
} from "@/components/workspace/workspace-container"; } from "@/components/workspace/workspace-container";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { useThreads } from "@/core/threads/hooks"; import { useInfiniteThreads } from "@/core/threads/hooks";
import { pathOfThread, titleOfThread } from "@/core/threads/utils"; import { pathOfThread, titleOfThread } from "@/core/threads/utils";
import { formatTimeAgo } from "@/core/utils/datetime"; import { formatTimeAgo } from "@/core/utils/datetime";
export default function ChatsPage() { export default function ChatsPage() {
const { t } = useI18n(); const { t } = useI18n();
const { data: threads } = useThreads(); const {
data: infiniteThreads,
fetchNextPage,
hasNextPage,
isFetchingNextPage,
} = useInfiniteThreads();
const threads = useMemo(
() => infiniteThreads?.pages.flat() ?? [],
[infiniteThreads],
);
const [search, setSearch] = useState(""); const [search, setSearch] = useState("");
const isSearching = search.trim().length > 0;
useEffect(() => { useEffect(() => {
document.title = `${t.pages.chats} - ${t.pages.appName}`; document.title = `${t.pages.chats} - ${t.pages.appName}`;
}, [t.pages.chats, t.pages.appName]); }, [t.pages.chats, t.pages.appName]);
const filteredThreads = useMemo(() => { const filteredThreads = useMemo(() => {
return threads?.filter((thread) => { return threads.filter((thread) => {
return titleOfThread(thread).toLowerCase().includes(search.toLowerCase()); return titleOfThread(thread).toLowerCase().includes(search.toLowerCase());
}); });
}, [threads, search]); }, [threads, search]);
// Sentinel-based auto load-more for the unfiltered list (issue #3482).
// In search mode we deliberately do NOT auto-paginate, otherwise an empty
// filtered view would keep the sentinel in the viewport and drain the
// entire backend list one page at a time. Searching falls back to an
// explicit button so users can still reach older conversations on demand.
const sentinelRef = useRef<HTMLDivElement | null>(null);
useEffect(() => {
const element = sentinelRef.current;
if (!element || !hasNextPage || isSearching) {
return;
}
const observer = new IntersectionObserver(
([entry]) => {
if (entry?.isIntersecting && hasNextPage && !isFetchingNextPage) {
void fetchNextPage();
}
},
{ rootMargin: "200px 0px 200px 0px" },
);
observer.observe(element);
return () => observer.disconnect();
}, [fetchNextPage, hasNextPage, isFetchingNextPage, isSearching]);
return ( return (
<WorkspaceContainer> <WorkspaceContainer>
<WorkspaceHeader></WorkspaceHeader> <WorkspaceHeader></WorkspaceHeader>
@@ -61,6 +96,28 @@ export default function ChatsPage() {
</div> </div>
</Link> </Link>
))} ))}
{hasNextPage && !isSearching && (
<div
ref={sentinelRef}
aria-hidden="true"
className="h-px w-full"
data-testid="chats-page-sentinel"
/>
)}
{hasNextPage && isSearching && (
<div className="flex justify-center p-4">
<Button
variant="outline"
onClick={() => void fetchNextPage()}
disabled={isFetchingNextPage}
data-testid="chats-page-load-more"
>
{isFetchingNextPage
? t.chats.loadingMore
: t.chats.loadMoreToSearch}
</Button>
</div>
)}
</div> </div>
</ScrollArea> </ScrollArea>
</main> </main>
@@ -11,7 +11,7 @@ import {
} from "lucide-react"; } from "lucide-react";
import Link from "next/link"; import Link from "next/link";
import { useParams, usePathname, useRouter } from "next/navigation"; import { useParams, usePathname, useRouter } from "next/navigation";
import { useCallback, useState } from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner"; import { toast } from "sonner";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
@@ -51,8 +51,8 @@ import {
} from "@/core/threads/export"; } from "@/core/threads/export";
import { import {
useDeleteThread, useDeleteThread,
useInfiniteThreads,
useRenameThread, useRenameThread,
useThreads,
} from "@/core/threads/hooks"; } from "@/core/threads/hooks";
import type { AgentThread, AgentThreadState } from "@/core/threads/types"; import type { AgentThread, AgentThreadState } from "@/core/threads/types";
import { pathOfThread, titleOfThread } from "@/core/threads/utils"; import { pathOfThread, titleOfThread } from "@/core/threads/utils";
@@ -68,7 +68,35 @@ export function RecentChatList() {
thread_id: string; thread_id: string;
agent_name?: string; agent_name?: string;
}>(); }>();
const { data: threads = [] } = useThreads(); const {
data: infiniteThreads,
fetchNextPage,
hasNextPage,
isFetchingNextPage,
} = useInfiniteThreads();
const threads = useMemo(
() => infiniteThreads?.pages.flat() ?? [],
[infiniteThreads],
);
const sentinelRef = useRef<HTMLDivElement | null>(null);
useEffect(() => {
const element = sentinelRef.current;
if (!element || !hasNextPage) {
return;
}
const observer = new IntersectionObserver(
([entry]) => {
if (entry?.isIntersecting && hasNextPage && !isFetchingNextPage) {
void fetchNextPage();
}
},
{ rootMargin: "120px 0px 120px 0px" },
);
observer.observe(element);
return () => observer.disconnect();
}, [fetchNextPage, hasNextPage, isFetchingNextPage]);
const { mutate: deleteThread } = useDeleteThread(); const { mutate: deleteThread } = useDeleteThread();
const { mutate: renameThread } = useRenameThread(); const { mutate: renameThread } = useRenameThread();
@@ -267,6 +295,28 @@ export function RecentChatList() {
</SidebarMenuItem> </SidebarMenuItem>
); );
})} })}
{hasNextPage && (
<>
<Button
variant="ghost"
size="sm"
className="mx-2 my-1 w-[calc(100%-1rem)] justify-center text-xs"
onClick={() => void fetchNextPage()}
disabled={isFetchingNextPage}
data-testid="recent-chat-list-load-more"
>
{isFetchingNextPage
? t.chats.loadingMore
: t.chats.loadOlderChats}
</Button>
<div
ref={sentinelRef}
aria-hidden="true"
className="h-px w-full"
data-testid="recent-chat-list-sentinel"
/>
</>
)}
</div> </div>
</SidebarMenu> </SidebarMenu>
</SidebarGroupContent> </SidebarGroupContent>
@@ -0,0 +1,23 @@
import type { User } from "./types";
export const AUTH_DISABLED_USER: User = {
id: "e2e-user",
email: "e2e@test.local",
system_role: "admin",
needs_setup: false,
};
const PRODUCTION_ENV_VALUES = new Set(["prod", "production"]);
function isExplicitProductionEnvironment() {
return ["DEER_FLOW_ENV", "ENVIRONMENT"].some((name) =>
PRODUCTION_ENV_VALUES.has((process.env[name] ?? "").trim().toLowerCase()),
);
}
export function isAuthDisabledMode() {
return (
process.env.DEER_FLOW_AUTH_DISABLED === "1" &&
!isExplicitProductionEnvironment()
);
}
+3 -7
View File
@@ -2,6 +2,7 @@ import { cookies } from "next/headers";
import { isStaticWebsiteOnly } from "../static-mode"; import { isStaticWebsiteOnly } from "../static-mode";
import { AUTH_DISABLED_USER, isAuthDisabledMode } from "./auth-disabled-user";
import { getGatewayConfig } from "./gateway-config"; import { getGatewayConfig } from "./gateway-config";
import { STATIC_WEBSITE_USER } from "./static-user"; import { STATIC_WEBSITE_USER } from "./static-user";
import { type AuthResult, userSchema } from "./types"; import { type AuthResult, userSchema } from "./types";
@@ -20,15 +21,10 @@ export async function getServerSideUser(): Promise<AuthResult> {
}; };
} }
if (process.env.DEER_FLOW_AUTH_DISABLED === "1") { if (isAuthDisabledMode()) {
return { return {
tag: "authenticated", tag: "authenticated",
user: { user: AUTH_DISABLED_USER,
id: "e2e-user",
email: "e2e@test.local",
system_role: "admin",
needs_setup: false,
},
}; };
} }
+3
View File
@@ -252,6 +252,9 @@ export const enUS: Translations = {
// Chats // Chats
chats: { chats: {
searchChats: "Search chats", searchChats: "Search chats",
loadMoreToSearch: "Load more to search older conversations",
loadingMore: "Loading more...",
loadOlderChats: "Load older chats",
}, },
// Page titles (document title) // Page titles (document title)
+3
View File
@@ -183,6 +183,9 @@ export interface Translations {
// Chats // Chats
chats: { chats: {
searchChats: string; searchChats: string;
loadMoreToSearch: string;
loadingMore: string;
loadOlderChats: string;
}; };
// Page titles (document title) // Page titles (document title)
+3
View File
@@ -240,6 +240,9 @@ export const zhCN: Translations = {
// Chats // Chats
chats: { chats: {
searchChats: "搜索对话", searchChats: "搜索对话",
loadMoreToSearch: "加载更多以搜索更早的对话",
loadingMore: "正在加载...",
loadOlderChats: "加载更早的对话",
}, },
// Page titles (document title) // Page titles (document title)
+240 -11
View File
@@ -3,6 +3,8 @@ import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
import { useStream } from "@langchain/langgraph-sdk/react"; import { useStream } from "@langchain/langgraph-sdk/react";
import { import {
type QueryClient, type QueryClient,
type InfiniteData,
useInfiniteQuery,
useMutation, useMutation,
useQuery, useQuery,
useQueryClient, useQueryClient,
@@ -311,6 +313,56 @@ export function upsertThreadInSearchCache(
); );
} }
export function upsertThreadInInfiniteCache(
queryClient: QueryClient,
thread: AgentThread,
) {
queryClient.setQueriesData(
{
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
exact: false,
},
(oldData: InfiniteData<AgentThread[]> | undefined) => {
if (!oldData) {
return oldData;
}
const merged = oldData.pages.map((page) =>
page.map((t) =>
t.thread_id === thread.thread_id
? {
...thread,
...t,
metadata: {
...(thread.metadata ?? {}),
...(t.metadata ?? {}),
},
values: {
...thread.values,
...t.values,
},
}
: t,
),
);
const exists = merged.some((page) =>
page.some((t) => t.thread_id === thread.thread_id),
);
if (exists) {
return { ...oldData, pages: merged };
}
const firstPage = merged[0] ?? [];
const restPages = merged.slice(1);
return {
...oldData,
pages: [[thread, ...firstPage], ...restPages],
};
},
);
}
function getStreamErrorMessage(error: unknown): string { function getStreamErrorMessage(error: unknown): string {
if (typeof error === "string" && error.trim()) { if (typeof error === "string" && error.trim()) {
return error; return error;
@@ -364,7 +416,7 @@ export function useThreadStream({
loadMore: loadMoreHistory, loadMore: loadMoreHistory,
loading: isHistoryLoading, loading: isHistoryLoading,
appendMessages, appendMessages,
} = useThreadHistory(onStreamThreadId ?? ""); } = useThreadHistory(onStreamThreadId ?? "", { enabled: !isMock });
// Keep listeners ref updated with latest callbacks // Keep listeners ref updated with latest callbacks
useEffect(() => { useEffect(() => {
@@ -417,6 +469,19 @@ export function useThreadStream({
}, },
interrupts: {}, interrupts: {},
}); });
upsertThreadInInfiniteCache(queryClient, {
thread_id: meta.thread_id,
created_at: now,
updated_at: now,
metadata: context.agent_name ? { agent_name: context.agent_name } : {},
status: "busy",
values: {
title: t.pages.newChat,
messages: [],
artifacts: [],
},
interrupts: {},
});
if (context.agent_name && !isMock) { if (context.agent_name && !isMock) {
void getAPIClient() void getAPIClient()
.threads.update(meta.thread_id, { .threads.update(meta.thread_id, {
@@ -488,6 +553,27 @@ export function useThreadStream({
}); });
}, },
); );
const nextTitle: string = update.title;
void queryClient.setQueriesData(
{
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
exact: false,
},
(oldData: InfiniteData<AgentThread[]> | undefined) =>
mapInfiniteThreadsCache(
oldData,
(t): AgentThread =>
t.thread_id === threadIdRef.current
? {
...t,
values: {
...t.values,
title: nextTitle,
},
}
: t,
),
);
} }
} }
}, },
@@ -542,6 +628,9 @@ export function useThreadStream({
.filter((id): id is string => Boolean(id)), .filter((id): id is string => Boolean(id)),
); );
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
void queryClient.invalidateQueries({
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
});
if (threadIdRef.current && !isMock) { if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({ void queryClient.invalidateQueries({
queryKey: threadTokenUsageQueryKey(threadIdRef.current), queryKey: threadTokenUsageQueryKey(threadIdRef.current),
@@ -801,6 +890,9 @@ export function useThreadStream({
}, },
); );
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
void queryClient.invalidateQueries({
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
});
} catch (error) { } catch (error) {
setOptimisticMessages([]); setOptimisticMessages([]);
setIsUploading(false); setIsUploading(false);
@@ -854,8 +946,15 @@ export function useThreadStream({
} as const; } as const;
} }
export function useThreadHistory(threadId: string) { type ThreadHistoryOptions = {
const runs = useThreadRuns(threadId); enabled?: boolean;
};
export function useThreadHistory(
threadId: string,
{ enabled = true }: ThreadHistoryOptions = {},
) {
const runs = useThreadRuns(threadId, { enabled });
const threadIdRef = useRef(threadId); const threadIdRef = useRef(threadId);
const runsRef = useRef(runs.data ?? []); const runsRef = useRef(runs.data ?? []);
const indexRef = useRef(-1); const indexRef = useRef(-1);
@@ -864,10 +963,15 @@ export function useThreadHistory(threadId: string) {
const loadingRunIdRef = useRef<string | null>(null); const loadingRunIdRef = useRef<string | null>(null);
const loadedRunIdsRef = useRef<Set<string>>(new Set()); const loadedRunIdsRef = useRef<Set<string>>(new Set());
const runBeforeSeqRef = useRef<Map<string, number>>(new Map()); const runBeforeSeqRef = useRef<Map<string, number>>(new Map());
const loadGenerationRef = useRef(0);
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [messages, setMessages] = useState<Message[]>([]); const [messages, setMessages] = useState<Message[]>([]);
const loadMessages = useCallback(async () => { const loadMessages = useCallback(async () => {
if (!enabled) {
return;
}
const loadGeneration = loadGenerationRef.current;
if (loadingRef.current) { if (loadingRef.current) {
const pendingRunIndex = findLatestUnloadedRunIndex( const pendingRunIndex = findLatestUnloadedRunIndex(
runsRef.current, runsRef.current,
@@ -921,12 +1025,15 @@ export function useThreadHistory(threadId: string) {
}).then((res) => { }).then((res) => {
return res.json(); return res.json();
}); });
if (
loadGenerationRef.current !== loadGeneration ||
threadIdRef.current !== requestThreadId
) {
return;
}
const _messages = result.data const _messages = result.data
.filter((m) => !m.metadata.caller?.startsWith("middleware:")) .filter((m) => !m.metadata.caller?.startsWith("middleware:"))
.map((m) => m.content); .map((m) => m.content);
if (threadIdRef.current !== requestThreadId) {
return;
}
setMessages((prev) => setMessages((prev) =>
dedupeMessagesByIdentity([..._messages, ...prev]), dedupeMessagesByIdentity([..._messages, ...prev]),
); );
@@ -961,16 +1068,19 @@ export function useThreadHistory(threadId: string) {
} catch (err) { } catch (err) {
console.error(err); console.error(err);
} finally { } finally {
if (loadGenerationRef.current === loadGeneration) {
loadingRef.current = false; loadingRef.current = false;
loadingRunIdRef.current = null; loadingRunIdRef.current = null;
setLoading(false); setLoading(false);
} }
}, []); }
}, [enabled]);
useEffect(() => { useEffect(() => {
const threadChanged = threadIdRef.current !== threadId; const threadChanged = threadIdRef.current !== threadId;
threadIdRef.current = threadId; threadIdRef.current = threadId;
if (threadChanged) { if (!enabled || threadChanged) {
loadGenerationRef.current += 1;
runsRef.current = []; runsRef.current = [];
indexRef.current = -1; indexRef.current = -1;
pendingLoadRef.current = false; pendingLoadRef.current = false;
@@ -982,6 +1092,10 @@ export function useThreadHistory(threadId: string) {
setMessages([]); setMessages([]);
} }
if (!enabled) {
return;
}
if (runs.data && runs.data.length > 0) { if (runs.data && runs.data.length > 0) {
runsRef.current = runs.data ?? []; runsRef.current = runs.data ?? [];
indexRef.current = findLatestUnloadedRunIndex( indexRef.current = findLatestUnloadedRunIndex(
@@ -992,14 +1106,15 @@ export function useThreadHistory(threadId: string) {
loadMessages().catch(() => { loadMessages().catch(() => {
toast.error("Failed to load thread history."); toast.error("Failed to load thread history.");
}); });
}, [threadId, runs.data, loadMessages]); }, [enabled, threadId, runs.data, loadMessages]);
const appendMessages = useCallback((_messages: Message[]) => { const appendMessages = useCallback((_messages: Message[]) => {
setMessages((prev) => { setMessages((prev) => {
return dedupeMessagesByIdentity([...prev, ..._messages]); return dedupeMessagesByIdentity([...prev, ..._messages]);
}); });
}, []); }, []);
const hasMore = indexRef.current >= 0 || !runs.data; const hasMore =
enabled && Boolean(threadId) && (indexRef.current >= 0 || !runs.data);
return { return {
runs: runs.data, runs: runs.data,
messages, messages,
@@ -1077,7 +1192,90 @@ export function useThreads(
}); });
} }
export function useThreadRuns(threadId?: string) { export const INFINITE_THREADS_PAGE_SIZE = 50;
export const INFINITE_THREADS_QUERY_KEY_PREFIX = [
"threads",
"searchInfinite",
] as const;
type InfiniteThreadsParams = Omit<
Parameters<ThreadsClient["search"]>[0],
"limit" | "offset"
>;
export function getInfiniteThreadsNextPageParam(
lastPage: AgentThread[],
allPages: AgentThread[][],
pageSize: number = INFINITE_THREADS_PAGE_SIZE,
): number | undefined {
if (lastPage.length < pageSize) {
return undefined;
}
return allPages.reduce((sum, page) => sum + page.length, 0);
}
export function mapInfiniteThreadsCache(
oldData: InfiniteData<AgentThread[]> | undefined,
mapper: (thread: AgentThread) => AgentThread,
): InfiniteData<AgentThread[]> | undefined {
if (!oldData) {
return oldData;
}
return {
...oldData,
pages: oldData.pages.map((page) => page.map(mapper)),
};
}
export function filterInfiniteThreadsCache(
oldData: InfiniteData<AgentThread[]> | undefined,
predicate: (thread: AgentThread) => boolean,
): InfiniteData<AgentThread[]> | undefined {
if (!oldData) {
return oldData;
}
return {
...oldData,
pages: oldData.pages.map((page) => page.filter(predicate)),
};
}
export function useInfiniteThreads(
params: InfiniteThreadsParams = {
sortBy: "updated_at",
sortOrder: "desc",
select: ["thread_id", "updated_at", "values", "metadata"],
},
) {
const apiClient = getAPIClient();
return useInfiniteQuery<
AgentThread[],
Error,
InfiniteData<AgentThread[]>,
readonly unknown[],
number
>({
queryKey: [...INFINITE_THREADS_QUERY_KEY_PREFIX, params],
initialPageParam: 0,
queryFn: async ({ pageParam }) => {
const response = (await apiClient.threads.search<AgentThreadState>({
...params,
limit: INFINITE_THREADS_PAGE_SIZE,
offset: pageParam,
})) as AgentThread[];
return response;
},
getNextPageParam: (lastPage, allPages) =>
getInfiniteThreadsNextPageParam(lastPage, allPages),
refetchOnWindowFocus: false,
});
}
export function useThreadRuns(
threadId?: string,
{ enabled = true }: { enabled?: boolean } = {},
) {
const apiClient = getAPIClient(); const apiClient = getAPIClient();
return useQuery<Run[]>({ return useQuery<Run[]>({
queryKey: ["thread", threadId], queryKey: ["thread", threadId],
@@ -1088,6 +1286,7 @@ export function useThreadRuns(threadId?: string) {
const response = await apiClient.runs.list(threadId); const response = await apiClient.runs.list(threadId);
return response; return response;
}, },
enabled: enabled && Boolean(threadId),
refetchOnWindowFocus: false, refetchOnWindowFocus: false,
}); });
} }
@@ -1156,9 +1355,21 @@ export function useDeleteThread() {
return oldData.filter((t) => t.thread_id !== threadId); return oldData.filter((t) => t.thread_id !== threadId);
}, },
); );
queryClient.setQueriesData(
{
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
exact: false,
}, },
(oldData: InfiniteData<AgentThread[]> | undefined) =>
filterInfiniteThreadsCache(oldData, (t) => t.thread_id !== threadId),
);
},
onSettled() { onSettled() {
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
void queryClient.invalidateQueries({
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
});
}, },
}); });
} }
@@ -1199,6 +1410,24 @@ export function useRenameThread() {
}); });
}, },
); );
queryClient.setQueriesData(
{
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
exact: false,
},
(oldData: InfiniteData<AgentThread[]> | undefined) =>
mapInfiniteThreadsCache(oldData, (t) =>
t.thread_id === threadId
? {
...t,
values: {
...t.values,
title,
},
}
: t,
),
);
}, },
}); });
} }
@@ -0,0 +1,16 @@
import { expect, test } from "@playwright/test";
import { AUTH_DISABLED_USER } from "../../src/core/auth/auth-disabled-user";
const APP = "http://localhost:3000";
test.describe("auth-disabled contract (real backend)", () => {
test("gateway /auth/me returns the frontend synthetic user without a cookie", async ({
context,
}) => {
const resp = await context.request.get(`${APP}/api/v1/auth/me`);
expect(resp.status(), await resp.text()).toBe(200);
await expect(resp.json()).resolves.toEqual(AUTH_DISABLED_USER);
});
});
@@ -101,10 +101,11 @@ test.describe("real backend render (replay, no API key)", () => {
EXPECTED_SUGGESTION, EXPECTED_SUGGESTION,
"fixture should contain a suggestions turn (re-record; the record spec waits for /suggestions)", "fixture should contain a suggestions turn (re-record; the record spec waits for /suggestions)",
).not.toBe(""); ).not.toBe("");
await expect(page.getByText(EXPECTED_TITLE)).toBeVisible({ const chat = page.locator("#chat");
await expect(chat.getByText(EXPECTED_TITLE)).toBeVisible({
timeout: 60_000, timeout: 60_000,
}); });
await expect(page.getByText(EXPECTED_SUGGESTION)).toBeVisible({ await expect(chat.getByText(EXPECTED_SUGGESTION)).toBeVisible({
timeout: 30_000, timeout: 30_000,
}); });
+1
View File
@@ -12,6 +12,7 @@ test.describe("Chat workspace", () => {
const textarea = page.getByPlaceholder(/how can i assist you/i); const textarea = page.getByPlaceholder(/how can i assist you/i);
await expect(textarea).toBeVisible({ timeout: 15_000 }); await expect(textarea).toBeVisible({ timeout: 15_000 });
await expect(page.getByRole("button", { name: /load more/i })).toBeHidden();
}); });
test("can type a message in the input box", async ({ page }) => { test("can type a message in the input box", async ({ page }) => {
+79
View File
@@ -18,6 +18,7 @@ const THREADS = [
updated_at: "2025-06-02T12:00:00Z", updated_at: "2025-06-02T12:00:00Z",
}, },
]; ];
const DEMO_THREAD_ID = "7cfa5f8f-a2f8-47ad-acbd-da7137baf990";
test.describe("Thread history", () => { test.describe("Thread history", () => {
test("sidebar shows existing threads", async ({ page }) => { test("sidebar shows existing threads", async ({ page }) => {
@@ -61,6 +62,84 @@ test.describe("Thread history", () => {
).toBeVisible({ timeout: 15_000 }); ).toBeVisible({ timeout: 15_000 });
}); });
test("mock thread does not load real backend run history", async ({
page,
}) => {
mockLangGraphAPI(page, {
threads: [
{
thread_id: DEMO_THREAD_ID,
title: "Forecasting 2026 Trends and Opportunities",
updated_at: "2025-06-01T12:00:00Z",
messages: [
{
type: "human",
id: `run-human-${DEMO_THREAD_ID}`,
content: [
{
type: "text",
text: "This run-message endpoint should not be called.",
},
],
},
],
},
],
});
const backendRunHistoryUrls: string[] = [];
await page.route(
/\/api\/langgraph\/threads\/[^/]+\/runs(?:\?|$)/,
(route) => {
if (
route.request().method() === "GET" &&
route
.request()
.url()
.includes(`/api/langgraph/threads/${DEMO_THREAD_ID}/runs`)
) {
backendRunHistoryUrls.push(route.request().url());
return route.fulfill({
status: 500,
contentType: "application/json",
body: JSON.stringify({
error: "mock=true must not load real runs",
}),
});
}
return route.fallback();
},
);
await page.route(
/\/api\/threads\/[^/]+\/runs\/[^/]+\/messages(?:\?|$)/,
(route) => {
if (
route.request().method() === "GET" &&
route.request().url().includes(`/api/threads/${DEMO_THREAD_ID}/runs/`)
) {
backendRunHistoryUrls.push(route.request().url());
return route.fulfill({
status: 500,
contentType: "application/json",
body: JSON.stringify({
error: "mock=true must not load real run messages",
}),
});
}
return route.fallback();
},
);
await page.goto(`/workspace/chats/${DEMO_THREAD_ID}?mock=true`);
await expect(
page.getByText("What might be the trends and opportunities in 2026?"),
).toBeVisible({ timeout: 15_000 });
await expect(
page.getByText("I've created a modern, minimalist website"),
).toBeVisible();
expect(backendRunHistoryUrls).toEqual([]);
});
test("chats list page shows all threads", async ({ page }) => { test("chats list page shows all threads", async ({ page }) => {
mockLangGraphAPI(page, { threads: THREADS }); mockLangGraphAPI(page, { threads: THREADS });
@@ -0,0 +1,123 @@
import { expect, test } from "@playwright/test";
import { mockLangGraphAPI } from "./utils/mock-api";
// Issue #3482: the sidebar's "Recent chats" and the /workspace/chats list
// page used to stop at the first 50 threads with no way to load more.
// `useInfiniteThreads()` + a sentinel near the bottom of each list now
// pages through the backend.
const TOTAL_THREADS = 120;
const PAGE_SIZE = 50;
const THREADS = Array.from({ length: TOTAL_THREADS }, (_, i) => {
// Pad index so titles sort deterministically as strings. The thread-search
// mock returns threads in the order provided, so paging boundaries are
// stable across runs.
const index = String(i + 1).padStart(3, "0");
return {
thread_id: `00000000-0000-0000-0000-0000000${index.padStart(5, "0")}`,
title: `Conversation ${index}`,
updated_at: `2025-06-${String((i % 28) + 1).padStart(2, "0")}T12:00:00Z`,
};
});
const FIRST_PAGE_LAST = `Conversation ${String(PAGE_SIZE).padStart(3, "0")}`;
const SECOND_PAGE_FIRST = `Conversation ${String(PAGE_SIZE + 1).padStart(3, "0")}`;
test.describe("Thread list infinite scroll (issue #3482)", () => {
test("chats list page loads more threads when scrolling to the bottom", async ({
page,
}) => {
mockLangGraphAPI(page, { threads: THREADS });
await page.goto("/workspace/chats");
const main = page.locator("main");
// First page renders.
await expect(main.getByText(FIRST_PAGE_LAST)).toBeVisible({
timeout: 15_000,
});
// Items past the first page have not been fetched yet.
await expect(main.getByText(SECOND_PAGE_FIRST)).toHaveCount(0);
// Scrolling the sentinel into view triggers the next page.
const sentinel = page.getByTestId("chats-page-sentinel");
await sentinel.scrollIntoViewIfNeeded();
await expect(main.getByText(SECOND_PAGE_FIRST)).toBeVisible({
timeout: 15_000,
});
});
test("sidebar recent chats loads more threads when scrolling to the bottom", async ({
page,
}) => {
mockLangGraphAPI(page, { threads: THREADS });
await page.goto("/workspace/chats/new");
// The 50th thread (end of first page) appears in the sidebar.
await expect(page.getByText(FIRST_PAGE_LAST).first()).toBeVisible({
timeout: 15_000,
});
// The 51st has not been fetched yet.
await expect(page.getByText(SECOND_PAGE_FIRST)).toHaveCount(0);
// Scroll the sidebar sentinel into view to trigger the next page.
const sentinel = page.getByTestId("recent-chat-list-sentinel");
await sentinel.scrollIntoViewIfNeeded();
await expect(page.getByText(SECOND_PAGE_FIRST).first()).toBeVisible({
timeout: 15_000,
});
});
test("chats list page does NOT auto-paginate while a search filter is active", async ({
page,
}) => {
// Count search requests via a passive request observer. Using
// page.route() here would race with mockLangGraphAPI's fulfill route
// (Playwright matches routes in reverse registration order), so the
// counter could miss real requests. page.on('request') is a pure
// observer and never interferes with routing.
let searchRequestCount = 0;
page.on("request", (request) => {
if (request.url().includes("/api/langgraph/threads/search")) {
searchRequestCount += 1;
}
});
mockLangGraphAPI(page, { threads: THREADS });
await page.goto("/workspace/chats");
// Wait for the first page to render so we have a baseline count.
await expect(page.locator("main").getByText(FIRST_PAGE_LAST)).toBeVisible({
timeout: 15_000,
});
const baselineRequests = searchRequestCount;
// Type a query that matches nothing in the first page (and nothing at
// all, since titles are deterministic).
await page
.getByPlaceholder("Search chats")
.fill("zzz-no-such-conversation");
// The auto-sentinel must be gone; an explicit button takes its place.
await expect(page.getByTestId("chats-page-sentinel")).toHaveCount(0);
await expect(page.getByTestId("chats-page-load-more")).toBeVisible();
// Give the IntersectionObserver a couple of frames to misbehave if the
// guard regresses. No additional /threads/search calls should fire.
await page.waitForTimeout(500);
expect(searchRequestCount).toBe(baselineRequests);
// The explicit button still works as an escape hatch.
await page.getByTestId("chats-page-load-more").click();
await expect
.poll(() => searchRequestCount, { timeout: 10_000 })
.toBeGreaterThan(baselineRequests);
});
});
+25 -2
View File
@@ -85,7 +85,7 @@ export function mockLangGraphAPI(page: Page, options?: MockAPIOptions) {
const skills = options?.skills ?? DEFAULT_SKILLS; const skills = options?.skills ?? DEFAULT_SKILLS;
// Thread search — sidebar thread list & chats list page // Thread search — sidebar thread list & chats list page
void page.route("**/api/langgraph/threads/search", (route) => { void page.route("**/api/langgraph/threads/search", async (route) => {
const body = threads.map((t) => ({ const body = threads.map((t) => ({
thread_id: t.thread_id, thread_id: t.thread_id,
created_at: "2025-01-01T00:00:00Z", created_at: "2025-01-01T00:00:00Z",
@@ -94,10 +94,33 @@ export function mockLangGraphAPI(page: Page, options?: MockAPIOptions) {
status: "idle", status: "idle",
values: { title: t.title ?? "Untitled" }, values: { title: t.title ?? "Untitled" },
})); }));
let limit: number | undefined;
let offset = 0;
try {
const postData = route.request().postDataJSON() as {
limit?: number;
offset?: number;
} | null;
if (postData) {
if (typeof postData.limit === "number") {
limit = postData.limit;
}
if (typeof postData.offset === "number") {
offset = postData.offset;
}
}
} catch {
// No / invalid JSON body — fall back to returning the full list.
}
const sliced =
typeof limit === "number" ? body.slice(offset, offset + limit) : body;
return route.fulfill({ return route.fulfill({
status: 200, status: 200,
contentType: "application/json", contentType: "application/json",
body: JSON.stringify(body), body: JSON.stringify(sliced),
}); });
}); });
@@ -1,5 +1,6 @@
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
import { AUTH_DISABLED_USER } from "@/core/auth/auth-disabled-user";
import { STATIC_WEBSITE_USER } from "@/core/auth/static-user"; import { STATIC_WEBSITE_USER } from "@/core/auth/static-user";
vi.mock("next/headers", () => ({ vi.mock("next/headers", () => ({
@@ -10,6 +11,8 @@ vi.mock("next/headers", () => ({
const ENV_KEYS = [ const ENV_KEYS = [
"DEER_FLOW_AUTH_DISABLED", "DEER_FLOW_AUTH_DISABLED",
"DEER_FLOW_ENV",
"ENVIRONMENT",
"NEXT_PUBLIC_STATIC_WEBSITE_ONLY", "NEXT_PUBLIC_STATIC_WEBSITE_ONLY",
] as const; ] as const;
@@ -51,6 +54,8 @@ describe("getServerSideUser", () => {
beforeEach(() => { beforeEach(() => {
saved = snapshotEnv(); saved = snapshotEnv();
setEnv("DEER_FLOW_AUTH_DISABLED", undefined); setEnv("DEER_FLOW_AUTH_DISABLED", undefined);
setEnv("DEER_FLOW_ENV", undefined);
setEnv("ENVIRONMENT", undefined);
setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", undefined); setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", undefined);
}); });
@@ -74,4 +79,30 @@ describe("getServerSideUser", () => {
}); });
expect(fetchSpy).not.toHaveBeenCalled(); expect(fetchSpy).not.toHaveBeenCalled();
}); });
test("bypasses gateway auth in auth-disabled mode", async () => {
setEnv("DEER_FLOW_AUTH_DISABLED", "1");
const fetchSpy = vi.fn(() => {
throw new Error("fetch should not be called in auth-disabled mode");
});
vi.stubGlobal("fetch", fetchSpy);
const { getServerSideUser } = await loadFreshServerAuth();
await expect(getServerSideUser()).resolves.toEqual({
tag: "authenticated",
user: AUTH_DISABLED_USER,
});
expect(fetchSpy).not.toHaveBeenCalled();
});
test("does not enable auth-disabled mode in explicit production environments", async () => {
setEnv("DEER_FLOW_AUTH_DISABLED", "1");
setEnv("DEER_FLOW_ENV", "production");
const { isAuthDisabledMode } =
await import("@/core/auth/auth-disabled-user");
expect(isAuthDisabledMode()).toBe(false);
});
}); });
@@ -0,0 +1,228 @@
import { QueryClient, type InfiniteData } from "@tanstack/react-query";
import { describe, expect, test } from "vitest";
import {
filterInfiniteThreadsCache,
getInfiniteThreadsNextPageParam,
INFINITE_THREADS_PAGE_SIZE,
INFINITE_THREADS_QUERY_KEY_PREFIX,
mapInfiniteThreadsCache,
upsertThreadInInfiniteCache,
} from "@/core/threads/hooks";
import type { AgentThread } from "@/core/threads/types";
// Issue #3482: the sidebar and /workspace/chats list used to be capped at
// 50 threads because `useThreads()` exits as soon as `threads.length >=
// params.limit`. These pure helpers back the `useInfiniteThreads()`
// pagination logic and the mirrored cache writes that keep rename / delete
// / stream-finish in sync with both the legacy array cache and the new
// infinite cache.
function makeThread(id: string, title = `Title ${id}`): AgentThread {
return {
thread_id: id,
created_at: "2025-01-01T00:00:00Z",
updated_at: "2025-01-01T00:00:00Z",
metadata: {},
status: "idle",
values: { title },
} as unknown as AgentThread;
}
function makePage(start: number, size: number): AgentThread[] {
return Array.from({ length: size }, (_, i) => makeThread(`t-${start + i}`));
}
function makeInfiniteData(pages: AgentThread[][]): InfiniteData<AgentThread[]> {
return {
pages,
pageParams: pages.map((_, i) => i * INFINITE_THREADS_PAGE_SIZE),
};
}
describe("getInfiniteThreadsNextPageParam", () => {
test("returns next offset when the last page is full", () => {
const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE);
expect(getInfiniteThreadsNextPageParam(page1, [page1])).toBe(
INFINITE_THREADS_PAGE_SIZE,
);
});
test("returns next offset across multiple full pages", () => {
const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE);
const page2 = makePage(
INFINITE_THREADS_PAGE_SIZE,
INFINITE_THREADS_PAGE_SIZE,
);
expect(getInfiniteThreadsNextPageParam(page2, [page1, page2])).toBe(
INFINITE_THREADS_PAGE_SIZE * 2,
);
});
test("returns undefined when the last page is short (end of list)", () => {
const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE);
const page2 = makePage(INFINITE_THREADS_PAGE_SIZE, 10);
expect(
getInfiniteThreadsNextPageParam(page2, [page1, page2]),
).toBeUndefined();
});
test("returns undefined when the last page is empty", () => {
const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE);
expect(getInfiniteThreadsNextPageParam([], [page1, []])).toBeUndefined();
});
test("respects a custom page size", () => {
const page1 = makePage(0, 5);
expect(getInfiniteThreadsNextPageParam(page1, [page1], 5)).toBe(5);
expect(getInfiniteThreadsNextPageParam(page1, [page1], 10)).toBeUndefined();
});
});
describe("mapInfiniteThreadsCache", () => {
test("returns undefined when oldData is undefined", () => {
expect(mapInfiniteThreadsCache(undefined, (t) => t)).toBeUndefined();
});
test("updates the matching thread across multiple pages", () => {
const page1 = [makeThread("a"), makeThread("b")];
const page2 = [makeThread("c"), makeThread("d")];
const data = makeInfiniteData([page1, page2]);
const updated = mapInfiniteThreadsCache(data, (t) =>
t.thread_id === "c"
? { ...t, values: { ...t.values, title: "renamed" } }
: t,
);
expect(updated?.pages[0]?.[0]?.values?.title).toBe("Title a");
expect(updated?.pages[1]?.[0]?.thread_id).toBe("c");
expect(updated?.pages[1]?.[0]?.values?.title).toBe("renamed");
expect(updated?.pages[1]?.[1]?.values?.title).toBe("Title d");
});
test("preserves pageParams", () => {
const data = makeInfiniteData([[makeThread("a")]]);
const updated = mapInfiniteThreadsCache(data, (t) => t);
expect(updated?.pageParams).toEqual(data.pageParams);
});
});
describe("filterInfiniteThreadsCache", () => {
test("returns undefined when oldData is undefined", () => {
expect(filterInfiniteThreadsCache(undefined, () => true)).toBeUndefined();
});
test("removes matching threads across all pages", () => {
const page1 = [makeThread("a"), makeThread("b")];
const page2 = [makeThread("b"), makeThread("c")];
const data = makeInfiniteData([page1, page2]);
const filtered = filterInfiniteThreadsCache(
data,
(t) => t.thread_id !== "b",
);
expect(filtered?.pages[0]?.map((t) => t.thread_id)).toEqual(["a"]);
expect(filtered?.pages[1]?.map((t) => t.thread_id)).toEqual(["c"]);
});
test("keeps an emptied page as an empty array (does not drop the page)", () => {
const page1 = [makeThread("a")];
const page2 = [makeThread("b")];
const data = makeInfiniteData([page1, page2]);
const filtered = filterInfiniteThreadsCache(
data,
(t) => t.thread_id !== "a",
);
expect(filtered?.pages).toHaveLength(2);
expect(filtered?.pages[0]).toEqual([]);
expect(filtered?.pages[1]?.[0]?.thread_id).toBe("b");
});
test("does not regress next offset when an earlier page has been shrunk by a delete", () => {
// Simulate two full pages already loaded.
const page1 = Array.from({ length: 50 }, (_, i) => ({
thread_id: `a${i}`,
}));
const page2 = Array.from({ length: 50 }, (_, i) => ({
thread_id: `b${i}`,
}));
// Offset right after fetching page 2 (this is the value TanStack Query
// freezes into pageParams).
const offsetAfterPage2 = getInfiniteThreadsNextPageParam(
page2 as unknown as AgentThread[],
[page1, page2] as unknown as AgentThread[][],
);
expect(offsetAfterPage2).toBe(100);
// Now a delete mutation runs filterInfiniteThreadsCache and shrinks
// page 1 from 50 to 49 entries. TanStack does NOT re-invoke
// getNextPageParam on cache mutations; the previously-computed offset
// (100) remains the param for the next fetchNextPage() call, so the
// helper is consistent with how the library uses its return value.
const shrunkPage1 = page1.slice(0, 49);
const recomputed = getInfiniteThreadsNextPageParam(
page2 as unknown as AgentThread[],
[shrunkPage1, page2] as unknown as AgentThread[][],
);
// We document the recomputed value for completeness, but in practice
// useDeleteThread invalidates the query in onSettled, so pages are
// refetched from offset 0 rather than relying on this number.
expect(recomputed).toBe(99);
});
});
describe("upsertThreadInInfiniteCache", () => {
function seedClient(initial?: InfiniteData<AgentThread[]>): QueryClient {
const client = new QueryClient();
if (initial) {
client.setQueryData([...INFINITE_THREADS_QUERY_KEY_PREFIX, {}], initial);
}
return client;
}
function readCache(
client: QueryClient,
): InfiniteData<AgentThread[]> | undefined {
return client.getQueryData([...INFINITE_THREADS_QUERY_KEY_PREFIX, {}]);
}
test("no-op when the infinite cache has not been initialised yet", () => {
const client = seedClient();
upsertThreadInInfiniteCache(client, makeThread("new"));
expect(readCache(client)).toBeUndefined();
});
test("prepends a brand-new thread to the first page", () => {
const client = seedClient({
pages: [[makeThread("a"), makeThread("b")]],
pageParams: [0],
});
upsertThreadInInfiniteCache(client, makeThread("new"));
const cache = readCache(client);
expect(cache?.pages[0]?.map((t) => t.thread_id)).toEqual(["new", "a", "b"]);
});
test("merges into the existing entry instead of duplicating it", () => {
const existing = makeThread("a", "Old title");
const client = seedClient({
pages: [[existing, makeThread("b")]],
pageParams: [0],
});
// Simulate an onCreated upsert that races with a thread already in cache:
// the cache copy should win for title/metadata (it represents later state),
// but no duplicate row should appear.
upsertThreadInInfiniteCache(client, {
...makeThread("a", "New title"),
status: "busy",
});
const cache = readCache(client);
const ids = cache?.pages[0]?.map((t) => t.thread_id);
expect(ids).toEqual(["a", "b"]);
expect(cache?.pages[0]?.[0]?.values.title).toBe("Old title");
});
});