update the code with review comments
This commit is contained in:
@@ -194,7 +194,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|||||||
formatted = get_buffer_string(trimmed)
|
formatted = get_buffer_string(trimmed)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.model.invoke(
|
response = self.model.with_config(callbacks=[]).invoke(
|
||||||
self.summary_prompt.format(messages=formatted).rstrip(),
|
self.summary_prompt.format(messages=formatted).rstrip(),
|
||||||
config={
|
config={
|
||||||
"metadata": {"lc_source": "summarization"},
|
"metadata": {"lc_source": "summarization"},
|
||||||
@@ -223,7 +223,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|||||||
formatted = get_buffer_string(trimmed)
|
formatted = get_buffer_string(trimmed)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.model.ainvoke(
|
response = await self.model.with_config(callbacks=[]).ainvoke(
|
||||||
self.summary_prompt.format(messages=formatted).rstrip(),
|
self.summary_prompt.format(messages=formatted).rstrip(),
|
||||||
config={
|
config={
|
||||||
"metadata": {"lc_source": "summarization"},
|
"metadata": {"lc_source": "summarization"},
|
||||||
@@ -235,9 +235,11 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|||||||
return f"Error generating summary: {e!s}"
|
return f"Error generating summary: {e!s}"
|
||||||
|
|
||||||
def _extract_summary_text(self, response: Any) -> str:
|
def _extract_summary_text(self, response: Any) -> str:
|
||||||
summary_text = getattr(response, "content", None)
|
# Prefer .text which normalizes list content blocks (e.g. [{"type": "text", "text": "..."}]).
|
||||||
|
# Fall back to .content for non-LangChain responses.
|
||||||
|
summary_text = getattr(response, "text", None)
|
||||||
if summary_text is None:
|
if summary_text is None:
|
||||||
summary_text = getattr(response, "text", "")
|
summary_text = getattr(response, "content", "")
|
||||||
return summary_text.strip() if isinstance(summary_text, str) else str(summary_text).strip()
|
return summary_text.strip() if isinstance(summary_text, str) else str(summary_text).strip()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ def _middleware(
|
|||||||
) -> DeerFlowSummarizationMiddleware:
|
) -> DeerFlowSummarizationMiddleware:
|
||||||
model = MagicMock()
|
model = MagicMock()
|
||||||
model.invoke.return_value = AIMessage(content="compressed summary")
|
model.invoke.return_value = AIMessage(content="compressed summary")
|
||||||
|
model.with_config.return_value.invoke.return_value = AIMessage(content="compressed summary")
|
||||||
return DeerFlowSummarizationMiddleware(
|
return DeerFlowSummarizationMiddleware(
|
||||||
model=model,
|
model=model,
|
||||||
trigger=trigger,
|
trigger=trigger,
|
||||||
@@ -660,14 +661,16 @@ def test_build_new_messages_sets_hide_from_ui() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_create_summary_suppresses_callbacks() -> None:
|
def test_create_summary_suppresses_callbacks() -> None:
|
||||||
"""_create_summary must pass callbacks=[] to prevent the internal LLM call
|
"""_create_summary must bind callbacks=[] on the model AND pass callbacks=[]
|
||||||
from producing visible streaming events in the frontend (issue #2804)."""
|
in the invoke config to suppress inherited LangGraph stream callbacks."""
|
||||||
middleware = _middleware()
|
middleware = _middleware()
|
||||||
|
|
||||||
middleware._create_summary(_messages())
|
middleware._create_summary(_messages())
|
||||||
|
|
||||||
middleware.model.invoke.assert_called_once()
|
middleware.model.with_config.assert_called_once_with(callbacks=[])
|
||||||
call_config = middleware.model.invoke.call_args.kwargs.get("config") or middleware.model.invoke.call_args[1].get("config")
|
bound = middleware.model.with_config.return_value
|
||||||
|
bound.invoke.assert_called_once()
|
||||||
|
call_config = bound.invoke.call_args.kwargs.get("config") or bound.invoke.call_args[1].get("config")
|
||||||
assert call_config is not None
|
assert call_config is not None
|
||||||
assert call_config.get("callbacks") == []
|
assert call_config.get("callbacks") == []
|
||||||
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
|
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
|
||||||
@@ -675,15 +678,17 @@ def test_create_summary_suppresses_callbacks() -> None:
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_acreate_summary_suppresses_callbacks() -> None:
|
async def test_acreate_summary_suppresses_callbacks() -> None:
|
||||||
"""_acreate_summary must pass callbacks=[] to prevent the internal LLM call
|
"""_acreate_summary must bind callbacks=[] on the model AND pass callbacks=[]
|
||||||
from producing visible streaming events in the frontend (issue #2804)."""
|
in the ainvoke config to suppress inherited LangGraph stream callbacks."""
|
||||||
middleware = _middleware()
|
middleware = _middleware()
|
||||||
middleware.model.ainvoke = mock.AsyncMock(return_value=AIMessage(content="async summary"))
|
middleware.model.with_config.return_value.ainvoke = mock.AsyncMock(return_value=AIMessage(content="async summary"))
|
||||||
|
|
||||||
await middleware._acreate_summary(_messages())
|
await middleware._acreate_summary(_messages())
|
||||||
|
|
||||||
middleware.model.ainvoke.assert_called_once()
|
middleware.model.with_config.assert_called_once_with(callbacks=[])
|
||||||
call_config = middleware.model.ainvoke.call_args.kwargs.get("config") or middleware.model.ainvoke.call_args[1].get("config")
|
bound = middleware.model.with_config.return_value
|
||||||
|
bound.ainvoke.assert_called_once()
|
||||||
|
call_config = bound.ainvoke.call_args.kwargs.get("config") or bound.ainvoke.call_args[1].get("config")
|
||||||
assert call_config is not None
|
assert call_config is not None
|
||||||
assert call_config.get("callbacks") == []
|
assert call_config.get("callbacks") == []
|
||||||
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
|
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
|
||||||
@@ -718,3 +723,17 @@ def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatc
|
|||||||
|
|
||||||
queue.add_nowait.assert_called_once()
|
queue.add_nowait.assert_called_once()
|
||||||
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
|
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_summary_text_normalizes_list_content_blocks() -> None:
|
||||||
|
"""AIMessage.content can be a list of content blocks; _extract_summary_text
|
||||||
|
must normalize to plain text via the .text property instead of producing
|
||||||
|
a Python repr like [{'type': 'text', 'text': 'summary'}]."""
|
||||||
|
middleware = _middleware()
|
||||||
|
|
||||||
|
response = AIMessage(content=[{"type": "text", "text": "A summary of the chat."}])
|
||||||
|
assert middleware._extract_summary_text(response) == "A summary of the chat."
|
||||||
|
|
||||||
|
# Plain string content still works
|
||||||
|
response_str = AIMessage(content="Plain summary")
|
||||||
|
assert middleware._extract_summary_text(response_str) == "Plain summary"
|
||||||
|
|||||||
Reference in New Issue
Block a user