mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 09:25:57 +00:00
fix(chat): preserve messages after summarization (#3280)
* fix(chat): preserve messages after summarization * make format * fix(chat): address summarization review comments
This commit is contained in:
@@ -476,6 +476,24 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
|||||||
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
|
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_summarization_middleware_uses_frontend_supported_update_key(monkeypatch):
|
||||||
|
"""LangGraph update keys use the middleware class name plus hook name."""
|
||||||
|
|
||||||
|
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
|
||||||
|
app_config.summarization = SummarizationConfig(enabled=True)
|
||||||
|
app_config.memory = MemoryConfig(enabled=False)
|
||||||
|
|
||||||
|
fake_model = MagicMock()
|
||||||
|
fake_model.with_config.return_value = fake_model
|
||||||
|
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: fake_model)
|
||||||
|
|
||||||
|
middleware = lead_agent_module._create_summarization_middleware(app_config=app_config)
|
||||||
|
|
||||||
|
assert middleware is not None
|
||||||
|
update_key = f"{type(middleware).__name__}.before_model"
|
||||||
|
assert update_key == "DeerFlowSummarizationMiddleware.before_model"
|
||||||
|
|
||||||
|
|
||||||
def test_create_summarization_middleware_threads_resolved_app_config_to_model(monkeypatch):
|
def test_create_summarization_middleware_threads_resolved_app_config_to_model(monkeypatch):
|
||||||
fallback_app_config = _make_app_config([_make_model("fallback-model", supports_thinking=False)])
|
fallback_app_config = _make_app_config([_make_model("fallback-model", supports_thinking=False)])
|
||||||
fallback_app_config.summarization = SummarizationConfig(enabled=True, model_name="fallback-model")
|
fallback_app_config.summarization = SummarizationConfig(enabled=True, model_name="fallback-model")
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ from unittest import mock
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
|
||||||
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||||
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware
|
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware
|
||||||
@@ -22,6 +25,23 @@ def _messages() -> list:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class _StaticChatModel(BaseChatModel):
|
||||||
|
text: str = "ok"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "static-test-chat-model"
|
||||||
|
|
||||||
|
def bind_tools(self, tools, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=self.text))])
|
||||||
|
|
||||||
|
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||||
|
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
|
def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
|
||||||
return HumanMessage(
|
return HumanMessage(
|
||||||
content="<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
|
content="<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
|
||||||
@@ -114,6 +134,32 @@ def test_before_summarization_hook_receives_messages_before_compression() -> Non
|
|||||||
assert result["messages"][1].content.startswith("Here is a summary")
|
assert result["messages"][1].content.startswith("Here is a summary")
|
||||||
|
|
||||||
|
|
||||||
|
def test_summarization_middleware_emits_frontend_update_key_in_agent_stream() -> None:
|
||||||
|
middleware = DeerFlowSummarizationMiddleware(
|
||||||
|
model=_StaticChatModel(text="compressed summary"),
|
||||||
|
trigger=("messages", 4),
|
||||||
|
keep=("messages", 2),
|
||||||
|
token_counter=len,
|
||||||
|
)
|
||||||
|
agent = create_agent(
|
||||||
|
model=_StaticChatModel(text="done"),
|
||||||
|
tools=[],
|
||||||
|
middleware=[middleware],
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = list(agent.stream({"messages": _messages()}, stream_mode="updates"))
|
||||||
|
update = next(
|
||||||
|
(chunk["DeerFlowSummarizationMiddleware.before_model"] for chunk in chunks if "DeerFlowSummarizationMiddleware.before_model" in chunk),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert update is not None
|
||||||
|
emitted = update["messages"]
|
||||||
|
assert isinstance(emitted[0], RemoveMessage)
|
||||||
|
assert emitted[1].name == "summary"
|
||||||
|
assert emitted[1].content == ("Here is a summary of the conversation to date:\n\ncompressed summary")
|
||||||
|
|
||||||
|
|
||||||
def test_dynamic_context_reminder_is_preserved_across_summarization() -> None:
|
def test_dynamic_context_reminder_is_preserved_across_summarization() -> None:
|
||||||
captured: list[SummarizationEvent] = []
|
captured: list[SummarizationEvent] = []
|
||||||
middleware = _middleware(before_summarization=[captured.append])
|
middleware = _middleware(before_summarization=[captured.append])
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import { getAPIClient } from "../api";
|
|||||||
import { fetch } from "../api/fetcher";
|
import { fetch } from "../api/fetcher";
|
||||||
import { getBackendBaseURL } from "../config";
|
import { getBackendBaseURL } from "../config";
|
||||||
import { useI18n } from "../i18n/hooks";
|
import { useI18n } from "../i18n/hooks";
|
||||||
|
import { isHiddenFromUIMessage } from "../messages/utils";
|
||||||
import type { FileInMessage } from "../messages/utils";
|
import type { FileInMessage } from "../messages/utils";
|
||||||
import type { LocalSettings } from "../settings";
|
import type { LocalSettings } from "../settings";
|
||||||
import { useUpdateSubtask } from "../tasks/context";
|
import { useUpdateSubtask } from "../tasks/context";
|
||||||
@@ -54,6 +55,11 @@ function isNonEmptyString(value: string | undefined): value is string {
|
|||||||
return typeof value === "string" && value.length > 0;
|
return typeof value === "string" && value.length > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const SUMMARIZATION_MIDDLEWARE_UPDATE_KEYS = new Set([
|
||||||
|
"SummarizationMiddleware.before_model",
|
||||||
|
"DeerFlowSummarizationMiddleware.before_model",
|
||||||
|
]);
|
||||||
|
|
||||||
function messageIdentity(message: Message): string | undefined {
|
function messageIdentity(message: Message): string | undefined {
|
||||||
if (
|
if (
|
||||||
"tool_call_id" in message &&
|
"tool_call_id" in message &&
|
||||||
@@ -70,17 +76,33 @@ function messageIdentity(message: Message): string | undefined {
|
|||||||
|
|
||||||
function dedupeMessagesByIdentity(messages: Message[]): Message[] {
|
function dedupeMessagesByIdentity(messages: Message[]): Message[] {
|
||||||
const lastIndexByIdentity = new Map<string, number>();
|
const lastIndexByIdentity = new Map<string, number>();
|
||||||
|
const lastVisibleIndexByIdentity = new Map<string, number>();
|
||||||
|
|
||||||
|
// This is a UI-display dedupe rule, not a general LangChain message-stream
|
||||||
|
// contract. Hidden messages that share an identity with a visible message are
|
||||||
|
// treated as control messages for this merged view; hidden messages carrying
|
||||||
|
// independent tracing/task semantics should use a distinct id or a custom
|
||||||
|
// stream/state channel instead of relying on message dedupe preservation.
|
||||||
messages.forEach((message, index) => {
|
messages.forEach((message, index) => {
|
||||||
const identity = messageIdentity(message);
|
const identity = messageIdentity(message);
|
||||||
if (identity) {
|
if (identity) {
|
||||||
lastIndexByIdentity.set(identity, index);
|
lastIndexByIdentity.set(identity, index);
|
||||||
|
if (!isHiddenFromUIMessage(message)) {
|
||||||
|
lastVisibleIndexByIdentity.set(identity, index);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return messages.filter((message, index) => {
|
return messages.filter((message, index) => {
|
||||||
const identity = messageIdentity(message);
|
const identity = messageIdentity(message);
|
||||||
return !identity || lastIndexByIdentity.get(identity) === index;
|
if (!identity) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
const visibleIndex = lastVisibleIndexByIdentity.get(identity);
|
||||||
|
if (visibleIndex !== undefined) {
|
||||||
|
return visibleIndex === index;
|
||||||
|
}
|
||||||
|
return lastIndexByIdentity.get(identity) === index;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,8 +124,15 @@ export function mergeMessages(
|
|||||||
threadMessages: Message[],
|
threadMessages: Message[],
|
||||||
optimisticMessages: Message[],
|
optimisticMessages: Message[],
|
||||||
): Message[] {
|
): Message[] {
|
||||||
|
// Only visible live messages should trim overlapping history. Hidden messages
|
||||||
|
// are UI control messages in this path, not observability records; any hidden
|
||||||
|
// message that must survive as task/tracing data should use custom events or a
|
||||||
|
// separate state channel instead of participating in this overlap heuristic.
|
||||||
const threadMessageIds = new Set(
|
const threadMessageIds = new Set(
|
||||||
threadMessages.map(messageIdentity).filter(isNonEmptyString),
|
threadMessages
|
||||||
|
.filter((message) => !isHiddenFromUIMessage(message))
|
||||||
|
.map(messageIdentity)
|
||||||
|
.filter(isNonEmptyString),
|
||||||
);
|
);
|
||||||
|
|
||||||
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
|
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
|
||||||
@@ -154,6 +183,30 @@ export function getVisibleOptimisticMessages(
|
|||||||
return optimisticMessages;
|
return optimisticMessages;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getSummarizationMiddlewareMessages(
|
||||||
|
data: unknown,
|
||||||
|
): Message[] | undefined {
|
||||||
|
if (typeof data !== "object" || data === null) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const [key, update] of Object.entries(data)) {
|
||||||
|
if (!SUMMARIZATION_MIDDLEWARE_UPDATE_KEYS.has(key)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (typeof update !== "object" || update === null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const messages = Reflect.get(update, "messages");
|
||||||
|
if (Array.isArray(messages)) {
|
||||||
|
return [...messages] as Message[];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
export function upsertThreadInSearchCache(
|
export function upsertThreadInSearchCache(
|
||||||
queryClient: QueryClient,
|
queryClient: QueryClient,
|
||||||
thread: AgentThread,
|
thread: AgentThread,
|
||||||
@@ -319,24 +372,25 @@ export function useThreadStream({
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
onUpdateEvent(data) {
|
onUpdateEvent(data) {
|
||||||
if (data["SummarizationMiddleware.before_model"]) {
|
const _messages = getSummarizationMiddlewareMessages(data);
|
||||||
const _messages = [
|
if (_messages && _messages.length >= 2) {
|
||||||
...(data["SummarizationMiddleware.before_model"].messages ?? []),
|
|
||||||
];
|
|
||||||
|
|
||||||
if (_messages.length < 2) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (const m of _messages) {
|
for (const m of _messages) {
|
||||||
if (m.name === "summary" && m.type === "human") {
|
if (m.name === "summary" && m.type === "human") {
|
||||||
summarizedRef.current?.add(m.id ?? "");
|
summarizedRef.current?.add(m.id ?? "");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const _lastKeepMessage = _messages[2];
|
const firstRetainedVisibleIdentity = _messages
|
||||||
|
.filter((message) => message.type !== "remove")
|
||||||
|
.filter((message) => !isHiddenFromUIMessage(message))
|
||||||
|
.map(messageIdentity)
|
||||||
|
.find(isNonEmptyString);
|
||||||
const _currentMessages = [...messagesRef.current];
|
const _currentMessages = [...messagesRef.current];
|
||||||
const _movedMessages: Message[] = [];
|
const _movedMessages: Message[] = [];
|
||||||
for (const m of _currentMessages) {
|
for (const m of _currentMessages) {
|
||||||
if (m.id !== undefined && m.id === _lastKeepMessage?.id) {
|
if (
|
||||||
|
firstRetainedVisibleIdentity &&
|
||||||
|
messageIdentity(m) === firstRetainedVisibleIdentity
|
||||||
|
) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (!summarizedRef.current?.has(m.id ?? "")) {
|
if (!summarizedRef.current?.has(m.id ?? "")) {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import type { Message } from "@langchain/langgraph-sdk";
|
|||||||
import { expect, test } from "vitest";
|
import { expect, test } from "vitest";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
getSummarizationMiddlewareMessages,
|
||||||
getVisibleOptimisticMessages,
|
getVisibleOptimisticMessages,
|
||||||
mergeMessages,
|
mergeMessages,
|
||||||
} from "@/core/threads/hooks";
|
} from "@/core/threads/hooks";
|
||||||
@@ -66,6 +67,104 @@ test("mergeMessages deduplicates tool messages by tool_call_id", () => {
|
|||||||
expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]);
|
expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("mergeMessages keeps a visible history message when a hidden live message reuses its id", () => {
|
||||||
|
const historyHuman = {
|
||||||
|
id: "human-1",
|
||||||
|
type: "human",
|
||||||
|
content: "visible user prompt",
|
||||||
|
} as Message;
|
||||||
|
const hiddenReminder = {
|
||||||
|
id: "human-1",
|
||||||
|
type: "human",
|
||||||
|
content: "<system-reminder>hidden</system-reminder>",
|
||||||
|
additional_kwargs: { hide_from_ui: true },
|
||||||
|
} as Message;
|
||||||
|
const liveAi = {
|
||||||
|
id: "ai-1",
|
||||||
|
type: "ai",
|
||||||
|
content: "live answer",
|
||||||
|
} as Message;
|
||||||
|
|
||||||
|
expect(mergeMessages([historyHuman], [hiddenReminder, liveAi], [])).toEqual([
|
||||||
|
historyHuman,
|
||||||
|
liveAi,
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("mergeMessages lets a visible live message replace overlapping hidden history", () => {
|
||||||
|
const hiddenHistoryHuman = {
|
||||||
|
id: "human-1",
|
||||||
|
type: "human",
|
||||||
|
content: "<system-reminder>hidden</system-reminder>",
|
||||||
|
additional_kwargs: { hide_from_ui: true },
|
||||||
|
} as Message;
|
||||||
|
const liveHuman = {
|
||||||
|
id: "human-1",
|
||||||
|
type: "human",
|
||||||
|
content: "visible user prompt",
|
||||||
|
} as Message;
|
||||||
|
|
||||||
|
expect(mergeMessages([hiddenHistoryHuman], [liveHuman], [])).toEqual([
|
||||||
|
liveHuman,
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("getSummarizationMiddlewareMessages matches DeerFlow summarization update keys", () => {
|
||||||
|
const removeAll = {
|
||||||
|
id: "__remove_all__",
|
||||||
|
type: "remove",
|
||||||
|
content: "",
|
||||||
|
} as Message;
|
||||||
|
const summary = {
|
||||||
|
id: "summary-1",
|
||||||
|
type: "human",
|
||||||
|
name: "summary",
|
||||||
|
content: "summary",
|
||||||
|
} as Message;
|
||||||
|
|
||||||
|
expect(
|
||||||
|
getSummarizationMiddlewareMessages({
|
||||||
|
"DeerFlowSummarizationMiddleware.before_model": {
|
||||||
|
messages: [removeAll, summary],
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
).toEqual([removeAll, summary]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("getSummarizationMiddlewareMessages matches base LangChain summarization update keys", () => {
|
||||||
|
const summary = {
|
||||||
|
id: "summary-1",
|
||||||
|
type: "human",
|
||||||
|
name: "summary",
|
||||||
|
content: "summary",
|
||||||
|
} as Message;
|
||||||
|
|
||||||
|
expect(
|
||||||
|
getSummarizationMiddlewareMessages({
|
||||||
|
"SummarizationMiddleware.before_model": {
|
||||||
|
messages: [summary],
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
).toEqual([summary]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("getSummarizationMiddlewareMessages ignores unrelated suffix-sharing update keys", () => {
|
||||||
|
const summary = {
|
||||||
|
id: "summary-1",
|
||||||
|
type: "human",
|
||||||
|
name: "summary",
|
||||||
|
content: "summary",
|
||||||
|
} as Message;
|
||||||
|
|
||||||
|
expect(
|
||||||
|
getSummarizationMiddlewareMessages({
|
||||||
|
"OtherSummarizationMiddleware.before_model": {
|
||||||
|
messages: [summary],
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
test("getVisibleOptimisticMessages hides optimistic user input after server human arrives", () => {
|
test("getVisibleOptimisticMessages hides optimistic user input after server human arrives", () => {
|
||||||
const optimisticHuman = {
|
const optimisticHuman = {
|
||||||
id: "opt-human-1",
|
id: "opt-human-1",
|
||||||
|
|||||||
Reference in New Issue
Block a user