feat(events): replace llm_start/llm_end with llm_request/llm_response in OpenAI format

Add on_chat_model_start to capture structured prompt messages as llm_request events.
Replace llm_end trace events with llm_response using OpenAI Chat Completions format.
Track llm_call_index to pair request/response events.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-04 09:37:34 +08:00
parent 362226be6e
commit 41745f1f2b
2 changed files with 138 additions and 19 deletions
@@ -6,7 +6,8 @@ handles token usage accumulation.
Key design decisions:
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
- All LangChain objects serialized via serialize_lc_object (same as worker.py SSE)
- on_chat_model_start captures structured prompts as llm_request (OpenAI format)
- on_llm_end emits llm_response in OpenAI Chat Completions format
- Token usage accumulated in memory, written to RunRow on run completion
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
"""
@@ -67,6 +68,11 @@ class RunJournal(BaseCallbackHandler):
# Latency tracking
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
# LLM request/response tracking
self._llm_call_index = 0
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
# Tool call ID cache
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
@@ -100,17 +106,36 @@ class RunJournal(BaseCallbackHandler):
# -- LLM callbacks --
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times[str(run_id)] = time.monotonic()
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
"""Capture structured prompt messages for llm_request event."""
from deerflow.runtime.converters import langchain_messages_to_openai
rid = str(run_id)
self._llm_start_times[rid] = time.monotonic()
self._llm_call_index += 1
model_name = serialized.get("name", "")
self._cached_models[rid] = model_name
# Convert the first message list (LangChain passes list-of-lists)
prompt_msgs = messages[0] if messages else []
openai_msgs = langchain_messages_to_openai(prompt_msgs)
self._cached_prompts[rid] = openai_msgs
caller = self._identify_caller(kwargs)
self._put(
event_type="llm_start",
event_type="llm_request",
category="trace",
metadata={"model_name": serialized.get("name", "")},
content={"model": model_name, "messages": openai_msgs},
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
)
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
# Fallback: on_chat_model_start is preferred. This just tracks latency.
self._llm_start_times[str(run_id)] = time.monotonic()
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.converters import langchain_to_openai_message
from deerflow.runtime.serialization import serialize_lc_object
from deerflow.runtime.converters import langchain_to_openai_completion, langchain_to_openai_message
try:
message = response.generations[0][0].message
@@ -121,24 +146,36 @@ class RunJournal(BaseCallbackHandler):
caller = self._identify_caller(kwargs)
# Latency
start = self._llm_start_times.pop(str(run_id), None)
rid = str(run_id)
start = self._llm_start_times.pop(rid, None)
latency_ms = int((time.monotonic() - start) * 1000) if start else None
# Token usage from message
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
# Trace event: llm_end (every LLM call)
# Resolve call index
call_index = self._llm_call_index
if rid not in self._cached_prompts:
# Fallback: on_chat_model_start was not called
self._llm_call_index += 1
call_index = self._llm_call_index
# Clean up caches
self._cached_prompts.pop(rid, None)
self._cached_models.pop(rid, None)
# Trace event: llm_response (OpenAI completion format)
content = getattr(message, "content", "")
self._put(
event_type="llm_end",
event_type="llm_response",
category="trace",
content=content if isinstance(content, str) else str(content),
content=langchain_to_openai_completion(message),
metadata={
"message": serialize_lc_object(message),
"caller": caller,
"usage": usage_dict,
"latency_ms": latency_ms,
"llm_call_index": call_index,
},
)