diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index adf9dbbb6..fba3edd0c 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -45,15 +45,60 @@ type SendMessageOptions = { additionalKwargs?: Record; }; -function mergeMessages( +function isNonEmptyString(value: string | undefined): value is string { + return typeof value === "string" && value.length > 0; +} + +function messageIdentity(message: Message): string | undefined { + if ( + "tool_call_id" in message && + typeof message.tool_call_id === "string" && + message.tool_call_id.length > 0 + ) { + return `tool:${message.tool_call_id}`; + } + if (typeof message.id === "string" && message.id.length > 0) { + return `message:${message.id}`; + } + return undefined; +} + +function dedupeMessagesByIdentity(messages: Message[]): Message[] { + const lastIndexByIdentity = new Map(); + + messages.forEach((message, index) => { + const identity = messageIdentity(message); + if (identity) { + lastIndexByIdentity.set(identity, index); + } + }); + + return messages.filter((message, index) => { + const identity = messageIdentity(message); + return !identity || lastIndexByIdentity.get(identity) === index; + }); +} + +function findLatestUnloadedRunIndex( + runs: Run[], + loadedRunIds: ReadonlySet, +): number { + for (let i = runs.length - 1; i >= 0; i--) { + const run = runs[i]; + if (run && !loadedRunIds.has(run.run_id)) { + return i; + } + } + return -1; +} + +export function mergeMessages( historyMessages: Message[], threadMessages: Message[], optimisticMessages: Message[], ): Message[] { const threadMessageIds = new Set( - threadMessages - .map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id)) - .filter(Boolean), + threadMessages.map(messageIdentity).filter(isNonEmptyString), ); // The overlap is a contiguous suffix of historyMessages (newest history == oldest thread). @@ -65,28 +110,19 @@ function mergeMessages( if (!msg) { continue; } - if ( - (msg?.id && threadMessageIds.has(msg.id)) || - ("tool_call_id" in msg && threadMessageIds.has(msg.tool_call_id)) - ) { + const identity = messageIdentity(msg); + if (identity && threadMessageIds.has(identity)) { cutoff = i; } else { break; } } - return [ + return dedupeMessagesByIdentity([ ...historyMessages.slice(0, cutoff), ...threadMessages, ...optimisticMessages, - ]; -} - -function messageIdentity(message: Message): string | undefined { - if ("tool_call_id" in message) { - return message.tool_call_id; - } - return message.id; + ]); } function getMessagesAfterBaseline( @@ -627,48 +663,105 @@ export function useThreadHistory(threadId: string) { const runsRef = useRef(runs.data ?? []); const indexRef = useRef(-1); const loadingRef = useRef(false); + const pendingLoadRef = useRef(false); + const loadingRunIdRef = useRef(null); + const loadedRunIdsRef = useRef>(new Set()); const [loading, setLoading] = useState(false); const [messages, setMessages] = useState([]); - loadingRef.current = loading; const loadMessages = useCallback(async () => { + if (loadingRef.current) { + const pendingRunIndex = findLatestUnloadedRunIndex( + runsRef.current, + loadedRunIdsRef.current, + ); + const pendingRun = runsRef.current[pendingRunIndex]; + if (pendingRun && pendingRun.run_id !== loadingRunIdRef.current) { + pendingLoadRef.current = true; + } + return; + } if (runsRef.current.length === 0) { return; } - const run = runsRef.current[indexRef.current]; - if (!run || loadingRef.current) { - return; - } + + loadingRef.current = true; + setLoading(true); + try { - setLoading(true); - const result: { data: RunMessage[]; hasMore: boolean } = await fetch( - `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadIdRef.current)}/runs/${encodeURIComponent(run.run_id)}/messages`, - { - method: "GET", - headers: { - "Content-Type": "application/json", + do { + pendingLoadRef.current = false; + + const nextRunIndex = findLatestUnloadedRunIndex( + runsRef.current, + loadedRunIdsRef.current, + ); + indexRef.current = nextRunIndex; + + const run = runsRef.current[nextRunIndex]; + if (!run) { + indexRef.current = -1; + return; + } + + const requestThreadId = threadIdRef.current; + loadingRunIdRef.current = run.run_id; + const result: { data: RunMessage[]; hasMore: boolean } = await fetch( + `${getBackendBaseURL()}/api/threads/${encodeURIComponent(requestThreadId)}/runs/${encodeURIComponent(run.run_id)}/messages`, + { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + credentials: "include", }, - credentials: "include", - }, - ).then((res) => { - return res.json(); - }); - const _messages = result.data - .filter((m) => !m.metadata.caller?.startsWith("middleware:")) - .map((m) => m.content); - setMessages((prev) => [..._messages, ...prev]); - indexRef.current -= 1; + ).then((res) => { + return res.json(); + }); + const _messages = result.data + .filter((m) => !m.metadata.caller?.startsWith("middleware:")) + .map((m) => m.content); + if (threadIdRef.current !== requestThreadId) { + return; + } + setMessages((prev) => + dedupeMessagesByIdentity([..._messages, ...prev]), + ); + loadedRunIdsRef.current.add(run.run_id); + indexRef.current = findLatestUnloadedRunIndex( + runsRef.current, + loadedRunIdsRef.current, + ); + } while (pendingLoadRef.current); } catch (err) { console.error(err); } finally { + loadingRef.current = false; + loadingRunIdRef.current = null; setLoading(false); } }, []); useEffect(() => { + const threadChanged = threadIdRef.current !== threadId; threadIdRef.current = threadId; + + if (threadChanged) { + runsRef.current = []; + indexRef.current = -1; + pendingLoadRef.current = false; + loadingRunIdRef.current = null; + loadedRunIdsRef.current = new Set(); + loadingRef.current = false; + setLoading(false); + setMessages([]); + } + if (runs.data && runs.data.length > 0) { runsRef.current = runs.data ?? []; - indexRef.current = runs.data.length - 1; + indexRef.current = findLatestUnloadedRunIndex( + runs.data, + loadedRunIdsRef.current, + ); } loadMessages().catch(() => { toast.error("Failed to load thread history."); @@ -677,7 +770,7 @@ export function useThreadHistory(threadId: string) { const appendMessages = useCallback((_messages: Message[]) => { setMessages((prev) => { - return [...prev, ..._messages]; + return dedupeMessagesByIdentity([...prev, ..._messages]); }); }, []); const hasMore = indexRef.current >= 0 || !runs.data; diff --git a/frontend/tests/unit/core/threads/message-merge.test.ts b/frontend/tests/unit/core/threads/message-merge.test.ts new file mode 100644 index 000000000..9b29aebc9 --- /dev/null +++ b/frontend/tests/unit/core/threads/message-merge.test.ts @@ -0,0 +1,64 @@ +import type { Message } from "@langchain/langgraph-sdk"; +import { expect, test } from "vitest"; + +import { mergeMessages } from "@/core/threads/hooks"; + +test("mergeMessages removes duplicate messages already present in history", () => { + const human = { + id: "human-1", + type: "human", + content: "Design an agent", + } as Message; + const ai = { + id: "ai-1", + type: "ai", + content: "Let's design it.", + } as Message; + + expect(mergeMessages([human, ai, human, ai], [], [])).toEqual([human, ai]); +}); + +test("mergeMessages lets live thread messages replace overlapping history", () => { + const oldHuman = { + id: "human-1", + type: "human", + content: "old", + } as Message; + const liveHuman = { + id: "human-1", + type: "human", + content: "live", + } as Message; + const oldAi = { + id: "ai-1", + type: "ai", + content: "old", + } as Message; + const liveAi = { + id: "ai-1", + type: "ai", + content: "live", + } as Message; + + expect(mergeMessages([oldHuman, oldAi], [liveHuman, liveAi], [])).toEqual([ + liveHuman, + liveAi, + ]); +}); + +test("mergeMessages deduplicates tool messages by tool_call_id", () => { + const oldTool = { + id: "tool-message-old", + type: "tool", + tool_call_id: "call-1", + content: "old", + } as Message; + const liveTool = { + id: "tool-message-live", + type: "tool", + tool_call_id: "call-1", + content: "live", + } as Message; + + expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]); +});