fix(replay-e2e): key fixtures by caller and conversation (#3453)

* add caller identity in replay e2e

* make format

* fix(replay-e2e): stabilize title caller replay

* fix(replay-e2e): use captured caller without run manager

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Nan Gao
2026-06-09 15:58:31 +02:00
committed by GitHub
parent 37337b77f9
commit 63ce88f874
7 changed files with 308 additions and 34 deletions
+137 -13
View File
@@ -2,14 +2,19 @@
record/replay e2e (mirrors open-design's ``mocks/`` golden traces).
A fixture is a JSON file capturing the *real* model calls of one scenario,
keyed by a normalized hash of the **input** each call received::
keyed by a normalized hash of the **caller + input** each call received::
{
"scenario": "write_read_file",
"mode": "ultra",
"model": "gpt-5.5",
"turns": [
{"input_hash": "<sha256>", "input_preview": "...", "output": <message dict>},
{
"caller": "lead_agent",
"conversation_hash": "<sha256>",
"input_hash": "<sha256>",
"output": <message dict>,
},
...
]
}
@@ -21,8 +26,11 @@ A real run makes model calls from several callers — the lead agent's own turns
and their count/order is not something we want a replay to depend on. Matching by
a normalized hash of the *input messages* means each call gets back exactly the
output that was recorded for that input, regardless of order or which middleware
issued it. That keeps the in-graph, deterministic title call part of the
recording; memory/summarization, by contrast, are disabled in the replay config
issued it. The caller name (``lead_agent``, ``middleware:title``,
``suggest_agent``, ``subagent:*``, ...) is included so two different model
callers with the same conversation text do not compete for the same replay
bucket. That keeps the in-graph, deterministic title call part of the recording;
memory/summarization, by contrast, are disabled in the replay config
(``_replay_fixture.py``) because their background, debounced timing is not
reproducible across runs.
@@ -67,7 +75,7 @@ from collections import deque
from collections.abc import Iterator
from typing import Any
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.callbacks import BaseCallbackHandler, CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, messages_from_dict
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
@@ -75,6 +83,14 @@ from langchain_core.runnables import Runnable
from pydantic import PrivateAttr
_FIXTURE_ENV = "DEERFLOW_REPLAY_FIXTURE"
_DEFAULT_CALLER = "lead_agent"
_CALLER_TAG_PREFIXES = ("middleware:", "subagent:")
_CALLER_NAME_ALIASES = {
# TitleMiddleware uses this run_name and tags the call as middleware:title.
# Some execution paths do not preserve the tag down to the model callback,
# so keep the run_name and tag in the same replay namespace.
"title_agent": "middleware:title",
}
# Process-wide record of replay misses. A miss raises inside the model, but the
# gateway's LLMErrorHandlingMiddleware swallows it into a normal assistant error
@@ -94,6 +110,30 @@ def reset_replay_misses() -> None:
_replay_misses.clear()
def _normalize_caller(caller: str | None) -> str:
value = _normalize_text(str(caller or "").strip())
if not value:
return _DEFAULT_CALLER
return _CALLER_NAME_ALIASES.get(value, value)
def _caller_from_tags(tags: list[str] | None) -> str | None:
for tag in tags or []:
if isinstance(tag, str) and (tag == _DEFAULT_CALLER or tag.startswith(_CALLER_TAG_PREFIXES)):
return tag
return None
def caller_identity(*, name: str | None = None, tags: list[str] | None = None) -> str:
"""Stable model-caller identity shared by record and replay.
Tags win because graph middleware and subagents already use them as the
explicit caller marker. ``run_name`` is exposed to callbacks as ``name`` and
covers route-level callers such as ``suggest_agent``.
"""
return _normalize_caller(_caller_from_tags(tags) or name)
# Volatile substrings that differ between a recording run and a replay run but
# carry no semantic weight for matching. Normalized to stable placeholders
# before hashing so the same logical input hashes identically across processes.
@@ -172,10 +212,30 @@ def _canonical_messages(messages: list[BaseMessage]) -> str:
def hash_messages(messages: list[BaseMessage]) -> str:
"""Stable hash of a model call's input. Shared by recorder and replayer."""
"""Legacy stable hash of only a model call's conversation input."""
return hashlib.sha256(_canonical_messages(messages).encode("utf-8")).hexdigest()
def hash_replay_input(messages: list[BaseMessage], *, caller: str | None) -> str:
"""Stable replay key for a caller-specific model input."""
return hash_input_key(hash_messages(messages), caller=caller)
def hash_input_key(conversation_hash: str, *, caller: str | None) -> str:
"""Namespace a conversation hash by caller identity.
Keeping this as ``hash(caller + legacy_conversation_hash)`` lets existing
fixtures migrate without a live-model re-record: their old ``input_hash`` is
exactly the conversation hash.
"""
payload = json.dumps(
{"caller": _normalize_caller(caller), "conversation_hash": conversation_hash},
sort_keys=True,
ensure_ascii=False,
)
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def _load_fixture(fixture_path: str) -> dict[str, deque[AIMessage]]:
with open(fixture_path, encoding="utf-8") as handle:
payload = json.load(handle)
@@ -199,24 +259,54 @@ class ReplayChatModel(BaseChatModel):
_table: dict[str, deque] = PrivateAttr(default_factory=dict)
_fixture_path: str = PrivateAttr(default="")
_run_callers: dict[str, str] = PrivateAttr(default_factory=dict)
def __init__(self, **kwargs: Any) -> None:
# Ignore provider noise the factory forwards from config (model, api_key,
# base_url, ...). Fixture path comes from the ``fixture`` kwarg or env.
fixture_path = kwargs.pop("fixture", None) or os.environ.get(_FIXTURE_ENV)
super().__init__()
callbacks = kwargs.pop("callbacks", None)
super().__init__(callbacks=callbacks)
if not fixture_path:
raise ValueError(f"ReplayChatModel needs a fixture path via the ``fixture`` kwarg or ${_FIXTURE_ENV}")
self._fixture_path = fixture_path
self._table = _load_fixture(fixture_path)
self.callbacks = [*(self.callbacks or []), _ReplayCallerCapture(self._run_callers)]
@property
def _llm_type(self) -> str:
return "deerflow-replay"
def _match(self, messages: list[BaseMessage]) -> AIMessage:
key = hash_messages(messages)
def _caller_from_run_manager(self, run_manager: CallbackManagerForLLMRun | None) -> str:
if run_manager is None:
if len(self._run_callers) == 1:
# Some async LangGraph paths fire on_chat_model_start with the
# caller metadata but invoke the model implementation without a
# run_manager. When there is only one pending start event, it is
# the current call; use it so record/replay share the same
# caller key.
return self._run_callers.pop(next(iter(self._run_callers)))
return _DEFAULT_CALLER
run_id = str(getattr(run_manager, "run_id", ""))
caller = self._run_callers.pop(run_id, None)
if caller:
return caller
return caller_identity(
name=getattr(run_manager, "run_name", None) or getattr(run_manager, "name", None),
tags=getattr(run_manager, "tags", None),
)
def _match(self, messages: list[BaseMessage], run_manager: CallbackManagerForLLMRun | None = None) -> AIMessage:
caller = self._caller_from_run_manager(run_manager)
key = hash_replay_input(messages, caller=caller)
bucket = self._table.get(key)
if not bucket:
# Backward compatibility for fixtures recorded before caller-aware
# keys. New recordings write caller-aware ``input_hash`` values.
legacy_key = hash_messages(messages)
bucket = self._table.get(legacy_key)
if bucket:
key = legacy_key
if not bucket:
_replay_misses.append(key)
preview = _canonical_messages(messages)
@@ -224,6 +314,7 @@ class ReplayChatModel(BaseChatModel):
f"replay miss: no recorded output for input hash {key} in {self._fixture_path!r}. "
"The replayed run diverged from the recording (graph changed, a non-deterministic tool result "
"altered a downstream input, or a volatile field slipped past normalization). "
f"Caller: {caller!r}. "
f"Known hashes: {sorted(self._table)}. "
f"Normalized input (first 800 chars): {preview[:800]!r}"
)
@@ -236,7 +327,7 @@ class ReplayChatModel(BaseChatModel):
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=self._match(messages))])
return ChatResult(generations=[ChatGeneration(message=self._match(messages, run_manager))])
def _stream(
self,
@@ -245,9 +336,16 @@ class ReplayChatModel(BaseChatModel):
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
turn = self._match(messages)
turn = self._match(messages, run_manager)
text = turn.content if isinstance(turn.content, str) else ""
chunk = ChatGenerationChunk(message=AIMessageChunk(content=turn.content, tool_calls=turn.tool_calls, additional_kwargs=turn.additional_kwargs, id=turn.id))
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content=turn.content,
tool_calls=turn.tool_calls,
additional_kwargs=turn.additional_kwargs,
id=turn.id,
)
)
if run_manager is not None and text:
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
@@ -256,5 +354,31 @@ class ReplayChatModel(BaseChatModel):
return self
class _ReplayCallerCapture(BaseCallbackHandler):
def __init__(self, run_callers: dict[str, str]) -> None:
self._run_callers = run_callers
def on_chat_model_start(
self,
serialized: dict,
messages: list[list[BaseMessage]],
*,
run_id: Any = None,
tags: list[str] | None = None,
name: str | None = None,
**kwargs: Any,
) -> None:
if run_id is not None:
self._run_callers[str(run_id)] = caller_identity(name=name, tags=tags)
# Re-export so the recorder shares the exact hashing logic.
__all__ = ["ReplayChatModel", "hash_messages", "replay_misses", "reset_replay_misses"]
__all__ = [
"ReplayChatModel",
"caller_identity",
"hash_input_key",
"hash_messages",
"hash_replay_input",
"replay_misses",
"reset_replay_misses",
]