mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
Merge branch 'main' of https://github.com/bytedance/deer-flow into rayhpeng/storage-package-base
This commit is contained in:
@@ -40,6 +40,15 @@ class MemoryUpdateQueue:
|
||||
self._timer: threading.Timer | None = None
|
||||
self._processing = False
|
||||
|
||||
@staticmethod
|
||||
def _queue_key(
|
||||
thread_id: str,
|
||||
user_id: str | None,
|
||||
agent_name: str | None,
|
||||
) -> tuple[str, str | None, str | None]:
|
||||
"""Return the debounce identity for a memory update target."""
|
||||
return (thread_id, user_id, agent_name)
|
||||
|
||||
def add(
|
||||
self,
|
||||
thread_id: str,
|
||||
@@ -115,8 +124,9 @@ class MemoryUpdateQueue:
|
||||
correction_detected: bool,
|
||||
reinforcement_detected: bool,
|
||||
) -> None:
|
||||
queue_key = self._queue_key(thread_id, user_id, agent_name)
|
||||
existing_context = next(
|
||||
(context for context in self._queue if context.thread_id == thread_id),
|
||||
(context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) == queue_key),
|
||||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
@@ -130,7 +140,7 @@ class MemoryUpdateQueue:
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
|
||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||
self._queue = [context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) != queue_key]
|
||||
self._queue.append(context)
|
||||
|
||||
def _reset_timer(self) -> None:
|
||||
|
||||
@@ -6,6 +6,7 @@ from deerflow.agents.memory.message_processing import detect_correction, detect_
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import resolve_runtime_user_id
|
||||
|
||||
|
||||
def memory_flush_hook(event: SummarizationEvent) -> None:
|
||||
@@ -21,11 +22,13 @@ def memory_flush_hook(event: SummarizationEvent) -> None:
|
||||
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
user_id = resolve_runtime_user_id(event.runtime)
|
||||
queue = get_memory_queue()
|
||||
queue.add_nowait(
|
||||
thread_id=event.thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=event.agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, override
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.todo import Todo
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
|
||||
return "thinking"
|
||||
|
||||
|
||||
def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
|
||||
"""Return True if the AIMessage contains a tool_call with the given id."""
|
||||
for tc in message.tool_calls or []:
|
||||
if isinstance(tc, dict):
|
||||
if tc.get("id") == tool_call_id:
|
||||
return True
|
||||
elif hasattr(tc, "id") and tc.id == tool_call_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
actions: list[dict[str, Any]] = []
|
||||
@@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Annotate subagent token usage onto the AIMessage that dispatched it.
|
||||
# When a task tool completes, its usage is cached by tool_call_id. Detect
|
||||
# the ToolMessage → search backward for the corresponding AIMessage → merge.
|
||||
# Walk backward through consecutive ToolMessages before the new AIMessage
|
||||
# so that multiple concurrent task tool calls all get their subagent tokens
|
||||
# written back to the same dispatch message (merging into one update).
|
||||
state_updates: dict[int, AIMessage] = {}
|
||||
if len(messages) >= 2:
|
||||
from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
|
||||
|
||||
idx = len(messages) - 2
|
||||
while idx >= 0:
|
||||
tool_msg = messages[idx]
|
||||
if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
|
||||
break
|
||||
|
||||
subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
|
||||
if subagent_usage:
|
||||
# Search backward from the ToolMessage to find the AIMessage
|
||||
# that dispatched it. A single model response can dispatch
|
||||
# multiple task tool calls, so we can't assume a fixed offset.
|
||||
dispatch_idx = idx - 1
|
||||
while dispatch_idx >= 0:
|
||||
candidate = messages[dispatch_idx]
|
||||
if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
|
||||
# Accumulate into an existing update for the same
|
||||
# AIMessage (multiple task calls in one response),
|
||||
# or merge fresh from the original message.
|
||||
existing_update = state_updates.get(dispatch_idx)
|
||||
prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
|
||||
merged = {
|
||||
**prev,
|
||||
"input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
|
||||
"output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
|
||||
"total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
|
||||
}
|
||||
state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
|
||||
break
|
||||
dispatch_idx -= 1
|
||||
idx -= 1
|
||||
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
if state_updates:
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
||||
return None
|
||||
|
||||
usage = getattr(last, "usage_metadata", None)
|
||||
@@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
|
||||
|
||||
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
|
||||
return None
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None
|
||||
|
||||
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
|
||||
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||
return {"messages": [updated_msg]}
|
||||
state_updates[len(messages) - 1] = updated_msg
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
|
||||
@@ -0,0 +1,195 @@
|
||||
"""Dialect-aware JSON value matching for SQLAlchemy (SQLite + PostgreSQL)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import BigInteger, Float, String, bindparam
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql.compiler import SQLCompiler
|
||||
from sqlalchemy.sql.expression import ColumnElement
|
||||
from sqlalchemy.sql.visitors import InternalTraversal
|
||||
from sqlalchemy.types import Boolean, TypeEngine
|
||||
|
||||
# Key is interpolated into compiled SQL; restrict charset to prevent injection.
|
||||
_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
|
||||
# Allowed value types for metadata filter values (same set accepted by JsonMatch).
|
||||
ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str)
|
||||
|
||||
# SQLite raises an overflow when binding values outside signed 64-bit range;
|
||||
# PostgreSQL overflows during BIGINT cast. Reject at validation time instead.
|
||||
_INT64_MIN = -(2**63)
|
||||
_INT64_MAX = 2**63 - 1
|
||||
|
||||
|
||||
def validate_metadata_filter_key(key: object) -> bool:
|
||||
"""Return True if *key* is safe for use as a JSON metadata filter key.
|
||||
|
||||
A key is "safe" when it is a string matching ``[A-Za-z0-9_-]+``. The
|
||||
charset is restricted because the key is interpolated into the
|
||||
compiled SQL path expression (``$."<key>"`` / ``->`` literal), so any
|
||||
laxer pattern would open a SQL/JSONPath injection surface.
|
||||
"""
|
||||
return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key))
|
||||
|
||||
|
||||
def validate_metadata_filter_value(value: object) -> bool:
|
||||
"""Return True if *value* is an allowed type for a JSON metadata filter.
|
||||
|
||||
Matches the set of types ``_build_clause`` knows how to compile into
|
||||
a dialect-portable predicate. Anything else (list/dict/bytes/...) is
|
||||
intentionally rejected rather than silently coerced via ``str()`` —
|
||||
silent coercion would (a) produce wrong matches and (b) break
|
||||
SQLAlchemy's ``inherit_cache`` invariant when ``value`` is unhashable.
|
||||
|
||||
Integer values are additionally restricted to the signed 64-bit range
|
||||
``[-2**63, 2**63 - 1]``: SQLite overflows when binding larger values
|
||||
and PostgreSQL overflows during the ``BIGINT`` cast.
|
||||
"""
|
||||
if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES):
|
||||
return False
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
if not (_INT64_MIN <= value <= _INT64_MAX):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class JsonMatch(ColumnElement):
|
||||
"""Dialect-portable ``column[key] == value`` for JSON columns.
|
||||
|
||||
Compiles to ``json_type``/``json_extract`` on SQLite and
|
||||
``json_typeof``/``->>`` on PostgreSQL, with type-safe comparison
|
||||
that distinguishes bool vs int and NULL vs missing key.
|
||||
|
||||
*key* must be a single literal key matching ``[A-Za-z0-9_-]+``.
|
||||
*value* must be one of: ``None``, ``bool``, ``int`` (signed 64-bit), ``float``, ``str``.
|
||||
"""
|
||||
|
||||
inherit_cache = True
|
||||
type = Boolean()
|
||||
_is_implicitly_boolean = True
|
||||
|
||||
_traverse_internals = [
|
||||
("column", InternalTraversal.dp_clauseelement),
|
||||
("key", InternalTraversal.dp_string),
|
||||
("value", InternalTraversal.dp_plain_obj),
|
||||
]
|
||||
|
||||
def __init__(self, column: ColumnElement, key: str, value: object) -> None:
|
||||
if not validate_metadata_filter_key(key):
|
||||
raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}")
|
||||
if not validate_metadata_filter_value(value):
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}")
|
||||
raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}")
|
||||
self.column = column
|
||||
self.key = key
|
||||
self.value = value
|
||||
super().__init__()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _Dialect:
|
||||
"""Per-dialect names used when emitting JSON type/value comparisons."""
|
||||
|
||||
null_type: str
|
||||
num_types: tuple[str, ...]
|
||||
num_cast: str
|
||||
int_types: tuple[str, ...]
|
||||
int_cast: str
|
||||
# None for SQLite where json_type already returns 'integer'/'real';
|
||||
# regex literal for PostgreSQL where json_typeof returns 'number' for
|
||||
# both ints and floats, so an extra guard prevents CAST errors on floats.
|
||||
int_guard: str | None
|
||||
string_type: str
|
||||
bool_type: str | None
|
||||
|
||||
|
||||
_SQLITE = _Dialect(
|
||||
null_type="null",
|
||||
num_types=("integer", "real"),
|
||||
num_cast="REAL",
|
||||
int_types=("integer",),
|
||||
int_cast="INTEGER",
|
||||
int_guard=None,
|
||||
string_type="text",
|
||||
bool_type=None,
|
||||
)
|
||||
|
||||
_PG = _Dialect(
|
||||
null_type="null",
|
||||
num_types=("number",),
|
||||
num_cast="DOUBLE PRECISION",
|
||||
int_types=("number",),
|
||||
int_cast="BIGINT",
|
||||
int_guard="'^-?[0-9]+$'",
|
||||
string_type="string",
|
||||
bool_type="boolean",
|
||||
)
|
||||
|
||||
|
||||
def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str:
|
||||
param = bindparam(None, value, type_=sa_type)
|
||||
return compiler.process(param, **kw)
|
||||
|
||||
|
||||
def _type_check(typeof: str, types: tuple[str, ...]) -> str:
|
||||
if len(types) == 1:
|
||||
return f"{typeof} = '{types[0]}'"
|
||||
quoted = ", ".join(f"'{t}'" for t in types)
|
||||
return f"{typeof} IN ({quoted})"
|
||||
|
||||
|
||||
def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str:
|
||||
if value is None:
|
||||
return f"{typeof} = '{dialect.null_type}'"
|
||||
if isinstance(value, bool):
|
||||
# bool check must precede int check — bool is a subclass of int in Python
|
||||
bool_str = "true" if value else "false"
|
||||
if dialect.bool_type is None:
|
||||
return f"{typeof} = '{bool_str}'"
|
||||
return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')"
|
||||
if isinstance(value, int):
|
||||
bp = _bind(compiler, value, BigInteger(), **kw)
|
||||
if dialect.int_guard:
|
||||
# CASE prevents CAST error when json_typeof = 'number' also matches floats
|
||||
return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
|
||||
return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})"
|
||||
if isinstance(value, float):
|
||||
bp = _bind(compiler, value, Float(), **kw)
|
||||
return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})"
|
||||
bp = _bind(compiler, str(value), String(), **kw)
|
||||
return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})"
|
||||
|
||||
|
||||
@compiles(JsonMatch, "sqlite")
|
||||
def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
path = f'$."{element.key}"'
|
||||
typeof = f"json_type({col}, '{path}')"
|
||||
extract = f"json_extract({col}, '{path}')"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch, "postgresql")
|
||||
def _compile_pg(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
typeof = f"json_typeof({col} -> '{element.key}')"
|
||||
extract = f"({col} ->> '{element.key}')"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _PG, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch)
|
||||
def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
raise NotImplementedError(f"JsonMatch supports only sqlite and postgresql; got dialect: {compiler.dialect.name}")
|
||||
|
||||
|
||||
def json_match(column: ColumnElement, key: str, value: object) -> JsonMatch:
|
||||
return JsonMatch(column, key, value)
|
||||
@@ -223,10 +223,11 @@ class RunRepository(RunStore):
|
||||
"""Aggregate token usage via a single SQL GROUP BY query."""
|
||||
_completed = RunRow.status.in_(("success", "error"))
|
||||
_thread = RunRow.thread_id == thread_id
|
||||
model_name = func.coalesce(RunRow.model_name, "unknown")
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
func.coalesce(RunRow.model_name, "unknown").label("model"),
|
||||
model_name.label("model"),
|
||||
func.count().label("runs"),
|
||||
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
|
||||
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
|
||||
@@ -236,7 +237,7 @@ class RunRepository(RunStore):
|
||||
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
|
||||
)
|
||||
.where(_thread, _completed)
|
||||
.group_by(func.coalesce(RunRow.model_name, "unknown"))
|
||||
.group_by(model_name)
|
||||
)
|
||||
|
||||
async with self._sf() as session:
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
|
||||
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
__all__ = [
|
||||
"InvalidMetadataFilterError",
|
||||
"MemoryThreadMetaStore",
|
||||
"ThreadMetaRepository",
|
||||
"ThreadMetaRow",
|
||||
|
||||
@@ -15,10 +15,15 @@ three-state semantics (see :mod:`deerflow.runtime.user_context`):
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel
|
||||
|
||||
|
||||
class InvalidMetadataFilterError(ValueError):
|
||||
"""Raised when all client-supplied metadata filter keys are rejected."""
|
||||
|
||||
|
||||
class ThreadMetaStore(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def create(
|
||||
@@ -40,12 +45,12 @@ class ThreadMetaStore(abc.ABC):
|
||||
async def search(
|
||||
self,
|
||||
*,
|
||||
metadata: dict | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -69,12 +69,12 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
async def search(
|
||||
self,
|
||||
*,
|
||||
metadata: dict | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
) -> list[dict[str, Any]]:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search")
|
||||
filter_dict: dict[str, Any] = {}
|
||||
if metadata:
|
||||
|
||||
@@ -2,16 +2,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.json_compat import json_match
|
||||
from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThreadMetaRepository(ThreadMetaStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
@@ -20,7 +24,7 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
@staticmethod
|
||||
def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]:
|
||||
d = row.to_dict()
|
||||
d["metadata"] = d.pop("metadata_json", {})
|
||||
d["metadata"] = d.pop("metadata_json", None) or {}
|
||||
for key in ("created_at", "updated_at"):
|
||||
val = d.get(key)
|
||||
if isinstance(val, datetime):
|
||||
@@ -104,39 +108,43 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
async def search(
|
||||
self,
|
||||
*,
|
||||
metadata: dict | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search threads with optional metadata and status filters.
|
||||
|
||||
Owner filter is enforced by default: caller must be in a user
|
||||
context. Pass ``user_id=None`` to bypass (migration/CLI).
|
||||
"""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search")
|
||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc(), ThreadMetaRow.thread_id.desc())
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id)
|
||||
if status:
|
||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||
|
||||
if metadata:
|
||||
# When metadata filter is active, fetch a larger window and filter
|
||||
# in Python. TODO(Phase 2): use JSON DB operators (Postgres @>,
|
||||
# SQLite json_extract) for server-side filtering.
|
||||
stmt = stmt.limit(limit * 5 + offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = [self._row_to_dict(r) for r in result.scalars()]
|
||||
rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())]
|
||||
return rows[offset : offset + limit]
|
||||
else:
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
applied = 0
|
||||
for key, value in metadata.items():
|
||||
try:
|
||||
stmt = stmt.where(json_match(ThreadMetaRow.metadata_json, key, value))
|
||||
applied += 1
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
|
||||
if applied == 0:
|
||||
# Comma-separated plain string (no list repr / nested
|
||||
# quoting) so the 400 detail surfaced by the Gateway is
|
||||
# easy for clients to read. Sorted for determinism.
|
||||
rejected_keys = ", ".join(sorted(str(k) for k in metadata))
|
||||
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
|
||||
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool:
|
||||
"""Return True if the row exists and is owned (or filter bypassed)."""
|
||||
|
||||
@@ -11,7 +11,7 @@ import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import delete, func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
@@ -86,6 +86,28 @@ class DbRunEventStore(RunEventStore):
|
||||
user = get_current_user()
|
||||
return str(user.id) if user is not None else None
|
||||
|
||||
@staticmethod
|
||||
async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None:
|
||||
"""Return the current max seq while serializing writers per thread.
|
||||
|
||||
PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate
|
||||
results are not lockable rows. As a release-safe workaround, take a
|
||||
transaction-level advisory lock keyed by thread_id before reading the
|
||||
aggregate. Other dialects keep the existing row-locking statement.
|
||||
"""
|
||||
stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)
|
||||
bind = session.get_bind()
|
||||
dialect_name = bind.dialect.name if bind is not None else ""
|
||||
|
||||
if dialect_name == "postgresql":
|
||||
await session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"),
|
||||
{"thread_id": thread_id},
|
||||
)
|
||||
return await session.scalar(stmt)
|
||||
|
||||
return await session.scalar(stmt.with_for_update())
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||
"""Write a single event — low-frequency path only.
|
||||
|
||||
@@ -100,10 +122,7 @@ class DbRunEventStore(RunEventStore):
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
max_seq = await self._max_seq_for_thread(session, thread_id)
|
||||
seq = (max_seq or 0) + 1
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
@@ -126,10 +145,8 @@ class DbRunEventStore(RunEventStore):
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
thread_id = events[0]["thread_id"]
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
max_seq = await self._max_seq_for_thread(session, thread_id)
|
||||
seq = max_seq or 0
|
||||
rows = []
|
||||
for e in events:
|
||||
|
||||
@@ -109,6 +109,34 @@ def get_effective_user_id() -> str:
|
||||
return str(user.id)
|
||||
|
||||
|
||||
def resolve_runtime_user_id(runtime: object | None) -> str:
|
||||
"""Single source of truth for a tool/middleware's effective user_id.
|
||||
|
||||
Resolution order (most authoritative first):
|
||||
1. ``runtime.context["user_id"]`` — set by ``inject_authenticated_user_context``
|
||||
in the gateway from the auth-validated ``request.state.user``. This is
|
||||
the only source that survives boundaries where the contextvar may have
|
||||
been lost (background tasks scheduled outside the request task,
|
||||
worker pools that don't copy_context, future cross-process drivers).
|
||||
2. The ``_current_user`` ContextVar — set by the auth middleware at
|
||||
request entry. Reliable for in-task work; copied by ``asyncio``
|
||||
child tasks and by ``ContextThreadPoolExecutor``.
|
||||
3. ``DEFAULT_USER_ID`` — last-resort fallback so unauthenticated
|
||||
CLI / migration / test paths keep working without raising.
|
||||
|
||||
Tools that persist user-scoped state (custom agents, memory, uploads)
|
||||
MUST call this instead of ``get_effective_user_id()`` directly so they
|
||||
benefit from the runtime.context channel that ``setup_agent`` already
|
||||
relies on.
|
||||
"""
|
||||
context = getattr(runtime, "context", None)
|
||||
if isinstance(context, dict):
|
||||
ctx_user_id = context.get("user_id")
|
||||
if ctx_user_id:
|
||||
return str(ctx_user_id)
|
||||
return get_effective_user_id()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sentinel-based user_id resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -7,19 +7,12 @@ from langgraph.types import Command
|
||||
|
||||
from deerflow.config.agents_config import validate_agent_name
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.user_context import resolve_runtime_user_id
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_runtime_user_id(runtime: Runtime) -> str:
|
||||
context_user_id = runtime.context.get("user_id") if runtime.context else None
|
||||
if context_user_id:
|
||||
return str(context_user_id)
|
||||
return get_effective_user_id()
|
||||
|
||||
|
||||
@tool(parse_docstring=True)
|
||||
def setup_agent(
|
||||
soul: str,
|
||||
@@ -45,7 +38,7 @@ def setup_agent(
|
||||
if agent_name:
|
||||
# Custom agents are persisted under the current user's bucket so
|
||||
# different users do not see each other's agents.
|
||||
user_id = _get_runtime_user_id(runtime)
|
||||
user_id = resolve_runtime_user_id(runtime)
|
||||
agent_dir = paths.user_agent_dir(user_id, agent_name)
|
||||
else:
|
||||
# Default agent (no agent_name): SOUL.md lives at the global base dir.
|
||||
|
||||
@@ -26,6 +26,28 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can
|
||||
# write it back to the triggering AIMessage's usage_metadata.
|
||||
_subagent_usage_cache: dict[str, dict[str, int]] = {}
|
||||
|
||||
|
||||
def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
|
||||
if app_config is None:
|
||||
try:
|
||||
app_config = get_app_config()
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
|
||||
|
||||
|
||||
def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None:
|
||||
if enabled and usage:
|
||||
_subagent_usage_cache[tool_call_id] = usage
|
||||
|
||||
|
||||
def pop_cached_subagent_usage(tool_call_id: str) -> dict | None:
|
||||
return _subagent_usage_cache.pop(tool_call_id, None)
|
||||
|
||||
|
||||
def _is_subagent_terminal(result: Any) -> bool:
|
||||
"""Return whether a background subagent result is safe to clean up."""
|
||||
@@ -92,6 +114,17 @@ def _find_usage_recorder(runtime: Any) -> Any | None:
|
||||
return None
|
||||
|
||||
|
||||
def _summarize_usage(records: list[dict] | None) -> dict | None:
|
||||
"""Summarize token usage records into a compact dict for SSE events."""
|
||||
if not records:
|
||||
return None
|
||||
return {
|
||||
"input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records),
|
||||
"output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records),
|
||||
"total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records),
|
||||
}
|
||||
|
||||
|
||||
def _report_subagent_usage(runtime: Any, result: Any) -> None:
|
||||
"""Report subagent token usage to the parent RunJournal, if available.
|
||||
|
||||
@@ -177,6 +210,7 @@ async def task_tool(
|
||||
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
"""
|
||||
runtime_app_config = _get_runtime_app_config(runtime)
|
||||
cache_token_usage = _token_usage_cache_enabled(runtime_app_config)
|
||||
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
|
||||
|
||||
# Get subagent configuration
|
||||
@@ -312,27 +346,32 @@ async def task_tool(
|
||||
last_message_count = current_message_count
|
||||
|
||||
# Check if task completed, failed, or timed out
|
||||
usage = _summarize_usage(getattr(result, "token_usage_records", None))
|
||||
if result.status == SubagentStatus.COMPLETED:
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task Succeeded. Result: {result.result}"
|
||||
elif result.status == SubagentStatus.FAILED:
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage})
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task failed. Error: {result.error}"
|
||||
elif result.status == SubagentStatus.CANCELLED:
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
|
||||
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return "Task cancelled by user."
|
||||
elif result.status == SubagentStatus.TIMED_OUT:
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage})
|
||||
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task timed out. Error: {result.error}"
|
||||
@@ -351,7 +390,9 @@ async def task_tool(
|
||||
timeout_minutes = config.timeout_seconds // 60
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_timed_out", "task_id": task_id})
|
||||
usage = _summarize_usage(getattr(result, "token_usage_records", None))
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "usage": usage})
|
||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||
except asyncio.CancelledError:
|
||||
# Signal the background subagent thread to stop cooperatively.
|
||||
@@ -374,4 +415,8 @@ async def task_tool(
|
||||
cleanup_background_task(task_id)
|
||||
else:
|
||||
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
|
||||
_subagent_usage_cache.pop(tool_call_id, None)
|
||||
raise
|
||||
except Exception:
|
||||
_subagent_usage_cache.pop(tool_call_id, None)
|
||||
raise
|
||||
|
||||
@@ -27,7 +27,7 @@ from langgraph.types import Command
|
||||
from deerflow.config.agents_config import load_agent_config, validate_agent_name
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.user_context import resolve_runtime_user_id
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -118,9 +118,13 @@ def update_agent(
|
||||
return _err("update_agent is only available inside a custom agent's chat. There is no agent_name in the current runtime context, so there is nothing to update. If you are inside the bootstrap flow, use setup_agent instead.")
|
||||
|
||||
# Resolve the active user so that updates only affect this user's agent.
|
||||
# ``get_effective_user_id`` returns DEFAULT_USER_ID when no auth context
|
||||
# is set (matching how memory and thread storage behave).
|
||||
user_id = get_effective_user_id()
|
||||
# ``resolve_runtime_user_id`` prefers ``runtime.context["user_id"]`` (set by
|
||||
# the gateway from the auth-validated request) and falls back to the
|
||||
# contextvar, then DEFAULT_USER_ID. This matches setup_agent so a user
|
||||
# creating an agent and later refining it always touches the same files,
|
||||
# even if the contextvar gets lost across an async/thread boundary
|
||||
# (issue #2782 / #2862 class of bugs).
|
||||
user_id = resolve_runtime_user_id(runtime)
|
||||
|
||||
# Reject an unknown ``model`` *before* touching the filesystem. Otherwise
|
||||
# ``_resolve_model_name`` silently falls back to the default at runtime
|
||||
|
||||
@@ -7,7 +7,7 @@ from deerflow.config.app_config import AppConfig
|
||||
from deerflow.reflection import resolve_variable
|
||||
from deerflow.sandbox.security import is_host_bash_allowed
|
||||
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
|
||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -116,8 +116,6 @@ def get_available_tools(
|
||||
# made through the Gateway API (which runs in a separate process) are immediately
|
||||
# reflected when loading MCP tools.
|
||||
mcp_tools = []
|
||||
# Reset deferred registry upfront to prevent stale state from previous calls
|
||||
reset_deferred_registry()
|
||||
if include_mcp:
|
||||
try:
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
@@ -135,12 +133,51 @@ def get_available_tools(
|
||||
from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry
|
||||
from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool
|
||||
|
||||
registry = DeferredToolRegistry()
|
||||
for t in mcp_tools:
|
||||
registry.register(t)
|
||||
set_deferred_registry(registry)
|
||||
# Reuse the existing registry if one is already set for
|
||||
# this async context. ``get_available_tools`` is
|
||||
# re-entered whenever a subagent is spawned
|
||||
# (``task_tool`` calls it to build the child agent's
|
||||
# toolset), and previously we used to unconditionally
|
||||
# rebuild the registry — wiping out the parent agent's
|
||||
# tool_search promotions. The
|
||||
# ``DeferredToolFilterMiddleware`` then re-hid those
|
||||
# tools from subsequent model calls, leaving the agent
|
||||
# able to see a tool's name but unable to invoke it
|
||||
# (issue #2884). ``contextvars`` already gives us the
|
||||
# lifetime semantics we want: a fresh request / graph
|
||||
# run starts in a new asyncio task with the
|
||||
# ContextVar at its default of ``None``, so reuse is
|
||||
# only triggered for re-entrant calls inside one run.
|
||||
#
|
||||
# Intentionally NOT reconciling against the current
|
||||
# ``mcp_tools`` snapshot. The MCP cache only refreshes
|
||||
# on ``extensions_config.json`` mtime changes, which
|
||||
# in practice happens between graph runs — not inside
|
||||
# one. And even if a refresh did happen mid-run, the
|
||||
# already-built lead agent's ``ToolNode`` still holds
|
||||
# the *previous* tool set (LangGraph binds tools at
|
||||
# graph construction time), so a brand-new MCP tool
|
||||
# couldn't actually be invoked anyway. The
|
||||
# ``DeferredToolRegistry`` doesn't retain the names
|
||||
# of previously-promoted tools (``promote()`` drops
|
||||
# the entry entirely), so re-syncing the registry
|
||||
# against a fresh ``mcp_tools`` list would
|
||||
# mis-classify those promotions as new tools and
|
||||
# re-register them as deferred — exactly the bug
|
||||
# this fix exists to prevent.
|
||||
existing_registry = get_deferred_registry()
|
||||
if existing_registry is None:
|
||||
registry = DeferredToolRegistry()
|
||||
for t in mcp_tools:
|
||||
registry.register(t)
|
||||
set_deferred_registry(registry)
|
||||
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
|
||||
else:
|
||||
mcp_tool_names = {t.name for t in mcp_tools}
|
||||
still_deferred = len(existing_registry)
|
||||
promoted_count = max(0, len(mcp_tool_names) - still_deferred)
|
||||
logger.info(f"Tool search active (preserved promotions): {still_deferred} tools deferred, {promoted_count} already promoted")
|
||||
builtin_tools.append(tool_search_tool)
|
||||
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
|
||||
except ImportError:
|
||||
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user