mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
167ef4512f
* feat(memory): add memory.token_counting config to avoid tiktoken network dependency (#3429) Add a `memory.token_counting` option (`tiktoken` | `char`) so deployments in network-restricted environments can opt out of tiktoken entirely. In `char` mode the memory-injection budget uses a network-free character-based estimate and never triggers the BPE download from openaipublic.blob.core.windows.net, which could otherwise block for tens of minutes (see #3402). Also harden the default `tiktoken` path: - cache an in-flight LOADING sentinel so concurrent callers fall back immediately instead of spawning more blocking get_encoding threads when the first load is still running (e.g. under the 5s startup warm-up timeout); - cache failures with a timestamp and retry after a cooldown so a transient network outage self-heals back to accurate counting without a restart; - skip startup warm-up entirely in char mode. The new config is surfaced via the memory config API and config.example.yaml (config_version bumped). Default remains `tiktoken`, so existing deployments are unaffected. * fix(memory): use CJK-aware char token estimate and address review feedback - Replace the flat len(text)//4 fallback with a CJK-aware estimate so Chinese/Japanese/Korean memory content does not over-fill the injection budget - Document the internal tiktoken retry cooldown and char-mode escape hatch - Sync CLAUDE.md / config.example.yaml / MEMORY_IMPROVEMENTS.md wording - Fix MemoryConfigResponse mocks/assertions and add CJK estimate tests
347 lines
14 KiB
Python
347 lines
14 KiB
Python
"""Tests for tiktoken encoding cache and _count_tokens fallback.
|
|
|
|
Verifies:
|
|
- Module-level cache avoids repeated ``get_encoding`` calls.
|
|
- ``_count_tokens`` falls back to character estimation when tiktoken is
|
|
unavailable or the encoding fails to load.
|
|
- ``warm_tiktoken_cache`` populates the cache on success.
|
|
- An in-flight tiktoken load prevents duplicate blocking downloads.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import threading
|
|
from unittest import mock
|
|
|
|
from deerflow.agents.memory.prompt import (
|
|
_count_tokens,
|
|
_get_tiktoken_encoding,
|
|
_tiktoken_encoding_cache,
|
|
format_memory_for_injection,
|
|
warm_tiktoken_cache,
|
|
)
|
|
from deerflow.config.memory_config import MemoryConfig
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _get_tiktoken_encoding
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetTiktokenEncoding:
|
|
"""Tests for _get_tiktoken_encoding caching and fallback."""
|
|
|
|
def test_returns_none_when_tiktoken_unavailable(self, monkeypatch):
|
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
|
assert _get_tiktoken_encoding("cl100k_base") is None
|
|
|
|
def test_returns_encoding_on_success(self, monkeypatch):
|
|
# Clear cache to ensure a fresh call
|
|
_tiktoken_encoding_cache.pop("cl100k_base", None)
|
|
|
|
fake_enc = mock.Mock()
|
|
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", mock.Mock(return_value=fake_enc))
|
|
|
|
enc = _get_tiktoken_encoding("cl100k_base")
|
|
assert enc is fake_enc
|
|
|
|
def test_populates_cache_on_success(self, monkeypatch):
|
|
_tiktoken_encoding_cache.pop("cl100k_base", None)
|
|
|
|
fake_enc = mock.Mock()
|
|
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", mock.Mock(return_value=fake_enc))
|
|
|
|
_get_tiktoken_encoding("cl100k_base")
|
|
assert _tiktoken_encoding_cache["cl100k_base"] is fake_enc
|
|
|
|
def test_returns_cached_encoding_without_calling_get_encoding(self, monkeypatch):
|
|
fake_enc = mock.Mock()
|
|
monkeypatch.setitem(_tiktoken_encoding_cache, "cl100k_base", fake_enc)
|
|
|
|
# Now patch tiktoken.get_encoding to raise if called
|
|
import tiktoken
|
|
|
|
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=RuntimeError("should not be called")))
|
|
# Cached path — should NOT call get_encoding
|
|
enc = _get_tiktoken_encoding("cl100k_base")
|
|
assert enc is fake_enc
|
|
tiktoken.get_encoding.assert_not_called()
|
|
|
|
def test_returns_none_and_caches_failure_sentinel(self, monkeypatch):
|
|
"""A failed load is cached (with a timestamp) so it is not re-attempted (no repeated network download)."""
|
|
_tiktoken_encoding_cache.pop("bogus_encoding", None)
|
|
import tiktoken
|
|
|
|
get_encoding = mock.Mock(side_effect=OSError("download failed"))
|
|
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
|
|
|
|
result = _get_tiktoken_encoding("bogus_encoding")
|
|
assert result is None
|
|
# The failure is remembered as a (None, timestamp) tuple.
|
|
assert "bogus_encoding" in _tiktoken_encoding_cache
|
|
cached = _tiktoken_encoding_cache["bogus_encoding"]
|
|
assert isinstance(cached, tuple)
|
|
assert cached[0] is None
|
|
|
|
# A second call must NOT re-attempt get_encoding (avoids re-blocking on
|
|
# the network download in restricted environments — see #3429).
|
|
result2 = _get_tiktoken_encoding("bogus_encoding")
|
|
assert result2 is None
|
|
assert get_encoding.call_count == 1
|
|
|
|
# Cleanup module-level cache to avoid cross-test leakage.
|
|
_tiktoken_encoding_cache.pop("bogus_encoding", None)
|
|
|
|
def test_failure_self_heals_after_cooldown(self, monkeypatch):
|
|
"""After the retry cooldown expires, a transient failure is re-attempted and can recover."""
|
|
_tiktoken_encoding_cache.pop("flaky_encoding", None)
|
|
import tiktoken
|
|
|
|
fake_enc = mock.Mock()
|
|
# First call fails, second call (after cooldown) succeeds.
|
|
get_encoding = mock.Mock(side_effect=[OSError("transient outage"), fake_enc])
|
|
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
|
|
|
|
# Initial failure is cached.
|
|
assert _get_tiktoken_encoding("flaky_encoding") is None
|
|
assert get_encoding.call_count == 1
|
|
|
|
# Within the cooldown window: no retry, immediate fallback.
|
|
assert _get_tiktoken_encoding("flaky_encoding") is None
|
|
assert get_encoding.call_count == 1
|
|
|
|
# Simulate the cooldown having elapsed by ageing the cached timestamp.
|
|
from deerflow.agents.memory import prompt as prompt_module
|
|
|
|
_, _failed_at = _tiktoken_encoding_cache["flaky_encoding"]
|
|
_tiktoken_encoding_cache["flaky_encoding"] = (
|
|
None,
|
|
_failed_at - prompt_module._TIKTOKEN_RETRY_COOLDOWN_S - 1,
|
|
)
|
|
|
|
# Now the load is retried and recovers to accurate counting.
|
|
assert _get_tiktoken_encoding("flaky_encoding") is fake_enc
|
|
assert get_encoding.call_count == 2
|
|
|
|
_tiktoken_encoding_cache.pop("flaky_encoding", None)
|
|
|
|
def test_in_flight_load_returns_none_without_duplicate_get_encoding(self, monkeypatch):
|
|
"""Concurrent callers must not start duplicate blocking BPE downloads."""
|
|
_tiktoken_encoding_cache.pop("slow_encoding", None)
|
|
import tiktoken
|
|
|
|
started = threading.Event()
|
|
release = threading.Event()
|
|
fake_enc = mock.Mock()
|
|
|
|
def slow_get_encoding(_name):
|
|
started.set()
|
|
assert release.wait(timeout=2), "test timed out waiting to release slow get_encoding"
|
|
return fake_enc
|
|
|
|
get_encoding = mock.Mock(side_effect=slow_get_encoding)
|
|
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
|
|
|
|
result: dict[str, object | None] = {}
|
|
|
|
def load_encoding():
|
|
result["encoding"] = _get_tiktoken_encoding("slow_encoding")
|
|
|
|
thread = threading.Thread(target=load_encoding)
|
|
thread.start()
|
|
try:
|
|
assert started.wait(timeout=1), "slow get_encoding did not start"
|
|
|
|
# While the first call is still blocked, a second call should see
|
|
# the in-flight sentinel and fall back immediately instead of
|
|
# starting another potentially long network download.
|
|
assert _get_tiktoken_encoding("slow_encoding") is None
|
|
assert get_encoding.call_count == 1
|
|
finally:
|
|
release.set()
|
|
thread.join(timeout=2)
|
|
_tiktoken_encoding_cache.pop("slow_encoding", None)
|
|
|
|
assert result["encoding"] is fake_enc
|
|
assert get_encoding.call_count == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _count_tokens
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCountTokens:
|
|
"""Tests for _count_tokens fallback behaviour."""
|
|
|
|
def test_returns_character_estimate_when_tiktoken_unavailable(self, monkeypatch):
|
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
|
text = "Hello, world! This is a test."
|
|
result = _count_tokens(text)
|
|
assert result == len(text) // 4
|
|
|
|
def test_returns_character_estimate_when_encoding_fails(self, monkeypatch):
|
|
monkeypatch.setattr(
|
|
"deerflow.agents.memory.prompt._get_tiktoken_encoding",
|
|
lambda _name=None: None,
|
|
)
|
|
text = "Some text to count"
|
|
result = _count_tokens(text)
|
|
assert result == len(text) // 4
|
|
|
|
def test_returns_token_count_on_success(self, monkeypatch):
|
|
fake_enc = mock.Mock()
|
|
fake_enc.encode.return_value = [0, 1, 2, 3]
|
|
monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", mock.Mock(return_value=fake_enc))
|
|
|
|
text = "Hello, world!"
|
|
result = _count_tokens(text)
|
|
assert result == 4
|
|
assert result <= len(text)
|
|
|
|
def test_falls_back_on_encode_exception(self, monkeypatch):
|
|
# Cache an encoding whose .encode raises
|
|
fake_enc = mock.Mock()
|
|
fake_enc.encode.side_effect = RuntimeError("encode failed")
|
|
monkeypatch.setitem(_tiktoken_encoding_cache, "test_enc", fake_enc)
|
|
|
|
text = "Fallback test"
|
|
result = _count_tokens(text, encoding_name="test_enc")
|
|
assert result == len(text) // 4
|
|
|
|
def test_use_tiktoken_false_returns_char_estimate_without_touching_tiktoken(self, monkeypatch):
|
|
"""use_tiktoken=False must never call tiktoken (guarantees no BPE download)."""
|
|
# Spy on both the encoding loader and tiktoken.get_encoding directly.
|
|
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
|
|
loader_spy = mock.Mock(side_effect=AssertionError("_get_tiktoken_encoding must not be called"))
|
|
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
|
|
monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", loader_spy)
|
|
|
|
text = "Hello, world! This is a network-free count."
|
|
result = _count_tokens(text, use_tiktoken=False)
|
|
assert result == len(text) // 4
|
|
get_encoding_spy.assert_not_called()
|
|
loader_spy.assert_not_called()
|
|
|
|
def test_cjk_estimate_is_denser_than_plain_quarter(self, monkeypatch):
|
|
"""CJK text should estimate more tokens than the plain len // 4 heuristic.
|
|
|
|
CJK characters are ~2 chars/token, so the char-based estimate must not
|
|
under-fill the budget the way ``len(text) // 4`` would.
|
|
"""
|
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
|
# "User prefers concise answers" rendered in CJK (Chinese) characters.
|
|
text = "\u7528\u6237\u504f\u597d\u7b80\u6d01\u7684\u4e2d\u6587\u56de\u7b54\u5e76\u5173\u6ce8\u91d1\u878d\u9886\u57df"
|
|
result = _count_tokens(text)
|
|
# Each CJK char counts as ~1/2 token (vs 1/4 for the plain heuristic).
|
|
assert result == len(text) // 2
|
|
assert result > len(text) // 4
|
|
|
|
def test_cjk_estimate_combines_cjk_and_non_cjk_characters(self, monkeypatch):
|
|
"""Mixed-language text should apply the CJK density only to CJK chars."""
|
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
|
# ASCII words mixed with CJK (Chinese) characters: "User" + "likes" + "Python and data analysis".
|
|
text = "User\u559c\u6b22Python\u548c\u6570\u636e\u5206\u6790"
|
|
cjk = sum(1 for ch in text if "\u4e00" <= ch <= "\u9fff")
|
|
|
|
result = _count_tokens(text)
|
|
|
|
assert result == (len(text) - cjk) // 4 + cjk // 2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# warm_tiktoken_cache
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestWarmTiktokenCache:
|
|
"""Tests for warm_tiktoken_cache startup helper."""
|
|
|
|
def test_returns_true_on_success(self, monkeypatch):
|
|
_tiktoken_encoding_cache.pop("cl100k_base", None)
|
|
|
|
fake_enc = mock.Mock()
|
|
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", mock.Mock(return_value=fake_enc))
|
|
|
|
assert warm_tiktoken_cache() is True
|
|
assert _tiktoken_encoding_cache["cl100k_base"] is fake_enc
|
|
|
|
def test_returns_true_if_already_cached(self, monkeypatch):
|
|
fake_enc = mock.Mock()
|
|
monkeypatch.setitem(_tiktoken_encoding_cache, "cl100k_base", fake_enc)
|
|
|
|
import tiktoken
|
|
|
|
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=RuntimeError("should not be called")))
|
|
assert warm_tiktoken_cache() is True
|
|
tiktoken.get_encoding.assert_not_called()
|
|
|
|
def test_returns_false_when_tiktoken_unavailable(self, monkeypatch):
|
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
|
assert warm_tiktoken_cache() is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# format_memory_for_injection token_counting strategy
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestFormatMemoryForInjectionTokenCounting:
|
|
"""Verify the use_tiktoken flag is honoured end-to-end."""
|
|
|
|
@staticmethod
|
|
def _sample_memory() -> dict:
|
|
return {
|
|
"facts": [
|
|
{"content": "User prefers concise answers.", "category": "preference", "confidence": 0.9},
|
|
{"content": "User works in the finance domain.", "category": "context", "confidence": 0.8},
|
|
],
|
|
}
|
|
|
|
def test_use_tiktoken_false_never_touches_tiktoken(self, monkeypatch):
|
|
"""With use_tiktoken=False, formatting must not call tiktoken at all."""
|
|
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
|
|
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
|
|
|
|
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=False)
|
|
assert "User prefers concise answers." in result
|
|
get_encoding_spy.assert_not_called()
|
|
|
|
def test_use_tiktoken_true_uses_encoding(self, monkeypatch):
|
|
"""With use_tiktoken=True (default), the cached encoding is used for counting."""
|
|
fake_enc = mock.Mock()
|
|
fake_enc.encode.side_effect = lambda text: list(range(len(text)))
|
|
monkeypatch.setattr(
|
|
"deerflow.agents.memory.prompt._get_tiktoken_encoding",
|
|
mock.Mock(return_value=fake_enc),
|
|
)
|
|
|
|
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=True)
|
|
assert "User prefers concise answers." in result
|
|
assert fake_enc.encode.called
|
|
|
|
def test_empty_memory_returns_empty(self):
|
|
assert format_memory_for_injection({}, max_tokens=2000, use_tiktoken=False) == ""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MemoryConfig.token_counting
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMemoryConfigTokenCounting:
|
|
"""Verify the new config field defaults and validation."""
|
|
|
|
def test_default_is_tiktoken(self):
|
|
"""Default must remain tiktoken so existing deployments are unaffected."""
|
|
assert MemoryConfig().token_counting == "tiktoken"
|
|
|
|
def test_accepts_char(self):
|
|
assert MemoryConfig(token_counting="char").token_counting == "char"
|
|
|
|
def test_rejects_invalid_value(self):
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
with pytest.raises(ValidationError):
|
|
MemoryConfig(token_counting="invalid")
|