Merge branch 'main' into fix-3127

This commit is contained in:
Willem Jiang
2026-05-22 21:56:04 +08:00
committed by GitHub
57 changed files with 4981 additions and 195 deletions
@@ -29,6 +29,7 @@ from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
@@ -338,6 +339,15 @@ def _build_middlewares(
if custom_middlewares:
middlewares.extend(custom_middlewares)
# SafetyFinishReasonMiddleware — suppress tool execution when the provider
# safety-terminated the response. Registered after custom middlewares so
# that LangChain's reverse-order after_model dispatch runs Safety first;
# cleared tool_calls then flow through Loop/Subagent accounting without
# firing extra alarms. See safety_finish_reason_middleware.py docstring.
safety_config = resolved_app_config.safety_finish_reason
if safety_config.enabled:
middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config))
# ClarificationMiddleware should always be last
middlewares.append(ClarificationMiddleware())
return middlewares
@@ -15,6 +15,7 @@ to the end of the message list as before_model + add_messages reducer would do.
import json
import logging
from collections import defaultdict, deque
from collections.abc import Awaitable, Callable
from typing import override
@@ -109,10 +110,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
This normalizes model-bound causal order before provider serialization while
preserving already-valid transcripts unchanged.
"""
tool_messages_by_id: dict[str, ToolMessage] = {}
tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque)
for msg in messages:
if isinstance(msg, ToolMessage):
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
tool_messages_by_id[msg.tool_call_id].append(msg)
tool_call_ids: set[str] = set()
for msg in messages:
@@ -124,7 +125,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
tool_call_ids.add(tc_id)
patched: list = []
consumed_tool_msg_ids: set[str] = set()
patch_count = 0
for msg in messages:
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
@@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
for tc in self._message_tool_calls(msg):
tc_id = tc.get("id")
if not tc_id or tc_id in consumed_tool_msg_ids:
if not tc_id:
continue
existing_tool_msg = tool_messages_by_id.get(tc_id)
tool_msg_queue = tool_messages_by_id.get(tc_id)
existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None
if existing_tool_msg is not None:
patched.append(existing_tool_msg)
consumed_tool_msg_ids.add(tc_id)
else:
patched.append(
ToolMessage(
@@ -152,7 +152,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
status="error",
)
)
consumed_tool_msg_ids.add(tc_id)
patch_count += 1
if patched == messages:
@@ -0,0 +1,317 @@
"""Suppress tool execution when the provider safety-terminated the response.
Background — see issue bytedance/deer-flow#3028.
Some providers (OpenAI ``finish_reason='content_filter'``, Anthropic
``stop_reason='refusal'``, Gemini ``finish_reason='SAFETY'`` ...) can stop
generation mid-stream while still returning partially-formed ``tool_calls``.
LangChain's tool router treats any AIMessage with a non-empty ``tool_calls``
field as "go execute these", so half-truncated arguments — e.g. a markdown
``write_file`` that stops in the middle of a sentence — get dispatched as if
they were complete. The agent then sees the truncated file, tries to fix it,
gets filtered again, and loops.
This middleware sits at ``after_model`` and gates that behaviour: when a
configured ``SafetyTerminationDetector`` fires *and* the AIMessage carries
tool calls, we strip the tool calls (both structured and raw provider
payloads), append a user-facing explanation, and stash observability fields
in ``additional_kwargs.safety_termination`` so logs, traces, and SSE
consumers can see what happened.
Hook choice: ``after_model`` (not ``wrap_model_call``) because the response
is a *normal* return — not an exception — and we want to participate in the
same after-model chain as ``LoopDetectionMiddleware``, with which we share
the same tool-call-suppression mechanic but a different trigger.
Placement: register *after* ``LoopDetectionMiddleware`` in the middleware
list. LangChain factory wires ``after_model`` edges in reverse list order
(``langchain/agents/factory.py:add_edge("model", middleware_w_after_model[-1])``,
then walks ``range(len-1, 0, -1)``), so the *last* registered middleware is
the *first* to observe the model output. Registering Safety after Loop
means Safety sees the raw response first, clears tool calls if it fires,
and Loop then accounts against the cleaned message.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.safety_termination_detectors import (
SafetyTermination,
SafetyTerminationDetector,
default_detectors,
)
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
if TYPE_CHECKING:
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
logger = logging.getLogger(__name__)
_USER_FACING_MESSAGE = (
"The model provider stopped this response with a safety-related signal "
"({reason_field}={reason_value!r}, detector={detector!r}). Any tool "
"calls produced in this turn were suppressed because their arguments "
"may be truncated and unsafe to execute. Please rephrase the request "
"or ask for a narrower output."
)
class SafetyFinishReasonMiddleware(AgentMiddleware[AgentState]):
"""Strip tool_calls from AIMessages flagged by a SafetyTerminationDetector."""
def __init__(self, detectors: list[SafetyTerminationDetector] | None = None) -> None:
super().__init__()
# Copy so caller mutations after construction don't leak into us.
self._detectors: list[SafetyTerminationDetector] = list(detectors) if detectors else default_detectors()
@classmethod
def from_config(cls, config: SafetyFinishReasonConfig) -> SafetyFinishReasonMiddleware:
"""Construct from validated Pydantic config, honouring the
reflection-loaded detector list when provided.
An explicit empty list is intentionally rejected — it would silently
disable detection while leaving the middleware in the chain, which
is the worst of both worlds. Use ``enabled: false`` instead.
"""
if config.detectors is None:
return cls()
if not config.detectors:
raise ValueError("safety_finish_reason.detectors must be omitted (use built-ins) or contain at least one entry; use enabled=false to disable the middleware entirely.")
from deerflow.reflection import resolve_variable
detectors: list[SafetyTerminationDetector] = []
for entry in config.detectors:
detector_cls = resolve_variable(entry.use)
kwargs = dict(entry.config) if entry.config else {}
detector = detector_cls(**kwargs)
if not isinstance(detector, SafetyTerminationDetector):
raise TypeError(f"{entry.use} did not produce a SafetyTerminationDetector (got {type(detector).__name__}); ensure it has a `name` attribute and a `detect(message)` method")
detectors.append(detector)
return cls(detectors=detectors)
# ----- detection -------------------------------------------------------
def _detect(self, message: AIMessage) -> SafetyTermination | None:
for detector in self._detectors:
try:
hit = detector.detect(message)
except Exception: # noqa: BLE001 - never let a buggy detector break the agent run
logger.exception("SafetyTerminationDetector %r raised; treating as no-match", getattr(detector, "name", type(detector).__name__))
continue
if hit is not None:
return hit
return None
# ----- message rewriting ----------------------------------------------
@staticmethod
def _append_user_message(content: object, text: str) -> str | list:
"""Append a plain-text explanation to AIMessage content.
Mirrors ``LoopDetectionMiddleware._append_text`` so list-content
responses (Anthropic thinking blocks, vLLM reasoning splits) keep
their structure instead of being string-coerced into a TypeError.
"""
if content is None or content == "":
return text
if isinstance(content, list):
return [*content, {"type": "text", "text": f"\n\n{text}"}]
if isinstance(content, str):
return content + f"\n\n{text}"
return str(content) + f"\n\n{text}"
def _build_suppressed_message(
self,
message: AIMessage,
termination: SafetyTermination,
) -> AIMessage:
suppressed_names = [tc.get("name") or "unknown" for tc in (message.tool_calls or [])]
explanation = _USER_FACING_MESSAGE.format(
reason_field=termination.reason_field,
reason_value=termination.reason_value,
detector=termination.detector,
)
new_content = self._append_user_message(message.content, explanation)
# clone_ai_message_with_tool_calls handles structured tool_calls,
# raw additional_kwargs.tool_calls, and function_call in one shot.
# It only rewrites finish_reason when the old value was "tool_calls",
# which is not our case — content_filter / refusal / SAFETY stay put
# so downstream SSE / converters keep seeing the real provider reason.
cleared = clone_ai_message_with_tool_calls(message, [], content=new_content)
# Re-clone additional_kwargs so we don't accidentally mutate the
# dict returned by clone_ai_message_with_tool_calls (which already
# made a shallow copy, but downstream model_copy still references
# it). Then stamp the observability record.
kwargs = dict(getattr(cleared, "additional_kwargs", None) or {})
kwargs["safety_termination"] = {
"detector": termination.detector,
"reason_field": termination.reason_field,
"reason_value": termination.reason_value,
"suppressed_tool_call_count": len(suppressed_names),
"suppressed_tool_call_names": suppressed_names,
"extras": dict(termination.extras) if termination.extras else {},
}
return cleared.model_copy(update={"additional_kwargs": kwargs})
# ----- observability ---------------------------------------------------
def _emit_event(
self,
termination: SafetyTermination,
suppressed_names: list[str],
runtime: Runtime,
) -> None:
"""Notify SSE consumers (e.g. the web UI) that a tool turn was
suppressed so they can reconcile any "tool starting..." placeholders
already streamed to the user. Failures are logged at debug and
ignored — this is a best-effort signal."""
try:
from langgraph.config import get_stream_writer
writer = get_stream_writer()
except Exception: # noqa: BLE001
logger.debug("get_stream_writer unavailable; skipping safety_termination event", exc_info=True)
return
thread_id = None
if runtime is not None and getattr(runtime, "context", None):
thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None
try:
writer(
{
"type": "safety_termination",
"detector": termination.detector,
"reason_field": termination.reason_field,
"reason_value": termination.reason_value,
"suppressed_tool_call_count": len(suppressed_names),
"suppressed_tool_call_names": suppressed_names,
"thread_id": thread_id,
}
)
except Exception: # noqa: BLE001
logger.debug("Failed to emit safety_termination stream event", exc_info=True)
def _record_audit_event(
self,
termination: SafetyTermination,
message,
tool_calls: list[dict],
runtime: Runtime,
) -> None:
"""Write a ``middleware:safety_termination`` record to RunEventStore
for post-run auditability.
The custom stream event in ``_emit_event`` is consumed by live SSE
clients and disappears after the run; this event is persisted so an
operator can answer "which runs were safety-suppressed today?" from
a single SQL query without joining the message body. Worker exposes
the run-scoped ``RunJournal`` via ``runtime.context["__run_journal"]``;
absent in unit-test / subagent / no-event-store paths, in which case
we silently skip.
Tool **arguments** are deliberately **not** recorded — those are the
very content the provider filtered; persisting them would defeat the
purpose of the safety filter. Names / count / ids are sufficient for
audit and debugging (issue #3028 review).
"""
journal = None
if runtime is not None and getattr(runtime, "context", None):
context = runtime.context
if isinstance(context, dict):
journal = context.get("__run_journal")
if journal is None:
return
suppressed_names = [tc.get("name") or "unknown" for tc in tool_calls]
suppressed_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
changes = {
"detector": termination.detector,
"reason_field": termination.reason_field,
"reason_value": termination.reason_value,
"suppressed_tool_call_count": len(tool_calls),
"suppressed_tool_call_names": suppressed_names,
"suppressed_tool_call_ids": suppressed_ids,
"message_id": getattr(message, "id", None),
"extras": dict(termination.extras) if termination.extras else {},
}
try:
journal.record_middleware(
tag="safety_termination",
name=type(self).__name__,
hook="after_model",
action="suppress_tool_calls",
changes=changes,
)
except Exception: # noqa: BLE001
# Audit-event persistence must never break agent execution.
logger.debug("Failed to record middleware:safety_termination event", exc_info=True)
# ----- main apply ------------------------------------------------------
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
# Issue scope: only intervene when there's something to suppress.
# ``content_filter`` without tool_calls is allowed through unchanged
# so the partial text response (if any) reaches the user naturally.
tool_calls = last.tool_calls
if not tool_calls:
return None
termination = self._detect(last)
if termination is None:
return None
patched = self._build_suppressed_message(last, termination)
thread_id = None
if runtime is not None and getattr(runtime, "context", None):
thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None
logger.warning(
"Provider safety termination detected — suppressed %d tool call(s)",
len(tool_calls),
extra={
"thread_id": thread_id,
"detector": termination.detector,
"reason_field": termination.reason_field,
"reason_value": termination.reason_value,
"suppressed_tool_call_names": [tc.get("name") for tc in tool_calls],
},
)
self._emit_event(termination, [tc.get("name") or "unknown" for tc in tool_calls], runtime)
self._record_audit_event(termination, last, list(tool_calls), runtime)
return {"messages": [patched]}
# ----- hooks -----------------------------------------------------------
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@@ -0,0 +1,237 @@
"""Detectors for provider-side safety termination signals.
Different LLM providers signal "I stopped this response for safety reasons"
through different fields with different values. This module defines a small
strategy interface and three built-in detectors that cover the major
providers DeerFlow supports today. New providers (Wenxin, Hunyuan, Bedrock
adapters, in-house gateways, ...) can be added by implementing
``SafetyTerminationDetector`` and wiring it through
``config.yaml: safety_finish_reason.detectors``.
The middleware that consumes these detectors lives in
``safety_finish_reason_middleware.py``.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable
from langchain_core.messages import AIMessage
@dataclass(frozen=True)
class SafetyTermination:
"""A detected safety-related termination signal.
Attributes:
detector: Name of the detector that produced this result. Used for
observability so operators can see which provider rule fired.
reason_field: The message metadata field that carried the signal
(e.g. ``finish_reason``, ``stop_reason``).
reason_value: The actual value of that field
(e.g. ``content_filter``, ``refusal``, ``SAFETY``).
extras: Provider-specific metadata that may help downstream
consumers (e.g. Azure OpenAI content_filter_results, Gemini
safety_ratings). Detectors are free to populate or skip this.
"""
detector: str
reason_field: str
reason_value: str
extras: dict[str, Any] = field(default_factory=dict)
@runtime_checkable
class SafetyTerminationDetector(Protocol):
"""Strategy interface for provider safety termination detection."""
name: str
def detect(self, message: AIMessage) -> SafetyTermination | None:
"""Return a SafetyTermination if *message* indicates provider safety
termination, otherwise return ``None``.
Implementations must be side-effect free and tolerant of missing or
oddly-typed metadata — detectors run on every model response.
"""
...
def _get_metadata_value(message: AIMessage, field_name: str) -> str | None:
"""Read a string-typed value from either ``response_metadata`` or
``additional_kwargs``.
LangChain provider adapters are inconsistent about where they stash
provider stop signals. Most modern adapters use ``response_metadata``,
but some legacy / passthrough paths still surface them via
``additional_kwargs``. We check both, in that order, and only accept
string values — Pydantic enums or dicts are ignored so we never raise
on malformed inputs.
"""
for container_name in ("response_metadata", "additional_kwargs"):
container = getattr(message, container_name, None) or {}
if not isinstance(container, dict):
continue
value = container.get(field_name)
if isinstance(value, str) and value:
return value
return None
class OpenAICompatibleContentFilterDetector:
"""OpenAI-compatible content_filter signal.
Covers OpenAI, Azure OpenAI, Moonshot/Kimi, DeepSeek, Mistral, vLLM,
Qwen (OpenAI-compatible mode), and any other adapter that follows the
OpenAI ``finish_reason`` convention.
Some Chinese providers ship custom OpenAI-compatible gateways that use
alternative tokens like ``sensitive`` or ``violation``. Extend the set
via the ``finish_reasons`` kwarg in config.
"""
name = "openai_compatible_content_filter"
def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None:
configured = finish_reasons if finish_reasons is not None else ("content_filter",)
self._finish_reasons: frozenset[str] = frozenset(r.lower() for r in configured)
def detect(self, message: AIMessage) -> SafetyTermination | None:
value = _get_metadata_value(message, "finish_reason")
if value is None or value.lower() not in self._finish_reasons:
return None
extras: dict[str, Any] = {}
# Azure OpenAI ships a structured content_filter_results block; carry it
# through so operators can see *what* was filtered without re-tracing.
response_metadata = getattr(message, "response_metadata", None) or {}
if isinstance(response_metadata, dict):
filter_results = response_metadata.get("content_filter_results")
if filter_results:
extras["content_filter_results"] = filter_results
return SafetyTermination(
detector=self.name,
reason_field="finish_reason",
reason_value=value,
extras=extras,
)
class AnthropicRefusalDetector:
"""Anthropic ``stop_reason == "refusal"`` signal.
Anthropic models surface safety refusals via a dedicated ``stop_reason``
rather than ``finish_reason``. See:
https://platform.claude.com/docs/en/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals
"""
name = "anthropic_refusal"
def __init__(self, stop_reasons: list[str] | tuple[str, ...] | None = None) -> None:
configured = stop_reasons if stop_reasons is not None else ("refusal",)
self._stop_reasons: frozenset[str] = frozenset(r.lower() for r in configured)
def detect(self, message: AIMessage) -> SafetyTermination | None:
value = _get_metadata_value(message, "stop_reason")
if value is None or value.lower() not in self._stop_reasons:
return None
return SafetyTermination(
detector=self.name,
reason_field="stop_reason",
reason_value=value,
)
class GeminiSafetyDetector:
"""Gemini / Vertex AI safety-related finish reasons.
Gemini uses the same ``finish_reason`` field as OpenAI but with an
enumerated upper-case taxonomy. The default set covers every Gemini
finish_reason that means "the model stopped because the content/image
tripped a safety, blocklist, recitation, or PII filter" — i.e. cases
where any tool_calls returned alongside are likely truncated/
unreliable. Full enum:
https://docs.cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform_v1.types.Candidate.FinishReason
Intentionally **excluded** from the default set:
- ``STOP`` — normal termination.
- ``MAX_TOKENS`` — output length truncation, not safety
(same root failure mode as
content_filter, but issue #3028
scopes it out; expose separately if
desired).
- ``LANGUAGE`` / ``NO_IMAGE`` — capability mismatches, unrelated to
safety; tool_calls would be absent
anyway.
- ``MALFORMED_FUNCTION_CALL`` /
``UNEXPECTED_TOOL_CALL`` — tool-call protocol errors. The
tool_calls are *also* unreliable
here, but the failure category is
distinct from safety filtering;
handle in a dedicated detector to
keep observability records honest.
- ``OTHER`` / ``IMAGE_OTHER`` /
``FINISH_REASON_UNSPECIFIED`` — too broad to enable by default;
opt in via ``finish_reasons=`` if
your provider abuses these.
"""
name = "gemini_safety"
_DEFAULT_FINISH_REASONS = (
# Text safety
"SAFETY",
"BLOCKLIST",
"PROHIBITED_CONTENT",
"SPII",
"RECITATION",
# Image safety (multimodal generation)
"IMAGE_SAFETY",
"IMAGE_PROHIBITED_CONTENT",
"IMAGE_RECITATION",
)
def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None:
configured = finish_reasons if finish_reasons is not None else self._DEFAULT_FINISH_REASONS
self._finish_reasons: frozenset[str] = frozenset(r.upper() for r in configured)
def detect(self, message: AIMessage) -> SafetyTermination | None:
value = _get_metadata_value(message, "finish_reason")
if value is None or value.upper() not in self._finish_reasons:
return None
extras: dict[str, Any] = {}
response_metadata = getattr(message, "response_metadata", None) or {}
if isinstance(response_metadata, dict):
# Gemini surfaces per-category scoring under safety_ratings.
ratings = response_metadata.get("safety_ratings")
if ratings:
extras["safety_ratings"] = ratings
return SafetyTermination(
detector=self.name,
reason_field="finish_reason",
reason_value=value,
extras=extras,
)
def default_detectors() -> list[SafetyTerminationDetector]:
"""Built-in detector set used when no custom detectors are configured."""
return [
OpenAICompatibleContentFilterDetector(),
AnthropicRefusalDetector(),
GeminiSafetyDetector(),
]
__all__ = [
"AnthropicRefusalDetector",
"GeminiSafetyDetector",
"OpenAICompatibleContentFilterDetector",
"SafetyTermination",
"SafetyTerminationDetector",
"default_detectors",
]
@@ -164,4 +164,14 @@ def build_subagent_runtime_middlewares(
middlewares.append(ViewImageMiddleware())
# Same provider safety-termination guard the lead agent uses — subagents
# are equally exposed to truncated tool_calls returned with
# finish_reason=content_filter (and friends), and the bad call would then
# propagate back to the lead agent via the task tool result.
safety_config = app_config.safety_finish_reason
if safety_config.enabled:
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config))
return middlewares
@@ -20,6 +20,7 @@ from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.runtime_paths import existing_project_file
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig
@@ -102,6 +103,7 @@ class AppConfig(BaseModel):
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
model_config = ConfigDict(extra="allow")
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
@@ -0,0 +1,47 @@
"""Configuration for SafetyFinishReasonMiddleware.
Mirrors the shape of GuardrailsConfig: detectors are loaded by class path
through ``deerflow.reflection.resolve_variable`` (same loader the
``guardrails.provider`` config uses) so users can drop in custom provider
detectors without modifying core code.
"""
from __future__ import annotations
from pydantic import BaseModel, Field
class SafetyDetectorConfig(BaseModel):
"""One detector entry under ``safety_finish_reason.detectors``."""
use: str = Field(
description=("Class path of a SafetyTerminationDetector implementation (e.g. 'deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector')."),
)
config: dict = Field(
default_factory=dict,
description="Constructor kwargs passed to the detector class.",
)
class SafetyFinishReasonConfig(BaseModel):
"""Configuration for the SafetyFinishReasonMiddleware.
The middleware intercepts AIMessages where the provider signaled a
safety-related termination (e.g. OpenAI ``finish_reason='content_filter'``)
while still returning tool calls, and suppresses those tool calls so the
half-truncated arguments never execute.
"""
enabled: bool = Field(
default=True,
description="Master switch for the SafetyFinishReasonMiddleware.",
)
detectors: list[SafetyDetectorConfig] | None = Field(
default=None,
description=(
"Custom detector list. Leave unset (None) to use the built-in "
"set covering OpenAI-compatible content_filter, Anthropic "
"refusal, and Gemini SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/"
"RECITATION. Provide a non-null list to fully override."
),
)
@@ -134,9 +134,25 @@ def reset_mcp_tools_cache() -> None:
"""Reset the MCP tools cache.
This is useful for testing or when you want to reload MCP tools.
Also closes all persistent MCP sessions so they are recreated on
the next tool load.
"""
global _mcp_tools_cache, _cache_initialized, _config_mtime
_mcp_tools_cache = None
_cache_initialized = False
_config_mtime = None
# Close persistent sessions they will be recreated by the next
# get_mcp_tools() call with the (possibly updated) connection config.
try:
from deerflow.mcp.session_pool import get_session_pool
pool = get_session_pool()
pool.close_all_sync()
except Exception:
logger.debug("Could not close MCP session pool on cache reset", exc_info=True)
from deerflow.mcp.session_pool import reset_session_pool
reset_session_pool()
logger.info("MCP tools cache reset")
@@ -0,0 +1,198 @@
"""Persistent MCP session pool for stateful tool calls.
When MCP tools are loaded via langchain-mcp-adapters with ``session=None``,
each tool call creates a new MCP session. For stateful servers like Playwright,
this means browser state (opened pages, filled forms) is lost between calls.
This module provides a session pool that maintains persistent MCP sessions,
scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id —
so that consecutive tool calls share the same session and server-side state.
Sessions are evicted in LRU order when the pool reaches capacity.
"""
from __future__ import annotations
import asyncio
import logging
import threading
from collections import OrderedDict
from typing import Any
from mcp import ClientSession
logger = logging.getLogger(__name__)
class MCPSessionPool:
"""Manages persistent MCP sessions scoped by ``(server_name, scope_key)``."""
MAX_SESSIONS = 256
SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session via run_coroutine_threadsafe
def __init__(self) -> None:
self._entries: OrderedDict[
tuple[str, str],
tuple[ClientSession, asyncio.AbstractEventLoop],
] = OrderedDict()
self._context_managers: dict[tuple[str, str], Any] = {}
# threading.Lock is not bound to any event loop, so it is safe to
# acquire from both async paths and sync/worker-thread paths.
self._lock = threading.Lock()
async def get_session(
self,
server_name: str,
scope_key: str,
connection: dict[str, Any],
) -> ClientSession:
"""Get or create a persistent MCP session.
If an existing session was created in a different event loop (e.g.
the sync-wrapper path), it is closed and replaced with a fresh one
in the current loop.
Args:
server_name: MCP server name.
scope_key: Isolation key (typically thread_id).
connection: Connection configuration for ``create_session``.
Returns:
An initialized ``ClientSession``.
"""
key = (server_name, scope_key)
current_loop = asyncio.get_running_loop()
# Phase 1: inspect/mutate the registry under the thread lock (no awaits).
cms_to_close: list[tuple[tuple[str, str], Any]] = []
with self._lock:
if key in self._entries:
session, loop = self._entries[key]
if loop is current_loop:
self._entries.move_to_end(key)
return session
# Session belongs to a different event loop evict it.
cm = self._context_managers.pop(key, None)
self._entries.pop(key)
if cm is not None:
cms_to_close.append((key, cm))
# Evict LRU entries when at capacity.
while len(self._entries) >= self.MAX_SESSIONS:
oldest_key = next(iter(self._entries))
cm = self._context_managers.pop(oldest_key, None)
self._entries.pop(oldest_key)
if cm is not None:
cms_to_close.append((oldest_key, cm))
# Phase 2: async cleanup outside the lock so we never await while holding it.
for close_key, cm in cms_to_close:
try:
await cm.__aexit__(None, None, None)
except Exception:
logger.warning("Error closing MCP session %s", close_key, exc_info=True)
from langchain_mcp_adapters.sessions import create_session
cm = create_session(connection)
session = await cm.__aenter__()
await session.initialize()
# Phase 3: register the new session under the lock.
with self._lock:
self._entries[key] = (session, current_loop)
self._context_managers[key] = cm
logger.info("Created persistent MCP session for %s/%s", server_name, scope_key)
return session
# ------------------------------------------------------------------
# Cleanup helpers
# ------------------------------------------------------------------
async def _close_cm(self, key: tuple[str, str], cm: Any) -> None:
"""Close a single context manager (must be called WITHOUT the lock)."""
try:
await cm.__aexit__(None, None, None)
except Exception:
logger.warning("Error closing MCP session %s", key, exc_info=True)
async def close_scope(self, scope_key: str) -> None:
"""Close all sessions for a given scope (e.g. thread_id)."""
with self._lock:
keys = [k for k in self._entries if k[1] == scope_key]
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
for k in keys:
self._entries.pop(k, None)
for key, cm in cms:
if cm is not None:
await self._close_cm(key, cm)
async def close_server(self, server_name: str) -> None:
"""Close all sessions for a given server."""
with self._lock:
keys = [k for k in self._entries if k[0] == server_name]
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
for k in keys:
self._entries.pop(k, None)
for key, cm in cms:
if cm is not None:
await self._close_cm(key, cm)
async def close_all(self) -> None:
"""Close every managed session."""
with self._lock:
cms = list(self._context_managers.items())
self._context_managers.clear()
self._entries.clear()
for key, cm in cms:
await self._close_cm(key, cm)
def close_all_sync(self) -> None:
"""Close all sessions using their owning event loops (synchronous).
Each session is closed on the loop it was created in, avoiding
cross-loop resource leaks. Safe to call from any thread without an
active event loop.
"""
with self._lock:
entries = list(self._entries.items())
cms = dict(self._context_managers)
self._entries.clear()
self._context_managers.clear()
for key, (_, loop) in entries:
cm = cms.get(key)
if cm is None or loop.is_closed():
continue
try:
if loop.is_running():
# Schedule on the owning loop from this (different) thread.
future = asyncio.run_coroutine_threadsafe(cm.__aexit__(None, None, None), loop)
future.result(timeout=self.SESSION_CLOSE_TIMEOUT)
else:
loop.run_until_complete(cm.__aexit__(None, None, None))
except Exception:
logger.debug("Error closing MCP session %s during sync close", key, exc_info=True)
# ------------------------------------------------------------------
# Module-level singleton
# ------------------------------------------------------------------
_pool: MCPSessionPool | None = None
_pool_lock = threading.Lock()
def get_session_pool() -> MCPSessionPool:
"""Return the global session-pool singleton."""
global _pool
if _pool is None:
with _pool_lock:
if _pool is None:
_pool = MCPSessionPool()
return _pool
def reset_session_pool() -> None:
"""Reset the singleton (for tests)."""
global _pool
_pool = None
+190 -8
View File
@@ -1,21 +1,181 @@
"""Load MCP tools using langchain-mcp-adapters."""
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.tools import BaseTool
from langchain_core.tools import BaseTool, StructuredTool
from langgraph.config import get_config
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.mcp.client import build_servers_config
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
from deerflow.mcp.session_pool import get_session_pool
from deerflow.reflection import resolve_variable
from deerflow.tools.sync import make_sync_tool_wrapper
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
def _extract_thread_id(runtime: Runtime | None) -> str:
"""Extract thread_id from the injected tool runtime or LangGraph config."""
if runtime is not None:
tid = runtime.context.get("thread_id") if runtime.context else None
if tid is not None:
return str(tid)
config = runtime.config or {}
tid = config.get("configurable", {}).get("thread_id")
if tid is not None:
return str(tid)
try:
tid = get_config().get("configurable", {}).get("thread_id")
return str(tid) if tid is not None else "default"
except RuntimeError:
return "default"
def _convert_call_tool_result(call_tool_result: Any) -> Any:
"""Convert an MCP CallToolResult to the LangChain ``content_and_artifact`` format.
Implements the same conversion logic as the adapter without relying on
the private ``langchain_mcp_adapters.tools._convert_call_tool_result`` symbol.
"""
from langchain_core.messages import ToolMessage
from langchain_core.messages.content import create_file_block, create_image_block, create_text_block
from langchain_core.tools import ToolException
from mcp.types import EmbeddedResource, ImageContent, ResourceLink, TextContent, TextResourceContents
# Pass ToolMessage through directly (interceptor short-circuit).
if isinstance(call_tool_result, ToolMessage):
return call_tool_result, None
# Pass LangGraph Command through directly when langgraph is installed.
try:
from langgraph.types import Command
if isinstance(call_tool_result, Command):
return call_tool_result, None
except ImportError:
# langgraph is optional; if unavailable, continue with standard MCP content conversion.
pass
# Convert MCP content blocks to LangChain content blocks.
lc_content = []
for item in call_tool_result.content:
if isinstance(item, TextContent):
lc_content.append(create_text_block(text=item.text))
elif isinstance(item, ImageContent):
lc_content.append(create_image_block(base64=item.data, mime_type=item.mimeType))
elif isinstance(item, ResourceLink):
mime = item.mimeType or None
if mime and mime.startswith("image/"):
lc_content.append(create_image_block(url=str(item.uri), mime_type=mime))
else:
lc_content.append(create_file_block(url=str(item.uri), mime_type=mime))
elif isinstance(item, EmbeddedResource):
from mcp.types import BlobResourceContents
res = item.resource
if isinstance(res, TextResourceContents):
lc_content.append(create_text_block(text=res.text))
elif isinstance(res, BlobResourceContents):
mime = res.mimeType or None
if mime and mime.startswith("image/"):
lc_content.append(create_image_block(base64=res.blob, mime_type=mime))
else:
lc_content.append(create_file_block(base64=res.blob, mime_type=mime))
else:
lc_content.append(create_text_block(text=str(res)))
else:
lc_content.append(create_text_block(text=str(item)))
if call_tool_result.isError:
error_parts = [item["text"] for item in lc_content if isinstance(item, dict) and item.get("type") == "text"]
raise ToolException("\n".join(error_parts) if error_parts else str(lc_content))
artifact = None
if call_tool_result.structuredContent is not None:
artifact = {"structured_content": call_tool_result.structuredContent}
return lc_content, artifact
def _make_session_pool_tool(
tool: BaseTool,
server_name: str,
connection: dict[str, Any],
tool_interceptors: list[Any] | None = None,
) -> BaseTool:
"""Wrap an MCP tool so it reuses a persistent session from the pool.
Replaces the per-call session creation with pool-managed sessions scoped
by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g.
Playwright) keep their state across tool calls within the same thread.
The configured ``tool_interceptors`` (OAuth, custom) are preserved and
applied on every call before invoking the pooled session.
"""
# Strip the server-name prefix to recover the original MCP tool name.
original_name = tool.name
prefix = f"{server_name}_"
if original_name.startswith(prefix):
original_name = original_name[len(prefix) :]
pool = get_session_pool()
async def call_with_persistent_session(
runtime: Runtime | None = None,
**arguments: Any,
) -> Any:
thread_id = _extract_thread_id(runtime)
session = await pool.get_session(server_name, thread_id, connection)
if tool_interceptors:
from langchain_mcp_adapters.interceptors import MCPToolCallRequest
async def base_handler(request: MCPToolCallRequest) -> Any:
return await session.call_tool(request.name, request.args)
handler = base_handler
for interceptor in reversed(tool_interceptors):
outer = handler
async def wrapped(req: Any, _i: Any = interceptor, _h: Any = outer) -> Any:
return await _i(req, _h)
handler = wrapped
request = MCPToolCallRequest(
name=original_name,
args=arguments,
server_name=server_name,
runtime=runtime,
)
call_tool_result = await handler(request)
else:
call_tool_result = await session.call_tool(original_name, arguments)
return _convert_call_tool_result(call_tool_result)
return StructuredTool(
name=tool.name,
description=tool.description,
args_schema=tool.args_schema,
coroutine=call_with_persistent_session,
response_format="content_and_artifact",
metadata=tool.metadata,
)
async def get_mcp_tools() -> list[BaseTool]:
"""Get all tools from enabled MCP servers.
Tools are wrapped with persistent-session logic so that consecutive
calls within the same thread reuse the same MCP session.
Returns:
List of LangChain tools from all enabled MCP servers.
"""
@@ -50,7 +210,7 @@ async def get_mcp_tools() -> list[BaseTool]:
existing_headers["Authorization"] = auth_header
servers_config[server_name]["headers"] = existing_headers
tool_interceptors = []
tool_interceptors: list[Any] = []
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
if oauth_interceptor is not None:
tool_interceptors.append(oauth_interceptor)
@@ -74,20 +234,42 @@ async def get_mcp_tools() -> list[BaseTool]:
elif interceptor is not None:
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
except Exception as e:
logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True)
logger.warning(
f"Failed to load MCP interceptor {interceptor_path}: {e}",
exc_info=True,
)
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
client = MultiServerMCPClient(
servers_config,
tool_interceptors=tool_interceptors,
tool_name_prefix=True,
)
# Get all tools from all servers
# Get all tools from all servers (discovers tool definitions via
# temporary sessions the persistent-session wrapping is applied below).
tools = await client.get_tools()
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
# Patch tools to support sync invocation, as deerflow client streams synchronously
# Wrap each tool with persistent-session logic.
wrapped_tools: list[BaseTool] = []
for tool in tools:
tool_server: str | None = None
for name in servers_config:
if tool.name.startswith(f"{name}_"):
tool_server = name
break
if tool_server is not None:
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
else:
wrapped_tools.append(tool)
# Patch tools to support sync invocation, as deerflow client streams synchronously
for tool in wrapped_tools:
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
return tools
return wrapped_tools
except Exception as e:
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)
@@ -227,9 +227,48 @@ class RunRepository(RunStore):
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Update token usage + convenience fields while a run is still active."""
values: dict[str, Any] = {"updated_at": datetime.now(UTC)}
optional_counters = {
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
}
for key, value in optional_counters.items():
if value is not None:
values[key] = value
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id, RunRow.status == "running").values(**values))
await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error"))
statuses = ("success", "error", "running") if include_active else ("success", "error")
_completed = RunRow.status.in_(statuses)
_thread = RunRow.thread_id == thread_id
model_name = func.coalesce(RunRow.model_name, "unknown")
@@ -20,7 +20,7 @@ from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Mapping
from collections.abc import Awaitable, Callable, Mapping
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID
@@ -46,6 +46,8 @@ class RunJournal(BaseCallbackHandler):
*,
track_token_usage: bool = True,
flush_threshold: int = 20,
progress_reporter: Callable[[dict], Awaitable[None]] | None = None,
progress_flush_interval: float = 5.0,
):
super().__init__()
self.run_id = run_id
@@ -53,10 +55,16 @@ class RunJournal(BaseCallbackHandler):
self._store = event_store
self._track_tokens = track_token_usage
self._flush_threshold = flush_threshold
self._progress_reporter = progress_reporter
self._progress_flush_interval = progress_flush_interval
# Write buffer
self._buffer: list[dict] = []
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
self._pending_progress_task: asyncio.Task[None] | None = None
self._pending_progress_delayed = False
self._progress_dirty = False
self._last_progress_flush = 0.0
# Token accumulators
self._total_input_tokens = 0
@@ -294,6 +302,8 @@ class RunJournal(BaseCallbackHandler):
else:
self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
if messages:
self._counted_message_llm_run_ids.add(str(run_id))
@@ -445,6 +455,8 @@ class RunJournal(BaseCallbackHandler):
else:
self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
def set_first_human_message(self, content: str) -> None:
"""Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None
@@ -474,6 +486,14 @@ class RunJournal(BaseCallbackHandler):
"""Force flush remaining buffer. Called in worker's finally block."""
if self._pending_flush_tasks:
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
while self._pending_progress_task is not None and not self._pending_progress_task.done():
if self._pending_progress_delayed:
self._pending_progress_task.cancel()
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
self._progress_dirty = False
self._pending_progress_delayed = False
break
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
while self._buffer:
batch = self._buffer[: self._flush_threshold]
@@ -484,6 +504,57 @@ class RunJournal(BaseCallbackHandler):
self._buffer = batch + self._buffer
raise
def _schedule_progress_flush(self) -> None:
"""Best-effort throttled progress snapshot for active run visibility."""
if self._progress_reporter is None:
return
now = time.monotonic()
elapsed = now - self._last_progress_flush
if elapsed < self._progress_flush_interval:
self._progress_dirty = True
self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed)
return
if self._pending_progress_task is not None and not self._pending_progress_task.done():
self._progress_dirty = True
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
self._progress_dirty = False
self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data()))
def _schedule_delayed_progress_flush(self, delay: float) -> None:
if self._pending_progress_task is not None and not self._pending_progress_task.done():
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
delay = max(0.0, delay)
self._pending_progress_delayed = delay > 0
self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay))
async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None:
if self._progress_reporter is None:
return
if delay > 0:
self._pending_progress_delayed = True
await asyncio.sleep(delay)
self._pending_progress_delayed = False
dirty_before_write = self._progress_dirty
self._progress_dirty = False
snapshot_to_write = snapshot or self.get_completion_data()
try:
await self._progress_reporter(snapshot_to_write)
self._last_progress_flush = time.monotonic()
except Exception:
logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True)
if dirty_before_write or self._progress_dirty:
self._progress_dirty = False
self._pending_progress_task = None
self._schedule_delayed_progress_flush(self._progress_flush_interval)
def get_completion_data(self) -> dict:
"""Return accumulated token and message data for run completion."""
return {
@@ -38,6 +38,16 @@ class RunRecord:
error: str | None = None
model_name: str | None = None
store_only: bool = False
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
last_ai_message: str | None = None
first_human_message: str | None = None
class RunManager:
@@ -102,16 +112,53 @@ class RunManager:
error=row.get("error"),
model_name=row.get("model_name"),
store_only=True,
total_input_tokens=row.get("total_input_tokens") or 0,
total_output_tokens=row.get("total_output_tokens") or 0,
total_tokens=row.get("total_tokens") or 0,
llm_call_count=row.get("llm_call_count") or 0,
lead_agent_tokens=row.get("lead_agent_tokens") or 0,
subagent_tokens=row.get("subagent_tokens") or 0,
middleware_tokens=row.get("middleware_tokens") or 0,
message_count=row.get("message_count") or 0,
last_ai_message=row.get("last_ai_message"),
first_human_message=row.get("first_human_message"),
)
async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store."""
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
for key, value in kwargs.items():
if key == "status":
continue
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if self._store is not None:
try:
await self._store.update_run_completion(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def update_run_progress(self, run_id: str, **kwargs) -> None:
"""Persist a running token/message snapshot without changing status."""
should_persist = True
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
should_persist = record.status == RunStatus.running
if record is not None and should_persist:
for key, value in kwargs.items():
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if should_persist and self._store is not None:
try:
await self._store.update_run_progress(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run progress for %s", run_id, exc_info=True)
async def create(
self,
thread_id: str,
@@ -95,12 +95,30 @@ class RunStore(abc.ABC):
) -> None:
pass
async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Persist a best-effort running snapshot without changing run status."""
return None
@abc.abstractmethod
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass
@abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread.
Returns a dict with keys: total_tokens, total_input_tokens,
@@ -82,14 +82,22 @@ class MemoryRunStore(RunStore):
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def update_run_progress(self, run_id, **kwargs):
if run_id in self._runs and self._runs[run_id].get("status") == "running":
for key, value in kwargs.items():
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"])
return results
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
statuses = ("success", "error", "running") if include_active else ("success", "error")
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
by_model: dict[str, dict] = {}
for r in completed:
model = r.get("model_name") or "unknown"
@@ -153,8 +153,6 @@ async def run_agent(
journal = None
journal = None
# Track whether "events" was requested but skipped
if "events" in requested_modes:
logger.info(
@@ -177,6 +175,7 @@ async def run_agent(
thread_id=thread_id,
event_store=event_store,
track_token_usage=getattr(run_events_config, "track_token_usage", True),
progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot),
)
# 1. Mark running
@@ -219,6 +218,12 @@ async def run_agent(
# manually here because we drive the graph through ``agent.astream(config=...)``
# without passing the official ``context=`` parameter.
runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config)
# Expose the run-scoped journal under a sentinel key so middleware can
# write audit events (e.g. SafetyFinishReasonMiddleware recording
# suppressed tool calls). Double-underscore prefix marks it as a
# runtime-internal channel; user code must not depend on the key name.
if journal is not None:
runtime_ctx["__run_journal"] = journal
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=cast(Any, runtime_ctx), store=store)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
@@ -42,6 +42,7 @@ _DEFAULT_GLOB_MAX_RESULTS = 200
_MAX_GLOB_MAX_RESULTS = 1000
_DEFAULT_GREP_MAX_RESULTS = 100
_MAX_GREP_MAX_RESULTS = 500
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS = 2000
_LOCAL_BASH_CWD_COMMANDS = {"cd", "pushd"}
_LOCAL_BASH_COMMAND_WRAPPERS = {"command", "builtin"}
_LOCAL_BASH_COMMAND_PREFIX_KEYWORDS = {"!", "{", "case", "do", "elif", "else", "for", "if", "select", "then", "time", "until", "while"}
@@ -435,6 +436,42 @@ def _sanitize_error(error: Exception, runtime: Runtime | None = None) -> str:
return msg
def _truncate_write_file_error_detail(detail: str, max_chars: int) -> str:
"""Middle-truncate write_file error details, preserving the head and tail."""
if max_chars == 0:
return detail
if len(detail) <= max_chars:
return detail
total = len(detail)
marker_max_len = len(f"\n... [write_file error truncated: {total} chars skipped] ...\n")
kept = max(0, max_chars - marker_max_len)
if kept == 0:
return detail[:max_chars]
head_len = kept // 2
tail_len = kept - head_len
skipped = total - kept
marker = f"\n... [write_file error truncated: {skipped} chars skipped] ...\n"
return f"{detail[:head_len]}{marker}{detail[-tail_len:] if tail_len > 0 else ''}"
def _format_write_file_error(
requested_path: str,
error: Exception,
runtime: Runtime | None = None,
*,
max_chars: int = _DEFAULT_WRITE_FILE_ERROR_MAX_CHARS,
) -> str:
"""Return a bounded, sanitized error string for write_file failures."""
header = f"Error: Failed to write file '{requested_path}'"
detail = _sanitize_error(error, runtime)
if max_chars == 0:
return f"{header}: {detail}"
detail_budget = max_chars - len(header) - 2
if detail_budget <= 0:
return _truncate_write_file_error_detail(f"{header}: {detail}", max_chars)
return f"{header}: {_truncate_write_file_error_detail(detail, detail_budget)}"
def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
"""Replace virtual /mnt/user-data paths with actual thread data paths.
@@ -1651,9 +1688,9 @@ def write_file_tool(
append: Whether to append content to the end of the file instead of overwriting it. Defaults to false.
"""
try:
requested_path = path
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data)
@@ -1664,15 +1701,21 @@ def write_file_tool(
sandbox.write_file(path, content, append)
return "OK"
except SandboxError as e:
return f"Error: {e}"
return _format_write_file_error(requested_path, e, runtime)
except PermissionError:
return f"Error: Permission denied writing to file: {requested_path}"
return _truncate_write_file_error_detail(
f"Error: Permission denied writing to file: {requested_path}",
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS,
)
except IsADirectoryError:
return f"Error: Path is a directory, not a file: {requested_path}"
return _truncate_write_file_error_detail(
f"Error: Path is a directory, not a file: {requested_path}",
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS,
)
except OSError as e:
return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime)}"
return _format_write_file_error(requested_path, e, runtime)
except Exception as e:
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}"
return _format_write_file_error(requested_path, e, runtime)
async def _write_file_tool_async(
@@ -7,6 +7,7 @@ from dataclasses import replace
from typing import TYPE_CHECKING, Annotated, Any, cast
from langchain.tools import InjectedToolCallId, tool
from langchain_core.callbacks import BaseCallbackManager
from langgraph.config import get_stream_writer
from deerflow.config import get_app_config
@@ -99,15 +100,31 @@ def _schedule_deferred_subagent_cleanup(task_id: str, trace_id: str, max_polls:
def _find_usage_recorder(runtime: Any) -> Any | None:
"""Find a callback handler with ``record_external_llm_usage_records`` in the runtime config."""
"""Find a callback handler with ``record_external_llm_usage_records`` in the runtime config.
LangChain may pass ``config["callbacks"]`` in three different shapes:
- ``None`` (no callbacks registered): no recorder.
- A plain ``list[BaseCallbackHandler]``: iterate it directly.
- A ``BaseCallbackManager`` instance (e.g. ``AsyncCallbackManager`` on async
tool runs): managers are not iterable, so we unwrap ``.handlers`` first.
Any other shape (e.g. a single handler object accidentally passed without a
list wrapper) cannot be iterated safely; treat it as "no recorder" rather
than raise.
"""
if runtime is None:
return None
config = getattr(runtime, "config", None)
if not isinstance(config, dict):
return None
callbacks = config.get("callbacks", [])
callbacks = config.get("callbacks")
if isinstance(callbacks, BaseCallbackManager):
callbacks = callbacks.handlers
if not callbacks:
return None
if not isinstance(callbacks, list):
return None
for cb in callbacks:
if hasattr(cb, "record_external_llm_usage_records"):
return cb