feat(memory): add memory.token_counting config to avoid tiktoken network dependency (#3429) (#3465)

* feat(memory): add memory.token_counting config to avoid tiktoken network dependency (#3429)

Add a `memory.token_counting` option (`tiktoken` | `char`) so deployments in
network-restricted environments can opt out of tiktoken entirely. In `char`
mode the memory-injection budget uses a network-free character-based estimate
and never triggers the BPE download from openaipublic.blob.core.windows.net,
which could otherwise block for tens of minutes (see #3402).

Also harden the default `tiktoken` path:
- cache an in-flight LOADING sentinel so concurrent callers fall back
  immediately instead of spawning more blocking get_encoding threads when the
  first load is still running (e.g. under the 5s startup warm-up timeout);
- cache failures with a timestamp and retry after a cooldown so a transient
  network outage self-heals back to accurate counting without a restart;
- skip startup warm-up entirely in char mode.

The new config is surfaced via the memory config API and config.example.yaml
(config_version bumped). Default remains `tiktoken`, so existing deployments
are unaffected.

* fix(memory): use CJK-aware char token estimate and address review feedback

- Replace the flat len(text)//4 fallback with a CJK-aware estimate so
  Chinese/Japanese/Korean memory content does not over-fill the injection budget
- Document the internal tiktoken retry cooldown and char-mode escape hatch
- Sync CLAUDE.md / config.example.yaml / MEMORY_IMPROVEMENTS.md wording
- Fix MemoryConfigResponse mocks/assertions and add CJK estimate tests
This commit is contained in:
Ryker_Feng
2026-06-10 23:26:15 +08:00
committed by GitHub
parent ba9cc5e972
commit 167ef4512f
13 changed files with 364 additions and 43 deletions
+7
View File
@@ -429,6 +429,12 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append 4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt 5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
**Token counting** (`packages/harness/deerflow/agents/memory/prompt.py`):
- `_count_tokens` budgets the injection. In default `tiktoken` mode, the encoding is loaded lazily and cached.
- Failed tiktoken loads are cached with a timestamp. During the fixed cooldown (`_TIKTOKEN_RETRY_COOLDOWN_S`, 600s), callers fall back to char estimation immediately instead of re-triggering the blocking BPE download; after the cooldown, transient outages can self-heal without a restart.
- In-flight loads are cached as a LOADING sentinel so concurrent callers fall back instead of spawning more blocking threads.
- Set `memory.token_counting: char` to skip tiktoken entirely and use the network-free CJK-aware char estimate.
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`. Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
**Configuration** (`config.yaml``memory`): **Configuration** (`config.yaml``memory`):
@@ -438,6 +444,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
- `model_name` - LLM for updates (null = default model) - `model_name` - LLM for updates (null = default model)
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7) - `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
- `max_injection_tokens` - Token limit for prompt injection (2000) - `max_injection_tokens` - Token limit for prompt injection (2000)
- `token_counting` - Token counting strategy for the injection budget: `tiktoken` (default, accurate but may download BPE data from a public endpoint on first use — can block for a long time in network-restricted environments, see issues #3402/#3429) or `char` (network-free CJK-aware char estimate, never touches tiktoken)
### Reflection System (`packages/harness/deerflow/reflection/`) ### Reflection System (`packages/harness/deerflow/reflection/`)
+20 -14
View File
@@ -184,21 +184,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Pre-warm tiktoken encoding cache so the first memory-injection request # Pre-warm tiktoken encoding cache so the first memory-injection request
# never blocks on the BPE data download (which hits an OpenAI/Azure URL # never blocks on the BPE data download (which hits an OpenAI/Azure URL
# that may be unreachable in restricted networks — see issue #3402). # that may be unreachable in restricted networks — see issue #3402).
try: # When memory.token_counting is "char", token counting never touches
from deerflow.agents.memory.prompt import warm_tiktoken_cache # tiktoken, so skip the warm-up entirely (avoids even the 5s probe in
# network-restricted deployments — see issue #3429).
if startup_config.memory.token_counting == "char":
logger.info("memory.token_counting='char'; skipping tiktoken warm-up (network-free token estimation)")
else:
try:
from deerflow.agents.memory.prompt import warm_tiktoken_cache
warmed = await asyncio.wait_for( warmed = await asyncio.wait_for(
asyncio.to_thread(warm_tiktoken_cache), asyncio.to_thread(warm_tiktoken_cache),
timeout=5, timeout=5,
) )
if warmed: if warmed:
logger.info("tiktoken encoding cache warmed successfully") logger.info("tiktoken encoding cache warmed successfully")
else: else:
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback") logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback until tiktoken loads successfully")
except TimeoutError: except TimeoutError:
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback") logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback until tiktoken loads successfully")
except Exception: except Exception:
logger.warning("tiktoken warm-up skipped", exc_info=True) logger.warning("tiktoken warm-up skipped", exc_info=True)
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store) # Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
async with langgraph_runtime(app, startup_config): async with langgraph_runtime(app, startup_config):
+5 -1
View File
@@ -98,6 +98,7 @@ class MemoryConfigResponse(BaseModel):
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts") fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
injection_enabled: bool = Field(..., description="Whether memory injection is enabled") injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection") max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
token_counting: str = Field(..., description="Token counting strategy for memory injection ('tiktoken' or 'char')")
class MemoryStatusResponse(BaseModel): class MemoryStatusResponse(BaseModel):
@@ -310,7 +311,8 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
"max_facts": 100, "max_facts": 100,
"fact_confidence_threshold": 0.7, "fact_confidence_threshold": 0.7,
"injection_enabled": true, "injection_enabled": true,
"max_injection_tokens": 2000 "max_injection_tokens": 2000,
"token_counting": "tiktoken"
} }
``` ```
""" """
@@ -323,6 +325,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
fact_confidence_threshold=config.fact_confidence_threshold, fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled, injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens, max_injection_tokens=config.max_injection_tokens,
token_counting=config.token_counting,
) )
@@ -351,6 +354,7 @@ async def get_memory_status() -> MemoryStatusResponse:
fact_confidence_threshold=config.fact_confidence_threshold, fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled, injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens, max_injection_tokens=config.max_injection_tokens,
token_counting=config.token_counting,
), ),
data=MemoryResponse(**memory_data), data=MemoryResponse(**memory_data),
) )
+2 -1
View File
@@ -31,7 +31,8 @@ Current injection format:
Token counting: Token counting:
- Uses `tiktoken` (`cl100k_base`) when available - Uses `tiktoken` (`cl100k_base`) when available
- Falls back to `len(text) // 4` if tokenizer import fails - Falls back to a network-free CJK-aware character estimate if tokenizer import or encoding load fails
(CJK characters count as ~2 chars/token, other characters as ~4 chars/token)
## Known Gap ## Known Gap
@@ -586,7 +586,11 @@ def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig
return "" return ""
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id()) memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens) memory_content = format_memory_for_injection(
memory_data,
max_tokens=config.max_injection_tokens,
use_tiktoken=(config.token_counting == "tiktoken"),
)
if not memory_content.strip(): if not memory_content.strip():
return "" return ""
@@ -5,7 +5,9 @@ from __future__ import annotations
import logging import logging
import math import math
import re import re
from typing import Any import threading
import time
from typing import Any, cast
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -169,7 +171,26 @@ Return ONLY valid JSON."""
# subsequent calls are a dict lookup (no network I/O). Pre-warming at # subsequent calls are a dict lookup (no network I/O). Pre-warming at
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the # startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
# (potentially slow) first ``get_encoding`` call. # (potentially slow) first ``get_encoding`` call.
_tiktoken_encoding_cache: dict[str, tiktoken.Encoding] = {} #
# A *failed* load is cached as a ``(None, monotonic_timestamp)`` tuple so that
# a network-restricted environment does not re-attempt the blocking BPE
# download on every subsequent call. After ``_TIKTOKEN_RETRY_COOLDOWN_S`` the
# failure is allowed to expire so a transient network outage can self-heal back
# to accurate tiktoken counting without a process restart. A load already in
# progress is cached as ``_TIKTOKEN_ENCODING_LOADING`` so concurrent callers
# fall back immediately instead of spawning more blocking
# ``tiktoken.get_encoding`` threads. Use the ``memory.token_counting: char``
# config to skip tiktoken entirely.
_TIKTOKEN_ENCODING_MISSING = object()
_TIKTOKEN_ENCODING_LOADING = object()
# Cooldown before a *failed* tiktoken load is re-attempted. This is an internal
# tuning constant rather than a user-facing config: it only affects how quickly
# the default ``tiktoken`` mode self-heals after a transient network outage.
# Deployments that want to avoid tiktoken's network dependency entirely should
# set ``memory.token_counting: char`` instead of tuning this value.
_TIKTOKEN_RETRY_COOLDOWN_S = 600.0
_tiktoken_encoding_cache: dict[str, Any] = {}
_tiktoken_encoding_cache_lock = threading.Lock()
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None: def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
@@ -181,44 +202,91 @@ def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encod
download can block for tens of minutes before the OS TCP timeout kicks in. download can block for tens of minutes before the OS TCP timeout kicks in.
The caller must therefore be prepared for this to block and should run it The caller must therefore be prepared for this to block and should run it
off the event loop (e.g. via ``asyncio.to_thread``). off the event loop (e.g. via ``asyncio.to_thread``).
A failed load is remembered (with a timestamp) so subsequent calls fall
back immediately to character-based estimation instead of re-triggering the
blocking download. The failure expires after ``_TIKTOKEN_RETRY_COOLDOWN_S``
so a transient outage can self-heal without a restart. A load already in
progress is also remembered so that a timed-out caller does not leave a
window where later requests start more blocking ``get_encoding`` calls.
""" """
if not TIKTOKEN_AVAILABLE: if not TIKTOKEN_AVAILABLE:
return None return None
cached = _tiktoken_encoding_cache.get(encoding_name) with _tiktoken_encoding_cache_lock:
if cached is not None: cached = _tiktoken_encoding_cache.get(encoding_name, _TIKTOKEN_ENCODING_MISSING)
return cached if cached is _TIKTOKEN_ENCODING_LOADING:
return None
if isinstance(cached, tuple):
# Cached failure: (None, failed_at). Retry only after cooldown.
_, failed_at = cached
if time.monotonic() - failed_at < _TIKTOKEN_RETRY_COOLDOWN_S:
return None
cached = _TIKTOKEN_ENCODING_MISSING
if cached is not _TIKTOKEN_ENCODING_MISSING:
return cast("tiktoken.Encoding", cached)
_tiktoken_encoding_cache[encoding_name] = _TIKTOKEN_ENCODING_LOADING
try: try:
encoding = tiktoken.get_encoding(encoding_name) encoding = tiktoken.get_encoding(encoding_name)
_tiktoken_encoding_cache[encoding_name] = encoding
return encoding
except Exception: except Exception:
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True) logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
with _tiktoken_encoding_cache_lock:
_tiktoken_encoding_cache[encoding_name] = (None, time.monotonic())
return None return None
with _tiktoken_encoding_cache_lock:
_tiktoken_encoding_cache[encoding_name] = encoding
return encoding
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
def _char_based_token_estimate(text: str) -> int:
"""Network-free token estimate that accounts for CJK density.
The plain ``len(text) // 4`` heuristic is reasonable for English/code
(~4 chars per token) but significantly under-estimates token counts for
Chinese, Japanese, and Korean text, where the ratio is closer to 1.5-2
characters per token. Counting CJK characters separately (~2 chars per
token) avoids over-filling the injection budget for CJK-heavy memory
content.
"""
cjk = sum(
1
for ch in text
if "\u4e00" <= ch <= "\u9fff" # CJK Unified Ideographs
or "\u3040" <= ch <= "\u30ff" # Hiragana + Katakana
or "\uac00" <= ch <= "\ud7a3" # Hangul syllables
)
return (len(text) - cjk) // 4 + cjk // 2
def _count_tokens(text: str, encoding_name: str = "cl100k_base", *, use_tiktoken: bool = True) -> int:
"""Count tokens in text using tiktoken. """Count tokens in text using tiktoken.
Args: Args:
text: The text to count tokens for. text: The text to count tokens for.
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5). encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
use_tiktoken: When ``False``, skip tiktoken entirely and use the
network-free character-based estimate. This guarantees no BPE
download is attempted (see ``memory.token_counting`` config).
Returns: Returns:
The number of tokens in the text. The number of tokens in the text.
""" """
if not use_tiktoken:
return _char_based_token_estimate(text)
encoding = _get_tiktoken_encoding(encoding_name) encoding = _get_tiktoken_encoding(encoding_name)
if encoding is None: if encoding is None:
# Fallback to character-based estimation if tiktoken is not available # Fallback to CJK-aware character estimation if tiktoken is not
# or the encoding failed to load. # available or the encoding failed to load.
return len(text) // 4 return _char_based_token_estimate(text)
try: try:
return len(encoding.encode(text)) return len(encoding.encode(text))
except Exception: except Exception:
# Fallback to character-based estimation on error # Fallback to CJK-aware character estimation on error.
return len(text) // 4 return _char_based_token_estimate(text)
def warm_tiktoken_cache() -> bool: def warm_tiktoken_cache() -> bool:
@@ -248,12 +316,15 @@ def _coerce_confidence(value: Any, default: float = 0.0) -> float:
return max(0.0, min(1.0, confidence)) return max(0.0, min(1.0, confidence))
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str: def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000, *, use_tiktoken: bool = True) -> str:
"""Format memory data for injection into system prompt. """Format memory data for injection into system prompt.
Args: Args:
memory_data: The memory data dictionary. memory_data: The memory data dictionary.
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy). max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
use_tiktoken: When ``False``, all token counting uses the network-free
character-based estimate instead of tiktoken (see
``memory.token_counting`` config). Defaults to ``True``.
Returns: Returns:
Formatted memory string for system prompt injection. Formatted memory string for system prompt injection.
@@ -315,10 +386,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
# Compute token count for existing sections once, then account # Compute token count for existing sections once, then account
# incrementally for each fact line to avoid full-string re-tokenization. # incrementally for each fact line to avoid full-string re-tokenization.
base_text = "\n\n".join(sections) base_text = "\n\n".join(sections)
base_tokens = _count_tokens(base_text) if base_text else 0 base_tokens = _count_tokens(base_text, use_tiktoken=use_tiktoken) if base_text else 0
# Account for the separator between existing sections and the facts section. # Account for the separator between existing sections and the facts section.
facts_header = "Facts:\n" facts_header = "Facts:\n"
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header) separator_tokens = _count_tokens("\n\n" + facts_header, use_tiktoken=use_tiktoken) if base_text else _count_tokens(facts_header, use_tiktoken=use_tiktoken)
running_tokens = base_tokens + separator_tokens running_tokens = base_tokens + separator_tokens
fact_lines: list[str] = [] fact_lines: list[str] = []
@@ -339,7 +410,7 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
# Each additional line is preceded by a newline (except the first). # Each additional line is preceded by a newline (except the first).
line_text = ("\n" + line) if fact_lines else line line_text = ("\n" + line) if fact_lines else line
line_tokens = _count_tokens(line_text) line_tokens = _count_tokens(line_text, use_tiktoken=use_tiktoken)
if running_tokens + line_tokens <= max_tokens: if running_tokens + line_tokens <= max_tokens:
fact_lines.append(line) fact_lines.append(line)
@@ -355,8 +426,9 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
result = "\n\n".join(sections) result = "\n\n".join(sections)
# Use accurate token counting with tiktoken # Use accurate token counting with tiktoken (or the char-based estimate
token_count = _count_tokens(result) # when use_tiktoken is False).
token_count = _count_tokens(result, use_tiktoken=use_tiktoken)
if token_count > max_tokens: if token_count > max_tokens:
# Truncate to fit within token limit # Truncate to fit within token limit
# Estimate characters to remove based on token ratio # Estimate characters to remove based on token ratio
@@ -1141,6 +1141,7 @@ class DeerFlowClient:
"fact_confidence_threshold": config.fact_confidence_threshold, "fact_confidence_threshold": config.fact_confidence_threshold,
"injection_enabled": config.injection_enabled, "injection_enabled": config.injection_enabled,
"max_injection_tokens": config.max_injection_tokens, "max_injection_tokens": config.max_injection_tokens,
"token_counting": config.token_counting,
} }
def get_memory_status(self) -> dict: def get_memory_status(self) -> dict:
@@ -1,5 +1,7 @@
"""Configuration for memory mechanism.""" """Configuration for memory mechanism."""
from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -60,6 +62,17 @@ class MemoryConfig(BaseModel):
le=8000, le=8000,
description="Maximum tokens to use for memory injection", description="Maximum tokens to use for memory injection",
) )
token_counting: Literal["tiktoken", "char"] = Field(
default="tiktoken",
description=(
"Token counting strategy for memory-injection budgeting. "
"'tiktoken' is accurate but the encoding's BPE data may be "
"downloaded from a public network endpoint on first use, which "
"can block for a long time in network-restricted environments "
"(see issue #3402/#3429). 'char' uses a network-free "
"CJK-aware character-based estimate and never touches tiktoken."
),
)
# Global configuration instance # Global configuration instance
+4
View File
@@ -2472,6 +2472,7 @@ class TestGatewayConformance:
mem_cfg.fact_confidence_threshold = 0.7 mem_cfg.fact_confidence_threshold = 0.7
mem_cfg.injection_enabled = True mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000 mem_cfg.max_injection_tokens = 2000
mem_cfg.token_counting = "tiktoken"
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg): with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
result = client.get_memory_config() result = client.get_memory_config()
@@ -2479,6 +2480,7 @@ class TestGatewayConformance:
parsed = MemoryConfigResponse(**result) parsed = MemoryConfigResponse(**result)
assert parsed.enabled is True assert parsed.enabled is True
assert parsed.max_facts == 100 assert parsed.max_facts == 100
assert parsed.token_counting == "tiktoken"
def test_get_memory_status(self, client): def test_get_memory_status(self, client):
mem_cfg = MagicMock() mem_cfg = MagicMock()
@@ -2489,6 +2491,7 @@ class TestGatewayConformance:
mem_cfg.fact_confidence_threshold = 0.7 mem_cfg.fact_confidence_threshold = 0.7
mem_cfg.injection_enabled = True mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000 mem_cfg.max_injection_tokens = 2000
mem_cfg.token_counting = "tiktoken"
memory_data = { memory_data = {
"version": "1.0", "version": "1.0",
@@ -2514,6 +2517,7 @@ class TestGatewayConformance:
parsed = MemoryStatusResponse(**result) parsed = MemoryStatusResponse(**result)
assert parsed.config.enabled is True assert parsed.config.enabled is True
assert parsed.config.token_counting == "tiktoken"
assert parsed.data.version == "1.0" assert parsed.data.version == "1.0"
+4 -2
View File
@@ -192,7 +192,7 @@ def test_build_acp_section_uses_explicit_app_config_without_global_config(monkey
def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch): def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch):
explicit_config = SimpleNamespace( explicit_config = SimpleNamespace(
memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234), memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234, token_counting="tiktoken"),
) )
captured: dict[str, object] = {} captured: dict[str, object] = {}
@@ -204,9 +204,10 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke
captured["user_id"] = user_id captured["user_id"] = user_id
return {"facts": []} return {"facts": []}
def fake_format_memory_for_injection(memory_data, *, max_tokens): def fake_format_memory_for_injection(memory_data, *, max_tokens, use_tiktoken=True):
captured["memory_data"] = memory_data captured["memory_data"] = memory_data
captured["max_tokens"] = max_tokens captured["max_tokens"] = max_tokens
captured["use_tiktoken"] = use_tiktoken
return "remember this" return "remember this"
monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config) monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config)
@@ -223,6 +224,7 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke
"user_id": "user-1", "user_id": "user-1",
"memory_data": {"facts": []}, "memory_data": {"facts": []},
"max_tokens": 1234, "max_tokens": 1234,
"use_tiktoken": True,
} }
@@ -39,7 +39,7 @@ def test_format_memory_sorts_facts_by_confidence_desc() -> None:
def test_format_memory_respects_budget_when_adding_facts(monkeypatch) -> None: def test_format_memory_respects_budget_when_adding_facts(monkeypatch) -> None:
# Make token counting deterministic for this test by counting characters. # Make token counting deterministic for this test by counting characters.
monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base": len(text)) monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base", *, use_tiktoken=True: len(text))
memory_data = { memory_data = {
"user": {}, "user": {},
@@ -5,18 +5,22 @@ Verifies:
- ``_count_tokens`` falls back to character estimation when tiktoken is - ``_count_tokens`` falls back to character estimation when tiktoken is
unavailable or the encoding fails to load. unavailable or the encoding fails to load.
- ``warm_tiktoken_cache`` populates the cache on success. - ``warm_tiktoken_cache`` populates the cache on success.
- An in-flight tiktoken load prevents duplicate blocking downloads.
""" """
from __future__ import annotations from __future__ import annotations
import threading
from unittest import mock from unittest import mock
from deerflow.agents.memory.prompt import ( from deerflow.agents.memory.prompt import (
_count_tokens, _count_tokens,
_get_tiktoken_encoding, _get_tiktoken_encoding,
_tiktoken_encoding_cache, _tiktoken_encoding_cache,
format_memory_for_injection,
warm_tiktoken_cache, warm_tiktoken_cache,
) )
from deerflow.config.memory_config import MemoryConfig
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# _get_tiktoken_encoding # _get_tiktoken_encoding
@@ -62,14 +66,103 @@ class TestGetTiktokenEncoding:
assert enc is fake_enc assert enc is fake_enc
tiktoken.get_encoding.assert_not_called() tiktoken.get_encoding.assert_not_called()
def test_returns_none_and_warns_on_get_encoding_failure(self, monkeypatch): def test_returns_none_and_caches_failure_sentinel(self, monkeypatch):
"""A failed load is cached (with a timestamp) so it is not re-attempted (no repeated network download)."""
_tiktoken_encoding_cache.pop("bogus_encoding", None) _tiktoken_encoding_cache.pop("bogus_encoding", None)
import tiktoken import tiktoken
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=OSError("download failed"))) get_encoding = mock.Mock(side_effect=OSError("download failed"))
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
result = _get_tiktoken_encoding("bogus_encoding") result = _get_tiktoken_encoding("bogus_encoding")
assert result is None assert result is None
assert "bogus_encoding" not in _tiktoken_encoding_cache # The failure is remembered as a (None, timestamp) tuple.
assert "bogus_encoding" in _tiktoken_encoding_cache
cached = _tiktoken_encoding_cache["bogus_encoding"]
assert isinstance(cached, tuple)
assert cached[0] is None
# A second call must NOT re-attempt get_encoding (avoids re-blocking on
# the network download in restricted environments — see #3429).
result2 = _get_tiktoken_encoding("bogus_encoding")
assert result2 is None
assert get_encoding.call_count == 1
# Cleanup module-level cache to avoid cross-test leakage.
_tiktoken_encoding_cache.pop("bogus_encoding", None)
def test_failure_self_heals_after_cooldown(self, monkeypatch):
"""After the retry cooldown expires, a transient failure is re-attempted and can recover."""
_tiktoken_encoding_cache.pop("flaky_encoding", None)
import tiktoken
fake_enc = mock.Mock()
# First call fails, second call (after cooldown) succeeds.
get_encoding = mock.Mock(side_effect=[OSError("transient outage"), fake_enc])
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
# Initial failure is cached.
assert _get_tiktoken_encoding("flaky_encoding") is None
assert get_encoding.call_count == 1
# Within the cooldown window: no retry, immediate fallback.
assert _get_tiktoken_encoding("flaky_encoding") is None
assert get_encoding.call_count == 1
# Simulate the cooldown having elapsed by ageing the cached timestamp.
from deerflow.agents.memory import prompt as prompt_module
_, _failed_at = _tiktoken_encoding_cache["flaky_encoding"]
_tiktoken_encoding_cache["flaky_encoding"] = (
None,
_failed_at - prompt_module._TIKTOKEN_RETRY_COOLDOWN_S - 1,
)
# Now the load is retried and recovers to accurate counting.
assert _get_tiktoken_encoding("flaky_encoding") is fake_enc
assert get_encoding.call_count == 2
_tiktoken_encoding_cache.pop("flaky_encoding", None)
def test_in_flight_load_returns_none_without_duplicate_get_encoding(self, monkeypatch):
"""Concurrent callers must not start duplicate blocking BPE downloads."""
_tiktoken_encoding_cache.pop("slow_encoding", None)
import tiktoken
started = threading.Event()
release = threading.Event()
fake_enc = mock.Mock()
def slow_get_encoding(_name):
started.set()
assert release.wait(timeout=2), "test timed out waiting to release slow get_encoding"
return fake_enc
get_encoding = mock.Mock(side_effect=slow_get_encoding)
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
result: dict[str, object | None] = {}
def load_encoding():
result["encoding"] = _get_tiktoken_encoding("slow_encoding")
thread = threading.Thread(target=load_encoding)
thread.start()
try:
assert started.wait(timeout=1), "slow get_encoding did not start"
# While the first call is still blocked, a second call should see
# the in-flight sentinel and fall back immediately instead of
# starting another potentially long network download.
assert _get_tiktoken_encoding("slow_encoding") is None
assert get_encoding.call_count == 1
finally:
release.set()
thread.join(timeout=2)
_tiktoken_encoding_cache.pop("slow_encoding", None)
assert result["encoding"] is fake_enc
assert get_encoding.call_count == 1
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -115,6 +208,45 @@ class TestCountTokens:
result = _count_tokens(text, encoding_name="test_enc") result = _count_tokens(text, encoding_name="test_enc")
assert result == len(text) // 4 assert result == len(text) // 4
def test_use_tiktoken_false_returns_char_estimate_without_touching_tiktoken(self, monkeypatch):
"""use_tiktoken=False must never call tiktoken (guarantees no BPE download)."""
# Spy on both the encoding loader and tiktoken.get_encoding directly.
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
loader_spy = mock.Mock(side_effect=AssertionError("_get_tiktoken_encoding must not be called"))
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", loader_spy)
text = "Hello, world! This is a network-free count."
result = _count_tokens(text, use_tiktoken=False)
assert result == len(text) // 4
get_encoding_spy.assert_not_called()
loader_spy.assert_not_called()
def test_cjk_estimate_is_denser_than_plain_quarter(self, monkeypatch):
"""CJK text should estimate more tokens than the plain len // 4 heuristic.
CJK characters are ~2 chars/token, so the char-based estimate must not
under-fill the budget the way ``len(text) // 4`` would.
"""
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
# "User prefers concise answers" rendered in CJK (Chinese) characters.
text = "\u7528\u6237\u504f\u597d\u7b80\u6d01\u7684\u4e2d\u6587\u56de\u7b54\u5e76\u5173\u6ce8\u91d1\u878d\u9886\u57df"
result = _count_tokens(text)
# Each CJK char counts as ~1/2 token (vs 1/4 for the plain heuristic).
assert result == len(text) // 2
assert result > len(text) // 4
def test_cjk_estimate_combines_cjk_and_non_cjk_characters(self, monkeypatch):
"""Mixed-language text should apply the CJK density only to CJK chars."""
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
# ASCII words mixed with CJK (Chinese) characters: "User" + "likes" + "Python and data analysis".
text = "User\u559c\u6b22Python\u548c\u6570\u636e\u5206\u6790"
cjk = sum(1 for ch in text if "\u4e00" <= ch <= "\u9fff")
result = _count_tokens(text)
assert result == (len(text) - cjk) // 4 + cjk // 2
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# warm_tiktoken_cache # warm_tiktoken_cache
@@ -146,3 +278,69 @@ class TestWarmTiktokenCache:
def test_returns_false_when_tiktoken_unavailable(self, monkeypatch): def test_returns_false_when_tiktoken_unavailable(self, monkeypatch):
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False) monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
assert warm_tiktoken_cache() is False assert warm_tiktoken_cache() is False
# ---------------------------------------------------------------------------
# format_memory_for_injection token_counting strategy
# ---------------------------------------------------------------------------
class TestFormatMemoryForInjectionTokenCounting:
"""Verify the use_tiktoken flag is honoured end-to-end."""
@staticmethod
def _sample_memory() -> dict:
return {
"facts": [
{"content": "User prefers concise answers.", "category": "preference", "confidence": 0.9},
{"content": "User works in the finance domain.", "category": "context", "confidence": 0.8},
],
}
def test_use_tiktoken_false_never_touches_tiktoken(self, monkeypatch):
"""With use_tiktoken=False, formatting must not call tiktoken at all."""
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=False)
assert "User prefers concise answers." in result
get_encoding_spy.assert_not_called()
def test_use_tiktoken_true_uses_encoding(self, monkeypatch):
"""With use_tiktoken=True (default), the cached encoding is used for counting."""
fake_enc = mock.Mock()
fake_enc.encode.side_effect = lambda text: list(range(len(text)))
monkeypatch.setattr(
"deerflow.agents.memory.prompt._get_tiktoken_encoding",
mock.Mock(return_value=fake_enc),
)
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=True)
assert "User prefers concise answers." in result
assert fake_enc.encode.called
def test_empty_memory_returns_empty(self):
assert format_memory_for_injection({}, max_tokens=2000, use_tiktoken=False) == ""
# ---------------------------------------------------------------------------
# MemoryConfig.token_counting
# ---------------------------------------------------------------------------
class TestMemoryConfigTokenCounting:
"""Verify the new config field defaults and validation."""
def test_default_is_tiktoken(self):
"""Default must remain tiktoken so existing deployments are unaffected."""
assert MemoryConfig().token_counting == "tiktoken"
def test_accepts_char(self):
assert MemoryConfig(token_counting="char").token_counting == "char"
def test_rejects_invalid_value(self):
import pytest
from pydantic import ValidationError
with pytest.raises(ValidationError):
MemoryConfig(token_counting="invalid")
+10 -1
View File
@@ -15,7 +15,7 @@
# ============================================================================ # ============================================================================
# Bump this number when the config schema changes. # Bump this number when the config schema changes.
# Run `make config-upgrade` to merge new fields into your local config.yaml. # Run `make config-upgrade` to merge new fields into your local config.yaml.
config_version: 11 config_version: 12
# ============================================================================ # ============================================================================
# Logging # Logging
@@ -1024,6 +1024,15 @@ memory:
fact_confidence_threshold: 0.7 # Minimum confidence for storing facts fact_confidence_threshold: 0.7 # Minimum confidence for storing facts
injection_enabled: true # Whether to inject memory into system prompt injection_enabled: true # Whether to inject memory into system prompt
max_injection_tokens: 2000 # Maximum tokens for memory injection max_injection_tokens: 2000 # Maximum tokens for memory injection
# Token counting strategy for memory-injection budgeting:
# tiktoken (default) - accurate, but the encoding's BPE data may be
# downloaded from a public network endpoint on first use. In
# network-restricted environments this download can block for a long
# time (see issues #3402 / #3429). Pre-cache the encoding or set this
# to "char" to avoid it.
# char - network-free CJK-aware character-based estimate; never touches
# tiktoken. Slightly less precise budgeting, zero network I/O.
token_counting: tiktoken
# ============================================================================ # ============================================================================
# Custom Agent Management API # Custom Agent Management API