mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
feat: enhance chat history loading with new hooks and UI components (#2338)
* Refactor API fetch calls to use a unified fetch function; enhance chat history loading with new hooks and UI components - Replaced `fetchWithAuth` with a generic `fetch` function across various API modules for consistency. - Updated `useThreadStream` and `useThreadHistory` hooks to manage chat history loading, including loading states and pagination. - Introduced `LoadMoreHistoryIndicator` component for better user experience when loading more chat history. - Enhanced message handling in `MessageList` to accommodate new loading states and history management. - Added support for run messages in the thread context, improving the overall message handling logic. - Updated translations for loading indicators in English and Chinese. * Fix test assertions for run ordering in RunManager tests - Updated assertions in `test_list_by_thread` to reflect correct ordering of runs. - Modified `test_list_by_thread_is_stable_when_timestamps_tie` to ensure stable ordering when timestamps are tied.
This commit is contained in:
+20
-17
@@ -8,13 +8,17 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, TypeVar, cast
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.runtime import RunContext, RunManager
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
from deerflow.runtime import RunContext, RunManager, StreamBridge
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
@@ -22,6 +26,9 @@ if TYPE_CHECKING:
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
||||
@@ -84,25 +91,25 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _require(attr: str, label: str):
|
||||
def _require(attr: str, label: str) -> Callable[[Request], T]:
|
||||
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
|
||||
|
||||
def dep(request: Request):
|
||||
def dep(request: Request) -> T:
|
||||
val = getattr(request.app.state, attr, None)
|
||||
if val is None:
|
||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||
return val
|
||||
return cast(T, val)
|
||||
|
||||
dep.__name__ = dep.__qualname__ = f"get_{attr}"
|
||||
return dep
|
||||
|
||||
|
||||
get_stream_bridge = _require("stream_bridge", "Stream bridge")
|
||||
get_run_manager = _require("run_manager", "Run manager")
|
||||
get_checkpointer = _require("checkpointer", "Checkpointer")
|
||||
get_run_event_store = _require("run_event_store", "Run event store")
|
||||
get_feedback_repo = _require("feedback_repo", "Feedback")
|
||||
get_run_store = _require("run_store", "Run store")
|
||||
get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge")
|
||||
get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager")
|
||||
get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer")
|
||||
get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store")
|
||||
get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback")
|
||||
get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store")
|
||||
|
||||
|
||||
def get_store(request: Request):
|
||||
@@ -121,10 +128,7 @@ def get_thread_store(request: Request) -> ThreadMetaStore:
|
||||
def get_run_context(request: Request) -> RunContext:
|
||||
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
||||
|
||||
Returns a *base* context with infrastructure dependencies. Callers that
|
||||
need per-run fields (e.g. ``follow_up_to_run_id``) should use
|
||||
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
|
||||
to :func:`run_agent`.
|
||||
Returns a *base* context with infrastructure dependencies.
|
||||
"""
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
@@ -137,7 +141,6 @@ def get_run_context(request: Request) -> RunContext:
|
||||
)
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth helpers (used by authz.py and auth middleware)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -123,7 +123,8 @@ async def run_messages(
|
||||
run = await _resolve_run(run_id, request)
|
||||
event_store = get_run_event_store(request)
|
||||
rows = await event_store.list_messages_by_run(
|
||||
run["thread_id"], run_id,
|
||||
run["thread_id"],
|
||||
run_id,
|
||||
limit=limit + 1,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
|
||||
@@ -54,7 +54,6 @@ class RunCreateRequest(BaseModel):
|
||||
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
||||
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
||||
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
||||
follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.")
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
@@ -312,11 +311,15 @@ async def list_thread_messages(
|
||||
if i in last_ai_indices:
|
||||
run_id = msg["run_id"]
|
||||
fb = feedback_map.get(run_id)
|
||||
msg["feedback"] = {
|
||||
"feedback_id": fb["feedback_id"],
|
||||
"rating": fb["rating"],
|
||||
"comment": fb.get("comment"),
|
||||
} if fb else None
|
||||
msg["feedback"] = (
|
||||
{
|
||||
"feedback_id": fb["feedback_id"],
|
||||
"rating": fb["rating"],
|
||||
"comment": fb.get("comment"),
|
||||
}
|
||||
if fb
|
||||
else None
|
||||
)
|
||||
else:
|
||||
msg["feedback"] = None
|
||||
|
||||
@@ -339,7 +342,8 @@ async def list_run_messages(
|
||||
"""
|
||||
event_store = get_run_event_store(request)
|
||||
rows = await event_store.list_messages_by_run(
|
||||
thread_id, run_id,
|
||||
thread_id,
|
||||
run_id,
|
||||
limit=limit + 1,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
|
||||
@@ -56,7 +56,7 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
||||
|
||||
|
||||
@router.post("", response_model=UploadResponse)
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=False)
|
||||
async def upload_files(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
|
||||
@@ -195,21 +195,6 @@ async def start_run(
|
||||
|
||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||
|
||||
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
|
||||
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
|
||||
if follow_up_to_run_id is None:
|
||||
run_store = get_run_store(request)
|
||||
try:
|
||||
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
|
||||
if recent_runs and recent_runs[0].get("status") == "success":
|
||||
follow_up_to_run_id = recent_runs[0]["run_id"]
|
||||
except Exception:
|
||||
pass # Don't block run creation
|
||||
|
||||
# Enrich base context with per-run field
|
||||
if follow_up_to_run_id:
|
||||
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
|
||||
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
@@ -218,7 +203,6 @@ async def start_run(
|
||||
metadata=body.metadata or {},
|
||||
kwargs={"input": body.input, "config": body.config},
|
||||
multitask_strategy=body.multitask_strategy,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
)
|
||||
except ConflictError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||
@@ -9,6 +9,7 @@ from deerflow.agents.middlewares.clarification_middleware import ClarificationMi
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
from deerflow.agents.middlewares.summarization_middleware import SummarizationMiddleware
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
||||
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
||||
|
||||
@@ -283,7 +283,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
# the conversation; injecting one mid-conversation crashes
|
||||
# langchain_anthropic's _format_messages(). HumanMessage works
|
||||
# with all providers. See #1299.
|
||||
return {"messages": [HumanMessage(content=warning)]}
|
||||
return {"messages": [HumanMessage(content=warning, name="loop_warning")]}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from typing import override
|
||||
|
||||
from langchain.agents.middleware import SummarizationMiddleware as BaseSummarizationMiddleware
|
||||
from langchain_core.messages.human import HumanMessage
|
||||
|
||||
|
||||
class SummarizationMiddleware(BaseSummarizationMiddleware):
|
||||
@override
|
||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||
"""Override the base implementation to let the human message with the special name 'summary'.
|
||||
And this message will be ignored to display in the frontend, but still can be used as context for the model.
|
||||
"""
|
||||
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
||||
@@ -1,8 +1,10 @@
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
@@ -97,8 +99,20 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
paths = self._create_thread_directories(thread_id, user_id=user_id)
|
||||
logger.debug("Created thread data directories for thread %s", thread_id)
|
||||
|
||||
messages = list(state.get("messages", []))
|
||||
last_message = messages[-1] if messages else None
|
||||
|
||||
if last_message and isinstance(last_message, HumanMessage):
|
||||
messages[-1] = HumanMessage(
|
||||
content=last_message.content,
|
||||
id=last_message.id,
|
||||
name=last_message.name or "user-input",
|
||||
additional_kwargs={**last_message.additional_kwargs, "run_id": runtime.context.get("run_id"), "timestamp": datetime.now(UTC).isoformat()},
|
||||
)
|
||||
|
||||
return {
|
||||
"thread_data": {
|
||||
**paths,
|
||||
}
|
||||
},
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
@@ -279,6 +279,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
updated_message = HumanMessage(
|
||||
content=f"{files_message}\n\n{original_content}",
|
||||
id=last_message.id,
|
||||
name=last_message.name,
|
||||
additional_kwargs=last_message.additional_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,10 @@ handles token usage accumulation.
|
||||
|
||||
Key design decisions:
|
||||
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
||||
- on_chat_model_start captures structured prompts as llm_request (OpenAI format)
|
||||
- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and
|
||||
extracts the first human message for run.input, because it is more reliable than
|
||||
on_chain_start (fires on every node) — messages here are fully structured.
|
||||
- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation.
|
||||
- on_llm_end emits llm_response in OpenAI Chat Completions format
|
||||
- Token usage accumulated in memory, written to RunRow on run completion
|
||||
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
||||
@@ -18,10 +21,12 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
@@ -72,34 +77,39 @@ class RunJournal(BaseCallbackHandler):
|
||||
# LLM request/response tracking
|
||||
self._llm_call_index = 0
|
||||
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
|
||||
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
|
||||
|
||||
# Tool call ID cache
|
||||
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
|
||||
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_start",
|
||||
category="lifecycle",
|
||||
metadata={"input_preview": str(inputs)[:500]},
|
||||
)
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
caller = self._identify_caller(tags)
|
||||
if parent_run_id is None:
|
||||
# Root graph invocation — emit a single trace event for the run start.
|
||||
chain_name = (serialized or {}).get("name", "unknown")
|
||||
self._put(
|
||||
event_type="run.start",
|
||||
category="trace",
|
||||
content={"chain": chain_name},
|
||||
metadata={"caller": caller, **(metadata or {})},
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
|
||||
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
||||
self._flush_sync()
|
||||
|
||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_error",
|
||||
category="lifecycle",
|
||||
event_type="run.error",
|
||||
category="error",
|
||||
content=str(error),
|
||||
metadata={"error_type": type(error).__name__},
|
||||
)
|
||||
@@ -107,266 +117,132 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
# -- LLM callbacks --
|
||||
|
||||
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
|
||||
"""Capture structured prompt messages for llm_request event."""
|
||||
from deerflow.runtime.converters import langchain_messages_to_openai
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict,
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Capture structured prompt messages for llm_request event.
|
||||
|
||||
This is also the canonical place to extract the first human message:
|
||||
messages are fully structured here, it fires only on real LLM calls,
|
||||
and the content is never compressed by checkpoint trimming.
|
||||
"""
|
||||
rid = str(run_id)
|
||||
self._llm_start_times[rid] = time.monotonic()
|
||||
self._llm_call_index += 1
|
||||
# Mark this run_id as seen so on_llm_end knows not to increment again.
|
||||
self._cached_prompts[rid] = []
|
||||
|
||||
model_name = serialized.get("name", "")
|
||||
self._cached_models[rid] = model_name
|
||||
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
|
||||
|
||||
# Convert the first message list (LangChain passes list-of-lists)
|
||||
prompt_msgs = messages[0] if messages else []
|
||||
openai_msgs = langchain_messages_to_openai(prompt_msgs)
|
||||
self._cached_prompts[rid] = openai_msgs
|
||||
# Capture the first human message sent to any LLM in this run.
|
||||
if not self._first_human_msg:
|
||||
for batch in messages.reversed():
|
||||
for m in batch.reversed():
|
||||
if isinstance(m, HumanMessage) and m.name != "summary":
|
||||
caller = self._identify_caller(tags)
|
||||
self.set_first_human_message(m.text)
|
||||
self._put(
|
||||
event_type="llm.human.input",
|
||||
category="message",
|
||||
content=m.model_dump(),
|
||||
metadata={"caller": caller},
|
||||
)
|
||||
break
|
||||
if self._first_human_msg:
|
||||
break
|
||||
|
||||
caller = self._identify_caller(kwargs)
|
||||
self._put(
|
||||
event_type="llm_request",
|
||||
category="trace",
|
||||
content={"model": model_name, "messages": openai_msgs},
|
||||
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
|
||||
)
|
||||
|
||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
|
||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
# Fallback: on_chat_model_start is preferred. This just tracks latency.
|
||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||
|
||||
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.converters import langchain_to_openai_completion
|
||||
|
||||
try:
|
||||
message = response.generations[0][0].message
|
||||
except (IndexError, AttributeError):
|
||||
logger.debug("on_llm_end: could not extract message from response")
|
||||
return
|
||||
|
||||
caller = self._identify_caller(kwargs)
|
||||
|
||||
# Latency
|
||||
rid = str(run_id)
|
||||
start = self._llm_start_times.pop(rid, None)
|
||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||
|
||||
# Token usage from message
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
|
||||
# Resolve call index
|
||||
call_index = self._llm_call_index
|
||||
if rid not in self._cached_prompts:
|
||||
# Fallback: on_chat_model_start was not called
|
||||
self._llm_call_index += 1
|
||||
call_index = self._llm_call_index
|
||||
|
||||
# Clean up caches
|
||||
self._cached_prompts.pop(rid, None)
|
||||
self._cached_models.pop(rid, None)
|
||||
|
||||
# Trace event: llm_response (OpenAI completion format)
|
||||
content = getattr(message, "content", "")
|
||||
self._put(
|
||||
event_type="llm_response",
|
||||
category="trace",
|
||||
content=langchain_to_openai_completion(message),
|
||||
metadata={
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
|
||||
# Message events: only lead_agent gets message-category events.
|
||||
# Content uses message.model_dump() to align with checkpoint format.
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
if caller == "lead_agent":
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
||||
if tool_calls:
|
||||
# ai_tool_call: agent decided to use tools
|
||||
self._put(
|
||||
event_type="ai_tool_call",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
|
||||
)
|
||||
elif isinstance(content, str) and content:
|
||||
# ai_message: final text reply
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={"model_name": model_name, "finish_reason": "stop"},
|
||||
)
|
||||
self._last_ai_msg = content
|
||||
self._msg_count += 1
|
||||
|
||||
# Token accumulation
|
||||
if self._track_tokens:
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk == 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None:
|
||||
messages: list[AnyMessage] = []
|
||||
logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}")
|
||||
for generation in response.generations:
|
||||
for gen in generation:
|
||||
if hasattr(gen, "message"):
|
||||
messages.append(gen.message)
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}")
|
||||
|
||||
for message in messages:
|
||||
caller = self._identify_caller(tags)
|
||||
|
||||
# Latency
|
||||
rid = str(run_id)
|
||||
start = self._llm_start_times.pop(rid, None)
|
||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||
|
||||
# Token usage from message
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
|
||||
# Resolve call index
|
||||
call_index = self._llm_call_index
|
||||
if rid not in self._cached_prompts:
|
||||
# Fallback: on_chat_model_start was not called
|
||||
self._llm_call_index += 1
|
||||
call_index = self._llm_call_index
|
||||
|
||||
# Trace event: llm_response (OpenAI completion format)
|
||||
self._put(
|
||||
event_type="llm.ai.response",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
|
||||
# Token accumulation
|
||||
if self._track_tokens:
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk == 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm_error", category="trace", content=str(error))
|
||||
self._put(event_type="llm.error", category="trace", content=str(error))
|
||||
|
||||
# -- Tool callbacks --
|
||||
def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs):
|
||||
"""Handle tool start event, cache tool call ID for later correlation"""
|
||||
tool_call_id = str(run_id)
|
||||
logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}")
|
||||
|
||||
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
tool_call_id = kwargs.get("tool_call_id")
|
||||
if tool_call_id:
|
||||
self._tool_call_ids[str(run_id)] = tool_call_id
|
||||
self._put(
|
||||
event_type="tool_start",
|
||||
category="trace",
|
||||
metadata={
|
||||
"tool_name": serialized.get("name", ""),
|
||||
"tool_call_id": tool_call_id,
|
||||
"args": str(input_str)[:2000],
|
||||
},
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
# Tools that update graph state return a ``Command`` (e.g.
|
||||
# ``present_files``). LangGraph later unwraps the inner ToolMessage
|
||||
# into checkpoint state, so to stay checkpoint-aligned we must
|
||||
# extract it here rather than storing ``str(Command(...))``.
|
||||
if isinstance(output, Command):
|
||||
update = getattr(output, "update", None) or {}
|
||||
inner_msgs = update.get("messages") if isinstance(update, dict) else None
|
||||
if isinstance(inner_msgs, list):
|
||||
inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None)
|
||||
if inner_tool_msg is not None:
|
||||
output = inner_tool_msg
|
||||
|
||||
# Extract fields from ToolMessage object when LangChain provides one.
|
||||
# LangChain's _format_output wraps tool results into a ToolMessage
|
||||
# with tool_call_id, name, status, and artifact — more complete than
|
||||
# what kwargs alone provides.
|
||||
if isinstance(output, ToolMessage):
|
||||
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = output.name or kwargs.get("name", "")
|
||||
status = getattr(output, "status", "success") or "success"
|
||||
content_str = output.content if isinstance(output.content, str) else str(output.content)
|
||||
# Use model_dump() for checkpoint-aligned message content.
|
||||
# Override tool_call_id if it was resolved from cache.
|
||||
msg_content = output.model_dump()
|
||||
if msg_content.get("tool_call_id") != tool_call_id:
|
||||
msg_content["tool_call_id"] = tool_call_id
|
||||
else:
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
status = "success"
|
||||
content_str = str(output)
|
||||
# Construct checkpoint-aligned dict when output is a plain string.
|
||||
msg_content = ToolMessage(
|
||||
content=content_str,
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status=status,
|
||||
).model_dump()
|
||||
|
||||
# Trace event (always)
|
||||
self._put(
|
||||
event_type="tool_end",
|
||||
category="trace",
|
||||
content=content_str,
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result (checkpoint-aligned model_dump format)
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": status},
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
|
||||
# Trace event
|
||||
self._put(
|
||||
event_type="tool_error",
|
||||
category="trace",
|
||||
content=str(error),
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result with error status (checkpoint-aligned)
|
||||
msg_content = ToolMessage(
|
||||
content=str(error),
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status="error",
|
||||
).model_dump()
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": "error"},
|
||||
)
|
||||
|
||||
# -- Custom event callback --
|
||||
|
||||
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
if name == "summarization":
|
||||
data_dict = data if isinstance(data, dict) else {}
|
||||
self._put(
|
||||
event_type="summarization",
|
||||
category="trace",
|
||||
content=data_dict.get("summary", ""),
|
||||
metadata={
|
||||
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
|
||||
"replaced_count": data_dict.get("replaced_count", 0),
|
||||
},
|
||||
)
|
||||
self._put(
|
||||
event_type="middleware:summarize",
|
||||
category="middleware",
|
||||
content={"role": "system", "content": data_dict.get("summary", "")},
|
||||
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
|
||||
)
|
||||
else:
|
||||
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
|
||||
self._put(
|
||||
event_type=name,
|
||||
category="trace",
|
||||
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
|
||||
)
|
||||
def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs):
|
||||
"""Handle tool end event, append message and clear node data"""
|
||||
try:
|
||||
if isinstance(output, ToolMessage):
|
||||
msg = cast(ToolMessage, output)
|
||||
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
||||
elif isinstance(output, Command):
|
||||
cmd = cast(Command, output)
|
||||
messages = cmd.update.get("messages", [])
|
||||
for message in messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}")
|
||||
finally:
|
||||
logger.info(f"Tool end for node {run_id}")
|
||||
|
||||
# -- Internal methods --
|
||||
|
||||
@@ -431,8 +307,9 @@ class RunJournal(BaseCallbackHandler):
|
||||
if exc:
|
||||
logger.warning("Journal flush task failed: %s", exc)
|
||||
|
||||
def _identify_caller(self, kwargs: dict) -> str:
|
||||
for tag in kwargs.get("tags") or []:
|
||||
def _identify_caller(self, tags: list[str] | None, **kwargs) -> str:
|
||||
_tags = tags or kwargs.get("tags", [])
|
||||
for tag in _tags:
|
||||
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
||||
return tag
|
||||
# Default to lead_agent: the main agent graph does not inject
|
||||
|
||||
@@ -54,7 +54,7 @@ class RunManager:
|
||||
self._lock = asyncio.Lock()
|
||||
self._store = store
|
||||
|
||||
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
|
||||
async def _persist_to_store(self, record: RunRecord) -> None:
|
||||
"""Best-effort persist run record to backing store."""
|
||||
if self._store is None:
|
||||
return
|
||||
@@ -68,7 +68,6 @@ class RunManager:
|
||||
metadata=record.metadata or {},
|
||||
kwargs=record.kwargs or {},
|
||||
created_at=record.created_at,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
||||
@@ -90,7 +89,6 @@ class RunManager:
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Create a new pending run and register it."""
|
||||
run_id = str(uuid.uuid4())
|
||||
@@ -109,7 +107,7 @@ class RunManager:
|
||||
)
|
||||
async with self._lock:
|
||||
self._runs[run_id] = record
|
||||
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||
await self._persist_to_store(record)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
@@ -122,7 +120,7 @@ class RunManager:
|
||||
async with self._lock:
|
||||
# Dict insertion order matches creation order, so reversing it gives
|
||||
# us deterministic newest-first results even when timestamps tie.
|
||||
return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id]
|
||||
return [r for r in self._runs.values() if r.thread_id == thread_id]
|
||||
|
||||
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
||||
"""Transition a run to a new status."""
|
||||
@@ -176,7 +174,6 @@ class RunManager:
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Atomically check for inflight runs and create a new one.
|
||||
|
||||
@@ -230,7 +227,7 @@ class RunManager:
|
||||
)
|
||||
self._runs[run_id] = record
|
||||
|
||||
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||
await self._persist_to_store(record)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ class RunStore(abc.ABC):
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
created_at: str | None = None,
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ class MemoryRunStore(RunStore):
|
||||
kwargs=None,
|
||||
error=None,
|
||||
created_at=None,
|
||||
follow_up_to_run_id=None,
|
||||
):
|
||||
now = datetime.now(UTC).isoformat()
|
||||
self._runs[run_id] = {
|
||||
@@ -41,7 +40,6 @@ class MemoryRunStore(RunStore):
|
||||
"metadata": metadata or {},
|
||||
"kwargs": kwargs or {},
|
||||
"error": error,
|
||||
"follow_up_to_run_id": follow_up_to_run_id,
|
||||
"created_at": created_at or now,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
@@ -51,7 +51,6 @@ class RunContext:
|
||||
event_store: Any | None = field(default=None)
|
||||
run_events_config: Any | None = field(default=None)
|
||||
thread_store: Any | None = field(default=None)
|
||||
follow_up_to_run_id: str | None = field(default=None)
|
||||
|
||||
|
||||
async def run_agent(
|
||||
@@ -76,7 +75,6 @@ async def run_agent(
|
||||
event_store = ctx.event_store
|
||||
run_events_config = ctx.run_events_config
|
||||
thread_store = ctx.thread_store
|
||||
follow_up_to_run_id = ctx.follow_up_to_run_id
|
||||
|
||||
run_id = record.run_id
|
||||
thread_id = record.thread_id
|
||||
@@ -113,22 +111,6 @@ async def run_agent(
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
human_msg = _extract_human_message(graph_input)
|
||||
if human_msg is not None:
|
||||
msg_metadata = {}
|
||||
if follow_up_to_run_id:
|
||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||
await event_store.put(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=human_msg.model_dump(),
|
||||
metadata=msg_metadata or None,
|
||||
)
|
||||
content = human_msg.content
|
||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||
|
||||
# 1. Mark running
|
||||
await run_manager.set_status(run_id, RunStatus.running)
|
||||
|
||||
@@ -166,12 +148,13 @@ async def run_agent(
|
||||
|
||||
# Inject runtime context so middlewares can access thread_id
|
||||
# (langgraph-cli does this automatically; we must do it manually)
|
||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
||||
runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store)
|
||||
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
||||
# prefers it over ``configurable`` for thread-level data), make
|
||||
# sure ``thread_id`` is available there too.
|
||||
if "context" in config and isinstance(config["context"], dict):
|
||||
config["context"].setdefault("thread_id", thread_id)
|
||||
config["context"].setdefault("run_id", run_id)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
# Inject RunJournal as a LangChain callback handler.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -75,27 +75,27 @@ async def test_cancel_not_inflight(manager: RunManager):
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(manager: RunManager):
|
||||
"""Same thread should return multiple runs, newest first."""
|
||||
"""Same thread should return multiple runs."""
|
||||
r1 = await manager.create("thread-1")
|
||||
r2 = await manager.create("thread-1")
|
||||
await manager.create("thread-2")
|
||||
|
||||
runs = await manager.list_by_thread("thread-1")
|
||||
assert len(runs) == 2
|
||||
assert runs[0].run_id == r2.run_id
|
||||
assert runs[1].run_id == r1.run_id
|
||||
assert runs[0].run_id == r1.run_id
|
||||
assert runs[1].run_id == r2.run_id
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_is_stable_when_timestamps_tie(manager: RunManager, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Newest-first ordering should not depend on timestamp precision."""
|
||||
"""Ordering should be stable (insertion order) even when timestamps tie."""
|
||||
monkeypatch.setattr("deerflow.runtime.runs.manager._now_iso", lambda: "2026-01-01T00:00:00+00:00")
|
||||
|
||||
r1 = await manager.create("thread-1")
|
||||
r2 = await manager.create("thread-1")
|
||||
|
||||
runs = await manager.list_by_thread("thread-1")
|
||||
assert [run.run_id for run in runs] == [r2.run_id, r1.run_id]
|
||||
assert [run.run_id for run in runs] == [r1.run_id, r2.run_id]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
Reference in New Issue
Block a user