update the code with review comments

This commit is contained in:
Willem Jiang
2026-05-16 17:07:53 +08:00
parent ba99a23814
commit 7752e74e2b
2 changed files with 34 additions and 13 deletions
@@ -194,7 +194,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
formatted = get_buffer_string(trimmed) formatted = get_buffer_string(trimmed)
try: try:
response = self.model.invoke( response = self.model.with_config(callbacks=[]).invoke(
self.summary_prompt.format(messages=formatted).rstrip(), self.summary_prompt.format(messages=formatted).rstrip(),
config={ config={
"metadata": {"lc_source": "summarization"}, "metadata": {"lc_source": "summarization"},
@@ -223,7 +223,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
formatted = get_buffer_string(trimmed) formatted = get_buffer_string(trimmed)
try: try:
response = await self.model.ainvoke( response = await self.model.with_config(callbacks=[]).ainvoke(
self.summary_prompt.format(messages=formatted).rstrip(), self.summary_prompt.format(messages=formatted).rstrip(),
config={ config={
"metadata": {"lc_source": "summarization"}, "metadata": {"lc_source": "summarization"},
@@ -235,9 +235,11 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
return f"Error generating summary: {e!s}" return f"Error generating summary: {e!s}"
def _extract_summary_text(self, response: Any) -> str: def _extract_summary_text(self, response: Any) -> str:
summary_text = getattr(response, "content", None) # Prefer .text which normalizes list content blocks (e.g. [{"type": "text", "text": "..."}]).
# Fall back to .content for non-LangChain responses.
summary_text = getattr(response, "text", None)
if summary_text is None: if summary_text is None:
summary_text = getattr(response, "text", "") summary_text = getattr(response, "content", "")
return summary_text.strip() if isinstance(summary_text, str) else str(summary_text).strip() return summary_text.strip() if isinstance(summary_text, str) else str(summary_text).strip()
@override @override
+28 -9
View File
@@ -57,6 +57,7 @@ def _middleware(
) -> DeerFlowSummarizationMiddleware: ) -> DeerFlowSummarizationMiddleware:
model = MagicMock() model = MagicMock()
model.invoke.return_value = AIMessage(content="compressed summary") model.invoke.return_value = AIMessage(content="compressed summary")
model.with_config.return_value.invoke.return_value = AIMessage(content="compressed summary")
return DeerFlowSummarizationMiddleware( return DeerFlowSummarizationMiddleware(
model=model, model=model,
trigger=trigger, trigger=trigger,
@@ -660,14 +661,16 @@ def test_build_new_messages_sets_hide_from_ui() -> None:
def test_create_summary_suppresses_callbacks() -> None: def test_create_summary_suppresses_callbacks() -> None:
"""_create_summary must pass callbacks=[] to prevent the internal LLM call """_create_summary must bind callbacks=[] on the model AND pass callbacks=[]
from producing visible streaming events in the frontend (issue #2804).""" in the invoke config to suppress inherited LangGraph stream callbacks."""
middleware = _middleware() middleware = _middleware()
middleware._create_summary(_messages()) middleware._create_summary(_messages())
middleware.model.invoke.assert_called_once() middleware.model.with_config.assert_called_once_with(callbacks=[])
call_config = middleware.model.invoke.call_args.kwargs.get("config") or middleware.model.invoke.call_args[1].get("config") bound = middleware.model.with_config.return_value
bound.invoke.assert_called_once()
call_config = bound.invoke.call_args.kwargs.get("config") or bound.invoke.call_args[1].get("config")
assert call_config is not None assert call_config is not None
assert call_config.get("callbacks") == [] assert call_config.get("callbacks") == []
assert call_config.get("metadata", {}).get("lc_source") == "summarization" assert call_config.get("metadata", {}).get("lc_source") == "summarization"
@@ -675,15 +678,17 @@ def test_create_summary_suppresses_callbacks() -> None:
@pytest.mark.anyio @pytest.mark.anyio
async def test_acreate_summary_suppresses_callbacks() -> None: async def test_acreate_summary_suppresses_callbacks() -> None:
"""_acreate_summary must pass callbacks=[] to prevent the internal LLM call """_acreate_summary must bind callbacks=[] on the model AND pass callbacks=[]
from producing visible streaming events in the frontend (issue #2804).""" in the ainvoke config to suppress inherited LangGraph stream callbacks."""
middleware = _middleware() middleware = _middleware()
middleware.model.ainvoke = mock.AsyncMock(return_value=AIMessage(content="async summary")) middleware.model.with_config.return_value.ainvoke = mock.AsyncMock(return_value=AIMessage(content="async summary"))
await middleware._acreate_summary(_messages()) await middleware._acreate_summary(_messages())
middleware.model.ainvoke.assert_called_once() middleware.model.with_config.assert_called_once_with(callbacks=[])
call_config = middleware.model.ainvoke.call_args.kwargs.get("config") or middleware.model.ainvoke.call_args[1].get("config") bound = middleware.model.with_config.return_value
bound.ainvoke.assert_called_once()
call_config = bound.ainvoke.call_args.kwargs.get("config") or bound.ainvoke.call_args[1].get("config")
assert call_config is not None assert call_config is not None
assert call_config.get("callbacks") == [] assert call_config.get("callbacks") == []
assert call_config.get("metadata", {}).get("lc_source") == "summarization" assert call_config.get("metadata", {}).get("lc_source") == "summarization"
@@ -718,3 +723,17 @@ def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatc
queue.add_nowait.assert_called_once() queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice" assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
def test_extract_summary_text_normalizes_list_content_blocks() -> None:
"""AIMessage.content can be a list of content blocks; _extract_summary_text
must normalize to plain text via the .text property instead of producing
a Python repr like [{'type': 'text', 'text': 'summary'}]."""
middleware = _middleware()
response = AIMessage(content=[{"type": "text", "text": "A summary of the chat."}])
assert middleware._extract_summary_text(response) == "A summary of the chat."
# Plain string content still works
response_str = AIMessage(content="Plain summary")
assert middleware._extract_summary_text(response_str) == "Plain summary"