mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
refactor(backend): consolidate thread_id resolution into shared get_thread_id() utility (#2522)
Extract duplicated thread_id fallback logic from 11 files into a single deerflow.utils.runtime.get_thread_id() function with a documented 3-level cascade (runtime.context → runtime.config → get_config()). The module docstring also clarifies the __pregel_runtime injection pattern used in gateway mode.
This commit is contained in:
@@ -3,11 +3,11 @@ from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.utils.runtime import get_thread_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -75,11 +75,7 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
|
||||
@override
|
||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
context = runtime.context or {}
|
||||
thread_id = context.get("thread_id")
|
||||
if thread_id is None:
|
||||
config = get_config()
|
||||
thread_id = config.get("configurable", {}).get("thread_id")
|
||||
thread_id = get_thread_id(runtime)
|
||||
|
||||
if thread_id is None:
|
||||
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
||||
|
||||
Reference in New Issue
Block a user