mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 01:45:58 +00:00
* feat(memory): add memory.token_counting config to avoid tiktoken network dependency (#3429) Add a `memory.token_counting` option (`tiktoken` | `char`) so deployments in network-restricted environments can opt out of tiktoken entirely. In `char` mode the memory-injection budget uses a network-free character-based estimate and never triggers the BPE download from openaipublic.blob.core.windows.net, which could otherwise block for tens of minutes (see #3402). Also harden the default `tiktoken` path: - cache an in-flight LOADING sentinel so concurrent callers fall back immediately instead of spawning more blocking get_encoding threads when the first load is still running (e.g. under the 5s startup warm-up timeout); - cache failures with a timestamp and retry after a cooldown so a transient network outage self-heals back to accurate counting without a restart; - skip startup warm-up entirely in char mode. The new config is surfaced via the memory config API and config.example.yaml (config_version bumped). Default remains `tiktoken`, so existing deployments are unaffected. * fix(memory): use CJK-aware char token estimate and address review feedback - Replace the flat len(text)//4 fallback with a CJK-aware estimate so Chinese/Japanese/Korean memory content does not over-fill the injection budget - Document the internal tiktoken retry cooldown and char-mode escape hatch - Sync CLAUDE.md / config.example.yaml / MEMORY_IMPROVEMENTS.md wording - Fix MemoryConfigResponse mocks/assertions and add CJK estimate tests
This commit is contained in:
@@ -5,7 +5,9 @@ from __future__ import annotations
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from typing import Any
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -169,7 +171,26 @@ Return ONLY valid JSON."""
|
||||
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
|
||||
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
|
||||
# (potentially slow) first ``get_encoding`` call.
|
||||
_tiktoken_encoding_cache: dict[str, tiktoken.Encoding] = {}
|
||||
#
|
||||
# A *failed* load is cached as a ``(None, monotonic_timestamp)`` tuple so that
|
||||
# a network-restricted environment does not re-attempt the blocking BPE
|
||||
# download on every subsequent call. After ``_TIKTOKEN_RETRY_COOLDOWN_S`` the
|
||||
# failure is allowed to expire so a transient network outage can self-heal back
|
||||
# to accurate tiktoken counting without a process restart. A load already in
|
||||
# progress is cached as ``_TIKTOKEN_ENCODING_LOADING`` so concurrent callers
|
||||
# fall back immediately instead of spawning more blocking
|
||||
# ``tiktoken.get_encoding`` threads. Use the ``memory.token_counting: char``
|
||||
# config to skip tiktoken entirely.
|
||||
_TIKTOKEN_ENCODING_MISSING = object()
|
||||
_TIKTOKEN_ENCODING_LOADING = object()
|
||||
# Cooldown before a *failed* tiktoken load is re-attempted. This is an internal
|
||||
# tuning constant rather than a user-facing config: it only affects how quickly
|
||||
# the default ``tiktoken`` mode self-heals after a transient network outage.
|
||||
# Deployments that want to avoid tiktoken's network dependency entirely should
|
||||
# set ``memory.token_counting: char`` instead of tuning this value.
|
||||
_TIKTOKEN_RETRY_COOLDOWN_S = 600.0
|
||||
_tiktoken_encoding_cache: dict[str, Any] = {}
|
||||
_tiktoken_encoding_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
|
||||
@@ -181,44 +202,91 @@ def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encod
|
||||
download can block for tens of minutes before the OS TCP timeout kicks in.
|
||||
The caller must therefore be prepared for this to block and should run it
|
||||
off the event loop (e.g. via ``asyncio.to_thread``).
|
||||
|
||||
A failed load is remembered (with a timestamp) so subsequent calls fall
|
||||
back immediately to character-based estimation instead of re-triggering the
|
||||
blocking download. The failure expires after ``_TIKTOKEN_RETRY_COOLDOWN_S``
|
||||
so a transient outage can self-heal without a restart. A load already in
|
||||
progress is also remembered so that a timed-out caller does not leave a
|
||||
window where later requests start more blocking ``get_encoding`` calls.
|
||||
"""
|
||||
if not TIKTOKEN_AVAILABLE:
|
||||
return None
|
||||
|
||||
cached = _tiktoken_encoding_cache.get(encoding_name)
|
||||
if cached is not None:
|
||||
return cached
|
||||
with _tiktoken_encoding_cache_lock:
|
||||
cached = _tiktoken_encoding_cache.get(encoding_name, _TIKTOKEN_ENCODING_MISSING)
|
||||
if cached is _TIKTOKEN_ENCODING_LOADING:
|
||||
return None
|
||||
if isinstance(cached, tuple):
|
||||
# Cached failure: (None, failed_at). Retry only after cooldown.
|
||||
_, failed_at = cached
|
||||
if time.monotonic() - failed_at < _TIKTOKEN_RETRY_COOLDOWN_S:
|
||||
return None
|
||||
cached = _TIKTOKEN_ENCODING_MISSING
|
||||
if cached is not _TIKTOKEN_ENCODING_MISSING:
|
||||
return cast("tiktoken.Encoding", cached)
|
||||
_tiktoken_encoding_cache[encoding_name] = _TIKTOKEN_ENCODING_LOADING
|
||||
|
||||
try:
|
||||
encoding = tiktoken.get_encoding(encoding_name)
|
||||
_tiktoken_encoding_cache[encoding_name] = encoding
|
||||
return encoding
|
||||
except Exception:
|
||||
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
|
||||
with _tiktoken_encoding_cache_lock:
|
||||
_tiktoken_encoding_cache[encoding_name] = (None, time.monotonic())
|
||||
return None
|
||||
|
||||
with _tiktoken_encoding_cache_lock:
|
||||
_tiktoken_encoding_cache[encoding_name] = encoding
|
||||
return encoding
|
||||
|
||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
||||
|
||||
def _char_based_token_estimate(text: str) -> int:
|
||||
"""Network-free token estimate that accounts for CJK density.
|
||||
|
||||
The plain ``len(text) // 4`` heuristic is reasonable for English/code
|
||||
(~4 chars per token) but significantly under-estimates token counts for
|
||||
Chinese, Japanese, and Korean text, where the ratio is closer to 1.5-2
|
||||
characters per token. Counting CJK characters separately (~2 chars per
|
||||
token) avoids over-filling the injection budget for CJK-heavy memory
|
||||
content.
|
||||
"""
|
||||
cjk = sum(
|
||||
1
|
||||
for ch in text
|
||||
if "\u4e00" <= ch <= "\u9fff" # CJK Unified Ideographs
|
||||
or "\u3040" <= ch <= "\u30ff" # Hiragana + Katakana
|
||||
or "\uac00" <= ch <= "\ud7a3" # Hangul syllables
|
||||
)
|
||||
return (len(text) - cjk) // 4 + cjk // 2
|
||||
|
||||
|
||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base", *, use_tiktoken: bool = True) -> int:
|
||||
"""Count tokens in text using tiktoken.
|
||||
|
||||
Args:
|
||||
text: The text to count tokens for.
|
||||
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
||||
use_tiktoken: When ``False``, skip tiktoken entirely and use the
|
||||
network-free character-based estimate. This guarantees no BPE
|
||||
download is attempted (see ``memory.token_counting`` config).
|
||||
|
||||
Returns:
|
||||
The number of tokens in the text.
|
||||
"""
|
||||
if not use_tiktoken:
|
||||
return _char_based_token_estimate(text)
|
||||
|
||||
encoding = _get_tiktoken_encoding(encoding_name)
|
||||
if encoding is None:
|
||||
# Fallback to character-based estimation if tiktoken is not available
|
||||
# or the encoding failed to load.
|
||||
return len(text) // 4
|
||||
# Fallback to CJK-aware character estimation if tiktoken is not
|
||||
# available or the encoding failed to load.
|
||||
return _char_based_token_estimate(text)
|
||||
|
||||
try:
|
||||
return len(encoding.encode(text))
|
||||
except Exception:
|
||||
# Fallback to character-based estimation on error
|
||||
return len(text) // 4
|
||||
# Fallback to CJK-aware character estimation on error.
|
||||
return _char_based_token_estimate(text)
|
||||
|
||||
|
||||
def warm_tiktoken_cache() -> bool:
|
||||
@@ -248,12 +316,15 @@ def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
||||
return max(0.0, min(1.0, confidence))
|
||||
|
||||
|
||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000, *, use_tiktoken: bool = True) -> str:
|
||||
"""Format memory data for injection into system prompt.
|
||||
|
||||
Args:
|
||||
memory_data: The memory data dictionary.
|
||||
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
||||
use_tiktoken: When ``False``, all token counting uses the network-free
|
||||
character-based estimate instead of tiktoken (see
|
||||
``memory.token_counting`` config). Defaults to ``True``.
|
||||
|
||||
Returns:
|
||||
Formatted memory string for system prompt injection.
|
||||
@@ -315,10 +386,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
# Compute token count for existing sections once, then account
|
||||
# incrementally for each fact line to avoid full-string re-tokenization.
|
||||
base_text = "\n\n".join(sections)
|
||||
base_tokens = _count_tokens(base_text) if base_text else 0
|
||||
base_tokens = _count_tokens(base_text, use_tiktoken=use_tiktoken) if base_text else 0
|
||||
# Account for the separator between existing sections and the facts section.
|
||||
facts_header = "Facts:\n"
|
||||
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
|
||||
separator_tokens = _count_tokens("\n\n" + facts_header, use_tiktoken=use_tiktoken) if base_text else _count_tokens(facts_header, use_tiktoken=use_tiktoken)
|
||||
running_tokens = base_tokens + separator_tokens
|
||||
|
||||
fact_lines: list[str] = []
|
||||
@@ -339,7 +410,7 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
|
||||
# Each additional line is preceded by a newline (except the first).
|
||||
line_text = ("\n" + line) if fact_lines else line
|
||||
line_tokens = _count_tokens(line_text)
|
||||
line_tokens = _count_tokens(line_text, use_tiktoken=use_tiktoken)
|
||||
|
||||
if running_tokens + line_tokens <= max_tokens:
|
||||
fact_lines.append(line)
|
||||
@@ -355,8 +426,9 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
|
||||
result = "\n\n".join(sections)
|
||||
|
||||
# Use accurate token counting with tiktoken
|
||||
token_count = _count_tokens(result)
|
||||
# Use accurate token counting with tiktoken (or the char-based estimate
|
||||
# when use_tiktoken is False).
|
||||
token_count = _count_tokens(result, use_tiktoken=use_tiktoken)
|
||||
if token_count > max_tokens:
|
||||
# Truncate to fit within token limit
|
||||
# Estimate characters to remove based on token ratio
|
||||
|
||||
Reference in New Issue
Block a user