mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
Compare commits
9 Commits
v2.0-m1-rc3
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 2d5f0787de | |||
| 5819bd8a59 | |||
| b3c2cc42cf | |||
| 167ef4512f | |||
| ba9cc5e972 | |||
| 05ae4467ae | |||
| 2b795265e7 | |||
| a57d05fe0a | |||
| ae9e8bc0bf |
@@ -10,7 +10,7 @@ permissions:
|
|||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint:
|
lint-backend:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v6
|
||||||
|
|||||||
@@ -247,6 +247,9 @@ Access: http://localhost:2026
|
|||||||
|
|
||||||
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
|
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> The Gateway holds run state (RunManager and the stream bridge) in process, so production defaults to a single Gateway worker (`GATEWAY_WORKERS=1`). Raising the worker count without a shared cross-worker stream bridge — which is not yet available — breaks run cancellation, SSE reconnects, request de-duplication, and IM channels, because nginx uses no sticky sessions and each worker keeps its own run state. Scale a single worker up with more CPU/RAM (or move the database and sandbox onto dedicated tiers) instead of raising `GATEWAY_WORKERS`.
|
||||||
|
|
||||||
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
||||||
|
|
||||||
#### Option 2: Local Development
|
#### Option 2: Local Development
|
||||||
|
|||||||
@@ -429,6 +429,12 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
|||||||
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
||||||
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
||||||
|
|
||||||
|
**Token counting** (`packages/harness/deerflow/agents/memory/prompt.py`):
|
||||||
|
- `_count_tokens` budgets the injection. In default `tiktoken` mode, the encoding is loaded lazily and cached.
|
||||||
|
- Failed tiktoken loads are cached with a timestamp. During the fixed cooldown (`_TIKTOKEN_RETRY_COOLDOWN_S`, 600s), callers fall back to char estimation immediately instead of re-triggering the blocking BPE download; after the cooldown, transient outages can self-heal without a restart.
|
||||||
|
- In-flight loads are cached as a LOADING sentinel so concurrent callers fall back instead of spawning more blocking threads.
|
||||||
|
- Set `memory.token_counting: char` to skip tiktoken entirely and use the network-free CJK-aware char estimate.
|
||||||
|
|
||||||
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
|
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
|
||||||
|
|
||||||
**Configuration** (`config.yaml` → `memory`):
|
**Configuration** (`config.yaml` → `memory`):
|
||||||
@@ -438,6 +444,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
|||||||
- `model_name` - LLM for updates (null = default model)
|
- `model_name` - LLM for updates (null = default model)
|
||||||
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
||||||
- `max_injection_tokens` - Token limit for prompt injection (2000)
|
- `max_injection_tokens` - Token limit for prompt injection (2000)
|
||||||
|
- `token_counting` - Token counting strategy for the injection budget: `tiktoken` (default, accurate but may download BPE data from a public endpoint on first use — can block for a long time in network-restricted environments, see issues #3402/#3429) or `char` (network-free CJK-aware char estimate, never touches tiktoken)
|
||||||
|
|
||||||
### Reflection System (`packages/harness/deerflow/reflection/`)
|
### Reflection System (`packages/harness/deerflow/reflection/`)
|
||||||
|
|
||||||
|
|||||||
+22
-14
@@ -6,6 +6,7 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
|
||||||
from app.gateway.auth_middleware import AuthMiddleware
|
from app.gateway.auth_middleware import AuthMiddleware
|
||||||
from app.gateway.config import get_gateway_config
|
from app.gateway.config import get_gateway_config
|
||||||
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
|
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
|
||||||
@@ -172,6 +173,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
startup_config = get_app_config()
|
startup_config = get_app_config()
|
||||||
apply_logging_level(startup_config.log_level)
|
apply_logging_level(startup_config.log_level)
|
||||||
logger.info("Configuration loaded successfully")
|
logger.info("Configuration loaded successfully")
|
||||||
|
warn_if_auth_disabled_enabled()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
||||||
logger.exception(error_msg)
|
logger.exception(error_msg)
|
||||||
@@ -182,21 +184,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
# Pre-warm tiktoken encoding cache so the first memory-injection request
|
# Pre-warm tiktoken encoding cache so the first memory-injection request
|
||||||
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
|
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
|
||||||
# that may be unreachable in restricted networks — see issue #3402).
|
# that may be unreachable in restricted networks — see issue #3402).
|
||||||
try:
|
# When memory.token_counting is "char", token counting never touches
|
||||||
from deerflow.agents.memory.prompt import warm_tiktoken_cache
|
# tiktoken, so skip the warm-up entirely (avoids even the 5s probe in
|
||||||
|
# network-restricted deployments — see issue #3429).
|
||||||
|
if startup_config.memory.token_counting == "char":
|
||||||
|
logger.info("memory.token_counting='char'; skipping tiktoken warm-up (network-free token estimation)")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from deerflow.agents.memory.prompt import warm_tiktoken_cache
|
||||||
|
|
||||||
warmed = await asyncio.wait_for(
|
warmed = await asyncio.wait_for(
|
||||||
asyncio.to_thread(warm_tiktoken_cache),
|
asyncio.to_thread(warm_tiktoken_cache),
|
||||||
timeout=5,
|
timeout=5,
|
||||||
)
|
)
|
||||||
if warmed:
|
if warmed:
|
||||||
logger.info("tiktoken encoding cache warmed successfully")
|
logger.info("tiktoken encoding cache warmed successfully")
|
||||||
else:
|
else:
|
||||||
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback")
|
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback until tiktoken loads successfully")
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback")
|
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback until tiktoken loads successfully")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("tiktoken warm-up skipped", exc_info=True)
|
logger.warning("tiktoken warm-up skipped", exc_info=True)
|
||||||
|
|
||||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
||||||
async with langgraph_runtime(app, startup_config):
|
async with langgraph_runtime(app, startup_config):
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
"""Shared helpers for local/E2E auth-disabled mode."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED"
|
||||||
|
AUTH_DISABLED_USER_ID = "e2e-user"
|
||||||
|
AUTH_DISABLED_USER_EMAIL = "e2e@test.local"
|
||||||
|
|
||||||
|
AUTH_SOURCE_SESSION = "session"
|
||||||
|
AUTH_SOURCE_INTERNAL = "internal"
|
||||||
|
AUTH_SOURCE_AUTH_DISABLED = "auth_disabled"
|
||||||
|
|
||||||
|
_PRODUCTION_ENV_VARS: tuple[str, ...] = ("DEER_FLOW_ENV", "ENVIRONMENT")
|
||||||
|
_PRODUCTION_ENV_VALUES: frozenset[str] = frozenset({"prod", "production"})
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_explicit_production_environment() -> bool:
|
||||||
|
return any(os.environ.get(name, "").strip().lower() in _PRODUCTION_ENV_VALUES for name in _PRODUCTION_ENV_VARS)
|
||||||
|
|
||||||
|
|
||||||
|
def is_auth_disabled_requested() -> bool:
|
||||||
|
return os.environ.get(AUTH_DISABLED_ENV_VAR) == "1"
|
||||||
|
|
||||||
|
|
||||||
|
def is_auth_disabled() -> bool:
|
||||||
|
return is_auth_disabled_requested() and not is_explicit_production_environment()
|
||||||
|
|
||||||
|
|
||||||
|
def warn_if_auth_disabled_enabled() -> None:
|
||||||
|
if not is_auth_disabled():
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"%s=1 is active: authentication is bypassed and anonymous requests run as synthetic admin user %r. Do not enable this in shared or production deployments.",
|
||||||
|
AUTH_DISABLED_ENV_VAR,
|
||||||
|
AUTH_DISABLED_USER_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_disabled_user():
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=AUTH_DISABLED_USER_ID,
|
||||||
|
email=AUTH_DISABLED_USER_EMAIL,
|
||||||
|
password_hash=None,
|
||||||
|
system_role="admin",
|
||||||
|
needs_setup=False,
|
||||||
|
token_version=0,
|
||||||
|
)
|
||||||
@@ -17,6 +17,13 @@ from starlette.responses import JSONResponse
|
|||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||||
|
from app.gateway.auth_disabled import (
|
||||||
|
AUTH_SOURCE_AUTH_DISABLED,
|
||||||
|
AUTH_SOURCE_INTERNAL,
|
||||||
|
AUTH_SOURCE_SESSION,
|
||||||
|
get_auth_disabled_user,
|
||||||
|
is_auth_disabled,
|
||||||
|
)
|
||||||
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
||||||
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
|
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
|
||||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||||
@@ -80,8 +87,38 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
|
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
|
||||||
internal_user = get_internal_user()
|
internal_user = get_internal_user()
|
||||||
|
|
||||||
|
auth_source = AUTH_SOURCE_SESSION
|
||||||
|
access_token = request.cookies.get("access_token")
|
||||||
|
|
||||||
# Non-public path: require session cookie
|
# Non-public path: require session cookie
|
||||||
if internal_user is None and not request.cookies.get("access_token"):
|
if internal_user is not None:
|
||||||
|
user = internal_user
|
||||||
|
auth_source = AUTH_SOURCE_INTERNAL
|
||||||
|
elif access_token:
|
||||||
|
# Strict JWT validation: reject junk/expired tokens with 401
|
||||||
|
# right here instead of silently passing through. This closes
|
||||||
|
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
||||||
|
# without this, non-isolation routes like /api/models would
|
||||||
|
# accept any cookie-shaped string as authentication.
|
||||||
|
#
|
||||||
|
# We call the *strict* resolver so that fine-grained error
|
||||||
|
# codes (token_expired, token_invalid, user_not_found, …)
|
||||||
|
# propagate from AuthErrorCode, not get flattened into one
|
||||||
|
# generic code. BaseHTTPMiddleware doesn't let HTTPException
|
||||||
|
# bubble up, so we catch and render it as JSONResponse here.
|
||||||
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
try:
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
except HTTPException as exc:
|
||||||
|
if not is_auth_disabled():
|
||||||
|
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||||
|
user = get_auth_disabled_user()
|
||||||
|
auth_source = AUTH_SOURCE_AUTH_DISABLED
|
||||||
|
elif is_auth_disabled():
|
||||||
|
user = get_auth_disabled_user()
|
||||||
|
auth_source = AUTH_SOURCE_AUTH_DISABLED
|
||||||
|
else:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
content={
|
content={
|
||||||
@@ -92,32 +129,12 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Strict JWT validation: reject junk/expired tokens with 401
|
|
||||||
# right here instead of silently passing through. This closes
|
|
||||||
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
|
||||||
# without this, non-isolation routes like /api/models would
|
|
||||||
# accept any cookie-shaped string as authentication.
|
|
||||||
#
|
|
||||||
# We call the *strict* resolver so that fine-grained error
|
|
||||||
# codes (token_expired, token_invalid, user_not_found, …)
|
|
||||||
# propagate from AuthErrorCode, not get flattened into one
|
|
||||||
# generic code. BaseHTTPMiddleware doesn't let HTTPException
|
|
||||||
# bubble up, so we catch and render it as JSONResponse here.
|
|
||||||
from app.gateway.deps import get_current_user_from_request
|
|
||||||
|
|
||||||
if internal_user is not None:
|
|
||||||
user = internal_user
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
user = await get_current_user_from_request(request)
|
|
||||||
except HTTPException as exc:
|
|
||||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
|
||||||
|
|
||||||
# Stamp both request.state.user (for the contextvar pattern)
|
# Stamp both request.state.user (for the contextvar pattern)
|
||||||
# and request.state.auth (so @require_permission's "auth is
|
# and request.state.auth (so @require_permission's "auth is
|
||||||
# None" branch short-circuits instead of running the entire
|
# None" branch short-circuits instead of running the entire
|
||||||
# JWT-decode + DB-lookup pipeline a second time per request).
|
# JWT-decode + DB-lookup pipeline a second time per request).
|
||||||
request.state.user = user
|
request.state.user = user
|
||||||
|
request.state.auth_source = auth_source
|
||||||
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||||
token = set_current_user(user)
|
token = set_current_user(user)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
from app.gateway.auth_disabled import is_auth_disabled
|
||||||
|
|
||||||
CSRF_COOKIE_NAME = "csrf_token"
|
CSRF_COOKIE_NAME = "csrf_token"
|
||||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||||
CSRF_TOKEN_LENGTH = 64 # bytes
|
CSRF_TOKEN_LENGTH = 64 # bytes
|
||||||
@@ -38,6 +40,9 @@ def should_check_csrf(request: Request) -> bool:
|
|||||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if is_auth_disabled():
|
||||||
|
return False
|
||||||
|
|
||||||
path = request.url.path.rstrip("/")
|
path = request.url.path.rstrip("/")
|
||||||
# Exempt /api/v1/auth/me endpoint
|
# Exempt /api/v1/auth/me endpoint
|
||||||
if path == "/api/v1/auth/me":
|
if path == "/api/v1/auth/me":
|
||||||
|
|||||||
@@ -331,6 +331,17 @@ async def get_current_user_from_request(request: Request):
|
|||||||
|
|
||||||
Raises HTTPException 401 if not authenticated.
|
Raises HTTPException 401 if not authenticated.
|
||||||
"""
|
"""
|
||||||
|
state = getattr(request, "state", None)
|
||||||
|
state_user = getattr(state, "user", None)
|
||||||
|
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED, AUTH_SOURCE_INTERNAL, AUTH_SOURCE_SESSION
|
||||||
|
|
||||||
|
if state_user is not None and getattr(state, "auth_source", None) in {
|
||||||
|
AUTH_SOURCE_SESSION,
|
||||||
|
AUTH_SOURCE_AUTH_DISABLED,
|
||||||
|
AUTH_SOURCE_INTERNAL,
|
||||||
|
}:
|
||||||
|
return state_user
|
||||||
|
|
||||||
from app.gateway.auth import decode_token
|
from app.gateway.auth import decode_token
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from langgraph_sdk import Auth
|
|||||||
|
|
||||||
from app.gateway.auth.errors import TokenError
|
from app.gateway.auth.errors import TokenError
|
||||||
from app.gateway.auth.jwt import decode_token
|
from app.gateway.auth.jwt import decode_token
|
||||||
|
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
|
||||||
from app.gateway.deps import get_local_provider
|
from app.gateway.deps import get_local_provider
|
||||||
|
|
||||||
auth = Auth()
|
auth = Auth()
|
||||||
@@ -38,6 +39,9 @@ def _check_csrf(request) -> None:
|
|||||||
if method.upper() not in _CSRF_METHODS:
|
if method.upper() not in _CSRF_METHODS:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if is_auth_disabled():
|
||||||
|
return
|
||||||
|
|
||||||
cookie_token = request.cookies.get("csrf_token")
|
cookie_token = request.cookies.get("csrf_token")
|
||||||
header_token = request.headers.get("x-csrf-token")
|
header_token = request.headers.get("x-csrf-token")
|
||||||
|
|
||||||
@@ -66,6 +70,9 @@ async def authenticate(request):
|
|||||||
# are rejected early, even if the cookie carries a valid JWT.
|
# are rejected early, even if the cookie carries a valid JWT.
|
||||||
_check_csrf(request)
|
_check_csrf(request)
|
||||||
|
|
||||||
|
if is_auth_disabled():
|
||||||
|
return AUTH_DISABLED_USER_ID
|
||||||
|
|
||||||
token = request.cookies.get("access_token")
|
token = request.cookies.get("access_token")
|
||||||
if not token:
|
if not token:
|
||||||
raise Auth.exceptions.HTTPException(
|
raise Auth.exceptions.HTTPException(
|
||||||
|
|||||||
@@ -341,9 +341,19 @@ async def change_password(request: Request, response: Response, body: ChangePass
|
|||||||
- Re-issues session cookie with new token_version
|
- Re-issues session cookie with new token_version
|
||||||
"""
|
"""
|
||||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||||
|
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED
|
||||||
|
|
||||||
user = await get_current_user_from_request(request)
|
user = await get_current_user_from_request(request)
|
||||||
|
|
||||||
|
if getattr(request.state, "auth_source", None) == AUTH_SOURCE_AUTH_DISABLED:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=AuthErrorResponse(
|
||||||
|
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||||
|
message="Password changes are not available when DEER_FLOW_AUTH_DISABLED=1.",
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
if user.password_hash is None:
|
if user.password_hash is None:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
||||||
|
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ class MemoryConfigResponse(BaseModel):
|
|||||||
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
|
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
|
||||||
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
|
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
|
||||||
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
|
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
|
||||||
|
token_counting: str = Field(..., description="Token counting strategy for memory injection ('tiktoken' or 'char')")
|
||||||
|
|
||||||
|
|
||||||
class MemoryStatusResponse(BaseModel):
|
class MemoryStatusResponse(BaseModel):
|
||||||
@@ -310,7 +311,8 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
|||||||
"max_facts": 100,
|
"max_facts": 100,
|
||||||
"fact_confidence_threshold": 0.7,
|
"fact_confidence_threshold": 0.7,
|
||||||
"injection_enabled": true,
|
"injection_enabled": true,
|
||||||
"max_injection_tokens": 2000
|
"max_injection_tokens": 2000,
|
||||||
|
"token_counting": "tiktoken"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@@ -323,6 +325,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
|||||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||||
injection_enabled=config.injection_enabled,
|
injection_enabled=config.injection_enabled,
|
||||||
max_injection_tokens=config.max_injection_tokens,
|
max_injection_tokens=config.max_injection_tokens,
|
||||||
|
token_counting=config.token_counting,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -351,6 +354,7 @@ async def get_memory_status() -> MemoryStatusResponse:
|
|||||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||||
injection_enabled=config.injection_enabled,
|
injection_enabled=config.injection_enabled,
|
||||||
max_injection_tokens=config.max_injection_tokens,
|
max_injection_tokens=config.max_injection_tokens,
|
||||||
|
token_counting=config.token_counting,
|
||||||
),
|
),
|
||||||
data=MemoryResponse(**memory_data),
|
data=MemoryResponse(**memory_data),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -315,6 +315,21 @@ async def start_run(
|
|||||||
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Stateless run endpoints carry thread_id in the request *body*, so the
|
||||||
|
# @require_permission(owner_check=True) decorator -- which resolves ownership
|
||||||
|
# from the path param -- cannot protect them. Enforce thread ownership here,
|
||||||
|
# before any run is created, so one user cannot start runs on (or read /wait
|
||||||
|
# checkpoint state from) another user's thread. Missing rows (auto-created
|
||||||
|
# temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible
|
||||||
|
# via check_access; only a thread already owned by another user is rejected
|
||||||
|
# with 404, matching thread_runs.py's anti-enumeration behaviour. Internal
|
||||||
|
# channel runs act on behalf of IM users they do not own (see
|
||||||
|
# inject_authenticated_user_context), so the internal system role is exempt.
|
||||||
|
user = getattr(request.state, "user", None)
|
||||||
|
if user is not None and getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
|
||||||
|
if not await run_ctx.thread_store.check_access(thread_id, str(user.id)):
|
||||||
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
record = await run_mgr.create_or_reject(
|
record = await run_mgr.create_or_reject(
|
||||||
thread_id,
|
thread_id,
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ Current injection format:
|
|||||||
|
|
||||||
Token counting:
|
Token counting:
|
||||||
- Uses `tiktoken` (`cl100k_base`) when available
|
- Uses `tiktoken` (`cl100k_base`) when available
|
||||||
- Falls back to `len(text) // 4` if tokenizer import fails
|
- Falls back to a network-free CJK-aware character estimate if tokenizer import or encoding load fails
|
||||||
|
(CJK characters count as ~2 chars/token, other characters as ~4 chars/token)
|
||||||
|
|
||||||
## Known Gap
|
## Known Gap
|
||||||
|
|
||||||
|
|||||||
@@ -586,7 +586,11 @@ def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
||||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
memory_content = format_memory_for_injection(
|
||||||
|
memory_data,
|
||||||
|
max_tokens=config.max_injection_tokens,
|
||||||
|
use_tiktoken=(config.token_counting == "tiktoken"),
|
||||||
|
)
|
||||||
|
|
||||||
if not memory_content.strip():
|
if not memory_content.strip():
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -169,7 +171,26 @@ Return ONLY valid JSON."""
|
|||||||
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
|
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
|
||||||
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
|
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
|
||||||
# (potentially slow) first ``get_encoding`` call.
|
# (potentially slow) first ``get_encoding`` call.
|
||||||
_tiktoken_encoding_cache: dict[str, tiktoken.Encoding] = {}
|
#
|
||||||
|
# A *failed* load is cached as a ``(None, monotonic_timestamp)`` tuple so that
|
||||||
|
# a network-restricted environment does not re-attempt the blocking BPE
|
||||||
|
# download on every subsequent call. After ``_TIKTOKEN_RETRY_COOLDOWN_S`` the
|
||||||
|
# failure is allowed to expire so a transient network outage can self-heal back
|
||||||
|
# to accurate tiktoken counting without a process restart. A load already in
|
||||||
|
# progress is cached as ``_TIKTOKEN_ENCODING_LOADING`` so concurrent callers
|
||||||
|
# fall back immediately instead of spawning more blocking
|
||||||
|
# ``tiktoken.get_encoding`` threads. Use the ``memory.token_counting: char``
|
||||||
|
# config to skip tiktoken entirely.
|
||||||
|
_TIKTOKEN_ENCODING_MISSING = object()
|
||||||
|
_TIKTOKEN_ENCODING_LOADING = object()
|
||||||
|
# Cooldown before a *failed* tiktoken load is re-attempted. This is an internal
|
||||||
|
# tuning constant rather than a user-facing config: it only affects how quickly
|
||||||
|
# the default ``tiktoken`` mode self-heals after a transient network outage.
|
||||||
|
# Deployments that want to avoid tiktoken's network dependency entirely should
|
||||||
|
# set ``memory.token_counting: char`` instead of tuning this value.
|
||||||
|
_TIKTOKEN_RETRY_COOLDOWN_S = 600.0
|
||||||
|
_tiktoken_encoding_cache: dict[str, Any] = {}
|
||||||
|
_tiktoken_encoding_cache_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
|
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
|
||||||
@@ -181,44 +202,91 @@ def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encod
|
|||||||
download can block for tens of minutes before the OS TCP timeout kicks in.
|
download can block for tens of minutes before the OS TCP timeout kicks in.
|
||||||
The caller must therefore be prepared for this to block and should run it
|
The caller must therefore be prepared for this to block and should run it
|
||||||
off the event loop (e.g. via ``asyncio.to_thread``).
|
off the event loop (e.g. via ``asyncio.to_thread``).
|
||||||
|
|
||||||
|
A failed load is remembered (with a timestamp) so subsequent calls fall
|
||||||
|
back immediately to character-based estimation instead of re-triggering the
|
||||||
|
blocking download. The failure expires after ``_TIKTOKEN_RETRY_COOLDOWN_S``
|
||||||
|
so a transient outage can self-heal without a restart. A load already in
|
||||||
|
progress is also remembered so that a timed-out caller does not leave a
|
||||||
|
window where later requests start more blocking ``get_encoding`` calls.
|
||||||
"""
|
"""
|
||||||
if not TIKTOKEN_AVAILABLE:
|
if not TIKTOKEN_AVAILABLE:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
cached = _tiktoken_encoding_cache.get(encoding_name)
|
with _tiktoken_encoding_cache_lock:
|
||||||
if cached is not None:
|
cached = _tiktoken_encoding_cache.get(encoding_name, _TIKTOKEN_ENCODING_MISSING)
|
||||||
return cached
|
if cached is _TIKTOKEN_ENCODING_LOADING:
|
||||||
|
return None
|
||||||
|
if isinstance(cached, tuple):
|
||||||
|
# Cached failure: (None, failed_at). Retry only after cooldown.
|
||||||
|
_, failed_at = cached
|
||||||
|
if time.monotonic() - failed_at < _TIKTOKEN_RETRY_COOLDOWN_S:
|
||||||
|
return None
|
||||||
|
cached = _TIKTOKEN_ENCODING_MISSING
|
||||||
|
if cached is not _TIKTOKEN_ENCODING_MISSING:
|
||||||
|
return cast("tiktoken.Encoding", cached)
|
||||||
|
_tiktoken_encoding_cache[encoding_name] = _TIKTOKEN_ENCODING_LOADING
|
||||||
|
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.get_encoding(encoding_name)
|
encoding = tiktoken.get_encoding(encoding_name)
|
||||||
_tiktoken_encoding_cache[encoding_name] = encoding
|
|
||||||
return encoding
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
|
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
|
||||||
|
with _tiktoken_encoding_cache_lock:
|
||||||
|
_tiktoken_encoding_cache[encoding_name] = (None, time.monotonic())
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
with _tiktoken_encoding_cache_lock:
|
||||||
|
_tiktoken_encoding_cache[encoding_name] = encoding
|
||||||
|
return encoding
|
||||||
|
|
||||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
|
||||||
|
def _char_based_token_estimate(text: str) -> int:
|
||||||
|
"""Network-free token estimate that accounts for CJK density.
|
||||||
|
|
||||||
|
The plain ``len(text) // 4`` heuristic is reasonable for English/code
|
||||||
|
(~4 chars per token) but significantly under-estimates token counts for
|
||||||
|
Chinese, Japanese, and Korean text, where the ratio is closer to 1.5-2
|
||||||
|
characters per token. Counting CJK characters separately (~2 chars per
|
||||||
|
token) avoids over-filling the injection budget for CJK-heavy memory
|
||||||
|
content.
|
||||||
|
"""
|
||||||
|
cjk = sum(
|
||||||
|
1
|
||||||
|
for ch in text
|
||||||
|
if "\u4e00" <= ch <= "\u9fff" # CJK Unified Ideographs
|
||||||
|
or "\u3040" <= ch <= "\u30ff" # Hiragana + Katakana
|
||||||
|
or "\uac00" <= ch <= "\ud7a3" # Hangul syllables
|
||||||
|
)
|
||||||
|
return (len(text) - cjk) // 4 + cjk // 2
|
||||||
|
|
||||||
|
|
||||||
|
def _count_tokens(text: str, encoding_name: str = "cl100k_base", *, use_tiktoken: bool = True) -> int:
|
||||||
"""Count tokens in text using tiktoken.
|
"""Count tokens in text using tiktoken.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The text to count tokens for.
|
text: The text to count tokens for.
|
||||||
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
||||||
|
use_tiktoken: When ``False``, skip tiktoken entirely and use the
|
||||||
|
network-free character-based estimate. This guarantees no BPE
|
||||||
|
download is attempted (see ``memory.token_counting`` config).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of tokens in the text.
|
The number of tokens in the text.
|
||||||
"""
|
"""
|
||||||
|
if not use_tiktoken:
|
||||||
|
return _char_based_token_estimate(text)
|
||||||
|
|
||||||
encoding = _get_tiktoken_encoding(encoding_name)
|
encoding = _get_tiktoken_encoding(encoding_name)
|
||||||
if encoding is None:
|
if encoding is None:
|
||||||
# Fallback to character-based estimation if tiktoken is not available
|
# Fallback to CJK-aware character estimation if tiktoken is not
|
||||||
# or the encoding failed to load.
|
# available or the encoding failed to load.
|
||||||
return len(text) // 4
|
return _char_based_token_estimate(text)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return len(encoding.encode(text))
|
return len(encoding.encode(text))
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fallback to character-based estimation on error
|
# Fallback to CJK-aware character estimation on error.
|
||||||
return len(text) // 4
|
return _char_based_token_estimate(text)
|
||||||
|
|
||||||
|
|
||||||
def warm_tiktoken_cache() -> bool:
|
def warm_tiktoken_cache() -> bool:
|
||||||
@@ -248,12 +316,15 @@ def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
|||||||
return max(0.0, min(1.0, confidence))
|
return max(0.0, min(1.0, confidence))
|
||||||
|
|
||||||
|
|
||||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000, *, use_tiktoken: bool = True) -> str:
|
||||||
"""Format memory data for injection into system prompt.
|
"""Format memory data for injection into system prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_data: The memory data dictionary.
|
memory_data: The memory data dictionary.
|
||||||
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
||||||
|
use_tiktoken: When ``False``, all token counting uses the network-free
|
||||||
|
character-based estimate instead of tiktoken (see
|
||||||
|
``memory.token_counting`` config). Defaults to ``True``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted memory string for system prompt injection.
|
Formatted memory string for system prompt injection.
|
||||||
@@ -315,10 +386,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
# Compute token count for existing sections once, then account
|
# Compute token count for existing sections once, then account
|
||||||
# incrementally for each fact line to avoid full-string re-tokenization.
|
# incrementally for each fact line to avoid full-string re-tokenization.
|
||||||
base_text = "\n\n".join(sections)
|
base_text = "\n\n".join(sections)
|
||||||
base_tokens = _count_tokens(base_text) if base_text else 0
|
base_tokens = _count_tokens(base_text, use_tiktoken=use_tiktoken) if base_text else 0
|
||||||
# Account for the separator between existing sections and the facts section.
|
# Account for the separator between existing sections and the facts section.
|
||||||
facts_header = "Facts:\n"
|
facts_header = "Facts:\n"
|
||||||
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
|
separator_tokens = _count_tokens("\n\n" + facts_header, use_tiktoken=use_tiktoken) if base_text else _count_tokens(facts_header, use_tiktoken=use_tiktoken)
|
||||||
running_tokens = base_tokens + separator_tokens
|
running_tokens = base_tokens + separator_tokens
|
||||||
|
|
||||||
fact_lines: list[str] = []
|
fact_lines: list[str] = []
|
||||||
@@ -339,7 +410,7 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
|
|
||||||
# Each additional line is preceded by a newline (except the first).
|
# Each additional line is preceded by a newline (except the first).
|
||||||
line_text = ("\n" + line) if fact_lines else line
|
line_text = ("\n" + line) if fact_lines else line
|
||||||
line_tokens = _count_tokens(line_text)
|
line_tokens = _count_tokens(line_text, use_tiktoken=use_tiktoken)
|
||||||
|
|
||||||
if running_tokens + line_tokens <= max_tokens:
|
if running_tokens + line_tokens <= max_tokens:
|
||||||
fact_lines.append(line)
|
fact_lines.append(line)
|
||||||
@@ -355,8 +426,9 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
|
|
||||||
result = "\n\n".join(sections)
|
result = "\n\n".join(sections)
|
||||||
|
|
||||||
# Use accurate token counting with tiktoken
|
# Use accurate token counting with tiktoken (or the char-based estimate
|
||||||
token_count = _count_tokens(result)
|
# when use_tiktoken is False).
|
||||||
|
token_count = _count_tokens(result, use_tiktoken=use_tiktoken)
|
||||||
if token_count > max_tokens:
|
if token_count > max_tokens:
|
||||||
# Truncate to fit within token limit
|
# Truncate to fit within token limit
|
||||||
# Estimate characters to remove based on token ratio
|
# Estimate characters to remove based on token ratio
|
||||||
|
|||||||
@@ -1141,6 +1141,7 @@ class DeerFlowClient:
|
|||||||
"fact_confidence_threshold": config.fact_confidence_threshold,
|
"fact_confidence_threshold": config.fact_confidence_threshold,
|
||||||
"injection_enabled": config.injection_enabled,
|
"injection_enabled": config.injection_enabled,
|
||||||
"max_injection_tokens": config.max_injection_tokens,
|
"max_injection_tokens": config.max_injection_tokens,
|
||||||
|
"token_counting": config.token_counting,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_memory_status(self) -> dict:
|
def get_memory_status(self) -> dict:
|
||||||
|
|||||||
@@ -67,11 +67,13 @@ def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
|
|||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
effective_user = user_id or get_effective_user_id()
|
effective_user = user_id or get_effective_user_id()
|
||||||
user_path = paths.user_agent_dir(effective_user, name)
|
user_path = paths.user_agent_dir(effective_user, name)
|
||||||
if user_path.exists():
|
# Require config.yaml to confirm this is a genuine agent directory,
|
||||||
|
# not a leftover from memory/storage writes (see #3390).
|
||||||
|
if user_path.exists() and (user_path / "config.yaml").exists():
|
||||||
return user_path
|
return user_path
|
||||||
|
|
||||||
legacy_path = paths.agent_dir(name)
|
legacy_path = paths.agent_dir(name)
|
||||||
if legacy_path.exists():
|
if legacy_path.exists() and (legacy_path / "config.yaml").exists():
|
||||||
return legacy_path
|
return legacy_path
|
||||||
|
|
||||||
return user_path
|
return user_path
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Configuration for memory mechanism."""
|
"""Configuration for memory mechanism."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@@ -60,6 +62,17 @@ class MemoryConfig(BaseModel):
|
|||||||
le=8000,
|
le=8000,
|
||||||
description="Maximum tokens to use for memory injection",
|
description="Maximum tokens to use for memory injection",
|
||||||
)
|
)
|
||||||
|
token_counting: Literal["tiktoken", "char"] = Field(
|
||||||
|
default="tiktoken",
|
||||||
|
description=(
|
||||||
|
"Token counting strategy for memory-injection budgeting. "
|
||||||
|
"'tiktoken' is accurate but the encoding's BPE data may be "
|
||||||
|
"downloaded from a public network endpoint on first use, which "
|
||||||
|
"can block for a long time in network-restricted environments "
|
||||||
|
"(see issue #3402/#3429). 'char' uses a network-free "
|
||||||
|
"CJK-aware character-based estimate and never touches tiktoken."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Global configuration instance
|
# Global configuration instance
|
||||||
|
|||||||
@@ -4,7 +4,20 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
class VolumeMountConfig(BaseModel):
|
class VolumeMountConfig(BaseModel):
|
||||||
"""Configuration for a volume mount."""
|
"""Configuration for a volume mount."""
|
||||||
|
|
||||||
host_path: str = Field(..., description="Path on the host machine")
|
host_path: str = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Source path for the mount. Resolution depends on the active provider: "
|
||||||
|
"``LocalSandboxProvider`` checks this path from the gateway process — in "
|
||||||
|
"``make dev`` that is the host machine, but in Docker deployments "
|
||||||
|
"(``make up`` / docker-compose) it is the path *inside* the "
|
||||||
|
"``deer-flow-gateway`` container, so the host directory must also be "
|
||||||
|
"bind-mounted into the gateway service for the mount to take effect. "
|
||||||
|
"``AioSandboxProvider`` (DooD) passes this value straight to ``docker -v`` "
|
||||||
|
"for the sandbox container, where it is resolved by the host Docker daemon "
|
||||||
|
"from the host machine's perspective."
|
||||||
|
),
|
||||||
|
)
|
||||||
container_path: str = Field(..., description="Path inside the container")
|
container_path: str = Field(..., description="Path inside the container")
|
||||||
read_only: bool = Field(default=False, description="Whether the mount is read-only")
|
read_only: bool = Field(default=False, description="Whether the mount is read-only")
|
||||||
|
|
||||||
|
|||||||
@@ -164,7 +164,18 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
metadata={"caller": caller, **(metadata or {})},
|
metadata={"caller": caller, **(metadata or {})},
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_chain_end(
|
||||||
|
self,
|
||||||
|
outputs: Any,
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: UUID | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
# Nested chain ends fire for internal graph nodes; only the root chain
|
||||||
|
# represents the user-visible run lifecycle.
|
||||||
|
if parent_run_id is not None:
|
||||||
|
return
|
||||||
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
||||||
self._flush_sync()
|
self._flush_sync()
|
||||||
|
|
||||||
|
|||||||
@@ -147,7 +147,17 @@ class LocalSandboxProvider(SandboxProvider):
|
|||||||
mount.container_path,
|
mount.container_path,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
# Ensure the host path exists before adding mapping
|
# Ensure the host path exists before adding mapping.
|
||||||
|
#
|
||||||
|
# ``host_path`` is resolved against the filesystem of the
|
||||||
|
# process running this provider — for ``make dev`` that is
|
||||||
|
# the host machine, but for ``make up`` it is the
|
||||||
|
# ``deer-flow-gateway`` container, so any host path that
|
||||||
|
# isn't bind-mounted into the gateway image will be missing
|
||||||
|
# here. Skipping silently makes this a high-cost-to-debug
|
||||||
|
# silent failure (sandbox skill / tool reads an empty dir
|
||||||
|
# instead of the configured mount), so escalate to ERROR
|
||||||
|
# and include actionable guidance. See #3244.
|
||||||
if host_path.exists():
|
if host_path.exists():
|
||||||
mappings.append(
|
mappings.append(
|
||||||
PathMapping(
|
PathMapping(
|
||||||
@@ -157,10 +167,16 @@ class LocalSandboxProvider(SandboxProvider):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.error(
|
||||||
"Mount host_path does not exist, skipping: %s -> %s",
|
"sandbox.mounts entry %s -> %s ignored: host_path %s does not exist from the "
|
||||||
|
"perspective of the gateway process. In Docker deployments (make up / docker-compose), "
|
||||||
|
"this path must also be bind-mounted into the gateway container — add a matching "
|
||||||
|
"volume entry under services.gateway.volumes in docker/docker-compose.yaml (and use "
|
||||||
|
"the in-container path here), or run in local mode (make dev) where the gateway sees "
|
||||||
|
"the host filesystem directly.",
|
||||||
mount.host_path,
|
mount.host_path,
|
||||||
mount.container_path,
|
mount.container_path,
|
||||||
|
mount.host_path,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log but don't fail if config loading fails
|
# Log but don't fail if config loading fails
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import pytest
|
|||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
||||||
|
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||||
|
|
||||||
# ── _is_public unit tests ─────────────────────────────────────────────────
|
# ── _is_public unit tests ─────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -88,7 +89,9 @@ def test_unknown_api_path_is_protected():
|
|||||||
|
|
||||||
def _make_app():
|
def _make_app():
|
||||||
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
|
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request
|
||||||
|
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.add_middleware(AuthMiddleware)
|
app.add_middleware(AuthMiddleware)
|
||||||
@@ -98,8 +101,16 @@ def _make_app():
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
@app.get("/api/v1/auth/me")
|
@app.get("/api/v1/auth/me")
|
||||||
async def auth_me():
|
async def auth_me(request: Request):
|
||||||
return {"id": "1", "email": "test@test.com"}
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
return {
|
||||||
|
"id": str(user.id),
|
||||||
|
"email": user.email,
|
||||||
|
"system_role": user.system_role,
|
||||||
|
"needs_setup": user.needs_setup,
|
||||||
|
}
|
||||||
|
|
||||||
@app.get("/api/v1/auth/setup-status")
|
@app.get("/api/v1/auth/setup-status")
|
||||||
async def setup_status():
|
async def setup_status():
|
||||||
@@ -109,6 +120,29 @@ def _make_app():
|
|||||||
async def models_get():
|
async def models_get():
|
||||||
return {"models": []}
|
return {"models": []}
|
||||||
|
|
||||||
|
@app.get("/api/whoami")
|
||||||
|
async def whoami(request: Request):
|
||||||
|
user = request.state.user
|
||||||
|
return {
|
||||||
|
"id": str(user.id),
|
||||||
|
"email": getattr(user, "email", None),
|
||||||
|
"system_role": getattr(user, "system_role", None),
|
||||||
|
"context_user_id": get_effective_user_id(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/api/current-user-from-dep")
|
||||||
|
async def current_user_from_dep(request: Request):
|
||||||
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
state_user = request.state.user
|
||||||
|
return {
|
||||||
|
"id": str(user.id),
|
||||||
|
"state_id": str(state_user.id),
|
||||||
|
"auth_source": request.state.auth_source,
|
||||||
|
"context_user_id": get_effective_user_id(),
|
||||||
|
}
|
||||||
|
|
||||||
@app.put("/api/mcp/config")
|
@app.put("/api/mcp/config")
|
||||||
async def mcp_put():
|
async def mcp_put():
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
@@ -132,8 +166,24 @@ def _make_app():
|
|||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def _make_auth_csrf_app():
|
||||||
|
"""Create a minimal app with production middleware ordering."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(AuthMiddleware)
|
||||||
|
app.add_middleware(CSRFMiddleware)
|
||||||
|
|
||||||
|
@app.post("/api/threads/abc/runs/stream")
|
||||||
|
async def protected_mutation():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client(monkeypatch):
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
return TestClient(_make_app())
|
return TestClient(_make_app())
|
||||||
|
|
||||||
|
|
||||||
@@ -161,6 +211,139 @@ def test_protected_path_no_cookie_returns_401(client):
|
|||||||
assert body["detail"]["code"] == "not_authenticated"
|
assert body["detail"]["code"] == "not_authenticated"
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_allows_protected_path_without_cookie(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
client = TestClient(_make_app())
|
||||||
|
|
||||||
|
res = client.get("/api/models")
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {"models": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_stamps_e2e_admin_user_without_cookie(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
client = TestClient(_make_app())
|
||||||
|
|
||||||
|
res = client.get("/api/whoami")
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {
|
||||||
|
"id": "e2e-user",
|
||||||
|
"email": "e2e@test.local",
|
||||||
|
"system_role": "admin",
|
||||||
|
"context_user_id": "e2e-user",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_auth_me_reuses_middleware_user_without_cookie(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
client = TestClient(_make_app())
|
||||||
|
|
||||||
|
res = client.get("/api/v1/auth/me")
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {
|
||||||
|
"id": "e2e-user",
|
||||||
|
"email": "e2e@test.local",
|
||||||
|
"system_role": "admin",
|
||||||
|
"needs_setup": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_does_not_clobber_valid_session_cookie(monkeypatch):
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
async def fake_current_user(request):
|
||||||
|
return SimpleNamespace(
|
||||||
|
id="session-user",
|
||||||
|
email="session@test.local",
|
||||||
|
system_role="user",
|
||||||
|
needs_setup=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
monkeypatch.setattr("app.gateway.deps.get_current_user_from_request", fake_current_user)
|
||||||
|
client = TestClient(_make_app())
|
||||||
|
|
||||||
|
res = client.get("/api/whoami", cookies={"access_token": "valid-session"})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {
|
||||||
|
"id": "session-user",
|
||||||
|
"email": "session@test.local",
|
||||||
|
"system_role": "user",
|
||||||
|
"context_user_id": "session-user",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_does_not_clobber_internal_auth_identity(monkeypatch):
|
||||||
|
from app.gateway.internal_auth import create_internal_auth_headers
|
||||||
|
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
client = TestClient(_make_app())
|
||||||
|
|
||||||
|
res = client.get(
|
||||||
|
"/api/current-user-from-dep",
|
||||||
|
headers=create_internal_auth_headers(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {
|
||||||
|
"id": DEFAULT_USER_ID,
|
||||||
|
"state_id": DEFAULT_USER_ID,
|
||||||
|
"auth_source": "internal",
|
||||||
|
"context_user_id": DEFAULT_USER_ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_skips_csrf_for_state_changing_requests(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
client = TestClient(_make_auth_csrf_app())
|
||||||
|
|
||||||
|
res = client.post("/api/threads/abc/runs/stream")
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_is_ignored_in_explicit_production_env(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
monkeypatch.setenv("DEER_FLOW_ENV", "production")
|
||||||
|
client = TestClient(_make_app())
|
||||||
|
|
||||||
|
res = client.get("/api/models")
|
||||||
|
|
||||||
|
assert res.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_startup_warning_when_effective(monkeypatch, caplog):
|
||||||
|
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
monkeypatch.delenv("DEER_FLOW_ENV", raising=False)
|
||||||
|
monkeypatch.delenv("ENVIRONMENT", raising=False)
|
||||||
|
|
||||||
|
with caplog.at_level("WARNING", logger="app.gateway.auth_disabled"):
|
||||||
|
warn_if_auth_disabled_enabled()
|
||||||
|
|
||||||
|
assert "authentication is bypassed" in caplog.text
|
||||||
|
assert "e2e-user" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_startup_warning_suppressed_in_explicit_production_env(monkeypatch, caplog):
|
||||||
|
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
monkeypatch.setenv("ENVIRONMENT", "production")
|
||||||
|
|
||||||
|
with caplog.at_level("WARNING", logger="app.gateway.auth_disabled"):
|
||||||
|
warn_if_auth_disabled_enabled()
|
||||||
|
|
||||||
|
assert "authentication is bypassed" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
def test_protected_path_with_junk_cookie_rejected(client):
|
def test_protected_path_with_junk_cookie_rejected(client):
|
||||||
"""Junk cookie → 401. Middleware strictly validates the JWT now
|
"""Junk cookie → 401. Middleware strictly validates the JWT now
|
||||||
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
|
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
|
||||||
|
|||||||
@@ -2472,6 +2472,7 @@ class TestGatewayConformance:
|
|||||||
mem_cfg.fact_confidence_threshold = 0.7
|
mem_cfg.fact_confidence_threshold = 0.7
|
||||||
mem_cfg.injection_enabled = True
|
mem_cfg.injection_enabled = True
|
||||||
mem_cfg.max_injection_tokens = 2000
|
mem_cfg.max_injection_tokens = 2000
|
||||||
|
mem_cfg.token_counting = "tiktoken"
|
||||||
|
|
||||||
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
|
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
|
||||||
result = client.get_memory_config()
|
result = client.get_memory_config()
|
||||||
@@ -2479,6 +2480,7 @@ class TestGatewayConformance:
|
|||||||
parsed = MemoryConfigResponse(**result)
|
parsed = MemoryConfigResponse(**result)
|
||||||
assert parsed.enabled is True
|
assert parsed.enabled is True
|
||||||
assert parsed.max_facts == 100
|
assert parsed.max_facts == 100
|
||||||
|
assert parsed.token_counting == "tiktoken"
|
||||||
|
|
||||||
def test_get_memory_status(self, client):
|
def test_get_memory_status(self, client):
|
||||||
mem_cfg = MagicMock()
|
mem_cfg = MagicMock()
|
||||||
@@ -2489,6 +2491,7 @@ class TestGatewayConformance:
|
|||||||
mem_cfg.fact_confidence_threshold = 0.7
|
mem_cfg.fact_confidence_threshold = 0.7
|
||||||
mem_cfg.injection_enabled = True
|
mem_cfg.injection_enabled = True
|
||||||
mem_cfg.max_injection_tokens = 2000
|
mem_cfg.max_injection_tokens = 2000
|
||||||
|
mem_cfg.token_counting = "tiktoken"
|
||||||
|
|
||||||
memory_data = {
|
memory_data = {
|
||||||
"version": "1.0",
|
"version": "1.0",
|
||||||
@@ -2514,6 +2517,7 @@ class TestGatewayConformance:
|
|||||||
|
|
||||||
parsed = MemoryStatusResponse(**result)
|
parsed = MemoryStatusResponse(**result)
|
||||||
assert parsed.config.enabled is True
|
assert parsed.config.enabled is True
|
||||||
|
assert parsed.config.token_counting == "tiktoken"
|
||||||
assert parsed.data.version == "1.0"
|
assert parsed.data.version == "1.0"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
"""Regression test for the Docker Compose default Gateway worker count.
|
||||||
|
|
||||||
|
The Gateway holds run state (RunManager and the stream bridge) in process, so
|
||||||
|
the default deployment must run a single Uvicorn worker. Running more than one
|
||||||
|
worker without a shared cross-worker stream bridge breaks run cancellation, SSE
|
||||||
|
reconnects, request de-duplication, and IM channels (nginx has no sticky
|
||||||
|
sessions, so requests scatter across workers that each keep their own run
|
||||||
|
state). This test pins the safe default so it cannot silently regress to a
|
||||||
|
multi-worker default, while still allowing operators to override it once a
|
||||||
|
shared stream bridge exists.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
COMPOSE_PATH = REPO_ROOT / "docker" / "docker-compose.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def _gateway_command() -> str:
|
||||||
|
"""Return the gateway service command as a single string."""
|
||||||
|
compose = yaml.safe_load(COMPOSE_PATH.read_text(encoding="utf-8"))
|
||||||
|
command = compose["services"]["gateway"]["command"]
|
||||||
|
# ``command`` may load as a scalar string or a list depending on YAML style.
|
||||||
|
if isinstance(command, list):
|
||||||
|
command = " ".join(str(part) for part in command)
|
||||||
|
return command
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_defaults_to_single_worker():
|
||||||
|
"""With GATEWAY_WORKERS unset, the worker count must default to 1."""
|
||||||
|
command = _gateway_command()
|
||||||
|
match = re.search(r"GATEWAY_WORKERS:-(\d+)", command)
|
||||||
|
assert match is not None, f"gateway command must set a GATEWAY_WORKERS default; got: {command}"
|
||||||
|
assert match.group(1) == "1", f"default Gateway worker count must be 1, got {match.group(1)}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_worker_count_remains_overridable():
|
||||||
|
"""The worker count must stay configurable, not hard-coded to 1."""
|
||||||
|
command = _gateway_command()
|
||||||
|
assert "${GATEWAY_WORKERS:-1}" in command, f"worker count must use ${{GATEWAY_WORKERS:-1}} so operators can override it; got: {command}"
|
||||||
@@ -203,6 +203,79 @@ class TestLoadAgentConfig:
|
|||||||
assert cfg.name == "legacy-agent"
|
assert cfg.name == "legacy-agent"
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# 3b. resolve_agent_dir — memory-only directory fallback (#3390)
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveAgentDirMemoryOnlyFallback:
|
||||||
|
"""Regression tests for #3390.
|
||||||
|
|
||||||
|
When memory is enabled, the first conversation creates a user-isolated
|
||||||
|
agent directory containing only ``memory.json`` (no ``config.yaml``).
|
||||||
|
On the next turn ``resolve_agent_dir`` must fall through to the legacy
|
||||||
|
shared layout instead of returning the incomplete user directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_user_dir_with_only_memory_falls_back_to_legacy(self, tmp_path):
|
||||||
|
"""User dir has memory.json but no config.yaml → use legacy dir."""
|
||||||
|
from deerflow.config.agents_config import resolve_agent_dir
|
||||||
|
|
||||||
|
# Legacy agent with full config
|
||||||
|
legacy_dir = tmp_path / "agents" / "my-agent"
|
||||||
|
legacy_dir.mkdir(parents=True)
|
||||||
|
(legacy_dir / "config.yaml").write_text("name: my-agent\n", encoding="utf-8")
|
||||||
|
(legacy_dir / "SOUL.md").write_text("legacy soul", encoding="utf-8")
|
||||||
|
|
||||||
|
# User dir created by memory write — no config.yaml
|
||||||
|
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
|
||||||
|
user_dir.mkdir(parents=True)
|
||||||
|
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
|
||||||
|
result = resolve_agent_dir("my-agent", user_id="u1")
|
||||||
|
|
||||||
|
assert result == legacy_dir
|
||||||
|
|
||||||
|
def test_user_dir_with_config_takes_priority(self, tmp_path):
|
||||||
|
"""User dir with config.yaml should still win over legacy."""
|
||||||
|
from deerflow.config.agents_config import resolve_agent_dir
|
||||||
|
|
||||||
|
# Legacy
|
||||||
|
legacy_dir = tmp_path / "agents" / "my-agent"
|
||||||
|
legacy_dir.mkdir(parents=True)
|
||||||
|
(legacy_dir / "config.yaml").write_text("name: my-agent\n", encoding="utf-8")
|
||||||
|
|
||||||
|
# User dir with full config (migrated)
|
||||||
|
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
|
||||||
|
user_dir.mkdir(parents=True)
|
||||||
|
(user_dir / "config.yaml").write_text("name: my-agent\nmodel: gpt-4\n", encoding="utf-8")
|
||||||
|
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
|
||||||
|
result = resolve_agent_dir("my-agent", user_id="u1")
|
||||||
|
|
||||||
|
assert result == user_dir
|
||||||
|
|
||||||
|
def test_load_config_falls_back_when_user_dir_is_memory_only(self, tmp_path):
|
||||||
|
"""End-to-end: load_agent_config works when user dir only has memory.json."""
|
||||||
|
config_dict = {"name": "my-agent", "description": "Legacy agent", "model": "deepseek-v3"}
|
||||||
|
_write_agent(tmp_path, "my-agent", config_dict)
|
||||||
|
|
||||||
|
# Simulate memory write creating user dir without config
|
||||||
|
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
|
||||||
|
user_dir.mkdir(parents=True)
|
||||||
|
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
|
||||||
|
from deerflow.config.agents_config import load_agent_config
|
||||||
|
|
||||||
|
cfg = load_agent_config("my-agent", user_id="u1")
|
||||||
|
|
||||||
|
assert cfg.name == "my-agent"
|
||||||
|
assert cfg.model == "deepseek-v3"
|
||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# 4. load_agent_soul
|
# 4. load_agent_soul
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from langgraph_sdk import Auth
|
|||||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||||
from app.gateway.auth.models import User
|
from app.gateway.auth.models import User
|
||||||
|
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
|
||||||
from app.gateway.langgraph_auth import add_owner_filter, authenticate
|
from app.gateway.langgraph_auth import add_owner_filter, authenticate
|
||||||
|
|
||||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||||
@@ -59,6 +60,14 @@ def test_no_cookie_raises_401():
|
|||||||
assert "Not authenticated" in str(exc.value.detail)
|
assert "Not authenticated" in str(exc.value.detail)
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_skips_csrf_and_authenticates_e2e_user(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
|
||||||
|
identity = asyncio.run(authenticate(_req(method="POST")))
|
||||||
|
|
||||||
|
assert identity == AUTH_DISABLED_USER_ID
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_jwt_raises_401():
|
def test_invalid_jwt_raises_401():
|
||||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
asyncio.run(authenticate(_req({"access_token": "garbage"})))
|
asyncio.run(authenticate(_req({"access_token": "garbage"})))
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ def test_build_acp_section_uses_explicit_app_config_without_global_config(monkey
|
|||||||
|
|
||||||
def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch):
|
def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch):
|
||||||
explicit_config = SimpleNamespace(
|
explicit_config = SimpleNamespace(
|
||||||
memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234),
|
memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234, token_counting="tiktoken"),
|
||||||
)
|
)
|
||||||
captured: dict[str, object] = {}
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
@@ -204,9 +204,10 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke
|
|||||||
captured["user_id"] = user_id
|
captured["user_id"] = user_id
|
||||||
return {"facts": []}
|
return {"facts": []}
|
||||||
|
|
||||||
def fake_format_memory_for_injection(memory_data, *, max_tokens):
|
def fake_format_memory_for_injection(memory_data, *, max_tokens, use_tiktoken=True):
|
||||||
captured["memory_data"] = memory_data
|
captured["memory_data"] = memory_data
|
||||||
captured["max_tokens"] = max_tokens
|
captured["max_tokens"] = max_tokens
|
||||||
|
captured["use_tiktoken"] = use_tiktoken
|
||||||
return "remember this"
|
return "remember this"
|
||||||
|
|
||||||
monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config)
|
monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config)
|
||||||
@@ -223,6 +224,7 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke
|
|||||||
"user_id": "user-1",
|
"user_id": "user-1",
|
||||||
"memory_data": {"facts": []},
|
"memory_data": {"facts": []},
|
||||||
"max_tokens": 1234,
|
"max_tokens": 1234,
|
||||||
|
"use_tiktoken": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -612,6 +612,54 @@ class TestLocalSandboxProviderMounts:
|
|||||||
|
|
||||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||||
|
|
||||||
|
def test_setup_path_mappings_logs_actionable_error_for_missing_host_path(self, tmp_path, caplog):
|
||||||
|
"""Regression for #3244.
|
||||||
|
|
||||||
|
When ``sandbox.mounts[].host_path`` is absent from the gateway process's
|
||||||
|
filesystem (the typical symptom in Docker production mode: host_path is a
|
||||||
|
host machine path that is not bind-mounted into the gateway container),
|
||||||
|
the mount is still skipped — but the failure must be a hard-to-miss ERROR
|
||||||
|
log with explicit, actionable guidance about Docker bind mounts, not the
|
||||||
|
old DEBUG/WARNING that buried the silent failure.
|
||||||
|
"""
|
||||||
|
skills_dir = tmp_path / "skills"
|
||||||
|
skills_dir.mkdir()
|
||||||
|
missing_host_path = tmp_path / "does-not-exist"
|
||||||
|
|
||||||
|
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||||
|
|
||||||
|
sandbox_config = SandboxConfig(
|
||||||
|
use="deerflow.sandbox.local:LocalSandboxProvider",
|
||||||
|
mounts=[
|
||||||
|
VolumeMountConfig(host_path=str(missing_host_path), container_path="/mnt/knowledge", read_only=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
config = SimpleNamespace(
|
||||||
|
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir, use="deerflow.skills.storage.local_skill_storage:LocalSkillStorage"),
|
||||||
|
sandbox=sandbox_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
with caplog.at_level("ERROR", logger="deerflow.sandbox.local.local_sandbox_provider"):
|
||||||
|
with patch("deerflow.config.get_app_config", return_value=config):
|
||||||
|
provider = LocalSandboxProvider()
|
||||||
|
|
||||||
|
# Silent-skip behaviour is preserved (no breaking change for existing deployments).
|
||||||
|
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||||
|
|
||||||
|
# The failure must be observable at ERROR level and reference the offending paths.
|
||||||
|
error_records = [r for r in caplog.records if r.levelname == "ERROR"]
|
||||||
|
assert error_records, "expected an ERROR log when host_path is missing"
|
||||||
|
message = "\n".join(r.getMessage() for r in error_records)
|
||||||
|
assert str(missing_host_path) in message
|
||||||
|
assert "/mnt/knowledge" in message
|
||||||
|
|
||||||
|
# And it must include actionable Docker guidance so users don't lose hours
|
||||||
|
# to a silent empty-mount failure in production.
|
||||||
|
lowered = message.lower()
|
||||||
|
assert "docker" in lowered
|
||||||
|
assert "gateway" in lowered
|
||||||
|
assert "docker-compose" in lowered
|
||||||
|
|
||||||
def test_write_file_resolves_container_paths_in_content(self, tmp_path):
|
def test_write_file_resolves_container_paths_in_content(self, tmp_path):
|
||||||
"""write_file should replace container paths in file content with local paths."""
|
"""write_file should replace container paths in file content with local paths."""
|
||||||
data_dir = tmp_path / "data"
|
data_dir = tmp_path / "data"
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def test_format_memory_sorts_facts_by_confidence_desc() -> None:
|
|||||||
|
|
||||||
def test_format_memory_respects_budget_when_adding_facts(monkeypatch) -> None:
|
def test_format_memory_respects_budget_when_adding_facts(monkeypatch) -> None:
|
||||||
# Make token counting deterministic for this test by counting characters.
|
# Make token counting deterministic for this test by counting characters.
|
||||||
monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base": len(text))
|
monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base", *, use_tiktoken=True: len(text))
|
||||||
|
|
||||||
memory_data = {
|
memory_data = {
|
||||||
"user": {},
|
"user": {},
|
||||||
|
|||||||
@@ -179,15 +179,16 @@ class TestLifecycleCallbacks:
|
|||||||
assert "run.end" in types
|
assert "run.end" in types
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_nested_chain_no_run_start(self, journal_setup):
|
async def test_nested_chain_no_run_lifecycle_events(self, journal_setup):
|
||||||
"""Nested chains (parent_run_id set) should NOT produce run.start."""
|
"""Nested chains (parent_run_id set) should NOT produce root run lifecycle events."""
|
||||||
j, store = journal_setup
|
j, store = journal_setup
|
||||||
parent_id = uuid4()
|
parent_id = uuid4()
|
||||||
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
|
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
|
||||||
j.on_chain_end({}, run_id=uuid4())
|
j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id)
|
||||||
await j.flush()
|
await j.flush()
|
||||||
events = await store.list_events("t1", "r1")
|
events = await store.list_events("t1", "r1")
|
||||||
assert not any(e["event_type"] == "run.start" for e in events)
|
assert not any(e["event_type"] == "run.start" for e in events)
|
||||||
|
assert not any(e["event_type"] == "run.end" for e in events)
|
||||||
|
|
||||||
|
|
||||||
class TestToolCallbacks:
|
class TestToolCallbacks:
|
||||||
|
|||||||
@@ -0,0 +1,173 @@
|
|||||||
|
"""Cross-user isolation for the stateless ``POST /api/runs/stream`` and ``/wait`` endpoints.
|
||||||
|
|
||||||
|
These endpoints receive ``thread_id`` in the request body, so the
|
||||||
|
``@require_permission(owner_check=True)`` decorator — which reads the
|
||||||
|
``thread_id`` *path* parameter — cannot protect them. The owner check
|
||||||
|
lives inside ``services.start_run()`` instead; this suite pins it at the
|
||||||
|
HTTP layer so the gap cannot silently reopen.
|
||||||
|
|
||||||
|
Strategy
|
||||||
|
--------
|
||||||
|
``app.state.run_manager.create_or_reject`` raises ``ConflictError``, so a
|
||||||
|
request that *passes* the owner check deterministically short-circuits
|
||||||
|
with 409 before any agent code runs. The two outcomes:
|
||||||
|
|
||||||
|
- 404 + ``create_or_reject`` never awaited -> blocked by the owner check
|
||||||
|
- 409 + ``create_or_reject`` awaited -> passed the owner check
|
||||||
|
|
||||||
|
The thread store is a real ``MemoryThreadMetaStore`` (not a mock) so the
|
||||||
|
``check_access`` semantics under test — missing row allows, ``user_id``
|
||||||
|
NULL allows, foreign owner denies — are exercised through real code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from _router_auth_helpers import make_authed_test_app
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from langgraph.store.memory import InMemoryStore
|
||||||
|
|
||||||
|
from app.gateway.auth.models import User
|
||||||
|
from app.gateway.routers import runs
|
||||||
|
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
|
||||||
|
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||||
|
from deerflow.runtime import ConflictError
|
||||||
|
|
||||||
|
USER_A = User(email="owner-a@example.com", password_hash="x", system_role="user", id=uuid4())
|
||||||
|
USER_B = User(email="intruder-b@example.com", password_hash="x", system_role="user", id=uuid4())
|
||||||
|
INTERNAL_USER = SimpleNamespace(id="default", system_role="internal")
|
||||||
|
|
||||||
|
THREAD_A = "thread-owned-by-a"
|
||||||
|
THREAD_SHARED = "thread-shared-null-owner"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _stub_app_config():
|
||||||
|
"""Inject a minimal AppConfig so the allowed path (which builds a
|
||||||
|
RunContext via ``get_config()``) never reads config.yaml from disk."""
|
||||||
|
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
|
||||||
|
yield
|
||||||
|
reset_app_config()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_thread_store() -> MemoryThreadMetaStore:
|
||||||
|
store = MemoryThreadMetaStore(InMemoryStore())
|
||||||
|
|
||||||
|
async def _seed():
|
||||||
|
await store.create(THREAD_A, user_id=str(USER_A.id))
|
||||||
|
await store.create(THREAD_SHARED, user_id=None)
|
||||||
|
|
||||||
|
asyncio.run(_seed())
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _client(user):
|
||||||
|
"""Yield a ``TestClient`` authenticated as ``user`` plus the stubbed
|
||||||
|
``create_or_reject`` mock, closing the client (and its anyio portal /
|
||||||
|
background threads) on exit.
|
||||||
|
|
||||||
|
``create_or_reject`` raises ``ConflictError`` so a request that passes the
|
||||||
|
owner check short-circuits to 409 before any agent code runs.
|
||||||
|
"""
|
||||||
|
app = make_authed_test_app(user_factory=lambda: user)
|
||||||
|
app.include_router(runs.router)
|
||||||
|
app.state.thread_store = _make_thread_store()
|
||||||
|
app.state.stream_bridge = MagicMock()
|
||||||
|
app.state.checkpointer = MagicMock()
|
||||||
|
app.state.store = MagicMock()
|
||||||
|
app.state.run_events_config = None
|
||||||
|
app.state.run_event_store = MagicMock()
|
||||||
|
run_manager = MagicMock()
|
||||||
|
run_manager.create_or_reject = AsyncMock(side_effect=ConflictError("sentinel: owner check passed"))
|
||||||
|
app.state.run_manager = run_manager
|
||||||
|
with TestClient(app) as client:
|
||||||
|
yield client, run_manager.create_or_reject
|
||||||
|
|
||||||
|
|
||||||
|
def _body(thread_id: str | None = None) -> dict:
|
||||||
|
if thread_id is None:
|
||||||
|
return {}
|
||||||
|
return {"config": {"configurable": {"thread_id": thread_id}}}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Denied: another user's thread
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_cross_user_returns_404():
|
||||||
|
"""User B cannot start a run on user A's thread via /api/runs/stream."""
|
||||||
|
with _client(USER_B) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/stream", json=_body(THREAD_A))
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert response.json()["detail"] == f"Thread {THREAD_A} not found"
|
||||||
|
create_or_reject.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_wait_cross_user_returns_404_without_channel_values():
|
||||||
|
"""User B cannot read user A's checkpoint state via /api/runs/wait."""
|
||||||
|
with _client(USER_B) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/wait", json=_body(THREAD_A))
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert response.json() == {"detail": f"Thread {THREAD_A} not found"}
|
||||||
|
create_or_reject.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Allowed: owner, fresh/untracked/shared threads, internal role
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_owner_passes_owner_check():
|
||||||
|
"""User A reaches run creation on their own thread (409 sentinel)."""
|
||||||
|
with _client(USER_A) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/stream", json=_body(THREAD_A))
|
||||||
|
assert response.status_code == 409
|
||||||
|
create_or_reject.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_wait_owner_passes_owner_check():
|
||||||
|
with _client(USER_A) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/wait", json=_body(THREAD_A))
|
||||||
|
assert response.status_code == 409
|
||||||
|
create_or_reject.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_without_thread_id_passes_owner_check():
|
||||||
|
"""Stateless run with no thread_id auto-creates a thread — never blocked."""
|
||||||
|
with _client(USER_B) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/stream", json=_body())
|
||||||
|
assert response.status_code == 409
|
||||||
|
create_or_reject.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_untracked_thread_passes_owner_check():
|
||||||
|
"""A thread_id with no thread_meta row (untracked legacy) stays accessible."""
|
||||||
|
with _client(USER_B) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/stream", json=_body("never-created-thread"))
|
||||||
|
assert response.status_code == 409
|
||||||
|
create_or_reject.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_shared_thread_passes_owner_check():
|
||||||
|
"""A thread_meta row with user_id NULL (shared / pre-auth data) stays accessible."""
|
||||||
|
with _client(USER_B) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/stream", json=_body(THREAD_SHARED))
|
||||||
|
assert response.status_code == 409
|
||||||
|
create_or_reject.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_internal_role_bypasses_owner_check():
|
||||||
|
"""IM channels run with the internal system role on behalf of platform
|
||||||
|
users whose threads they do not own — the owner check must not break them."""
|
||||||
|
with _client(INTERNAL_USER) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/stream", json=_body(THREAD_A))
|
||||||
|
assert response.status_code == 409
|
||||||
|
create_or_reject.assert_awaited()
|
||||||
@@ -5,18 +5,22 @@ Verifies:
|
|||||||
- ``_count_tokens`` falls back to character estimation when tiktoken is
|
- ``_count_tokens`` falls back to character estimation when tiktoken is
|
||||||
unavailable or the encoding fails to load.
|
unavailable or the encoding fails to load.
|
||||||
- ``warm_tiktoken_cache`` populates the cache on success.
|
- ``warm_tiktoken_cache`` populates the cache on success.
|
||||||
|
- An in-flight tiktoken load prevents duplicate blocking downloads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from deerflow.agents.memory.prompt import (
|
from deerflow.agents.memory.prompt import (
|
||||||
_count_tokens,
|
_count_tokens,
|
||||||
_get_tiktoken_encoding,
|
_get_tiktoken_encoding,
|
||||||
_tiktoken_encoding_cache,
|
_tiktoken_encoding_cache,
|
||||||
|
format_memory_for_injection,
|
||||||
warm_tiktoken_cache,
|
warm_tiktoken_cache,
|
||||||
)
|
)
|
||||||
|
from deerflow.config.memory_config import MemoryConfig
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _get_tiktoken_encoding
|
# _get_tiktoken_encoding
|
||||||
@@ -62,14 +66,103 @@ class TestGetTiktokenEncoding:
|
|||||||
assert enc is fake_enc
|
assert enc is fake_enc
|
||||||
tiktoken.get_encoding.assert_not_called()
|
tiktoken.get_encoding.assert_not_called()
|
||||||
|
|
||||||
def test_returns_none_and_warns_on_get_encoding_failure(self, monkeypatch):
|
def test_returns_none_and_caches_failure_sentinel(self, monkeypatch):
|
||||||
|
"""A failed load is cached (with a timestamp) so it is not re-attempted (no repeated network download)."""
|
||||||
_tiktoken_encoding_cache.pop("bogus_encoding", None)
|
_tiktoken_encoding_cache.pop("bogus_encoding", None)
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=OSError("download failed")))
|
get_encoding = mock.Mock(side_effect=OSError("download failed"))
|
||||||
|
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
|
||||||
|
|
||||||
result = _get_tiktoken_encoding("bogus_encoding")
|
result = _get_tiktoken_encoding("bogus_encoding")
|
||||||
assert result is None
|
assert result is None
|
||||||
assert "bogus_encoding" not in _tiktoken_encoding_cache
|
# The failure is remembered as a (None, timestamp) tuple.
|
||||||
|
assert "bogus_encoding" in _tiktoken_encoding_cache
|
||||||
|
cached = _tiktoken_encoding_cache["bogus_encoding"]
|
||||||
|
assert isinstance(cached, tuple)
|
||||||
|
assert cached[0] is None
|
||||||
|
|
||||||
|
# A second call must NOT re-attempt get_encoding (avoids re-blocking on
|
||||||
|
# the network download in restricted environments — see #3429).
|
||||||
|
result2 = _get_tiktoken_encoding("bogus_encoding")
|
||||||
|
assert result2 is None
|
||||||
|
assert get_encoding.call_count == 1
|
||||||
|
|
||||||
|
# Cleanup module-level cache to avoid cross-test leakage.
|
||||||
|
_tiktoken_encoding_cache.pop("bogus_encoding", None)
|
||||||
|
|
||||||
|
def test_failure_self_heals_after_cooldown(self, monkeypatch):
|
||||||
|
"""After the retry cooldown expires, a transient failure is re-attempted and can recover."""
|
||||||
|
_tiktoken_encoding_cache.pop("flaky_encoding", None)
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
# First call fails, second call (after cooldown) succeeds.
|
||||||
|
get_encoding = mock.Mock(side_effect=[OSError("transient outage"), fake_enc])
|
||||||
|
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
|
||||||
|
|
||||||
|
# Initial failure is cached.
|
||||||
|
assert _get_tiktoken_encoding("flaky_encoding") is None
|
||||||
|
assert get_encoding.call_count == 1
|
||||||
|
|
||||||
|
# Within the cooldown window: no retry, immediate fallback.
|
||||||
|
assert _get_tiktoken_encoding("flaky_encoding") is None
|
||||||
|
assert get_encoding.call_count == 1
|
||||||
|
|
||||||
|
# Simulate the cooldown having elapsed by ageing the cached timestamp.
|
||||||
|
from deerflow.agents.memory import prompt as prompt_module
|
||||||
|
|
||||||
|
_, _failed_at = _tiktoken_encoding_cache["flaky_encoding"]
|
||||||
|
_tiktoken_encoding_cache["flaky_encoding"] = (
|
||||||
|
None,
|
||||||
|
_failed_at - prompt_module._TIKTOKEN_RETRY_COOLDOWN_S - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now the load is retried and recovers to accurate counting.
|
||||||
|
assert _get_tiktoken_encoding("flaky_encoding") is fake_enc
|
||||||
|
assert get_encoding.call_count == 2
|
||||||
|
|
||||||
|
_tiktoken_encoding_cache.pop("flaky_encoding", None)
|
||||||
|
|
||||||
|
def test_in_flight_load_returns_none_without_duplicate_get_encoding(self, monkeypatch):
|
||||||
|
"""Concurrent callers must not start duplicate blocking BPE downloads."""
|
||||||
|
_tiktoken_encoding_cache.pop("slow_encoding", None)
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
started = threading.Event()
|
||||||
|
release = threading.Event()
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
|
||||||
|
def slow_get_encoding(_name):
|
||||||
|
started.set()
|
||||||
|
assert release.wait(timeout=2), "test timed out waiting to release slow get_encoding"
|
||||||
|
return fake_enc
|
||||||
|
|
||||||
|
get_encoding = mock.Mock(side_effect=slow_get_encoding)
|
||||||
|
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
|
||||||
|
|
||||||
|
result: dict[str, object | None] = {}
|
||||||
|
|
||||||
|
def load_encoding():
|
||||||
|
result["encoding"] = _get_tiktoken_encoding("slow_encoding")
|
||||||
|
|
||||||
|
thread = threading.Thread(target=load_encoding)
|
||||||
|
thread.start()
|
||||||
|
try:
|
||||||
|
assert started.wait(timeout=1), "slow get_encoding did not start"
|
||||||
|
|
||||||
|
# While the first call is still blocked, a second call should see
|
||||||
|
# the in-flight sentinel and fall back immediately instead of
|
||||||
|
# starting another potentially long network download.
|
||||||
|
assert _get_tiktoken_encoding("slow_encoding") is None
|
||||||
|
assert get_encoding.call_count == 1
|
||||||
|
finally:
|
||||||
|
release.set()
|
||||||
|
thread.join(timeout=2)
|
||||||
|
_tiktoken_encoding_cache.pop("slow_encoding", None)
|
||||||
|
|
||||||
|
assert result["encoding"] is fake_enc
|
||||||
|
assert get_encoding.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -115,6 +208,45 @@ class TestCountTokens:
|
|||||||
result = _count_tokens(text, encoding_name="test_enc")
|
result = _count_tokens(text, encoding_name="test_enc")
|
||||||
assert result == len(text) // 4
|
assert result == len(text) // 4
|
||||||
|
|
||||||
|
def test_use_tiktoken_false_returns_char_estimate_without_touching_tiktoken(self, monkeypatch):
|
||||||
|
"""use_tiktoken=False must never call tiktoken (guarantees no BPE download)."""
|
||||||
|
# Spy on both the encoding loader and tiktoken.get_encoding directly.
|
||||||
|
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
|
||||||
|
loader_spy = mock.Mock(side_effect=AssertionError("_get_tiktoken_encoding must not be called"))
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", loader_spy)
|
||||||
|
|
||||||
|
text = "Hello, world! This is a network-free count."
|
||||||
|
result = _count_tokens(text, use_tiktoken=False)
|
||||||
|
assert result == len(text) // 4
|
||||||
|
get_encoding_spy.assert_not_called()
|
||||||
|
loader_spy.assert_not_called()
|
||||||
|
|
||||||
|
def test_cjk_estimate_is_denser_than_plain_quarter(self, monkeypatch):
|
||||||
|
"""CJK text should estimate more tokens than the plain len // 4 heuristic.
|
||||||
|
|
||||||
|
CJK characters are ~2 chars/token, so the char-based estimate must not
|
||||||
|
under-fill the budget the way ``len(text) // 4`` would.
|
||||||
|
"""
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
||||||
|
# "User prefers concise answers" rendered in CJK (Chinese) characters.
|
||||||
|
text = "\u7528\u6237\u504f\u597d\u7b80\u6d01\u7684\u4e2d\u6587\u56de\u7b54\u5e76\u5173\u6ce8\u91d1\u878d\u9886\u57df"
|
||||||
|
result = _count_tokens(text)
|
||||||
|
# Each CJK char counts as ~1/2 token (vs 1/4 for the plain heuristic).
|
||||||
|
assert result == len(text) // 2
|
||||||
|
assert result > len(text) // 4
|
||||||
|
|
||||||
|
def test_cjk_estimate_combines_cjk_and_non_cjk_characters(self, monkeypatch):
|
||||||
|
"""Mixed-language text should apply the CJK density only to CJK chars."""
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
||||||
|
# ASCII words mixed with CJK (Chinese) characters: "User" + "likes" + "Python and data analysis".
|
||||||
|
text = "User\u559c\u6b22Python\u548c\u6570\u636e\u5206\u6790"
|
||||||
|
cjk = sum(1 for ch in text if "\u4e00" <= ch <= "\u9fff")
|
||||||
|
|
||||||
|
result = _count_tokens(text)
|
||||||
|
|
||||||
|
assert result == (len(text) - cjk) // 4 + cjk // 2
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# warm_tiktoken_cache
|
# warm_tiktoken_cache
|
||||||
@@ -146,3 +278,69 @@ class TestWarmTiktokenCache:
|
|||||||
def test_returns_false_when_tiktoken_unavailable(self, monkeypatch):
|
def test_returns_false_when_tiktoken_unavailable(self, monkeypatch):
|
||||||
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
||||||
assert warm_tiktoken_cache() is False
|
assert warm_tiktoken_cache() is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# format_memory_for_injection token_counting strategy
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatMemoryForInjectionTokenCounting:
|
||||||
|
"""Verify the use_tiktoken flag is honoured end-to-end."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sample_memory() -> dict:
|
||||||
|
return {
|
||||||
|
"facts": [
|
||||||
|
{"content": "User prefers concise answers.", "category": "preference", "confidence": 0.9},
|
||||||
|
{"content": "User works in the finance domain.", "category": "context", "confidence": 0.8},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_use_tiktoken_false_never_touches_tiktoken(self, monkeypatch):
|
||||||
|
"""With use_tiktoken=False, formatting must not call tiktoken at all."""
|
||||||
|
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
|
||||||
|
|
||||||
|
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=False)
|
||||||
|
assert "User prefers concise answers." in result
|
||||||
|
get_encoding_spy.assert_not_called()
|
||||||
|
|
||||||
|
def test_use_tiktoken_true_uses_encoding(self, monkeypatch):
|
||||||
|
"""With use_tiktoken=True (default), the cached encoding is used for counting."""
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
fake_enc.encode.side_effect = lambda text: list(range(len(text)))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"deerflow.agents.memory.prompt._get_tiktoken_encoding",
|
||||||
|
mock.Mock(return_value=fake_enc),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=True)
|
||||||
|
assert "User prefers concise answers." in result
|
||||||
|
assert fake_enc.encode.called
|
||||||
|
|
||||||
|
def test_empty_memory_returns_empty(self):
|
||||||
|
assert format_memory_for_injection({}, max_tokens=2000, use_tiktoken=False) == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MemoryConfig.token_counting
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryConfigTokenCounting:
|
||||||
|
"""Verify the new config field defaults and validation."""
|
||||||
|
|
||||||
|
def test_default_is_tiktoken(self):
|
||||||
|
"""Default must remain tiktoken so existing deployments are unaffected."""
|
||||||
|
assert MemoryConfig().token_counting == "tiktoken"
|
||||||
|
|
||||||
|
def test_accepts_char(self):
|
||||||
|
assert MemoryConfig(token_counting="char").token_counting == "char"
|
||||||
|
|
||||||
|
def test_rejects_invalid_value(self):
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
MemoryConfig(token_counting="invalid")
|
||||||
|
|||||||
+15
-2
@@ -15,7 +15,7 @@
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Bump this number when the config schema changes.
|
# Bump this number when the config schema changes.
|
||||||
# Run `make config-upgrade` to merge new fields into your local config.yaml.
|
# Run `make config-upgrade` to merge new fields into your local config.yaml.
|
||||||
config_version: 11
|
config_version: 12
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Logging
|
# Logging
|
||||||
@@ -768,8 +768,12 @@ sandbox:
|
|||||||
allow_host_bash: false
|
allow_host_bash: false
|
||||||
# Optional: Mount additional host directories into the sandbox.
|
# Optional: Mount additional host directories into the sandbox.
|
||||||
# Each mount maps a host path to a virtual container path accessible by the agent.
|
# Each mount maps a host path to a virtual container path accessible by the agent.
|
||||||
|
# Note: with LocalSandboxProvider under `make up` (docker-compose), host_path is
|
||||||
|
# checked from inside the deer-flow-gateway container — you must also bind-mount
|
||||||
|
# the same directory into services.gateway.volumes in docker/docker-compose.yaml
|
||||||
|
# for this mount to take effect (see issue #3244).
|
||||||
# mounts:
|
# mounts:
|
||||||
# - host_path: /home/user/my-project # Absolute path on the host machine
|
# - host_path: /home/user/my-project # Absolute path; see note above for Docker mode
|
||||||
# container_path: /mnt/my-project # Virtual path inside the sandbox
|
# container_path: /mnt/my-project # Virtual path inside the sandbox
|
||||||
# read_only: true # Whether the mount is read-only (default: false)
|
# read_only: true # Whether the mount is read-only (default: false)
|
||||||
|
|
||||||
@@ -1020,6 +1024,15 @@ memory:
|
|||||||
fact_confidence_threshold: 0.7 # Minimum confidence for storing facts
|
fact_confidence_threshold: 0.7 # Minimum confidence for storing facts
|
||||||
injection_enabled: true # Whether to inject memory into system prompt
|
injection_enabled: true # Whether to inject memory into system prompt
|
||||||
max_injection_tokens: 2000 # Maximum tokens for memory injection
|
max_injection_tokens: 2000 # Maximum tokens for memory injection
|
||||||
|
# Token counting strategy for memory-injection budgeting:
|
||||||
|
# tiktoken (default) - accurate, but the encoding's BPE data may be
|
||||||
|
# downloaded from a public network endpoint on first use. In
|
||||||
|
# network-restricted environments this download can block for a long
|
||||||
|
# time (see issues #3402 / #3429). Pre-cache the encoding or set this
|
||||||
|
# to "char" to avoid it.
|
||||||
|
# char - network-free CJK-aware character-based estimate; never touches
|
||||||
|
# tiktoken. Slightly less precise budgeting, zero network I/O.
|
||||||
|
token_counting: tiktoken
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Custom Agent Management API
|
# Custom Agent Management API
|
||||||
|
|||||||
@@ -72,7 +72,13 @@ services:
|
|||||||
UV_INDEX_URL: ${UV_INDEX_URL:-https://pypi.org/simple}
|
UV_INDEX_URL: ${UV_INDEX_URL:-https://pypi.org/simple}
|
||||||
UV_EXTRAS: ${UV_EXTRAS:-}
|
UV_EXTRAS: ${UV_EXTRAS:-}
|
||||||
container_name: deer-flow-gateway
|
container_name: deer-flow-gateway
|
||||||
command: sh -c "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --workers ${GATEWAY_WORKERS:-4}"
|
# Gateway hosts the agent runtime with in-process RunManager + StreamBridge
|
||||||
|
# singletons -- run state lives in this worker's memory. Default to a single
|
||||||
|
# worker: with >1 worker and no nginx sticky sessions, run cancel, SSE
|
||||||
|
# reconnect, request dedup, and per-worker IM channel services all break
|
||||||
|
# across workers until a shared (e.g. redis) stream bridge lands, which is
|
||||||
|
# not yet implemented. Override GATEWAY_WORKERS only once that is in place.
|
||||||
|
command: sh -c "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --workers ${GATEWAY_WORKERS:-1}"
|
||||||
volumes:
|
volumes:
|
||||||
- ${DEER_FLOW_CONFIG_PATH}:/app/backend/config.yaml:ro
|
- ${DEER_FLOW_CONFIG_PATH}:/app/backend/config.yaml:ro
|
||||||
- ${DEER_FLOW_EXTENSIONS_CONFIG_PATH}:/app/backend/extensions_config.json:ro
|
- ${DEER_FLOW_EXTENSIONS_CONFIG_PATH}:/app/backend/extensions_config.json:ro
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ import { defineConfig, devices } from "@playwright/test";
|
|||||||
* so the mock-based suite is untouched.
|
* so the mock-based suite is untouched.
|
||||||
*
|
*
|
||||||
* Two webServers are started: the replay gateway (:8011) and the frontend
|
* Two webServers are started: the replay gateway (:8011) and the frontend
|
||||||
* (:3000, pointed at the gateway). Auth uses a throwaway test account the spec
|
* (:3000, pointed at the gateway). Auth-disabled mode is enabled on both
|
||||||
* registers at runtime — no secrets.
|
* servers so the no-cookie e2e contract is covered; specs that need session
|
||||||
|
* cookies still register a throwaway test account at runtime.
|
||||||
*/
|
*/
|
||||||
export default defineConfig({
|
export default defineConfig({
|
||||||
testDir: "./tests/e2e-real-backend",
|
testDir: "./tests/e2e-real-backend",
|
||||||
@@ -38,7 +39,10 @@ export default defineConfig({
|
|||||||
// Mount the test-only run/message seeder used by multi-run-order.spec.ts
|
// Mount the test-only run/message seeder used by multi-run-order.spec.ts
|
||||||
// (#3352). The endpoint exists only on this replay gateway, never in the
|
// (#3352). The endpoint exists only on this replay gateway, never in the
|
||||||
// production app.
|
// production app.
|
||||||
env: { DEERFLOW_ENABLE_TEST_SEED: "1" },
|
env: {
|
||||||
|
DEERFLOW_ENABLE_TEST_SEED: "1",
|
||||||
|
DEER_FLOW_AUTH_DISABLED: "1",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
command: "pnpm build && pnpm start",
|
command: "pnpm build && pnpm start",
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { useEffect, useMemo, useState } from "react";
|
import { useEffect, useMemo, useRef, useState } from "react";
|
||||||
|
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
import { Input } from "@/components/ui/input";
|
import { Input } from "@/components/ui/input";
|
||||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||||
import {
|
import {
|
||||||
@@ -11,24 +12,58 @@ import {
|
|||||||
WorkspaceHeader,
|
WorkspaceHeader,
|
||||||
} from "@/components/workspace/workspace-container";
|
} from "@/components/workspace/workspace-container";
|
||||||
import { useI18n } from "@/core/i18n/hooks";
|
import { useI18n } from "@/core/i18n/hooks";
|
||||||
import { useThreads } from "@/core/threads/hooks";
|
import { useInfiniteThreads } from "@/core/threads/hooks";
|
||||||
import { pathOfThread, titleOfThread } from "@/core/threads/utils";
|
import { pathOfThread, titleOfThread } from "@/core/threads/utils";
|
||||||
import { formatTimeAgo } from "@/core/utils/datetime";
|
import { formatTimeAgo } from "@/core/utils/datetime";
|
||||||
|
|
||||||
export default function ChatsPage() {
|
export default function ChatsPage() {
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
const { data: threads } = useThreads();
|
const {
|
||||||
|
data: infiniteThreads,
|
||||||
|
fetchNextPage,
|
||||||
|
hasNextPage,
|
||||||
|
isFetchingNextPage,
|
||||||
|
} = useInfiniteThreads();
|
||||||
|
const threads = useMemo(
|
||||||
|
() => infiniteThreads?.pages.flat() ?? [],
|
||||||
|
[infiniteThreads],
|
||||||
|
);
|
||||||
const [search, setSearch] = useState("");
|
const [search, setSearch] = useState("");
|
||||||
|
const isSearching = search.trim().length > 0;
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
document.title = `${t.pages.chats} - ${t.pages.appName}`;
|
document.title = `${t.pages.chats} - ${t.pages.appName}`;
|
||||||
}, [t.pages.chats, t.pages.appName]);
|
}, [t.pages.chats, t.pages.appName]);
|
||||||
|
|
||||||
const filteredThreads = useMemo(() => {
|
const filteredThreads = useMemo(() => {
|
||||||
return threads?.filter((thread) => {
|
return threads.filter((thread) => {
|
||||||
return titleOfThread(thread).toLowerCase().includes(search.toLowerCase());
|
return titleOfThread(thread).toLowerCase().includes(search.toLowerCase());
|
||||||
});
|
});
|
||||||
}, [threads, search]);
|
}, [threads, search]);
|
||||||
|
|
||||||
|
// Sentinel-based auto load-more for the unfiltered list (issue #3482).
|
||||||
|
// In search mode we deliberately do NOT auto-paginate, otherwise an empty
|
||||||
|
// filtered view would keep the sentinel in the viewport and drain the
|
||||||
|
// entire backend list one page at a time. Searching falls back to an
|
||||||
|
// explicit button so users can still reach older conversations on demand.
|
||||||
|
const sentinelRef = useRef<HTMLDivElement | null>(null);
|
||||||
|
useEffect(() => {
|
||||||
|
const element = sentinelRef.current;
|
||||||
|
if (!element || !hasNextPage || isSearching) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const observer = new IntersectionObserver(
|
||||||
|
([entry]) => {
|
||||||
|
if (entry?.isIntersecting && hasNextPage && !isFetchingNextPage) {
|
||||||
|
void fetchNextPage();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ rootMargin: "200px 0px 200px 0px" },
|
||||||
|
);
|
||||||
|
observer.observe(element);
|
||||||
|
return () => observer.disconnect();
|
||||||
|
}, [fetchNextPage, hasNextPage, isFetchingNextPage, isSearching]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<WorkspaceContainer>
|
<WorkspaceContainer>
|
||||||
<WorkspaceHeader></WorkspaceHeader>
|
<WorkspaceHeader></WorkspaceHeader>
|
||||||
@@ -61,6 +96,28 @@ export default function ChatsPage() {
|
|||||||
</div>
|
</div>
|
||||||
</Link>
|
</Link>
|
||||||
))}
|
))}
|
||||||
|
{hasNextPage && !isSearching && (
|
||||||
|
<div
|
||||||
|
ref={sentinelRef}
|
||||||
|
aria-hidden="true"
|
||||||
|
className="h-px w-full"
|
||||||
|
data-testid="chats-page-sentinel"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{hasNextPage && isSearching && (
|
||||||
|
<div className="flex justify-center p-4">
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
onClick={() => void fetchNextPage()}
|
||||||
|
disabled={isFetchingNextPage}
|
||||||
|
data-testid="chats-page-load-more"
|
||||||
|
>
|
||||||
|
{isFetchingNextPage
|
||||||
|
? t.chats.loadingMore
|
||||||
|
: t.chats.loadMoreToSearch}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</ScrollArea>
|
</ScrollArea>
|
||||||
</main>
|
</main>
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import {
|
|||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { useParams, usePathname, useRouter } from "next/navigation";
|
import { useParams, usePathname, useRouter } from "next/navigation";
|
||||||
import { useCallback, useState } from "react";
|
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
|
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
@@ -51,8 +51,8 @@ import {
|
|||||||
} from "@/core/threads/export";
|
} from "@/core/threads/export";
|
||||||
import {
|
import {
|
||||||
useDeleteThread,
|
useDeleteThread,
|
||||||
|
useInfiniteThreads,
|
||||||
useRenameThread,
|
useRenameThread,
|
||||||
useThreads,
|
|
||||||
} from "@/core/threads/hooks";
|
} from "@/core/threads/hooks";
|
||||||
import type { AgentThread, AgentThreadState } from "@/core/threads/types";
|
import type { AgentThread, AgentThreadState } from "@/core/threads/types";
|
||||||
import { pathOfThread, titleOfThread } from "@/core/threads/utils";
|
import { pathOfThread, titleOfThread } from "@/core/threads/utils";
|
||||||
@@ -68,7 +68,35 @@ export function RecentChatList() {
|
|||||||
thread_id: string;
|
thread_id: string;
|
||||||
agent_name?: string;
|
agent_name?: string;
|
||||||
}>();
|
}>();
|
||||||
const { data: threads = [] } = useThreads();
|
const {
|
||||||
|
data: infiniteThreads,
|
||||||
|
fetchNextPage,
|
||||||
|
hasNextPage,
|
||||||
|
isFetchingNextPage,
|
||||||
|
} = useInfiniteThreads();
|
||||||
|
const threads = useMemo(
|
||||||
|
() => infiniteThreads?.pages.flat() ?? [],
|
||||||
|
[infiniteThreads],
|
||||||
|
);
|
||||||
|
|
||||||
|
const sentinelRef = useRef<HTMLDivElement | null>(null);
|
||||||
|
useEffect(() => {
|
||||||
|
const element = sentinelRef.current;
|
||||||
|
if (!element || !hasNextPage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const observer = new IntersectionObserver(
|
||||||
|
([entry]) => {
|
||||||
|
if (entry?.isIntersecting && hasNextPage && !isFetchingNextPage) {
|
||||||
|
void fetchNextPage();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ rootMargin: "120px 0px 120px 0px" },
|
||||||
|
);
|
||||||
|
observer.observe(element);
|
||||||
|
return () => observer.disconnect();
|
||||||
|
}, [fetchNextPage, hasNextPage, isFetchingNextPage]);
|
||||||
|
|
||||||
const { mutate: deleteThread } = useDeleteThread();
|
const { mutate: deleteThread } = useDeleteThread();
|
||||||
const { mutate: renameThread } = useRenameThread();
|
const { mutate: renameThread } = useRenameThread();
|
||||||
|
|
||||||
@@ -267,6 +295,28 @@ export function RecentChatList() {
|
|||||||
</SidebarMenuItem>
|
</SidebarMenuItem>
|
||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
|
{hasNextPage && (
|
||||||
|
<>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
className="mx-2 my-1 w-[calc(100%-1rem)] justify-center text-xs"
|
||||||
|
onClick={() => void fetchNextPage()}
|
||||||
|
disabled={isFetchingNextPage}
|
||||||
|
data-testid="recent-chat-list-load-more"
|
||||||
|
>
|
||||||
|
{isFetchingNextPage
|
||||||
|
? t.chats.loadingMore
|
||||||
|
: t.chats.loadOlderChats}
|
||||||
|
</Button>
|
||||||
|
<div
|
||||||
|
ref={sentinelRef}
|
||||||
|
aria-hidden="true"
|
||||||
|
className="h-px w-full"
|
||||||
|
data-testid="recent-chat-list-sentinel"
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</SidebarMenu>
|
</SidebarMenu>
|
||||||
</SidebarGroupContent>
|
</SidebarGroupContent>
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
import type { User } from "./types";
|
||||||
|
|
||||||
|
export const AUTH_DISABLED_USER: User = {
|
||||||
|
id: "e2e-user",
|
||||||
|
email: "e2e@test.local",
|
||||||
|
system_role: "admin",
|
||||||
|
needs_setup: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
const PRODUCTION_ENV_VALUES = new Set(["prod", "production"]);
|
||||||
|
|
||||||
|
function isExplicitProductionEnvironment() {
|
||||||
|
return ["DEER_FLOW_ENV", "ENVIRONMENT"].some((name) =>
|
||||||
|
PRODUCTION_ENV_VALUES.has((process.env[name] ?? "").trim().toLowerCase()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isAuthDisabledMode() {
|
||||||
|
return (
|
||||||
|
process.env.DEER_FLOW_AUTH_DISABLED === "1" &&
|
||||||
|
!isExplicitProductionEnvironment()
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ import { cookies } from "next/headers";
|
|||||||
|
|
||||||
import { isStaticWebsiteOnly } from "../static-mode";
|
import { isStaticWebsiteOnly } from "../static-mode";
|
||||||
|
|
||||||
|
import { AUTH_DISABLED_USER, isAuthDisabledMode } from "./auth-disabled-user";
|
||||||
import { getGatewayConfig } from "./gateway-config";
|
import { getGatewayConfig } from "./gateway-config";
|
||||||
import { STATIC_WEBSITE_USER } from "./static-user";
|
import { STATIC_WEBSITE_USER } from "./static-user";
|
||||||
import { type AuthResult, userSchema } from "./types";
|
import { type AuthResult, userSchema } from "./types";
|
||||||
@@ -20,15 +21,10 @@ export async function getServerSideUser(): Promise<AuthResult> {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (process.env.DEER_FLOW_AUTH_DISABLED === "1") {
|
if (isAuthDisabledMode()) {
|
||||||
return {
|
return {
|
||||||
tag: "authenticated",
|
tag: "authenticated",
|
||||||
user: {
|
user: AUTH_DISABLED_USER,
|
||||||
id: "e2e-user",
|
|
||||||
email: "e2e@test.local",
|
|
||||||
system_role: "admin",
|
|
||||||
needs_setup: false,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -252,6 +252,9 @@ export const enUS: Translations = {
|
|||||||
// Chats
|
// Chats
|
||||||
chats: {
|
chats: {
|
||||||
searchChats: "Search chats",
|
searchChats: "Search chats",
|
||||||
|
loadMoreToSearch: "Load more to search older conversations",
|
||||||
|
loadingMore: "Loading more...",
|
||||||
|
loadOlderChats: "Load older chats",
|
||||||
},
|
},
|
||||||
|
|
||||||
// Page titles (document title)
|
// Page titles (document title)
|
||||||
|
|||||||
@@ -183,6 +183,9 @@ export interface Translations {
|
|||||||
// Chats
|
// Chats
|
||||||
chats: {
|
chats: {
|
||||||
searchChats: string;
|
searchChats: string;
|
||||||
|
loadMoreToSearch: string;
|
||||||
|
loadingMore: string;
|
||||||
|
loadOlderChats: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Page titles (document title)
|
// Page titles (document title)
|
||||||
|
|||||||
@@ -240,6 +240,9 @@ export const zhCN: Translations = {
|
|||||||
// Chats
|
// Chats
|
||||||
chats: {
|
chats: {
|
||||||
searchChats: "搜索对话",
|
searchChats: "搜索对话",
|
||||||
|
loadMoreToSearch: "加载更多以搜索更早的对话",
|
||||||
|
loadingMore: "正在加载...",
|
||||||
|
loadOlderChats: "加载更早的对话",
|
||||||
},
|
},
|
||||||
|
|
||||||
// Page titles (document title)
|
// Page titles (document title)
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
|
|||||||
import { useStream } from "@langchain/langgraph-sdk/react";
|
import { useStream } from "@langchain/langgraph-sdk/react";
|
||||||
import {
|
import {
|
||||||
type QueryClient,
|
type QueryClient,
|
||||||
|
type InfiniteData,
|
||||||
|
useInfiniteQuery,
|
||||||
useMutation,
|
useMutation,
|
||||||
useQuery,
|
useQuery,
|
||||||
useQueryClient,
|
useQueryClient,
|
||||||
@@ -311,6 +313,56 @@ export function upsertThreadInSearchCache(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function upsertThreadInInfiniteCache(
|
||||||
|
queryClient: QueryClient,
|
||||||
|
thread: AgentThread,
|
||||||
|
) {
|
||||||
|
queryClient.setQueriesData(
|
||||||
|
{
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
exact: false,
|
||||||
|
},
|
||||||
|
(oldData: InfiniteData<AgentThread[]> | undefined) => {
|
||||||
|
if (!oldData) {
|
||||||
|
return oldData;
|
||||||
|
}
|
||||||
|
|
||||||
|
const merged = oldData.pages.map((page) =>
|
||||||
|
page.map((t) =>
|
||||||
|
t.thread_id === thread.thread_id
|
||||||
|
? {
|
||||||
|
...thread,
|
||||||
|
...t,
|
||||||
|
metadata: {
|
||||||
|
...(thread.metadata ?? {}),
|
||||||
|
...(t.metadata ?? {}),
|
||||||
|
},
|
||||||
|
values: {
|
||||||
|
...thread.values,
|
||||||
|
...t.values,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: t,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
const exists = merged.some((page) =>
|
||||||
|
page.some((t) => t.thread_id === thread.thread_id),
|
||||||
|
);
|
||||||
|
if (exists) {
|
||||||
|
return { ...oldData, pages: merged };
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstPage = merged[0] ?? [];
|
||||||
|
const restPages = merged.slice(1);
|
||||||
|
return {
|
||||||
|
...oldData,
|
||||||
|
pages: [[thread, ...firstPage], ...restPages],
|
||||||
|
};
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
function getStreamErrorMessage(error: unknown): string {
|
function getStreamErrorMessage(error: unknown): string {
|
||||||
if (typeof error === "string" && error.trim()) {
|
if (typeof error === "string" && error.trim()) {
|
||||||
return error;
|
return error;
|
||||||
@@ -364,7 +416,7 @@ export function useThreadStream({
|
|||||||
loadMore: loadMoreHistory,
|
loadMore: loadMoreHistory,
|
||||||
loading: isHistoryLoading,
|
loading: isHistoryLoading,
|
||||||
appendMessages,
|
appendMessages,
|
||||||
} = useThreadHistory(onStreamThreadId ?? "");
|
} = useThreadHistory(onStreamThreadId ?? "", { enabled: !isMock });
|
||||||
|
|
||||||
// Keep listeners ref updated with latest callbacks
|
// Keep listeners ref updated with latest callbacks
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -417,6 +469,19 @@ export function useThreadStream({
|
|||||||
},
|
},
|
||||||
interrupts: {},
|
interrupts: {},
|
||||||
});
|
});
|
||||||
|
upsertThreadInInfiniteCache(queryClient, {
|
||||||
|
thread_id: meta.thread_id,
|
||||||
|
created_at: now,
|
||||||
|
updated_at: now,
|
||||||
|
metadata: context.agent_name ? { agent_name: context.agent_name } : {},
|
||||||
|
status: "busy",
|
||||||
|
values: {
|
||||||
|
title: t.pages.newChat,
|
||||||
|
messages: [],
|
||||||
|
artifacts: [],
|
||||||
|
},
|
||||||
|
interrupts: {},
|
||||||
|
});
|
||||||
if (context.agent_name && !isMock) {
|
if (context.agent_name && !isMock) {
|
||||||
void getAPIClient()
|
void getAPIClient()
|
||||||
.threads.update(meta.thread_id, {
|
.threads.update(meta.thread_id, {
|
||||||
@@ -488,6 +553,27 @@ export function useThreadStream({
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
const nextTitle: string = update.title;
|
||||||
|
void queryClient.setQueriesData(
|
||||||
|
{
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
exact: false,
|
||||||
|
},
|
||||||
|
(oldData: InfiniteData<AgentThread[]> | undefined) =>
|
||||||
|
mapInfiniteThreadsCache(
|
||||||
|
oldData,
|
||||||
|
(t): AgentThread =>
|
||||||
|
t.thread_id === threadIdRef.current
|
||||||
|
? {
|
||||||
|
...t,
|
||||||
|
values: {
|
||||||
|
...t.values,
|
||||||
|
title: nextTitle,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: t,
|
||||||
|
),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -542,6 +628,9 @@ export function useThreadStream({
|
|||||||
.filter((id): id is string => Boolean(id)),
|
.filter((id): id is string => Boolean(id)),
|
||||||
);
|
);
|
||||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||||
|
void queryClient.invalidateQueries({
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
});
|
||||||
if (threadIdRef.current && !isMock) {
|
if (threadIdRef.current && !isMock) {
|
||||||
void queryClient.invalidateQueries({
|
void queryClient.invalidateQueries({
|
||||||
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
||||||
@@ -801,6 +890,9 @@ export function useThreadStream({
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||||
|
void queryClient.invalidateQueries({
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
setOptimisticMessages([]);
|
setOptimisticMessages([]);
|
||||||
setIsUploading(false);
|
setIsUploading(false);
|
||||||
@@ -854,8 +946,15 @@ export function useThreadStream({
|
|||||||
} as const;
|
} as const;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useThreadHistory(threadId: string) {
|
type ThreadHistoryOptions = {
|
||||||
const runs = useThreadRuns(threadId);
|
enabled?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function useThreadHistory(
|
||||||
|
threadId: string,
|
||||||
|
{ enabled = true }: ThreadHistoryOptions = {},
|
||||||
|
) {
|
||||||
|
const runs = useThreadRuns(threadId, { enabled });
|
||||||
const threadIdRef = useRef(threadId);
|
const threadIdRef = useRef(threadId);
|
||||||
const runsRef = useRef(runs.data ?? []);
|
const runsRef = useRef(runs.data ?? []);
|
||||||
const indexRef = useRef(-1);
|
const indexRef = useRef(-1);
|
||||||
@@ -864,10 +963,15 @@ export function useThreadHistory(threadId: string) {
|
|||||||
const loadingRunIdRef = useRef<string | null>(null);
|
const loadingRunIdRef = useRef<string | null>(null);
|
||||||
const loadedRunIdsRef = useRef<Set<string>>(new Set());
|
const loadedRunIdsRef = useRef<Set<string>>(new Set());
|
||||||
const runBeforeSeqRef = useRef<Map<string, number>>(new Map());
|
const runBeforeSeqRef = useRef<Map<string, number>>(new Map());
|
||||||
|
const loadGenerationRef = useRef(0);
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
const [messages, setMessages] = useState<Message[]>([]);
|
const [messages, setMessages] = useState<Message[]>([]);
|
||||||
|
|
||||||
const loadMessages = useCallback(async () => {
|
const loadMessages = useCallback(async () => {
|
||||||
|
if (!enabled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const loadGeneration = loadGenerationRef.current;
|
||||||
if (loadingRef.current) {
|
if (loadingRef.current) {
|
||||||
const pendingRunIndex = findLatestUnloadedRunIndex(
|
const pendingRunIndex = findLatestUnloadedRunIndex(
|
||||||
runsRef.current,
|
runsRef.current,
|
||||||
@@ -921,12 +1025,15 @@ export function useThreadHistory(threadId: string) {
|
|||||||
}).then((res) => {
|
}).then((res) => {
|
||||||
return res.json();
|
return res.json();
|
||||||
});
|
});
|
||||||
|
if (
|
||||||
|
loadGenerationRef.current !== loadGeneration ||
|
||||||
|
threadIdRef.current !== requestThreadId
|
||||||
|
) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const _messages = result.data
|
const _messages = result.data
|
||||||
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
||||||
.map((m) => m.content);
|
.map((m) => m.content);
|
||||||
if (threadIdRef.current !== requestThreadId) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
dedupeMessagesByIdentity([..._messages, ...prev]),
|
dedupeMessagesByIdentity([..._messages, ...prev]),
|
||||||
);
|
);
|
||||||
@@ -961,16 +1068,19 @@ export function useThreadHistory(threadId: string) {
|
|||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err);
|
console.error(err);
|
||||||
} finally {
|
} finally {
|
||||||
loadingRef.current = false;
|
if (loadGenerationRef.current === loadGeneration) {
|
||||||
loadingRunIdRef.current = null;
|
loadingRef.current = false;
|
||||||
setLoading(false);
|
loadingRunIdRef.current = null;
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}, []);
|
}, [enabled]);
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const threadChanged = threadIdRef.current !== threadId;
|
const threadChanged = threadIdRef.current !== threadId;
|
||||||
threadIdRef.current = threadId;
|
threadIdRef.current = threadId;
|
||||||
|
|
||||||
if (threadChanged) {
|
if (!enabled || threadChanged) {
|
||||||
|
loadGenerationRef.current += 1;
|
||||||
runsRef.current = [];
|
runsRef.current = [];
|
||||||
indexRef.current = -1;
|
indexRef.current = -1;
|
||||||
pendingLoadRef.current = false;
|
pendingLoadRef.current = false;
|
||||||
@@ -982,6 +1092,10 @@ export function useThreadHistory(threadId: string) {
|
|||||||
setMessages([]);
|
setMessages([]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!enabled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (runs.data && runs.data.length > 0) {
|
if (runs.data && runs.data.length > 0) {
|
||||||
runsRef.current = runs.data ?? [];
|
runsRef.current = runs.data ?? [];
|
||||||
indexRef.current = findLatestUnloadedRunIndex(
|
indexRef.current = findLatestUnloadedRunIndex(
|
||||||
@@ -992,14 +1106,15 @@ export function useThreadHistory(threadId: string) {
|
|||||||
loadMessages().catch(() => {
|
loadMessages().catch(() => {
|
||||||
toast.error("Failed to load thread history.");
|
toast.error("Failed to load thread history.");
|
||||||
});
|
});
|
||||||
}, [threadId, runs.data, loadMessages]);
|
}, [enabled, threadId, runs.data, loadMessages]);
|
||||||
|
|
||||||
const appendMessages = useCallback((_messages: Message[]) => {
|
const appendMessages = useCallback((_messages: Message[]) => {
|
||||||
setMessages((prev) => {
|
setMessages((prev) => {
|
||||||
return dedupeMessagesByIdentity([...prev, ..._messages]);
|
return dedupeMessagesByIdentity([...prev, ..._messages]);
|
||||||
});
|
});
|
||||||
}, []);
|
}, []);
|
||||||
const hasMore = indexRef.current >= 0 || !runs.data;
|
const hasMore =
|
||||||
|
enabled && Boolean(threadId) && (indexRef.current >= 0 || !runs.data);
|
||||||
return {
|
return {
|
||||||
runs: runs.data,
|
runs: runs.data,
|
||||||
messages,
|
messages,
|
||||||
@@ -1077,7 +1192,90 @@ export function useThreads(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useThreadRuns(threadId?: string) {
|
export const INFINITE_THREADS_PAGE_SIZE = 50;
|
||||||
|
|
||||||
|
export const INFINITE_THREADS_QUERY_KEY_PREFIX = [
|
||||||
|
"threads",
|
||||||
|
"searchInfinite",
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
type InfiniteThreadsParams = Omit<
|
||||||
|
Parameters<ThreadsClient["search"]>[0],
|
||||||
|
"limit" | "offset"
|
||||||
|
>;
|
||||||
|
|
||||||
|
export function getInfiniteThreadsNextPageParam(
|
||||||
|
lastPage: AgentThread[],
|
||||||
|
allPages: AgentThread[][],
|
||||||
|
pageSize: number = INFINITE_THREADS_PAGE_SIZE,
|
||||||
|
): number | undefined {
|
||||||
|
if (lastPage.length < pageSize) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
return allPages.reduce((sum, page) => sum + page.length, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function mapInfiniteThreadsCache(
|
||||||
|
oldData: InfiniteData<AgentThread[]> | undefined,
|
||||||
|
mapper: (thread: AgentThread) => AgentThread,
|
||||||
|
): InfiniteData<AgentThread[]> | undefined {
|
||||||
|
if (!oldData) {
|
||||||
|
return oldData;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
...oldData,
|
||||||
|
pages: oldData.pages.map((page) => page.map(mapper)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function filterInfiniteThreadsCache(
|
||||||
|
oldData: InfiniteData<AgentThread[]> | undefined,
|
||||||
|
predicate: (thread: AgentThread) => boolean,
|
||||||
|
): InfiniteData<AgentThread[]> | undefined {
|
||||||
|
if (!oldData) {
|
||||||
|
return oldData;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
...oldData,
|
||||||
|
pages: oldData.pages.map((page) => page.filter(predicate)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useInfiniteThreads(
|
||||||
|
params: InfiniteThreadsParams = {
|
||||||
|
sortBy: "updated_at",
|
||||||
|
sortOrder: "desc",
|
||||||
|
select: ["thread_id", "updated_at", "values", "metadata"],
|
||||||
|
},
|
||||||
|
) {
|
||||||
|
const apiClient = getAPIClient();
|
||||||
|
return useInfiniteQuery<
|
||||||
|
AgentThread[],
|
||||||
|
Error,
|
||||||
|
InfiniteData<AgentThread[]>,
|
||||||
|
readonly unknown[],
|
||||||
|
number
|
||||||
|
>({
|
||||||
|
queryKey: [...INFINITE_THREADS_QUERY_KEY_PREFIX, params],
|
||||||
|
initialPageParam: 0,
|
||||||
|
queryFn: async ({ pageParam }) => {
|
||||||
|
const response = (await apiClient.threads.search<AgentThreadState>({
|
||||||
|
...params,
|
||||||
|
limit: INFINITE_THREADS_PAGE_SIZE,
|
||||||
|
offset: pageParam,
|
||||||
|
})) as AgentThread[];
|
||||||
|
return response;
|
||||||
|
},
|
||||||
|
getNextPageParam: (lastPage, allPages) =>
|
||||||
|
getInfiniteThreadsNextPageParam(lastPage, allPages),
|
||||||
|
refetchOnWindowFocus: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useThreadRuns(
|
||||||
|
threadId?: string,
|
||||||
|
{ enabled = true }: { enabled?: boolean } = {},
|
||||||
|
) {
|
||||||
const apiClient = getAPIClient();
|
const apiClient = getAPIClient();
|
||||||
return useQuery<Run[]>({
|
return useQuery<Run[]>({
|
||||||
queryKey: ["thread", threadId],
|
queryKey: ["thread", threadId],
|
||||||
@@ -1088,6 +1286,7 @@ export function useThreadRuns(threadId?: string) {
|
|||||||
const response = await apiClient.runs.list(threadId);
|
const response = await apiClient.runs.list(threadId);
|
||||||
return response;
|
return response;
|
||||||
},
|
},
|
||||||
|
enabled: enabled && Boolean(threadId),
|
||||||
refetchOnWindowFocus: false,
|
refetchOnWindowFocus: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -1156,9 +1355,21 @@ export function useDeleteThread() {
|
|||||||
return oldData.filter((t) => t.thread_id !== threadId);
|
return oldData.filter((t) => t.thread_id !== threadId);
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
queryClient.setQueriesData(
|
||||||
|
{
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
exact: false,
|
||||||
|
},
|
||||||
|
(oldData: InfiniteData<AgentThread[]> | undefined) =>
|
||||||
|
filterInfiniteThreadsCache(oldData, (t) => t.thread_id !== threadId),
|
||||||
|
);
|
||||||
},
|
},
|
||||||
|
|
||||||
onSettled() {
|
onSettled() {
|
||||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||||
|
void queryClient.invalidateQueries({
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -1199,6 +1410,24 @@ export function useRenameThread() {
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
queryClient.setQueriesData(
|
||||||
|
{
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
exact: false,
|
||||||
|
},
|
||||||
|
(oldData: InfiniteData<AgentThread[]> | undefined) =>
|
||||||
|
mapInfiniteThreadsCache(oldData, (t) =>
|
||||||
|
t.thread_id === threadId
|
||||||
|
? {
|
||||||
|
...t,
|
||||||
|
values: {
|
||||||
|
...t.values,
|
||||||
|
title,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: t,
|
||||||
|
),
|
||||||
|
);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
import { expect, test } from "@playwright/test";
|
||||||
|
|
||||||
|
import { AUTH_DISABLED_USER } from "../../src/core/auth/auth-disabled-user";
|
||||||
|
|
||||||
|
const APP = "http://localhost:3000";
|
||||||
|
|
||||||
|
test.describe("auth-disabled contract (real backend)", () => {
|
||||||
|
test("gateway /auth/me returns the frontend synthetic user without a cookie", async ({
|
||||||
|
context,
|
||||||
|
}) => {
|
||||||
|
const resp = await context.request.get(`${APP}/api/v1/auth/me`);
|
||||||
|
|
||||||
|
expect(resp.status(), await resp.text()).toBe(200);
|
||||||
|
await expect(resp.json()).resolves.toEqual(AUTH_DISABLED_USER);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -101,10 +101,11 @@ test.describe("real backend render (replay, no API key)", () => {
|
|||||||
EXPECTED_SUGGESTION,
|
EXPECTED_SUGGESTION,
|
||||||
"fixture should contain a suggestions turn (re-record; the record spec waits for /suggestions)",
|
"fixture should contain a suggestions turn (re-record; the record spec waits for /suggestions)",
|
||||||
).not.toBe("");
|
).not.toBe("");
|
||||||
await expect(page.getByText(EXPECTED_TITLE)).toBeVisible({
|
const chat = page.locator("#chat");
|
||||||
|
await expect(chat.getByText(EXPECTED_TITLE)).toBeVisible({
|
||||||
timeout: 60_000,
|
timeout: 60_000,
|
||||||
});
|
});
|
||||||
await expect(page.getByText(EXPECTED_SUGGESTION)).toBeVisible({
|
await expect(chat.getByText(EXPECTED_SUGGESTION)).toBeVisible({
|
||||||
timeout: 30_000,
|
timeout: 30_000,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ test.describe("Chat workspace", () => {
|
|||||||
|
|
||||||
const textarea = page.getByPlaceholder(/how can i assist you/i);
|
const textarea = page.getByPlaceholder(/how can i assist you/i);
|
||||||
await expect(textarea).toBeVisible({ timeout: 15_000 });
|
await expect(textarea).toBeVisible({ timeout: 15_000 });
|
||||||
|
await expect(page.getByRole("button", { name: /load more/i })).toBeHidden();
|
||||||
});
|
});
|
||||||
|
|
||||||
test("can type a message in the input box", async ({ page }) => {
|
test("can type a message in the input box", async ({ page }) => {
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ const THREADS = [
|
|||||||
updated_at: "2025-06-02T12:00:00Z",
|
updated_at: "2025-06-02T12:00:00Z",
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
const DEMO_THREAD_ID = "7cfa5f8f-a2f8-47ad-acbd-da7137baf990";
|
||||||
|
|
||||||
test.describe("Thread history", () => {
|
test.describe("Thread history", () => {
|
||||||
test("sidebar shows existing threads", async ({ page }) => {
|
test("sidebar shows existing threads", async ({ page }) => {
|
||||||
@@ -61,6 +62,84 @@ test.describe("Thread history", () => {
|
|||||||
).toBeVisible({ timeout: 15_000 });
|
).toBeVisible({ timeout: 15_000 });
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("mock thread does not load real backend run history", async ({
|
||||||
|
page,
|
||||||
|
}) => {
|
||||||
|
mockLangGraphAPI(page, {
|
||||||
|
threads: [
|
||||||
|
{
|
||||||
|
thread_id: DEMO_THREAD_ID,
|
||||||
|
title: "Forecasting 2026 Trends and Opportunities",
|
||||||
|
updated_at: "2025-06-01T12:00:00Z",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
type: "human",
|
||||||
|
id: `run-human-${DEMO_THREAD_ID}`,
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "text",
|
||||||
|
text: "This run-message endpoint should not be called.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
const backendRunHistoryUrls: string[] = [];
|
||||||
|
await page.route(
|
||||||
|
/\/api\/langgraph\/threads\/[^/]+\/runs(?:\?|$)/,
|
||||||
|
(route) => {
|
||||||
|
if (
|
||||||
|
route.request().method() === "GET" &&
|
||||||
|
route
|
||||||
|
.request()
|
||||||
|
.url()
|
||||||
|
.includes(`/api/langgraph/threads/${DEMO_THREAD_ID}/runs`)
|
||||||
|
) {
|
||||||
|
backendRunHistoryUrls.push(route.request().url());
|
||||||
|
return route.fulfill({
|
||||||
|
status: 500,
|
||||||
|
contentType: "application/json",
|
||||||
|
body: JSON.stringify({
|
||||||
|
error: "mock=true must not load real runs",
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return route.fallback();
|
||||||
|
},
|
||||||
|
);
|
||||||
|
await page.route(
|
||||||
|
/\/api\/threads\/[^/]+\/runs\/[^/]+\/messages(?:\?|$)/,
|
||||||
|
(route) => {
|
||||||
|
if (
|
||||||
|
route.request().method() === "GET" &&
|
||||||
|
route.request().url().includes(`/api/threads/${DEMO_THREAD_ID}/runs/`)
|
||||||
|
) {
|
||||||
|
backendRunHistoryUrls.push(route.request().url());
|
||||||
|
return route.fulfill({
|
||||||
|
status: 500,
|
||||||
|
contentType: "application/json",
|
||||||
|
body: JSON.stringify({
|
||||||
|
error: "mock=true must not load real run messages",
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return route.fallback();
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
await page.goto(`/workspace/chats/${DEMO_THREAD_ID}?mock=true`);
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
page.getByText("What might be the trends and opportunities in 2026?"),
|
||||||
|
).toBeVisible({ timeout: 15_000 });
|
||||||
|
await expect(
|
||||||
|
page.getByText("I've created a modern, minimalist website"),
|
||||||
|
).toBeVisible();
|
||||||
|
expect(backendRunHistoryUrls).toEqual([]);
|
||||||
|
});
|
||||||
|
|
||||||
test("chats list page shows all threads", async ({ page }) => {
|
test("chats list page shows all threads", async ({ page }) => {
|
||||||
mockLangGraphAPI(page, { threads: THREADS });
|
mockLangGraphAPI(page, { threads: THREADS });
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,123 @@
|
|||||||
|
import { expect, test } from "@playwright/test";
|
||||||
|
|
||||||
|
import { mockLangGraphAPI } from "./utils/mock-api";
|
||||||
|
|
||||||
|
// Issue #3482: the sidebar's "Recent chats" and the /workspace/chats list
|
||||||
|
// page used to stop at the first 50 threads with no way to load more.
|
||||||
|
// `useInfiniteThreads()` + a sentinel near the bottom of each list now
|
||||||
|
// pages through the backend.
|
||||||
|
|
||||||
|
const TOTAL_THREADS = 120;
|
||||||
|
const PAGE_SIZE = 50;
|
||||||
|
|
||||||
|
const THREADS = Array.from({ length: TOTAL_THREADS }, (_, i) => {
|
||||||
|
// Pad index so titles sort deterministically as strings. The thread-search
|
||||||
|
// mock returns threads in the order provided, so paging boundaries are
|
||||||
|
// stable across runs.
|
||||||
|
const index = String(i + 1).padStart(3, "0");
|
||||||
|
return {
|
||||||
|
thread_id: `00000000-0000-0000-0000-0000000${index.padStart(5, "0")}`,
|
||||||
|
title: `Conversation ${index}`,
|
||||||
|
updated_at: `2025-06-${String((i % 28) + 1).padStart(2, "0")}T12:00:00Z`,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
const FIRST_PAGE_LAST = `Conversation ${String(PAGE_SIZE).padStart(3, "0")}`;
|
||||||
|
const SECOND_PAGE_FIRST = `Conversation ${String(PAGE_SIZE + 1).padStart(3, "0")}`;
|
||||||
|
|
||||||
|
test.describe("Thread list infinite scroll (issue #3482)", () => {
|
||||||
|
test("chats list page loads more threads when scrolling to the bottom", async ({
|
||||||
|
page,
|
||||||
|
}) => {
|
||||||
|
mockLangGraphAPI(page, { threads: THREADS });
|
||||||
|
|
||||||
|
await page.goto("/workspace/chats");
|
||||||
|
|
||||||
|
const main = page.locator("main");
|
||||||
|
|
||||||
|
// First page renders.
|
||||||
|
await expect(main.getByText(FIRST_PAGE_LAST)).toBeVisible({
|
||||||
|
timeout: 15_000,
|
||||||
|
});
|
||||||
|
// Items past the first page have not been fetched yet.
|
||||||
|
await expect(main.getByText(SECOND_PAGE_FIRST)).toHaveCount(0);
|
||||||
|
|
||||||
|
// Scrolling the sentinel into view triggers the next page.
|
||||||
|
const sentinel = page.getByTestId("chats-page-sentinel");
|
||||||
|
await sentinel.scrollIntoViewIfNeeded();
|
||||||
|
|
||||||
|
await expect(main.getByText(SECOND_PAGE_FIRST)).toBeVisible({
|
||||||
|
timeout: 15_000,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("sidebar recent chats loads more threads when scrolling to the bottom", async ({
|
||||||
|
page,
|
||||||
|
}) => {
|
||||||
|
mockLangGraphAPI(page, { threads: THREADS });
|
||||||
|
|
||||||
|
await page.goto("/workspace/chats/new");
|
||||||
|
|
||||||
|
// The 50th thread (end of first page) appears in the sidebar.
|
||||||
|
await expect(page.getByText(FIRST_PAGE_LAST).first()).toBeVisible({
|
||||||
|
timeout: 15_000,
|
||||||
|
});
|
||||||
|
// The 51st has not been fetched yet.
|
||||||
|
await expect(page.getByText(SECOND_PAGE_FIRST)).toHaveCount(0);
|
||||||
|
|
||||||
|
// Scroll the sidebar sentinel into view to trigger the next page.
|
||||||
|
const sentinel = page.getByTestId("recent-chat-list-sentinel");
|
||||||
|
await sentinel.scrollIntoViewIfNeeded();
|
||||||
|
|
||||||
|
await expect(page.getByText(SECOND_PAGE_FIRST).first()).toBeVisible({
|
||||||
|
timeout: 15_000,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("chats list page does NOT auto-paginate while a search filter is active", async ({
|
||||||
|
page,
|
||||||
|
}) => {
|
||||||
|
// Count search requests via a passive request observer. Using
|
||||||
|
// page.route() here would race with mockLangGraphAPI's fulfill route
|
||||||
|
// (Playwright matches routes in reverse registration order), so the
|
||||||
|
// counter could miss real requests. page.on('request') is a pure
|
||||||
|
// observer and never interferes with routing.
|
||||||
|
let searchRequestCount = 0;
|
||||||
|
page.on("request", (request) => {
|
||||||
|
if (request.url().includes("/api/langgraph/threads/search")) {
|
||||||
|
searchRequestCount += 1;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
mockLangGraphAPI(page, { threads: THREADS });
|
||||||
|
|
||||||
|
await page.goto("/workspace/chats");
|
||||||
|
|
||||||
|
// Wait for the first page to render so we have a baseline count.
|
||||||
|
await expect(page.locator("main").getByText(FIRST_PAGE_LAST)).toBeVisible({
|
||||||
|
timeout: 15_000,
|
||||||
|
});
|
||||||
|
const baselineRequests = searchRequestCount;
|
||||||
|
|
||||||
|
// Type a query that matches nothing in the first page (and nothing at
|
||||||
|
// all, since titles are deterministic).
|
||||||
|
await page
|
||||||
|
.getByPlaceholder("Search chats")
|
||||||
|
.fill("zzz-no-such-conversation");
|
||||||
|
|
||||||
|
// The auto-sentinel must be gone; an explicit button takes its place.
|
||||||
|
await expect(page.getByTestId("chats-page-sentinel")).toHaveCount(0);
|
||||||
|
await expect(page.getByTestId("chats-page-load-more")).toBeVisible();
|
||||||
|
|
||||||
|
// Give the IntersectionObserver a couple of frames to misbehave if the
|
||||||
|
// guard regresses. No additional /threads/search calls should fire.
|
||||||
|
await page.waitForTimeout(500);
|
||||||
|
expect(searchRequestCount).toBe(baselineRequests);
|
||||||
|
|
||||||
|
// The explicit button still works as an escape hatch.
|
||||||
|
await page.getByTestId("chats-page-load-more").click();
|
||||||
|
await expect
|
||||||
|
.poll(() => searchRequestCount, { timeout: 10_000 })
|
||||||
|
.toBeGreaterThan(baselineRequests);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -85,7 +85,7 @@ export function mockLangGraphAPI(page: Page, options?: MockAPIOptions) {
|
|||||||
const skills = options?.skills ?? DEFAULT_SKILLS;
|
const skills = options?.skills ?? DEFAULT_SKILLS;
|
||||||
|
|
||||||
// Thread search — sidebar thread list & chats list page
|
// Thread search — sidebar thread list & chats list page
|
||||||
void page.route("**/api/langgraph/threads/search", (route) => {
|
void page.route("**/api/langgraph/threads/search", async (route) => {
|
||||||
const body = threads.map((t) => ({
|
const body = threads.map((t) => ({
|
||||||
thread_id: t.thread_id,
|
thread_id: t.thread_id,
|
||||||
created_at: "2025-01-01T00:00:00Z",
|
created_at: "2025-01-01T00:00:00Z",
|
||||||
@@ -94,10 +94,33 @@ export function mockLangGraphAPI(page: Page, options?: MockAPIOptions) {
|
|||||||
status: "idle",
|
status: "idle",
|
||||||
values: { title: t.title ?? "Untitled" },
|
values: { title: t.title ?? "Untitled" },
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
let limit: number | undefined;
|
||||||
|
let offset = 0;
|
||||||
|
try {
|
||||||
|
const postData = route.request().postDataJSON() as {
|
||||||
|
limit?: number;
|
||||||
|
offset?: number;
|
||||||
|
} | null;
|
||||||
|
if (postData) {
|
||||||
|
if (typeof postData.limit === "number") {
|
||||||
|
limit = postData.limit;
|
||||||
|
}
|
||||||
|
if (typeof postData.offset === "number") {
|
||||||
|
offset = postData.offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// No / invalid JSON body — fall back to returning the full list.
|
||||||
|
}
|
||||||
|
|
||||||
|
const sliced =
|
||||||
|
typeof limit === "number" ? body.slice(offset, offset + limit) : body;
|
||||||
|
|
||||||
return route.fulfill({
|
return route.fulfill({
|
||||||
status: 200,
|
status: 200,
|
||||||
contentType: "application/json",
|
contentType: "application/json",
|
||||||
body: JSON.stringify(body),
|
body: JSON.stringify(sliced),
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
|
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
|
||||||
|
|
||||||
|
import { AUTH_DISABLED_USER } from "@/core/auth/auth-disabled-user";
|
||||||
import { STATIC_WEBSITE_USER } from "@/core/auth/static-user";
|
import { STATIC_WEBSITE_USER } from "@/core/auth/static-user";
|
||||||
|
|
||||||
vi.mock("next/headers", () => ({
|
vi.mock("next/headers", () => ({
|
||||||
@@ -10,6 +11,8 @@ vi.mock("next/headers", () => ({
|
|||||||
|
|
||||||
const ENV_KEYS = [
|
const ENV_KEYS = [
|
||||||
"DEER_FLOW_AUTH_DISABLED",
|
"DEER_FLOW_AUTH_DISABLED",
|
||||||
|
"DEER_FLOW_ENV",
|
||||||
|
"ENVIRONMENT",
|
||||||
"NEXT_PUBLIC_STATIC_WEBSITE_ONLY",
|
"NEXT_PUBLIC_STATIC_WEBSITE_ONLY",
|
||||||
] as const;
|
] as const;
|
||||||
|
|
||||||
@@ -51,6 +54,8 @@ describe("getServerSideUser", () => {
|
|||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
saved = snapshotEnv();
|
saved = snapshotEnv();
|
||||||
setEnv("DEER_FLOW_AUTH_DISABLED", undefined);
|
setEnv("DEER_FLOW_AUTH_DISABLED", undefined);
|
||||||
|
setEnv("DEER_FLOW_ENV", undefined);
|
||||||
|
setEnv("ENVIRONMENT", undefined);
|
||||||
setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", undefined);
|
setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", undefined);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -74,4 +79,30 @@ describe("getServerSideUser", () => {
|
|||||||
});
|
});
|
||||||
expect(fetchSpy).not.toHaveBeenCalled();
|
expect(fetchSpy).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("bypasses gateway auth in auth-disabled mode", async () => {
|
||||||
|
setEnv("DEER_FLOW_AUTH_DISABLED", "1");
|
||||||
|
const fetchSpy = vi.fn(() => {
|
||||||
|
throw new Error("fetch should not be called in auth-disabled mode");
|
||||||
|
});
|
||||||
|
vi.stubGlobal("fetch", fetchSpy);
|
||||||
|
|
||||||
|
const { getServerSideUser } = await loadFreshServerAuth();
|
||||||
|
|
||||||
|
await expect(getServerSideUser()).resolves.toEqual({
|
||||||
|
tag: "authenticated",
|
||||||
|
user: AUTH_DISABLED_USER,
|
||||||
|
});
|
||||||
|
expect(fetchSpy).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("does not enable auth-disabled mode in explicit production environments", async () => {
|
||||||
|
setEnv("DEER_FLOW_AUTH_DISABLED", "1");
|
||||||
|
setEnv("DEER_FLOW_ENV", "production");
|
||||||
|
|
||||||
|
const { isAuthDisabledMode } =
|
||||||
|
await import("@/core/auth/auth-disabled-user");
|
||||||
|
|
||||||
|
expect(isAuthDisabledMode()).toBe(false);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -0,0 +1,228 @@
|
|||||||
|
import { QueryClient, type InfiniteData } from "@tanstack/react-query";
|
||||||
|
import { describe, expect, test } from "vitest";
|
||||||
|
|
||||||
|
import {
|
||||||
|
filterInfiniteThreadsCache,
|
||||||
|
getInfiniteThreadsNextPageParam,
|
||||||
|
INFINITE_THREADS_PAGE_SIZE,
|
||||||
|
INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
mapInfiniteThreadsCache,
|
||||||
|
upsertThreadInInfiniteCache,
|
||||||
|
} from "@/core/threads/hooks";
|
||||||
|
import type { AgentThread } from "@/core/threads/types";
|
||||||
|
|
||||||
|
// Issue #3482: the sidebar and /workspace/chats list used to be capped at
|
||||||
|
// 50 threads because `useThreads()` exits as soon as `threads.length >=
|
||||||
|
// params.limit`. These pure helpers back the `useInfiniteThreads()`
|
||||||
|
// pagination logic and the mirrored cache writes that keep rename / delete
|
||||||
|
// / stream-finish in sync with both the legacy array cache and the new
|
||||||
|
// infinite cache.
|
||||||
|
|
||||||
|
function makeThread(id: string, title = `Title ${id}`): AgentThread {
|
||||||
|
return {
|
||||||
|
thread_id: id,
|
||||||
|
created_at: "2025-01-01T00:00:00Z",
|
||||||
|
updated_at: "2025-01-01T00:00:00Z",
|
||||||
|
metadata: {},
|
||||||
|
status: "idle",
|
||||||
|
values: { title },
|
||||||
|
} as unknown as AgentThread;
|
||||||
|
}
|
||||||
|
|
||||||
|
function makePage(start: number, size: number): AgentThread[] {
|
||||||
|
return Array.from({ length: size }, (_, i) => makeThread(`t-${start + i}`));
|
||||||
|
}
|
||||||
|
|
||||||
|
function makeInfiniteData(pages: AgentThread[][]): InfiniteData<AgentThread[]> {
|
||||||
|
return {
|
||||||
|
pages,
|
||||||
|
pageParams: pages.map((_, i) => i * INFINITE_THREADS_PAGE_SIZE),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("getInfiniteThreadsNextPageParam", () => {
|
||||||
|
test("returns next offset when the last page is full", () => {
|
||||||
|
const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE);
|
||||||
|
expect(getInfiniteThreadsNextPageParam(page1, [page1])).toBe(
|
||||||
|
INFINITE_THREADS_PAGE_SIZE,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("returns next offset across multiple full pages", () => {
|
||||||
|
const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE);
|
||||||
|
const page2 = makePage(
|
||||||
|
INFINITE_THREADS_PAGE_SIZE,
|
||||||
|
INFINITE_THREADS_PAGE_SIZE,
|
||||||
|
);
|
||||||
|
expect(getInfiniteThreadsNextPageParam(page2, [page1, page2])).toBe(
|
||||||
|
INFINITE_THREADS_PAGE_SIZE * 2,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("returns undefined when the last page is short (end of list)", () => {
|
||||||
|
const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE);
|
||||||
|
const page2 = makePage(INFINITE_THREADS_PAGE_SIZE, 10);
|
||||||
|
expect(
|
||||||
|
getInfiniteThreadsNextPageParam(page2, [page1, page2]),
|
||||||
|
).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("returns undefined when the last page is empty", () => {
|
||||||
|
const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE);
|
||||||
|
expect(getInfiniteThreadsNextPageParam([], [page1, []])).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("respects a custom page size", () => {
|
||||||
|
const page1 = makePage(0, 5);
|
||||||
|
expect(getInfiniteThreadsNextPageParam(page1, [page1], 5)).toBe(5);
|
||||||
|
expect(getInfiniteThreadsNextPageParam(page1, [page1], 10)).toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("mapInfiniteThreadsCache", () => {
|
||||||
|
test("returns undefined when oldData is undefined", () => {
|
||||||
|
expect(mapInfiniteThreadsCache(undefined, (t) => t)).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("updates the matching thread across multiple pages", () => {
|
||||||
|
const page1 = [makeThread("a"), makeThread("b")];
|
||||||
|
const page2 = [makeThread("c"), makeThread("d")];
|
||||||
|
const data = makeInfiniteData([page1, page2]);
|
||||||
|
|
||||||
|
const updated = mapInfiniteThreadsCache(data, (t) =>
|
||||||
|
t.thread_id === "c"
|
||||||
|
? { ...t, values: { ...t.values, title: "renamed" } }
|
||||||
|
: t,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(updated?.pages[0]?.[0]?.values?.title).toBe("Title a");
|
||||||
|
expect(updated?.pages[1]?.[0]?.thread_id).toBe("c");
|
||||||
|
expect(updated?.pages[1]?.[0]?.values?.title).toBe("renamed");
|
||||||
|
expect(updated?.pages[1]?.[1]?.values?.title).toBe("Title d");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("preserves pageParams", () => {
|
||||||
|
const data = makeInfiniteData([[makeThread("a")]]);
|
||||||
|
const updated = mapInfiniteThreadsCache(data, (t) => t);
|
||||||
|
expect(updated?.pageParams).toEqual(data.pageParams);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("filterInfiniteThreadsCache", () => {
|
||||||
|
test("returns undefined when oldData is undefined", () => {
|
||||||
|
expect(filterInfiniteThreadsCache(undefined, () => true)).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("removes matching threads across all pages", () => {
|
||||||
|
const page1 = [makeThread("a"), makeThread("b")];
|
||||||
|
const page2 = [makeThread("b"), makeThread("c")];
|
||||||
|
const data = makeInfiniteData([page1, page2]);
|
||||||
|
|
||||||
|
const filtered = filterInfiniteThreadsCache(
|
||||||
|
data,
|
||||||
|
(t) => t.thread_id !== "b",
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(filtered?.pages[0]?.map((t) => t.thread_id)).toEqual(["a"]);
|
||||||
|
expect(filtered?.pages[1]?.map((t) => t.thread_id)).toEqual(["c"]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("keeps an emptied page as an empty array (does not drop the page)", () => {
|
||||||
|
const page1 = [makeThread("a")];
|
||||||
|
const page2 = [makeThread("b")];
|
||||||
|
const data = makeInfiniteData([page1, page2]);
|
||||||
|
|
||||||
|
const filtered = filterInfiniteThreadsCache(
|
||||||
|
data,
|
||||||
|
(t) => t.thread_id !== "a",
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(filtered?.pages).toHaveLength(2);
|
||||||
|
expect(filtered?.pages[0]).toEqual([]);
|
||||||
|
expect(filtered?.pages[1]?.[0]?.thread_id).toBe("b");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("does not regress next offset when an earlier page has been shrunk by a delete", () => {
|
||||||
|
// Simulate two full pages already loaded.
|
||||||
|
const page1 = Array.from({ length: 50 }, (_, i) => ({
|
||||||
|
thread_id: `a${i}`,
|
||||||
|
}));
|
||||||
|
const page2 = Array.from({ length: 50 }, (_, i) => ({
|
||||||
|
thread_id: `b${i}`,
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Offset right after fetching page 2 (this is the value TanStack Query
|
||||||
|
// freezes into pageParams).
|
||||||
|
const offsetAfterPage2 = getInfiniteThreadsNextPageParam(
|
||||||
|
page2 as unknown as AgentThread[],
|
||||||
|
[page1, page2] as unknown as AgentThread[][],
|
||||||
|
);
|
||||||
|
expect(offsetAfterPage2).toBe(100);
|
||||||
|
|
||||||
|
// Now a delete mutation runs filterInfiniteThreadsCache and shrinks
|
||||||
|
// page 1 from 50 to 49 entries. TanStack does NOT re-invoke
|
||||||
|
// getNextPageParam on cache mutations; the previously-computed offset
|
||||||
|
// (100) remains the param for the next fetchNextPage() call, so the
|
||||||
|
// helper is consistent with how the library uses its return value.
|
||||||
|
const shrunkPage1 = page1.slice(0, 49);
|
||||||
|
const recomputed = getInfiniteThreadsNextPageParam(
|
||||||
|
page2 as unknown as AgentThread[],
|
||||||
|
[shrunkPage1, page2] as unknown as AgentThread[][],
|
||||||
|
);
|
||||||
|
// We document the recomputed value for completeness, but in practice
|
||||||
|
// useDeleteThread invalidates the query in onSettled, so pages are
|
||||||
|
// refetched from offset 0 rather than relying on this number.
|
||||||
|
expect(recomputed).toBe(99);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("upsertThreadInInfiniteCache", () => {
|
||||||
|
function seedClient(initial?: InfiniteData<AgentThread[]>): QueryClient {
|
||||||
|
const client = new QueryClient();
|
||||||
|
if (initial) {
|
||||||
|
client.setQueryData([...INFINITE_THREADS_QUERY_KEY_PREFIX, {}], initial);
|
||||||
|
}
|
||||||
|
return client;
|
||||||
|
}
|
||||||
|
|
||||||
|
function readCache(
|
||||||
|
client: QueryClient,
|
||||||
|
): InfiniteData<AgentThread[]> | undefined {
|
||||||
|
return client.getQueryData([...INFINITE_THREADS_QUERY_KEY_PREFIX, {}]);
|
||||||
|
}
|
||||||
|
|
||||||
|
test("no-op when the infinite cache has not been initialised yet", () => {
|
||||||
|
const client = seedClient();
|
||||||
|
upsertThreadInInfiniteCache(client, makeThread("new"));
|
||||||
|
expect(readCache(client)).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("prepends a brand-new thread to the first page", () => {
|
||||||
|
const client = seedClient({
|
||||||
|
pages: [[makeThread("a"), makeThread("b")]],
|
||||||
|
pageParams: [0],
|
||||||
|
});
|
||||||
|
upsertThreadInInfiniteCache(client, makeThread("new"));
|
||||||
|
const cache = readCache(client);
|
||||||
|
expect(cache?.pages[0]?.map((t) => t.thread_id)).toEqual(["new", "a", "b"]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("merges into the existing entry instead of duplicating it", () => {
|
||||||
|
const existing = makeThread("a", "Old title");
|
||||||
|
const client = seedClient({
|
||||||
|
pages: [[existing, makeThread("b")]],
|
||||||
|
pageParams: [0],
|
||||||
|
});
|
||||||
|
// Simulate an onCreated upsert that races with a thread already in cache:
|
||||||
|
// the cache copy should win for title/metadata (it represents later state),
|
||||||
|
// but no duplicate row should appear.
|
||||||
|
upsertThreadInInfiniteCache(client, {
|
||||||
|
...makeThread("a", "New title"),
|
||||||
|
status: "busy",
|
||||||
|
});
|
||||||
|
const cache = readCache(client);
|
||||||
|
const ids = cache?.pages[0]?.map((t) => t.thread_id);
|
||||||
|
expect(ids).toEqual(["a", "b"]);
|
||||||
|
expect(cache?.pages[0]?.[0]?.values.title).toBe("Old title");
|
||||||
|
});
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user