diff --git a/.github/workflows/lint-check.yml b/.github/workflows/lint-check.yml index e6ffa7f07..d46b53cb8 100644 --- a/.github/workflows/lint-check.yml +++ b/.github/workflows/lint-check.yml @@ -10,7 +10,7 @@ permissions: contents: read jobs: - lint: + lint-backend: runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 938d97b90..01da37b05 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -441,6 +441,12 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk 4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append 5. Next interaction injects top 15 facts + context into `` tags in system prompt +**Token counting** (`packages/harness/deerflow/agents/memory/prompt.py`): +- `_count_tokens` budgets the injection. In default `tiktoken` mode, the encoding is loaded lazily and cached. +- Failed tiktoken loads are cached with a timestamp. During the fixed cooldown (`_TIKTOKEN_RETRY_COOLDOWN_S`, 600s), callers fall back to char estimation immediately instead of re-triggering the blocking BPE download; after the cooldown, transient outages can self-heal without a restart. +- In-flight loads are cached as a LOADING sentinel so concurrent callers fall back instead of spawning more blocking threads. +- Set `memory.token_counting: char` to skip tiktoken entirely and use the network-free CJK-aware char estimate. + Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`. **Configuration** (`config.yaml` → `memory`): @@ -450,6 +456,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_ - `model_name` - LLM for updates (null = default model) - `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7) - `max_injection_tokens` - Token limit for prompt injection (2000) +- `token_counting` - Token counting strategy for the injection budget: `tiktoken` (default, accurate but may download BPE data from a public endpoint on first use — can block for a long time in network-restricted environments, see issues #3402/#3429) or `char` (network-free CJK-aware char estimate, never touches tiktoken) ### Reflection System (`packages/harness/deerflow/reflection/`) diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index 56ea57e50..7e080a587 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -185,21 +185,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # Pre-warm tiktoken encoding cache so the first memory-injection request # never blocks on the BPE data download (which hits an OpenAI/Azure URL # that may be unreachable in restricted networks — see issue #3402). - try: - from deerflow.agents.memory.prompt import warm_tiktoken_cache + # When memory.token_counting is "char", token counting never touches + # tiktoken, so skip the warm-up entirely (avoids even the 5s probe in + # network-restricted deployments — see issue #3429). + if startup_config.memory.token_counting == "char": + logger.info("memory.token_counting='char'; skipping tiktoken warm-up (network-free token estimation)") + else: + try: + from deerflow.agents.memory.prompt import warm_tiktoken_cache - warmed = await asyncio.wait_for( - asyncio.to_thread(warm_tiktoken_cache), - timeout=5, - ) - if warmed: - logger.info("tiktoken encoding cache warmed successfully") - else: - logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback") - except TimeoutError: - logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback") - except Exception: - logger.warning("tiktoken warm-up skipped", exc_info=True) + warmed = await asyncio.wait_for( + asyncio.to_thread(warm_tiktoken_cache), + timeout=5, + ) + if warmed: + logger.info("tiktoken encoding cache warmed successfully") + else: + logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback until tiktoken loads successfully") + except TimeoutError: + logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback until tiktoken loads successfully") + except Exception: + logger.warning("tiktoken warm-up skipped", exc_info=True) # Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store) async with langgraph_runtime(app, startup_config): diff --git a/backend/app/gateway/routers/memory.py b/backend/app/gateway/routers/memory.py index ca9e5f5e5..fd413a715 100644 --- a/backend/app/gateway/routers/memory.py +++ b/backend/app/gateway/routers/memory.py @@ -98,6 +98,7 @@ class MemoryConfigResponse(BaseModel): fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts") injection_enabled: bool = Field(..., description="Whether memory injection is enabled") max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection") + token_counting: str = Field(..., description="Token counting strategy for memory injection ('tiktoken' or 'char')") class MemoryStatusResponse(BaseModel): @@ -310,7 +311,8 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse: "max_facts": 100, "fact_confidence_threshold": 0.7, "injection_enabled": true, - "max_injection_tokens": 2000 + "max_injection_tokens": 2000, + "token_counting": "tiktoken" } ``` """ @@ -323,6 +325,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse: fact_confidence_threshold=config.fact_confidence_threshold, injection_enabled=config.injection_enabled, max_injection_tokens=config.max_injection_tokens, + token_counting=config.token_counting, ) @@ -351,6 +354,7 @@ async def get_memory_status() -> MemoryStatusResponse: fact_confidence_threshold=config.fact_confidence_threshold, injection_enabled=config.injection_enabled, max_injection_tokens=config.max_injection_tokens, + token_counting=config.token_counting, ), data=MemoryResponse(**memory_data), ) diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 015f74398..ea39ab72e 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -318,6 +318,21 @@ async def start_run( ) owner_user_id = get_trusted_internal_owner_user_id(request) + # Stateless run endpoints carry thread_id in the request *body*, so the + # @require_permission(owner_check=True) decorator -- which resolves ownership + # from the path param -- cannot protect them. Enforce thread ownership here, + # before any run is created, so one user cannot start runs on (or read /wait + # checkpoint state from) another user's thread. Missing rows (auto-created + # temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible + # via check_access; only a thread already owned by another user is rejected + # with 404, matching thread_runs.py's anti-enumeration behaviour. Internal + # channel runs act on behalf of IM users they do not own (see + # inject_authenticated_user_context), so the internal system role is exempt. + user = getattr(request.state, "user", None) + if user is not None and getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE: + if not await run_ctx.thread_store.check_access(thread_id, str(user.id)): + raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") + owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None try: try: diff --git a/backend/docs/MEMORY_IMPROVEMENTS.md b/backend/docs/MEMORY_IMPROVEMENTS.md index 3fddd4b6b..0b098c150 100644 --- a/backend/docs/MEMORY_IMPROVEMENTS.md +++ b/backend/docs/MEMORY_IMPROVEMENTS.md @@ -31,7 +31,8 @@ Current injection format: Token counting: - Uses `tiktoken` (`cl100k_base`) when available -- Falls back to `len(text) // 4` if tokenizer import fails +- Falls back to a network-free CJK-aware character estimate if tokenizer import or encoding load fails + (CJK characters count as ~2 chars/token, other characters as ~4 chars/token) ## Known Gap diff --git a/backend/packages/harness/deerflow/agents/lead_agent/prompt.py b/backend/packages/harness/deerflow/agents/lead_agent/prompt.py index 7a32d0c9e..440f43960 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/prompt.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/prompt.py @@ -586,7 +586,11 @@ def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig return "" memory_data = get_memory_data(agent_name, user_id=get_effective_user_id()) - memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens) + memory_content = format_memory_for_injection( + memory_data, + max_tokens=config.max_injection_tokens, + use_tiktoken=(config.token_counting == "tiktoken"), + ) if not memory_content.strip(): return "" diff --git a/backend/packages/harness/deerflow/agents/memory/prompt.py b/backend/packages/harness/deerflow/agents/memory/prompt.py index f9e0b582b..69fe48aff 100644 --- a/backend/packages/harness/deerflow/agents/memory/prompt.py +++ b/backend/packages/harness/deerflow/agents/memory/prompt.py @@ -5,7 +5,9 @@ from __future__ import annotations import logging import math import re -from typing import Any +import threading +import time +from typing import Any, cast logger = logging.getLogger(__name__) @@ -169,7 +171,26 @@ Return ONLY valid JSON.""" # subsequent calls are a dict lookup (no network I/O). Pre-warming at # startup via :func:`warm_tiktoken_cache` avoids blocking a request on the # (potentially slow) first ``get_encoding`` call. -_tiktoken_encoding_cache: dict[str, tiktoken.Encoding] = {} +# +# A *failed* load is cached as a ``(None, monotonic_timestamp)`` tuple so that +# a network-restricted environment does not re-attempt the blocking BPE +# download on every subsequent call. After ``_TIKTOKEN_RETRY_COOLDOWN_S`` the +# failure is allowed to expire so a transient network outage can self-heal back +# to accurate tiktoken counting without a process restart. A load already in +# progress is cached as ``_TIKTOKEN_ENCODING_LOADING`` so concurrent callers +# fall back immediately instead of spawning more blocking +# ``tiktoken.get_encoding`` threads. Use the ``memory.token_counting: char`` +# config to skip tiktoken entirely. +_TIKTOKEN_ENCODING_MISSING = object() +_TIKTOKEN_ENCODING_LOADING = object() +# Cooldown before a *failed* tiktoken load is re-attempted. This is an internal +# tuning constant rather than a user-facing config: it only affects how quickly +# the default ``tiktoken`` mode self-heals after a transient network outage. +# Deployments that want to avoid tiktoken's network dependency entirely should +# set ``memory.token_counting: char`` instead of tuning this value. +_TIKTOKEN_RETRY_COOLDOWN_S = 600.0 +_tiktoken_encoding_cache: dict[str, Any] = {} +_tiktoken_encoding_cache_lock = threading.Lock() def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None: @@ -181,44 +202,91 @@ def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encod download can block for tens of minutes before the OS TCP timeout kicks in. The caller must therefore be prepared for this to block and should run it off the event loop (e.g. via ``asyncio.to_thread``). + + A failed load is remembered (with a timestamp) so subsequent calls fall + back immediately to character-based estimation instead of re-triggering the + blocking download. The failure expires after ``_TIKTOKEN_RETRY_COOLDOWN_S`` + so a transient outage can self-heal without a restart. A load already in + progress is also remembered so that a timed-out caller does not leave a + window where later requests start more blocking ``get_encoding`` calls. """ if not TIKTOKEN_AVAILABLE: return None - cached = _tiktoken_encoding_cache.get(encoding_name) - if cached is not None: - return cached + with _tiktoken_encoding_cache_lock: + cached = _tiktoken_encoding_cache.get(encoding_name, _TIKTOKEN_ENCODING_MISSING) + if cached is _TIKTOKEN_ENCODING_LOADING: + return None + if isinstance(cached, tuple): + # Cached failure: (None, failed_at). Retry only after cooldown. + _, failed_at = cached + if time.monotonic() - failed_at < _TIKTOKEN_RETRY_COOLDOWN_S: + return None + cached = _TIKTOKEN_ENCODING_MISSING + if cached is not _TIKTOKEN_ENCODING_MISSING: + return cast("tiktoken.Encoding", cached) + _tiktoken_encoding_cache[encoding_name] = _TIKTOKEN_ENCODING_LOADING try: encoding = tiktoken.get_encoding(encoding_name) - _tiktoken_encoding_cache[encoding_name] = encoding - return encoding except Exception: logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True) + with _tiktoken_encoding_cache_lock: + _tiktoken_encoding_cache[encoding_name] = (None, time.monotonic()) return None + with _tiktoken_encoding_cache_lock: + _tiktoken_encoding_cache[encoding_name] = encoding + return encoding -def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int: + +def _char_based_token_estimate(text: str) -> int: + """Network-free token estimate that accounts for CJK density. + + The plain ``len(text) // 4`` heuristic is reasonable for English/code + (~4 chars per token) but significantly under-estimates token counts for + Chinese, Japanese, and Korean text, where the ratio is closer to 1.5-2 + characters per token. Counting CJK characters separately (~2 chars per + token) avoids over-filling the injection budget for CJK-heavy memory + content. + """ + cjk = sum( + 1 + for ch in text + if "\u4e00" <= ch <= "\u9fff" # CJK Unified Ideographs + or "\u3040" <= ch <= "\u30ff" # Hiragana + Katakana + or "\uac00" <= ch <= "\ud7a3" # Hangul syllables + ) + return (len(text) - cjk) // 4 + cjk // 2 + + +def _count_tokens(text: str, encoding_name: str = "cl100k_base", *, use_tiktoken: bool = True) -> int: """Count tokens in text using tiktoken. Args: text: The text to count tokens for. encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5). + use_tiktoken: When ``False``, skip tiktoken entirely and use the + network-free character-based estimate. This guarantees no BPE + download is attempted (see ``memory.token_counting`` config). Returns: The number of tokens in the text. """ + if not use_tiktoken: + return _char_based_token_estimate(text) + encoding = _get_tiktoken_encoding(encoding_name) if encoding is None: - # Fallback to character-based estimation if tiktoken is not available - # or the encoding failed to load. - return len(text) // 4 + # Fallback to CJK-aware character estimation if tiktoken is not + # available or the encoding failed to load. + return _char_based_token_estimate(text) try: return len(encoding.encode(text)) except Exception: - # Fallback to character-based estimation on error - return len(text) // 4 + # Fallback to CJK-aware character estimation on error. + return _char_based_token_estimate(text) def warm_tiktoken_cache() -> bool: @@ -248,12 +316,15 @@ def _coerce_confidence(value: Any, default: float = 0.0) -> float: return max(0.0, min(1.0, confidence)) -def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str: +def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000, *, use_tiktoken: bool = True) -> str: """Format memory data for injection into system prompt. Args: memory_data: The memory data dictionary. max_tokens: Maximum tokens to use (counted via tiktoken for accuracy). + use_tiktoken: When ``False``, all token counting uses the network-free + character-based estimate instead of tiktoken (see + ``memory.token_counting`` config). Defaults to ``True``. Returns: Formatted memory string for system prompt injection. @@ -315,10 +386,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2 # Compute token count for existing sections once, then account # incrementally for each fact line to avoid full-string re-tokenization. base_text = "\n\n".join(sections) - base_tokens = _count_tokens(base_text) if base_text else 0 + base_tokens = _count_tokens(base_text, use_tiktoken=use_tiktoken) if base_text else 0 # Account for the separator between existing sections and the facts section. facts_header = "Facts:\n" - separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header) + separator_tokens = _count_tokens("\n\n" + facts_header, use_tiktoken=use_tiktoken) if base_text else _count_tokens(facts_header, use_tiktoken=use_tiktoken) running_tokens = base_tokens + separator_tokens fact_lines: list[str] = [] @@ -339,7 +410,7 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2 # Each additional line is preceded by a newline (except the first). line_text = ("\n" + line) if fact_lines else line - line_tokens = _count_tokens(line_text) + line_tokens = _count_tokens(line_text, use_tiktoken=use_tiktoken) if running_tokens + line_tokens <= max_tokens: fact_lines.append(line) @@ -355,8 +426,9 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2 result = "\n\n".join(sections) - # Use accurate token counting with tiktoken - token_count = _count_tokens(result) + # Use accurate token counting with tiktoken (or the char-based estimate + # when use_tiktoken is False). + token_count = _count_tokens(result, use_tiktoken=use_tiktoken) if token_count > max_tokens: # Truncate to fit within token limit # Estimate characters to remove based on token ratio diff --git a/backend/packages/harness/deerflow/client.py b/backend/packages/harness/deerflow/client.py index 563c8f835..b0c0b8b13 100644 --- a/backend/packages/harness/deerflow/client.py +++ b/backend/packages/harness/deerflow/client.py @@ -1141,6 +1141,7 @@ class DeerFlowClient: "fact_confidence_threshold": config.fact_confidence_threshold, "injection_enabled": config.injection_enabled, "max_injection_tokens": config.max_injection_tokens, + "token_counting": config.token_counting, } def get_memory_status(self) -> dict: diff --git a/backend/packages/harness/deerflow/config/agents_config.py b/backend/packages/harness/deerflow/config/agents_config.py index 86b5347db..a9c9e212d 100644 --- a/backend/packages/harness/deerflow/config/agents_config.py +++ b/backend/packages/harness/deerflow/config/agents_config.py @@ -67,11 +67,13 @@ def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path: paths = get_paths() effective_user = user_id or get_effective_user_id() user_path = paths.user_agent_dir(effective_user, name) - if user_path.exists(): + # Require config.yaml to confirm this is a genuine agent directory, + # not a leftover from memory/storage writes (see #3390). + if user_path.exists() and (user_path / "config.yaml").exists(): return user_path legacy_path = paths.agent_dir(name) - if legacy_path.exists(): + if legacy_path.exists() and (legacy_path / "config.yaml").exists(): return legacy_path return user_path diff --git a/backend/packages/harness/deerflow/config/memory_config.py b/backend/packages/harness/deerflow/config/memory_config.py index f9153262f..9a2c12952 100644 --- a/backend/packages/harness/deerflow/config/memory_config.py +++ b/backend/packages/harness/deerflow/config/memory_config.py @@ -1,5 +1,7 @@ """Configuration for memory mechanism.""" +from typing import Literal + from pydantic import BaseModel, Field @@ -60,6 +62,17 @@ class MemoryConfig(BaseModel): le=8000, description="Maximum tokens to use for memory injection", ) + token_counting: Literal["tiktoken", "char"] = Field( + default="tiktoken", + description=( + "Token counting strategy for memory-injection budgeting. " + "'tiktoken' is accurate but the encoding's BPE data may be " + "downloaded from a public network endpoint on first use, which " + "can block for a long time in network-restricted environments " + "(see issue #3402/#3429). 'char' uses a network-free " + "CJK-aware character-based estimate and never touches tiktoken." + ), + ) # Global configuration instance diff --git a/backend/tests/test_client.py b/backend/tests/test_client.py index 6c15c04e7..5b593bc86 100644 --- a/backend/tests/test_client.py +++ b/backend/tests/test_client.py @@ -2472,6 +2472,7 @@ class TestGatewayConformance: mem_cfg.fact_confidence_threshold = 0.7 mem_cfg.injection_enabled = True mem_cfg.max_injection_tokens = 2000 + mem_cfg.token_counting = "tiktoken" with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg): result = client.get_memory_config() @@ -2479,6 +2480,7 @@ class TestGatewayConformance: parsed = MemoryConfigResponse(**result) assert parsed.enabled is True assert parsed.max_facts == 100 + assert parsed.token_counting == "tiktoken" def test_get_memory_status(self, client): mem_cfg = MagicMock() @@ -2489,6 +2491,7 @@ class TestGatewayConformance: mem_cfg.fact_confidence_threshold = 0.7 mem_cfg.injection_enabled = True mem_cfg.max_injection_tokens = 2000 + mem_cfg.token_counting = "tiktoken" memory_data = { "version": "1.0", @@ -2514,6 +2517,7 @@ class TestGatewayConformance: parsed = MemoryStatusResponse(**result) assert parsed.config.enabled is True + assert parsed.config.token_counting == "tiktoken" assert parsed.data.version == "1.0" diff --git a/backend/tests/test_custom_agent.py b/backend/tests/test_custom_agent.py index 284908081..9f7a61ba2 100644 --- a/backend/tests/test_custom_agent.py +++ b/backend/tests/test_custom_agent.py @@ -203,6 +203,79 @@ class TestLoadAgentConfig: assert cfg.name == "legacy-agent" +# =========================================================================== +# 3b. resolve_agent_dir — memory-only directory fallback (#3390) +# =========================================================================== + + +class TestResolveAgentDirMemoryOnlyFallback: + """Regression tests for #3390. + + When memory is enabled, the first conversation creates a user-isolated + agent directory containing only ``memory.json`` (no ``config.yaml``). + On the next turn ``resolve_agent_dir`` must fall through to the legacy + shared layout instead of returning the incomplete user directory. + """ + + def test_user_dir_with_only_memory_falls_back_to_legacy(self, tmp_path): + """User dir has memory.json but no config.yaml → use legacy dir.""" + from deerflow.config.agents_config import resolve_agent_dir + + # Legacy agent with full config + legacy_dir = tmp_path / "agents" / "my-agent" + legacy_dir.mkdir(parents=True) + (legacy_dir / "config.yaml").write_text("name: my-agent\n", encoding="utf-8") + (legacy_dir / "SOUL.md").write_text("legacy soul", encoding="utf-8") + + # User dir created by memory write — no config.yaml + user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent" + user_dir.mkdir(parents=True) + (user_dir / "memory.json").write_text("{}", encoding="utf-8") + + with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"): + result = resolve_agent_dir("my-agent", user_id="u1") + + assert result == legacy_dir + + def test_user_dir_with_config_takes_priority(self, tmp_path): + """User dir with config.yaml should still win over legacy.""" + from deerflow.config.agents_config import resolve_agent_dir + + # Legacy + legacy_dir = tmp_path / "agents" / "my-agent" + legacy_dir.mkdir(parents=True) + (legacy_dir / "config.yaml").write_text("name: my-agent\n", encoding="utf-8") + + # User dir with full config (migrated) + user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent" + user_dir.mkdir(parents=True) + (user_dir / "config.yaml").write_text("name: my-agent\nmodel: gpt-4\n", encoding="utf-8") + (user_dir / "memory.json").write_text("{}", encoding="utf-8") + + with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"): + result = resolve_agent_dir("my-agent", user_id="u1") + + assert result == user_dir + + def test_load_config_falls_back_when_user_dir_is_memory_only(self, tmp_path): + """End-to-end: load_agent_config works when user dir only has memory.json.""" + config_dict = {"name": "my-agent", "description": "Legacy agent", "model": "deepseek-v3"} + _write_agent(tmp_path, "my-agent", config_dict) + + # Simulate memory write creating user dir without config + user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent" + user_dir.mkdir(parents=True) + (user_dir / "memory.json").write_text("{}", encoding="utf-8") + + with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"): + from deerflow.config.agents_config import load_agent_config + + cfg = load_agent_config("my-agent", user_id="u1") + + assert cfg.name == "my-agent" + assert cfg.model == "deepseek-v3" + + # =========================================================================== # 4. load_agent_soul # =========================================================================== diff --git a/backend/tests/test_lead_agent_prompt.py b/backend/tests/test_lead_agent_prompt.py index 78a5739f3..e72e180e7 100644 --- a/backend/tests/test_lead_agent_prompt.py +++ b/backend/tests/test_lead_agent_prompt.py @@ -192,7 +192,7 @@ def test_build_acp_section_uses_explicit_app_config_without_global_config(monkey def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch): explicit_config = SimpleNamespace( - memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234), + memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234, token_counting="tiktoken"), ) captured: dict[str, object] = {} @@ -204,9 +204,10 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke captured["user_id"] = user_id return {"facts": []} - def fake_format_memory_for_injection(memory_data, *, max_tokens): + def fake_format_memory_for_injection(memory_data, *, max_tokens, use_tiktoken=True): captured["memory_data"] = memory_data captured["max_tokens"] = max_tokens + captured["use_tiktoken"] = use_tiktoken return "remember this" monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config) @@ -223,6 +224,7 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke "user_id": "user-1", "memory_data": {"facts": []}, "max_tokens": 1234, + "use_tiktoken": True, } diff --git a/backend/tests/test_memory_prompt_injection.py b/backend/tests/test_memory_prompt_injection.py index 7c3ad85c4..c2b58b61f 100644 --- a/backend/tests/test_memory_prompt_injection.py +++ b/backend/tests/test_memory_prompt_injection.py @@ -39,7 +39,7 @@ def test_format_memory_sorts_facts_by_confidence_desc() -> None: def test_format_memory_respects_budget_when_adding_facts(monkeypatch) -> None: # Make token counting deterministic for this test by counting characters. - monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base": len(text)) + monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base", *, use_tiktoken=True: len(text)) memory_data = { "user": {}, diff --git a/backend/tests/test_stateless_runs_owner_isolation.py b/backend/tests/test_stateless_runs_owner_isolation.py new file mode 100644 index 000000000..60a20d17c --- /dev/null +++ b/backend/tests/test_stateless_runs_owner_isolation.py @@ -0,0 +1,173 @@ +"""Cross-user isolation for the stateless ``POST /api/runs/stream`` and ``/wait`` endpoints. + +These endpoints receive ``thread_id`` in the request body, so the +``@require_permission(owner_check=True)`` decorator — which reads the +``thread_id`` *path* parameter — cannot protect them. The owner check +lives inside ``services.start_run()`` instead; this suite pins it at the +HTTP layer so the gap cannot silently reopen. + +Strategy +-------- +``app.state.run_manager.create_or_reject`` raises ``ConflictError``, so a +request that *passes* the owner check deterministically short-circuits +with 409 before any agent code runs. The two outcomes: + +- 404 + ``create_or_reject`` never awaited -> blocked by the owner check +- 409 + ``create_or_reject`` awaited -> passed the owner check + +The thread store is a real ``MemoryThreadMetaStore`` (not a mock) so the +``check_access`` semantics under test — missing row allows, ``user_id`` +NULL allows, foreign owner denies — are exercised through real code. +""" + +from __future__ import annotations + +import asyncio +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest +from _router_auth_helpers import make_authed_test_app +from fastapi.testclient import TestClient +from langgraph.store.memory import InMemoryStore + +from app.gateway.auth.models import User +from app.gateway.routers import runs +from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config +from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore +from deerflow.runtime import ConflictError + +USER_A = User(email="owner-a@example.com", password_hash="x", system_role="user", id=uuid4()) +USER_B = User(email="intruder-b@example.com", password_hash="x", system_role="user", id=uuid4()) +INTERNAL_USER = SimpleNamespace(id="default", system_role="internal") + +THREAD_A = "thread-owned-by-a" +THREAD_SHARED = "thread-shared-null-owner" + + +@pytest.fixture(autouse=True) +def _stub_app_config(): + """Inject a minimal AppConfig so the allowed path (which builds a + RunContext via ``get_config()``) never reads config.yaml from disk.""" + set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}})) + yield + reset_app_config() + + +def _make_thread_store() -> MemoryThreadMetaStore: + store = MemoryThreadMetaStore(InMemoryStore()) + + async def _seed(): + await store.create(THREAD_A, user_id=str(USER_A.id)) + await store.create(THREAD_SHARED, user_id=None) + + asyncio.run(_seed()) + return store + + +@contextmanager +def _client(user): + """Yield a ``TestClient`` authenticated as ``user`` plus the stubbed + ``create_or_reject`` mock, closing the client (and its anyio portal / + background threads) on exit. + + ``create_or_reject`` raises ``ConflictError`` so a request that passes the + owner check short-circuits to 409 before any agent code runs. + """ + app = make_authed_test_app(user_factory=lambda: user) + app.include_router(runs.router) + app.state.thread_store = _make_thread_store() + app.state.stream_bridge = MagicMock() + app.state.checkpointer = MagicMock() + app.state.store = MagicMock() + app.state.run_events_config = None + app.state.run_event_store = MagicMock() + run_manager = MagicMock() + run_manager.create_or_reject = AsyncMock(side_effect=ConflictError("sentinel: owner check passed")) + app.state.run_manager = run_manager + with TestClient(app) as client: + yield client, run_manager.create_or_reject + + +def _body(thread_id: str | None = None) -> dict: + if thread_id is None: + return {} + return {"config": {"configurable": {"thread_id": thread_id}}} + + +# --------------------------------------------------------------------------- +# Denied: another user's thread +# --------------------------------------------------------------------------- + + +def test_stream_cross_user_returns_404(): + """User B cannot start a run on user A's thread via /api/runs/stream.""" + with _client(USER_B) as (client, create_or_reject): + response = client.post("/api/runs/stream", json=_body(THREAD_A)) + assert response.status_code == 404 + assert response.json()["detail"] == f"Thread {THREAD_A} not found" + create_or_reject.assert_not_awaited() + + +def test_wait_cross_user_returns_404_without_channel_values(): + """User B cannot read user A's checkpoint state via /api/runs/wait.""" + with _client(USER_B) as (client, create_or_reject): + response = client.post("/api/runs/wait", json=_body(THREAD_A)) + assert response.status_code == 404 + assert response.json() == {"detail": f"Thread {THREAD_A} not found"} + create_or_reject.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# Allowed: owner, fresh/untracked/shared threads, internal role +# --------------------------------------------------------------------------- + + +def test_stream_owner_passes_owner_check(): + """User A reaches run creation on their own thread (409 sentinel).""" + with _client(USER_A) as (client, create_or_reject): + response = client.post("/api/runs/stream", json=_body(THREAD_A)) + assert response.status_code == 409 + create_or_reject.assert_awaited() + + +def test_wait_owner_passes_owner_check(): + with _client(USER_A) as (client, create_or_reject): + response = client.post("/api/runs/wait", json=_body(THREAD_A)) + assert response.status_code == 409 + create_or_reject.assert_awaited() + + +def test_stream_without_thread_id_passes_owner_check(): + """Stateless run with no thread_id auto-creates a thread — never blocked.""" + with _client(USER_B) as (client, create_or_reject): + response = client.post("/api/runs/stream", json=_body()) + assert response.status_code == 409 + create_or_reject.assert_awaited() + + +def test_stream_untracked_thread_passes_owner_check(): + """A thread_id with no thread_meta row (untracked legacy) stays accessible.""" + with _client(USER_B) as (client, create_or_reject): + response = client.post("/api/runs/stream", json=_body("never-created-thread")) + assert response.status_code == 409 + create_or_reject.assert_awaited() + + +def test_stream_shared_thread_passes_owner_check(): + """A thread_meta row with user_id NULL (shared / pre-auth data) stays accessible.""" + with _client(USER_B) as (client, create_or_reject): + response = client.post("/api/runs/stream", json=_body(THREAD_SHARED)) + assert response.status_code == 409 + create_or_reject.assert_awaited() + + +def test_stream_internal_role_bypasses_owner_check(): + """IM channels run with the internal system role on behalf of platform + users whose threads they do not own — the owner check must not break them.""" + with _client(INTERNAL_USER) as (client, create_or_reject): + response = client.post("/api/runs/stream", json=_body(THREAD_A)) + assert response.status_code == 409 + create_or_reject.assert_awaited() diff --git a/backend/tests/test_tiktoken_cache_and_count_tokens.py b/backend/tests/test_tiktoken_cache_and_count_tokens.py index 730934039..045d9c13b 100644 --- a/backend/tests/test_tiktoken_cache_and_count_tokens.py +++ b/backend/tests/test_tiktoken_cache_and_count_tokens.py @@ -5,18 +5,22 @@ Verifies: - ``_count_tokens`` falls back to character estimation when tiktoken is unavailable or the encoding fails to load. - ``warm_tiktoken_cache`` populates the cache on success. +- An in-flight tiktoken load prevents duplicate blocking downloads. """ from __future__ import annotations +import threading from unittest import mock from deerflow.agents.memory.prompt import ( _count_tokens, _get_tiktoken_encoding, _tiktoken_encoding_cache, + format_memory_for_injection, warm_tiktoken_cache, ) +from deerflow.config.memory_config import MemoryConfig # --------------------------------------------------------------------------- # _get_tiktoken_encoding @@ -62,14 +66,103 @@ class TestGetTiktokenEncoding: assert enc is fake_enc tiktoken.get_encoding.assert_not_called() - def test_returns_none_and_warns_on_get_encoding_failure(self, monkeypatch): + def test_returns_none_and_caches_failure_sentinel(self, monkeypatch): + """A failed load is cached (with a timestamp) so it is not re-attempted (no repeated network download).""" _tiktoken_encoding_cache.pop("bogus_encoding", None) import tiktoken - monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=OSError("download failed"))) + get_encoding = mock.Mock(side_effect=OSError("download failed")) + monkeypatch.setattr(tiktoken, "get_encoding", get_encoding) + result = _get_tiktoken_encoding("bogus_encoding") assert result is None - assert "bogus_encoding" not in _tiktoken_encoding_cache + # The failure is remembered as a (None, timestamp) tuple. + assert "bogus_encoding" in _tiktoken_encoding_cache + cached = _tiktoken_encoding_cache["bogus_encoding"] + assert isinstance(cached, tuple) + assert cached[0] is None + + # A second call must NOT re-attempt get_encoding (avoids re-blocking on + # the network download in restricted environments — see #3429). + result2 = _get_tiktoken_encoding("bogus_encoding") + assert result2 is None + assert get_encoding.call_count == 1 + + # Cleanup module-level cache to avoid cross-test leakage. + _tiktoken_encoding_cache.pop("bogus_encoding", None) + + def test_failure_self_heals_after_cooldown(self, monkeypatch): + """After the retry cooldown expires, a transient failure is re-attempted and can recover.""" + _tiktoken_encoding_cache.pop("flaky_encoding", None) + import tiktoken + + fake_enc = mock.Mock() + # First call fails, second call (after cooldown) succeeds. + get_encoding = mock.Mock(side_effect=[OSError("transient outage"), fake_enc]) + monkeypatch.setattr(tiktoken, "get_encoding", get_encoding) + + # Initial failure is cached. + assert _get_tiktoken_encoding("flaky_encoding") is None + assert get_encoding.call_count == 1 + + # Within the cooldown window: no retry, immediate fallback. + assert _get_tiktoken_encoding("flaky_encoding") is None + assert get_encoding.call_count == 1 + + # Simulate the cooldown having elapsed by ageing the cached timestamp. + from deerflow.agents.memory import prompt as prompt_module + + _, _failed_at = _tiktoken_encoding_cache["flaky_encoding"] + _tiktoken_encoding_cache["flaky_encoding"] = ( + None, + _failed_at - prompt_module._TIKTOKEN_RETRY_COOLDOWN_S - 1, + ) + + # Now the load is retried and recovers to accurate counting. + assert _get_tiktoken_encoding("flaky_encoding") is fake_enc + assert get_encoding.call_count == 2 + + _tiktoken_encoding_cache.pop("flaky_encoding", None) + + def test_in_flight_load_returns_none_without_duplicate_get_encoding(self, monkeypatch): + """Concurrent callers must not start duplicate blocking BPE downloads.""" + _tiktoken_encoding_cache.pop("slow_encoding", None) + import tiktoken + + started = threading.Event() + release = threading.Event() + fake_enc = mock.Mock() + + def slow_get_encoding(_name): + started.set() + assert release.wait(timeout=2), "test timed out waiting to release slow get_encoding" + return fake_enc + + get_encoding = mock.Mock(side_effect=slow_get_encoding) + monkeypatch.setattr(tiktoken, "get_encoding", get_encoding) + + result: dict[str, object | None] = {} + + def load_encoding(): + result["encoding"] = _get_tiktoken_encoding("slow_encoding") + + thread = threading.Thread(target=load_encoding) + thread.start() + try: + assert started.wait(timeout=1), "slow get_encoding did not start" + + # While the first call is still blocked, a second call should see + # the in-flight sentinel and fall back immediately instead of + # starting another potentially long network download. + assert _get_tiktoken_encoding("slow_encoding") is None + assert get_encoding.call_count == 1 + finally: + release.set() + thread.join(timeout=2) + _tiktoken_encoding_cache.pop("slow_encoding", None) + + assert result["encoding"] is fake_enc + assert get_encoding.call_count == 1 # --------------------------------------------------------------------------- @@ -115,6 +208,45 @@ class TestCountTokens: result = _count_tokens(text, encoding_name="test_enc") assert result == len(text) // 4 + def test_use_tiktoken_false_returns_char_estimate_without_touching_tiktoken(self, monkeypatch): + """use_tiktoken=False must never call tiktoken (guarantees no BPE download).""" + # Spy on both the encoding loader and tiktoken.get_encoding directly. + get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called")) + loader_spy = mock.Mock(side_effect=AssertionError("_get_tiktoken_encoding must not be called")) + monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy) + monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", loader_spy) + + text = "Hello, world! This is a network-free count." + result = _count_tokens(text, use_tiktoken=False) + assert result == len(text) // 4 + get_encoding_spy.assert_not_called() + loader_spy.assert_not_called() + + def test_cjk_estimate_is_denser_than_plain_quarter(self, monkeypatch): + """CJK text should estimate more tokens than the plain len // 4 heuristic. + + CJK characters are ~2 chars/token, so the char-based estimate must not + under-fill the budget the way ``len(text) // 4`` would. + """ + monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False) + # "User prefers concise answers" rendered in CJK (Chinese) characters. + text = "\u7528\u6237\u504f\u597d\u7b80\u6d01\u7684\u4e2d\u6587\u56de\u7b54\u5e76\u5173\u6ce8\u91d1\u878d\u9886\u57df" + result = _count_tokens(text) + # Each CJK char counts as ~1/2 token (vs 1/4 for the plain heuristic). + assert result == len(text) // 2 + assert result > len(text) // 4 + + def test_cjk_estimate_combines_cjk_and_non_cjk_characters(self, monkeypatch): + """Mixed-language text should apply the CJK density only to CJK chars.""" + monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False) + # ASCII words mixed with CJK (Chinese) characters: "User" + "likes" + "Python and data analysis". + text = "User\u559c\u6b22Python\u548c\u6570\u636e\u5206\u6790" + cjk = sum(1 for ch in text if "\u4e00" <= ch <= "\u9fff") + + result = _count_tokens(text) + + assert result == (len(text) - cjk) // 4 + cjk // 2 + # --------------------------------------------------------------------------- # warm_tiktoken_cache @@ -146,3 +278,69 @@ class TestWarmTiktokenCache: def test_returns_false_when_tiktoken_unavailable(self, monkeypatch): monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False) assert warm_tiktoken_cache() is False + + +# --------------------------------------------------------------------------- +# format_memory_for_injection token_counting strategy +# --------------------------------------------------------------------------- + + +class TestFormatMemoryForInjectionTokenCounting: + """Verify the use_tiktoken flag is honoured end-to-end.""" + + @staticmethod + def _sample_memory() -> dict: + return { + "facts": [ + {"content": "User prefers concise answers.", "category": "preference", "confidence": 0.9}, + {"content": "User works in the finance domain.", "category": "context", "confidence": 0.8}, + ], + } + + def test_use_tiktoken_false_never_touches_tiktoken(self, monkeypatch): + """With use_tiktoken=False, formatting must not call tiktoken at all.""" + get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called")) + monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy) + + result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=False) + assert "User prefers concise answers." in result + get_encoding_spy.assert_not_called() + + def test_use_tiktoken_true_uses_encoding(self, monkeypatch): + """With use_tiktoken=True (default), the cached encoding is used for counting.""" + fake_enc = mock.Mock() + fake_enc.encode.side_effect = lambda text: list(range(len(text))) + monkeypatch.setattr( + "deerflow.agents.memory.prompt._get_tiktoken_encoding", + mock.Mock(return_value=fake_enc), + ) + + result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=True) + assert "User prefers concise answers." in result + assert fake_enc.encode.called + + def test_empty_memory_returns_empty(self): + assert format_memory_for_injection({}, max_tokens=2000, use_tiktoken=False) == "" + + +# --------------------------------------------------------------------------- +# MemoryConfig.token_counting +# --------------------------------------------------------------------------- + + +class TestMemoryConfigTokenCounting: + """Verify the new config field defaults and validation.""" + + def test_default_is_tiktoken(self): + """Default must remain tiktoken so existing deployments are unaffected.""" + assert MemoryConfig().token_counting == "tiktoken" + + def test_accepts_char(self): + assert MemoryConfig(token_counting="char").token_counting == "char" + + def test_rejects_invalid_value(self): + import pytest + from pydantic import ValidationError + + with pytest.raises(ValidationError): + MemoryConfig(token_counting="invalid") diff --git a/config.example.yaml b/config.example.yaml index 1a20e23fd..73af462f6 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1024,6 +1024,15 @@ memory: fact_confidence_threshold: 0.7 # Minimum confidence for storing facts injection_enabled: true # Whether to inject memory into system prompt max_injection_tokens: 2000 # Maximum tokens for memory injection + # Token counting strategy for memory-injection budgeting: + # tiktoken (default) - accurate, but the encoding's BPE data may be + # downloaded from a public network endpoint on first use. In + # network-restricted environments this download can block for a long + # time (see issues #3402 / #3429). Pre-cache the encoding or set this + # to "char" to avoid it. + # char - network-free CJK-aware character-based estimate; never touches + # tiktoken. Slightly less precise budgeting, zero network I/O. + token_counting: tiktoken # ============================================================================ # Custom Agent Management API diff --git a/frontend/src/app/workspace/chats/page.tsx b/frontend/src/app/workspace/chats/page.tsx index a3cea55ed..faaeb8ff2 100644 --- a/frontend/src/app/workspace/chats/page.tsx +++ b/frontend/src/app/workspace/chats/page.tsx @@ -1,8 +1,9 @@ "use client"; import Link from "next/link"; -import { useEffect, useMemo, useState } from "react"; +import { useEffect, useMemo, useRef, useState } from "react"; +import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { ScrollArea } from "@/components/ui/scroll-area"; import { @@ -15,7 +16,7 @@ import { WorkspaceHeader, } from "@/components/workspace/workspace-container"; import { useI18n } from "@/core/i18n/hooks"; -import { useThreads } from "@/core/threads/hooks"; +import { useInfiniteThreads } from "@/core/threads/hooks"; import { channelSourceOfThread, pathOfThread, @@ -25,18 +26,52 @@ import { formatTimeAgo } from "@/core/utils/datetime"; export default function ChatsPage() { const { t } = useI18n(); - const { data: threads } = useThreads(); + const { + data: infiniteThreads, + fetchNextPage, + hasNextPage, + isFetchingNextPage, + } = useInfiniteThreads(); + const threads = useMemo( + () => infiniteThreads?.pages.flat() ?? [], + [infiniteThreads], + ); const [search, setSearch] = useState(""); + const isSearching = search.trim().length > 0; useEffect(() => { document.title = `${t.pages.chats} - ${t.pages.appName}`; }, [t.pages.chats, t.pages.appName]); const filteredThreads = useMemo(() => { - return threads?.filter((thread) => { + return threads.filter((thread) => { return titleOfThread(thread).toLowerCase().includes(search.toLowerCase()); }); }, [threads, search]); + + // Sentinel-based auto load-more for the unfiltered list (issue #3482). + // In search mode we deliberately do NOT auto-paginate, otherwise an empty + // filtered view would keep the sentinel in the viewport and drain the + // entire backend list one page at a time. Searching falls back to an + // explicit button so users can still reach older conversations on demand. + const sentinelRef = useRef(null); + useEffect(() => { + const element = sentinelRef.current; + if (!element || !hasNextPage || isSearching) { + return; + } + const observer = new IntersectionObserver( + ([entry]) => { + if (entry?.isIntersecting && hasNextPage && !isFetchingNextPage) { + void fetchNextPage(); + } + }, + { rootMargin: "200px 0px 200px 0px" }, + ); + observer.observe(element); + return () => observer.disconnect(); + }, [fetchNextPage, hasNextPage, isFetchingNextPage, isSearching]); + return ( @@ -55,7 +90,7 @@ export default function ChatsPage() {
- {filteredThreads?.map((thread) => { + {filteredThreads.map((thread) => { const channelSource = channelSourceOfThread(thread); return ( @@ -79,6 +114,28 @@ export default function ChatsPage() { ); })} + {hasNextPage && !isSearching && ( +
diff --git a/frontend/src/components/workspace/recent-chat-list.tsx b/frontend/src/components/workspace/recent-chat-list.tsx index 42b374dd0..65a76f25c 100644 --- a/frontend/src/components/workspace/recent-chat-list.tsx +++ b/frontend/src/components/workspace/recent-chat-list.tsx @@ -11,7 +11,7 @@ import { } from "lucide-react"; import Link from "next/link"; import { useParams, usePathname, useRouter } from "next/navigation"; -import { useCallback, useState } from "react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { Button } from "@/components/ui/button"; @@ -51,8 +51,8 @@ import { } from "@/core/threads/export"; import { useDeleteThread, + useInfiniteThreads, useRenameThread, - useThreads, } from "@/core/threads/hooks"; import type { AgentThread, AgentThreadState } from "@/core/threads/types"; import { @@ -74,7 +74,35 @@ export function RecentChatList() { thread_id: string; agent_name?: string; }>(); - const { data: threads = [] } = useThreads(); + const { + data: infiniteThreads, + fetchNextPage, + hasNextPage, + isFetchingNextPage, + } = useInfiniteThreads(); + const threads = useMemo( + () => infiniteThreads?.pages.flat() ?? [], + [infiniteThreads], + ); + + const sentinelRef = useRef(null); + useEffect(() => { + const element = sentinelRef.current; + if (!element || !hasNextPage) { + return; + } + const observer = new IntersectionObserver( + ([entry]) => { + if (entry?.isIntersecting && hasNextPage && !isFetchingNextPage) { + void fetchNextPage(); + } + }, + { rootMargin: "120px 0px 120px 0px" }, + ); + observer.observe(element); + return () => observer.disconnect(); + }, [fetchNextPage, hasNextPage, isFetchingNextPage]); + const { mutate: deleteThread } = useDeleteThread(); const { mutate: renameThread } = useRenameThread(); @@ -287,6 +315,28 @@ export function RecentChatList() { ); })} + {hasNextPage && ( + <> + + diff --git a/frontend/src/core/i18n/locales/en-US.ts b/frontend/src/core/i18n/locales/en-US.ts index 53bb721c2..49fb3ad90 100644 --- a/frontend/src/core/i18n/locales/en-US.ts +++ b/frontend/src/core/i18n/locales/en-US.ts @@ -253,6 +253,9 @@ export const enUS: Translations = { // Chats chats: { searchChats: "Search chats", + loadMoreToSearch: "Load more to search older conversations", + loadingMore: "Loading more...", + loadOlderChats: "Load older chats", }, // Channels diff --git a/frontend/src/core/i18n/locales/types.ts b/frontend/src/core/i18n/locales/types.ts index de36a2b53..a2590a68a 100644 --- a/frontend/src/core/i18n/locales/types.ts +++ b/frontend/src/core/i18n/locales/types.ts @@ -184,6 +184,9 @@ export interface Translations { // Chats chats: { searchChats: string; + loadMoreToSearch: string; + loadingMore: string; + loadOlderChats: string; }; // Channels diff --git a/frontend/src/core/i18n/locales/zh-CN.ts b/frontend/src/core/i18n/locales/zh-CN.ts index 462ef0bb9..f821bdc30 100644 --- a/frontend/src/core/i18n/locales/zh-CN.ts +++ b/frontend/src/core/i18n/locales/zh-CN.ts @@ -241,6 +241,9 @@ export const zhCN: Translations = { // Chats chats: { searchChats: "搜索对话", + loadMoreToSearch: "加载更多以搜索更早的对话", + loadingMore: "正在加载...", + loadOlderChats: "加载更早的对话", }, // Channels diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index b8f6ceee5..0f5b9c7b0 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -1,7 +1,10 @@ import type { AIMessage, Message, Run } from "@langchain/langgraph-sdk"; +import type { ThreadsClient } from "@langchain/langgraph-sdk/client"; import { useStream } from "@langchain/langgraph-sdk/react"; import { type QueryClient, + type InfiniteData, + useInfiniteQuery, useMutation, useQuery, useQueryClient, @@ -315,6 +318,56 @@ export function upsertThreadInSearchCache( ); } +export function upsertThreadInInfiniteCache( + queryClient: QueryClient, + thread: AgentThread, +) { + queryClient.setQueriesData( + { + queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX, + exact: false, + }, + (oldData: InfiniteData | undefined) => { + if (!oldData) { + return oldData; + } + + const merged = oldData.pages.map((page) => + page.map((t) => + t.thread_id === thread.thread_id + ? { + ...thread, + ...t, + metadata: { + ...(thread.metadata ?? {}), + ...(t.metadata ?? {}), + }, + values: { + ...thread.values, + ...t.values, + }, + } + : t, + ), + ); + + const exists = merged.some((page) => + page.some((t) => t.thread_id === thread.thread_id), + ); + if (exists) { + return { ...oldData, pages: merged }; + } + + const firstPage = merged[0] ?? []; + const restPages = merged.slice(1); + return { + ...oldData, + pages: [[thread, ...firstPage], ...restPages], + }; + }, + ); +} + function getStreamErrorMessage(error: unknown): string { if (typeof error === "string" && error.trim()) { return error; @@ -421,6 +474,19 @@ export function useThreadStream({ }, interrupts: {}, }); + upsertThreadInInfiniteCache(queryClient, { + thread_id: meta.thread_id, + created_at: now, + updated_at: now, + metadata: context.agent_name ? { agent_name: context.agent_name } : {}, + status: "busy", + values: { + title: t.pages.newChat, + messages: [], + artifacts: [], + }, + interrupts: {}, + }); if (context.agent_name && !isMock) { void getAPIClient() .threads.update(meta.thread_id, { @@ -492,6 +558,27 @@ export function useThreadStream({ }); }, ); + const nextTitle: string = update.title; + void queryClient.setQueriesData( + { + queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX, + exact: false, + }, + (oldData: InfiniteData | undefined) => + mapInfiniteThreadsCache( + oldData, + (t): AgentThread => + t.thread_id === threadIdRef.current + ? { + ...t, + values: { + ...t.values, + title: nextTitle, + }, + } + : t, + ), + ); } } }, @@ -546,6 +633,9 @@ export function useThreadStream({ .filter((id): id is string => Boolean(id)), ); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); + void queryClient.invalidateQueries({ + queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX, + }); if (threadIdRef.current && !isMock) { void queryClient.invalidateQueries({ queryKey: threadTokenUsageQueryKey(threadIdRef.current), @@ -805,6 +895,9 @@ export function useThreadStream({ }, ); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); + void queryClient.invalidateQueries({ + queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX, + }); } catch (error) { setOptimisticMessages([]); setIsUploading(false); @@ -1046,6 +1139,86 @@ export function useThreads( }); } +export const INFINITE_THREADS_PAGE_SIZE = 50; + +export const INFINITE_THREADS_QUERY_KEY_PREFIX = [ + "threads", + "searchInfinite", +] as const; + +type InfiniteThreadsParams = Omit< + Parameters[0], + "limit" | "offset" +>; + +export function getInfiniteThreadsNextPageParam( + lastPage: AgentThread[], + allPages: AgentThread[][], + pageSize: number = INFINITE_THREADS_PAGE_SIZE, +): number | undefined { + if (lastPage.length < pageSize) { + return undefined; + } + return allPages.reduce((sum, page) => sum + page.length, 0); +} + +export function mapInfiniteThreadsCache( + oldData: InfiniteData | undefined, + mapper: (thread: AgentThread) => AgentThread, +): InfiniteData | undefined { + if (!oldData) { + return oldData; + } + return { + ...oldData, + pages: oldData.pages.map((page) => page.map(mapper)), + }; +} + +export function filterInfiniteThreadsCache( + oldData: InfiniteData | undefined, + predicate: (thread: AgentThread) => boolean, +): InfiniteData | undefined { + if (!oldData) { + return oldData; + } + return { + ...oldData, + pages: oldData.pages.map((page) => page.filter(predicate)), + }; +} + +export function useInfiniteThreads( + params: InfiniteThreadsParams = { + sortBy: "updated_at", + sortOrder: "desc", + select: ["thread_id", "updated_at", "values", "metadata"], + }, +) { + const apiClient = getAPIClient(); + return useInfiniteQuery< + AgentThread[], + Error, + InfiniteData, + readonly unknown[], + number + >({ + queryKey: [...INFINITE_THREADS_QUERY_KEY_PREFIX, params], + initialPageParam: 0, + queryFn: async ({ pageParam }) => { + const response = (await apiClient.threads.search({ + ...params, + limit: INFINITE_THREADS_PAGE_SIZE, + offset: pageParam, + })) as AgentThread[]; + return response; + }, + getNextPageParam: (lastPage, allPages) => + getInfiniteThreadsNextPageParam(lastPage, allPages), + refetchOnWindowFocus: false, + }); +} + export function useThreadRuns( threadId?: string, { enabled = true }: { enabled?: boolean } = {}, @@ -1129,9 +1302,21 @@ export function useDeleteThread() { return oldData.filter((t) => t.thread_id !== threadId); }, ); + queryClient.setQueriesData( + { + queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX, + exact: false, + }, + (oldData: InfiniteData | undefined) => + filterInfiniteThreadsCache(oldData, (t) => t.thread_id !== threadId), + ); }, + onSettled() { void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); + void queryClient.invalidateQueries({ + queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX, + }); }, }); } @@ -1172,6 +1357,24 @@ export function useRenameThread() { }); }, ); + queryClient.setQueriesData( + { + queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX, + exact: false, + }, + (oldData: InfiniteData | undefined) => + mapInfiniteThreadsCache(oldData, (t) => + t.thread_id === threadId + ? { + ...t, + values: { + ...t.values, + title, + }, + } + : t, + ), + ); }, }); } diff --git a/frontend/tests/e2e/thread-list-infinite-scroll.spec.ts b/frontend/tests/e2e/thread-list-infinite-scroll.spec.ts new file mode 100644 index 000000000..f0d75ecc0 --- /dev/null +++ b/frontend/tests/e2e/thread-list-infinite-scroll.spec.ts @@ -0,0 +1,123 @@ +import { expect, test } from "@playwright/test"; + +import { mockLangGraphAPI } from "./utils/mock-api"; + +// Issue #3482: the sidebar's "Recent chats" and the /workspace/chats list +// page used to stop at the first 50 threads with no way to load more. +// `useInfiniteThreads()` + a sentinel near the bottom of each list now +// pages through the backend. + +const TOTAL_THREADS = 120; +const PAGE_SIZE = 50; + +const THREADS = Array.from({ length: TOTAL_THREADS }, (_, i) => { + // Pad index so titles sort deterministically as strings. The thread-search + // mock returns threads in the order provided, so paging boundaries are + // stable across runs. + const index = String(i + 1).padStart(3, "0"); + return { + thread_id: `00000000-0000-0000-0000-0000000${index.padStart(5, "0")}`, + title: `Conversation ${index}`, + updated_at: `2025-06-${String((i % 28) + 1).padStart(2, "0")}T12:00:00Z`, + }; +}); + +const FIRST_PAGE_LAST = `Conversation ${String(PAGE_SIZE).padStart(3, "0")}`; +const SECOND_PAGE_FIRST = `Conversation ${String(PAGE_SIZE + 1).padStart(3, "0")}`; + +test.describe("Thread list infinite scroll (issue #3482)", () => { + test("chats list page loads more threads when scrolling to the bottom", async ({ + page, + }) => { + mockLangGraphAPI(page, { threads: THREADS }); + + await page.goto("/workspace/chats"); + + const main = page.locator("main"); + + // First page renders. + await expect(main.getByText(FIRST_PAGE_LAST)).toBeVisible({ + timeout: 15_000, + }); + // Items past the first page have not been fetched yet. + await expect(main.getByText(SECOND_PAGE_FIRST)).toHaveCount(0); + + // Scrolling the sentinel into view triggers the next page. + const sentinel = page.getByTestId("chats-page-sentinel"); + await sentinel.scrollIntoViewIfNeeded(); + + await expect(main.getByText(SECOND_PAGE_FIRST)).toBeVisible({ + timeout: 15_000, + }); + }); + + test("sidebar recent chats loads more threads when scrolling to the bottom", async ({ + page, + }) => { + mockLangGraphAPI(page, { threads: THREADS }); + + await page.goto("/workspace/chats/new"); + + // The 50th thread (end of first page) appears in the sidebar. + await expect(page.getByText(FIRST_PAGE_LAST).first()).toBeVisible({ + timeout: 15_000, + }); + // The 51st has not been fetched yet. + await expect(page.getByText(SECOND_PAGE_FIRST)).toHaveCount(0); + + // Scroll the sidebar sentinel into view to trigger the next page. + const sentinel = page.getByTestId("recent-chat-list-sentinel"); + await sentinel.scrollIntoViewIfNeeded(); + + await expect(page.getByText(SECOND_PAGE_FIRST).first()).toBeVisible({ + timeout: 15_000, + }); + }); + + test("chats list page does NOT auto-paginate while a search filter is active", async ({ + page, + }) => { + // Count search requests via a passive request observer. Using + // page.route() here would race with mockLangGraphAPI's fulfill route + // (Playwright matches routes in reverse registration order), so the + // counter could miss real requests. page.on('request') is a pure + // observer and never interferes with routing. + let searchRequestCount = 0; + page.on("request", (request) => { + if (request.url().includes("/api/langgraph/threads/search")) { + searchRequestCount += 1; + } + }); + + mockLangGraphAPI(page, { threads: THREADS }); + + await page.goto("/workspace/chats"); + + // Wait for the first page to render so we have a baseline count. + await expect(page.locator("main").getByText(FIRST_PAGE_LAST)).toBeVisible({ + timeout: 15_000, + }); + const baselineRequests = searchRequestCount; + + // Type a query that matches nothing in the first page (and nothing at + // all, since titles are deterministic). + await page + .getByPlaceholder("Search chats") + .fill("zzz-no-such-conversation"); + + // The auto-sentinel must be gone; an explicit button takes its place. + await expect(page.getByTestId("chats-page-sentinel")).toHaveCount(0); + await expect(page.getByTestId("chats-page-load-more")).toBeVisible(); + + // Give the IntersectionObserver a couple of frames to misbehave if the + // guard regresses. No additional /threads/search calls should fire. + await page.waitForTimeout(500); + expect(searchRequestCount).toBe(baselineRequests); + + // The explicit button still works as an escape hatch. + await page.getByTestId("chats-page-load-more").click(); + await expect + .poll(() => searchRequestCount, { timeout: 10_000 }) + .toBeGreaterThan(baselineRequests); + }); +}); diff --git a/frontend/tests/e2e/utils/mock-api.ts b/frontend/tests/e2e/utils/mock-api.ts index 1fbe3f348..b476be1ab 100644 --- a/frontend/tests/e2e/utils/mock-api.ts +++ b/frontend/tests/e2e/utils/mock-api.ts @@ -86,7 +86,7 @@ export function mockLangGraphAPI(page: Page, options?: MockAPIOptions) { const skills = options?.skills ?? DEFAULT_SKILLS; // Thread search — sidebar thread list & chats list page - void page.route("**/api/langgraph/threads/search", (route) => { + void page.route("**/api/langgraph/threads/search", async (route) => { const body = threads.map((t) => ({ thread_id: t.thread_id, created_at: "2025-01-01T00:00:00Z", @@ -98,10 +98,33 @@ export function mockLangGraphAPI(page: Page, options?: MockAPIOptions) { status: "idle", values: { title: t.title ?? "Untitled" }, })); + + let limit: number | undefined; + let offset = 0; + try { + const postData = route.request().postDataJSON() as { + limit?: number; + offset?: number; + } | null; + if (postData) { + if (typeof postData.limit === "number") { + limit = postData.limit; + } + if (typeof postData.offset === "number") { + offset = postData.offset; + } + } + } catch { + // No / invalid JSON body — fall back to returning the full list. + } + + const sliced = + typeof limit === "number" ? body.slice(offset, offset + limit) : body; + return route.fulfill({ status: 200, contentType: "application/json", - body: JSON.stringify(body), + body: JSON.stringify(sliced), }); }); diff --git a/frontend/tests/unit/core/threads/infinite.test.ts b/frontend/tests/unit/core/threads/infinite.test.ts new file mode 100644 index 000000000..d040ff0a1 --- /dev/null +++ b/frontend/tests/unit/core/threads/infinite.test.ts @@ -0,0 +1,228 @@ +import { QueryClient, type InfiniteData } from "@tanstack/react-query"; +import { describe, expect, test } from "vitest"; + +import { + filterInfiniteThreadsCache, + getInfiniteThreadsNextPageParam, + INFINITE_THREADS_PAGE_SIZE, + INFINITE_THREADS_QUERY_KEY_PREFIX, + mapInfiniteThreadsCache, + upsertThreadInInfiniteCache, +} from "@/core/threads/hooks"; +import type { AgentThread } from "@/core/threads/types"; + +// Issue #3482: the sidebar and /workspace/chats list used to be capped at +// 50 threads because `useThreads()` exits as soon as `threads.length >= +// params.limit`. These pure helpers back the `useInfiniteThreads()` +// pagination logic and the mirrored cache writes that keep rename / delete +// / stream-finish in sync with both the legacy array cache and the new +// infinite cache. + +function makeThread(id: string, title = `Title ${id}`): AgentThread { + return { + thread_id: id, + created_at: "2025-01-01T00:00:00Z", + updated_at: "2025-01-01T00:00:00Z", + metadata: {}, + status: "idle", + values: { title }, + } as unknown as AgentThread; +} + +function makePage(start: number, size: number): AgentThread[] { + return Array.from({ length: size }, (_, i) => makeThread(`t-${start + i}`)); +} + +function makeInfiniteData(pages: AgentThread[][]): InfiniteData { + return { + pages, + pageParams: pages.map((_, i) => i * INFINITE_THREADS_PAGE_SIZE), + }; +} + +describe("getInfiniteThreadsNextPageParam", () => { + test("returns next offset when the last page is full", () => { + const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE); + expect(getInfiniteThreadsNextPageParam(page1, [page1])).toBe( + INFINITE_THREADS_PAGE_SIZE, + ); + }); + + test("returns next offset across multiple full pages", () => { + const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE); + const page2 = makePage( + INFINITE_THREADS_PAGE_SIZE, + INFINITE_THREADS_PAGE_SIZE, + ); + expect(getInfiniteThreadsNextPageParam(page2, [page1, page2])).toBe( + INFINITE_THREADS_PAGE_SIZE * 2, + ); + }); + + test("returns undefined when the last page is short (end of list)", () => { + const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE); + const page2 = makePage(INFINITE_THREADS_PAGE_SIZE, 10); + expect( + getInfiniteThreadsNextPageParam(page2, [page1, page2]), + ).toBeUndefined(); + }); + + test("returns undefined when the last page is empty", () => { + const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE); + expect(getInfiniteThreadsNextPageParam([], [page1, []])).toBeUndefined(); + }); + + test("respects a custom page size", () => { + const page1 = makePage(0, 5); + expect(getInfiniteThreadsNextPageParam(page1, [page1], 5)).toBe(5); + expect(getInfiniteThreadsNextPageParam(page1, [page1], 10)).toBeUndefined(); + }); +}); + +describe("mapInfiniteThreadsCache", () => { + test("returns undefined when oldData is undefined", () => { + expect(mapInfiniteThreadsCache(undefined, (t) => t)).toBeUndefined(); + }); + + test("updates the matching thread across multiple pages", () => { + const page1 = [makeThread("a"), makeThread("b")]; + const page2 = [makeThread("c"), makeThread("d")]; + const data = makeInfiniteData([page1, page2]); + + const updated = mapInfiniteThreadsCache(data, (t) => + t.thread_id === "c" + ? { ...t, values: { ...t.values, title: "renamed" } } + : t, + ); + + expect(updated?.pages[0]?.[0]?.values?.title).toBe("Title a"); + expect(updated?.pages[1]?.[0]?.thread_id).toBe("c"); + expect(updated?.pages[1]?.[0]?.values?.title).toBe("renamed"); + expect(updated?.pages[1]?.[1]?.values?.title).toBe("Title d"); + }); + + test("preserves pageParams", () => { + const data = makeInfiniteData([[makeThread("a")]]); + const updated = mapInfiniteThreadsCache(data, (t) => t); + expect(updated?.pageParams).toEqual(data.pageParams); + }); +}); + +describe("filterInfiniteThreadsCache", () => { + test("returns undefined when oldData is undefined", () => { + expect(filterInfiniteThreadsCache(undefined, () => true)).toBeUndefined(); + }); + + test("removes matching threads across all pages", () => { + const page1 = [makeThread("a"), makeThread("b")]; + const page2 = [makeThread("b"), makeThread("c")]; + const data = makeInfiniteData([page1, page2]); + + const filtered = filterInfiniteThreadsCache( + data, + (t) => t.thread_id !== "b", + ); + + expect(filtered?.pages[0]?.map((t) => t.thread_id)).toEqual(["a"]); + expect(filtered?.pages[1]?.map((t) => t.thread_id)).toEqual(["c"]); + }); + + test("keeps an emptied page as an empty array (does not drop the page)", () => { + const page1 = [makeThread("a")]; + const page2 = [makeThread("b")]; + const data = makeInfiniteData([page1, page2]); + + const filtered = filterInfiniteThreadsCache( + data, + (t) => t.thread_id !== "a", + ); + + expect(filtered?.pages).toHaveLength(2); + expect(filtered?.pages[0]).toEqual([]); + expect(filtered?.pages[1]?.[0]?.thread_id).toBe("b"); + }); + + test("does not regress next offset when an earlier page has been shrunk by a delete", () => { + // Simulate two full pages already loaded. + const page1 = Array.from({ length: 50 }, (_, i) => ({ + thread_id: `a${i}`, + })); + const page2 = Array.from({ length: 50 }, (_, i) => ({ + thread_id: `b${i}`, + })); + + // Offset right after fetching page 2 (this is the value TanStack Query + // freezes into pageParams). + const offsetAfterPage2 = getInfiniteThreadsNextPageParam( + page2 as unknown as AgentThread[], + [page1, page2] as unknown as AgentThread[][], + ); + expect(offsetAfterPage2).toBe(100); + + // Now a delete mutation runs filterInfiniteThreadsCache and shrinks + // page 1 from 50 to 49 entries. TanStack does NOT re-invoke + // getNextPageParam on cache mutations; the previously-computed offset + // (100) remains the param for the next fetchNextPage() call, so the + // helper is consistent with how the library uses its return value. + const shrunkPage1 = page1.slice(0, 49); + const recomputed = getInfiniteThreadsNextPageParam( + page2 as unknown as AgentThread[], + [shrunkPage1, page2] as unknown as AgentThread[][], + ); + // We document the recomputed value for completeness, but in practice + // useDeleteThread invalidates the query in onSettled, so pages are + // refetched from offset 0 rather than relying on this number. + expect(recomputed).toBe(99); + }); +}); + +describe("upsertThreadInInfiniteCache", () => { + function seedClient(initial?: InfiniteData): QueryClient { + const client = new QueryClient(); + if (initial) { + client.setQueryData([...INFINITE_THREADS_QUERY_KEY_PREFIX, {}], initial); + } + return client; + } + + function readCache( + client: QueryClient, + ): InfiniteData | undefined { + return client.getQueryData([...INFINITE_THREADS_QUERY_KEY_PREFIX, {}]); + } + + test("no-op when the infinite cache has not been initialised yet", () => { + const client = seedClient(); + upsertThreadInInfiniteCache(client, makeThread("new")); + expect(readCache(client)).toBeUndefined(); + }); + + test("prepends a brand-new thread to the first page", () => { + const client = seedClient({ + pages: [[makeThread("a"), makeThread("b")]], + pageParams: [0], + }); + upsertThreadInInfiniteCache(client, makeThread("new")); + const cache = readCache(client); + expect(cache?.pages[0]?.map((t) => t.thread_id)).toEqual(["new", "a", "b"]); + }); + + test("merges into the existing entry instead of duplicating it", () => { + const existing = makeThread("a", "Old title"); + const client = seedClient({ + pages: [[existing, makeThread("b")]], + pageParams: [0], + }); + // Simulate an onCreated upsert that races with a thread already in cache: + // the cache copy should win for title/metadata (it represents later state), + // but no duplicate row should appear. + upsertThreadInInfiniteCache(client, { + ...makeThread("a", "New title"), + status: "busy", + }); + const cache = readCache(client); + const ids = cache?.pages[0]?.map((t) => t.thread_id); + expect(ids).toEqual(["a", "b"]); + expect(cache?.pages[0]?.[0]?.values.title).toBe("Old title"); + }); +});