update the code with review comments
This commit is contained in:
@@ -194,7 +194,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
formatted = get_buffer_string(trimmed)
|
||||
|
||||
try:
|
||||
response = self.model.invoke(
|
||||
response = self.model.with_config(callbacks=[]).invoke(
|
||||
self.summary_prompt.format(messages=formatted).rstrip(),
|
||||
config={
|
||||
"metadata": {"lc_source": "summarization"},
|
||||
@@ -223,7 +223,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
formatted = get_buffer_string(trimmed)
|
||||
|
||||
try:
|
||||
response = await self.model.ainvoke(
|
||||
response = await self.model.with_config(callbacks=[]).ainvoke(
|
||||
self.summary_prompt.format(messages=formatted).rstrip(),
|
||||
config={
|
||||
"metadata": {"lc_source": "summarization"},
|
||||
@@ -235,9 +235,11 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
def _extract_summary_text(self, response: Any) -> str:
|
||||
summary_text = getattr(response, "content", None)
|
||||
# Prefer .text which normalizes list content blocks (e.g. [{"type": "text", "text": "..."}]).
|
||||
# Fall back to .content for non-LangChain responses.
|
||||
summary_text = getattr(response, "text", None)
|
||||
if summary_text is None:
|
||||
summary_text = getattr(response, "text", "")
|
||||
summary_text = getattr(response, "content", "")
|
||||
return summary_text.strip() if isinstance(summary_text, str) else str(summary_text).strip()
|
||||
|
||||
@override
|
||||
|
||||
@@ -57,6 +57,7 @@ def _middleware(
|
||||
) -> DeerFlowSummarizationMiddleware:
|
||||
model = MagicMock()
|
||||
model.invoke.return_value = AIMessage(content="compressed summary")
|
||||
model.with_config.return_value.invoke.return_value = AIMessage(content="compressed summary")
|
||||
return DeerFlowSummarizationMiddleware(
|
||||
model=model,
|
||||
trigger=trigger,
|
||||
@@ -660,14 +661,16 @@ def test_build_new_messages_sets_hide_from_ui() -> None:
|
||||
|
||||
|
||||
def test_create_summary_suppresses_callbacks() -> None:
|
||||
"""_create_summary must pass callbacks=[] to prevent the internal LLM call
|
||||
from producing visible streaming events in the frontend (issue #2804)."""
|
||||
"""_create_summary must bind callbacks=[] on the model AND pass callbacks=[]
|
||||
in the invoke config to suppress inherited LangGraph stream callbacks."""
|
||||
middleware = _middleware()
|
||||
|
||||
middleware._create_summary(_messages())
|
||||
|
||||
middleware.model.invoke.assert_called_once()
|
||||
call_config = middleware.model.invoke.call_args.kwargs.get("config") or middleware.model.invoke.call_args[1].get("config")
|
||||
middleware.model.with_config.assert_called_once_with(callbacks=[])
|
||||
bound = middleware.model.with_config.return_value
|
||||
bound.invoke.assert_called_once()
|
||||
call_config = bound.invoke.call_args.kwargs.get("config") or bound.invoke.call_args[1].get("config")
|
||||
assert call_config is not None
|
||||
assert call_config.get("callbacks") == []
|
||||
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
|
||||
@@ -675,15 +678,17 @@ def test_create_summary_suppresses_callbacks() -> None:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_acreate_summary_suppresses_callbacks() -> None:
|
||||
"""_acreate_summary must pass callbacks=[] to prevent the internal LLM call
|
||||
from producing visible streaming events in the frontend (issue #2804)."""
|
||||
"""_acreate_summary must bind callbacks=[] on the model AND pass callbacks=[]
|
||||
in the ainvoke config to suppress inherited LangGraph stream callbacks."""
|
||||
middleware = _middleware()
|
||||
middleware.model.ainvoke = mock.AsyncMock(return_value=AIMessage(content="async summary"))
|
||||
middleware.model.with_config.return_value.ainvoke = mock.AsyncMock(return_value=AIMessage(content="async summary"))
|
||||
|
||||
await middleware._acreate_summary(_messages())
|
||||
|
||||
middleware.model.ainvoke.assert_called_once()
|
||||
call_config = middleware.model.ainvoke.call_args.kwargs.get("config") or middleware.model.ainvoke.call_args[1].get("config")
|
||||
middleware.model.with_config.assert_called_once_with(callbacks=[])
|
||||
bound = middleware.model.with_config.return_value
|
||||
bound.ainvoke.assert_called_once()
|
||||
call_config = bound.ainvoke.call_args.kwargs.get("config") or bound.ainvoke.call_args[1].get("config")
|
||||
assert call_config is not None
|
||||
assert call_config.get("callbacks") == []
|
||||
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
|
||||
@@ -718,3 +723,17 @@ def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatc
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
|
||||
|
||||
|
||||
def test_extract_summary_text_normalizes_list_content_blocks() -> None:
|
||||
"""AIMessage.content can be a list of content blocks; _extract_summary_text
|
||||
must normalize to plain text via the .text property instead of producing
|
||||
a Python repr like [{'type': 'text', 'text': 'summary'}]."""
|
||||
middleware = _middleware()
|
||||
|
||||
response = AIMessage(content=[{"type": "text", "text": "A summary of the chat."}])
|
||||
assert middleware._extract_summary_text(response) == "A summary of the chat."
|
||||
|
||||
# Plain string content still works
|
||||
response_str = AIMessage(content="Plain summary")
|
||||
assert middleware._extract_summary_text(response_str) == "Plain summary"
|
||||
|
||||
Reference in New Issue
Block a user