mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 17:06:00 +00:00
refactor(backend): consolidate thread_id resolution into shared get_thread_id() utility (#2522)
Extract duplicated thread_id fallback logic from 11 files into a single deerflow.utils.runtime.get_thread_id() function with a documented 3-level cascade (runtime.context → runtime.config → get_config()). The module docstring also clarifies the __pregel_runtime injection pattern used in gateway mode.
This commit is contained in:
@@ -25,6 +25,8 @@ from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.utils.runtime import get_thread_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults — can be overridden via constructor
|
||||
@@ -183,10 +185,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
|
||||
def _get_thread_id(self, runtime: Runtime) -> str:
|
||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id:
|
||||
return thread_id
|
||||
return "default"
|
||||
return get_thread_id(runtime) or "default"
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict least recently used threads if over the limit.
|
||||
|
||||
@@ -5,12 +5,12 @@ from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.utils.runtime import get_thread_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,11 +57,8 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
if not config.enabled:
|
||||
return None
|
||||
|
||||
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id is None:
|
||||
config_data = get_config()
|
||||
thread_id = config_data.get("configurable", {}).get("thread_id")
|
||||
# Get thread ID from runtime context
|
||||
thread_id = get_thread_id(runtime)
|
||||
if not thread_id:
|
||||
logger.debug("No thread_id in context, skipping memory update")
|
||||
return None
|
||||
|
||||
@@ -14,6 +14,7 @@ from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.utils.runtime import get_thread_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -218,15 +219,7 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_thread_id(self, request: ToolCallRequest) -> str | None:
|
||||
runtime = request.runtime # ToolRuntime; may be None-like in tests
|
||||
if runtime is None:
|
||||
return None
|
||||
ctx = getattr(runtime, "context", None) or {}
|
||||
thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None
|
||||
if thread_id is None:
|
||||
cfg = getattr(runtime, "config", None) or {}
|
||||
thread_id = cfg.get("configurable", {}).get("thread_id")
|
||||
return thread_id
|
||||
return get_thread_id(request.runtime)
|
||||
|
||||
_AUDIT_COMMAND_LIMIT = 200
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ from langgraph.config import get_config
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.utils.runtime import get_thread_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -35,18 +37,6 @@ class BeforeSummarizationHook(Protocol):
|
||||
def __call__(self, event: SummarizationEvent) -> None: ...
|
||||
|
||||
|
||||
def _resolve_thread_id(runtime: Runtime) -> str | None:
|
||||
"""Resolve the current thread ID from runtime context or LangGraph config."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id is None:
|
||||
try:
|
||||
config_data = get_config()
|
||||
except RuntimeError:
|
||||
return None
|
||||
thread_id = config_data.get("configurable", {}).get("thread_id")
|
||||
return thread_id
|
||||
|
||||
|
||||
def _resolve_agent_name(runtime: Runtime) -> str | None:
|
||||
"""Resolve the current agent name from runtime context or LangGraph config."""
|
||||
agent_name = runtime.context.get("agent_name") if runtime.context else None
|
||||
@@ -334,7 +324,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
event = SummarizationEvent(
|
||||
messages_to_summarize=tuple(messages_to_summarize),
|
||||
preserved_messages=tuple(preserved_messages),
|
||||
thread_id=_resolve_thread_id(runtime),
|
||||
thread_id=get_thread_id(runtime),
|
||||
agent_name=_resolve_agent_name(runtime),
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
@@ -3,11 +3,11 @@ from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.utils.runtime import get_thread_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -75,11 +75,7 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
|
||||
@override
|
||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
context = runtime.context or {}
|
||||
thread_id = context.get("thread_id")
|
||||
if thread_id is None:
|
||||
config = get_config()
|
||||
thread_id = config.get("configurable", {}).get("thread_id")
|
||||
thread_id = get_thread_id(runtime)
|
||||
|
||||
if thread_id is None:
|
||||
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
||||
|
||||
@@ -11,6 +11,7 @@ from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.utils.file_conversion import extract_outline
|
||||
from deerflow.utils.runtime import get_thread_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -213,14 +214,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
return None
|
||||
|
||||
# Resolve uploads directory for existence checks
|
||||
thread_id = (runtime.context or {}).get("thread_id")
|
||||
if thread_id is None:
|
||||
try:
|
||||
from langgraph.config import get_config
|
||||
|
||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||||
except RuntimeError:
|
||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||||
thread_id = get_thread_id(runtime)
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||
|
||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||
|
||||
Reference in New Issue
Block a user