mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 00:45:57 +00:00
Merge refactor/config-deerflow-context into release/2.0-rc
Cherry-pick PR #2271's config refactor onto release/2.0-rc. Used 'git merge -X theirs' to auto-resolve content conflicts in favor of the PR's design (frozen AppConfig + explicit-parameter passing). Limitations: - Release-only changes that overlapped with PR's refactor in 119 files are NOT preserved — those files reflect PR's version. Follow-up commits on this branch will need to re-apply release-only modifications where meaningful. - See PR #2271 for design rationale.
This commit is contained in:
@@ -24,7 +24,7 @@ from collections.abc import AsyncIterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.runtime.checkpointer.provider import (
|
||||
POSTGRES_CONN_REQUIRED,
|
||||
POSTGRES_INSTALL,
|
||||
@@ -123,11 +123,11 @@ async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpoi
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
|
||||
async def make_checkpointer(app_config: AppConfig) -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that yields a checkpointer for the caller's lifetime.
|
||||
Resources are opened on enter and closed on exit -- no global state::
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
async with make_checkpointer(app_config) as checkpointer:
|
||||
app.state.checkpointer = checkpointer
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
@@ -138,16 +138,14 @@ async def make_checkpointer() -> AsyncIterator[Checkpointer]:
|
||||
3. Default InMemorySaver
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
|
||||
# Legacy: standalone checkpointer config takes precedence
|
||||
if config.checkpointer is not None:
|
||||
async with _async_checkpointer(config.checkpointer) as saver:
|
||||
if app_config.checkpointer is not None:
|
||||
async with _async_checkpointer(app_config.checkpointer) as saver:
|
||||
yield saver
|
||||
return
|
||||
|
||||
# Unified database config
|
||||
db_config = getattr(config, "database", None)
|
||||
db_config = getattr(app_config, "database", None)
|
||||
if db_config is not None and db_config.backend != "memory":
|
||||
async with _async_checkpointer_from_database(db_config) as saver:
|
||||
yield saver
|
||||
|
||||
@@ -25,7 +25,7 @@ from collections.abc import Iterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
@@ -100,10 +100,13 @@ _checkpointer: Checkpointer | None = None
|
||||
_checkpointer_ctx = None # open context manager keeping the connection alive
|
||||
|
||||
|
||||
def get_checkpointer() -> Checkpointer:
|
||||
def get_checkpointer(app_config: AppConfig) -> Checkpointer:
|
||||
"""Return the global sync checkpointer singleton, creating it on first call.
|
||||
|
||||
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
Returns an ``InMemorySaver`` only when ``checkpointer`` is explicitly
|
||||
absent from config.yaml. Any other failure (missing config, invalid
|
||||
backend, connection error) propagates — silent degradation to in-memory
|
||||
would drop persistent-run state on process restart.
|
||||
|
||||
Raises:
|
||||
ImportError: If the required package for the configured backend is not installed.
|
||||
@@ -114,25 +117,7 @@ def get_checkpointer() -> Checkpointer:
|
||||
if _checkpointer is not None:
|
||||
return _checkpointer
|
||||
|
||||
# Ensure app config is loaded before checking checkpointer config
|
||||
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
|
||||
# but hasn't been loaded yet
|
||||
from deerflow.config.app_config import _app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
config = get_checkpointer_config()
|
||||
|
||||
if config is None and _app_config is None:
|
||||
# Only load app config lazily when neither the app config nor an explicit
|
||||
# checkpointer config has been initialized yet. This keeps tests that
|
||||
# intentionally set the global checkpointer config isolated from any
|
||||
# ambient config.yaml on disk.
|
||||
try:
|
||||
get_app_config()
|
||||
except FileNotFoundError:
|
||||
# In test environments without config.yaml, this is expected.
|
||||
pass
|
||||
config = get_checkpointer_config()
|
||||
config = app_config.checkpointer
|
||||
if config is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
@@ -168,25 +153,23 @@ def reset_checkpointer() -> None:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def checkpointer_context() -> Iterator[Checkpointer]:
|
||||
def checkpointer_context(app_config: AppConfig) -> Iterator[Checkpointer]:
|
||||
"""Sync context manager that yields a checkpointer and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
|
||||
each ``with`` block creates and destroys its own connection. Use it in
|
||||
CLI scripts or tests where you want deterministic cleanup::
|
||||
|
||||
with checkpointer_context() as cp:
|
||||
with checkpointer_context(app_config) as cp:
|
||||
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
if config.checkpointer is None:
|
||||
if app_config.checkpointer is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
with _sync_checkpointer_cm(config.checkpointer) as saver:
|
||||
with _sync_checkpointer_cm(app_config.checkpointer) as saver:
|
||||
yield saver
|
||||
|
||||
@@ -6,10 +6,7 @@ handles token usage accumulation.
|
||||
|
||||
Key design decisions:
|
||||
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
||||
- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and
|
||||
extracts the first human message for run.input, because it is more reliable than
|
||||
on_chain_start (fires on every node) — messages here are fully structured.
|
||||
- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation.
|
||||
- on_chat_model_start captures structured prompts as llm_request (OpenAI format)
|
||||
- on_llm_end emits llm_response in OpenAI Chat Completions format
|
||||
- Token usage accumulated in memory, written to RunRow on run completion
|
||||
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
||||
@@ -21,12 +18,10 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
@@ -77,39 +72,34 @@ class RunJournal(BaseCallbackHandler):
|
||||
# LLM request/response tracking
|
||||
self._llm_call_index = 0
|
||||
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
|
||||
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
|
||||
|
||||
# Tool call ID cache
|
||||
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
|
||||
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
caller = self._identify_caller(tags)
|
||||
if parent_run_id is None:
|
||||
# Root graph invocation — emit a single trace event for the run start.
|
||||
chain_name = (serialized or {}).get("name", "unknown")
|
||||
self._put(
|
||||
event_type="run.start",
|
||||
category="trace",
|
||||
content={"chain": chain_name},
|
||||
metadata={"caller": caller, **(metadata or {})},
|
||||
)
|
||||
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_start",
|
||||
category="lifecycle",
|
||||
metadata={"input_preview": str(inputs)[:500]},
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
|
||||
self._flush_sync()
|
||||
|
||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run.error",
|
||||
category="error",
|
||||
event_type="run_error",
|
||||
category="lifecycle",
|
||||
content=str(error),
|
||||
metadata={"error_type": type(error).__name__},
|
||||
)
|
||||
@@ -117,132 +107,266 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
# -- LLM callbacks --
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict,
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Capture structured prompt messages for llm_request event.
|
||||
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
|
||||
"""Capture structured prompt messages for llm_request event."""
|
||||
from deerflow.runtime.converters import langchain_messages_to_openai
|
||||
|
||||
This is also the canonical place to extract the first human message:
|
||||
messages are fully structured here, it fires only on real LLM calls,
|
||||
and the content is never compressed by checkpoint trimming.
|
||||
"""
|
||||
rid = str(run_id)
|
||||
self._llm_start_times[rid] = time.monotonic()
|
||||
self._llm_call_index += 1
|
||||
# Mark this run_id as seen so on_llm_end knows not to increment again.
|
||||
self._cached_prompts[rid] = []
|
||||
|
||||
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
|
||||
model_name = serialized.get("name", "")
|
||||
self._cached_models[rid] = model_name
|
||||
|
||||
# Capture the first human message sent to any LLM in this run.
|
||||
if not self._first_human_msg and not messages:
|
||||
for batch in messages.reversed():
|
||||
for m in batch.reversed():
|
||||
if isinstance(m, HumanMessage) and m.name != "summary":
|
||||
caller = self._identify_caller(tags)
|
||||
self.set_first_human_message(m.text)
|
||||
self._put(
|
||||
event_type="llm.human.input",
|
||||
category="message",
|
||||
content=m.model_dump(),
|
||||
metadata={"caller": caller},
|
||||
)
|
||||
break
|
||||
if self._first_human_msg:
|
||||
break
|
||||
# Convert the first message list (LangChain passes list-of-lists)
|
||||
prompt_msgs = messages[0] if messages else []
|
||||
openai_msgs = langchain_messages_to_openai(prompt_msgs)
|
||||
self._cached_prompts[rid] = openai_msgs
|
||||
|
||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
caller = self._identify_caller(kwargs)
|
||||
self._put(
|
||||
event_type="llm_request",
|
||||
category="trace",
|
||||
content={"model": model_name, "messages": openai_msgs},
|
||||
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
|
||||
)
|
||||
|
||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
|
||||
# Fallback: on_chat_model_start is preferred. This just tracks latency.
|
||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||
|
||||
def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None:
|
||||
messages: list[AnyMessage] = []
|
||||
logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}")
|
||||
for generation in response.generations:
|
||||
for gen in generation:
|
||||
if hasattr(gen, "message"):
|
||||
messages.append(gen.message)
|
||||
else:
|
||||
logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}")
|
||||
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.converters import langchain_to_openai_completion
|
||||
|
||||
for message in messages:
|
||||
caller = self._identify_caller(tags)
|
||||
try:
|
||||
message = response.generations[0][0].message
|
||||
except (IndexError, AttributeError):
|
||||
logger.debug("on_llm_end: could not extract message from response")
|
||||
return
|
||||
|
||||
# Latency
|
||||
rid = str(run_id)
|
||||
start = self._llm_start_times.pop(rid, None)
|
||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||
caller = self._identify_caller(kwargs)
|
||||
|
||||
# Token usage from message
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
# Latency
|
||||
rid = str(run_id)
|
||||
start = self._llm_start_times.pop(rid, None)
|
||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||
|
||||
# Resolve call index
|
||||
# Token usage from message
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
|
||||
# Resolve call index
|
||||
call_index = self._llm_call_index
|
||||
if rid not in self._cached_prompts:
|
||||
# Fallback: on_chat_model_start was not called
|
||||
self._llm_call_index += 1
|
||||
call_index = self._llm_call_index
|
||||
if rid not in self._cached_prompts:
|
||||
# Fallback: on_chat_model_start was not called
|
||||
self._llm_call_index += 1
|
||||
call_index = self._llm_call_index
|
||||
|
||||
# Trace event: llm_response (OpenAI completion format)
|
||||
self._put(
|
||||
event_type="llm.ai.response",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
# Clean up caches
|
||||
self._cached_prompts.pop(rid, None)
|
||||
self._cached_models.pop(rid, None)
|
||||
|
||||
# Token accumulation
|
||||
if self._track_tokens:
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk == 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
# Trace event: llm_response (OpenAI completion format)
|
||||
content = getattr(message, "content", "")
|
||||
self._put(
|
||||
event_type="llm_response",
|
||||
category="trace",
|
||||
content=langchain_to_openai_completion(message),
|
||||
metadata={
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
|
||||
# Message events: only lead_agent gets message-category events.
|
||||
# Content uses message.model_dump() to align with checkpoint format.
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
if caller == "lead_agent":
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
||||
if tool_calls:
|
||||
# ai_tool_call: agent decided to use tools
|
||||
self._put(
|
||||
event_type="ai_tool_call",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
|
||||
)
|
||||
elif isinstance(content, str) and content:
|
||||
# ai_message: final text reply
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={"model_name": model_name, "finish_reason": "stop"},
|
||||
)
|
||||
self._last_ai_msg = content
|
||||
self._msg_count += 1
|
||||
|
||||
# Token accumulation
|
||||
if self._track_tokens:
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk == 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm.error", category="trace", content=str(error))
|
||||
self._put(event_type="llm_error", category="trace", content=str(error))
|
||||
|
||||
def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs):
|
||||
"""Handle tool start event, cache tool call ID for later correlation"""
|
||||
tool_call_id = str(run_id)
|
||||
logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}")
|
||||
# -- Tool callbacks --
|
||||
|
||||
def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs):
|
||||
"""Handle tool end event, append message and clear node data"""
|
||||
try:
|
||||
if isinstance(output, ToolMessage):
|
||||
msg = cast(ToolMessage, output)
|
||||
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
||||
elif isinstance(output, Command):
|
||||
cmd = cast(Command, output)
|
||||
messages = cmd.update.get("messages", [])
|
||||
for message in messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}")
|
||||
finally:
|
||||
logger.info(f"Tool end for node {run_id}")
|
||||
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
tool_call_id = kwargs.get("tool_call_id")
|
||||
if tool_call_id:
|
||||
self._tool_call_ids[str(run_id)] = tool_call_id
|
||||
self._put(
|
||||
event_type="tool_start",
|
||||
category="trace",
|
||||
metadata={
|
||||
"tool_name": serialized.get("name", ""),
|
||||
"tool_call_id": tool_call_id,
|
||||
"args": str(input_str)[:2000],
|
||||
},
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
# Tools that update graph state return a ``Command`` (e.g.
|
||||
# ``present_files``). LangGraph later unwraps the inner ToolMessage
|
||||
# into checkpoint state, so to stay checkpoint-aligned we must
|
||||
# extract it here rather than storing ``str(Command(...))``.
|
||||
if isinstance(output, Command):
|
||||
update = getattr(output, "update", None) or {}
|
||||
inner_msgs = update.get("messages") if isinstance(update, dict) else None
|
||||
if isinstance(inner_msgs, list):
|
||||
inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None)
|
||||
if inner_tool_msg is not None:
|
||||
output = inner_tool_msg
|
||||
|
||||
# Extract fields from ToolMessage object when LangChain provides one.
|
||||
# LangChain's _format_output wraps tool results into a ToolMessage
|
||||
# with tool_call_id, name, status, and artifact — more complete than
|
||||
# what kwargs alone provides.
|
||||
if isinstance(output, ToolMessage):
|
||||
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = output.name or kwargs.get("name", "")
|
||||
status = getattr(output, "status", "success") or "success"
|
||||
content_str = output.content if isinstance(output.content, str) else str(output.content)
|
||||
# Use model_dump() for checkpoint-aligned message content.
|
||||
# Override tool_call_id if it was resolved from cache.
|
||||
msg_content = output.model_dump()
|
||||
if msg_content.get("tool_call_id") != tool_call_id:
|
||||
msg_content["tool_call_id"] = tool_call_id
|
||||
else:
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
status = "success"
|
||||
content_str = str(output)
|
||||
# Construct checkpoint-aligned dict when output is a plain string.
|
||||
msg_content = ToolMessage(
|
||||
content=content_str,
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status=status,
|
||||
).model_dump()
|
||||
|
||||
# Trace event (always)
|
||||
self._put(
|
||||
event_type="tool_end",
|
||||
category="trace",
|
||||
content=content_str,
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result (checkpoint-aligned model_dump format)
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": status},
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
|
||||
# Trace event
|
||||
self._put(
|
||||
event_type="tool_error",
|
||||
category="trace",
|
||||
content=str(error),
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result with error status (checkpoint-aligned)
|
||||
msg_content = ToolMessage(
|
||||
content=str(error),
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status="error",
|
||||
).model_dump()
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": "error"},
|
||||
)
|
||||
|
||||
# -- Custom event callback --
|
||||
|
||||
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
if name == "summarization":
|
||||
data_dict = data if isinstance(data, dict) else {}
|
||||
self._put(
|
||||
event_type="summarization",
|
||||
category="trace",
|
||||
content=data_dict.get("summary", ""),
|
||||
metadata={
|
||||
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
|
||||
"replaced_count": data_dict.get("replaced_count", 0),
|
||||
},
|
||||
)
|
||||
self._put(
|
||||
event_type="middleware:summarize",
|
||||
category="middleware",
|
||||
content={"role": "system", "content": data_dict.get("summary", "")},
|
||||
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
|
||||
)
|
||||
else:
|
||||
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
|
||||
self._put(
|
||||
event_type=name,
|
||||
category="trace",
|
||||
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
|
||||
)
|
||||
|
||||
# -- Internal methods --
|
||||
|
||||
@@ -307,9 +431,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
if exc:
|
||||
logger.warning("Journal flush task failed: %s", exc)
|
||||
|
||||
def _identify_caller(self, tags: list[str] | None, **kwargs) -> str:
|
||||
_tags = tags or kwargs.get("tags", [])
|
||||
for tag in _tags:
|
||||
def _identify_caller(self, kwargs: dict) -> str:
|
||||
for tag in kwargs.get("tags") or []:
|
||||
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
||||
return tag
|
||||
# Default to lead_agent: the main agent graph does not inject
|
||||
|
||||
@@ -54,7 +54,7 @@ class RunManager:
|
||||
self._lock = asyncio.Lock()
|
||||
self._store = store
|
||||
|
||||
async def _persist_to_store(self, record: RunRecord) -> None:
|
||||
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
|
||||
"""Best-effort persist run record to backing store."""
|
||||
if self._store is None:
|
||||
return
|
||||
@@ -68,6 +68,7 @@ class RunManager:
|
||||
metadata=record.metadata or {},
|
||||
kwargs=record.kwargs or {},
|
||||
created_at=record.created_at,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
||||
@@ -89,6 +90,7 @@ class RunManager:
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Create a new pending run and register it."""
|
||||
run_id = str(uuid.uuid4())
|
||||
@@ -107,7 +109,7 @@ class RunManager:
|
||||
)
|
||||
async with self._lock:
|
||||
self._runs[run_id] = record
|
||||
await self._persist_to_store(record)
|
||||
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
@@ -174,6 +176,7 @@ class RunManager:
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Atomically check for inflight runs and create a new one.
|
||||
|
||||
@@ -227,7 +230,7 @@ class RunManager:
|
||||
)
|
||||
self._runs[run_id] = record
|
||||
|
||||
await self._persist_to_store(record)
|
||||
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ class RunStore(abc.ABC):
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
created_at: str | None = None,
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ class MemoryRunStore(RunStore):
|
||||
kwargs=None,
|
||||
error=None,
|
||||
created_at=None,
|
||||
follow_up_to_run_id=None,
|
||||
):
|
||||
now = datetime.now(UTC).isoformat()
|
||||
self._runs[run_id] = {
|
||||
@@ -40,6 +41,7 @@ class MemoryRunStore(RunStore):
|
||||
"metadata": metadata or {},
|
||||
"kwargs": kwargs or {},
|
||||
"error": error,
|
||||
"follow_up_to_run_id": follow_up_to_run_id,
|
||||
"created_at": created_at or now,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.runtime.serialization import serialize
|
||||
from deerflow.runtime.stream_bridge import StreamBridge
|
||||
|
||||
@@ -51,6 +53,8 @@ class RunContext:
|
||||
event_store: Any | None = field(default=None)
|
||||
run_events_config: Any | None = field(default=None)
|
||||
thread_store: Any | None = field(default=None)
|
||||
follow_up_to_run_id: str | None = field(default=None)
|
||||
app_config: AppConfig | None = field(default=None)
|
||||
|
||||
|
||||
async def run_agent(
|
||||
@@ -75,6 +79,7 @@ async def run_agent(
|
||||
event_store = ctx.event_store
|
||||
run_events_config = ctx.run_events_config
|
||||
thread_store = ctx.thread_store
|
||||
follow_up_to_run_id = ctx.follow_up_to_run_id
|
||||
|
||||
run_id = record.run_id
|
||||
thread_id = record.thread_id
|
||||
@@ -111,6 +116,22 @@ async def run_agent(
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
human_msg = _extract_human_message(graph_input)
|
||||
if human_msg is not None:
|
||||
msg_metadata = {}
|
||||
if follow_up_to_run_id:
|
||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||
await event_store.put(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=human_msg.model_dump(),
|
||||
metadata=msg_metadata or None,
|
||||
)
|
||||
content = human_msg.content
|
||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||
|
||||
# 1. Mark running
|
||||
await run_manager.set_status(run_id, RunStatus.running)
|
||||
|
||||
@@ -144,18 +165,21 @@ async def run_agent(
|
||||
|
||||
# 3. Build the agent
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# Inject runtime context so middlewares can access thread_id
|
||||
# (langgraph-cli does this automatically; we must do it manually)
|
||||
runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store)
|
||||
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
||||
# prefers it over ``configurable`` for thread-level data), make
|
||||
# sure ``thread_id`` is available there too.
|
||||
if "context" in config and isinstance(config["context"], dict):
|
||||
config["context"].setdefault("thread_id", thread_id)
|
||||
config["context"].setdefault("run_id", run_id)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
# Construct typed context for the agent run.
|
||||
# LangGraph's astream(context=...) injects this into Runtime.context
|
||||
# so middleware/tools can access it via resolve_context().
|
||||
if ctx.app_config is None:
|
||||
raise RuntimeError("RunContext.app_config is required — Gateway must populate it via get_run_context")
|
||||
deer_flow_context = DeerFlowContext(
|
||||
app_config=ctx.app_config,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
# Inject RunJournal as a LangChain callback handler.
|
||||
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
|
||||
if journal is not None:
|
||||
config.setdefault("callbacks", []).append(journal)
|
||||
|
||||
# Inject RunJournal as a LangChain callback handler.
|
||||
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
|
||||
@@ -207,7 +231,7 @@ async def run_agent(
|
||||
if len(lg_modes) == 1 and not stream_subgraphs:
|
||||
# Single mode, no subgraphs: astream yields raw chunks
|
||||
single_mode = lg_modes[0]
|
||||
async for chunk in agent.astream(graph_input, config=runnable_config, stream_mode=single_mode):
|
||||
async for chunk in agent.astream(graph_input, config=runnable_config, context=deer_flow_context, stream_mode=single_mode):
|
||||
if record.abort_event.is_set():
|
||||
logger.info("Run %s abort requested — stopping", run_id)
|
||||
break
|
||||
@@ -218,6 +242,7 @@ async def run_agent(
|
||||
async for item in agent.astream(
|
||||
graph_input,
|
||||
config=runnable_config,
|
||||
context=deer_flow_context,
|
||||
stream_mode=lg_modes,
|
||||
subgraphs=stream_subgraphs,
|
||||
):
|
||||
|
||||
@@ -23,7 +23,7 @@ from collections.abc import AsyncIterator
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -86,7 +86,7 @@ async def _async_store(config) -> AsyncIterator[BaseStore]:
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_store() -> AsyncIterator[BaseStore]:
|
||||
async def make_store(app_config: AppConfig) -> AsyncIterator[BaseStore]:
|
||||
"""Async context manager that yields a Store whose backend matches the
|
||||
configured checkpointer.
|
||||
|
||||
@@ -94,20 +94,18 @@ async def make_store() -> AsyncIterator[BaseStore]:
|
||||
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
|
||||
that both singletons always use the same persistence technology::
|
||||
|
||||
async with make_store() as store:
|
||||
async with make_store(app_config) as store:
|
||||
app.state.store = store
|
||||
|
||||
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
``checkpointer`` section is configured (emits a WARNING in that case).
|
||||
"""
|
||||
config = get_app_config()
|
||||
|
||||
if config.checkpointer is None:
|
||||
if app_config.checkpointer is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
yield InMemoryStore()
|
||||
return
|
||||
|
||||
async with _async_store(config.checkpointer) as store:
|
||||
async with _async_store(app_config.checkpointer) as store:
|
||||
yield store
|
||||
|
||||
@@ -26,7 +26,7 @@ from collections.abc import Iterator
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -100,7 +100,7 @@ _store: BaseStore | None = None
|
||||
_store_ctx = None # open context manager keeping the connection alive
|
||||
|
||||
|
||||
def get_store() -> BaseStore:
|
||||
def get_store(app_config: AppConfig) -> BaseStore:
|
||||
"""Return the global sync Store singleton, creating it on first call.
|
||||
|
||||
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
@@ -115,19 +115,10 @@ def get_store() -> BaseStore:
|
||||
if _store is not None:
|
||||
return _store
|
||||
|
||||
# Lazily load app config, mirroring the checkpointer singleton pattern so
|
||||
# that tests that set the global checkpointer config explicitly remain isolated.
|
||||
from deerflow.config.app_config import _app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
config = get_checkpointer_config()
|
||||
|
||||
if config is None and _app_config is None:
|
||||
try:
|
||||
get_app_config()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
config = get_checkpointer_config()
|
||||
# See matching comment in checkpointer/provider.py: a missing config.yaml
|
||||
# is a deployment error, not a cue to silently pick InMemoryStore. Only
|
||||
# the explicit "no checkpointer section" path falls through to memory.
|
||||
config = app_config.checkpointer
|
||||
|
||||
if config is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
@@ -163,26 +154,25 @@ def reset_store() -> None:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def store_context() -> Iterator[BaseStore]:
|
||||
def store_context(app_config: AppConfig) -> Iterator[BaseStore]:
|
||||
"""Sync context manager that yields a Store and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_store`, this does **not** cache the instance — each
|
||||
``with`` block creates and destroys its own connection. Use it in CLI
|
||||
scripts or tests where you want deterministic cleanup::
|
||||
|
||||
with store_context() as store:
|
||||
with store_context(app_config) as store:
|
||||
store.put(("threads",), thread_id, {...})
|
||||
|
||||
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
config = get_app_config()
|
||||
if config.checkpointer is None:
|
||||
if app_config.checkpointer is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
yield InMemoryStore()
|
||||
return
|
||||
|
||||
with _sync_store_cm(config.checkpointer) as store:
|
||||
with _sync_store_cm(app_config.checkpointer) as store:
|
||||
yield store
|
||||
|
||||
@@ -17,7 +17,7 @@ import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from deerflow.config.stream_bridge_config import get_stream_bridge_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
from .base import StreamBridge
|
||||
|
||||
@@ -25,14 +25,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_stream_bridge(config=None) -> AsyncIterator[StreamBridge]:
|
||||
async def make_stream_bridge(app_config: AppConfig) -> AsyncIterator[StreamBridge]:
|
||||
"""Async context manager that yields a :class:`StreamBridge`.
|
||||
|
||||
Falls back to :class:`MemoryStreamBridge` when no configuration is
|
||||
provided and nothing is set globally.
|
||||
Falls back to :class:`MemoryStreamBridge` when no ``stream_bridge``
|
||||
section is configured.
|
||||
"""
|
||||
if config is None:
|
||||
config = get_stream_bridge_config()
|
||||
config = app_config.stream_bridge
|
||||
|
||||
if config is None or config.type == "memory":
|
||||
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
|
||||
|
||||
Reference in New Issue
Block a user