mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
fix(middleware): offload memory injection off event loop to prevent tiktoken blocking (#3402) (#3411)
* fix(middleware): offload memory injection off event loop to prevent tiktoken blocking (#3402) DynamicContextMiddleware.abefore_agent() called _inject() synchronously on the asyncio event loop. The first time memory is injected (second request), _inject() → format_memory_for_injection() → _count_tokens() → tiktoken.get_encoding("cl100k_base") needs to download the BPE data from openaipublic.blob.core.windows.net. In network-restricted environments this download blocks until the OS TCP timeout (~26 min), starving ALL concurrent handlers including /api/v1/auth/me. Fix: - abefore_agent now uses asyncio.to_thread(self._inject, state) so file I/O and tiktoken never block the event loop. - Extract _get_tiktoken_encoding() with a module-level cache so tiktoken.get_encoding() is called at most once per encoding name. - Add warm_tiktoken_cache() startup helper; gateway lifespan pre-warms the cache via asyncio.to_thread so the first request never triggers a cold download. - _count_tokens falls back to len(text) // 4 on any encoding failure. Tests: - tests/test_tiktoken_cache_and_count_tokens.py (12 tests): cache hit/miss, fallback paths, warm-up helper. - tests/blocking_io/test_dynamic_context_middleware.py (2 tests): Blockbuster gate verifies abefore_agent does not block the event loop; async/sync parity check. Fixes #3402 * Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix the lint error * fix(memory): use future annotations to avoid NameError when tiktoken is absent Add `from __future__ import annotations` to prompt.py so that tiktoken.Encoding type hints are never evaluated at runtime. Without this, environments where tiktoken is not installed could raise NameError on the module-level cache and function return annotations. Addresses Copilot review comment on PR #3411. * fix(middleware): bound abefore_agent injection with timeout to prevent hung requests Wrap the asyncio.to_thread(self._inject) offload in asyncio.wait_for() with a 5-second cap. If the startup warm-up failed silently (e.g. network blip during deploy), a cold tiktoken BPE download on the first request can block until the OS TCP timeout (~26 min). The bounded timeout ensures the request degrades gracefully (no memory/date context for that turn) rather than hanging. Adds test_abefore_agent_returns_none_on_timeout to the blocking-IO regression anchors. Addresses review feedback from xg-gh-25 on PR #3411. --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -179,6 +179,25 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
config = get_gateway_config()
|
config = get_gateway_config()
|
||||||
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
||||||
|
|
||||||
|
# Pre-warm tiktoken encoding cache so the first memory-injection request
|
||||||
|
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
|
||||||
|
# that may be unreachable in restricted networks — see issue #3402).
|
||||||
|
try:
|
||||||
|
from deerflow.agents.memory.prompt import warm_tiktoken_cache
|
||||||
|
|
||||||
|
warmed = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(warm_tiktoken_cache),
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if warmed:
|
||||||
|
logger.info("tiktoken encoding cache warmed successfully")
|
||||||
|
else:
|
||||||
|
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback")
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback")
|
||||||
|
except Exception:
|
||||||
|
logger.warning("tiktoken warm-up skipped", exc_info=True)
|
||||||
|
|
||||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
||||||
async with langgraph_runtime(app, startup_config):
|
async with langgraph_runtime(app, startup_config):
|
||||||
logger.info("LangGraph runtime initialised")
|
logger.info("LangGraph runtime initialised")
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
"""Prompt templates for memory update and injection."""
|
"""Prompt templates for memory update and injection."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
@@ -160,6 +165,39 @@ Rules:
|
|||||||
Return ONLY valid JSON."""
|
Return ONLY valid JSON."""
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level tiktoken encoding cache. Populated lazily on first use;
|
||||||
|
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
|
||||||
|
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
|
||||||
|
# (potentially slow) first ``get_encoding`` call.
|
||||||
|
_tiktoken_encoding_cache: dict[str, tiktoken.Encoding] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
|
||||||
|
"""Return a cached tiktoken encoding, or ``None`` on failure / unavailability.
|
||||||
|
|
||||||
|
On the very first call for a given *encoding_name*, tiktoken may need to
|
||||||
|
download the BPE data from ``openaipublic.blob.core.windows.net``. In
|
||||||
|
network-restricted environments (e.g. deployments behind the GFW) this
|
||||||
|
download can block for tens of minutes before the OS TCP timeout kicks in.
|
||||||
|
The caller must therefore be prepared for this to block and should run it
|
||||||
|
off the event loop (e.g. via ``asyncio.to_thread``).
|
||||||
|
"""
|
||||||
|
if not TIKTOKEN_AVAILABLE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cached = _tiktoken_encoding_cache.get(encoding_name)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.get_encoding(encoding_name)
|
||||||
|
_tiktoken_encoding_cache[encoding_name] = encoding
|
||||||
|
return encoding
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
||||||
"""Count tokens in text using tiktoken.
|
"""Count tokens in text using tiktoken.
|
||||||
|
|
||||||
@@ -170,18 +208,30 @@ def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
|||||||
Returns:
|
Returns:
|
||||||
The number of tokens in the text.
|
The number of tokens in the text.
|
||||||
"""
|
"""
|
||||||
if not TIKTOKEN_AVAILABLE:
|
encoding = _get_tiktoken_encoding(encoding_name)
|
||||||
|
if encoding is None:
|
||||||
# Fallback to character-based estimation if tiktoken is not available
|
# Fallback to character-based estimation if tiktoken is not available
|
||||||
|
# or the encoding failed to load.
|
||||||
return len(text) // 4
|
return len(text) // 4
|
||||||
|
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.get_encoding(encoding_name)
|
|
||||||
return len(encoding.encode(text))
|
return len(encoding.encode(text))
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fallback to character-based estimation on error
|
# Fallback to character-based estimation on error
|
||||||
return len(text) // 4
|
return len(text) // 4
|
||||||
|
|
||||||
|
|
||||||
|
def warm_tiktoken_cache() -> bool:
|
||||||
|
"""Pre-warm the tiktoken encoding cache.
|
||||||
|
|
||||||
|
Call at startup (off the event loop) so the first request never blocks
|
||||||
|
on the BPE download. Returns ``True`` if the encoding was loaded
|
||||||
|
successfully (or was already cached), ``False`` if tiktoken is
|
||||||
|
unavailable or the download failed.
|
||||||
|
"""
|
||||||
|
return _get_tiktoken_encoding("cl100k_base") is not None
|
||||||
|
|
||||||
|
|
||||||
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
||||||
"""Coerce a confidence-like value to a bounded float in [0, 1].
|
"""Coerce a confidence-like value to a bounded float in [0, 1].
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ Date-update format:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
@@ -43,6 +44,12 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Upper bound (seconds) for a single _inject() offload. If the warm-up at
|
||||||
|
# gateway startup failed silently, the first request may still hit a cold
|
||||||
|
# tiktoken BPE download that blocks until the OS TCP timeout (~26 min).
|
||||||
|
# This cap ensures the request degrades gracefully instead of hanging.
|
||||||
|
_INJECT_TIMEOUT_SECONDS = 5.0
|
||||||
|
|
||||||
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
|
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
|
||||||
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
|
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
|
||||||
_SUMMARY_MESSAGE_NAME = "summary"
|
_SUMMARY_MESSAGE_NAME = "summary"
|
||||||
@@ -201,4 +208,25 @@ class DynamicContextMiddleware(AgentMiddleware):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
|
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
|
||||||
return self._inject(state)
|
# _inject() performs synchronous file I/O (memory JSON loading) and
|
||||||
|
# potentially blocking network calls (tiktoken encoding download on
|
||||||
|
# first use). Offload to a thread so the event loop is never blocked
|
||||||
|
# — a blocking call here starves all concurrent HTTP handlers (auth,
|
||||||
|
# SSE heartbeats, etc.). See issue #3402.
|
||||||
|
#
|
||||||
|
# Bounded timeout: if startup warm-up failed silently (e.g. network
|
||||||
|
# blip during deploy), the first request's cold tiktoken download can
|
||||||
|
# block for tens of minutes (OS TCP timeout). Time-box injection so
|
||||||
|
# the request degrades gracefully (no memory context) rather than
|
||||||
|
# hanging.
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(self._inject, state),
|
||||||
|
timeout=_INJECT_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"DynamicContextMiddleware: injection timed out (%.1fs); skipping memory/date injection for this turn",
|
||||||
|
_INJECT_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|||||||
@@ -0,0 +1,124 @@
|
|||||||
|
"""Regression anchor: DynamicContextMiddleware must not block the event loop.
|
||||||
|
|
||||||
|
``_inject`` performs synchronous file I/O (memory JSON loading) and
|
||||||
|
potentially blocking network calls (tiktoken encoding download on first
|
||||||
|
use — see issue #3402). ``abefore_agent`` offloads the call via
|
||||||
|
``asyncio.to_thread`` so the event loop stays responsive.
|
||||||
|
|
||||||
|
This anchor drives the real ``create_agent`` graph via ``ainvoke`` under
|
||||||
|
the strict Blockbuster gate. If the offload regresses and the blocking
|
||||||
|
I/O runs on the event loop, Blockbuster raises ``BlockingError`` and
|
||||||
|
this test fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
from deerflow.agents.middlewares.dynamic_context_middleware import DynamicContextMiddleware
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeModel(FakeMessagesListChatModel):
|
||||||
|
"""FakeMessagesListChatModel with a no-op ``bind_tools`` for create_agent."""
|
||||||
|
|
||||||
|
def bind_tools(self, tools, **kwargs): # type: ignore[override]
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abefore_agent_does_not_block_event_loop() -> None:
|
||||||
|
"""``abefore_agent`` must offload _inject() to a thread pool."""
|
||||||
|
mw = DynamicContextMiddleware()
|
||||||
|
|
||||||
|
# Mock _build_full_reminder to simulate a slow synchronous operation
|
||||||
|
# (file I/O + tiktoken download). The mock sleeps briefly to make any
|
||||||
|
# event-loop blocking visible to the Blockbuster gate.
|
||||||
|
original_build = mw._build_full_reminder
|
||||||
|
|
||||||
|
def slow_build_reminder():
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(0.05) # 50ms sync sleep — blocks the thread it runs on
|
||||||
|
return original_build()
|
||||||
|
|
||||||
|
with (
|
||||||
|
mock.patch.object(mw, "_build_full_reminder", slow_build_reminder),
|
||||||
|
mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""),
|
||||||
|
):
|
||||||
|
agent = await asyncio.to_thread(
|
||||||
|
lambda: create_agent(
|
||||||
|
model=_FakeModel(responses=[AIMessage(content="ok")]),
|
||||||
|
tools=[],
|
||||||
|
middleware=[mw],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await agent.ainvoke(
|
||||||
|
{"messages": [HumanMessage(content="hi")]},
|
||||||
|
{"configurable": {"thread_id": "test-thread"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["messages"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abefore_agent_returns_same_result_as_before_agent() -> None:
|
||||||
|
"""``abefore_agent`` (async, offloaded) must produce the same result as
|
||||||
|
``before_agent`` (sync, for backward compatibility)."""
|
||||||
|
mw = DynamicContextMiddleware()
|
||||||
|
|
||||||
|
state = {"messages": [HumanMessage(content="Hello", id="msg-1")]}
|
||||||
|
runtime = SimpleNamespace(context={})
|
||||||
|
|
||||||
|
with (
|
||||||
|
mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""),
|
||||||
|
mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt,
|
||||||
|
):
|
||||||
|
mock_dt.now.return_value.strftime.return_value = "2026-06-05, Friday"
|
||||||
|
|
||||||
|
# Sync path
|
||||||
|
sync_result = mw.before_agent(state, runtime)
|
||||||
|
|
||||||
|
# Async path (offloaded to thread)
|
||||||
|
async_result = await mw.abefore_agent(state, runtime)
|
||||||
|
|
||||||
|
assert sync_result is not None
|
||||||
|
assert async_result is not None
|
||||||
|
assert sync_result.keys() == async_result.keys()
|
||||||
|
# Both return 2 messages: reminder + user content
|
||||||
|
assert len(sync_result["messages"]) == 2
|
||||||
|
assert len(async_result["messages"]) == 2
|
||||||
|
# IDs match
|
||||||
|
assert sync_result["messages"][0].id == async_result["messages"][0].id
|
||||||
|
assert sync_result["messages"][1].id == async_result["messages"][1].id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abefore_agent_returns_none_on_timeout() -> None:
|
||||||
|
"""If _inject() exceeds the timeout, abefore_agent returns None gracefully."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
mw = DynamicContextMiddleware()
|
||||||
|
|
||||||
|
def blocking_inject(state):
|
||||||
|
time.sleep(10) # Simulate a blocking call that far exceeds the timeout
|
||||||
|
return {"messages": [HumanMessage(content="should not reach")]}
|
||||||
|
|
||||||
|
with (
|
||||||
|
mock.patch.object(mw, "_inject", blocking_inject),
|
||||||
|
mock.patch(
|
||||||
|
"deerflow.agents.middlewares.dynamic_context_middleware._INJECT_TIMEOUT_SECONDS",
|
||||||
|
0.1,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
state = {"messages": [HumanMessage(content="Hello", id="msg-1")]}
|
||||||
|
runtime = SimpleNamespace(context={})
|
||||||
|
result = await mw.abefore_agent(state, runtime)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
@@ -0,0 +1,148 @@
|
|||||||
|
"""Tests for tiktoken encoding cache and _count_tokens fallback.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- Module-level cache avoids repeated ``get_encoding`` calls.
|
||||||
|
- ``_count_tokens`` falls back to character estimation when tiktoken is
|
||||||
|
unavailable or the encoding fails to load.
|
||||||
|
- ``warm_tiktoken_cache`` populates the cache on success.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from deerflow.agents.memory.prompt import (
|
||||||
|
_count_tokens,
|
||||||
|
_get_tiktoken_encoding,
|
||||||
|
_tiktoken_encoding_cache,
|
||||||
|
warm_tiktoken_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _get_tiktoken_encoding
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetTiktokenEncoding:
|
||||||
|
"""Tests for _get_tiktoken_encoding caching and fallback."""
|
||||||
|
|
||||||
|
def test_returns_none_when_tiktoken_unavailable(self, monkeypatch):
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
||||||
|
assert _get_tiktoken_encoding("cl100k_base") is None
|
||||||
|
|
||||||
|
def test_returns_encoding_on_success(self, monkeypatch):
|
||||||
|
# Clear cache to ensure a fresh call
|
||||||
|
_tiktoken_encoding_cache.pop("cl100k_base", None)
|
||||||
|
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", mock.Mock(return_value=fake_enc))
|
||||||
|
|
||||||
|
enc = _get_tiktoken_encoding("cl100k_base")
|
||||||
|
assert enc is fake_enc
|
||||||
|
|
||||||
|
def test_populates_cache_on_success(self, monkeypatch):
|
||||||
|
_tiktoken_encoding_cache.pop("cl100k_base", None)
|
||||||
|
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", mock.Mock(return_value=fake_enc))
|
||||||
|
|
||||||
|
_get_tiktoken_encoding("cl100k_base")
|
||||||
|
assert _tiktoken_encoding_cache["cl100k_base"] is fake_enc
|
||||||
|
|
||||||
|
def test_returns_cached_encoding_without_calling_get_encoding(self, monkeypatch):
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
monkeypatch.setitem(_tiktoken_encoding_cache, "cl100k_base", fake_enc)
|
||||||
|
|
||||||
|
# Now patch tiktoken.get_encoding to raise if called
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=RuntimeError("should not be called")))
|
||||||
|
# Cached path — should NOT call get_encoding
|
||||||
|
enc = _get_tiktoken_encoding("cl100k_base")
|
||||||
|
assert enc is fake_enc
|
||||||
|
tiktoken.get_encoding.assert_not_called()
|
||||||
|
|
||||||
|
def test_returns_none_and_warns_on_get_encoding_failure(self, monkeypatch):
|
||||||
|
_tiktoken_encoding_cache.pop("bogus_encoding", None)
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=OSError("download failed")))
|
||||||
|
result = _get_tiktoken_encoding("bogus_encoding")
|
||||||
|
assert result is None
|
||||||
|
assert "bogus_encoding" not in _tiktoken_encoding_cache
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _count_tokens
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCountTokens:
|
||||||
|
"""Tests for _count_tokens fallback behaviour."""
|
||||||
|
|
||||||
|
def test_returns_character_estimate_when_tiktoken_unavailable(self, monkeypatch):
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
||||||
|
text = "Hello, world! This is a test."
|
||||||
|
result = _count_tokens(text)
|
||||||
|
assert result == len(text) // 4
|
||||||
|
|
||||||
|
def test_returns_character_estimate_when_encoding_fails(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"deerflow.agents.memory.prompt._get_tiktoken_encoding",
|
||||||
|
lambda _name=None: None,
|
||||||
|
)
|
||||||
|
text = "Some text to count"
|
||||||
|
result = _count_tokens(text)
|
||||||
|
assert result == len(text) // 4
|
||||||
|
|
||||||
|
def test_returns_token_count_on_success(self, monkeypatch):
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
fake_enc.encode.return_value = [0, 1, 2, 3]
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", mock.Mock(return_value=fake_enc))
|
||||||
|
|
||||||
|
text = "Hello, world!"
|
||||||
|
result = _count_tokens(text)
|
||||||
|
assert result == 4
|
||||||
|
assert result <= len(text)
|
||||||
|
|
||||||
|
def test_falls_back_on_encode_exception(self, monkeypatch):
|
||||||
|
# Cache an encoding whose .encode raises
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
fake_enc.encode.side_effect = RuntimeError("encode failed")
|
||||||
|
monkeypatch.setitem(_tiktoken_encoding_cache, "test_enc", fake_enc)
|
||||||
|
|
||||||
|
text = "Fallback test"
|
||||||
|
result = _count_tokens(text, encoding_name="test_enc")
|
||||||
|
assert result == len(text) // 4
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# warm_tiktoken_cache
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestWarmTiktokenCache:
|
||||||
|
"""Tests for warm_tiktoken_cache startup helper."""
|
||||||
|
|
||||||
|
def test_returns_true_on_success(self, monkeypatch):
|
||||||
|
_tiktoken_encoding_cache.pop("cl100k_base", None)
|
||||||
|
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", mock.Mock(return_value=fake_enc))
|
||||||
|
|
||||||
|
assert warm_tiktoken_cache() is True
|
||||||
|
assert _tiktoken_encoding_cache["cl100k_base"] is fake_enc
|
||||||
|
|
||||||
|
def test_returns_true_if_already_cached(self, monkeypatch):
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
monkeypatch.setitem(_tiktoken_encoding_cache, "cl100k_base", fake_enc)
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=RuntimeError("should not be called")))
|
||||||
|
assert warm_tiktoken_cache() is True
|
||||||
|
tiktoken.get_encoding.assert_not_called()
|
||||||
|
|
||||||
|
def test_returns_false_when_tiktoken_unavailable(self, monkeypatch):
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
||||||
|
assert warm_tiktoken_cache() is False
|
||||||
Reference in New Issue
Block a user