mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
fix: use backend thread token usage for header total (#2800)
* fix: use backend thread token usage for header total * Refactor thread token usage fetch
This commit is contained in:
@@ -0,0 +1,24 @@
|
||||
import { fetch as fetchWithAuth } from "@/core/api/fetcher";
|
||||
import { getBackendBaseURL } from "@/core/config";
|
||||
|
||||
import type { ThreadTokenUsageResponse } from "./types";
|
||||
|
||||
export async function fetchThreadTokenUsage(
|
||||
threadId: string,
|
||||
): Promise<ThreadTokenUsageResponse | null> {
|
||||
const response = await fetchWithAuth(
|
||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/token-usage`,
|
||||
{
|
||||
method: "GET",
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
if (response.status === 403 || response.status === 404) {
|
||||
return null;
|
||||
}
|
||||
throw new Error("Failed to load thread token usage.");
|
||||
}
|
||||
|
||||
return (await response.json()) as ThreadTokenUsageResponse;
|
||||
}
|
||||
@@ -17,7 +17,14 @@ import { useUpdateSubtask } from "../tasks/context";
|
||||
import type { UploadedFileInfo } from "../uploads";
|
||||
import { promptInputFilePartToFile, uploadFiles } from "../uploads";
|
||||
|
||||
import type { AgentThread, AgentThreadState, RunMessage } from "./types";
|
||||
import { fetchThreadTokenUsage } from "./api";
|
||||
import { threadTokenUsageQueryKey } from "./token-usage";
|
||||
import type {
|
||||
AgentThread,
|
||||
AgentThreadState,
|
||||
RunMessage,
|
||||
ThreadTokenUsageResponse,
|
||||
} from "./types";
|
||||
|
||||
export type ToolEndEvent = {
|
||||
name: string;
|
||||
@@ -75,6 +82,23 @@ function mergeMessages(
|
||||
];
|
||||
}
|
||||
|
||||
function messageIdentity(message: Message): string | undefined {
|
||||
if ("tool_call_id" in message) {
|
||||
return message.tool_call_id;
|
||||
}
|
||||
return message.id;
|
||||
}
|
||||
|
||||
function getMessagesAfterBaseline(
|
||||
messages: Message[],
|
||||
baselineMessageIds: ReadonlySet<string>,
|
||||
): Message[] {
|
||||
return messages.filter((message) => {
|
||||
const id = messageIdentity(message);
|
||||
return !id || !baselineMessageIds.has(id);
|
||||
});
|
||||
}
|
||||
|
||||
function getStreamErrorMessage(error: unknown): string {
|
||||
if (typeof error === "string" && error.trim()) {
|
||||
return error;
|
||||
@@ -114,6 +138,7 @@ export function useThreadStream({
|
||||
// and to allow access to the current thread id in onUpdateEvent
|
||||
const threadIdRef = useRef<string | null>(threadId ?? null);
|
||||
const startedRef = useRef(false);
|
||||
const pendingUsageBaselineMessageIdsRef = useRef<Set<string>>(new Set());
|
||||
const listeners = useRef({
|
||||
onSend,
|
||||
onStart,
|
||||
@@ -271,29 +296,42 @@ export function useThreadStream({
|
||||
onError(error) {
|
||||
setOptimisticMessages([]);
|
||||
toast.error(getStreamErrorMessage(error));
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||
if (threadIdRef.current && !isMock) {
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
||||
});
|
||||
}
|
||||
},
|
||||
onFinish(state) {
|
||||
listeners.current.onFinish?.(state.values);
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||
if (threadIdRef.current && !isMock) {
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// Optimistic messages shown before the server stream responds
|
||||
const [optimisticMessages, setOptimisticMessages] = useState<Message[]>([]);
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const humanMessageCount = thread.messages.filter(
|
||||
(m) => m.type === "human",
|
||||
).length;
|
||||
const latestMessageCountsRef = useRef({ humanMessageCount });
|
||||
const sendInFlightRef = useRef(false);
|
||||
const messagesRef = useRef<Message[]>([]);
|
||||
const summarizedRef = useRef<Set<string>>(null);
|
||||
// Track message count before sending so we know when server has responded
|
||||
const prevMsgCountRef = useRef(thread.messages.length);
|
||||
// Track human message count before sending to prevent clearing optimistic
|
||||
// messages before the server's human message arrives (e.g. when AI messages
|
||||
// from "messages-tuple" events arrive before the input human message from
|
||||
// "values" events).
|
||||
const prevHumanMsgCountRef = useRef(
|
||||
thread.messages.filter((m) => m.type === "human").length,
|
||||
);
|
||||
const prevHumanMsgCountRef = useRef(humanMessageCount);
|
||||
|
||||
latestMessageCountsRef.current = { humanMessageCount };
|
||||
summarizedRef.current ??= new Set<string>();
|
||||
|
||||
// Reset thread-local pending UI state when switching between threads so
|
||||
@@ -301,31 +339,43 @@ export function useThreadStream({
|
||||
useEffect(() => {
|
||||
startedRef.current = false;
|
||||
sendInFlightRef.current = false;
|
||||
prevMsgCountRef.current = thread.messages.length;
|
||||
prevHumanMsgCountRef.current = thread.messages.filter(
|
||||
(m) => m.type === "human",
|
||||
).length;
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||
prevHumanMsgCountRef.current =
|
||||
latestMessageCountsRef.current.humanMessageCount;
|
||||
}, [threadId]);
|
||||
|
||||
// When streaming starts without a baseline (e.g. reconnection, run started
|
||||
// from another client, or page reload mid-stream), snapshot the current
|
||||
// messages so only *new* messages are treated as "pending" for token usage.
|
||||
useEffect(() => {
|
||||
if (
|
||||
thread.isLoading &&
|
||||
pendingUsageBaselineMessageIdsRef.current.size === 0
|
||||
) {
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
||||
thread.messages
|
||||
.map(messageIdentity)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
}
|
||||
}, [thread.isLoading, thread.messages]);
|
||||
|
||||
// Clear optimistic when server messages arrive.
|
||||
// For messages with a human optimistic message, wait until the server's
|
||||
// human message has arrived to avoid clearing before the input message
|
||||
// appears in the stream (the input message may arrive via "values" events
|
||||
// after individual "messages-tuple" events for AI messages).
|
||||
const optimisticMessageCount = optimisticMessages.length;
|
||||
const hasHumanOptimistic = optimisticMessages.some((m) => m.type === "human");
|
||||
useEffect(() => {
|
||||
if (optimisticMessages.length === 0) return;
|
||||
if (optimisticMessageCount === 0) return;
|
||||
|
||||
const hasHumanOptimistic = optimisticMessages.some(
|
||||
(m) => m.type === "human",
|
||||
);
|
||||
const newHumanMsgArrived =
|
||||
thread.messages.filter((m) => m.type === "human").length >
|
||||
prevHumanMsgCountRef.current;
|
||||
const newHumanMsgArrived = humanMessageCount > prevHumanMsgCountRef.current;
|
||||
|
||||
if (!hasHumanOptimistic || newHumanMsgArrived) {
|
||||
setOptimisticMessages([]);
|
||||
}
|
||||
}, [thread.messages.length, optimisticMessages.length]);
|
||||
}, [hasHumanOptimistic, humanMessageCount, optimisticMessageCount]);
|
||||
|
||||
const sendMessage = useCallback(
|
||||
async (
|
||||
@@ -341,11 +391,14 @@ export function useThreadStream({
|
||||
|
||||
const text = message.text.trim();
|
||||
|
||||
// Capture current count before showing optimistic messages
|
||||
prevMsgCountRef.current = thread.messages.length;
|
||||
prevHumanMsgCountRef.current = thread.messages.filter(
|
||||
(m) => m.type === "human",
|
||||
).length;
|
||||
// Capture the current human message count before showing optimistic
|
||||
// messages so we can wait for the server's copy of the user input.
|
||||
prevHumanMsgCountRef.current = humanMessageCount;
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
||||
thread.messages
|
||||
.map(messageIdentity)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
|
||||
// Build optimistic files list with uploading status
|
||||
const optimisticFiles: FileInMessage[] = (message.files ?? []).map(
|
||||
@@ -517,7 +570,7 @@ export function useThreadStream({
|
||||
sendInFlightRef.current = false;
|
||||
}
|
||||
},
|
||||
[thread, t.uploads.uploadingFiles, context, queryClient],
|
||||
[thread, t.uploads.uploadingFiles, context, queryClient, humanMessageCount],
|
||||
);
|
||||
|
||||
// Cache the latest thread messages in a ref to compare against incoming history messages for deduplication,
|
||||
@@ -531,6 +584,12 @@ export function useThreadStream({
|
||||
thread.messages,
|
||||
optimisticMessages,
|
||||
);
|
||||
const pendingUsageMessages = thread.isLoading
|
||||
? getMessagesAfterBaseline(
|
||||
thread.messages,
|
||||
pendingUsageBaselineMessageIdsRef.current,
|
||||
)
|
||||
: [];
|
||||
|
||||
// Merge history, live stream, and optimistic messages for display
|
||||
// History messages may overlap with thread.messages; thread.messages take precedence
|
||||
@@ -541,6 +600,7 @@ export function useThreadStream({
|
||||
|
||||
return {
|
||||
thread: mergedThread,
|
||||
pendingUsageMessages,
|
||||
sendMessage,
|
||||
isUploading,
|
||||
isHistoryLoading,
|
||||
@@ -701,6 +761,24 @@ export function useThreadRuns(threadId?: string) {
|
||||
});
|
||||
}
|
||||
|
||||
export function useThreadTokenUsage(
|
||||
threadId?: string | null,
|
||||
{ enabled = true }: { enabled?: boolean } = {},
|
||||
) {
|
||||
return useQuery<ThreadTokenUsageResponse | null>({
|
||||
queryKey: threadTokenUsageQueryKey(threadId),
|
||||
queryFn: async () => {
|
||||
if (!threadId) {
|
||||
return null;
|
||||
}
|
||||
return fetchThreadTokenUsage(threadId);
|
||||
},
|
||||
enabled: enabled && Boolean(threadId),
|
||||
retry: false,
|
||||
refetchOnWindowFocus: false,
|
||||
});
|
||||
}
|
||||
|
||||
export function useRunDetail(threadId: string, runId: string) {
|
||||
const apiClient = getAPIClient();
|
||||
return useQuery<Run>({
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
import type { TokenUsage } from "@/core/messages/usage";
|
||||
|
||||
import type { ThreadTokenUsageResponse } from "./types";
|
||||
|
||||
export function threadTokenUsageQueryKey(threadId?: string | null) {
|
||||
return ["thread-token-usage", threadId] as const;
|
||||
}
|
||||
|
||||
export function threadTokenUsageToTokenUsage(
|
||||
usage: ThreadTokenUsageResponse | null | undefined,
|
||||
): TokenUsage | null {
|
||||
if (!usage) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
inputTokens: usage.total_input_tokens ?? 0,
|
||||
outputTokens: usage.total_output_tokens ?? 0,
|
||||
totalTokens: usage.total_tokens ?? 0,
|
||||
};
|
||||
}
|
||||
@@ -31,3 +31,17 @@ export interface RunMessage {
|
||||
};
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
export interface ThreadTokenUsageResponse {
|
||||
thread_id: string;
|
||||
total_tokens: number;
|
||||
total_input_tokens: number;
|
||||
total_output_tokens: number;
|
||||
total_runs: number;
|
||||
by_model: Record<string, { tokens: number; runs: number }>;
|
||||
by_caller: {
|
||||
lead_agent: number;
|
||||
subagent: number;
|
||||
middleware: number;
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user