Files
deer-flow/frontend/src/core/threads/hooks.ts
T
Willem Jiang e5c7328cf5 fix(chat): refresh run list after each turn so scroll-to-load history works
After context compression in an active conversation, scrolling to the top
  would not load earlier messages because useThreadRuns was never invalidated
  once new runs completed. The stale runs list left hasMore permanently false.

  - Invalidate the ["thread", threadId] query in onFinish so the runs list
    stays current as new runs are created
  - Split the useThreadHistory effect into separate reset (threadId change)
    and incremental-update (runs.data change) paths so indexRef is adjusted
    without resetting already-loaded history
  - Add message deduplication when prepending loaded run messages to guard
    against overlap with the live stream
  - Extract deduplicateHistoryMessages and adjustHistoryIndex into a pure
    utility module with 10 unit tests

  Closes #2965
2026-05-15 10:11:56 +08:00

921 lines
28 KiB
TypeScript

import type { AIMessage, Message, Run } from "@langchain/langgraph-sdk";
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
import { useStream } from "@langchain/langgraph-sdk/react";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { useCallback, useEffect, useRef, useState } from "react";
import { toast } from "sonner";
import type { PromptInputMessage } from "@/components/ai-elements/prompt-input";
import { getAPIClient } from "../api";
import { fetch } from "../api/fetcher";
import { getBackendBaseURL } from "../config";
import { useI18n } from "../i18n/hooks";
import type { FileInMessage } from "../messages/utils";
import type { LocalSettings } from "../settings";
import { useUpdateSubtask } from "../tasks/context";
import type { UploadedFileInfo } from "../uploads";
import { promptInputFilePartToFile, uploadFiles } from "../uploads";
import { fetchThreadTokenUsage } from "./api";
import {
adjustHistoryIndex,
deduplicateHistoryMessages,
} from "./history-utils";
import { threadTokenUsageQueryKey } from "./token-usage";
import type {
AgentThread,
AgentThreadState,
RunMessage,
ThreadTokenUsageResponse,
} from "./types";
export type ToolEndEvent = {
name: string;
data: unknown;
};
export type ThreadStreamOptions = {
threadId?: string | null | undefined;
context: LocalSettings["context"];
isMock?: boolean;
onSend?: (threadId: string) => void;
onStart?: (threadId: string, runId: string) => void;
onFinish?: (state: AgentThreadState) => void;
onToolEnd?: (event: ToolEndEvent) => void;
};
type SendMessageOptions = {
additionalKwargs?: Record<string, unknown>;
};
function mergeMessages(
historyMessages: Message[],
threadMessages: Message[],
optimisticMessages: Message[],
): Message[] {
const threadMessageIds = new Set(
threadMessages
.map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id))
.filter(Boolean),
);
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
// Scan from the end: shrink cutoff while messages are already in thread, stop as soon as
// we hit one that isn't — everything before that point is non-overlapping.
let cutoff = historyMessages.length;
for (let i = historyMessages.length - 1; i >= 0; i--) {
const msg = historyMessages[i];
if (!msg) {
continue;
}
if (
(msg?.id && threadMessageIds.has(msg.id)) ||
("tool_call_id" in msg && threadMessageIds.has(msg.tool_call_id))
) {
cutoff = i;
} else {
break;
}
}
return [
...historyMessages.slice(0, cutoff),
...threadMessages,
...optimisticMessages,
];
}
function messageIdentity(message: Message): string | undefined {
if ("tool_call_id" in message) {
return message.tool_call_id;
}
return message.id;
}
function getMessagesAfterBaseline(
messages: Message[],
baselineMessageIds: ReadonlySet<string>,
): Message[] {
return messages.filter((message) => {
const id = messageIdentity(message);
return !id || !baselineMessageIds.has(id);
});
}
function getStreamErrorMessage(error: unknown): string {
if (typeof error === "string" && error.trim()) {
return error;
}
if (error instanceof Error && error.message.trim()) {
return error.message;
}
if (typeof error === "object" && error !== null) {
const message = Reflect.get(error, "message");
if (typeof message === "string" && message.trim()) {
return message;
}
const nestedError = Reflect.get(error, "error");
if (nestedError instanceof Error && nestedError.message.trim()) {
return nestedError.message;
}
if (typeof nestedError === "string" && nestedError.trim()) {
return nestedError;
}
}
return "Request failed.";
}
export function useThreadStream({
threadId,
context,
isMock,
onSend,
onStart,
onFinish,
onToolEnd,
}: ThreadStreamOptions) {
const { t } = useI18n();
// Track the thread ID that is currently streaming to handle thread changes during streaming
const [onStreamThreadId, setOnStreamThreadId] = useState(() => threadId);
// Ref to track current thread ID across async callbacks without causing re-renders,
// and to allow access to the current thread id in onUpdateEvent
const threadIdRef = useRef<string | null>(threadId ?? null);
const startedRef = useRef(false);
const pendingUsageBaselineMessageIdsRef = useRef<Set<string>>(new Set());
const listeners = useRef({
onSend,
onStart,
onFinish,
onToolEnd,
});
const {
messages: history,
hasMore: hasMoreHistory,
loadMore: loadMoreHistory,
loading: isHistoryLoading,
appendMessages,
} = useThreadHistory(onStreamThreadId ?? "");
// Keep listeners ref updated with latest callbacks
useEffect(() => {
listeners.current = { onSend, onStart, onFinish, onToolEnd };
}, [onSend, onStart, onFinish, onToolEnd]);
useEffect(() => {
const normalizedThreadId = threadId ?? null;
if (!normalizedThreadId) {
// Reset when the UI moves back to a brand new unsaved thread.
startedRef.current = false;
setOnStreamThreadId(normalizedThreadId);
} else {
setOnStreamThreadId(normalizedThreadId);
}
threadIdRef.current = normalizedThreadId;
}, [threadId]);
const handleStreamStart = useCallback((_threadId: string, _runId: string) => {
threadIdRef.current = _threadId;
if (!startedRef.current) {
listeners.current.onStart?.(_threadId, _runId);
startedRef.current = true;
}
setOnStreamThreadId(_threadId);
}, []);
const queryClient = useQueryClient();
const updateSubtask = useUpdateSubtask();
const thread = useStream<AgentThreadState>({
client: getAPIClient(isMock),
assistantId: "lead_agent",
threadId: onStreamThreadId,
reconnectOnMount: true,
fetchStateHistory: { limit: 1 },
onCreated(meta) {
handleStreamStart(meta.thread_id, meta.run_id);
if (context.agent_name && !isMock) {
void getAPIClient()
.threads.update(meta.thread_id, {
metadata: { agent_name: context.agent_name },
})
.catch(() => ({}));
}
},
onLangChainEvent(event) {
if (event.event === "on_tool_end") {
listeners.current.onToolEnd?.({
name: event.name,
data: event.data,
});
}
},
onUpdateEvent(data) {
if (data["SummarizationMiddleware.before_model"]) {
const _messages = [
...(data["SummarizationMiddleware.before_model"].messages ?? []),
];
if (_messages.length < 2) {
return;
}
for (const m of _messages) {
if (m.name === "summary" && m.type === "human") {
summarizedRef.current?.add(m.id ?? "");
}
}
const _lastKeepMessage = _messages[2];
const _currentMessages = [...messagesRef.current];
const _movedMessages: Message[] = [];
for (const m of _currentMessages) {
if (m.id !== undefined && m.id === _lastKeepMessage?.id) {
break;
}
if (!summarizedRef.current?.has(m.id ?? "")) {
_movedMessages.push(m);
}
}
appendMessages(_movedMessages);
messagesRef.current = [];
}
const updates: Array<Partial<AgentThreadState> | null> = Object.values(
data || {},
);
for (const update of updates) {
if (update && "title" in update && update.title) {
void queryClient.setQueriesData(
{
queryKey: ["threads", "search"],
exact: false,
},
(oldData: Array<AgentThread> | undefined) => {
return oldData?.map((t) => {
if (t.thread_id === threadIdRef.current) {
return {
...t,
values: {
...t.values,
title: update.title,
},
};
}
return t;
});
},
);
}
}
},
onCustomEvent(event: unknown) {
if (
typeof event === "object" &&
event !== null &&
"type" in event &&
event.type === "task_running"
) {
const e = event as {
type: "task_running";
task_id: string;
message: AIMessage;
};
updateSubtask({ id: e.task_id, latestMessage: e.message });
return;
}
if (
typeof event === "object" &&
event !== null &&
"type" in event &&
event.type === "llm_retry" &&
"message" in event &&
typeof event.message === "string" &&
event.message.trim()
) {
const e = event as { type: "llm_retry"; message: string };
toast(e.message);
}
},
onError(error) {
setOptimisticMessages([]);
toast.error(getStreamErrorMessage(error));
pendingUsageBaselineMessageIdsRef.current = new Set(
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
});
}
},
onFinish(state) {
listeners.current.onFinish?.(state.values);
pendingUsageBaselineMessageIdsRef.current = new Set(
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({
queryKey: ["thread", threadIdRef.current],
});
void queryClient.invalidateQueries({
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
});
}
},
});
// Optimistic messages shown before the server stream responds
const [optimisticMessages, setOptimisticMessages] = useState<Message[]>([]);
const [isUploading, setIsUploading] = useState(false);
const humanMessageCount = thread.messages.filter(
(m) => m.type === "human",
).length;
const latestMessageCountsRef = useRef({ humanMessageCount });
const sendInFlightRef = useRef(false);
const messagesRef = useRef<Message[]>([]);
const summarizedRef = useRef<Set<string>>(null);
// Track human message count before sending to prevent clearing optimistic
// messages before the server's human message arrives (e.g. when AI messages
// from "messages-tuple" events arrive before the input human message from
// "values" events).
const prevHumanMsgCountRef = useRef(humanMessageCount);
latestMessageCountsRef.current = { humanMessageCount };
summarizedRef.current ??= new Set<string>();
// Reset thread-local pending UI state when switching between threads so
// optimistic messages and in-flight guards do not leak across chat views.
useEffect(() => {
startedRef.current = false;
sendInFlightRef.current = false;
pendingUsageBaselineMessageIdsRef.current = new Set(
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
prevHumanMsgCountRef.current =
latestMessageCountsRef.current.humanMessageCount;
}, [threadId]);
// When streaming starts without a baseline (e.g. reconnection, run started
// from another client, or page reload mid-stream), snapshot the current
// messages so only *new* messages are treated as "pending" for token usage.
useEffect(() => {
if (
thread.isLoading &&
pendingUsageBaselineMessageIdsRef.current.size === 0
) {
pendingUsageBaselineMessageIdsRef.current = new Set(
thread.messages
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
}
}, [thread.isLoading, thread.messages]);
// Clear optimistic when server messages arrive.
// For messages with a human optimistic message, wait until the server's
// human message has arrived to avoid clearing before the input message
// appears in the stream (the input message may arrive via "values" events
// after individual "messages-tuple" events for AI messages).
const optimisticMessageCount = optimisticMessages.length;
const hasHumanOptimistic = optimisticMessages.some((m) => m.type === "human");
useEffect(() => {
if (optimisticMessageCount === 0) return;
const newHumanMsgArrived = humanMessageCount > prevHumanMsgCountRef.current;
if (!hasHumanOptimistic || newHumanMsgArrived) {
setOptimisticMessages([]);
}
}, [hasHumanOptimistic, humanMessageCount, optimisticMessageCount]);
const sendMessage = useCallback(
async (
threadId: string,
message: PromptInputMessage,
extraContext?: Record<string, unknown>,
options?: SendMessageOptions,
) => {
if (sendInFlightRef.current) {
return;
}
sendInFlightRef.current = true;
const text = message.text.trim();
// Capture the current human message count before showing optimistic
// messages so we can wait for the server's copy of the user input.
prevHumanMsgCountRef.current = humanMessageCount;
pendingUsageBaselineMessageIdsRef.current = new Set(
thread.messages
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
// Build optimistic files list with uploading status
const optimisticFiles: FileInMessage[] = (message.files ?? []).map(
(f) => ({
filename: f.filename ?? "",
size: 0,
status: "uploading" as const,
}),
);
const hideFromUI = options?.additionalKwargs?.hide_from_ui === true;
const optimisticAdditionalKwargs = {
...options?.additionalKwargs,
...(optimisticFiles.length > 0 ? { files: optimisticFiles } : {}),
};
const newOptimistic: Message[] = [];
if (!hideFromUI) {
newOptimistic.push({
type: "human",
id: `opt-human-${Date.now()}`,
content: text ? [{ type: "text", text }] : "",
additional_kwargs: optimisticAdditionalKwargs,
});
}
if (optimisticFiles.length > 0 && !hideFromUI) {
// Mock AI message while files are being uploaded
newOptimistic.push({
type: "ai",
id: `opt-ai-${Date.now()}`,
content: t.uploads.uploadingFiles,
additional_kwargs: { element: "task" },
});
}
setOptimisticMessages(newOptimistic);
listeners.current.onSend?.(threadId);
let uploadedFileInfo: UploadedFileInfo[] = [];
try {
// Upload files first if any
if (message.files && message.files.length > 0) {
setIsUploading(true);
try {
const filePromises = message.files.map((fileUIPart) =>
promptInputFilePartToFile(fileUIPart),
);
const conversionResults = await Promise.all(filePromises);
const files = conversionResults.filter(
(file): file is File => file !== null,
);
const failedConversions = conversionResults.length - files.length;
if (failedConversions > 0) {
throw new Error(
`Failed to prepare ${failedConversions} attachment(s) for upload. Please retry.`,
);
}
if (!threadId) {
throw new Error("Thread is not ready for file upload.");
}
if (files.length > 0) {
const uploadResponse = await uploadFiles(threadId, files);
uploadedFileInfo = uploadResponse.files;
// Update optimistic human message with uploaded status + paths
const uploadedFiles: FileInMessage[] = uploadedFileInfo.map(
(info) => ({
filename: info.filename,
size: info.size,
path: info.virtual_path,
status: "uploaded" as const,
}),
);
setOptimisticMessages((messages) => {
if (messages.length > 1 && messages[0]) {
const humanMessage: Message = messages[0];
return [
{
...humanMessage,
additional_kwargs: { files: uploadedFiles },
},
...messages.slice(1),
];
}
return messages;
});
}
} catch (error) {
const errorMessage =
error instanceof Error
? error.message
: "Failed to upload files.";
toast.error(errorMessage);
setOptimisticMessages([]);
throw error;
} finally {
setIsUploading(false);
}
}
// Build files metadata for submission (included in additional_kwargs)
const filesForSubmit: FileInMessage[] = uploadedFileInfo.map(
(info) => ({
filename: info.filename,
size: info.size,
path: info.virtual_path,
status: "uploaded" as const,
}),
);
await thread.submit(
{
messages: [
{
type: "human",
content: [
{
type: "text",
text,
},
],
additional_kwargs: {
...options?.additionalKwargs,
...(filesForSubmit.length > 0
? { files: filesForSubmit }
: {}),
},
},
],
},
{
threadId: threadId,
streamSubgraphs: true,
streamResumable: true,
config: {
recursion_limit: 1000,
},
context: {
...extraContext,
...context,
thinking_enabled: context.mode !== "flash",
is_plan_mode: context.mode === "pro" || context.mode === "ultra",
subagent_enabled: context.mode === "ultra",
reasoning_effort:
context.reasoning_effort ??
(context.mode === "ultra"
? "high"
: context.mode === "pro"
? "medium"
: context.mode === "thinking"
? "low"
: undefined),
thread_id: threadId,
},
},
);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
} catch (error) {
setOptimisticMessages([]);
setIsUploading(false);
throw error;
} finally {
sendInFlightRef.current = false;
}
},
[thread, t.uploads.uploadingFiles, context, queryClient, humanMessageCount],
);
// Cache the latest thread messages in a ref to compare against incoming history messages for deduplication,
// and to allow access to the full message list in onUpdateEvent without causing re-renders.
if (thread.messages.length >= messagesRef.current.length) {
messagesRef.current = thread.messages;
}
const mergedMessages = mergeMessages(
history,
thread.messages,
optimisticMessages,
);
const pendingUsageMessages = thread.isLoading
? getMessagesAfterBaseline(
thread.messages,
pendingUsageBaselineMessageIdsRef.current,
)
: [];
// Merge history, live stream, and optimistic messages for display
// History messages may overlap with thread.messages; thread.messages take precedence
const mergedThread = {
...thread,
messages: mergedMessages,
} as typeof thread;
return {
thread: mergedThread,
pendingUsageMessages,
sendMessage,
isUploading,
isHistoryLoading,
hasMoreHistory,
loadMoreHistory,
} as const;
}
export function useThreadHistory(threadId: string) {
const runs = useThreadRuns(threadId);
const threadIdRef = useRef(threadId);
const runsRef = useRef(runs.data ?? []);
const indexRef = useRef(-1);
const loadingRef = useRef(false);
const [loading, setLoading] = useState(false);
const [messages, setMessages] = useState<Message[]>([]);
const initialLoadDoneRef = useRef(false);
loadingRef.current = loading;
const loadMessages = useCallback(async () => {
if (runsRef.current.length === 0) {
return;
}
const run = runsRef.current[indexRef.current];
if (!run || loadingRef.current) {
return;
}
try {
setLoading(true);
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadIdRef.current)}/runs/${encodeURIComponent(run.run_id)}/messages`,
{
method: "GET",
headers: {
"Content-Type": "application/json",
},
credentials: "include",
},
).then((res) => {
return res.json();
});
const _messages = result.data
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
.map((m) => m.content);
setMessages((prev) => {
const deduped = deduplicateHistoryMessages(prev, _messages);
return [...deduped, ...prev];
});
indexRef.current -= 1;
} catch (err) {
console.error(err);
} finally {
setLoading(false);
}
}, []);
// Reset state when threadId changes
useEffect(() => {
threadIdRef.current = threadId;
runsRef.current = [];
indexRef.current = -1;
initialLoadDoneRef.current = false;
setMessages([]);
}, [threadId]);
// Load/update history when runs data changes
useEffect(() => {
if (runs.data && runs.data.length > 0) {
const prevLength = runsRef.current.length;
runsRef.current = runs.data;
if (!initialLoadDoneRef.current) {
// Initial load: start from the most recent run
initialLoadDoneRef.current = true;
indexRef.current = runs.data.length - 1;
loadMessages().catch(() => {
toast.error("Failed to load thread history.");
});
} else if (runs.data.length > prevLength) {
// New runs added (e.g., after query invalidation): adjust indexRef
// so the user can load older history by scrolling up
indexRef.current = adjustHistoryIndex(
indexRef.current,
prevLength,
runs.data.length,
);
}
}
}, [threadId, runs.data, loadMessages]);
const appendMessages = useCallback((_messages: Message[]) => {
setMessages((prev) => {
return [...prev, ..._messages];
});
}, []);
const hasMore = indexRef.current >= 0 || !runs.data;
return {
runs: runs.data,
messages,
loading,
appendMessages,
hasMore,
loadMore: loadMessages,
};
}
export function useThreads(
params: Parameters<ThreadsClient["search"]>[0] = {
limit: 50,
sortBy: "updated_at",
sortOrder: "desc",
select: ["thread_id", "updated_at", "values", "metadata"],
},
) {
const apiClient = getAPIClient();
return useQuery<AgentThread[]>({
queryKey: ["threads", "search", params],
queryFn: async () => {
const maxResults = params.limit;
const initialOffset = params.offset ?? 0;
const DEFAULT_PAGE_SIZE = 50;
// Preserve prior semantics: if a non-positive limit is explicitly provided,
// delegate to a single search call with the original parameters.
if (maxResults !== undefined && maxResults <= 0) {
const response =
await apiClient.threads.search<AgentThreadState>(params);
return response as AgentThread[];
}
const pageSize =
typeof maxResults === "number" && maxResults > 0
? Math.min(DEFAULT_PAGE_SIZE, maxResults)
: DEFAULT_PAGE_SIZE;
const threads: AgentThread[] = [];
let offset = initialOffset;
while (true) {
if (typeof maxResults === "number" && threads.length >= maxResults) {
break;
}
const currentLimit =
typeof maxResults === "number"
? Math.min(pageSize, maxResults - threads.length)
: pageSize;
if (typeof maxResults === "number" && currentLimit <= 0) {
break;
}
const response = (await apiClient.threads.search<AgentThreadState>({
...params,
limit: currentLimit,
offset,
})) as AgentThread[];
threads.push(...response);
if (response.length < currentLimit) {
break;
}
offset += response.length;
}
return threads;
},
refetchOnWindowFocus: false,
});
}
export function useThreadRuns(threadId?: string) {
const apiClient = getAPIClient();
return useQuery<Run[]>({
queryKey: ["thread", threadId],
queryFn: async () => {
if (!threadId) {
return [];
}
const response = await apiClient.runs.list(threadId);
return response;
},
refetchOnWindowFocus: false,
});
}
export function useThreadTokenUsage(
threadId?: string | null,
{ enabled = true }: { enabled?: boolean } = {},
) {
return useQuery<ThreadTokenUsageResponse | null>({
queryKey: threadTokenUsageQueryKey(threadId),
queryFn: async () => {
if (!threadId) {
return null;
}
return fetchThreadTokenUsage(threadId);
},
enabled: enabled && Boolean(threadId),
retry: false,
refetchOnWindowFocus: false,
});
}
export function useRunDetail(threadId: string, runId: string) {
const apiClient = getAPIClient();
return useQuery<Run>({
queryKey: ["thread", threadId, "run", runId],
queryFn: async () => {
const response = await apiClient.runs.get(threadId, runId);
return response;
},
refetchOnWindowFocus: false,
});
}
export function useDeleteThread() {
const queryClient = useQueryClient();
const apiClient = getAPIClient();
return useMutation({
mutationFn: async ({ threadId }: { threadId: string }) => {
await apiClient.threads.delete(threadId);
const response = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}`,
{
method: "DELETE",
},
);
if (!response.ok) {
const error = await response
.json()
.catch(() => ({ detail: "Failed to delete local thread data." }));
throw new Error(error.detail ?? "Failed to delete local thread data.");
}
},
onSuccess(_, { threadId }) {
queryClient.setQueriesData(
{
queryKey: ["threads", "search"],
exact: false,
},
(oldData: Array<AgentThread> | undefined) => {
if (oldData == null) {
return oldData;
}
return oldData.filter((t) => t.thread_id !== threadId);
},
);
},
onSettled() {
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
},
});
}
export function useRenameThread() {
const queryClient = useQueryClient();
const apiClient = getAPIClient();
return useMutation({
mutationFn: async ({
threadId,
title,
}: {
threadId: string;
title: string;
}) => {
await apiClient.threads.updateState(threadId, {
values: { title },
});
},
onSuccess(_, { threadId, title }) {
queryClient.setQueriesData(
{
queryKey: ["threads", "search"],
exact: false,
},
(oldData: Array<AgentThread>) => {
return oldData.map((t) => {
if (t.thread_id === threadId) {
return {
...t,
values: {
...t.values,
title,
},
};
}
return t;
});
},
);
},
});
}