mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 23:46:50 +00:00
feat(events): replace llm_start/llm_end with llm_request/llm_response in OpenAI format
Add on_chat_model_start to capture structured prompt messages as llm_request events. Replace llm_end trace events with llm_response using OpenAI Chat Completions format. Track llm_call_index to pair request/response events. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -6,7 +6,8 @@ handles token usage accumulation.
|
|||||||
|
|
||||||
Key design decisions:
|
Key design decisions:
|
||||||
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
||||||
- All LangChain objects serialized via serialize_lc_object (same as worker.py SSE)
|
- on_chat_model_start captures structured prompts as llm_request (OpenAI format)
|
||||||
|
- on_llm_end emits llm_response in OpenAI Chat Completions format
|
||||||
- Token usage accumulated in memory, written to RunRow on run completion
|
- Token usage accumulated in memory, written to RunRow on run completion
|
||||||
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
||||||
"""
|
"""
|
||||||
@@ -67,6 +68,11 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
# Latency tracking
|
# Latency tracking
|
||||||
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
|
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
|
||||||
|
|
||||||
|
# LLM request/response tracking
|
||||||
|
self._llm_call_index = 0
|
||||||
|
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
|
||||||
|
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
|
||||||
|
|
||||||
# Tool call ID cache
|
# Tool call ID cache
|
||||||
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
|
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
|
||||||
|
|
||||||
@@ -100,17 +106,36 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
|
|
||||||
# -- LLM callbacks --
|
# -- LLM callbacks --
|
||||||
|
|
||||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
|
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
"""Capture structured prompt messages for llm_request event."""
|
||||||
|
from deerflow.runtime.converters import langchain_messages_to_openai
|
||||||
|
|
||||||
|
rid = str(run_id)
|
||||||
|
self._llm_start_times[rid] = time.monotonic()
|
||||||
|
self._llm_call_index += 1
|
||||||
|
|
||||||
|
model_name = serialized.get("name", "")
|
||||||
|
self._cached_models[rid] = model_name
|
||||||
|
|
||||||
|
# Convert the first message list (LangChain passes list-of-lists)
|
||||||
|
prompt_msgs = messages[0] if messages else []
|
||||||
|
openai_msgs = langchain_messages_to_openai(prompt_msgs)
|
||||||
|
self._cached_prompts[rid] = openai_msgs
|
||||||
|
|
||||||
|
caller = self._identify_caller(kwargs)
|
||||||
self._put(
|
self._put(
|
||||||
event_type="llm_start",
|
event_type="llm_request",
|
||||||
category="trace",
|
category="trace",
|
||||||
metadata={"model_name": serialized.get("name", "")},
|
content={"model": model_name, "messages": openai_msgs},
|
||||||
|
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
|
# Fallback: on_chat_model_start is preferred. This just tracks latency.
|
||||||
|
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||||
|
|
||||||
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
from deerflow.runtime.converters import langchain_to_openai_message
|
from deerflow.runtime.converters import langchain_to_openai_completion, langchain_to_openai_message
|
||||||
from deerflow.runtime.serialization import serialize_lc_object
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message = response.generations[0][0].message
|
message = response.generations[0][0].message
|
||||||
@@ -121,24 +146,36 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
caller = self._identify_caller(kwargs)
|
caller = self._identify_caller(kwargs)
|
||||||
|
|
||||||
# Latency
|
# Latency
|
||||||
start = self._llm_start_times.pop(str(run_id), None)
|
rid = str(run_id)
|
||||||
|
start = self._llm_start_times.pop(rid, None)
|
||||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||||
|
|
||||||
# Token usage from message
|
# Token usage from message
|
||||||
usage = getattr(message, "usage_metadata", None)
|
usage = getattr(message, "usage_metadata", None)
|
||||||
usage_dict = dict(usage) if usage else {}
|
usage_dict = dict(usage) if usage else {}
|
||||||
|
|
||||||
# Trace event: llm_end (every LLM call)
|
# Resolve call index
|
||||||
|
call_index = self._llm_call_index
|
||||||
|
if rid not in self._cached_prompts:
|
||||||
|
# Fallback: on_chat_model_start was not called
|
||||||
|
self._llm_call_index += 1
|
||||||
|
call_index = self._llm_call_index
|
||||||
|
|
||||||
|
# Clean up caches
|
||||||
|
self._cached_prompts.pop(rid, None)
|
||||||
|
self._cached_models.pop(rid, None)
|
||||||
|
|
||||||
|
# Trace event: llm_response (OpenAI completion format)
|
||||||
content = getattr(message, "content", "")
|
content = getattr(message, "content", "")
|
||||||
self._put(
|
self._put(
|
||||||
event_type="llm_end",
|
event_type="llm_response",
|
||||||
category="trace",
|
category="trace",
|
||||||
content=content if isinstance(content, str) else str(content),
|
content=langchain_to_openai_completion(message),
|
||||||
metadata={
|
metadata={
|
||||||
"message": serialize_lc_object(message),
|
|
||||||
"caller": caller,
|
"caller": caller,
|
||||||
"usage": usage_dict,
|
"usage": usage_dict,
|
||||||
"latency_ms": latency_ms,
|
"latency_ms": latency_ms,
|
||||||
|
"llm_call_index": call_index,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class TestLlmCallbacks:
|
|||||||
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"])
|
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"])
|
||||||
await j.flush()
|
await j.flush()
|
||||||
events = await store.list_events("t1", "r1")
|
events = await store.list_events("t1", "r1")
|
||||||
trace_events = [e for e in events if e["event_type"] == "llm_end"]
|
trace_events = [e for e in events if e["event_type"] == "llm_response"]
|
||||||
assert len(trace_events) == 1
|
assert len(trace_events) == 1
|
||||||
assert trace_events[0]["category"] == "trace"
|
assert trace_events[0]["category"] == "trace"
|
||||||
|
|
||||||
@@ -147,9 +147,9 @@ class TestLlmCallbacks:
|
|||||||
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"])
|
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"])
|
||||||
await j.flush()
|
await j.flush()
|
||||||
events = await store.list_events("t1", "r1")
|
events = await store.list_events("t1", "r1")
|
||||||
llm_end = [e for e in events if e["event_type"] == "llm_end"][0]
|
llm_resp = [e for e in events if e["event_type"] == "llm_response"][0]
|
||||||
assert "latency_ms" in llm_end["metadata"]
|
assert "latency_ms" in llm_resp["metadata"]
|
||||||
assert llm_end["metadata"]["latency_ms"] is not None
|
assert llm_resp["metadata"]["latency_ms"] is not None
|
||||||
|
|
||||||
|
|
||||||
class TestLifecycleCallbacks:
|
class TestLifecycleCallbacks:
|
||||||
@@ -252,14 +252,14 @@ class TestBufferFlush:
|
|||||||
|
|
||||||
asyncio.get_running_loop = no_loop
|
asyncio.get_running_loop = no_loop
|
||||||
try:
|
try:
|
||||||
j._put(event_type="llm_end", category="trace", content="test")
|
j._put(event_type="llm_response", category="trace", content="test")
|
||||||
finally:
|
finally:
|
||||||
asyncio.get_running_loop = original
|
asyncio.get_running_loop = original
|
||||||
|
|
||||||
assert len(j._buffer) == 1
|
assert len(j._buffer) == 1
|
||||||
await j.flush()
|
await j.flush()
|
||||||
events = await store.list_events("t1", "r1")
|
events = await store.list_events("t1", "r1")
|
||||||
assert any(e["event_type"] == "llm_end" for e in events)
|
assert any(e["event_type"] == "llm_response" for e in events)
|
||||||
|
|
||||||
|
|
||||||
class TestIdentifyCaller:
|
class TestIdentifyCaller:
|
||||||
@@ -417,7 +417,7 @@ class TestDbBackedLifecycle:
|
|||||||
events = await event_store.list_events("t1", run_id)
|
events = await event_store.list_events("t1", run_id)
|
||||||
event_types = {e["event_type"] for e in events}
|
event_types = {e["event_type"] for e in events}
|
||||||
assert "run_start" in event_types
|
assert "run_start" in event_types
|
||||||
assert "llm_end" in event_types
|
assert "llm_response" in event_types
|
||||||
assert "run_end" in event_types
|
assert "run_end" in event_types
|
||||||
|
|
||||||
await close_engine()
|
await close_engine()
|
||||||
@@ -661,3 +661,85 @@ class TestToolResultMessage:
|
|||||||
await j.flush()
|
await j.flush()
|
||||||
messages = await store.list_messages("t1")
|
messages = await store.list_messages("t1")
|
||||||
assert len(messages) == 0
|
assert len(messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def _make_base_messages():
|
||||||
|
"""Create mock LangChain BaseMessages for on_chat_model_start."""
|
||||||
|
sys_msg = MagicMock()
|
||||||
|
sys_msg.content = "You are helpful."
|
||||||
|
sys_msg.type = "system"
|
||||||
|
sys_msg.tool_calls = []
|
||||||
|
sys_msg.tool_call_id = None
|
||||||
|
|
||||||
|
user_msg = MagicMock()
|
||||||
|
user_msg.content = "Hello"
|
||||||
|
user_msg.type = "human"
|
||||||
|
user_msg.tool_calls = []
|
||||||
|
user_msg.tool_call_id = None
|
||||||
|
|
||||||
|
return [sys_msg, user_msg]
|
||||||
|
|
||||||
|
|
||||||
|
class TestLlmRequestResponse:
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_llm_request_event(self, journal_setup):
|
||||||
|
j, store = journal_setup
|
||||||
|
run_id = uuid4()
|
||||||
|
messages = _make_base_messages()
|
||||||
|
j.on_chat_model_start({"name": "gpt-4o"}, [messages], run_id=run_id, tags=["lead_agent"])
|
||||||
|
await j.flush()
|
||||||
|
events = await store.list_events("t1", "r1")
|
||||||
|
req_events = [e for e in events if e["event_type"] == "llm_request"]
|
||||||
|
assert len(req_events) == 1
|
||||||
|
content = req_events[0]["content"]
|
||||||
|
assert content["model"] == "gpt-4o"
|
||||||
|
assert len(content["messages"]) == 2
|
||||||
|
assert content["messages"][0]["role"] == "system"
|
||||||
|
assert content["messages"][1]["role"] == "user"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_llm_response_event(self, journal_setup):
|
||||||
|
j, store = journal_setup
|
||||||
|
run_id = uuid4()
|
||||||
|
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||||
|
j.on_llm_end(
|
||||||
|
_make_llm_response("Answer", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
||||||
|
run_id=run_id,
|
||||||
|
tags=["lead_agent"],
|
||||||
|
)
|
||||||
|
await j.flush()
|
||||||
|
events = await store.list_events("t1", "r1")
|
||||||
|
assert not any(e["event_type"] == "llm_end" for e in events)
|
||||||
|
resp_events = [e for e in events if e["event_type"] == "llm_response"]
|
||||||
|
assert len(resp_events) == 1
|
||||||
|
content = resp_events[0]["content"]
|
||||||
|
assert "choices" in content
|
||||||
|
assert content["choices"][0]["message"]["role"] == "assistant"
|
||||||
|
assert content["choices"][0]["message"]["content"] == "Answer"
|
||||||
|
assert content["usage"]["prompt_tokens"] == 10
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_llm_request_response_paired(self, journal_setup):
|
||||||
|
j, store = journal_setup
|
||||||
|
run_id = uuid4()
|
||||||
|
messages = _make_base_messages()
|
||||||
|
j.on_chat_model_start({"name": "gpt-4o"}, [messages], run_id=run_id, tags=["lead_agent"])
|
||||||
|
j.on_llm_end(
|
||||||
|
_make_llm_response("Hi", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
||||||
|
run_id=run_id,
|
||||||
|
tags=["lead_agent"],
|
||||||
|
)
|
||||||
|
await j.flush()
|
||||||
|
events = await store.list_events("t1", "r1")
|
||||||
|
req = [e for e in events if e["event_type"] == "llm_request"][0]
|
||||||
|
resp = [e for e in events if e["event_type"] == "llm_response"][0]
|
||||||
|
assert req["metadata"]["llm_call_index"] == resp["metadata"]["llm_call_index"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_no_llm_start_event(self, journal_setup):
|
||||||
|
j, store = journal_setup
|
||||||
|
run_id = uuid4()
|
||||||
|
j.on_llm_start({"name": "test"}, [], run_id=run_id, tags=["lead_agent"])
|
||||||
|
await j.flush()
|
||||||
|
events = await store.list_events("t1", "r1")
|
||||||
|
assert not any(e["event_type"] == "llm_start" for e in events)
|
||||||
|
|||||||
Reference in New Issue
Block a user