Merge refactor/config-deerflow-context into release/2.0-rc

Cherry-pick PR #2271's config refactor onto release/2.0-rc.
Used 'git merge -X theirs' to auto-resolve content conflicts in favor of
the PR's design (frozen AppConfig + explicit-parameter passing).

Limitations:
- Release-only changes that overlapped with PR's refactor in 119 files
  are NOT preserved — those files reflect PR's version. Follow-up commits
  on this branch will need to re-apply release-only modifications where
  meaningful.
- See PR #2271 for design rationale.
This commit is contained in:
greatmengqi
2026-04-27 18:16:42 +08:00
227 changed files with 6965 additions and 5578 deletions
@@ -24,7 +24,7 @@ from collections.abc import AsyncIterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.checkpointer.provider import (
POSTGRES_CONN_REQUIRED,
POSTGRES_INSTALL,
@@ -123,11 +123,11 @@ async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpoi
@contextlib.asynccontextmanager
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
async def make_checkpointer(app_config: AppConfig) -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit -- no global state::
async with make_checkpointer() as checkpointer:
async with make_checkpointer(app_config) as checkpointer:
app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
@@ -138,16 +138,14 @@ async def make_checkpointer() -> AsyncIterator[Checkpointer]:
3. Default InMemorySaver
"""
config = get_app_config()
# Legacy: standalone checkpointer config takes precedence
if config.checkpointer is not None:
async with _async_checkpointer(config.checkpointer) as saver:
if app_config.checkpointer is not None:
async with _async_checkpointer(app_config.checkpointer) as saver:
yield saver
return
# Unified database config
db_config = getattr(config, "database", None)
db_config = getattr(app_config, "database", None)
if db_config is not None and db_config.backend != "memory":
async with _async_checkpointer_from_database(db_config) as saver:
yield saver
@@ -25,7 +25,7 @@ from collections.abc import Iterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
@@ -100,10 +100,13 @@ _checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive
def get_checkpointer() -> Checkpointer:
def get_checkpointer(app_config: AppConfig) -> Checkpointer:
"""Return the global sync checkpointer singleton, creating it on first call.
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
Returns an ``InMemorySaver`` only when ``checkpointer`` is explicitly
absent from config.yaml. Any other failure (missing config, invalid
backend, connection error) propagates — silent degradation to in-memory
would drop persistent-run state on process restart.
Raises:
ImportError: If the required package for the configured backend is not installed.
@@ -114,25 +117,7 @@ def get_checkpointer() -> Checkpointer:
if _checkpointer is not None:
return _checkpointer
# Ensure app config is loaded before checking checkpointer config
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
# but hasn't been loaded yet
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
config = app_config.checkpointer
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
@@ -168,25 +153,23 @@ def reset_checkpointer() -> None:
@contextlib.contextmanager
def checkpointer_context() -> Iterator[Checkpointer]:
def checkpointer_context(app_config: AppConfig) -> Iterator[Checkpointer]:
"""Sync context manager that yields a checkpointer and cleans up on exit.
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
each ``with`` block creates and destroys its own connection. Use it in
CLI scripts or tests where you want deterministic cleanup::
with checkpointer_context() as cp:
with checkpointer_context(app_config) as cp:
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
if app_config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
with _sync_checkpointer_cm(config.checkpointer) as saver:
with _sync_checkpointer_cm(app_config.checkpointer) as saver:
yield saver
@@ -6,10 +6,7 @@ handles token usage accumulation.
Key design decisions:
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and
extracts the first human message for run.input, because it is more reliable than
on_chain_start (fires on every node) — messages here are fully structured.
- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation.
- on_chat_model_start captures structured prompts as llm_request (OpenAI format)
- on_llm_end emits llm_response in OpenAI Chat Completions format
- Token usage accumulated in memory, written to RunRow on run completion
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
@@ -21,12 +18,10 @@ import asyncio
import logging
import time
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.types import Command
if TYPE_CHECKING:
from deerflow.runtime.events.store.base import RunEventStore
@@ -77,39 +72,34 @@ class RunJournal(BaseCallbackHandler):
# LLM request/response tracking
self._llm_call_index = 0
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
# Tool call ID cache
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
# -- Lifecycle callbacks --
def on_chain_start(
self,
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
caller = self._identify_caller(tags)
if parent_run_id is None:
# Root graph invocation — emit a single trace event for the run start.
chain_name = (serialized or {}).get("name", "unknown")
self._put(
event_type="run.start",
category="trace",
content={"chain": chain_name},
metadata={"caller": caller, **(metadata or {})},
)
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(
event_type="run_start",
category="lifecycle",
metadata={"input_preview": str(inputs)[:500]},
)
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
if kwargs.get("parent_run_id") is not None:
return
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
self._flush_sync()
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(
event_type="run.error",
category="error",
event_type="run_error",
category="lifecycle",
content=str(error),
metadata={"error_type": type(error).__name__},
)
@@ -117,132 +107,266 @@ class RunJournal(BaseCallbackHandler):
# -- LLM callbacks --
def on_chat_model_start(
self,
serialized: dict,
messages: list[list[BaseMessage]],
*,
run_id: UUID,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
"""Capture structured prompt messages for llm_request event.
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
"""Capture structured prompt messages for llm_request event."""
from deerflow.runtime.converters import langchain_messages_to_openai
This is also the canonical place to extract the first human message:
messages are fully structured here, it fires only on real LLM calls,
and the content is never compressed by checkpoint trimming.
"""
rid = str(run_id)
self._llm_start_times[rid] = time.monotonic()
self._llm_call_index += 1
# Mark this run_id as seen so on_llm_end knows not to increment again.
self._cached_prompts[rid] = []
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
model_name = serialized.get("name", "")
self._cached_models[rid] = model_name
# Capture the first human message sent to any LLM in this run.
if not self._first_human_msg and not messages:
for batch in messages.reversed():
for m in batch.reversed():
if isinstance(m, HumanMessage) and m.name != "summary":
caller = self._identify_caller(tags)
self.set_first_human_message(m.text)
self._put(
event_type="llm.human.input",
category="message",
content=m.model_dump(),
metadata={"caller": caller},
)
break
if self._first_human_msg:
break
# Convert the first message list (LangChain passes list-of-lists)
prompt_msgs = messages[0] if messages else []
openai_msgs = langchain_messages_to_openai(prompt_msgs)
self._cached_prompts[rid] = openai_msgs
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None:
caller = self._identify_caller(kwargs)
self._put(
event_type="llm_request",
category="trace",
content={"model": model_name, "messages": openai_msgs},
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
)
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
# Fallback: on_chat_model_start is preferred. This just tracks latency.
self._llm_start_times[str(run_id)] = time.monotonic()
def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None:
messages: list[AnyMessage] = []
logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}")
for generation in response.generations:
for gen in generation:
if hasattr(gen, "message"):
messages.append(gen.message)
else:
logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}")
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.converters import langchain_to_openai_completion
for message in messages:
caller = self._identify_caller(tags)
try:
message = response.generations[0][0].message
except (IndexError, AttributeError):
logger.debug("on_llm_end: could not extract message from response")
return
# Latency
rid = str(run_id)
start = self._llm_start_times.pop(rid, None)
latency_ms = int((time.monotonic() - start) * 1000) if start else None
caller = self._identify_caller(kwargs)
# Token usage from message
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
# Latency
rid = str(run_id)
start = self._llm_start_times.pop(rid, None)
latency_ms = int((time.monotonic() - start) * 1000) if start else None
# Resolve call index
# Token usage from message
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
# Resolve call index
call_index = self._llm_call_index
if rid not in self._cached_prompts:
# Fallback: on_chat_model_start was not called
self._llm_call_index += 1
call_index = self._llm_call_index
if rid not in self._cached_prompts:
# Fallback: on_chat_model_start was not called
self._llm_call_index += 1
call_index = self._llm_call_index
# Trace event: llm_response (OpenAI completion format)
self._put(
event_type="llm.ai.response",
category="message",
content=message.model_dump(),
metadata={
"caller": caller,
"usage": usage_dict,
"latency_ms": latency_ms,
"llm_call_index": call_index,
},
)
# Clean up caches
self._cached_prompts.pop(rid, None)
self._cached_models.pop(rid, None)
# Token accumulation
if self._track_tokens:
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if total_tk == 0:
total_tk = input_tk + output_tk
if total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
# Trace event: llm_response (OpenAI completion format)
content = getattr(message, "content", "")
self._put(
event_type="llm_response",
category="trace",
content=langchain_to_openai_completion(message),
metadata={
"caller": caller,
"usage": usage_dict,
"latency_ms": latency_ms,
"llm_call_index": call_index,
},
)
# Message events: only lead_agent gets message-category events.
# Content uses message.model_dump() to align with checkpoint format.
tool_calls = getattr(message, "tool_calls", None) or []
if caller == "lead_agent":
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
if tool_calls:
# ai_tool_call: agent decided to use tools
self._put(
event_type="ai_tool_call",
category="message",
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
)
elif isinstance(content, str) and content:
# ai_message: final text reply
self._put(
event_type="ai_message",
category="message",
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "stop"},
)
self._last_ai_msg = content
self._msg_count += 1
# Token accumulation
if self._track_tokens:
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if total_tk == 0:
total_tk = input_tk + output_tk
if total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None)
self._put(event_type="llm.error", category="trace", content=str(error))
self._put(event_type="llm_error", category="trace", content=str(error))
def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs):
"""Handle tool start event, cache tool call ID for later correlation"""
tool_call_id = str(run_id)
logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}")
# -- Tool callbacks --
def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs):
"""Handle tool end event, append message and clear node data"""
try:
if isinstance(output, ToolMessage):
msg = cast(ToolMessage, output)
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
elif isinstance(output, Command):
cmd = cast(Command, output)
messages = cmd.update.get("messages", [])
for message in messages:
if isinstance(message, BaseMessage):
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
else:
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
else:
logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}")
finally:
logger.info(f"Tool end for node {run_id}")
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
tool_call_id = kwargs.get("tool_call_id")
if tool_call_id:
self._tool_call_ids[str(run_id)] = tool_call_id
self._put(
event_type="tool_start",
category="trace",
metadata={
"tool_name": serialized.get("name", ""),
"tool_call_id": tool_call_id,
"args": str(input_str)[:2000],
},
)
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
from langgraph.types import Command
# Tools that update graph state return a ``Command`` (e.g.
# ``present_files``). LangGraph later unwraps the inner ToolMessage
# into checkpoint state, so to stay checkpoint-aligned we must
# extract it here rather than storing ``str(Command(...))``.
if isinstance(output, Command):
update = getattr(output, "update", None) or {}
inner_msgs = update.get("messages") if isinstance(update, dict) else None
if isinstance(inner_msgs, list):
inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None)
if inner_tool_msg is not None:
output = inner_tool_msg
# Extract fields from ToolMessage object when LangChain provides one.
# LangChain's _format_output wraps tool results into a ToolMessage
# with tool_call_id, name, status, and artifact — more complete than
# what kwargs alone provides.
if isinstance(output, ToolMessage):
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = output.name or kwargs.get("name", "")
status = getattr(output, "status", "success") or "success"
content_str = output.content if isinstance(output.content, str) else str(output.content)
# Use model_dump() for checkpoint-aligned message content.
# Override tool_call_id if it was resolved from cache.
msg_content = output.model_dump()
if msg_content.get("tool_call_id") != tool_call_id:
msg_content["tool_call_id"] = tool_call_id
else:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
status = "success"
content_str = str(output)
# Construct checkpoint-aligned dict when output is a plain string.
msg_content = ToolMessage(
content=content_str,
tool_call_id=tool_call_id or "",
name=tool_name,
status=status,
).model_dump()
# Trace event (always)
self._put(
event_type="tool_end",
category="trace",
content=content_str,
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
"status": status,
},
)
# Message event: tool_result (checkpoint-aligned model_dump format)
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": status},
)
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
# Trace event
self._put(
event_type="tool_error",
category="trace",
content=str(error),
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
},
)
# Message event: tool_result with error status (checkpoint-aligned)
msg_content = ToolMessage(
content=str(error),
tool_call_id=tool_call_id or "",
name=tool_name,
status="error",
).model_dump()
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": "error"},
)
# -- Custom event callback --
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.serialization import serialize_lc_object
if name == "summarization":
data_dict = data if isinstance(data, dict) else {}
self._put(
event_type="summarization",
category="trace",
content=data_dict.get("summary", ""),
metadata={
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
"replaced_count": data_dict.get("replaced_count", 0),
},
)
self._put(
event_type="middleware:summarize",
category="middleware",
content={"role": "system", "content": data_dict.get("summary", "")},
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
)
else:
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
self._put(
event_type=name,
category="trace",
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
)
# -- Internal methods --
@@ -307,9 +431,8 @@ class RunJournal(BaseCallbackHandler):
if exc:
logger.warning("Journal flush task failed: %s", exc)
def _identify_caller(self, tags: list[str] | None, **kwargs) -> str:
_tags = tags or kwargs.get("tags", [])
for tag in _tags:
def _identify_caller(self, kwargs: dict) -> str:
for tag in kwargs.get("tags") or []:
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
return tag
# Default to lead_agent: the main agent graph does not inject
@@ -54,7 +54,7 @@ class RunManager:
self._lock = asyncio.Lock()
self._store = store
async def _persist_to_store(self, record: RunRecord) -> None:
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
"""Best-effort persist run record to backing store."""
if self._store is None:
return
@@ -68,6 +68,7 @@ class RunManager:
metadata=record.metadata or {},
kwargs=record.kwargs or {},
created_at=record.created_at,
follow_up_to_run_id=follow_up_to_run_id,
)
except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
@@ -89,6 +90,7 @@ class RunManager:
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord:
"""Create a new pending run and register it."""
run_id = str(uuid.uuid4())
@@ -107,7 +109,7 @@ class RunManager:
)
async with self._lock:
self._runs[run_id] = record
await self._persist_to_store(record)
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
@@ -174,6 +176,7 @@ class RunManager:
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.
@@ -227,7 +230,7 @@ class RunManager:
)
self._runs[run_id] = record
await self._persist_to_store(record)
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
@@ -29,6 +29,7 @@ class RunStore(abc.ABC):
kwargs: dict[str, Any] | None = None,
error: str | None = None,
created_at: str | None = None,
follow_up_to_run_id: str | None = None,
) -> None:
pass
@@ -28,6 +28,7 @@ class MemoryRunStore(RunStore):
kwargs=None,
error=None,
created_at=None,
follow_up_to_run_id=None,
):
now = datetime.now(UTC).isoformat()
self._runs[run_id] = {
@@ -40,6 +41,7 @@ class MemoryRunStore(RunStore):
"metadata": metadata or {},
"kwargs": kwargs or {},
"error": error,
"follow_up_to_run_id": follow_up_to_run_id,
"created_at": created_at or now,
"updated_at": now,
}
@@ -25,6 +25,8 @@ from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from langchain_core.messages import HumanMessage
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge
@@ -51,6 +53,8 @@ class RunContext:
event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None)
thread_store: Any | None = field(default=None)
follow_up_to_run_id: str | None = field(default=None)
app_config: AppConfig | None = field(default=None)
async def run_agent(
@@ -75,6 +79,7 @@ async def run_agent(
event_store = ctx.event_store
run_events_config = ctx.run_events_config
thread_store = ctx.thread_store
follow_up_to_run_id = ctx.follow_up_to_run_id
run_id = record.run_id
thread_id = record.thread_id
@@ -111,6 +116,22 @@ async def run_agent(
track_token_usage=getattr(run_events_config, "track_token_usage", True),
)
human_msg = _extract_human_message(graph_input)
if human_msg is not None:
msg_metadata = {}
if follow_up_to_run_id:
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
await event_store.put(
thread_id=thread_id,
run_id=run_id,
event_type="human_message",
category="message",
content=human_msg.model_dump(),
metadata=msg_metadata or None,
)
content = human_msg.content
journal.set_first_human_message(content if isinstance(content, str) else str(content))
# 1. Mark running
await run_manager.set_status(run_id, RunStatus.running)
@@ -144,18 +165,21 @@ async def run_agent(
# 3. Build the agent
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
# Inject runtime context so middlewares can access thread_id
# (langgraph-cli does this automatically; we must do it manually)
runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store)
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
# prefers it over ``configurable`` for thread-level data), make
# sure ``thread_id`` is available there too.
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config["context"].setdefault("run_id", run_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Construct typed context for the agent run.
# LangGraph's astream(context=...) injects this into Runtime.context
# so middleware/tools can access it via resolve_context().
if ctx.app_config is None:
raise RuntimeError("RunContext.app_config is required — Gateway must populate it via get_run_context")
deer_flow_context = DeerFlowContext(
app_config=ctx.app_config,
thread_id=thread_id,
)
# Inject RunJournal as a LangChain callback handler.
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
if journal is not None:
config.setdefault("callbacks", []).append(journal)
# Inject RunJournal as a LangChain callback handler.
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
@@ -207,7 +231,7 @@ async def run_agent(
if len(lg_modes) == 1 and not stream_subgraphs:
# Single mode, no subgraphs: astream yields raw chunks
single_mode = lg_modes[0]
async for chunk in agent.astream(graph_input, config=runnable_config, stream_mode=single_mode):
async for chunk in agent.astream(graph_input, config=runnable_config, context=deer_flow_context, stream_mode=single_mode):
if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id)
break
@@ -218,6 +242,7 @@ async def run_agent(
async for item in agent.astream(
graph_input,
config=runnable_config,
context=deer_flow_context,
stream_mode=lg_modes,
subgraphs=stream_subgraphs,
):
@@ -23,7 +23,7 @@ from collections.abc import AsyncIterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -86,7 +86,7 @@ async def _async_store(config) -> AsyncIterator[BaseStore]:
@contextlib.asynccontextmanager
async def make_store() -> AsyncIterator[BaseStore]:
async def make_store(app_config: AppConfig) -> AsyncIterator[BaseStore]:
"""Async context manager that yields a Store whose backend matches the
configured checkpointer.
@@ -94,20 +94,18 @@ async def make_store() -> AsyncIterator[BaseStore]:
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
that both singletons always use the same persistence technology::
async with make_store() as store:
async with make_store(app_config) as store:
app.state.store = store
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
``checkpointer`` section is configured (emits a WARNING in that case).
"""
config = get_app_config()
if config.checkpointer is None:
if app_config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
return
async with _async_store(config.checkpointer) as store:
async with _async_store(app_config.checkpointer) as store:
yield store
@@ -26,7 +26,7 @@ from collections.abc import Iterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -100,7 +100,7 @@ _store: BaseStore | None = None
_store_ctx = None # open context manager keeping the connection alive
def get_store() -> BaseStore:
def get_store(app_config: AppConfig) -> BaseStore:
"""Return the global sync Store singleton, creating it on first call.
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
@@ -115,19 +115,10 @@ def get_store() -> BaseStore:
if _store is not None:
return _store
# Lazily load app config, mirroring the checkpointer singleton pattern so
# that tests that set the global checkpointer config explicitly remain isolated.
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
try:
get_app_config()
except FileNotFoundError:
pass
config = get_checkpointer_config()
# See matching comment in checkpointer/provider.py: a missing config.yaml
# is a deployment error, not a cue to silently pick InMemoryStore. Only
# the explicit "no checkpointer section" path falls through to memory.
config = app_config.checkpointer
if config is None:
from langgraph.store.memory import InMemoryStore
@@ -163,26 +154,25 @@ def reset_store() -> None:
@contextlib.contextmanager
def store_context() -> Iterator[BaseStore]:
def store_context(app_config: AppConfig) -> Iterator[BaseStore]:
"""Sync context manager that yields a Store and cleans up on exit.
Unlike :func:`get_store`, this does **not** cache the instance — each
``with`` block creates and destroys its own connection. Use it in CLI
scripts or tests where you want deterministic cleanup::
with store_context() as store:
with store_context(app_config) as store:
store.put(("threads",), thread_id, {...})
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
if app_config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
return
with _sync_store_cm(config.checkpointer) as store:
with _sync_store_cm(app_config.checkpointer) as store:
yield store
@@ -17,7 +17,7 @@ import contextlib
import logging
from collections.abc import AsyncIterator
from deerflow.config.stream_bridge_config import get_stream_bridge_config
from deerflow.config.app_config import AppConfig
from .base import StreamBridge
@@ -25,14 +25,13 @@ logger = logging.getLogger(__name__)
@contextlib.asynccontextmanager
async def make_stream_bridge(config=None) -> AsyncIterator[StreamBridge]:
async def make_stream_bridge(app_config: AppConfig) -> AsyncIterator[StreamBridge]:
"""Async context manager that yields a :class:`StreamBridge`.
Falls back to :class:`MemoryStreamBridge` when no configuration is
provided and nothing is set globally.
Falls back to :class:`MemoryStreamBridge` when no ``stream_bridge``
section is configured.
"""
if config is None:
config = get_stream_bridge_config()
config = app_config.stream_bridge
if config is None or config.type == "memory":
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge