refactor(runtime): restructure runs module with new execution architecture

Major refactoring of deerflow/runtime/:
- runs/callbacks/ - new callback system (builder, events, title, tokens)
- runs/internal/ - execution internals (executor, supervisor, stream_logic, registry)
- runs/internal/execution/ - execution artifacts and events handling
- runs/facade.py - high-level run facade
- runs/observer.py - run observation protocol
- runs/types.py - type definitions
- runs/store/ - simplified store interfaces (create, delete, query, event)

Refactor stream_bridge/:
- Replace old providers with contract.py and exceptions.py
- Remove async_provider.py, base.py, memory.py

Add documentation:
- README.md and README_zh.md for runtime module

Remove deprecated:
- manager.py moved to internal/
- worker.py, schemas.py
- user_context.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-22 11:28:01 +08:00
parent 39a575617b
commit 9d0a42c1fb
43 changed files with 3928 additions and 1192 deletions
@@ -1,16 +1,48 @@
"""Run lifecycle management for LangGraph Platform API compatibility."""
"""Public runs API."""
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
from .schemas import DisconnectMode, RunStatus
from .worker import RunContext, run_agent
from .facade import RunsFacade
from .internal.manager import RunManager
from .observer import (
CallbackObserver,
CompositeObserver,
LifecycleEventType,
NullObserver,
ObserverBinding,
ObserverLike,
RunEventCallback,
RunLifecycleEvent,
RunObserver,
RunResult,
ensure_observer,
)
from .store import RunCreateStore, RunDeleteStore, RunEventStore, RunQueryStore
from .types import CancelAction, RunRecord, RunScope, RunSpec, RunStatus, WaitResult
__all__ = [
"ConflictError",
"DisconnectMode",
"RunContext",
# facade
"RunsFacade",
"RunManager",
"RunCreateStore",
"RunDeleteStore",
"RunEventStore",
"RunQueryStore",
# hooks
"CallbackObserver",
"CompositeObserver",
"LifecycleEventType",
"NullObserver",
"ObserverBinding",
"ObserverLike",
"RunEventCallback",
"RunLifecycleEvent",
"RunObserver",
"RunResult",
"ensure_observer",
# types
"CancelAction",
"RunRecord",
"RunScope",
"RunSpec",
"WaitResult",
"RunStatus",
"UnsupportedStrategyError",
"run_agent",
]
@@ -0,0 +1,15 @@
"""Runs execution callbacks."""
from .builder import RunCallbackArtifacts, build_run_callbacks
from .events import RunEventCallback
from .title import RunTitleCallback
from .tokens import RunCompletionData, RunTokenCallback
__all__ = [
"RunCallbackArtifacts",
"RunCompletionData",
"RunEventCallback",
"RunTitleCallback",
"RunTokenCallback",
"build_run_callbacks",
]
@@ -0,0 +1,138 @@
"""Callback assembly for runs execution."""
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any
from langchain_core.callbacks import BaseCallbackHandler
from ..store import RunEventStore
from ..types import RunRecord
from .events import RunEventCallback
from .title import RunTitleCallback
from .tokens import RunCompletionData, RunTokenCallback
@dataclass
class RunCallbackArtifacts:
"""Callbacks plus handles used by the executor after callbacks run."""
callbacks: list[BaseCallbackHandler]
event_callback: RunEventCallback | None = None
token_callback: RunTokenCallback | None = None
title_callback: RunTitleCallback | None = None
async def flush(self) -> None:
for callback in self.callbacks:
flush = getattr(callback, "flush", None)
if flush is None:
continue
result = flush()
if hasattr(result, "__await__"):
await result
def completion_data(self) -> RunCompletionData:
if self.token_callback is None:
return RunCompletionData()
return self.token_callback.completion_data()
def title(self) -> str | None:
if self.title_callback is None:
return None
return self.title_callback.title()
def build_run_callbacks(
*,
record: RunRecord,
graph_input: dict[str, Any],
event_store: RunEventStore | None,
existing_callbacks: Iterable[BaseCallbackHandler] = (),
) -> RunCallbackArtifacts:
"""Build execution callbacks for a run.
Reference callbacks are intentionally not assembled here yet; they remain
in the existing artifacts path until that integration is migrated.
"""
callbacks = list(existing_callbacks)
event_callback = None
if event_store is not None:
event_callback = RunEventCallback(
run_id=record.run_id,
thread_id=record.thread_id,
event_store=event_store,
)
callbacks.append(event_callback)
token_callback = RunTokenCallback(track_token_usage=True)
_set_first_human_message(token_callback, graph_input)
callbacks.append(token_callback)
title_callback = RunTitleCallback()
callbacks.append(title_callback)
return RunCallbackArtifacts(
callbacks=callbacks,
event_callback=event_callback,
token_callback=token_callback,
title_callback=title_callback,
)
def _set_first_human_message(token_callback: RunTokenCallback, graph_input: dict[str, Any]) -> None:
messages = graph_input.get("messages")
if not isinstance(messages, list) or not messages:
return
first = messages[0]
content = _extract_first_human_text(first)
if content:
token_callback.set_first_human_message(content)
def _extract_first_human_text(message: Any) -> str | None:
if isinstance(message, str):
return message
content = getattr(message, "content", None)
if content is not None:
return _extract_text_content(content)
if isinstance(message, dict):
return _extract_text_content(message.get("content"))
return None
def _extract_text_content(content: Any) -> str | None:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
continue
if not isinstance(item, dict):
continue
if item.get("type") == "text" and isinstance(item.get("text"), str):
parts.append(item["text"])
continue
if isinstance(item.get("content"), str):
parts.append(item["content"])
joined = "".join(parts).strip()
return joined or None
if isinstance(content, dict):
text = content.get("text")
if isinstance(text, str):
return text
nested = content.get("content")
if isinstance(nested, str):
return nested
return None
@@ -0,0 +1,353 @@
"""Run execution event recording callback."""
from __future__ import annotations
import asyncio
import logging
import time
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import HumanMessage
from deerflow.runtime.converters import langchain_messages_to_openai, langchain_to_openai_completion
from ..store import RunEventStore
logger = logging.getLogger(__name__)
class RunEventCallback(BaseCallbackHandler):
"""Capture LangChain execution events into the run event store."""
def __init__(
self,
*,
run_id: str,
thread_id: str,
event_store: RunEventStore,
flush_threshold: int = 5,
max_trace_content: int = 10240,
) -> None:
super().__init__()
self.run_id = run_id
self.thread_id = thread_id
self._store = event_store
self._flush_threshold = flush_threshold
self._max_trace_content = max_trace_content
self._buffer: list[dict[str, Any]] = []
self._llm_start_times: dict[str, float] = {}
self._llm_call_index = 0
self._cached_prompts: dict[str, list[dict[str, Any]]] = {}
self._tool_call_ids: dict[str, str] = {}
self._human_message_recorded = False
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(
event_type="run_start",
category="lifecycle",
metadata={"input_preview": str(inputs)[:500]},
)
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
self._flush_sync()
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(
event_type="run_error",
category="lifecycle",
content=str(error),
metadata={"error_type": type(error).__name__},
)
self._flush_sync()
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
rid = str(run_id)
self._llm_start_times[rid] = time.monotonic()
self._llm_call_index += 1
prompt_msgs = messages[0] if messages else []
openai_msgs = langchain_messages_to_openai(prompt_msgs)
self._cached_prompts[rid] = openai_msgs
caller = self._identify_caller(kwargs)
self._record_first_human_message(prompt_msgs, caller=caller)
self._put(
event_type="llm_request",
category="trace",
content={"model": serialized.get("name", ""), "messages": openai_msgs},
metadata={
"caller": caller,
"llm_call_index": self._llm_call_index,
},
)
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times[str(run_id)] = time.monotonic()
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
try:
message = response.generations[0][0].message
except (IndexError, AttributeError):
logger.debug("on_llm_end: could not extract message from response")
return
rid = str(run_id)
start = self._llm_start_times.pop(rid, None)
latency_ms = int((time.monotonic() - start) * 1000) if start else None
usage = dict(getattr(message, "usage_metadata", None) or {})
caller = self._identify_caller(kwargs)
call_index = self._llm_call_index
if rid not in self._cached_prompts:
self._llm_call_index += 1
call_index = self._llm_call_index
self._cached_prompts.pop(rid, None)
self._put(
event_type="llm_response",
category="trace",
content=langchain_to_openai_completion(message),
metadata={
"caller": caller,
"usage": usage,
"latency_ms": latency_ms,
"llm_call_index": call_index,
},
)
content = getattr(message, "content", "")
tool_calls = getattr(message, "tool_calls", None) or []
if caller != "lead_agent":
return
if tool_calls:
self._put(
event_type="ai_tool_call",
category="message",
content=message.model_dump(),
metadata={"finish_reason": "tool_calls"},
)
elif isinstance(content, str) and content:
self._put(
event_type="ai_message",
category="message",
content=message.model_dump(),
metadata={"finish_reason": "stop"},
)
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None)
self._put(event_type="llm_error", category="trace", content=str(error))
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
tool_call_id = kwargs.get("tool_call_id")
if tool_call_id:
self._tool_call_ids[str(run_id)] = tool_call_id
self._put(
event_type="tool_start",
category="trace",
metadata={
"tool_name": serialized.get("name", ""),
"tool_call_id": tool_call_id,
"args": str(input_str)[:2000],
},
)
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
if isinstance(output, ToolMessage):
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = output.name or kwargs.get("name", "")
status = getattr(output, "status", "success") or "success"
content_str = output.content if isinstance(output.content, str) else str(output.content)
msg_content = output.model_dump()
if msg_content.get("tool_call_id") != tool_call_id:
msg_content["tool_call_id"] = tool_call_id
else:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
status = "success"
content_str = str(output)
msg_content = ToolMessage(
content=content_str,
tool_call_id=tool_call_id or "",
name=tool_name,
status=status,
).model_dump()
self._put(
event_type="tool_end",
category="trace",
content=content_str,
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
"status": status,
},
)
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": status},
)
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
self._put(
event_type="tool_error",
category="trace",
content=str(error),
metadata={"tool_name": tool_name, "tool_call_id": tool_call_id},
)
self._put(
event_type="tool_result",
category="message",
content=ToolMessage(
content=str(error),
tool_call_id=tool_call_id or "",
name=tool_name,
status="error",
).model_dump(),
metadata={"tool_name": tool_name, "status": "error"},
)
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.serialization import serialize_lc_object
if name == "summarization":
data_dict = data if isinstance(data, dict) else {}
self._put(
event_type="summarization",
category="trace",
content=data_dict.get("summary", ""),
metadata={
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
"replaced_count": data_dict.get("replaced_count", 0),
},
)
self._put(
event_type="middleware:summarize",
category="middleware",
content={"role": "system", "content": data_dict.get("summary", "")},
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
)
return
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
self._put(
event_type=name,
category="trace",
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
)
async def flush(self) -> None:
if self._buffer:
batch = self._buffer.copy()
self._buffer.clear()
await self._store.put_batch(batch)
def _put(
self,
*,
event_type: str,
category: str,
content: Any = "",
metadata: dict[str, Any] | None = None,
) -> None:
normalized_metadata = dict(metadata or {})
if category != "message" and isinstance(content, str) and len(content) > self._max_trace_content:
normalized_metadata["content_truncated"] = True
normalized_metadata["original_content_length"] = len(content)
content = content[: self._max_trace_content]
self._buffer.append(
{
"thread_id": self.thread_id,
"run_id": self.run_id,
"event_type": event_type,
"category": category,
"content": content,
"metadata": normalized_metadata,
"created_at": datetime.now(UTC).isoformat(),
}
)
if len(self._buffer) >= self._flush_threshold:
self._flush_sync()
def _flush_sync(self) -> None:
if not self._buffer:
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
batch = self._buffer.copy()
self._buffer.clear()
task = loop.create_task(self._flush_async(batch))
task.add_done_callback(self._on_flush_done)
async def _flush_async(self, batch: list[dict[str, Any]]) -> None:
try:
await self._store.put_batch(batch)
except Exception:
logger.warning(
"Failed to flush %d events for run %s; returning to buffer",
len(batch),
self.run_id,
exc_info=True,
)
self._buffer = batch + self._buffer
@staticmethod
def _on_flush_done(task: asyncio.Task) -> None:
if task.cancelled():
return
exc = task.exception()
if exc:
logger.warning("Run event flush task failed: %s", exc)
def _identify_caller(self, kwargs: dict[str, Any]) -> str:
for tag in kwargs.get("tags") or []:
if isinstance(tag, str) and (
tag.startswith("subagent:")
or tag.startswith("middleware:")
or tag == "lead_agent"
):
return tag
return "lead_agent"
def _record_first_human_message(self, messages: list[Any], *, caller: str) -> None:
if self._human_message_recorded:
return
for message in messages:
if not isinstance(message, HumanMessage):
continue
if message.name == "summary":
continue
self._put(
event_type="human_message",
category="message",
content=message.model_dump(),
metadata={
"caller": caller,
"source": "chat_model_start",
},
)
self._human_message_recorded = True
return
@@ -0,0 +1,51 @@
"""Title capture callback for runs."""
from __future__ import annotations
from typing import Any
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
class RunTitleCallback(BaseCallbackHandler):
"""Capture title generated by title middleware LLM calls or custom events."""
def __init__(self) -> None:
super().__init__()
self._title: str | None = None
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
if self._identify_caller(kwargs) != "middleware:title":
return
try:
message = response.generations[0][0].message
except (IndexError, AttributeError):
return
content = getattr(message, "content", "")
if isinstance(content, str) and content:
self._title = content.strip().strip('"').strip("'")[:200]
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
if name not in {"title", "thread_title", "middleware:title"}:
return
if isinstance(data, str):
self._title = data.strip()[:200]
return
if isinstance(data, dict):
title = data.get("title")
if isinstance(title, str):
self._title = title.strip()[:200]
def title(self) -> str | None:
return self._title
def _identify_caller(self, kwargs: dict[str, Any]) -> str:
for tag in kwargs.get("tags") or []:
if isinstance(tag, str) and (
tag.startswith("subagent:")
or tag.startswith("middleware:")
or tag == "lead_agent"
):
return tag
return "lead_agent"
@@ -0,0 +1,122 @@
"""Token and message summary callback for runs."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
@dataclass(frozen=True)
class RunCompletionData:
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
last_ai_message: str | None = None
first_human_message: str | None = None
def to_dict(self) -> dict[str, object]:
return {
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"total_tokens": self.total_tokens,
"llm_call_count": self.llm_call_count,
"lead_agent_tokens": self.lead_agent_tokens,
"subagent_tokens": self.subagent_tokens,
"middleware_tokens": self.middleware_tokens,
"message_count": self.message_count,
"last_ai_message": self.last_ai_message,
"first_human_message": self.first_human_message,
}
class RunTokenCallback(BaseCallbackHandler):
"""Aggregate token and message summary data for one run."""
def __init__(self, *, track_token_usage: bool = True) -> None:
super().__init__()
self._track_token_usage = track_token_usage
self._total_input_tokens = 0
self._total_output_tokens = 0
self._total_tokens = 0
self._llm_call_count = 0
self._lead_agent_tokens = 0
self._subagent_tokens = 0
self._middleware_tokens = 0
self._message_count = 0
self._last_ai_message: str | None = None
self._first_human_message: str | None = None
def set_first_human_message(self, content: str) -> None:
self._first_human_message = content[:2000] if content else None
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
try:
message = response.generations[0][0].message
except (IndexError, AttributeError):
return
self._record_ai_message(message, kwargs)
if not self._track_token_usage:
return
usage = dict(getattr(message, "usage_metadata", None) or {})
input_tk = usage.get("input_tokens", 0) or 0
output_tk = usage.get("output_tokens", 0) or 0
total_tk = usage.get("total_tokens", 0) or input_tk + output_tk
if total_tk <= 0:
return
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
caller = self._identify_caller(kwargs)
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
def completion_data(self) -> RunCompletionData:
return RunCompletionData(
total_input_tokens=self._total_input_tokens,
total_output_tokens=self._total_output_tokens,
total_tokens=self._total_tokens,
llm_call_count=self._llm_call_count,
lead_agent_tokens=self._lead_agent_tokens,
subagent_tokens=self._subagent_tokens,
middleware_tokens=self._middleware_tokens,
message_count=self._message_count,
last_ai_message=self._last_ai_message,
first_human_message=self._first_human_message,
)
def _record_ai_message(self, message: Any, kwargs: dict[str, Any]) -> None:
if self._identify_caller(kwargs) != "lead_agent":
return
if getattr(message, "tool_calls", None):
return
content = getattr(message, "content", "")
if isinstance(content, str) and content:
self._last_ai_message = content[:2000]
self._message_count += 1
def _identify_caller(self, kwargs: dict[str, Any]) -> str:
for tag in kwargs.get("tags") or []:
if isinstance(tag, str) and (
tag.startswith("subagent:")
or tag.startswith("middleware:")
or tag == "lead_agent"
):
return tag
return "lead_agent"
@@ -0,0 +1,240 @@
"""Public runs facade."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, AsyncIterator, Callable
from deerflow.runtime.stream_bridge import StreamEvent
from .internal.execution.executor import _RunExecution
from .internal.execution.supervisor import RunSupervisor
from .internal.planner import ExecutionPlanner
from .internal.registry import RunRegistry
from .internal.streams import RunStreamService
from .internal.wait import RunWaitService, WaitErrorResult
from .observer import ObserverLike
from .store import RunCreateStore, RunDeleteStore, RunEventStore, RunQueryStore
from .types import CancelAction, RunRecord, RunSpec
class MultitaskRejectError(Exception):
"""Raised when multitask_strategy is reject and thread has inflight runs."""
pass
@dataclass(frozen=True)
class RunsRuntime:
"""Runtime dependencies needed to execute a run."""
bridge: Any
checkpointer: Any
store: Any | None
event_store: RunEventStore | None
agent_factory_resolver: Callable[[str | None], Any]
class _RegistryStatusAdapter:
"""Minimal adapter so execution can update registry-backed run status."""
def __init__(self, registry: RunRegistry) -> None:
self._registry = registry
async def set_status(self, run_id: str, status: Any, *, error: str | None = None) -> None:
await self._registry.set_status(run_id, status, error=error)
class RunsFacade:
"""
Phase 1 runs domain facade.
Provides unified interface for:
- create_background
- create_and_stream
- create_and_wait
- join_stream
- join_wait
Orchestrates registry, planner, supervisor, stream, and wait services.
Execution now flows through ExecutionPlanner + RunSupervisor rather than
the legacy RunManager create/start path.
"""
def __init__(
self,
registry: RunRegistry,
planner: ExecutionPlanner,
supervisor: RunSupervisor,
stream_service: RunStreamService,
wait_service: RunWaitService,
runtime: RunsRuntime,
observer: ObserverLike = None,
query_store: RunQueryStore | None = None,
create_store: RunCreateStore | None = None,
delete_store: RunDeleteStore | None = None,
) -> None:
self._registry = registry
self._planner = planner
self._supervisor = supervisor
self._stream = stream_service
self._wait = wait_service
self._runtime = runtime
self._observer = observer
self._query_store = query_store
self._create_store = create_store
self._delete_store = delete_store
async def create_background(self, spec: RunSpec) -> RunRecord:
"""
Create a run in background mode.
Returns immediately with the run record.
The run executes asynchronously.
"""
return await self._create_run(spec)
async def create_and_stream(
self,
spec: RunSpec,
) -> tuple[RunRecord, AsyncIterator[StreamEvent]]:
"""
Create a run and return stream.
Returns (record, stream_iterator).
"""
record = await self._create_run(spec)
stream = self._stream.subscribe(record.run_id)
return record, stream
async def create_and_wait(
self,
spec: RunSpec,
) -> tuple[RunRecord, dict[str, Any] | WaitErrorResult | None]:
"""
Create a run and wait for completion.
Returns (record, final_values_or_error).
"""
record = await self._create_run(spec)
result = await self._wait.wait_for_values_or_error(record.run_id)
return record, result
async def join_stream(
self,
run_id: str,
*,
last_event_id: str | None = None,
) -> AsyncIterator[StreamEvent]:
"""
Join an existing run stream.
Supports resumption via last_event_id.
"""
return self._stream.subscribe(run_id, last_event_id=last_event_id)
async def join_wait(
self,
run_id: str,
*,
last_event_id: str | None = None,
) -> dict[str, Any] | WaitErrorResult | None:
"""
Join an existing run and wait for completion.
"""
return await self._wait.wait_for_values_or_error(
run_id,
last_event_id=last_event_id,
)
async def cancel(
self,
run_id: str,
*,
action: CancelAction = "interrupt",
) -> bool:
"""Request cancellation for an active run."""
return await self._supervisor.cancel(run_id, action=action)
async def get_run(self, run_id: str) -> RunRecord | None:
"""Get run record by ID."""
if self._query_store is not None:
return await self._query_store.get_run(run_id)
return self._registry.get(run_id)
async def list_runs(self, thread_id: str) -> list[RunRecord]:
"""List runs for a thread."""
if self._query_store is not None:
return await self._query_store.list_runs(thread_id)
return await self._registry.list_by_thread(thread_id)
async def delete_run(self, run_id: str) -> bool:
"""Delete a run from durable storage and local runtime state."""
record = await self.get_run(run_id)
if record is None:
return False
await self._supervisor.cancel(run_id, action="interrupt")
await self._registry.delete(run_id)
if self._delete_store is not None:
return await self._delete_store.delete_run(run_id)
return True
async def _create_run(self, spec: RunSpec) -> RunRecord:
"""Create a run record and hand it to the execution backend."""
await self._apply_multitask_strategy(spec)
record = await self._registry.create(spec)
if self._create_store is not None:
await self._create_store.create_run(record)
await self._start_execution(record, spec)
return record
async def _apply_multitask_strategy(self, spec: RunSpec) -> None:
"""Apply multitask strategy before creating run."""
has_inflight = await self._registry.has_inflight(spec.scope.thread_id)
if not has_inflight:
return
if spec.multitask_strategy == "reject":
raise MultitaskRejectError(
f"Thread {spec.scope.thread_id} has inflight runs"
)
elif spec.multitask_strategy == "interrupt":
interrupted = await self._registry.interrupt_inflight(spec.scope.thread_id)
for run_id in interrupted:
await self._supervisor.cancel(run_id, action="interrupt")
async def _start_execution(self, record: RunRecord, spec: RunSpec) -> None:
"""Start run execution via planner + supervisor."""
# Update status to starting
await self._registry.set_status(record.run_id, "starting")
plan = self._planner.build(record, spec)
status_adapter = _RegistryStatusAdapter(self._registry)
agent_factory = self._runtime.agent_factory_resolver(spec.assistant_id)
async def _runner(handle) -> Any:
return await _RunExecution(
bridge=self._runtime.bridge,
run_manager=status_adapter, # type: ignore[arg-type]
record=record,
checkpointer=self._runtime.checkpointer,
store=self._runtime.store,
event_store=self._runtime.event_store,
agent_factory=agent_factory,
graph_input=plan.graph_input,
config=plan.runnable_config,
observer=self._observer,
stream_modes=plan.stream_modes,
stream_subgraphs=plan.stream_subgraphs,
interrupt_before=plan.interrupt_before,
interrupt_after=plan.interrupt_after,
handle=handle,
).run()
await self._supervisor.launch(record.run_id, runner=_runner)
@@ -0,0 +1,4 @@
"""Internal runs implementation modules.
These modules are implementation details behind the public runs surface.
"""
@@ -0,0 +1 @@
"""Internal execution components for runs domain."""
@@ -0,0 +1,64 @@
"""Execution preparation helpers for a single run."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from deerflow.runtime.stream_bridge import StreamBridge
@dataclass
class RunBuildArtifacts:
"""Assembled agent runtime pieces for a single run."""
agent: Any
runnable_config: dict[str, Any]
reference_store: Any | None = None
def build_run_artifacts(
*,
thread_id: str,
run_id: str,
checkpointer: Any | None,
store: Any | None,
agent_factory: Any,
config: dict[str, Any],
bridge: StreamBridge,
interrupt_before: list[str] | None = None,
interrupt_after: list[str] | None = None,
callbacks: list[BaseCallbackHandler] | None = None,
) -> RunBuildArtifacts:
"""Assemble all components needed for agent execution."""
runtime = Runtime(context={"thread_id": thread_id}, store=store)
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
config_callbacks = config.setdefault("callbacks", [])
if callbacks:
config_callbacks.extend(callbacks)
runnable_config = RunnableConfig(**config)
agent = agent_factory(config=runnable_config)
if checkpointer is not None:
agent.checkpointer = checkpointer
if store is not None:
agent.store = store
if interrupt_before:
agent.interrupt_before_nodes = interrupt_before
if interrupt_after:
agent.interrupt_after_nodes = interrupt_after
return RunBuildArtifacts(
agent=agent,
runnable_config=dict(runnable_config),
reference_store=store,
)
@@ -0,0 +1,45 @@
"""Lifecycle event helpers for run execution."""
from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
from ...observer import LifecycleEventType, RunLifecycleEvent, RunObserver
class RunEventEmitter:
"""Build and dispatch lifecycle events for a single run."""
def __init__(
self,
*,
run_id: str,
thread_id: str,
observer: RunObserver,
) -> None:
self._run_id = run_id
self._thread_id = thread_id
self._observer = observer
self._sequence = 0
@property
def sequence(self) -> int:
return self._sequence
async def emit(
self,
event_type: LifecycleEventType,
payload: dict[str, Any] | None = None,
) -> None:
self._sequence += 1
event = RunLifecycleEvent(
event_id=f"{self._run_id}:{event_type.value}:{self._sequence}",
event_type=event_type,
run_id=self._run_id,
thread_id=self._thread_id,
sequence=self._sequence,
occurred_at=datetime.now(UTC),
payload=payload or {},
)
await self._observer.on_event(event)
@@ -0,0 +1,376 @@
"""Single-run execution orchestrator and execution-local helpers."""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Literal
from langchain_core.runnables import RunnableConfig
from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge, StreamStatus
from ...callbacks.builder import RunCallbackArtifacts, build_run_callbacks
from ...observer import LifecycleEventType, RunObserver, RunResult
from ...store import RunEventStore
from ...types import RunStatus
from .artifacts import build_run_artifacts
from .events import RunEventEmitter
from .stream_logic import external_stream_event_name, normalize_stream_modes, should_filter_event, unpack_stream_item
from .supervisor import RunHandle
logger = logging.getLogger(__name__)
class _RunExecution:
"""Encapsulate the lifecycle of a single run."""
def __init__(
self,
*,
bridge: StreamBridge,
run_manager: Any,
record: Any,
checkpointer: Any | None = None,
store: Any | None = None,
event_store: RunEventStore | None = None,
ctx: Any | None = None,
agent_factory: Any,
graph_input: dict,
config: dict,
observer: RunObserver,
stream_modes: list[str] | None,
stream_subgraphs: bool,
interrupt_before: list[str] | Literal["*"] | None,
interrupt_after: list[str] | Literal["*"] | None,
handle: RunHandle | None = None,
) -> None:
if ctx is not None:
checkpointer = getattr(ctx, "checkpointer", checkpointer)
store = getattr(ctx, "store", store)
self.bridge = bridge
self.run_manager = run_manager
self.record = record
self.checkpointer = checkpointer
self.store = store
self.event_store = event_store
self.agent_factory = agent_factory
self.graph_input = graph_input
self.config = config
self.observer = observer
self.stream_modes = stream_modes
self.stream_subgraphs = stream_subgraphs
self.interrupt_before = interrupt_before
self.interrupt_after = interrupt_after
self.handle = handle
self.run_id = record.run_id
self.thread_id = record.thread_id
self._pre_run_checkpoint_id: str | None = None
self._emitter = RunEventEmitter(
run_id=self.run_id,
thread_id=self.thread_id,
observer=observer,
)
self.result = RunResult(
run_id=self.run_id,
thread_id=self.thread_id,
status=RunStatus.pending,
)
self._agent: Any = None
self._runnable_config: dict[str, Any] = {}
self._lg_modes: list[str] = []
self._callback_artifacts: RunCallbackArtifacts | None = None
@property
def _event_sequence(self) -> int:
return self._emitter.sequence
async def _emit(
self,
event_type: LifecycleEventType,
payload: dict[str, Any] | None = None,
) -> None:
await self._emitter.emit(event_type, payload)
async def _start(self) -> None:
await self.run_manager.set_status(self.run_id, RunStatus.running)
await self._emit(LifecycleEventType.RUN_STARTED, {})
human_msg = self._extract_human_message()
if human_msg is not None:
await self._emit(
LifecycleEventType.HUMAN_MESSAGE,
{"message": human_msg.model_dump()},
)
await self._capture_pre_run_checkpoint()
await self.bridge.publish(
self.run_id,
"metadata",
{"run_id": self.run_id, "thread_id": self.thread_id},
)
def _extract_human_message(self) -> Any:
from langchain_core.messages import HumanMessage
messages = self.graph_input.get("messages")
if not messages:
return None
last = messages[-1] if isinstance(messages, list) else messages
if isinstance(last, HumanMessage):
return last
if isinstance(last, str):
return HumanMessage(content=last) if last else None
if hasattr(last, "content"):
return HumanMessage(content=last.content)
if isinstance(last, dict):
content = last.get("content", "")
return HumanMessage(content=content) if content else None
return None
async def _capture_pre_run_checkpoint(self) -> None:
try:
config_for_check = {"configurable": {"thread_id": self.thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await self.checkpointer.aget_tuple(config_for_check)
if ckpt_tuple is not None:
self._pre_run_checkpoint_id = (
getattr(ckpt_tuple, "config", {})
.get("configurable", {})
.get("checkpoint_id")
)
except Exception:
logger.debug("Could not get pre-run checkpoint_id for run %s", self.run_id)
async def _prepare(self) -> None:
config = dict(self.config)
existing_callbacks = config.pop("callbacks", [])
if existing_callbacks is None:
existing_callbacks = []
elif not isinstance(existing_callbacks, list):
existing_callbacks = [existing_callbacks]
self._callback_artifacts = build_run_callbacks(
record=self.record,
graph_input=self.graph_input,
event_store=self.event_store,
existing_callbacks=existing_callbacks,
)
artifacts = build_run_artifacts(
thread_id=self.thread_id,
run_id=self.run_id,
checkpointer=self.checkpointer,
store=self.store,
agent_factory=self.agent_factory,
config=config,
bridge=self.bridge,
interrupt_before=self.interrupt_before,
interrupt_after=self.interrupt_after,
callbacks=self._callback_artifacts.callbacks,
)
self._agent = artifacts.agent
self._runnable_config = artifacts.runnable_config
self._lg_modes = normalize_stream_modes(self.stream_modes)
logger.info(
"Run %s: streaming with modes %s (requested: %s)",
self.run_id,
self._lg_modes,
self.stream_modes,
)
async def _finish_success(self) -> None:
await self.run_manager.set_status(self.run_id, RunStatus.success)
await self.bridge.publish_terminal(self.run_id, StreamStatus.ENDED)
self.result.status = RunStatus.success
completion_data = self._completion_data()
title = self._callback_title() or await self._extract_title_from_checkpoint()
self.result.title = title
self.result.completion_data = completion_data
await self._emit(
LifecycleEventType.RUN_COMPLETED,
{
"title": title,
"completion_data": completion_data,
},
)
async def _finish_aborted(self, cancel_mode: str) -> None:
payload = {
"cancel_mode": cancel_mode,
"pre_run_checkpoint_id": self._pre_run_checkpoint_id,
"completion_data": self._completion_data(),
}
if cancel_mode == "rollback":
await self.run_manager.set_status(
self.run_id,
RunStatus.error,
error="Rolled back by user",
)
await self.bridge.publish_terminal(
self.run_id,
StreamStatus.CANCELLED,
{"cancel_mode": "rollback", "message": "Rolled back by user"},
)
self.result.status = RunStatus.error
self.result.error = "Rolled back by user"
logger.info("Run %s rolled back", self.run_id)
else:
await self.run_manager.set_status(self.run_id, RunStatus.interrupted)
await self.bridge.publish_terminal(
self.run_id,
StreamStatus.CANCELLED,
{"cancel_mode": cancel_mode},
)
self.result.status = RunStatus.interrupted
logger.info("Run %s cancelled (mode=%s)", self.run_id, cancel_mode)
await self._emit(LifecycleEventType.RUN_CANCELLED, payload)
async def _finish_failed(self, exc: Exception) -> None:
error_msg = str(exc)
logger.exception("Run %s failed: %s", self.run_id, error_msg)
await self.run_manager.set_status(self.run_id, RunStatus.error, error=error_msg)
await self.bridge.publish_terminal(
self.run_id,
StreamStatus.ERRORED,
{"message": error_msg, "name": type(exc).__name__},
)
self.result.status = RunStatus.error
self.result.error = error_msg
await self._emit(
LifecycleEventType.RUN_FAILED,
{
"error": error_msg,
"error_type": type(exc).__name__,
"completion_data": self._completion_data(),
},
)
def _completion_data(self) -> dict[str, object]:
if self._callback_artifacts is None:
return {}
return self._callback_artifacts.completion_data().to_dict()
def _callback_title(self) -> str | None:
if self._callback_artifacts is None:
return None
return self._callback_artifacts.title()
async def _extract_title_from_checkpoint(self) -> str | None:
if self.checkpointer is None:
return None
try:
ckpt_config = {"configurable": {"thread_id": self.thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await self.checkpointer.aget_tuple(ckpt_config)
if ckpt_tuple is not None:
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
return ckpt.get("channel_values", {}).get("title")
except Exception:
logger.debug("Failed to extract title from checkpoint for thread %s", self.thread_id)
return None
def _map_run_status_to_thread_status(self, status: RunStatus) -> str:
if status == RunStatus.success:
return "idle"
if status == RunStatus.interrupted:
return "interrupted"
if status in (RunStatus.error, RunStatus.timeout):
return "error"
return "running"
def _abort_requested(self) -> bool:
if self.handle is not None:
return self.handle.cancel_event.is_set()
return self.record.abort_event.is_set()
def _abort_action(self) -> str:
if self.handle is not None:
return self.handle.cancel_action
return self.record.abort_action
async def _stream(self) -> None:
runnable_config = RunnableConfig(**self._runnable_config)
if len(self._lg_modes) == 1 and not self.stream_subgraphs:
single_mode = self._lg_modes[0]
async for chunk in self._agent.astream(
self.graph_input,
config=runnable_config,
stream_mode=single_mode,
):
if self._abort_requested():
logger.info("Run %s abort requested - stopping", self.run_id)
break
if should_filter_event(single_mode, chunk):
continue
await self.bridge.publish(
self.run_id,
external_stream_event_name(single_mode),
serialize(chunk, mode=single_mode),
)
return
async for item in self._agent.astream(
self.graph_input,
config=runnable_config,
stream_mode=self._lg_modes,
subgraphs=self.stream_subgraphs,
):
if self._abort_requested():
logger.info("Run %s abort requested - stopping", self.run_id)
break
mode, chunk = unpack_stream_item(item, self._lg_modes, stream_subgraphs=self.stream_subgraphs)
if mode is None:
continue
if should_filter_event(mode, chunk):
continue
await self.bridge.publish(
self.run_id,
external_stream_event_name(mode),
serialize(chunk, mode=mode),
)
async def _finish_after_stream(self) -> None:
if self._abort_requested():
action = self._abort_action()
cancel_mode = "rollback" if action == "rollback" else "interrupt"
await self._finish_aborted(cancel_mode)
return
await self._finish_success()
async def _emit_final_thread_status(self) -> None:
final_thread_status = self._map_run_status_to_thread_status(self.result.status)
await self._emit(
LifecycleEventType.THREAD_STATUS_UPDATED,
{"status": final_thread_status},
)
async def run(self) -> RunResult:
try:
await self._start()
await self._prepare()
await self._stream()
await self._finish_after_stream()
except asyncio.CancelledError:
await self._finish_aborted("task_cancelled")
except Exception as exc:
await self._finish_failed(exc)
finally:
await self._emit_final_thread_status()
if self._callback_artifacts is not None:
await self._callback_artifacts.flush()
await self.bridge.cleanup(self.run_id)
return self.result
__all__ = ["_RunExecution"]
@@ -0,0 +1,93 @@
"""Execution-local stream processing helpers."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class StreamItem:
"""Normalized stream item from LangGraph."""
mode: str
chunk: Any
_FILTERED_NODES = frozenset({"__start__", "__end__"})
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
def normalize_stream_modes(requested_modes: list[str] | None) -> list[str]:
"""Normalize requested stream modes to valid LangGraph modes."""
input_modes: list[str] = list(requested_modes or ["values"])
lg_modes: list[str] = []
for mode in input_modes:
if mode == "messages-tuple":
lg_modes.append("messages")
elif mode == "events":
logger.info("'events' stream_mode not supported (requires astream_events). Skipping.")
continue
elif mode in _VALID_LG_MODES:
lg_modes.append(mode)
if not lg_modes:
lg_modes = ["values"]
seen: set[str] = set()
deduped: list[str] = []
for mode in lg_modes:
if mode not in seen:
seen.add(mode)
deduped.append(mode)
return deduped
def unpack_stream_item(
item: Any,
lg_modes: list[str],
*,
stream_subgraphs: bool,
) -> tuple[str | None, Any]:
"""Unpack a multi-mode or subgraph stream item into ``(mode, chunk)``."""
if stream_subgraphs:
if isinstance(item, tuple) and len(item) == 3:
_namespace, mode, chunk = item
return str(mode), chunk
if isinstance(item, tuple) and len(item) == 2:
mode, chunk = item
return str(mode), chunk
return None, None
if isinstance(item, tuple) and len(item) == 2:
mode, chunk = item
return str(mode), chunk
return lg_modes[0] if lg_modes else None, item
def should_filter_event(mode: str, chunk: Any) -> bool:
"""Determine whether a stream event should be filtered before publish."""
if mode == "updates" and isinstance(chunk, dict):
node_names = set(chunk.keys())
if node_names & _FILTERED_NODES:
return True
if mode == "messages" and isinstance(chunk, tuple) and len(chunk) == 2:
_, metadata = chunk
if isinstance(metadata, dict):
node = metadata.get("langgraph_node", "")
if node in _FILTERED_NODES:
return True
return False
def external_stream_event_name(mode: str) -> str:
"""Map LangGraph internal modes to the external SSE event contract."""
return mode
@@ -0,0 +1,78 @@
"""Active execution handle management for runs domain."""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
from ...types import CancelAction
@dataclass
class RunHandle:
"""In-process control handle for an active run."""
run_id: str
task: asyncio.Task[Any] | None = None
cancel_event: asyncio.Event = field(default_factory=asyncio.Event)
cancel_action: CancelAction = "interrupt"
class RunSupervisor:
"""Own and control active run handles within the current process."""
def __init__(self) -> None:
self._handles: dict[str, RunHandle] = {}
self._lock = asyncio.Lock()
async def launch(
self,
run_id: str,
*,
runner: Callable[[RunHandle], Awaitable[Any]],
) -> RunHandle:
"""Create a handle and start a background task for it."""
handle = RunHandle(run_id=run_id)
async with self._lock:
if run_id in self._handles:
raise RuntimeError(f"Run {run_id} is already active")
self._handles[run_id] = handle
task = asyncio.create_task(runner(handle))
handle.task = task
task.add_done_callback(lambda _: asyncio.create_task(self.cleanup(run_id)))
return handle
async def cancel(
self,
run_id: str,
*,
action: CancelAction = "interrupt",
) -> bool:
"""Signal cancellation for an active handle."""
async with self._lock:
handle = self._handles.get(run_id)
if handle is None:
return False
handle.cancel_action = action
handle.cancel_event.set()
if handle.task is not None and not handle.task.done():
handle.task.cancel()
return True
def get_handle(self, run_id: str) -> RunHandle | None:
"""Return the active handle for a run, if any."""
return self._handles.get(run_id)
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
"""Remove a handle after optional delay."""
if delay > 0:
await asyncio.sleep(delay)
async with self._lock:
self._handles.pop(run_id, None)
@@ -7,12 +7,9 @@ import logging
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Literal
from .schemas import DisconnectMode, RunStatus
if TYPE_CHECKING:
from deerflow.runtime.runs.store.base import RunStore
from ..types import RunStatus
logger = logging.getLogger(__name__)
@@ -29,7 +26,7 @@ class RunRecord:
thread_id: str
assistant_id: str | None
status: RunStatus
on_disconnect: DisconnectMode
on_disconnect: Literal["cancel", "continue"]
multitask_strategy: str = "reject"
metadata: dict = field(default_factory=dict)
kwargs: dict = field(default_factory=dict)
@@ -49,12 +46,12 @@ class RunManager:
that run history survives process restarts.
"""
def __init__(self, store: RunStore | None = None) -> None:
def __init__(self, store: Any | None = None) -> None:
self._runs: dict[str, RunRecord] = {}
self._lock = asyncio.Lock()
self._store = store
async def _persist_to_store(self, record: RunRecord) -> None:
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
"""Best-effort persist run record to backing store."""
if self._store is None:
return
@@ -68,6 +65,7 @@ class RunManager:
metadata=record.metadata or {},
kwargs=record.kwargs or {},
created_at=record.created_at,
follow_up_to_run_id=follow_up_to_run_id,
)
except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
@@ -85,10 +83,11 @@ class RunManager:
thread_id: str,
assistant_id: str | None = None,
*,
on_disconnect: DisconnectMode = DisconnectMode.cancel,
on_disconnect: Literal["cancel", "continue"] = "cancel",
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord:
"""Create a new pending run and register it."""
run_id = str(uuid.uuid4())
@@ -107,7 +106,7 @@ class RunManager:
)
async with self._lock:
self._runs[run_id] = record
await self._persist_to_store(record)
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
@@ -120,7 +119,7 @@ class RunManager:
async with self._lock:
# Dict insertion order matches creation order, so reversing it gives
# us deterministic newest-first results even when timestamps tie.
return [r for r in self._runs.values() if r.thread_id == thread_id]
return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id]
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
"""Transition a run to a new status."""
@@ -170,10 +169,11 @@ class RunManager:
thread_id: str,
assistant_id: str | None = None,
*,
on_disconnect: DisconnectMode = DisconnectMode.cancel,
on_disconnect: Literal["cancel", "continue"] = "cancel",
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.
@@ -227,7 +227,7 @@ class RunManager:
)
self._runs[run_id] = record
await self._persist_to_store(record)
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
@@ -0,0 +1,42 @@
"""Execution plan builder for runs domain."""
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Literal
from ..types import RunRecord, RunSpec
@dataclass(frozen=True)
class ExecutionPlan:
"""Normalized execution inputs derived from a run record and spec."""
record: RunRecord
graph_input: dict[str, Any]
runnable_config: dict[str, Any]
stream_modes: list[str]
stream_subgraphs: bool
interrupt_before: list[str] | Literal["*"] | None
interrupt_after: list[str] | Literal["*"] | None
class ExecutionPlanner:
"""Build executor-ready plans from public run specs."""
def build(self, record: RunRecord, spec: RunSpec) -> ExecutionPlan:
return ExecutionPlan(
record=record,
graph_input=self._normalize_graph_input(spec.input),
runnable_config=deepcopy(spec.runnable_config),
stream_modes=list(spec.stream_modes),
stream_subgraphs=spec.stream_subgraphs,
interrupt_before=spec.interrupt_before,
interrupt_after=spec.interrupt_after,
)
def _normalize_graph_input(self, raw_input: dict[str, Any] | None) -> dict[str, Any]:
if raw_input is None:
return {}
return deepcopy(raw_input)
@@ -0,0 +1,146 @@
"""In-memory run registry for runs domain state."""
from __future__ import annotations
import asyncio
import uuid
from datetime import datetime, timezone
from typing import Any
from ..types import INFLIGHT_STATUSES, RunRecord, RunSpec, RunStatus
class RunRegistry:
"""In-memory source of truth for run records and their status."""
def __init__(self) -> None:
self._records: dict[str, RunRecord] = {}
self._thread_index: dict[str, set[str]] = {} # thread_id -> set[run_id]
self._lock = asyncio.Lock()
async def create(self, spec: RunSpec) -> RunRecord:
"""Create a new RunRecord from RunSpec."""
run_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
record = RunRecord(
run_id=run_id,
thread_id=spec.scope.thread_id,
assistant_id=spec.assistant_id,
status="pending",
temporary=spec.scope.temporary,
multitask_strategy=spec.multitask_strategy,
metadata=dict(spec.metadata),
follow_up_to_run_id=spec.follow_up_to_run_id,
created_at=now,
updated_at=now,
)
async with self._lock:
self._records[run_id] = record
# Update thread index
if spec.scope.thread_id not in self._thread_index:
self._thread_index[spec.scope.thread_id] = set()
self._thread_index[spec.scope.thread_id].add(run_id)
return record
def get(self, run_id: str) -> RunRecord | None:
"""Get RunRecord by run_id."""
return self._records.get(run_id)
async def list_by_thread(self, thread_id: str) -> list[RunRecord]:
"""List all RunRecords for a thread."""
async with self._lock:
run_ids = self._thread_index.get(thread_id, set())
return [self._records[rid] for rid in run_ids if rid in self._records]
async def set_status(
self,
run_id: str,
status: RunStatus,
*,
error: str | None = None,
started_at: str | None = None,
ended_at: str | None = None,
) -> None:
"""Update run status and optional fields."""
async with self._lock:
record = self._records.get(run_id)
if record is None:
return
record.status = status
record.updated_at = datetime.now(timezone.utc).isoformat()
if error is not None:
record.error = error
if started_at is not None:
record.started_at = started_at
if ended_at is not None:
record.ended_at = ended_at
async def has_inflight(self, thread_id: str) -> bool:
"""Check if thread has any inflight runs."""
async with self._lock:
run_ids = self._thread_index.get(thread_id, set())
for rid in run_ids:
record = self._records.get(rid)
if record and record.status in INFLIGHT_STATUSES:
return True
return False
async def interrupt_inflight(self, thread_id: str) -> list[str]:
"""
Mark all inflight runs for a thread as interrupted.
Returns list of interrupted run_ids.
"""
interrupted: list[str] = []
now = datetime.now(timezone.utc).isoformat()
async with self._lock:
run_ids = self._thread_index.get(thread_id, set())
for rid in run_ids:
record = self._records.get(rid)
if record and record.status in INFLIGHT_STATUSES:
record.status = "interrupted"
record.updated_at = now
record.ended_at = now
interrupted.append(rid)
return interrupted
async def update_metadata(self, run_id: str, metadata: dict[str, Any]) -> None:
"""Update run metadata."""
async with self._lock:
record = self._records.get(run_id)
if record is not None:
record.metadata.update(metadata)
record.updated_at = datetime.now(timezone.utc).isoformat()
async def delete(self, run_id: str) -> bool:
"""Delete a run record. Returns True if deleted."""
async with self._lock:
record = self._records.pop(run_id, None)
if record is None:
return False
# Update thread index
thread_runs = self._thread_index.get(record.thread_id)
if thread_runs:
thread_runs.discard(run_id)
return True
def count(self) -> int:
"""Return total number of records."""
return len(self._records)
def count_by_status(self, status: RunStatus) -> int:
"""Return count of records with given status."""
return sum(1 for r in self._records.values() if r.status == status)
# Compatibility alias during the refactor.
RuntimeRunRegistry = RunRegistry
@@ -0,0 +1,76 @@
"""Internal run stream adapter over StreamBridge."""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from deerflow.runtime.stream_bridge import JSONValue, StreamBridge, StreamEvent
from deerflow.runtime.stream_bridge import StreamStatus
class RunStreamService:
"""Thin runs-domain adapter over the harness stream bridge contract."""
def __init__(self, bridge: "StreamBridge") -> None:
self._bridge = bridge
async def publish_event(
self,
run_id: str,
*,
event: str,
data: "JSONValue",
) -> str:
"""Publish a replayable run event."""
return await self._bridge.publish(run_id, event, data)
async def publish_end(self, run_id: str) -> str:
"""Publish a successful terminal signal."""
return await self._bridge.publish_terminal(run_id, StreamStatus.ENDED)
async def publish_cancelled(
self,
run_id: str,
*,
data: "JSONValue" = None,
) -> str:
"""Publish a cancelled terminal signal."""
return await self._bridge.publish_terminal(
run_id,
StreamStatus.CANCELLED,
data,
)
async def publish_error(
self,
run_id: str,
*,
data: "JSONValue",
) -> str:
"""Publish a failed terminal signal."""
return await self._bridge.publish_terminal(
run_id,
StreamStatus.ERRORED,
data,
)
def subscribe(
self,
run_id: str,
*,
last_event_id: str | None = None,
heartbeat_interval: float = 15.0,
) -> AsyncIterator[StreamEvent]:
"""Subscribe to a run stream with resume support."""
return self._bridge.subscribe(
run_id,
last_event_id=last_event_id,
heartbeat_interval=heartbeat_interval,
)
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
"""Release per-run bridge resources after completion."""
await self._bridge.cleanup(run_id, delay=delay)
@@ -0,0 +1,95 @@
"""Internal run wait helpers based on stream events."""
from __future__ import annotations
from typing import Any
from deerflow.runtime.stream_bridge import StreamEvent
from .streams import RunStreamService
class WaitTimeoutError(TimeoutError):
"""Raised when wait times out."""
pass
class WaitErrorResult:
"""Represents an error result from wait."""
def __init__(self, error: str, details: dict[str, Any] | None = None) -> None:
self.error = error
self.details = details or {}
def to_dict(self) -> dict[str, Any]:
return {"error": self.error, **self.details}
class RunWaitService:
"""
Wait service for runs domain.
Based on RunStreamService.subscribe(), implements wait semantics.
Phase 1 behavior:
- Records last 'values' event
- On 'error', returns unified error structure
- On 'end' only, returns last values
"""
TERMINAL_EVENTS = frozenset({"end", "error", "cancel"})
def __init__(self, stream_service: RunStreamService) -> None:
self._stream_service = stream_service
async def wait_for_terminal(
self,
run_id: str,
*,
last_event_id: str | None = None,
) -> StreamEvent | None:
"""Block until the next terminal event for a run is observed."""
async for event in self._stream_service.subscribe(
run_id,
last_event_id=last_event_id,
):
if event.event in self.TERMINAL_EVENTS:
return event
return None
async def wait_for_values_or_error(
self,
run_id: str,
*,
last_event_id: str | None = None,
) -> dict[str, Any] | WaitErrorResult | None:
"""
Wait for run to complete and return final values or error.
Returns:
- dict: Final values if successful
- WaitErrorResult: If run failed
- None: If no values were produced
"""
last_values: dict[str, Any] | None = None
async for event in self._stream_service.subscribe(
run_id,
last_event_id=last_event_id,
):
if event.event == "values":
last_values = event.data
elif event.event == "error":
return WaitErrorResult(
error=str(event.data) if event.data else "Unknown error",
details={"run_id": run_id},
)
elif event.event in self.TERMINAL_EVENTS:
# Stream ended, return last values
break
return last_values
@@ -0,0 +1,203 @@
"""Run lifecycle observer types for decoupled observation.
Defines the RunObserver protocol and lifecycle event types that allow
the harness layer to emit notifications without directly calling
storage implementations.
The app layer provides concrete observers (e.g., StorageObserver) that
map lifecycle events to persistence operations.
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Protocol, runtime_checkable
from .types import RunStatus
# Callback type for lightweight observer registration
type RunEventCallback = Callable[["RunLifecycleEvent"], Awaitable[None]]
class LifecycleEventType(str, Enum):
"""Lifecycle event types emitted during run execution."""
# Run lifecycle
RUN_STARTED = "run_started"
RUN_COMPLETED = "run_completed"
RUN_FAILED = "run_failed"
RUN_CANCELLED = "run_cancelled"
# Human message (for event store)
HUMAN_MESSAGE = "human_message"
# Thread status updates
THREAD_STATUS_UPDATED = "thread_status_updated"
@dataclass(frozen=True)
class RunLifecycleEvent:
"""A single lifecycle event emitted during run execution.
Attributes:
event_type: The type of lifecycle event.
run_id: The run that emitted this event.
thread_id: The thread this run belongs to.
payload: Event-specific data (varies by event_type).
"""
event_id: str
event_type: LifecycleEventType
run_id: str
thread_id: str
sequence: int
occurred_at: datetime
payload: Mapping[str, Any] = field(default_factory=dict)
@dataclass
class RunResult:
"""Minimal result returned after run execution.
Contains only the data needed for the caller to understand
what happened. Detailed events are delivered via observer.
Attributes:
run_id: The run ID.
thread_id: The thread ID.
status: Final status (success, error, interrupted, etc.).
error: Error message if status is error.
completion_data: Token usage and message counts from journal.
title: Thread title extracted from checkpoint (if available).
"""
run_id: str
thread_id: str
status: RunStatus
error: str | None = None
completion_data: dict[str, Any] = field(default_factory=dict)
title: str | None = None
@runtime_checkable
class RunObserver(Protocol):
"""Protocol for observing run lifecycle events.
Implementations receive events as they occur during execution
and can perform side effects (storage, logging, metrics, etc.)
without coupling the worker to specific implementations.
Methods are async to support IO-bound operations like database writes.
"""
async def on_event(self, event: RunLifecycleEvent) -> None:
"""Called when a lifecycle event occurs.
Args:
event: The lifecycle event with type, IDs, and payload.
Implementations should be explicit about failure handling.
CompositeObserver can be configured to either swallow or raise
observer failures based on each binding's ``required`` flag.
"""
...
@dataclass(frozen=True)
class ObserverBinding:
"""Observer registration with failure policy.
Attributes:
observer: Observer instance to invoke.
required: When True, observer failures are raised to the caller.
When False, failures are logged and dispatch continues.
"""
observer: RunObserver
required: bool = False
class CompositeObserver:
"""Observer that delegates to multiple child observers.
Useful for combining storage, metrics, and logging observers.
Optional observers are logged on failure; required observers raise.
"""
def __init__(
self,
observers: list[RunObserver | ObserverBinding] | None = None,
) -> None:
self._observers: list[ObserverBinding] = [
obs if isinstance(obs, ObserverBinding) else ObserverBinding(obs)
for obs in (observers or [])
]
def add(self, observer: RunObserver, *, required: bool = False) -> None:
"""Add an observer to the composite."""
self._observers.append(ObserverBinding(observer=observer, required=required))
async def on_event(self, event: RunLifecycleEvent) -> None:
"""Dispatch event to all child observers."""
logger = logging.getLogger(__name__)
for binding in self._observers:
try:
await binding.observer.on_event(event)
except Exception:
if binding.required:
raise
logger.warning(
"Observer %s failed on event %s",
type(binding.observer).__name__,
event.event_type.value,
exc_info=True,
)
class NullObserver:
"""No-op observer for when no observation is needed."""
async def on_event(self, event: RunLifecycleEvent) -> None:
"""Do nothing."""
pass
@dataclass(slots=True)
class CallbackObserver:
"""Adapter that wraps a callback function as a RunObserver.
Allows lightweight callback functions to participate in the
observer protocol without defining a full class.
"""
callback: RunEventCallback
async def on_event(self, event: RunLifecycleEvent) -> None:
"""Invoke the wrapped callback with the event."""
await self.callback(event)
type ObserverLike = RunObserver | RunEventCallback | None
def ensure_observer(observer: ObserverLike) -> RunObserver:
"""Normalize an observer-like value to a RunObserver.
Args:
observer: Can be:
- None: returns NullObserver
- A callable: wraps in CallbackObserver
- A RunObserver: returns as-is
Returns:
A RunObserver instance.
"""
if observer is None:
return NullObserver()
if callable(observer) and not isinstance(observer, RunObserver):
return CallbackObserver(observer)
return observer
@@ -1,21 +0,0 @@
"""Run status and disconnect mode enums."""
from enum import StrEnum
class RunStatus(StrEnum):
"""Lifecycle status of a single run."""
pending = "pending"
running = "running"
success = "success"
error = "error"
timeout = "timeout"
interrupted = "interrupted"
class DisconnectMode(StrEnum):
"""Behaviour when the SSE consumer disconnects."""
cancel = "cancel"
continue_ = "continue"
@@ -1,4 +1,13 @@
from deerflow.runtime.runs.store.base import RunStore
from deerflow.runtime.runs.store.memory import MemoryRunStore
"""Store boundary protocols for runs."""
__all__ = ["MemoryRunStore", "RunStore"]
from .create_store import RunCreateStore
from .delete_store import RunDeleteStore
from .event_store import RunEventStore
from .query_store import RunQueryStore
__all__ = [
"RunCreateStore",
"RunDeleteStore",
"RunEventStore",
"RunQueryStore",
]
@@ -1,95 +0,0 @@
"""Abstract interface for run metadata storage.
RunManager depends on this interface. Implementations:
- MemoryRunStore: in-memory dict (development, tests)
- Future: RunRepository backed by SQLAlchemy ORM
All methods accept an optional user_id for user isolation.
When user_id is None, no user filtering is applied (single-user mode).
"""
from __future__ import annotations
import abc
from typing import Any
class RunStore(abc.ABC):
@abc.abstractmethod
async def put(
self,
run_id: str,
*,
thread_id: str,
assistant_id: str | None = None,
user_id: str | None = None,
status: str = "pending",
multitask_strategy: str = "reject",
metadata: dict[str, Any] | None = None,
kwargs: dict[str, Any] | None = None,
error: str | None = None,
created_at: str | None = None,
) -> None:
pass
@abc.abstractmethod
async def get(self, run_id: str) -> dict[str, Any] | None:
pass
@abc.abstractmethod
async def list_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int = 100,
) -> list[dict[str, Any]]:
pass
@abc.abstractmethod
async def update_status(
self,
run_id: str,
status: str,
*,
error: str | None = None,
) -> None:
pass
@abc.abstractmethod
async def delete(self, run_id: str) -> None:
pass
@abc.abstractmethod
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
last_ai_message: str | None = None,
first_human_message: str | None = None,
error: str | None = None,
) -> None:
pass
@abc.abstractmethod
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass
@abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread.
Returns a dict with keys: total_tokens, total_input_tokens,
total_output_tokens, total_runs, by_model (model_name → {tokens, runs}),
by_caller ({lead_agent, subagent, middleware}).
"""
pass
@@ -0,0 +1,13 @@
"""Create-side boundary for durable run initialization."""
from __future__ import annotations
from typing import Protocol
from ..types import RunRecord
class RunCreateStore(Protocol):
"""Persist the initial durable row for a newly created run."""
async def create_run(self, record: RunRecord) -> None: ...
@@ -0,0 +1,11 @@
"""Delete-side durable boundary for runs."""
from __future__ import annotations
from typing import Protocol
class RunDeleteStore(Protocol):
"""Minimal protocol for removing durable run records."""
async def delete_run(self, run_id: str) -> bool: ...
@@ -0,0 +1,11 @@
"""Run event store boundary used by runs callbacks."""
from __future__ import annotations
from typing import Any, Protocol
class RunEventStore(Protocol):
"""Minimal append-only event store protocol for execution callbacks."""
async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]: ...
@@ -1,98 +0,0 @@
"""In-memory RunStore. Used when database.backend=memory (default) and in tests.
Equivalent to the original RunManager._runs dict behavior.
"""
from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
from deerflow.runtime.runs.store.base import RunStore
class MemoryRunStore(RunStore):
def __init__(self) -> None:
self._runs: dict[str, dict[str, Any]] = {}
async def put(
self,
run_id,
*,
thread_id,
assistant_id=None,
user_id=None,
status="pending",
multitask_strategy="reject",
metadata=None,
kwargs=None,
error=None,
created_at=None,
):
now = datetime.now(UTC).isoformat()
self._runs[run_id] = {
"run_id": run_id,
"thread_id": thread_id,
"assistant_id": assistant_id,
"user_id": user_id,
"status": status,
"multitask_strategy": multitask_strategy,
"metadata": metadata or {},
"kwargs": kwargs or {},
"error": error,
"created_at": created_at or now,
"updated_at": now,
}
async def get(self, run_id):
return self._runs.get(run_id)
async def list_by_thread(self, thread_id, *, user_id=None, limit=100):
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)]
results.sort(key=lambda r: r["created_at"], reverse=True)
return results[:limit]
async def update_status(self, run_id, status, *, error=None):
if run_id in self._runs:
self._runs[run_id]["status"] = status
if error is not None:
self._runs[run_id]["error"] = error
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def delete(self, run_id):
self._runs.pop(run_id, None)
async def update_run_completion(self, run_id, *, status, **kwargs):
if run_id in self._runs:
self._runs[run_id]["status"] = status
for key, value in kwargs.items():
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"])
return results
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
by_model: dict[str, dict] = {}
for r in completed:
model = r.get("model_name") or "unknown"
entry = by_model.setdefault(model, {"tokens": 0, "runs": 0})
entry["tokens"] += r.get("total_tokens", 0)
entry["runs"] += 1
return {
"total_tokens": sum(r.get("total_tokens", 0) for r in completed),
"total_input_tokens": sum(r.get("total_input_tokens", 0) for r in completed),
"total_output_tokens": sum(r.get("total_output_tokens", 0) for r in completed),
"total_runs": len(completed),
"by_model": by_model,
"by_caller": {
"lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed),
"subagent": sum(r.get("subagent_tokens", 0) for r in completed),
"middleware": sum(r.get("middleware_tokens", 0) for r in completed),
},
}
@@ -0,0 +1,20 @@
"""Read-side boundary for durable run queries."""
from __future__ import annotations
from typing import Protocol
from ..types import RunRecord
class RunQueryStore(Protocol):
"""Read durable run records for public query APIs."""
async def get_run(self, run_id: str) -> RunRecord | None: ...
async def list_runs(
self,
thread_id: str,
*,
limit: int = 100,
) -> list[RunRecord]: ...
@@ -0,0 +1,117 @@
"""Public runs domain types."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import StrEnum
from typing import Any, Literal
# Intent: 表示请求的意图
RunIntent = Literal[
"create_background",
"create_and_stream",
"create_and_wait",
"join_stream",
"join_wait",
]
# Scope kind: stateful (需要 thread_id) vs stateless (临时 thread)
RunScopeKind = Literal["stateful", "stateless"]
class RunStatus(StrEnum):
pending = "pending"
starting = "starting"
running = "running"
success = "success"
error = "error"
interrupted = "interrupted"
timeout = "timeout"
CancelAction = Literal["interrupt", "rollback"]
@dataclass(frozen=True)
class RunScope:
"""Run 的作用域 - stateful 需要 thread_id, stateless 自动创建临时 thread."""
kind: RunScopeKind
thread_id: str
temporary: bool = False
@dataclass(frozen=True)
class CheckpointRequest:
"""Checkpoint 恢复请求 - phase1 只接受但不实现 restore."""
checkpoint_id: str | None = None
checkpoint: dict[str, Any] | None = None
@dataclass(frozen=True)
class RunSpec:
"""
Run 规格对象 - 由 app 输入层构建,是执行器的输入。
Phase 1 限制:
- multitask_strategy 只支持 reject/interrupt
- 不支持 enqueue/rollback/after_seconds/batch
"""
intent: RunIntent
scope: RunScope
assistant_id: str | None
input: dict[str, Any] | None
command: dict[str, Any] | None
runnable_config: dict[str, Any]
context: dict[str, Any] | None
metadata: dict[str, Any]
stream_modes: list[str]
stream_subgraphs: bool
stream_resumable: bool
on_disconnect: Literal["cancel", "continue"]
on_completion: Literal["delete", "keep"]
multitask_strategy: Literal["reject", "interrupt"]
interrupt_before: list[str] | Literal["*"] | None
interrupt_after: list[str] | Literal["*"] | None
checkpoint_request: CheckpointRequest | None
follow_up_to_run_id: str | None = None
webhook: str | None = None
feedback_keys: list[str] | None = None
type WaitResult = dict[str, Any] | None
@dataclass
class RunRecord:
"""
运行时 Run 记录 - 由 RuntimeRunRegistry 管理。
与 ORM 模型解耦,只在内存中维护。
"""
run_id: str
thread_id: str
assistant_id: str | None
status: RunStatus
temporary: bool
multitask_strategy: str
metadata: dict[str, Any] = field(default_factory=dict)
follow_up_to_run_id: str | None = None
created_at: str = ""
updated_at: str = ""
started_at: str | None = None
ended_at: str | None = None
error: str | None = None
def __post_init__(self) -> None:
if not self.created_at:
now = datetime.now(timezone.utc).isoformat()
self.created_at = now
self.updated_at = now
# Terminal statuses for quick checks
TERMINAL_STATUSES: frozenset[RunStatus] = frozenset({"success", "error", "interrupted"})
INFLIGHT_STATUSES: frozenset[RunStatus] = frozenset({"pending", "starting", "running"})
@@ -1,493 +0,0 @@
"""Background agent execution.
Runs an agent graph inside an ``asyncio.Task``, publishing events to
a :class:`StreamBridge` as they are produced.
Uses ``graph.astream(stream_mode=[...])`` which gives correct full-state
snapshots for ``values`` mode, proper ``{node: writes}`` for ``updates``,
and ``(chunk, metadata)`` tuples for ``messages`` mode.
Note: ``events`` mode is not supported through the gateway — it requires
``graph.astream_events()`` which cannot simultaneously produce ``values``
snapshots. The JS open-source LangGraph API server works around this via
internal checkpoint callbacks that are not exposed in the Python public API.
"""
from __future__ import annotations
import asyncio
import copy
import inspect
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from langchain_core.messages import HumanMessage
from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge
from .manager import RunManager, RunRecord
from .schemas import RunStatus
logger = logging.getLogger(__name__)
# Valid stream_mode values for LangGraph's graph.astream()
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
@dataclass(frozen=True)
class RunContext:
"""Infrastructure dependencies for a single agent run.
Groups checkpointer, store, and persistence-related singletons so that
``run_agent`` (and any future callers) receive one object instead of a
growing list of keyword arguments.
"""
checkpointer: Any
store: Any | None = field(default=None)
event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None)
thread_store: Any | None = field(default=None)
async def run_agent(
bridge: StreamBridge,
run_manager: RunManager,
record: RunRecord,
*,
ctx: RunContext,
agent_factory: Any,
graph_input: dict,
config: dict,
stream_modes: list[str] | None = None,
stream_subgraphs: bool = False,
interrupt_before: list[str] | Literal["*"] | None = None,
interrupt_after: list[str] | Literal["*"] | None = None,
) -> None:
"""Execute an agent in the background, publishing events to *bridge*."""
# Unpack infrastructure dependencies from RunContext.
checkpointer = ctx.checkpointer
store = ctx.store
event_store = ctx.event_store
run_events_config = ctx.run_events_config
thread_store = ctx.thread_store
run_id = record.run_id
thread_id = record.thread_id
requested_modes: set[str] = set(stream_modes or ["values"])
pre_run_checkpoint_id: str | None = None
pre_run_snapshot: dict[str, Any] | None = None
snapshot_capture_failed = False
journal = None
journal = None
# Track whether "events" was requested but skipped
if "events" in requested_modes:
logger.info(
"Run %s: 'events' stream_mode not supported in gateway (requires astream_events + checkpoint callbacks). Skipping.",
run_id,
)
try:
# Initialize RunJournal + write human_message event.
# These are inside the try block so any exception (e.g. a DB
# error writing the event) flows through the except/finally
# path that publishes an "end" event to the SSE bridge —
# otherwise a failure here would leave the stream hanging
# with no terminator.
if event_store is not None:
from deerflow.runtime.journal import RunJournal
journal = RunJournal(
run_id=run_id,
thread_id=thread_id,
event_store=event_store,
track_token_usage=getattr(run_events_config, "track_token_usage", True),
)
# 1. Mark running
await run_manager.set_status(run_id, RunStatus.running)
# Snapshot the latest pre-run checkpoint so rollback can restore it.
if checkpointer is not None:
try:
config_for_check = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(config_for_check)
if ckpt_tuple is not None:
ckpt_config = getattr(ckpt_tuple, "config", {}).get("configurable", {})
pre_run_checkpoint_id = ckpt_config.get("checkpoint_id")
pre_run_snapshot = {
"checkpoint_ns": ckpt_config.get("checkpoint_ns", ""),
"checkpoint": copy.deepcopy(getattr(ckpt_tuple, "checkpoint", {})),
"metadata": copy.deepcopy(getattr(ckpt_tuple, "metadata", {})),
"pending_writes": copy.deepcopy(getattr(ckpt_tuple, "pending_writes", []) or []),
}
except Exception:
snapshot_capture_failed = True
logger.warning("Could not capture pre-run checkpoint snapshot for run %s", run_id, exc_info=True)
# 2. Publish metadata — useStream needs both run_id AND thread_id
await bridge.publish(
run_id,
"metadata",
{
"run_id": run_id,
"thread_id": thread_id,
},
)
# 3. Build the agent
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
# Inject runtime context so middlewares can access thread_id
# (langgraph-cli does this automatically; we must do it manually)
runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store)
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
# prefers it over ``configurable`` for thread-level data), make
# sure ``thread_id`` is available there too.
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config["context"].setdefault("run_id", run_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Inject RunJournal as a LangChain callback handler.
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
if journal is not None:
config.setdefault("callbacks", []).append(journal)
runnable_config = RunnableConfig(**config)
agent = agent_factory(config=runnable_config)
# 4. Attach checkpointer and store
if checkpointer is not None:
agent.checkpointer = checkpointer
if store is not None:
agent.store = store
# 5. Set interrupt nodes
if interrupt_before:
agent.interrupt_before_nodes = interrupt_before
if interrupt_after:
agent.interrupt_after_nodes = interrupt_after
# 6. Build LangGraph stream_mode list
# "events" is NOT a valid astream mode — skip it
# "messages-tuple" maps to LangGraph's "messages" mode
lg_modes: list[str] = []
for m in requested_modes:
if m == "messages-tuple":
lg_modes.append("messages")
elif m == "events":
# Skipped — see log above
continue
elif m in _VALID_LG_MODES:
lg_modes.append(m)
if not lg_modes:
lg_modes = ["values"]
# Deduplicate while preserving order
seen: set[str] = set()
deduped: list[str] = []
for m in lg_modes:
if m not in seen:
seen.add(m)
deduped.append(m)
lg_modes = deduped
logger.info("Run %s: streaming with modes %s (requested: %s)", run_id, lg_modes, requested_modes)
# 7. Stream using graph.astream
if len(lg_modes) == 1 and not stream_subgraphs:
# Single mode, no subgraphs: astream yields raw chunks
single_mode = lg_modes[0]
async for chunk in agent.astream(graph_input, config=runnable_config, stream_mode=single_mode):
if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id)
break
sse_event = _lg_mode_to_sse_event(single_mode)
await bridge.publish(run_id, sse_event, serialize(chunk, mode=single_mode))
else:
# Multiple modes or subgraphs: astream yields tuples
async for item in agent.astream(
graph_input,
config=runnable_config,
stream_mode=lg_modes,
subgraphs=stream_subgraphs,
):
if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id)
break
mode, chunk = _unpack_stream_item(item, lg_modes, stream_subgraphs)
if mode is None:
continue
sse_event = _lg_mode_to_sse_event(mode)
await bridge.publish(run_id, sse_event, serialize(chunk, mode=mode))
# 8. Final status
if record.abort_event.is_set():
action = record.abort_action
if action == "rollback":
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
try:
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id=thread_id,
run_id=run_id,
pre_run_checkpoint_id=pre_run_checkpoint_id,
pre_run_snapshot=pre_run_snapshot,
snapshot_capture_failed=snapshot_capture_failed,
)
logger.info("Run %s rolled back to pre-run checkpoint %s", run_id, pre_run_checkpoint_id)
except Exception:
logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True)
else:
await run_manager.set_status(run_id, RunStatus.interrupted)
else:
await run_manager.set_status(run_id, RunStatus.success)
except asyncio.CancelledError:
action = record.abort_action
if action == "rollback":
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
try:
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id=thread_id,
run_id=run_id,
pre_run_checkpoint_id=pre_run_checkpoint_id,
pre_run_snapshot=pre_run_snapshot,
snapshot_capture_failed=snapshot_capture_failed,
)
logger.info("Run %s was cancelled and rolled back", run_id)
except Exception:
logger.warning("Run %s cancellation rollback failed", run_id, exc_info=True)
else:
await run_manager.set_status(run_id, RunStatus.interrupted)
logger.info("Run %s was cancelled", run_id)
except Exception as exc:
error_msg = f"{exc}"
logger.exception("Run %s failed: %s", run_id, error_msg)
await run_manager.set_status(run_id, RunStatus.error, error=error_msg)
await bridge.publish(
run_id,
"error",
{
"message": error_msg,
"name": type(exc).__name__,
},
)
finally:
# Flush any buffered journal events and persist completion data
if journal is not None:
try:
await journal.flush()
except Exception:
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
try:
# Persist token usage + convenience fields to RunStore
completion = journal.get_completion_data()
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
except Exception:
logger.warning("Failed to persist run completion for %s (non-fatal)", run_id, exc_info=True)
# Sync title from checkpoint to threads_meta.display_name
if checkpointer is not None and thread_store is not None:
try:
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
if ckpt_tuple is not None:
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
title = ckpt.get("channel_values", {}).get("title")
if title:
await thread_store.update_display_name(thread_id, title)
except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
# Update threads_meta status based on run outcome
if thread_store is not None:
try:
final_status = "idle" if record.status == RunStatus.success else record.status.value
await thread_store.update_status(thread_id, final_status)
except Exception:
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
await bridge.publish_end(run_id)
asyncio.create_task(bridge.cleanup(run_id, delay=60))
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _call_checkpointer_method(checkpointer: Any, async_name: str, sync_name: str, *args: Any, **kwargs: Any) -> Any:
"""Call a checkpointer method, supporting async and sync variants."""
method = getattr(checkpointer, async_name, None) or getattr(checkpointer, sync_name, None)
if method is None:
raise AttributeError(f"Missing checkpointer method: {async_name}/{sync_name}")
result = method(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result
async def _rollback_to_pre_run_checkpoint(
*,
checkpointer: Any,
thread_id: str,
run_id: str,
pre_run_checkpoint_id: str | None,
pre_run_snapshot: dict[str, Any] | None,
snapshot_capture_failed: bool,
) -> None:
"""Restore thread state to the checkpoint snapshot captured before run start."""
if checkpointer is None:
logger.info("Run %s rollback requested but no checkpointer is configured", run_id)
return
if snapshot_capture_failed:
logger.warning("Run %s rollback skipped: pre-run checkpoint snapshot capture failed", run_id)
return
if pre_run_snapshot is None:
await _call_checkpointer_method(checkpointer, "adelete_thread", "delete_thread", thread_id)
logger.info("Run %s rollback reset thread %s to empty state", run_id, thread_id)
return
checkpoint_to_restore = None
metadata_to_restore: dict[str, Any] = {}
checkpoint_ns = ""
checkpoint = pre_run_snapshot.get("checkpoint")
if not isinstance(checkpoint, dict):
logger.warning("Run %s rollback skipped: invalid pre-run checkpoint snapshot", run_id)
return
checkpoint_to_restore = checkpoint
if checkpoint_to_restore.get("id") is None and pre_run_checkpoint_id is not None:
checkpoint_to_restore = {**checkpoint_to_restore, "id": pre_run_checkpoint_id}
if checkpoint_to_restore.get("id") is None:
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
return
metadata = pre_run_snapshot.get("metadata", {})
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
checkpoint_ns = raw_checkpoint_ns if isinstance(raw_checkpoint_ns, str) else ""
channel_versions = checkpoint_to_restore.get("channel_versions")
new_versions = dict(channel_versions) if isinstance(channel_versions, dict) else {}
restore_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}
restored_config = await _call_checkpointer_method(
checkpointer,
"aput",
"put",
restore_config,
checkpoint_to_restore,
metadata_to_restore if isinstance(metadata_to_restore, dict) else {},
new_versions,
)
if not isinstance(restored_config, dict):
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config: expected dict")
restored_configurable = restored_config.get("configurable", {})
if not isinstance(restored_configurable, dict):
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config payload")
restored_checkpoint_id = restored_configurable.get("checkpoint_id")
if not restored_checkpoint_id:
raise RuntimeError(f"Run {run_id} rollback restore did not return checkpoint_id")
pending_writes = pre_run_snapshot.get("pending_writes", [])
if not pending_writes:
return
writes_by_task: dict[str, list[tuple[str, Any]]] = {}
for item in pending_writes:
if not isinstance(item, (tuple, list)) or len(item) != 3:
raise RuntimeError(f"Run {run_id} rollback failed: pending_write is not a 3-tuple: {item!r}")
task_id, channel, value = item
if not isinstance(channel, str):
raise RuntimeError(f"Run {run_id} rollback failed: pending_write has non-string channel: task_id={task_id!r}, channel={channel!r}")
writes_by_task.setdefault(str(task_id), []).append((channel, value))
for task_id, writes in writes_by_task.items():
await _call_checkpointer_method(
checkpointer,
"aput_writes",
"put_writes",
restored_config,
writes,
task_id=task_id,
)
def _lg_mode_to_sse_event(mode: str) -> str:
"""Map LangGraph internal stream_mode name to SSE event name.
LangGraph's ``astream(stream_mode="messages")`` produces message
tuples. The SSE protocol calls this ``messages-tuple`` when the
client explicitly requests it, but the default SSE event name used
by LangGraph Platform is simply ``"messages"``.
"""
# All LG modes map 1:1 to SSE event names — "messages" stays "messages"
return mode
def _extract_human_message(graph_input: dict) -> HumanMessage | None:
"""Extract or construct a HumanMessage from graph_input for event recording.
Returns a LangChain HumanMessage so callers can use .model_dump() to get
the checkpoint-aligned serialization format.
"""
from langchain_core.messages import HumanMessage
messages = graph_input.get("messages")
if not messages:
return None
last = messages[-1] if isinstance(messages, list) else messages
if isinstance(last, HumanMessage):
return last
if isinstance(last, str):
return HumanMessage(content=last) if last else None
if hasattr(last, "content"):
content = last.content
return HumanMessage(content=content)
if isinstance(last, dict):
content = last.get("content", "")
return HumanMessage(content=content) if content else None
return None
def _unpack_stream_item(
item: Any,
lg_modes: list[str],
stream_subgraphs: bool,
) -> tuple[str | None, Any]:
"""Unpack a multi-mode or subgraph stream item into (mode, chunk).
Returns ``(None, None)`` if the item cannot be parsed.
"""
if stream_subgraphs:
if isinstance(item, tuple) and len(item) == 3:
_ns, mode, chunk = item
return str(mode), chunk
if isinstance(item, tuple) and len(item) == 2:
mode, chunk = item
return str(mode), chunk
return None, None
if isinstance(item, tuple) and len(item) == 2:
mode, chunk = item
return str(mode), chunk
# Fallback: single-element output from first mode
return lg_modes[0] if lg_modes else None, item