From a55de566b93a63f0cd96496597108b6311218c14 Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Sun, 26 Apr 2026 10:52:37 +0800 Subject: [PATCH] refactor(backend): consolidate thread_id resolution into shared get_thread_id() utility (#2522) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract duplicated thread_id fallback logic from 11 files into a single deerflow.utils.runtime.get_thread_id() function with a documented 3-level cascade (runtime.context → runtime.config → get_config()). The module docstring also clarifies the __pregel_runtime injection pattern used in gateway mode. --- .../middlewares/loop_detection_middleware.py | 7 +- .../agents/middlewares/memory_middleware.py | 9 +- .../middlewares/sandbox_audit_middleware.py | 11 +-- .../middlewares/summarization_middleware.py | 16 +--- .../middlewares/thread_data_middleware.py | 8 +- .../agents/middlewares/uploads_middleware.py | 10 +-- .../harness/deerflow/sandbox/middleware.py | 3 +- .../harness/deerflow/sandbox/tools.py | 5 +- .../tools/builtins/present_file_tool.py | 21 +---- .../deerflow/tools/builtins/task_tool.py | 5 +- .../deerflow/tools/skill_manage_tool.py | 11 +-- .../harness/deerflow/utils/runtime.py | 88 +++++++++++++++++++ backend/tests/test_runtime_utils.py | 69 +++++++++++++++ backend/tests/test_thread_data_middleware.py | 6 +- 14 files changed, 185 insertions(+), 84 deletions(-) create mode 100644 backend/packages/harness/deerflow/utils/runtime.py create mode 100644 backend/tests/test_runtime_utils.py diff --git a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py index 4c1ba28ec..e0af9da50 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py @@ -25,6 +25,8 @@ from langchain.agents.middleware import AgentMiddleware from langchain_core.messages import HumanMessage from langgraph.runtime import Runtime +from deerflow.utils.runtime import get_thread_id + logger = logging.getLogger(__name__) # Defaults — can be overridden via constructor @@ -183,10 +185,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): def _get_thread_id(self, runtime: Runtime) -> str: """Extract thread_id from runtime context for per-thread tracking.""" - thread_id = runtime.context.get("thread_id") if runtime.context else None - if thread_id: - return thread_id - return "default" + return get_thread_id(runtime) or "default" def _evict_if_needed(self) -> None: """Evict least recently used threads if over the limit. diff --git a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py index f1dccf689..8a7b92208 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py @@ -5,12 +5,12 @@ from typing import override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware -from langgraph.config import get_config from langgraph.runtime import Runtime from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory from deerflow.agents.memory.queue import get_memory_queue from deerflow.config.memory_config import get_memory_config +from deerflow.utils.runtime import get_thread_id logger = logging.getLogger(__name__) @@ -57,11 +57,8 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): if not config.enabled: return None - # Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata - thread_id = runtime.context.get("thread_id") if runtime.context else None - if thread_id is None: - config_data = get_config() - thread_id = config_data.get("configurable", {}).get("thread_id") + # Get thread ID from runtime context + thread_id = get_thread_id(runtime) if not thread_id: logger.debug("No thread_id in context, skipping memory update") return None diff --git a/backend/packages/harness/deerflow/agents/middlewares/sandbox_audit_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/sandbox_audit_middleware.py index e41f5912a..b56bc924a 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/sandbox_audit_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/sandbox_audit_middleware.py @@ -14,6 +14,7 @@ from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.types import Command from deerflow.agents.thread_state import ThreadState +from deerflow.utils.runtime import get_thread_id logger = logging.getLogger(__name__) @@ -218,15 +219,7 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]): # ------------------------------------------------------------------ def _get_thread_id(self, request: ToolCallRequest) -> str | None: - runtime = request.runtime # ToolRuntime; may be None-like in tests - if runtime is None: - return None - ctx = getattr(runtime, "context", None) or {} - thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None - if thread_id is None: - cfg = getattr(runtime, "config", None) or {} - thread_id = cfg.get("configurable", {}).get("thread_id") - return thread_id + return get_thread_id(request.runtime) _AUDIT_COMMAND_LIMIT = 200 diff --git a/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py index 651b64a72..5a8e627b0 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py @@ -14,6 +14,8 @@ from langgraph.config import get_config from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.runtime import Runtime +from deerflow.utils.runtime import get_thread_id + logger = logging.getLogger(__name__) @@ -35,18 +37,6 @@ class BeforeSummarizationHook(Protocol): def __call__(self, event: SummarizationEvent) -> None: ... -def _resolve_thread_id(runtime: Runtime) -> str | None: - """Resolve the current thread ID from runtime context or LangGraph config.""" - thread_id = runtime.context.get("thread_id") if runtime.context else None - if thread_id is None: - try: - config_data = get_config() - except RuntimeError: - return None - thread_id = config_data.get("configurable", {}).get("thread_id") - return thread_id - - def _resolve_agent_name(runtime: Runtime) -> str | None: """Resolve the current agent name from runtime context or LangGraph config.""" agent_name = runtime.context.get("agent_name") if runtime.context else None @@ -334,7 +324,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware): event = SummarizationEvent( messages_to_summarize=tuple(messages_to_summarize), preserved_messages=tuple(preserved_messages), - thread_id=_resolve_thread_id(runtime), + thread_id=get_thread_id(runtime), agent_name=_resolve_agent_name(runtime), runtime=runtime, ) diff --git a/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py index c25531e02..20a0e5c39 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/thread_data_middleware.py @@ -3,11 +3,11 @@ from typing import NotRequired, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware -from langgraph.config import get_config from langgraph.runtime import Runtime from deerflow.agents.thread_state import ThreadDataState from deerflow.config.paths import Paths, get_paths +from deerflow.utils.runtime import get_thread_id logger = logging.getLogger(__name__) @@ -75,11 +75,7 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]): @override def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None: - context = runtime.context or {} - thread_id = context.get("thread_id") - if thread_id is None: - config = get_config() - thread_id = config.get("configurable", {}).get("thread_id") + thread_id = get_thread_id(runtime) if thread_id is None: raise ValueError("Thread ID is required in runtime context or config.configurable") diff --git a/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py index 0fb217bcc..91351bf26 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py @@ -11,6 +11,7 @@ from langgraph.runtime import Runtime from deerflow.config.paths import Paths, get_paths from deerflow.utils.file_conversion import extract_outline +from deerflow.utils.runtime import get_thread_id logger = logging.getLogger(__name__) @@ -213,14 +214,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]): return None # Resolve uploads directory for existence checks - thread_id = (runtime.context or {}).get("thread_id") - if thread_id is None: - try: - from langgraph.config import get_config - - thread_id = get_config().get("configurable", {}).get("thread_id") - except RuntimeError: - pass # get_config() raises outside a runnable context (e.g. unit tests) + thread_id = get_thread_id(runtime) uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None # Get newly uploaded files from the current message's additional_kwargs.files diff --git a/backend/packages/harness/deerflow/sandbox/middleware.py b/backend/packages/harness/deerflow/sandbox/middleware.py index deefc2397..ca08023a8 100644 --- a/backend/packages/harness/deerflow/sandbox/middleware.py +++ b/backend/packages/harness/deerflow/sandbox/middleware.py @@ -7,6 +7,7 @@ from langgraph.runtime import Runtime from deerflow.agents.thread_state import SandboxState, ThreadDataState from deerflow.sandbox import get_sandbox_provider +from deerflow.utils.runtime import get_thread_id logger = logging.getLogger(__name__) @@ -56,7 +57,7 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): # Eager initialization (original behavior) if "sandbox" not in state or state["sandbox"] is None: - thread_id = (runtime.context or {}).get("thread_id") + thread_id = get_thread_id(runtime) if thread_id is None: return super().before_agent(state, runtime) sandbox_id = self._acquire_sandbox(thread_id) diff --git a/backend/packages/harness/deerflow/sandbox/tools.py b/backend/packages/harness/deerflow/sandbox/tools.py index 7b09358e7..86c0fef8c 100644 --- a/backend/packages/harness/deerflow/sandbox/tools.py +++ b/backend/packages/harness/deerflow/sandbox/tools.py @@ -19,6 +19,7 @@ from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.sandbox_provider import get_sandbox_provider from deerflow.sandbox.search import GrepMatch from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed +from deerflow.utils.runtime import get_thread_id _ABSOLUTE_PATH_PATTERN = re.compile(r"(?()]+)") _FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE) @@ -851,9 +852,7 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non # Sandbox was released, fall through to acquire new one # Lazy acquisition: get thread_id and acquire sandbox - thread_id = runtime.context.get("thread_id") if runtime.context else None - if thread_id is None: - thread_id = runtime.config.get("configurable", {}).get("thread_id") if runtime.config else None + thread_id = get_thread_id(runtime) if thread_id is None: raise SandboxRuntimeError("Thread ID not available in runtime context") diff --git a/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py b/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py index 13ddd247e..743e72018 100644 --- a/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py @@ -3,33 +3,16 @@ from typing import Annotated from langchain.tools import InjectedToolCallId, ToolRuntime, tool from langchain_core.messages import ToolMessage -from langgraph.config import get_config from langgraph.types import Command from langgraph.typing import ContextT from deerflow.agents.thread_state import ThreadState from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths +from deerflow.utils.runtime import get_thread_id OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs" -def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None: - """Resolve the current thread id from runtime context or RunnableConfig.""" - thread_id = runtime.context.get("thread_id") if runtime.context else None - if thread_id: - return thread_id - - runtime_config = getattr(runtime, "config", None) or {} - thread_id = runtime_config.get("configurable", {}).get("thread_id") - if thread_id: - return thread_id - - try: - return get_config().get("configurable", {}).get("thread_id") - except RuntimeError: - return None - - def _normalize_presented_filepath( runtime: ToolRuntime[ContextT, ThreadState], filepath: str, @@ -51,7 +34,7 @@ def _normalize_presented_filepath( if runtime.state is None: raise ValueError("Thread runtime state is not available") - thread_id = _get_thread_id(runtime) + thread_id = get_thread_id(runtime) if not thread_id: raise ValueError("Thread ID is not available in runtime context or runtime config") diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index 59613272c..f12e7418e 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -14,6 +14,7 @@ from deerflow.agents.thread_state import ThreadState from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task +from deerflow.utils.runtime import get_thread_id logger = logging.getLogger(__name__) @@ -105,9 +106,7 @@ async def task_tool( if runtime is not None: sandbox_state = runtime.state.get("sandbox") thread_data = runtime.state.get("thread_data") - thread_id = runtime.context.get("thread_id") if runtime.context else None - if thread_id is None: - thread_id = runtime.config.get("configurable", {}).get("thread_id") + thread_id = get_thread_id(runtime) # Try to get parent model from configurable metadata = runtime.config.get("metadata", {}) diff --git a/backend/packages/harness/deerflow/tools/skill_manage_tool.py b/backend/packages/harness/deerflow/tools/skill_manage_tool.py index 3b7a109cc..bce85dc49 100644 --- a/backend/packages/harness/deerflow/tools/skill_manage_tool.py +++ b/backend/packages/harness/deerflow/tools/skill_manage_tool.py @@ -28,6 +28,7 @@ from deerflow.skills.manager import ( validate_skill_name, ) from deerflow.skills.security_scanner import scan_skill_content +from deerflow.utils.runtime import get_thread_id logger = logging.getLogger(__name__) @@ -42,14 +43,6 @@ def _get_lock(name: str) -> asyncio.Lock: return lock -def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None: - if runtime is None: - return None - if runtime.context and runtime.context.get("thread_id"): - return runtime.context.get("thread_id") - return runtime.config.get("configurable", {}).get("thread_id") - - def _history_record(*, action: str, file_path: str, prev_content: str | None, new_content: str | None, thread_id: str | None, scanner: dict[str, Any]) -> dict[str, Any]: return { "action": action, @@ -98,7 +91,7 @@ async def _skill_manage_impl( """ name = validate_skill_name(name) lock = _get_lock(name) - thread_id = _get_thread_id(runtime) + thread_id = get_thread_id(runtime) async with lock: if action == "create": diff --git a/backend/packages/harness/deerflow/utils/runtime.py b/backend/packages/harness/deerflow/utils/runtime.py new file mode 100644 index 000000000..84776b3d9 --- /dev/null +++ b/backend/packages/harness/deerflow/utils/runtime.py @@ -0,0 +1,88 @@ +"""Runtime utilities for thread_id resolution and context access. + +Thread ID Resolution Strategy +============================= + +DeerFlow resolves the current ``thread_id`` from a three-level cascade: + +1. **runtime.context["thread_id"]** -- Set by ``worker.py`` (gateway mode) + or by LangGraph Server (standard mode) when constructing the Runtime. +2. **runtime.config["configurable"]["thread_id"]** -- Available on + ``ToolRuntime`` instances passed to tools via the ``@tool`` decorator. + Not available on ``Runtime`` instances received by middlewares. +3. **get_config()["configurable"]["thread_id"]** -- LangGraph's thread-local + config, available when executing inside a graph's runnable context. + +About ``__pregel_runtime`` +=========================== + +In gateway mode (``run_agent()`` in ``worker.py``), the agent graph does not +run inside the LangGraph Server. The server normally injects a ``Runtime`` +object automatically. Since we run the graph ourselves, we must inject the +Runtime manually via ``config["configurable"]["__pregel_runtime"]``. This is +the standard mechanism provided by LangGraph's Pregel engine for injecting +runtime context into graph nodes. It is not a private/internal hack -- it is +the documented way to pass Runtime when running a graph outside the server. + +Duck Typing +=========== + +Both ``langgraph.runtime.Runtime`` (middlewares) and +``langchain.tools.ToolRuntime`` (tools) expose a ``.context`` attribute (a +dict or None). ``ToolRuntime`` additionally exposes ``.config``. The +function below uses ``getattr`` with safe defaults so it works with either +type, with ``SimpleNamespace`` in tests, or with ``None``. +""" + +from __future__ import annotations + +from typing import Any + + +def get_thread_id(runtime: Any | None) -> str | None: + """Resolve the current thread_id from a runtime object. + + Follows a three-level fallback chain: + + 1. ``runtime.context.get("thread_id")`` -- if context is a non-empty dict. + 2. ``runtime.config.get("configurable", {}).get("thread_id")`` -- if + the runtime has a config dict (ToolRuntime). + 3. ``get_config().get("configurable", {}).get("thread_id")`` -- LangGraph's + thread-local config. Wrapped in ``try/except RuntimeError`` because it + raises outside a runnable context (e.g., unit tests). + + Args: + runtime: A Runtime, ToolRuntime, SimpleNamespace, or None. + + Returns: + The thread_id string, or None if it cannot be resolved. + """ + if runtime is None: + return None + + # Level 1: runtime.context["thread_id"] + context = getattr(runtime, "context", None) + if context and isinstance(context, dict): + thread_id = context.get("thread_id") + if thread_id: + return thread_id + + # Level 2: runtime.config["configurable"]["thread_id"] + config = getattr(runtime, "config", None) + if config and isinstance(config, dict): + thread_id = config.get("configurable", {}).get("thread_id") + if thread_id: + return thread_id + + # Level 3: langgraph.config.get_config() -- only works inside runnable context + try: + from langgraph.config import get_config + + config_data = get_config() + thread_id = config_data.get("configurable", {}).get("thread_id") + if thread_id: + return thread_id + except RuntimeError: + pass + + return None diff --git a/backend/tests/test_runtime_utils.py b/backend/tests/test_runtime_utils.py new file mode 100644 index 000000000..2f6cb5a11 --- /dev/null +++ b/backend/tests/test_runtime_utils.py @@ -0,0 +1,69 @@ +"""Tests for deerflow.utils.runtime.get_thread_id.""" + +from types import SimpleNamespace +from unittest.mock import patch + +from deerflow.utils.runtime import get_thread_id + + +class TestGetThreadId: + """Tests for get_thread_id() with various runtime shapes.""" + + def test_returns_none_when_runtime_is_none(self): + assert get_thread_id(None) is None + + def test_returns_thread_id_from_context(self): + runtime = SimpleNamespace(context={"thread_id": "t-1"}, config={}) + assert get_thread_id(runtime) == "t-1" + + def test_returns_none_from_empty_context(self): + runtime = SimpleNamespace(context={}, config={}) + assert get_thread_id(runtime) is None + + def test_returns_none_from_none_context(self): + runtime = SimpleNamespace(context=None, config={}) + assert get_thread_id(runtime) is None + + def test_falls_back_to_runtime_config(self): + runtime = SimpleNamespace( + context=None, + config={"configurable": {"thread_id": "t-from-config"}}, + ) + assert get_thread_id(runtime) == "t-from-config" + + def test_context_takes_precedence_over_config(self): + runtime = SimpleNamespace( + context={"thread_id": "t-from-context"}, + config={"configurable": {"thread_id": "t-from-config"}}, + ) + assert get_thread_id(runtime) == "t-from-context" + + def test_falls_back_to_get_config(self): + runtime = SimpleNamespace(context=None, config={}) + with patch("langgraph.config.get_config", return_value={"configurable": {"thread_id": "t-from-lg"}}): + assert get_thread_id(runtime) == "t-from-lg" + + def test_returns_none_when_get_config_raises_runtime_error(self): + runtime = SimpleNamespace(context=None, config={}) + assert get_thread_id(runtime) is None + + def test_handles_object_without_context_or_config(self): + runtime = SimpleNamespace() + assert get_thread_id(runtime) is None + + def test_handles_context_not_dict(self): + runtime = SimpleNamespace(context="not-a-dict", config={}) + assert get_thread_id(runtime) is None + + def test_config_without_configurable(self): + runtime = SimpleNamespace(context=None, config={"other_key": "value"}) + assert get_thread_id(runtime) is None + + def test_empty_string_thread_id_treated_as_missing(self): + runtime = SimpleNamespace(context={"thread_id": ""}, config={}) + assert get_thread_id(runtime) is None + + def test_full_cascade_with_all_levels_failing(self): + runtime = SimpleNamespace(context=None, config={}) + with patch("langgraph.config.get_config", return_value={"configurable": {}}): + assert get_thread_id(runtime) is None diff --git a/backend/tests/test_thread_data_middleware.py b/backend/tests/test_thread_data_middleware.py index ef3e440f7..737808791 100644 --- a/backend/tests/test_thread_data_middleware.py +++ b/backend/tests/test_thread_data_middleware.py @@ -23,7 +23,7 @@ class TestThreadDataMiddleware: middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True) runtime = Runtime(context=None) monkeypatch.setattr( - "deerflow.agents.middlewares.thread_data_middleware.get_config", + "langgraph.config.get_config", lambda: {"configurable": {"thread_id": "thread-from-config"}}, ) @@ -37,7 +37,7 @@ class TestThreadDataMiddleware: middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True) runtime = Runtime(context={}) monkeypatch.setattr( - "deerflow.agents.middlewares.thread_data_middleware.get_config", + "langgraph.config.get_config", lambda: {"configurable": {"thread_id": "thread-from-config"}}, ) @@ -50,7 +50,7 @@ class TestThreadDataMiddleware: def test_before_agent_raises_clear_error_when_thread_id_missing_everywhere(self, tmp_path, monkeypatch): middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True) monkeypatch.setattr( - "deerflow.agents.middlewares.thread_data_middleware.get_config", + "langgraph.config.get_config", lambda: {"configurable": {}}, )