mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 17:06:00 +00:00
refactor(runtime): restructure runs module with new execution architecture
Major refactoring of deerflow/runtime/: - runs/callbacks/ - new callback system (builder, events, title, tokens) - runs/internal/ - execution internals (executor, supervisor, stream_logic, registry) - runs/internal/execution/ - execution artifacts and events handling - runs/facade.py - high-level run facade - runs/observer.py - run observation protocol - runs/types.py - type definitions - runs/store/ - simplified store interfaces (create, delete, query, event) Refactor stream_bridge/: - Replace old providers with contract.py and exceptions.py - Remove async_provider.py, base.py, memory.py Add documentation: - README.md and README_zh.md for runtime module Remove deprecated: - manager.py moved to internal/ - worker.py, schemas.py - user_context.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
"""Runs execution callbacks."""
|
||||
|
||||
from .builder import RunCallbackArtifacts, build_run_callbacks
|
||||
from .events import RunEventCallback
|
||||
from .title import RunTitleCallback
|
||||
from .tokens import RunCompletionData, RunTokenCallback
|
||||
|
||||
__all__ = [
|
||||
"RunCallbackArtifacts",
|
||||
"RunCompletionData",
|
||||
"RunEventCallback",
|
||||
"RunTitleCallback",
|
||||
"RunTokenCallback",
|
||||
"build_run_callbacks",
|
||||
]
|
||||
@@ -0,0 +1,138 @@
|
||||
"""Callback assembly for runs execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
from ..store import RunEventStore
|
||||
from ..types import RunRecord
|
||||
from .events import RunEventCallback
|
||||
from .title import RunTitleCallback
|
||||
from .tokens import RunCompletionData, RunTokenCallback
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunCallbackArtifacts:
|
||||
"""Callbacks plus handles used by the executor after callbacks run."""
|
||||
|
||||
callbacks: list[BaseCallbackHandler]
|
||||
event_callback: RunEventCallback | None = None
|
||||
token_callback: RunTokenCallback | None = None
|
||||
title_callback: RunTitleCallback | None = None
|
||||
|
||||
async def flush(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
flush = getattr(callback, "flush", None)
|
||||
if flush is None:
|
||||
continue
|
||||
result = flush()
|
||||
if hasattr(result, "__await__"):
|
||||
await result
|
||||
|
||||
def completion_data(self) -> RunCompletionData:
|
||||
if self.token_callback is None:
|
||||
return RunCompletionData()
|
||||
return self.token_callback.completion_data()
|
||||
|
||||
def title(self) -> str | None:
|
||||
if self.title_callback is None:
|
||||
return None
|
||||
return self.title_callback.title()
|
||||
|
||||
|
||||
def build_run_callbacks(
|
||||
*,
|
||||
record: RunRecord,
|
||||
graph_input: dict[str, Any],
|
||||
event_store: RunEventStore | None,
|
||||
existing_callbacks: Iterable[BaseCallbackHandler] = (),
|
||||
) -> RunCallbackArtifacts:
|
||||
"""Build execution callbacks for a run.
|
||||
|
||||
Reference callbacks are intentionally not assembled here yet; they remain
|
||||
in the existing artifacts path until that integration is migrated.
|
||||
"""
|
||||
callbacks = list(existing_callbacks)
|
||||
|
||||
event_callback = None
|
||||
if event_store is not None:
|
||||
event_callback = RunEventCallback(
|
||||
run_id=record.run_id,
|
||||
thread_id=record.thread_id,
|
||||
event_store=event_store,
|
||||
)
|
||||
callbacks.append(event_callback)
|
||||
|
||||
token_callback = RunTokenCallback(track_token_usage=True)
|
||||
_set_first_human_message(token_callback, graph_input)
|
||||
callbacks.append(token_callback)
|
||||
|
||||
title_callback = RunTitleCallback()
|
||||
callbacks.append(title_callback)
|
||||
|
||||
return RunCallbackArtifacts(
|
||||
callbacks=callbacks,
|
||||
event_callback=event_callback,
|
||||
token_callback=token_callback,
|
||||
title_callback=title_callback,
|
||||
)
|
||||
|
||||
|
||||
def _set_first_human_message(token_callback: RunTokenCallback, graph_input: dict[str, Any]) -> None:
|
||||
messages = graph_input.get("messages")
|
||||
if not isinstance(messages, list) or not messages:
|
||||
return
|
||||
|
||||
first = messages[0]
|
||||
content = _extract_first_human_text(first)
|
||||
if content:
|
||||
token_callback.set_first_human_message(content)
|
||||
|
||||
|
||||
def _extract_first_human_text(message: Any) -> str | None:
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
|
||||
content = getattr(message, "content", None)
|
||||
if content is not None:
|
||||
return _extract_text_content(content)
|
||||
|
||||
if isinstance(message, dict):
|
||||
return _extract_text_content(message.get("content"))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_text_content(content: Any) -> str | None:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
continue
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text" and isinstance(item.get("text"), str):
|
||||
parts.append(item["text"])
|
||||
continue
|
||||
if isinstance(item.get("content"), str):
|
||||
parts.append(item["content"])
|
||||
joined = "".join(parts).strip()
|
||||
return joined or None
|
||||
|
||||
if isinstance(content, dict):
|
||||
text = content.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
nested = content.get("content")
|
||||
if isinstance(nested, str):
|
||||
return nested
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,353 @@
|
||||
"""Run execution event recording callback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from deerflow.runtime.converters import langchain_messages_to_openai, langchain_to_openai_completion
|
||||
|
||||
from ..store import RunEventStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunEventCallback(BaseCallbackHandler):
|
||||
"""Capture LangChain execution events into the run event store."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
event_store: RunEventStore,
|
||||
flush_threshold: int = 5,
|
||||
max_trace_content: int = 10240,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.run_id = run_id
|
||||
self.thread_id = thread_id
|
||||
self._store = event_store
|
||||
self._flush_threshold = flush_threshold
|
||||
self._max_trace_content = max_trace_content
|
||||
self._buffer: list[dict[str, Any]] = []
|
||||
self._llm_start_times: dict[str, float] = {}
|
||||
self._llm_call_index = 0
|
||||
self._cached_prompts: dict[str, list[dict[str, Any]]] = {}
|
||||
self._tool_call_ids: dict[str, str] = {}
|
||||
self._human_message_recorded = False
|
||||
|
||||
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_start",
|
||||
category="lifecycle",
|
||||
metadata={"input_preview": str(inputs)[:500]},
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
|
||||
self._flush_sync()
|
||||
|
||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_error",
|
||||
category="lifecycle",
|
||||
content=str(error),
|
||||
metadata={"error_type": type(error).__name__},
|
||||
)
|
||||
self._flush_sync()
|
||||
|
||||
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
|
||||
rid = str(run_id)
|
||||
self._llm_start_times[rid] = time.monotonic()
|
||||
self._llm_call_index += 1
|
||||
|
||||
prompt_msgs = messages[0] if messages else []
|
||||
openai_msgs = langchain_messages_to_openai(prompt_msgs)
|
||||
self._cached_prompts[rid] = openai_msgs
|
||||
caller = self._identify_caller(kwargs)
|
||||
|
||||
self._record_first_human_message(prompt_msgs, caller=caller)
|
||||
|
||||
self._put(
|
||||
event_type="llm_request",
|
||||
category="trace",
|
||||
content={"model": serialized.get("name", ""), "messages": openai_msgs},
|
||||
metadata={
|
||||
"caller": caller,
|
||||
"llm_call_index": self._llm_call_index,
|
||||
},
|
||||
)
|
||||
|
||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||
|
||||
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
try:
|
||||
message = response.generations[0][0].message
|
||||
except (IndexError, AttributeError):
|
||||
logger.debug("on_llm_end: could not extract message from response")
|
||||
return
|
||||
|
||||
rid = str(run_id)
|
||||
start = self._llm_start_times.pop(rid, None)
|
||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||
usage = dict(getattr(message, "usage_metadata", None) or {})
|
||||
caller = self._identify_caller(kwargs)
|
||||
|
||||
call_index = self._llm_call_index
|
||||
if rid not in self._cached_prompts:
|
||||
self._llm_call_index += 1
|
||||
call_index = self._llm_call_index
|
||||
self._cached_prompts.pop(rid, None)
|
||||
|
||||
self._put(
|
||||
event_type="llm_response",
|
||||
category="trace",
|
||||
content=langchain_to_openai_completion(message),
|
||||
metadata={
|
||||
"caller": caller,
|
||||
"usage": usage,
|
||||
"latency_ms": latency_ms,
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
|
||||
content = getattr(message, "content", "")
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
if caller != "lead_agent":
|
||||
return
|
||||
if tool_calls:
|
||||
self._put(
|
||||
event_type="ai_tool_call",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
elif isinstance(content, str) and content:
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={"finish_reason": "stop"},
|
||||
)
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm_error", category="trace", content=str(error))
|
||||
|
||||
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
tool_call_id = kwargs.get("tool_call_id")
|
||||
if tool_call_id:
|
||||
self._tool_call_ids[str(run_id)] = tool_call_id
|
||||
self._put(
|
||||
event_type="tool_start",
|
||||
category="trace",
|
||||
metadata={
|
||||
"tool_name": serialized.get("name", ""),
|
||||
"tool_call_id": tool_call_id,
|
||||
"args": str(input_str)[:2000],
|
||||
},
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
if isinstance(output, ToolMessage):
|
||||
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = output.name or kwargs.get("name", "")
|
||||
status = getattr(output, "status", "success") or "success"
|
||||
content_str = output.content if isinstance(output.content, str) else str(output.content)
|
||||
msg_content = output.model_dump()
|
||||
if msg_content.get("tool_call_id") != tool_call_id:
|
||||
msg_content["tool_call_id"] = tool_call_id
|
||||
else:
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
status = "success"
|
||||
content_str = str(output)
|
||||
msg_content = ToolMessage(
|
||||
content=content_str,
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status=status,
|
||||
).model_dump()
|
||||
|
||||
self._put(
|
||||
event_type="tool_end",
|
||||
category="trace",
|
||||
content=content_str,
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": status},
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
self._put(
|
||||
event_type="tool_error",
|
||||
category="trace",
|
||||
content=str(error),
|
||||
metadata={"tool_name": tool_name, "tool_call_id": tool_call_id},
|
||||
)
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content=ToolMessage(
|
||||
content=str(error),
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status="error",
|
||||
).model_dump(),
|
||||
metadata={"tool_name": tool_name, "status": "error"},
|
||||
)
|
||||
|
||||
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
if name == "summarization":
|
||||
data_dict = data if isinstance(data, dict) else {}
|
||||
self._put(
|
||||
event_type="summarization",
|
||||
category="trace",
|
||||
content=data_dict.get("summary", ""),
|
||||
metadata={
|
||||
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
|
||||
"replaced_count": data_dict.get("replaced_count", 0),
|
||||
},
|
||||
)
|
||||
self._put(
|
||||
event_type="middleware:summarize",
|
||||
category="middleware",
|
||||
content={"role": "system", "content": data_dict.get("summary", "")},
|
||||
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
|
||||
)
|
||||
return
|
||||
|
||||
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
|
||||
self._put(
|
||||
event_type=name,
|
||||
category="trace",
|
||||
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
|
||||
)
|
||||
|
||||
async def flush(self) -> None:
|
||||
if self._buffer:
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
await self._store.put_batch(batch)
|
||||
|
||||
def _put(
|
||||
self,
|
||||
*,
|
||||
event_type: str,
|
||||
category: str,
|
||||
content: Any = "",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
normalized_metadata = dict(metadata or {})
|
||||
if category != "message" and isinstance(content, str) and len(content) > self._max_trace_content:
|
||||
normalized_metadata["content_truncated"] = True
|
||||
normalized_metadata["original_content_length"] = len(content)
|
||||
content = content[: self._max_trace_content]
|
||||
|
||||
self._buffer.append(
|
||||
{
|
||||
"thread_id": self.thread_id,
|
||||
"run_id": self.run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": normalized_metadata,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
if len(self._buffer) >= self._flush_threshold:
|
||||
self._flush_sync()
|
||||
|
||||
def _flush_sync(self) -> None:
|
||||
if not self._buffer:
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
task = loop.create_task(self._flush_async(batch))
|
||||
task.add_done_callback(self._on_flush_done)
|
||||
|
||||
async def _flush_async(self, batch: list[dict[str, Any]]) -> None:
|
||||
try:
|
||||
await self._store.put_batch(batch)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to flush %d events for run %s; returning to buffer",
|
||||
len(batch),
|
||||
self.run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
self._buffer = batch + self._buffer
|
||||
|
||||
@staticmethod
|
||||
def _on_flush_done(task: asyncio.Task) -> None:
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc:
|
||||
logger.warning("Run event flush task failed: %s", exc)
|
||||
|
||||
def _identify_caller(self, kwargs: dict[str, Any]) -> str:
|
||||
for tag in kwargs.get("tags") or []:
|
||||
if isinstance(tag, str) and (
|
||||
tag.startswith("subagent:")
|
||||
or tag.startswith("middleware:")
|
||||
or tag == "lead_agent"
|
||||
):
|
||||
return tag
|
||||
return "lead_agent"
|
||||
|
||||
def _record_first_human_message(self, messages: list[Any], *, caller: str) -> None:
|
||||
if self._human_message_recorded:
|
||||
return
|
||||
|
||||
for message in messages:
|
||||
if not isinstance(message, HumanMessage):
|
||||
continue
|
||||
if message.name == "summary":
|
||||
continue
|
||||
self._put(
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={
|
||||
"caller": caller,
|
||||
"source": "chat_model_start",
|
||||
},
|
||||
)
|
||||
self._human_message_recorded = True
|
||||
return
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Title capture callback for runs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
|
||||
class RunTitleCallback(BaseCallbackHandler):
|
||||
"""Capture title generated by title middleware LLM calls or custom events."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._title: str | None = None
|
||||
|
||||
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if self._identify_caller(kwargs) != "middleware:title":
|
||||
return
|
||||
try:
|
||||
message = response.generations[0][0].message
|
||||
except (IndexError, AttributeError):
|
||||
return
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, str) and content:
|
||||
self._title = content.strip().strip('"').strip("'")[:200]
|
||||
|
||||
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if name not in {"title", "thread_title", "middleware:title"}:
|
||||
return
|
||||
if isinstance(data, str):
|
||||
self._title = data.strip()[:200]
|
||||
return
|
||||
if isinstance(data, dict):
|
||||
title = data.get("title")
|
||||
if isinstance(title, str):
|
||||
self._title = title.strip()[:200]
|
||||
|
||||
def title(self) -> str | None:
|
||||
return self._title
|
||||
|
||||
def _identify_caller(self, kwargs: dict[str, Any]) -> str:
|
||||
for tag in kwargs.get("tags") or []:
|
||||
if isinstance(tag, str) and (
|
||||
tag.startswith("subagent:")
|
||||
or tag.startswith("middleware:")
|
||||
or tag == "lead_agent"
|
||||
):
|
||||
return tag
|
||||
return "lead_agent"
|
||||
@@ -0,0 +1,122 @@
|
||||
"""Token and message summary callback for runs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunCompletionData:
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
llm_call_count: int = 0
|
||||
lead_agent_tokens: int = 0
|
||||
subagent_tokens: int = 0
|
||||
middleware_tokens: int = 0
|
||||
message_count: int = 0
|
||||
last_ai_message: str | None = None
|
||||
first_human_message: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"total_input_tokens": self.total_input_tokens,
|
||||
"total_output_tokens": self.total_output_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"llm_call_count": self.llm_call_count,
|
||||
"lead_agent_tokens": self.lead_agent_tokens,
|
||||
"subagent_tokens": self.subagent_tokens,
|
||||
"middleware_tokens": self.middleware_tokens,
|
||||
"message_count": self.message_count,
|
||||
"last_ai_message": self.last_ai_message,
|
||||
"first_human_message": self.first_human_message,
|
||||
}
|
||||
|
||||
|
||||
class RunTokenCallback(BaseCallbackHandler):
|
||||
"""Aggregate token and message summary data for one run."""
|
||||
|
||||
def __init__(self, *, track_token_usage: bool = True) -> None:
|
||||
super().__init__()
|
||||
self._track_token_usage = track_token_usage
|
||||
self._total_input_tokens = 0
|
||||
self._total_output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._llm_call_count = 0
|
||||
self._lead_agent_tokens = 0
|
||||
self._subagent_tokens = 0
|
||||
self._middleware_tokens = 0
|
||||
self._message_count = 0
|
||||
self._last_ai_message: str | None = None
|
||||
self._first_human_message: str | None = None
|
||||
|
||||
def set_first_human_message(self, content: str) -> None:
|
||||
self._first_human_message = content[:2000] if content else None
|
||||
|
||||
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
try:
|
||||
message = response.generations[0][0].message
|
||||
except (IndexError, AttributeError):
|
||||
return
|
||||
|
||||
self._record_ai_message(message, kwargs)
|
||||
if not self._track_token_usage:
|
||||
return
|
||||
|
||||
usage = dict(getattr(message, "usage_metadata", None) or {})
|
||||
input_tk = usage.get("input_tokens", 0) or 0
|
||||
output_tk = usage.get("output_tokens", 0) or 0
|
||||
total_tk = usage.get("total_tokens", 0) or input_tk + output_tk
|
||||
if total_tk <= 0:
|
||||
return
|
||||
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
|
||||
caller = self._identify_caller(kwargs)
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
def completion_data(self) -> RunCompletionData:
|
||||
return RunCompletionData(
|
||||
total_input_tokens=self._total_input_tokens,
|
||||
total_output_tokens=self._total_output_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
llm_call_count=self._llm_call_count,
|
||||
lead_agent_tokens=self._lead_agent_tokens,
|
||||
subagent_tokens=self._subagent_tokens,
|
||||
middleware_tokens=self._middleware_tokens,
|
||||
message_count=self._message_count,
|
||||
last_ai_message=self._last_ai_message,
|
||||
first_human_message=self._first_human_message,
|
||||
)
|
||||
|
||||
def _record_ai_message(self, message: Any, kwargs: dict[str, Any]) -> None:
|
||||
if self._identify_caller(kwargs) != "lead_agent":
|
||||
return
|
||||
if getattr(message, "tool_calls", None):
|
||||
return
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, str) and content:
|
||||
self._last_ai_message = content[:2000]
|
||||
self._message_count += 1
|
||||
|
||||
def _identify_caller(self, kwargs: dict[str, Any]) -> str:
|
||||
for tag in kwargs.get("tags") or []:
|
||||
if isinstance(tag, str) and (
|
||||
tag.startswith("subagent:")
|
||||
or tag.startswith("middleware:")
|
||||
or tag == "lead_agent"
|
||||
):
|
||||
return tag
|
||||
return "lead_agent"
|
||||
Reference in New Issue
Block a user