Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7052978a43 | |||
| d9f7f658be | |||
| a55de566b9 |
@@ -25,6 +25,8 @@ from langchain.agents.middleware import AgentMiddleware
|
|||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from deerflow.utils.runtime import get_thread_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Defaults — can be overridden via constructor
|
# Defaults — can be overridden via constructor
|
||||||
@@ -183,10 +185,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
|
|
||||||
def _get_thread_id(self, runtime: Runtime) -> str:
|
def _get_thread_id(self, runtime: Runtime) -> str:
|
||||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
return get_thread_id(runtime) or "default"
|
||||||
if thread_id:
|
|
||||||
return thread_id
|
|
||||||
return "default"
|
|
||||||
|
|
||||||
def _evict_if_needed(self) -> None:
|
def _evict_if_needed(self) -> None:
|
||||||
"""Evict least recently used threads if over the limit.
|
"""Evict least recently used threads if over the limit.
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ from typing import override
|
|||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langgraph.config import get_config
|
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
||||||
from deerflow.agents.memory.queue import get_memory_queue
|
from deerflow.agents.memory.queue import get_memory_queue
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
|
from deerflow.utils.runtime import get_thread_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -57,13 +57,10 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
|||||||
if not config.enabled:
|
if not config.enabled:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata
|
# Resolve thread ID from the runtime or configured fallback sources
|
||||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
thread_id = get_thread_id(runtime)
|
||||||
if thread_id is None:
|
|
||||||
config_data = get_config()
|
|
||||||
thread_id = config_data.get("configurable", {}).get("thread_id")
|
|
||||||
if not thread_id:
|
if not thread_id:
|
||||||
logger.debug("No thread_id in context, skipping memory update")
|
logger.debug("No thread_id could be resolved from runtime/config, skipping memory update")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get messages from state
|
# Get messages from state
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from langgraph.prebuilt.tool_node import ToolCallRequest
|
|||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
from deerflow.agents.thread_state import ThreadState
|
from deerflow.agents.thread_state import ThreadState
|
||||||
|
from deerflow.utils.runtime import get_thread_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -218,15 +219,7 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _get_thread_id(self, request: ToolCallRequest) -> str | None:
|
def _get_thread_id(self, request: ToolCallRequest) -> str | None:
|
||||||
runtime = request.runtime # ToolRuntime; may be None-like in tests
|
return get_thread_id(request.runtime)
|
||||||
if runtime is None:
|
|
||||||
return None
|
|
||||||
ctx = getattr(runtime, "context", None) or {}
|
|
||||||
thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None
|
|
||||||
if thread_id is None:
|
|
||||||
cfg = getattr(runtime, "config", None) or {}
|
|
||||||
thread_id = cfg.get("configurable", {}).get("thread_id")
|
|
||||||
return thread_id
|
|
||||||
|
|
||||||
_AUDIT_COMMAND_LIMIT = 200
|
_AUDIT_COMMAND_LIMIT = 200
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from langgraph.config import get_config
|
|||||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from deerflow.utils.runtime import get_thread_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -35,18 +37,6 @@ class BeforeSummarizationHook(Protocol):
|
|||||||
def __call__(self, event: SummarizationEvent) -> None: ...
|
def __call__(self, event: SummarizationEvent) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
def _resolve_thread_id(runtime: Runtime) -> str | None:
|
|
||||||
"""Resolve the current thread ID from runtime context or LangGraph config."""
|
|
||||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
|
||||||
if thread_id is None:
|
|
||||||
try:
|
|
||||||
config_data = get_config()
|
|
||||||
except RuntimeError:
|
|
||||||
return None
|
|
||||||
thread_id = config_data.get("configurable", {}).get("thread_id")
|
|
||||||
return thread_id
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_agent_name(runtime: Runtime) -> str | None:
|
def _resolve_agent_name(runtime: Runtime) -> str | None:
|
||||||
"""Resolve the current agent name from runtime context or LangGraph config."""
|
"""Resolve the current agent name from runtime context or LangGraph config."""
|
||||||
agent_name = runtime.context.get("agent_name") if runtime.context else None
|
agent_name = runtime.context.get("agent_name") if runtime.context else None
|
||||||
@@ -334,7 +324,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|||||||
event = SummarizationEvent(
|
event = SummarizationEvent(
|
||||||
messages_to_summarize=tuple(messages_to_summarize),
|
messages_to_summarize=tuple(messages_to_summarize),
|
||||||
preserved_messages=tuple(preserved_messages),
|
preserved_messages=tuple(preserved_messages),
|
||||||
thread_id=_resolve_thread_id(runtime),
|
thread_id=get_thread_id(runtime),
|
||||||
agent_name=_resolve_agent_name(runtime),
|
agent_name=_resolve_agent_name(runtime),
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ from typing import NotRequired, override
|
|||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langgraph.config import get_config
|
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from deerflow.agents.thread_state import ThreadDataState
|
from deerflow.agents.thread_state import ThreadDataState
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
|
from deerflow.utils.runtime import get_thread_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -75,11 +75,7 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||||
context = runtime.context or {}
|
thread_id = get_thread_id(runtime)
|
||||||
thread_id = context.get("thread_id")
|
|
||||||
if thread_id is None:
|
|
||||||
config = get_config()
|
|
||||||
thread_id = config.get("configurable", {}).get("thread_id")
|
|
||||||
|
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from langgraph.runtime import Runtime
|
|||||||
|
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.utils.file_conversion import extract_outline
|
from deerflow.utils.file_conversion import extract_outline
|
||||||
|
from deerflow.utils.runtime import get_thread_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -213,14 +214,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Resolve uploads directory for existence checks
|
# Resolve uploads directory for existence checks
|
||||||
thread_id = (runtime.context or {}).get("thread_id")
|
thread_id = get_thread_id(runtime)
|
||||||
if thread_id is None:
|
|
||||||
try:
|
|
||||||
from langgraph.config import get_config
|
|
||||||
|
|
||||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
|
||||||
except RuntimeError:
|
|
||||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
|
||||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||||
|
|
||||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from langgraph.runtime import Runtime
|
|||||||
|
|
||||||
from deerflow.agents.thread_state import SandboxState, ThreadDataState
|
from deerflow.agents.thread_state import SandboxState, ThreadDataState
|
||||||
from deerflow.sandbox import get_sandbox_provider
|
from deerflow.sandbox import get_sandbox_provider
|
||||||
|
from deerflow.utils.runtime import get_thread_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -56,7 +57,7 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
|||||||
|
|
||||||
# Eager initialization (original behavior)
|
# Eager initialization (original behavior)
|
||||||
if "sandbox" not in state or state["sandbox"] is None:
|
if "sandbox" not in state or state["sandbox"] is None:
|
||||||
thread_id = (runtime.context or {}).get("thread_id")
|
thread_id = get_thread_id(runtime)
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
return super().before_agent(state, runtime)
|
return super().before_agent(state, runtime)
|
||||||
sandbox_id = self._acquire_sandbox(thread_id)
|
sandbox_id = self._acquire_sandbox(thread_id)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from deerflow.sandbox.sandbox import Sandbox
|
|||||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
from deerflow.sandbox.search import GrepMatch
|
from deerflow.sandbox.search import GrepMatch
|
||||||
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
|
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
|
||||||
|
from deerflow.utils.runtime import get_thread_id
|
||||||
|
|
||||||
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
|
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
|
||||||
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
|
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
|
||||||
@@ -851,11 +852,9 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
|
|||||||
# Sandbox was released, fall through to acquire new one
|
# Sandbox was released, fall through to acquire new one
|
||||||
|
|
||||||
# Lazy acquisition: get thread_id and acquire sandbox
|
# Lazy acquisition: get thread_id and acquire sandbox
|
||||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
thread_id = get_thread_id(runtime)
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
thread_id = runtime.config.get("configurable", {}).get("thread_id") if runtime.config else None
|
raise SandboxRuntimeError("Thread ID not available in runtime context, runtime config, or LangGraph config")
|
||||||
if thread_id is None:
|
|
||||||
raise SandboxRuntimeError("Thread ID not available in runtime context")
|
|
||||||
|
|
||||||
provider = get_sandbox_provider()
|
provider = get_sandbox_provider()
|
||||||
sandbox_id = provider.acquire(thread_id)
|
sandbox_id = provider.acquire(thread_id)
|
||||||
|
|||||||
@@ -3,33 +3,16 @@ from typing import Annotated
|
|||||||
|
|
||||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
from langgraph.config import get_config
|
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
from langgraph.typing import ContextT
|
from langgraph.typing import ContextT
|
||||||
|
|
||||||
from deerflow.agents.thread_state import ThreadState
|
from deerflow.agents.thread_state import ThreadState
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
|
from deerflow.utils.runtime import get_thread_id
|
||||||
|
|
||||||
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
||||||
|
|
||||||
|
|
||||||
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
|
|
||||||
"""Resolve the current thread id from runtime context or RunnableConfig."""
|
|
||||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
|
||||||
if thread_id:
|
|
||||||
return thread_id
|
|
||||||
|
|
||||||
runtime_config = getattr(runtime, "config", None) or {}
|
|
||||||
thread_id = runtime_config.get("configurable", {}).get("thread_id")
|
|
||||||
if thread_id:
|
|
||||||
return thread_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
return get_config().get("configurable", {}).get("thread_id")
|
|
||||||
except RuntimeError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_presented_filepath(
|
def _normalize_presented_filepath(
|
||||||
runtime: ToolRuntime[ContextT, ThreadState],
|
runtime: ToolRuntime[ContextT, ThreadState],
|
||||||
filepath: str,
|
filepath: str,
|
||||||
@@ -51,9 +34,9 @@ def _normalize_presented_filepath(
|
|||||||
if runtime.state is None:
|
if runtime.state is None:
|
||||||
raise ValueError("Thread runtime state is not available")
|
raise ValueError("Thread runtime state is not available")
|
||||||
|
|
||||||
thread_id = _get_thread_id(runtime)
|
thread_id = get_thread_id(runtime)
|
||||||
if not thread_id:
|
if not thread_id:
|
||||||
raise ValueError("Thread ID is not available in runtime context or runtime config")
|
raise ValueError("Thread ID is not available in runtime context, runtime config, or LangGraph thread-local config")
|
||||||
|
|
||||||
thread_data = runtime.state.get("thread_data") or {}
|
thread_data = runtime.state.get("thread_data") or {}
|
||||||
outputs_path = thread_data.get("outputs_path")
|
outputs_path = thread_data.get("outputs_path")
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from deerflow.agents.thread_state import ThreadState
|
|||||||
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
|
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
|
||||||
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
|
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
|
||||||
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task
|
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task
|
||||||
|
from deerflow.utils.runtime import get_thread_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -105,9 +106,7 @@ async def task_tool(
|
|||||||
if runtime is not None:
|
if runtime is not None:
|
||||||
sandbox_state = runtime.state.get("sandbox")
|
sandbox_state = runtime.state.get("sandbox")
|
||||||
thread_data = runtime.state.get("thread_data")
|
thread_data = runtime.state.get("thread_data")
|
||||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
thread_id = get_thread_id(runtime)
|
||||||
if thread_id is None:
|
|
||||||
thread_id = runtime.config.get("configurable", {}).get("thread_id")
|
|
||||||
|
|
||||||
# Try to get parent model from configurable
|
# Try to get parent model from configurable
|
||||||
metadata = runtime.config.get("metadata", {})
|
metadata = runtime.config.get("metadata", {})
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from deerflow.skills.manager import (
|
|||||||
validate_skill_name,
|
validate_skill_name,
|
||||||
)
|
)
|
||||||
from deerflow.skills.security_scanner import scan_skill_content
|
from deerflow.skills.security_scanner import scan_skill_content
|
||||||
|
from deerflow.utils.runtime import get_thread_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -42,14 +43,6 @@ def _get_lock(name: str) -> asyncio.Lock:
|
|||||||
return lock
|
return lock
|
||||||
|
|
||||||
|
|
||||||
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
|
|
||||||
if runtime is None:
|
|
||||||
return None
|
|
||||||
if runtime.context and runtime.context.get("thread_id"):
|
|
||||||
return runtime.context.get("thread_id")
|
|
||||||
return runtime.config.get("configurable", {}).get("thread_id")
|
|
||||||
|
|
||||||
|
|
||||||
def _history_record(*, action: str, file_path: str, prev_content: str | None, new_content: str | None, thread_id: str | None, scanner: dict[str, Any]) -> dict[str, Any]:
|
def _history_record(*, action: str, file_path: str, prev_content: str | None, new_content: str | None, thread_id: str | None, scanner: dict[str, Any]) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"action": action,
|
"action": action,
|
||||||
@@ -98,7 +91,7 @@ async def _skill_manage_impl(
|
|||||||
"""
|
"""
|
||||||
name = validate_skill_name(name)
|
name = validate_skill_name(name)
|
||||||
lock = _get_lock(name)
|
lock = _get_lock(name)
|
||||||
thread_id = _get_thread_id(runtime)
|
thread_id = get_thread_id(runtime)
|
||||||
|
|
||||||
async with lock:
|
async with lock:
|
||||||
if action == "create":
|
if action == "create":
|
||||||
|
|||||||
@@ -0,0 +1,90 @@
|
|||||||
|
"""Runtime utilities for thread_id resolution and context access.
|
||||||
|
|
||||||
|
Thread ID Resolution Strategy
|
||||||
|
=============================
|
||||||
|
|
||||||
|
DeerFlow resolves the current ``thread_id`` from a three-level cascade:
|
||||||
|
|
||||||
|
1. **runtime.context["thread_id"]** -- Set by ``worker.py`` (gateway mode)
|
||||||
|
or by LangGraph Server (standard mode) when constructing the Runtime.
|
||||||
|
2. **runtime.config["configurable"]["thread_id"]** -- Available on
|
||||||
|
``ToolRuntime`` instances passed to tools via the ``@tool`` decorator.
|
||||||
|
Not available on ``Runtime`` instances received by middlewares.
|
||||||
|
3. **get_config()["configurable"]["thread_id"]** -- LangGraph's thread-local
|
||||||
|
config, available when executing inside a graph's runnable context.
|
||||||
|
|
||||||
|
About ``__pregel_runtime``
|
||||||
|
===========================
|
||||||
|
|
||||||
|
In gateway mode (``run_agent()`` in ``worker.py``), the agent graph does not
|
||||||
|
run inside the LangGraph Server. The server normally injects a ``Runtime``
|
||||||
|
object automatically. Since we run the graph ourselves, we must inject the
|
||||||
|
Runtime manually via ``config["configurable"]["__pregel_runtime"]``. This is
|
||||||
|
the standard mechanism provided by LangGraph's Pregel engine for injecting
|
||||||
|
runtime context into graph nodes. It is not a private/internal hack -- it is
|
||||||
|
the documented way to pass Runtime when running a graph outside the server.
|
||||||
|
|
||||||
|
Duck Typing
|
||||||
|
===========
|
||||||
|
|
||||||
|
Both ``langgraph.runtime.Runtime`` (middlewares) and
|
||||||
|
``langchain.tools.ToolRuntime`` (tools) expose a ``.context`` attribute (a
|
||||||
|
dict or None). ``ToolRuntime`` additionally exposes ``.config``. The
|
||||||
|
function below uses ``getattr`` with safe defaults so it works with either
|
||||||
|
type, with ``SimpleNamespace`` in tests, or with ``None``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_id(runtime: Any | None) -> str | None:
|
||||||
|
"""Resolve the current thread_id from a runtime object.
|
||||||
|
|
||||||
|
Follows a three-level fallback chain:
|
||||||
|
|
||||||
|
1. ``runtime.context.get("thread_id")`` -- if context is a non-empty dict.
|
||||||
|
2. ``runtime.config.get("configurable", {}).get("thread_id")`` -- if
|
||||||
|
the runtime has a config dict (ToolRuntime).
|
||||||
|
3. ``get_config().get("configurable", {}).get("thread_id")`` -- LangGraph's
|
||||||
|
thread-local config. Wrapped in ``try/except RuntimeError`` because it
|
||||||
|
raises outside a runnable context (e.g., unit tests).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runtime: A Runtime, ToolRuntime, SimpleNamespace, or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The thread_id string, or None if it cannot be resolved.
|
||||||
|
"""
|
||||||
|
if runtime is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Level 1: runtime.context["thread_id"]
|
||||||
|
context = getattr(runtime, "context", None)
|
||||||
|
if context and isinstance(context, dict):
|
||||||
|
thread_id = context.get("thread_id")
|
||||||
|
if thread_id:
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
# Level 2: runtime.config["configurable"]["thread_id"]
|
||||||
|
config = getattr(runtime, "config", None)
|
||||||
|
if config and isinstance(config, dict):
|
||||||
|
thread_id = config.get("configurable", {}).get("thread_id")
|
||||||
|
if thread_id:
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
# Level 3: langgraph.config.get_config() -- only works inside runnable context
|
||||||
|
try:
|
||||||
|
from langgraph.config import get_config
|
||||||
|
|
||||||
|
config_data = get_config()
|
||||||
|
thread_id = config_data.get("configurable", {}).get("thread_id")
|
||||||
|
if thread_id:
|
||||||
|
return thread_id
|
||||||
|
except RuntimeError:
|
||||||
|
# Expected when not running inside a LangGraph runnable context (e.g., unit tests).
|
||||||
|
# In that case, thread_id cannot be resolved from thread-local config, so fall through.
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
"""Tests for deerflow.utils.runtime.get_thread_id."""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from deerflow.utils.runtime import get_thread_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetThreadId:
|
||||||
|
"""Tests for get_thread_id() with various runtime shapes."""
|
||||||
|
|
||||||
|
def test_returns_none_when_runtime_is_none(self):
|
||||||
|
assert get_thread_id(None) is None
|
||||||
|
|
||||||
|
def test_returns_thread_id_from_context(self):
|
||||||
|
runtime = SimpleNamespace(context={"thread_id": "t-1"}, config={})
|
||||||
|
assert get_thread_id(runtime) == "t-1"
|
||||||
|
|
||||||
|
def test_returns_none_from_empty_context(self):
|
||||||
|
runtime = SimpleNamespace(context={}, config={})
|
||||||
|
assert get_thread_id(runtime) is None
|
||||||
|
|
||||||
|
def test_returns_none_from_none_context(self):
|
||||||
|
runtime = SimpleNamespace(context=None, config={})
|
||||||
|
assert get_thread_id(runtime) is None
|
||||||
|
|
||||||
|
def test_falls_back_to_runtime_config(self):
|
||||||
|
runtime = SimpleNamespace(
|
||||||
|
context=None,
|
||||||
|
config={"configurable": {"thread_id": "t-from-config"}},
|
||||||
|
)
|
||||||
|
assert get_thread_id(runtime) == "t-from-config"
|
||||||
|
|
||||||
|
def test_context_takes_precedence_over_config(self):
|
||||||
|
runtime = SimpleNamespace(
|
||||||
|
context={"thread_id": "t-from-context"},
|
||||||
|
config={"configurable": {"thread_id": "t-from-config"}},
|
||||||
|
)
|
||||||
|
assert get_thread_id(runtime) == "t-from-context"
|
||||||
|
|
||||||
|
def test_falls_back_to_get_config(self):
|
||||||
|
runtime = SimpleNamespace(context=None, config={})
|
||||||
|
with patch("langgraph.config.get_config", return_value={"configurable": {"thread_id": "t-from-lg"}}):
|
||||||
|
assert get_thread_id(runtime) == "t-from-lg"
|
||||||
|
|
||||||
|
def test_returns_none_when_get_config_raises_runtime_error(self):
|
||||||
|
runtime = SimpleNamespace(context=None, config={})
|
||||||
|
with patch("langgraph.config.get_config", side_effect=RuntimeError):
|
||||||
|
assert get_thread_id(runtime) is None
|
||||||
|
|
||||||
|
def test_handles_object_without_context_or_config(self):
|
||||||
|
runtime = SimpleNamespace()
|
||||||
|
assert get_thread_id(runtime) is None
|
||||||
|
|
||||||
|
def test_handles_context_not_dict(self):
|
||||||
|
runtime = SimpleNamespace(context="not-a-dict", config={})
|
||||||
|
assert get_thread_id(runtime) is None
|
||||||
|
|
||||||
|
def test_config_without_configurable(self):
|
||||||
|
runtime = SimpleNamespace(context=None, config={"other_key": "value"})
|
||||||
|
assert get_thread_id(runtime) is None
|
||||||
|
|
||||||
|
def test_empty_string_thread_id_treated_as_missing(self):
|
||||||
|
runtime = SimpleNamespace(context={"thread_id": ""}, config={})
|
||||||
|
assert get_thread_id(runtime) is None
|
||||||
|
|
||||||
|
def test_full_cascade_with_all_levels_failing(self):
|
||||||
|
runtime = SimpleNamespace(context=None, config={})
|
||||||
|
with patch("langgraph.config.get_config", return_value={"configurable": {}}):
|
||||||
|
assert get_thread_id(runtime) is None
|
||||||
@@ -23,7 +23,7 @@ class TestThreadDataMiddleware:
|
|||||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||||
runtime = Runtime(context=None)
|
runtime = Runtime(context=None)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"deerflow.agents.middlewares.thread_data_middleware.get_config",
|
"langgraph.config.get_config",
|
||||||
lambda: {"configurable": {"thread_id": "thread-from-config"}},
|
lambda: {"configurable": {"thread_id": "thread-from-config"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ class TestThreadDataMiddleware:
|
|||||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||||
runtime = Runtime(context={})
|
runtime = Runtime(context={})
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"deerflow.agents.middlewares.thread_data_middleware.get_config",
|
"langgraph.config.get_config",
|
||||||
lambda: {"configurable": {"thread_id": "thread-from-config"}},
|
lambda: {"configurable": {"thread_id": "thread-from-config"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ class TestThreadDataMiddleware:
|
|||||||
def test_before_agent_raises_clear_error_when_thread_id_missing_everywhere(self, tmp_path, monkeypatch):
|
def test_before_agent_raises_clear_error_when_thread_id_missing_everywhere(self, tmp_path, monkeypatch):
|
||||||
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
middleware = ThreadDataMiddleware(base_dir=str(tmp_path), lazy_init=True)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"deerflow.agents.middlewares.thread_data_middleware.get_config",
|
"langgraph.config.get_config",
|
||||||
lambda: {"configurable": {}},
|
lambda: {"configurable": {}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user