Files
deer-flow/backend/tests/test_tiktoken_cache_and_count_tokens.py
T
Ryker_Feng 167ef4512f feat(memory): add memory.token_counting config to avoid tiktoken network dependency (#3429) (#3465)
* feat(memory): add memory.token_counting config to avoid tiktoken network dependency (#3429)

Add a `memory.token_counting` option (`tiktoken` | `char`) so deployments in
network-restricted environments can opt out of tiktoken entirely. In `char`
mode the memory-injection budget uses a network-free character-based estimate
and never triggers the BPE download from openaipublic.blob.core.windows.net,
which could otherwise block for tens of minutes (see #3402).

Also harden the default `tiktoken` path:
- cache an in-flight LOADING sentinel so concurrent callers fall back
  immediately instead of spawning more blocking get_encoding threads when the
  first load is still running (e.g. under the 5s startup warm-up timeout);
- cache failures with a timestamp and retry after a cooldown so a transient
  network outage self-heals back to accurate counting without a restart;
- skip startup warm-up entirely in char mode.

The new config is surfaced via the memory config API and config.example.yaml
(config_version bumped). Default remains `tiktoken`, so existing deployments
are unaffected.

* fix(memory): use CJK-aware char token estimate and address review feedback

- Replace the flat len(text)//4 fallback with a CJK-aware estimate so
  Chinese/Japanese/Korean memory content does not over-fill the injection budget
- Document the internal tiktoken retry cooldown and char-mode escape hatch
- Sync CLAUDE.md / config.example.yaml / MEMORY_IMPROVEMENTS.md wording
- Fix MemoryConfigResponse mocks/assertions and add CJK estimate tests
2026-06-10 23:26:15 +08:00

347 lines
14 KiB
Python

"""Tests for tiktoken encoding cache and _count_tokens fallback.
Verifies:
- Module-level cache avoids repeated ``get_encoding`` calls.
- ``_count_tokens`` falls back to character estimation when tiktoken is
unavailable or the encoding fails to load.
- ``warm_tiktoken_cache`` populates the cache on success.
- An in-flight tiktoken load prevents duplicate blocking downloads.
"""
from __future__ import annotations
import threading
from unittest import mock
from deerflow.agents.memory.prompt import (
_count_tokens,
_get_tiktoken_encoding,
_tiktoken_encoding_cache,
format_memory_for_injection,
warm_tiktoken_cache,
)
from deerflow.config.memory_config import MemoryConfig
# ---------------------------------------------------------------------------
# _get_tiktoken_encoding
# ---------------------------------------------------------------------------
class TestGetTiktokenEncoding:
"""Tests for _get_tiktoken_encoding caching and fallback."""
def test_returns_none_when_tiktoken_unavailable(self, monkeypatch):
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
assert _get_tiktoken_encoding("cl100k_base") is None
def test_returns_encoding_on_success(self, monkeypatch):
# Clear cache to ensure a fresh call
_tiktoken_encoding_cache.pop("cl100k_base", None)
fake_enc = mock.Mock()
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", mock.Mock(return_value=fake_enc))
enc = _get_tiktoken_encoding("cl100k_base")
assert enc is fake_enc
def test_populates_cache_on_success(self, monkeypatch):
_tiktoken_encoding_cache.pop("cl100k_base", None)
fake_enc = mock.Mock()
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", mock.Mock(return_value=fake_enc))
_get_tiktoken_encoding("cl100k_base")
assert _tiktoken_encoding_cache["cl100k_base"] is fake_enc
def test_returns_cached_encoding_without_calling_get_encoding(self, monkeypatch):
fake_enc = mock.Mock()
monkeypatch.setitem(_tiktoken_encoding_cache, "cl100k_base", fake_enc)
# Now patch tiktoken.get_encoding to raise if called
import tiktoken
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=RuntimeError("should not be called")))
# Cached path — should NOT call get_encoding
enc = _get_tiktoken_encoding("cl100k_base")
assert enc is fake_enc
tiktoken.get_encoding.assert_not_called()
def test_returns_none_and_caches_failure_sentinel(self, monkeypatch):
"""A failed load is cached (with a timestamp) so it is not re-attempted (no repeated network download)."""
_tiktoken_encoding_cache.pop("bogus_encoding", None)
import tiktoken
get_encoding = mock.Mock(side_effect=OSError("download failed"))
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
result = _get_tiktoken_encoding("bogus_encoding")
assert result is None
# The failure is remembered as a (None, timestamp) tuple.
assert "bogus_encoding" in _tiktoken_encoding_cache
cached = _tiktoken_encoding_cache["bogus_encoding"]
assert isinstance(cached, tuple)
assert cached[0] is None
# A second call must NOT re-attempt get_encoding (avoids re-blocking on
# the network download in restricted environments — see #3429).
result2 = _get_tiktoken_encoding("bogus_encoding")
assert result2 is None
assert get_encoding.call_count == 1
# Cleanup module-level cache to avoid cross-test leakage.
_tiktoken_encoding_cache.pop("bogus_encoding", None)
def test_failure_self_heals_after_cooldown(self, monkeypatch):
"""After the retry cooldown expires, a transient failure is re-attempted and can recover."""
_tiktoken_encoding_cache.pop("flaky_encoding", None)
import tiktoken
fake_enc = mock.Mock()
# First call fails, second call (after cooldown) succeeds.
get_encoding = mock.Mock(side_effect=[OSError("transient outage"), fake_enc])
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
# Initial failure is cached.
assert _get_tiktoken_encoding("flaky_encoding") is None
assert get_encoding.call_count == 1
# Within the cooldown window: no retry, immediate fallback.
assert _get_tiktoken_encoding("flaky_encoding") is None
assert get_encoding.call_count == 1
# Simulate the cooldown having elapsed by ageing the cached timestamp.
from deerflow.agents.memory import prompt as prompt_module
_, _failed_at = _tiktoken_encoding_cache["flaky_encoding"]
_tiktoken_encoding_cache["flaky_encoding"] = (
None,
_failed_at - prompt_module._TIKTOKEN_RETRY_COOLDOWN_S - 1,
)
# Now the load is retried and recovers to accurate counting.
assert _get_tiktoken_encoding("flaky_encoding") is fake_enc
assert get_encoding.call_count == 2
_tiktoken_encoding_cache.pop("flaky_encoding", None)
def test_in_flight_load_returns_none_without_duplicate_get_encoding(self, monkeypatch):
"""Concurrent callers must not start duplicate blocking BPE downloads."""
_tiktoken_encoding_cache.pop("slow_encoding", None)
import tiktoken
started = threading.Event()
release = threading.Event()
fake_enc = mock.Mock()
def slow_get_encoding(_name):
started.set()
assert release.wait(timeout=2), "test timed out waiting to release slow get_encoding"
return fake_enc
get_encoding = mock.Mock(side_effect=slow_get_encoding)
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
result: dict[str, object | None] = {}
def load_encoding():
result["encoding"] = _get_tiktoken_encoding("slow_encoding")
thread = threading.Thread(target=load_encoding)
thread.start()
try:
assert started.wait(timeout=1), "slow get_encoding did not start"
# While the first call is still blocked, a second call should see
# the in-flight sentinel and fall back immediately instead of
# starting another potentially long network download.
assert _get_tiktoken_encoding("slow_encoding") is None
assert get_encoding.call_count == 1
finally:
release.set()
thread.join(timeout=2)
_tiktoken_encoding_cache.pop("slow_encoding", None)
assert result["encoding"] is fake_enc
assert get_encoding.call_count == 1
# ---------------------------------------------------------------------------
# _count_tokens
# ---------------------------------------------------------------------------
class TestCountTokens:
"""Tests for _count_tokens fallback behaviour."""
def test_returns_character_estimate_when_tiktoken_unavailable(self, monkeypatch):
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
text = "Hello, world! This is a test."
result = _count_tokens(text)
assert result == len(text) // 4
def test_returns_character_estimate_when_encoding_fails(self, monkeypatch):
monkeypatch.setattr(
"deerflow.agents.memory.prompt._get_tiktoken_encoding",
lambda _name=None: None,
)
text = "Some text to count"
result = _count_tokens(text)
assert result == len(text) // 4
def test_returns_token_count_on_success(self, monkeypatch):
fake_enc = mock.Mock()
fake_enc.encode.return_value = [0, 1, 2, 3]
monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", mock.Mock(return_value=fake_enc))
text = "Hello, world!"
result = _count_tokens(text)
assert result == 4
assert result <= len(text)
def test_falls_back_on_encode_exception(self, monkeypatch):
# Cache an encoding whose .encode raises
fake_enc = mock.Mock()
fake_enc.encode.side_effect = RuntimeError("encode failed")
monkeypatch.setitem(_tiktoken_encoding_cache, "test_enc", fake_enc)
text = "Fallback test"
result = _count_tokens(text, encoding_name="test_enc")
assert result == len(text) // 4
def test_use_tiktoken_false_returns_char_estimate_without_touching_tiktoken(self, monkeypatch):
"""use_tiktoken=False must never call tiktoken (guarantees no BPE download)."""
# Spy on both the encoding loader and tiktoken.get_encoding directly.
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
loader_spy = mock.Mock(side_effect=AssertionError("_get_tiktoken_encoding must not be called"))
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", loader_spy)
text = "Hello, world! This is a network-free count."
result = _count_tokens(text, use_tiktoken=False)
assert result == len(text) // 4
get_encoding_spy.assert_not_called()
loader_spy.assert_not_called()
def test_cjk_estimate_is_denser_than_plain_quarter(self, monkeypatch):
"""CJK text should estimate more tokens than the plain len // 4 heuristic.
CJK characters are ~2 chars/token, so the char-based estimate must not
under-fill the budget the way ``len(text) // 4`` would.
"""
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
# "User prefers concise answers" rendered in CJK (Chinese) characters.
text = "\u7528\u6237\u504f\u597d\u7b80\u6d01\u7684\u4e2d\u6587\u56de\u7b54\u5e76\u5173\u6ce8\u91d1\u878d\u9886\u57df"
result = _count_tokens(text)
# Each CJK char counts as ~1/2 token (vs 1/4 for the plain heuristic).
assert result == len(text) // 2
assert result > len(text) // 4
def test_cjk_estimate_combines_cjk_and_non_cjk_characters(self, monkeypatch):
"""Mixed-language text should apply the CJK density only to CJK chars."""
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
# ASCII words mixed with CJK (Chinese) characters: "User" + "likes" + "Python and data analysis".
text = "User\u559c\u6b22Python\u548c\u6570\u636e\u5206\u6790"
cjk = sum(1 for ch in text if "\u4e00" <= ch <= "\u9fff")
result = _count_tokens(text)
assert result == (len(text) - cjk) // 4 + cjk // 2
# ---------------------------------------------------------------------------
# warm_tiktoken_cache
# ---------------------------------------------------------------------------
class TestWarmTiktokenCache:
"""Tests for warm_tiktoken_cache startup helper."""
def test_returns_true_on_success(self, monkeypatch):
_tiktoken_encoding_cache.pop("cl100k_base", None)
fake_enc = mock.Mock()
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", mock.Mock(return_value=fake_enc))
assert warm_tiktoken_cache() is True
assert _tiktoken_encoding_cache["cl100k_base"] is fake_enc
def test_returns_true_if_already_cached(self, monkeypatch):
fake_enc = mock.Mock()
monkeypatch.setitem(_tiktoken_encoding_cache, "cl100k_base", fake_enc)
import tiktoken
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=RuntimeError("should not be called")))
assert warm_tiktoken_cache() is True
tiktoken.get_encoding.assert_not_called()
def test_returns_false_when_tiktoken_unavailable(self, monkeypatch):
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
assert warm_tiktoken_cache() is False
# ---------------------------------------------------------------------------
# format_memory_for_injection token_counting strategy
# ---------------------------------------------------------------------------
class TestFormatMemoryForInjectionTokenCounting:
"""Verify the use_tiktoken flag is honoured end-to-end."""
@staticmethod
def _sample_memory() -> dict:
return {
"facts": [
{"content": "User prefers concise answers.", "category": "preference", "confidence": 0.9},
{"content": "User works in the finance domain.", "category": "context", "confidence": 0.8},
],
}
def test_use_tiktoken_false_never_touches_tiktoken(self, monkeypatch):
"""With use_tiktoken=False, formatting must not call tiktoken at all."""
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=False)
assert "User prefers concise answers." in result
get_encoding_spy.assert_not_called()
def test_use_tiktoken_true_uses_encoding(self, monkeypatch):
"""With use_tiktoken=True (default), the cached encoding is used for counting."""
fake_enc = mock.Mock()
fake_enc.encode.side_effect = lambda text: list(range(len(text)))
monkeypatch.setattr(
"deerflow.agents.memory.prompt._get_tiktoken_encoding",
mock.Mock(return_value=fake_enc),
)
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=True)
assert "User prefers concise answers." in result
assert fake_enc.encode.called
def test_empty_memory_returns_empty(self):
assert format_memory_for_injection({}, max_tokens=2000, use_tiktoken=False) == ""
# ---------------------------------------------------------------------------
# MemoryConfig.token_counting
# ---------------------------------------------------------------------------
class TestMemoryConfigTokenCounting:
"""Verify the new config field defaults and validation."""
def test_default_is_tiktoken(self):
"""Default must remain tiktoken so existing deployments are unaffected."""
assert MemoryConfig().token_counting == "tiktoken"
def test_accepts_char(self):
assert MemoryConfig(token_counting="char").token_counting == "char"
def test_rejects_invalid_value(self):
import pytest
from pydantic import ValidationError
with pytest.raises(ValidationError):
MemoryConfig(token_counting="invalid")