Compare commits

...

2 Commits

Author SHA1 Message Date
JeffJiang 565ab432fc Fix test assertions for run ordering in RunManager tests
- Updated assertions in `test_list_by_thread` to reflect correct ordering of runs.
- Modified `test_list_by_thread_is_stable_when_timestamps_tie` to ensure stable ordering when timestamps are tied.
2026-04-19 09:55:34 +08:00
JeffJiang df63c104a7 Refactor API fetch calls to use a unified fetch function; enhance chat history loading with new hooks and UI components
- Replaced `fetchWithAuth` with a generic `fetch` function across various API modules for consistency.
- Updated `useThreadStream` and `useThreadHistory` hooks to manage chat history loading, including loading states and pagination.
- Introduced `LoadMoreHistoryIndicator` component for better user experience when loading more chat history.
- Enhanced message handling in `MessageList` to accommodate new loading states and history management.
- Added support for run messages in the thread context, improving the overall message handling logic.
- Updated translations for loading indicators in English and Chinese.
2026-04-17 23:41:11 +08:00
35 changed files with 755 additions and 1440 deletions
+20 -17
View File
@@ -8,13 +8,17 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, TypeVar, cast
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from langgraph.types import Checkpointer
from deerflow.runtime import RunContext, RunManager from deerflow.persistence.feedback import FeedbackRepository
from deerflow.runtime import RunContext, RunManager, StreamBridge
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.runs.store.base import RunStore
if TYPE_CHECKING: if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.local_provider import LocalAuthProvider
@@ -22,6 +26,9 @@ if TYPE_CHECKING:
from deerflow.persistence.thread_meta.base import ThreadMetaStore from deerflow.persistence.thread_meta.base import ThreadMetaStore
T = TypeVar("T")
@asynccontextmanager @asynccontextmanager
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
"""Bootstrap and tear down all LangGraph runtime singletons. """Bootstrap and tear down all LangGraph runtime singletons.
@@ -84,25 +91,25 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _require(attr: str, label: str): def _require(attr: str, label: str) -> Callable[[Request], T]:
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503.""" """Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
def dep(request: Request): def dep(request: Request) -> T:
val = getattr(request.app.state, attr, None) val = getattr(request.app.state, attr, None)
if val is None: if val is None:
raise HTTPException(status_code=503, detail=f"{label} not available") raise HTTPException(status_code=503, detail=f"{label} not available")
return val return cast(T, val)
dep.__name__ = dep.__qualname__ = f"get_{attr}" dep.__name__ = dep.__qualname__ = f"get_{attr}"
return dep return dep
get_stream_bridge = _require("stream_bridge", "Stream bridge") get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge")
get_run_manager = _require("run_manager", "Run manager") get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager")
get_checkpointer = _require("checkpointer", "Checkpointer") get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer")
get_run_event_store = _require("run_event_store", "Run event store") get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store")
get_feedback_repo = _require("feedback_repo", "Feedback") get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback")
get_run_store = _require("run_store", "Run store") get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store")
def get_store(request: Request): def get_store(request: Request):
@@ -121,10 +128,7 @@ def get_thread_store(request: Request) -> ThreadMetaStore:
def get_run_context(request: Request) -> RunContext: def get_run_context(request: Request) -> RunContext:
"""Build a :class:`RunContext` from ``app.state`` singletons. """Build a :class:`RunContext` from ``app.state`` singletons.
Returns a *base* context with infrastructure dependencies. Callers that Returns a *base* context with infrastructure dependencies.
need per-run fields (e.g. ``follow_up_to_run_id``) should use
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
to :func:`run_agent`.
""" """
from deerflow.config import get_app_config from deerflow.config import get_app_config
@@ -137,7 +141,6 @@ def get_run_context(request: Request) -> RunContext:
) )
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Auth helpers (used by authz.py and auth middleware) # Auth helpers (used by authz.py and auth middleware)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
+2 -1
View File
@@ -123,7 +123,8 @@ async def run_messages(
run = await _resolve_run(run_id, request) run = await _resolve_run(run_id, request)
event_store = get_run_event_store(request) event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run( rows = await event_store.list_messages_by_run(
run["thread_id"], run_id, run["thread_id"],
run_id,
limit=limit + 1, limit=limit + 1,
before_seq=before_seq, before_seq=before_seq,
after_seq=after_seq, after_seq=after_seq,
+11 -7
View File
@@ -54,7 +54,6 @@ class RunCreateRequest(BaseModel):
after_seconds: float | None = Field(default=None, description="Delayed execution") after_seconds: float | None = Field(default=None, description="Delayed execution")
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy") if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys") feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.")
class RunResponse(BaseModel): class RunResponse(BaseModel):
@@ -312,11 +311,15 @@ async def list_thread_messages(
if i in last_ai_indices: if i in last_ai_indices:
run_id = msg["run_id"] run_id = msg["run_id"]
fb = feedback_map.get(run_id) fb = feedback_map.get(run_id)
msg["feedback"] = { msg["feedback"] = (
"feedback_id": fb["feedback_id"], {
"rating": fb["rating"], "feedback_id": fb["feedback_id"],
"comment": fb.get("comment"), "rating": fb["rating"],
} if fb else None "comment": fb.get("comment"),
}
if fb
else None
)
else: else:
msg["feedback"] = None msg["feedback"] = None
@@ -339,7 +342,8 @@ async def list_run_messages(
""" """
event_store = get_run_event_store(request) event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run( rows = await event_store.list_messages_by_run(
thread_id, run_id, thread_id,
run_id,
limit=limit + 1, limit=limit + 1,
before_seq=before_seq, before_seq=before_seq,
after_seq=after_seq, after_seq=after_seq,
+1 -1
View File
@@ -56,7 +56,7 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
@router.post("", response_model=UploadResponse) @router.post("", response_model=UploadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True) @require_permission("threads", "write", owner_check=True, require_existing=False)
async def upload_files( async def upload_files(
thread_id: str, thread_id: str,
request: Request, request: Request,
-16
View File
@@ -195,21 +195,6 @@ async def start_run(
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_ disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
if follow_up_to_run_id is None:
run_store = get_run_store(request)
try:
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
if recent_runs and recent_runs[0].get("status") == "success":
follow_up_to_run_id = recent_runs[0]["run_id"]
except Exception:
pass # Don't block run creation
# Enrich base context with per-run field
if follow_up_to_run_id:
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
try: try:
record = await run_mgr.create_or_reject( record = await run_mgr.create_or_reject(
thread_id, thread_id,
@@ -218,7 +203,6 @@ async def start_run(
metadata=body.metadata or {}, metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config}, kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy, multitask_strategy=body.multitask_strategy,
follow_up_to_run_id=follow_up_to_run_id,
) )
except ConflictError as exc: except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc raise HTTPException(status_code=409, detail=str(exc)) from exc
@@ -1,7 +1,7 @@
import logging import logging
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from deerflow.agents.lead_agent.prompt import apply_prompt_template from deerflow.agents.lead_agent.prompt import apply_prompt_template
@@ -9,6 +9,7 @@ from deerflow.agents.middlewares.clarification_middleware import ClarificationMi
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from deerflow.agents.middlewares.summarization_middleware import SummarizationMiddleware
from deerflow.agents.middlewares.title_middleware import TitleMiddleware from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
@@ -283,7 +283,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
# the conversation; injecting one mid-conversation crashes # the conversation; injecting one mid-conversation crashes
# langchain_anthropic's _format_messages(). HumanMessage works # langchain_anthropic's _format_messages(). HumanMessage works
# with all providers. See #1299. # with all providers. See #1299.
return {"messages": [HumanMessage(content=warning)]} return {"messages": [HumanMessage(content=warning, name="loop_warning")]}
return None return None
@@ -0,0 +1,13 @@
from typing import override
from langchain.agents.middleware import SummarizationMiddleware as BaseSummarizationMiddleware
from langchain_core.messages.human import HumanMessage
class SummarizationMiddleware(BaseSummarizationMiddleware):
@override
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
"""Override the base implementation to let the human message with the special name 'summary'.
And this message will be ignored to display in the frontend, but still can be used as context for the model.
"""
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
@@ -1,8 +1,10 @@
import logging import logging
from datetime import UTC, datetime
from typing import NotRequired, override from typing import NotRequired, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.config import get_config from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
@@ -97,8 +99,20 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
paths = self._create_thread_directories(thread_id, user_id=user_id) paths = self._create_thread_directories(thread_id, user_id=user_id)
logger.debug("Created thread data directories for thread %s", thread_id) logger.debug("Created thread data directories for thread %s", thread_id)
messages = list(state.get("messages", []))
last_message = messages[-1] if messages else None
if last_message and isinstance(last_message, HumanMessage):
messages[-1] = HumanMessage(
content=last_message.content,
id=last_message.id,
name=last_message.name or "user-input",
additional_kwargs={**last_message.additional_kwargs, "run_id": runtime.context.get("run_id"), "timestamp": datetime.now(UTC).isoformat()},
)
return { return {
"thread_data": { "thread_data": {
**paths, **paths,
} },
"messages": messages,
} }
@@ -279,6 +279,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
updated_message = HumanMessage( updated_message = HumanMessage(
content=f"{files_message}\n\n{original_content}", content=f"{files_message}\n\n{original_content}",
id=last_message.id, id=last_message.id,
name=last_message.name,
additional_kwargs=last_message.additional_kwargs, additional_kwargs=last_message.additional_kwargs,
) )
@@ -6,7 +6,10 @@ handles token usage accumulation.
Key design decisions: Key design decisions:
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end - on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
- on_chat_model_start captures structured prompts as llm_request (OpenAI format) - on_chat_model_start captures structured prompts as llm_request (OpenAI format) and
extracts the first human message for run.input, because it is more reliable than
on_chain_start (fires on every node) — messages here are fully structured.
- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation.
- on_llm_end emits llm_response in OpenAI Chat Completions format - on_llm_end emits llm_response in OpenAI Chat Completions format
- Token usage accumulated in memory, written to RunRow on run completion - Token usage accumulated in memory, written to RunRow on run completion
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name}) - Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
@@ -18,10 +21,12 @@ import asyncio
import logging import logging
import time import time
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, cast
from uuid import UUID from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.types import Command
if TYPE_CHECKING: if TYPE_CHECKING:
from deerflow.runtime.events.store.base import RunEventStore from deerflow.runtime.events.store.base import RunEventStore
@@ -72,34 +77,39 @@ class RunJournal(BaseCallbackHandler):
# LLM request/response tracking # LLM request/response tracking
self._llm_call_index = 0 self._llm_call_index = 0
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
# Tool call ID cache
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
# -- Lifecycle callbacks -- # -- Lifecycle callbacks --
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None: def on_chain_start(
if kwargs.get("parent_run_id") is not None: self,
return serialized: dict[str, Any],
self._put( inputs: dict[str, Any],
event_type="run_start", *,
category="lifecycle", run_id: UUID,
metadata={"input_preview": str(inputs)[:500]}, parent_run_id: UUID | None = None,
) tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
caller = self._identify_caller(tags)
if parent_run_id is None:
# Root graph invocation — emit a single trace event for the run start.
chain_name = (serialized or {}).get("name", "unknown")
self._put(
event_type="run.start",
category="trace",
content={"chain": chain_name},
metadata={"caller": caller, **(metadata or {})},
)
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None: def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None: self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
return
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
self._flush_sync() self._flush_sync()
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put( self._put(
event_type="run_error", event_type="run.error",
category="lifecycle", category="error",
content=str(error), content=str(error),
metadata={"error_type": type(error).__name__}, metadata={"error_type": type(error).__name__},
) )
@@ -107,266 +117,132 @@ class RunJournal(BaseCallbackHandler):
# -- LLM callbacks -- # -- LLM callbacks --
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None: def on_chat_model_start(
"""Capture structured prompt messages for llm_request event.""" self,
from deerflow.runtime.converters import langchain_messages_to_openai serialized: dict,
messages: list[list[BaseMessage]],
*,
run_id: UUID,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
"""Capture structured prompt messages for llm_request event.
This is also the canonical place to extract the first human message:
messages are fully structured here, it fires only on real LLM calls,
and the content is never compressed by checkpoint trimming.
"""
rid = str(run_id) rid = str(run_id)
self._llm_start_times[rid] = time.monotonic() self._llm_start_times[rid] = time.monotonic()
self._llm_call_index += 1 self._llm_call_index += 1
# Mark this run_id as seen so on_llm_end knows not to increment again.
self._cached_prompts[rid] = []
model_name = serialized.get("name", "") logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
self._cached_models[rid] = model_name
# Convert the first message list (LangChain passes list-of-lists) # Capture the first human message sent to any LLM in this run.
prompt_msgs = messages[0] if messages else [] if not self._first_human_msg:
openai_msgs = langchain_messages_to_openai(prompt_msgs) for batch in messages.reversed():
self._cached_prompts[rid] = openai_msgs for m in batch.reversed():
if isinstance(m, HumanMessage) and m.name != "summary":
caller = self._identify_caller(tags)
self.set_first_human_message(m.text)
self._put(
event_type="llm.human.input",
category="message",
content=m.model_dump(),
metadata={"caller": caller},
)
break
if self._first_human_msg:
break
caller = self._identify_caller(kwargs) def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None:
self._put(
event_type="llm_request",
category="trace",
content={"model": model_name, "messages": openai_msgs},
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
)
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
# Fallback: on_chat_model_start is preferred. This just tracks latency. # Fallback: on_chat_model_start is preferred. This just tracks latency.
self._llm_start_times[str(run_id)] = time.monotonic() self._llm_start_times[str(run_id)] = time.monotonic()
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None: def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None:
from deerflow.runtime.converters import langchain_to_openai_completion messages: list[AnyMessage] = []
logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}")
try: for generation in response.generations:
message = response.generations[0][0].message for gen in generation:
except (IndexError, AttributeError): if hasattr(gen, "message"):
logger.debug("on_llm_end: could not extract message from response") messages.append(gen.message)
return
caller = self._identify_caller(kwargs)
# Latency
rid = str(run_id)
start = self._llm_start_times.pop(rid, None)
latency_ms = int((time.monotonic() - start) * 1000) if start else None
# Token usage from message
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
# Resolve call index
call_index = self._llm_call_index
if rid not in self._cached_prompts:
# Fallback: on_chat_model_start was not called
self._llm_call_index += 1
call_index = self._llm_call_index
# Clean up caches
self._cached_prompts.pop(rid, None)
self._cached_models.pop(rid, None)
# Trace event: llm_response (OpenAI completion format)
content = getattr(message, "content", "")
self._put(
event_type="llm_response",
category="trace",
content=langchain_to_openai_completion(message),
metadata={
"caller": caller,
"usage": usage_dict,
"latency_ms": latency_ms,
"llm_call_index": call_index,
},
)
# Message events: only lead_agent gets message-category events.
# Content uses message.model_dump() to align with checkpoint format.
tool_calls = getattr(message, "tool_calls", None) or []
if caller == "lead_agent":
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
if tool_calls:
# ai_tool_call: agent decided to use tools
self._put(
event_type="ai_tool_call",
category="message",
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
)
elif isinstance(content, str) and content:
# ai_message: final text reply
self._put(
event_type="ai_message",
category="message",
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "stop"},
)
self._last_ai_msg = content
self._msg_count += 1
# Token accumulation
if self._track_tokens:
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if total_tk == 0:
total_tk = input_tk + output_tk
if total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else: else:
self._lead_agent_tokens += total_tk logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}")
for message in messages:
caller = self._identify_caller(tags)
# Latency
rid = str(run_id)
start = self._llm_start_times.pop(rid, None)
latency_ms = int((time.monotonic() - start) * 1000) if start else None
# Token usage from message
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
# Resolve call index
call_index = self._llm_call_index
if rid not in self._cached_prompts:
# Fallback: on_chat_model_start was not called
self._llm_call_index += 1
call_index = self._llm_call_index
# Trace event: llm_response (OpenAI completion format)
self._put(
event_type="llm.ai.response",
category="message",
content=message.model_dump(),
metadata={
"caller": caller,
"usage": usage_dict,
"latency_ms": latency_ms,
"llm_call_index": call_index,
},
)
# Token accumulation
if self._track_tokens:
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if total_tk == 0:
total_tk = input_tk + output_tk
if total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None) self._llm_start_times.pop(str(run_id), None)
self._put(event_type="llm_error", category="trace", content=str(error)) self._put(event_type="llm.error", category="trace", content=str(error))
# -- Tool callbacks -- def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs):
"""Handle tool start event, cache tool call ID for later correlation"""
tool_call_id = str(run_id)
logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}")
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None: def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs):
tool_call_id = kwargs.get("tool_call_id") """Handle tool end event, append message and clear node data"""
if tool_call_id: try:
self._tool_call_ids[str(run_id)] = tool_call_id if isinstance(output, ToolMessage):
self._put( msg = cast(ToolMessage, output)
event_type="tool_start", self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
category="trace", elif isinstance(output, Command):
metadata={ cmd = cast(Command, output)
"tool_name": serialized.get("name", ""), messages = cmd.update.get("messages", [])
"tool_call_id": tool_call_id, for message in messages:
"args": str(input_str)[:2000], if isinstance(message, BaseMessage):
}, self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
) else:
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None: else:
from langchain_core.messages import ToolMessage logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}")
from langgraph.types import Command finally:
logger.info(f"Tool end for node {run_id}")
# Tools that update graph state return a ``Command`` (e.g.
# ``present_files``). LangGraph later unwraps the inner ToolMessage
# into checkpoint state, so to stay checkpoint-aligned we must
# extract it here rather than storing ``str(Command(...))``.
if isinstance(output, Command):
update = getattr(output, "update", None) or {}
inner_msgs = update.get("messages") if isinstance(update, dict) else None
if isinstance(inner_msgs, list):
inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None)
if inner_tool_msg is not None:
output = inner_tool_msg
# Extract fields from ToolMessage object when LangChain provides one.
# LangChain's _format_output wraps tool results into a ToolMessage
# with tool_call_id, name, status, and artifact — more complete than
# what kwargs alone provides.
if isinstance(output, ToolMessage):
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = output.name or kwargs.get("name", "")
status = getattr(output, "status", "success") or "success"
content_str = output.content if isinstance(output.content, str) else str(output.content)
# Use model_dump() for checkpoint-aligned message content.
# Override tool_call_id if it was resolved from cache.
msg_content = output.model_dump()
if msg_content.get("tool_call_id") != tool_call_id:
msg_content["tool_call_id"] = tool_call_id
else:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
status = "success"
content_str = str(output)
# Construct checkpoint-aligned dict when output is a plain string.
msg_content = ToolMessage(
content=content_str,
tool_call_id=tool_call_id or "",
name=tool_name,
status=status,
).model_dump()
# Trace event (always)
self._put(
event_type="tool_end",
category="trace",
content=content_str,
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
"status": status,
},
)
# Message event: tool_result (checkpoint-aligned model_dump format)
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": status},
)
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
# Trace event
self._put(
event_type="tool_error",
category="trace",
content=str(error),
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
},
)
# Message event: tool_result with error status (checkpoint-aligned)
msg_content = ToolMessage(
content=str(error),
tool_call_id=tool_call_id or "",
name=tool_name,
status="error",
).model_dump()
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": "error"},
)
# -- Custom event callback --
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.serialization import serialize_lc_object
if name == "summarization":
data_dict = data if isinstance(data, dict) else {}
self._put(
event_type="summarization",
category="trace",
content=data_dict.get("summary", ""),
metadata={
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
"replaced_count": data_dict.get("replaced_count", 0),
},
)
self._put(
event_type="middleware:summarize",
category="middleware",
content={"role": "system", "content": data_dict.get("summary", "")},
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
)
else:
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
self._put(
event_type=name,
category="trace",
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
)
# -- Internal methods -- # -- Internal methods --
@@ -431,8 +307,9 @@ class RunJournal(BaseCallbackHandler):
if exc: if exc:
logger.warning("Journal flush task failed: %s", exc) logger.warning("Journal flush task failed: %s", exc)
def _identify_caller(self, kwargs: dict) -> str: def _identify_caller(self, tags: list[str] | None, **kwargs) -> str:
for tag in kwargs.get("tags") or []: _tags = tags or kwargs.get("tags", [])
for tag in _tags:
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"): if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
return tag return tag
# Default to lead_agent: the main agent graph does not inject # Default to lead_agent: the main agent graph does not inject
@@ -54,7 +54,7 @@ class RunManager:
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._store = store self._store = store
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None: async def _persist_to_store(self, record: RunRecord) -> None:
"""Best-effort persist run record to backing store.""" """Best-effort persist run record to backing store."""
if self._store is None: if self._store is None:
return return
@@ -68,7 +68,6 @@ class RunManager:
metadata=record.metadata or {}, metadata=record.metadata or {},
kwargs=record.kwargs or {}, kwargs=record.kwargs or {},
created_at=record.created_at, created_at=record.created_at,
follow_up_to_run_id=follow_up_to_run_id,
) )
except Exception: except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True) logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
@@ -90,7 +89,6 @@ class RunManager:
metadata: dict | None = None, metadata: dict | None = None,
kwargs: dict | None = None, kwargs: dict | None = None,
multitask_strategy: str = "reject", multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord: ) -> RunRecord:
"""Create a new pending run and register it.""" """Create a new pending run and register it."""
run_id = str(uuid.uuid4()) run_id = str(uuid.uuid4())
@@ -109,7 +107,7 @@ class RunManager:
) )
async with self._lock: async with self._lock:
self._runs[run_id] = record self._runs[run_id] = record
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id) await self._persist_to_store(record)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record return record
@@ -122,7 +120,7 @@ class RunManager:
async with self._lock: async with self._lock:
# Dict insertion order matches creation order, so reversing it gives # Dict insertion order matches creation order, so reversing it gives
# us deterministic newest-first results even when timestamps tie. # us deterministic newest-first results even when timestamps tie.
return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id] return [r for r in self._runs.values() if r.thread_id == thread_id]
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
"""Transition a run to a new status.""" """Transition a run to a new status."""
@@ -176,7 +174,6 @@ class RunManager:
metadata: dict | None = None, metadata: dict | None = None,
kwargs: dict | None = None, kwargs: dict | None = None,
multitask_strategy: str = "reject", multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord: ) -> RunRecord:
"""Atomically check for inflight runs and create a new one. """Atomically check for inflight runs and create a new one.
@@ -230,7 +227,7 @@ class RunManager:
) )
self._runs[run_id] = record self._runs[run_id] = record
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id) await self._persist_to_store(record)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record return record
@@ -29,7 +29,6 @@ class RunStore(abc.ABC):
kwargs: dict[str, Any] | None = None, kwargs: dict[str, Any] | None = None,
error: str | None = None, error: str | None = None,
created_at: str | None = None, created_at: str | None = None,
follow_up_to_run_id: str | None = None,
) -> None: ) -> None:
pass pass
@@ -28,7 +28,6 @@ class MemoryRunStore(RunStore):
kwargs=None, kwargs=None,
error=None, error=None,
created_at=None, created_at=None,
follow_up_to_run_id=None,
): ):
now = datetime.now(UTC).isoformat() now = datetime.now(UTC).isoformat()
self._runs[run_id] = { self._runs[run_id] = {
@@ -41,7 +40,6 @@ class MemoryRunStore(RunStore):
"metadata": metadata or {}, "metadata": metadata or {},
"kwargs": kwargs or {}, "kwargs": kwargs or {},
"error": error, "error": error,
"follow_up_to_run_id": follow_up_to_run_id,
"created_at": created_at or now, "created_at": created_at or now,
"updated_at": now, "updated_at": now,
} }
@@ -51,7 +51,6 @@ class RunContext:
event_store: Any | None = field(default=None) event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None) run_events_config: Any | None = field(default=None)
thread_store: Any | None = field(default=None) thread_store: Any | None = field(default=None)
follow_up_to_run_id: str | None = field(default=None)
async def run_agent( async def run_agent(
@@ -76,7 +75,6 @@ async def run_agent(
event_store = ctx.event_store event_store = ctx.event_store
run_events_config = ctx.run_events_config run_events_config = ctx.run_events_config
thread_store = ctx.thread_store thread_store = ctx.thread_store
follow_up_to_run_id = ctx.follow_up_to_run_id
run_id = record.run_id run_id = record.run_id
thread_id = record.thread_id thread_id = record.thread_id
@@ -113,22 +111,6 @@ async def run_agent(
track_token_usage=getattr(run_events_config, "track_token_usage", True), track_token_usage=getattr(run_events_config, "track_token_usage", True),
) )
human_msg = _extract_human_message(graph_input)
if human_msg is not None:
msg_metadata = {}
if follow_up_to_run_id:
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
await event_store.put(
thread_id=thread_id,
run_id=run_id,
event_type="human_message",
category="message",
content=human_msg.model_dump(),
metadata=msg_metadata or None,
)
content = human_msg.content
journal.set_first_human_message(content if isinstance(content, str) else str(content))
# 1. Mark running # 1. Mark running
await run_manager.set_status(run_id, RunStatus.running) await run_manager.set_status(run_id, RunStatus.running)
@@ -166,12 +148,13 @@ async def run_agent(
# Inject runtime context so middlewares can access thread_id # Inject runtime context so middlewares can access thread_id
# (langgraph-cli does this automatically; we must do it manually) # (langgraph-cli does this automatically; we must do it manually)
runtime = Runtime(context={"thread_id": thread_id}, store=store) runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store)
# If the caller already set a ``context`` key (LangGraph >= 0.6.0 # If the caller already set a ``context`` key (LangGraph >= 0.6.0
# prefers it over ``configurable`` for thread-level data), make # prefers it over ``configurable`` for thread-level data), make
# sure ``thread_id`` is available there too. # sure ``thread_id`` is available there too.
if "context" in config and isinstance(config["context"], dict): if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id) config["context"].setdefault("thread_id", thread_id)
config["context"].setdefault("run_id", run_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Inject RunJournal as a LangChain callback handler. # Inject RunJournal as a LangChain callback handler.
File diff suppressed because it is too large Load Diff
+5 -5
View File
@@ -75,27 +75,27 @@ async def test_cancel_not_inflight(manager: RunManager):
@pytest.mark.anyio @pytest.mark.anyio
async def test_list_by_thread(manager: RunManager): async def test_list_by_thread(manager: RunManager):
"""Same thread should return multiple runs, newest first.""" """Same thread should return multiple runs."""
r1 = await manager.create("thread-1") r1 = await manager.create("thread-1")
r2 = await manager.create("thread-1") r2 = await manager.create("thread-1")
await manager.create("thread-2") await manager.create("thread-2")
runs = await manager.list_by_thread("thread-1") runs = await manager.list_by_thread("thread-1")
assert len(runs) == 2 assert len(runs) == 2
assert runs[0].run_id == r2.run_id assert runs[0].run_id == r1.run_id
assert runs[1].run_id == r1.run_id assert runs[1].run_id == r2.run_id
@pytest.mark.anyio @pytest.mark.anyio
async def test_list_by_thread_is_stable_when_timestamps_tie(manager: RunManager, monkeypatch: pytest.MonkeyPatch): async def test_list_by_thread_is_stable_when_timestamps_tie(manager: RunManager, monkeypatch: pytest.MonkeyPatch):
"""Newest-first ordering should not depend on timestamp precision.""" """Ordering should be stable (insertion order) even when timestamps tie."""
monkeypatch.setattr("deerflow.runtime.runs.manager._now_iso", lambda: "2026-01-01T00:00:00+00:00") monkeypatch.setattr("deerflow.runtime.runs.manager._now_iso", lambda: "2026-01-01T00:00:00+00:00")
r1 = await manager.create("thread-1") r1 = await manager.create("thread-1")
r2 = await manager.create("thread-1") r2 = await manager.create("thread-1")
runs = await manager.list_by_thread("thread-1") runs = await manager.list_by_thread("thread-1")
assert [run.run_id for run in runs] == [r2.run_id, r1.run_id] assert [run.run_id for run in runs] == [r1.run_id, r2.run_id]
@pytest.mark.anyio @pytest.mark.anyio
@@ -46,7 +46,13 @@ export default function AgentChatPage() {
const [settings, setSettings] = useThreadSettings(threadId); const [settings, setSettings] = useThreadSettings(threadId);
const { showNotification } = useNotification(); const { showNotification } = useNotification();
const [thread, sendMessage] = useThreadStream({ const {
thread,
sendMessage,
isHistoryLoading,
hasMoreHistory,
loadMoreHistory,
} = useThreadStream({
threadId: isNewThread ? undefined : threadId, threadId: isNewThread ? undefined : threadId,
context: { ...settings.context, agent_name: agent_name }, context: { ...settings.context, agent_name: agent_name },
onStart: (createdThreadId) => { onStart: (createdThreadId) => {
@@ -141,6 +147,9 @@ export default function AgentChatPage() {
threadId={threadId} threadId={threadId}
thread={thread} thread={thread}
paddingBottom={messageListPaddingBottom} paddingBottom={messageListPaddingBottom}
hasMoreHistory={hasMoreHistory}
loadMoreHistory={loadMoreHistory}
isHistoryLoading={isHistoryLoading}
/> />
</div> </div>
@@ -1,6 +1,6 @@
"use client"; "use client";
import { useCallback, useEffect, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { type PromptInputMessage } from "@/components/ai-elements/prompt-input"; import { type PromptInputMessage } from "@/components/ai-elements/prompt-input";
import { ArtifactTrigger } from "@/components/workspace/artifacts"; import { ArtifactTrigger } from "@/components/workspace/artifacts";
@@ -35,19 +35,30 @@ export default function ChatPage() {
const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } = const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } =
useThreadChat(); useThreadChat();
const [settings, setSettings] = useThreadSettings(threadId); const [settings, setSettings] = useThreadSettings(threadId);
const [mounted, setMounted] = useState(false); const mountedRef = useRef(false);
useSpecificChatMode(); useSpecificChatMode();
useEffect(() => { useEffect(() => {
setMounted(true); mountedRef.current = true;
}, []); }, []);
const { showNotification } = useNotification(); const { showNotification } = useNotification();
const [thread, sendMessage, isUploading] = useThreadStream({ const {
thread,
sendMessage,
isUploading,
isHistoryLoading,
hasMoreHistory,
loadMoreHistory,
} = useThreadStream({
threadId: isNewThread ? undefined : threadId, threadId: isNewThread ? undefined : threadId,
context: settings.context, context: settings.context,
isMock, isMock,
onSend: (_threadId) => {
setThreadId(_threadId);
setIsNewThread(false);
},
onStart: (createdThreadId) => { onStart: (createdThreadId) => {
setThreadId(createdThreadId); setThreadId(createdThreadId);
setIsNewThread(false); setIsNewThread(false);
@@ -115,6 +126,9 @@ export default function ChatPage() {
threadId={threadId} threadId={threadId}
thread={thread} thread={thread}
paddingBottom={messageListPaddingBottom} paddingBottom={messageListPaddingBottom}
hasMoreHistory={hasMoreHistory}
loadMoreHistory={loadMoreHistory}
isHistoryLoading={isHistoryLoading}
/> />
</div> </div>
<div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4"> <div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4">
@@ -138,7 +152,7 @@ export default function ChatPage() {
/> />
</div> </div>
</div> </div>
{mounted ? ( {mountedRef.current ? (
<InputBox <InputBox
className={cn("bg-background/5 w-full -translate-y-4")} className={cn("bg-background/5 w-full -translate-y-4")}
isNewThread={isNewThread} isNewThread={isNewThread}
@@ -170,7 +184,7 @@ export default function ChatPage() {
<div <div
aria-hidden="true" aria-hidden="true"
className={cn( className={cn(
"bg-background/5 h-32 w-full -translate-y-4 rounded-2xl border", "bg-background/5 h-32 w-full -translate-y-4 rounded-2xl",
)} )}
/> />
)} )}
+21 -17
View File
@@ -55,7 +55,7 @@ import {
DropdownMenuLabel, DropdownMenuLabel,
DropdownMenuSeparator, DropdownMenuSeparator,
} from "@/components/ui/dropdown-menu"; } from "@/components/ui/dropdown-menu";
import { fetchWithAuth } from "@/core/api/fetcher"; import { fetch } from "@/core/api/fetcher";
import { getBackendBaseURL } from "@/core/config"; import { getBackendBaseURL } from "@/core/config";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { useModels } from "@/core/models/hooks"; import { useModels } from "@/core/models/hooks";
@@ -155,6 +155,7 @@ export function InputBox({
const [followupsLoading, setFollowupsLoading] = useState(false); const [followupsLoading, setFollowupsLoading] = useState(false);
const lastGeneratedForAiIdRef = useRef<string | null>(null); const lastGeneratedForAiIdRef = useRef<string | null>(null);
const wasStreamingRef = useRef(false); const wasStreamingRef = useRef(false);
const messagesRef = useRef(thread.messages);
const [confirmOpen, setConfirmOpen] = useState(false); const [confirmOpen, setConfirmOpen] = useState(false);
const [pendingSuggestion, setPendingSuggestion] = useState<string | null>( const [pendingSuggestion, setPendingSuggestion] = useState<string | null>(
@@ -354,6 +355,10 @@ export function InputBox({
followupsVisibilityChangeRef.current?.(showFollowups); followupsVisibilityChangeRef.current?.(showFollowups);
}, [showFollowups]); }, [showFollowups]);
useEffect(() => {
messagesRef.current = thread.messages;
}, [thread.messages]);
useEffect(() => { useEffect(() => {
return () => followupsVisibilityChangeRef.current?.(false); return () => followupsVisibilityChangeRef.current?.(false);
}, []); }, []);
@@ -370,14 +375,16 @@ export function InputBox({
return; return;
} }
const lastAi = [...thread.messages].reverse().find((m) => m.type === "ai"); const lastAi = [...messagesRef.current]
.reverse()
.find((m) => m.type === "ai");
const lastAiId = lastAi?.id ?? null; const lastAiId = lastAi?.id ?? null;
if (!lastAiId || lastAiId === lastGeneratedForAiIdRef.current) { if (!lastAiId || lastAiId === lastGeneratedForAiIdRef.current) {
return; return;
} }
lastGeneratedForAiIdRef.current = lastAiId; lastGeneratedForAiIdRef.current = lastAiId;
const recent = thread.messages const recent = messagesRef.current
.filter((m) => m.type === "human" || m.type === "ai") .filter((m) => m.type === "human" || m.type === "ai")
.map((m) => { .map((m) => {
const role = m.type === "human" ? "user" : "assistant"; const role = m.type === "human" ? "user" : "assistant";
@@ -396,19 +403,16 @@ export function InputBox({
setFollowupsLoading(true); setFollowupsLoading(true);
setFollowups([]); setFollowups([]);
fetchWithAuth( fetch(`${getBackendBaseURL()}/api/threads/${threadId}/suggestions`, {
`${getBackendBaseURL()}/api/threads/${threadId}/suggestions`, method: "POST",
{ headers: { "Content-Type": "application/json" },
method: "POST", body: JSON.stringify({
headers: { "Content-Type": "application/json" }, messages: recent,
body: JSON.stringify({ n: 3,
messages: recent, model_name: context.model_name ?? undefined,
n: 3, }),
model_name: context.model_name ?? undefined, signal: controller.signal,
}), })
signal: controller.signal,
},
)
.then(async (res) => { .then(async (res) => {
if (!res.ok) { if (!res.ok) {
return { suggestions: [] as string[] }; return { suggestions: [] as string[] };
@@ -430,7 +434,7 @@ export function InputBox({
}); });
return () => controller.abort(); return () => controller.abort();
}, [context.model_name, disabled, isMock, status, thread.messages, threadId]); }, [context.model_name, disabled, isMock, status, threadId]);
return ( return (
<div ref={promptRootRef} className="relative flex flex-col gap-4"> <div ref={promptRootRef} className="relative flex flex-col gap-4">
@@ -1,9 +1,12 @@
import type { BaseStream } from "@langchain/langgraph-sdk/react"; import type { BaseStream } from "@langchain/langgraph-sdk/react";
import { ChevronUpIcon, Loader2Icon } from "lucide-react";
import { useCallback, useEffect, useRef } from "react";
import { import {
Conversation, Conversation,
ConversationContent, ConversationContent,
} from "@/components/ai-elements/conversation"; } from "@/components/ai-elements/conversation";
import { Button } from "@/components/ui/button";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { import {
extractContentFromMessage, extractContentFromMessage,
@@ -18,7 +21,6 @@ import { useRehypeSplitWordsIntoSpans } from "@/core/rehype";
import type { Subtask } from "@/core/tasks"; import type { Subtask } from "@/core/tasks";
import { useUpdateSubtask } from "@/core/tasks/context"; import { useUpdateSubtask } from "@/core/tasks/context";
import type { AgentThreadState } from "@/core/threads"; import type { AgentThreadState } from "@/core/threads";
import { useThreadMessageEnrichment } from "@/core/threads/hooks";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { ArtifactFileList } from "../artifacts/artifact-file-list"; import { ArtifactFileList } from "../artifacts/artifact-file-list";
@@ -33,22 +35,134 @@ import { SubtaskCard } from "./subtask-card";
export const MESSAGE_LIST_DEFAULT_PADDING_BOTTOM = 160; export const MESSAGE_LIST_DEFAULT_PADDING_BOTTOM = 160;
export const MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM = 80; export const MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM = 80;
const LOAD_MORE_HISTORY_THROTTLE_MS = 1200;
function LoadMoreHistoryIndicator({
isLoading,
hasMore,
loadMore,
}: {
isLoading?: boolean;
hasMore?: boolean;
loadMore?: () => void;
}) {
const { t } = useI18n();
const sentinelRef = useRef<HTMLDivElement | null>(null);
const timeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const lastLoadRef = useRef(0);
const throttledLoadMore = useCallback(() => {
if (!hasMore || isLoading) {
return;
}
const now = Date.now();
const remaining =
LOAD_MORE_HISTORY_THROTTLE_MS - (now - lastLoadRef.current);
if (remaining <= 0) {
lastLoadRef.current = now;
loadMore?.();
return;
}
if (timeoutRef.current) {
return;
}
timeoutRef.current = setTimeout(() => {
timeoutRef.current = null;
if (!hasMore || isLoading) {
return;
}
lastLoadRef.current = Date.now();
loadMore?.();
}, remaining);
}, [hasMore, isLoading, loadMore]);
useEffect(() => {
const element = sentinelRef.current;
if (!element || !hasMore) {
return;
}
const observer = new IntersectionObserver(
([entry]) => {
if (entry?.isIntersecting) {
throttledLoadMore();
}
},
{
rootMargin: "120px 0px 0px 0px",
},
);
observer.observe(element);
return () => {
observer.disconnect();
};
}, [hasMore, throttledLoadMore]);
useEffect(() => {
return () => {
if (timeoutRef.current) {
clearTimeout(timeoutRef.current);
}
};
}, []);
if (!hasMore && !isLoading) {
return null;
}
return (
<div ref={sentinelRef} className="flex w-full justify-center">
<Button
type="button"
variant="ghost"
size="sm"
className="text-muted-foreground hover:text-foreground rounded-full px-3"
disabled={(isLoading ?? false) || !hasMore}
onClick={throttledLoadMore}
>
{isLoading ? (
<>
<Loader2Icon className="mr-2 size-4 animate-spin" />
{t.common.loading}
</>
) : (
<>
<ChevronUpIcon className="mr-2 size-4" />
{t.common.loadMore}
</>
)}
</Button>
</div>
);
}
export function MessageList({ export function MessageList({
className, className,
threadId, threadId,
thread, thread,
paddingBottom = MESSAGE_LIST_DEFAULT_PADDING_BOTTOM, paddingBottom = MESSAGE_LIST_DEFAULT_PADDING_BOTTOM,
hasMoreHistory,
loadMoreHistory,
isHistoryLoading,
}: { }: {
className?: string; className?: string;
threadId: string; threadId: string;
thread: BaseStream<AgentThreadState>; thread: BaseStream<AgentThreadState>;
paddingBottom?: number; paddingBottom?: number;
hasMoreHistory?: boolean;
loadMoreHistory?: () => void;
isHistoryLoading?: boolean;
}) { }) {
const { t } = useI18n(); const { t } = useI18n();
const rehypePlugins = useRehypeSplitWordsIntoSpans(thread.isLoading); const rehypePlugins = useRehypeSplitWordsIntoSpans(thread.isLoading);
const updateSubtask = useUpdateSubtask(); const updateSubtask = useUpdateSubtask();
const messages = thread.messages; const messages = thread.messages;
const { data: enrichment } = useThreadMessageEnrichment(threadId);
if (thread.isThreadLoading && messages.length === 0) { if (thread.isThreadLoading && messages.length === 0) {
return <MessageListSkeleton />; return <MessageListSkeleton />;
@@ -57,19 +171,21 @@ export function MessageList({
<Conversation <Conversation
className={cn("flex size-full flex-col justify-center", className)} className={cn("flex size-full flex-col justify-center", className)}
> >
<ConversationContent className="mx-auto w-full max-w-(--container-width-md) gap-8 pt-12"> <ConversationContent className="mx-auto w-full max-w-(--container-width-md) gap-8 pt-8">
<LoadMoreHistoryIndicator
isLoading={isHistoryLoading}
hasMore={hasMoreHistory}
loadMore={loadMoreHistory}
/>
{groupMessages(messages, (group) => { {groupMessages(messages, (group) => {
if (group.type === "human" || group.type === "assistant") { if (group.type === "human" || group.type === "assistant") {
return group.messages.map((msg) => { return group.messages.map((msg) => {
const entry = msg.id ? enrichment?.get(msg.id) : undefined;
return ( return (
<MessageListItem <MessageListItem
key={`${group.id}/${msg.id}`} key={`${group.id}/${msg.id}`}
threadId={threadId} threadId={threadId}
message={msg} message={msg}
isLoading={thread.isLoading} isLoading={thread.isLoading}
runId={entry?.run_id}
feedback={entry?.feedback}
/> />
); );
}); });
@@ -5,7 +5,7 @@ import { useState } from "react";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { fetchWithAuth, getCsrfHeaders } from "@/core/api/fetcher"; import { fetch, getCsrfHeaders } from "@/core/api/fetcher";
import { useAuth } from "@/core/auth/AuthProvider"; import { useAuth } from "@/core/auth/AuthProvider";
import { parseAuthError } from "@/core/auth/types"; import { parseAuthError } from "@/core/auth/types";
@@ -36,7 +36,7 @@ export function AccountSettingsPage() {
setLoading(true); setLoading(true);
try { try {
const res = await fetchWithAuth("/api/v1/auth/change-password", { const res = await fetch("/api/v1/auth/change-password", {
method: "POST", method: "POST",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
+4 -4
View File
@@ -1,4 +1,4 @@
import { fetchWithAuth } from "@/core/api/fetcher"; import { fetch } from "@/core/api/fetcher";
import { getBackendBaseURL } from "@/core/config"; import { getBackendBaseURL } from "@/core/config";
import type { Agent, CreateAgentRequest, UpdateAgentRequest } from "./types"; import type { Agent, CreateAgentRequest, UpdateAgentRequest } from "./types";
@@ -29,7 +29,7 @@ export async function getAgent(name: string): Promise<Agent> {
} }
export async function createAgent(request: CreateAgentRequest): Promise<Agent> { export async function createAgent(request: CreateAgentRequest): Promise<Agent> {
const res = await fetchWithAuth(`${getBackendBaseURL()}/api/agents`, { const res = await fetch(`${getBackendBaseURL()}/api/agents`, {
method: "POST", method: "POST",
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
body: JSON.stringify(request), body: JSON.stringify(request),
@@ -45,7 +45,7 @@ export async function updateAgent(
name: string, name: string,
request: UpdateAgentRequest, request: UpdateAgentRequest,
): Promise<Agent> { ): Promise<Agent> {
const res = await fetchWithAuth(`${getBackendBaseURL()}/api/agents/${name}`, { const res = await fetch(`${getBackendBaseURL()}/api/agents/${name}`, {
method: "PUT", method: "PUT",
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
body: JSON.stringify(request), body: JSON.stringify(request),
@@ -58,7 +58,7 @@ export async function updateAgent(
} }
export async function deleteAgent(name: string): Promise<void> { export async function deleteAgent(name: string): Promise<void> {
const res = await fetchWithAuth(`${getBackendBaseURL()}/api/agents/${name}`, { const res = await fetch(`${getBackendBaseURL()}/api/agents/${name}`, {
method: "DELETE", method: "DELETE",
}); });
if (!res.ok) throw new Error(`Failed to delete agent: ${res.statusText}`); if (!res.ok) throw new Error(`Failed to delete agent: ${res.statusText}`);
+3 -3
View File
@@ -1,6 +1,6 @@
import { getBackendBaseURL } from "../config"; import { getBackendBaseURL } from "../config";
import { fetchWithAuth } from "./fetcher"; import { fetch } from "./fetcher";
export interface FeedbackData { export interface FeedbackData {
feedback_id: string; feedback_id: string;
@@ -14,7 +14,7 @@ export async function upsertFeedback(
rating: number, rating: number,
comment?: string, comment?: string,
): Promise<FeedbackData> { ): Promise<FeedbackData> {
const res = await fetchWithAuth( const res = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/runs/${encodeURIComponent(runId)}/feedback`, `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/runs/${encodeURIComponent(runId)}/feedback`,
{ {
method: "PUT", method: "PUT",
@@ -32,7 +32,7 @@ export async function deleteFeedback(
threadId: string, threadId: string,
runId: string, runId: string,
): Promise<void> { ): Promise<void> {
const res = await fetchWithAuth( const res = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/runs/${encodeURIComponent(runId)}/feedback`, `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/runs/${encodeURIComponent(runId)}/feedback`,
{ method: "DELETE" }, { method: "DELETE" },
); );
+2 -2
View File
@@ -53,7 +53,7 @@ export function readCsrfCookie(): string | null {
* preserved; the helper only ADDS the CSRF header when it isn't already * preserved; the helper only ADDS the CSRF header when it isn't already
* present, so explicit overrides win. * present, so explicit overrides win.
*/ */
export async function fetchWithAuth( export async function fetch(
input: RequestInfo | string, input: RequestInfo | string,
init?: RequestInit, init?: RequestInit,
): Promise<Response> { ): Promise<Response> {
@@ -74,7 +74,7 @@ export async function fetchWithAuth(
} }
} }
const res = await fetch(url, { const res = await globalThis.fetch(url, {
...init, ...init,
headers, headers,
credentials: "include", credentials: "include",
+1
View File
@@ -29,6 +29,7 @@ export const enUS: Translations = {
close: "Close", close: "Close",
more: "More", more: "More",
search: "Search", search: "Search",
loadMore: "Load more",
download: "Download", download: "Download",
thinking: "Thinking", thinking: "Thinking",
artifacts: "Artifacts", artifacts: "Artifacts",
+1
View File
@@ -18,6 +18,7 @@ export interface Translations {
close: string; close: string;
more: string; more: string;
search: string; search: string;
loadMore: string;
download: string; download: string;
thinking: string; thinking: string;
artifacts: string; artifacts: string;
+1
View File
@@ -29,6 +29,7 @@ export const zhCN: Translations = {
close: "关闭", close: "关闭",
more: "更多", more: "更多",
search: "搜索", search: "搜索",
loadMore: "加载更多",
download: "下载", download: "下载",
thinking: "思考", thinking: "思考",
artifacts: "文件", artifacts: "文件",
+7 -10
View File
@@ -1,4 +1,4 @@
import { fetchWithAuth } from "@/core/api/fetcher"; import { fetch } from "@/core/api/fetcher";
import { getBackendBaseURL } from "@/core/config"; import { getBackendBaseURL } from "@/core/config";
import type { MCPConfig } from "./types"; import type { MCPConfig } from "./types";
@@ -9,15 +9,12 @@ export async function loadMCPConfig() {
} }
export async function updateMCPConfig(config: MCPConfig) { export async function updateMCPConfig(config: MCPConfig) {
const response = await fetchWithAuth( const response = await fetch(`${getBackendBaseURL()}/api/mcp/config`, {
`${getBackendBaseURL()}/api/mcp/config`, method: "PUT",
{ headers: {
method: "PUT", "Content-Type": "application/json",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(config),
}, },
); body: JSON.stringify(config),
});
return response.json(); return response.json();
} }
+16 -22
View File
@@ -1,4 +1,4 @@
import { fetchWithAuth } from "../api/fetcher"; import { fetch } from "../api/fetcher";
import { getBackendBaseURL } from "../config"; import { getBackendBaseURL } from "../config";
import type { import type {
@@ -86,14 +86,14 @@ export async function loadMemory(): Promise<UserMemory> {
} }
export async function clearMemory(): Promise<UserMemory> { export async function clearMemory(): Promise<UserMemory> {
const response = await fetchWithAuth(`${getBackendBaseURL()}/api/memory`, { const response = await fetch(`${getBackendBaseURL()}/api/memory`, {
method: "DELETE", method: "DELETE",
}); });
return readMemoryResponse(response, "Failed to clear memory"); return readMemoryResponse(response, "Failed to clear memory");
} }
export async function deleteMemoryFact(factId: string): Promise<UserMemory> { export async function deleteMemoryFact(factId: string): Promise<UserMemory> {
const response = await fetchWithAuth( const response = await fetch(
`${getBackendBaseURL()}/api/memory/facts/${encodeURIComponent(factId)}`, `${getBackendBaseURL()}/api/memory/facts/${encodeURIComponent(factId)}`,
{ {
method: "DELETE", method: "DELETE",
@@ -108,32 +108,26 @@ export async function exportMemory(): Promise<UserMemory> {
} }
export async function importMemory(memory: UserMemory): Promise<UserMemory> { export async function importMemory(memory: UserMemory): Promise<UserMemory> {
const response = await fetchWithAuth( const response = await fetch(`${getBackendBaseURL()}/api/memory/import`, {
`${getBackendBaseURL()}/api/memory/import`, method: "POST",
{ headers: {
method: "POST", "Content-Type": "application/json",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(memory),
}, },
); body: JSON.stringify(memory),
});
return readMemoryResponse(response, "Failed to import memory"); return readMemoryResponse(response, "Failed to import memory");
} }
export async function createMemoryFact( export async function createMemoryFact(
input: MemoryFactInput, input: MemoryFactInput,
): Promise<UserMemory> { ): Promise<UserMemory> {
const response = await fetchWithAuth( const response = await fetch(`${getBackendBaseURL()}/api/memory/facts`, {
`${getBackendBaseURL()}/api/memory/facts`, method: "POST",
{ headers: {
method: "POST", "Content-Type": "application/json",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(input),
}, },
); body: JSON.stringify(input),
});
return readMemoryResponse(response, "Failed to create memory fact"); return readMemoryResponse(response, "Failed to create memory fact");
} }
@@ -141,7 +135,7 @@ export async function updateMemoryFact(
factId: string, factId: string,
input: MemoryFactPatchInput, input: MemoryFactPatchInput,
): Promise<UserMemory> { ): Promise<UserMemory> {
const response = await fetchWithAuth( const response = await fetch(
`${getBackendBaseURL()}/api/memory/facts/${encodeURIComponent(factId)}`, `${getBackendBaseURL()}/api/memory/facts/${encodeURIComponent(factId)}`,
{ {
method: "PATCH", method: "PATCH",
+5 -1
View File
@@ -328,7 +328,11 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) {
} }
export function isHiddenFromUIMessage(message: Message) { export function isHiddenFromUIMessage(message: Message) {
return message.additional_kwargs?.hide_from_ui === true; return (
message.additional_kwargs?.hide_from_ui === true ||
message.name === "summary" ||
message.name === "loop_warning"
);
} }
/** /**
+8 -11
View File
@@ -1,4 +1,4 @@
import { fetchWithAuth } from "@/core/api/fetcher"; import { fetch } from "@/core/api/fetcher";
import { getBackendBaseURL } from "@/core/config"; import { getBackendBaseURL } from "@/core/config";
import type { Skill } from "./type"; import type { Skill } from "./type";
@@ -10,7 +10,7 @@ export async function loadSkills() {
} }
export async function enableSkill(skillName: string, enabled: boolean) { export async function enableSkill(skillName: string, enabled: boolean) {
const response = await fetchWithAuth( const response = await fetch(
`${getBackendBaseURL()}/api/skills/${skillName}`, `${getBackendBaseURL()}/api/skills/${skillName}`,
{ {
method: "PUT", method: "PUT",
@@ -39,16 +39,13 @@ export interface InstallSkillResponse {
export async function installSkill( export async function installSkill(
request: InstallSkillRequest, request: InstallSkillRequest,
): Promise<InstallSkillResponse> { ): Promise<InstallSkillResponse> {
const response = await fetchWithAuth( const response = await fetch(`${getBackendBaseURL()}/api/skills/install`, {
`${getBackendBaseURL()}/api/skills/install`, method: "POST",
{ headers: {
method: "POST", "Content-Type": "application/json",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(request),
}, },
); body: JSON.stringify(request),
});
if (!response.ok) { if (!response.ok) {
// Handle HTTP error responses (4xx, 5xx) // Handle HTTP error responses (4xx, 5xx)
+214 -189
View File
@@ -1,15 +1,14 @@
import type { AIMessage, Message } from "@langchain/langgraph-sdk"; import type { AIMessage, Message, Run } from "@langchain/langgraph-sdk";
import type { ThreadsClient } from "@langchain/langgraph-sdk/client"; import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
import { useStream } from "@langchain/langgraph-sdk/react"; import { useStream } from "@langchain/langgraph-sdk/react";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner"; import { toast } from "sonner";
import type { PromptInputMessage } from "@/components/ai-elements/prompt-input"; import type { PromptInputMessage } from "@/components/ai-elements/prompt-input";
import { getAPIClient } from "../api"; import { getAPIClient } from "../api";
import type { FeedbackData } from "../api/feedback"; import { fetch } from "../api/fetcher";
import { fetchWithAuth } from "../api/fetcher";
import { getBackendBaseURL } from "../config"; import { getBackendBaseURL } from "../config";
import { useI18n } from "../i18n/hooks"; import { useI18n } from "../i18n/hooks";
import type { FileInMessage } from "../messages/utils"; import type { FileInMessage } from "../messages/utils";
@@ -18,7 +17,7 @@ import { useUpdateSubtask } from "../tasks/context";
import type { UploadedFileInfo } from "../uploads"; import type { UploadedFileInfo } from "../uploads";
import { promptInputFilePartToFile, uploadFiles } from "../uploads"; import { promptInputFilePartToFile, uploadFiles } from "../uploads";
import type { AgentThread, AgentThreadState } from "./types"; import type { AgentThread, AgentThreadState, RunMessage } from "./types";
export type ToolEndEvent = { export type ToolEndEvent = {
name: string; name: string;
@@ -29,7 +28,8 @@ export type ThreadStreamOptions = {
threadId?: string | null | undefined; threadId?: string | null | undefined;
context: LocalSettings["context"]; context: LocalSettings["context"];
isMock?: boolean; isMock?: boolean;
onStart?: (threadId: string) => void; onSend?: (threadId: string) => void;
onStart?: (threadId: string, runId: string) => void;
onFinish?: (state: AgentThreadState) => void; onFinish?: (state: AgentThreadState) => void;
onToolEnd?: (event: ToolEndEvent) => void; onToolEnd?: (event: ToolEndEvent) => void;
}; };
@@ -38,79 +38,41 @@ type SendMessageOptions = {
additionalKwargs?: Record<string, unknown>; additionalKwargs?: Record<string, unknown>;
}; };
function normalizeStoredRunId(runId: string | null): string | null { function mergeMessages(
if (!runId) { historyMessages: Message[],
return null; threadMessages: Message[],
} optimisticMessages: Message[],
): Message[] {
const threadMessageIds = new Set(
threadMessages
.map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id))
.filter(Boolean),
);
const trimmed = runId.trim(); // The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
if (!trimmed) { // Scan from the end: shrink cutoff while messages are already in thread, stop as soon as
return null; // we hit one that isn't — everything before that point is non-overlapping.
} let cutoff = historyMessages.length;
for (let i = historyMessages.length - 1; i >= 0; i--) {
const queryIndex = trimmed.indexOf("?"); const msg = historyMessages[i];
if (queryIndex >= 0) { if (!msg) {
const params = new URLSearchParams(trimmed.slice(queryIndex + 1)); continue;
const queryRunId = params.get("run_id")?.trim(); }
if (queryRunId) { if (
return queryRunId; (msg?.id && threadMessageIds.has(msg.id)) ||
("tool_call_id" in msg && threadMessageIds.has(msg.tool_call_id))
) {
cutoff = i;
} else {
break;
} }
} }
const pathWithoutQueryOrHash = trimmed.split(/[?#]/, 1)[0]?.trim() ?? ""; return [
if (!pathWithoutQueryOrHash) { ...historyMessages.slice(0, cutoff),
return null; ...threadMessages,
} ...optimisticMessages,
];
const runsMarker = "/runs/";
const runsIndex = pathWithoutQueryOrHash.lastIndexOf(runsMarker);
if (runsIndex >= 0) {
const runIdAfterMarker = pathWithoutQueryOrHash
.slice(runsIndex + runsMarker.length)
.split("/", 1)[0]
?.trim();
if (runIdAfterMarker) {
return runIdAfterMarker;
}
return null;
}
const segments = pathWithoutQueryOrHash
.split("/")
.map((segment) => segment.trim())
.filter(Boolean);
return segments.at(-1) ?? null;
}
function getRunMetadataStorage(): {
getItem(key: `lg:stream:${string}`): string | null;
setItem(key: `lg:stream:${string}`, value: string): void;
removeItem(key: `lg:stream:${string}`): void;
} {
return {
getItem(key) {
const normalized = normalizeStoredRunId(
window.sessionStorage.getItem(key),
);
if (normalized) {
window.sessionStorage.setItem(key, normalized);
return normalized;
}
window.sessionStorage.removeItem(key);
return null;
},
setItem(key, value) {
const normalized = normalizeStoredRunId(value);
if (normalized) {
window.sessionStorage.setItem(key, normalized);
return;
}
window.sessionStorage.removeItem(key);
},
removeItem(key) {
window.sessionStorage.removeItem(key);
},
};
} }
function getStreamErrorMessage(error: unknown): string { function getStreamErrorMessage(error: unknown): string {
@@ -140,6 +102,7 @@ export function useThreadStream({
threadId, threadId,
context, context,
isMock, isMock,
onSend,
onStart, onStart,
onFinish, onFinish,
onToolEnd, onToolEnd,
@@ -151,17 +114,25 @@ export function useThreadStream({
// and to allow access to the current thread id in onUpdateEvent // and to allow access to the current thread id in onUpdateEvent
const threadIdRef = useRef<string | null>(threadId ?? null); const threadIdRef = useRef<string | null>(threadId ?? null);
const startedRef = useRef(false); const startedRef = useRef(false);
const listeners = useRef({ const listeners = useRef({
onSend,
onStart, onStart,
onFinish, onFinish,
onToolEnd, onToolEnd,
}); });
const {
messages: history,
hasMore: hasMoreHistory,
loadMore: loadMoreHistory,
loading: isHistoryLoading,
appendMessages,
} = useThreadHistory(onStreamThreadId ?? "");
// Keep listeners ref updated with latest callbacks // Keep listeners ref updated with latest callbacks
useEffect(() => { useEffect(() => {
listeners.current = { onStart, onFinish, onToolEnd }; listeners.current = { onSend, onStart, onFinish, onToolEnd };
}, [onStart, onFinish, onToolEnd]); }, [onSend, onStart, onFinish, onToolEnd]);
useEffect(() => { useEffect(() => {
const normalizedThreadId = threadId ?? null; const normalizedThreadId = threadId ?? null;
@@ -175,45 +146,26 @@ export function useThreadStream({
threadIdRef.current = normalizedThreadId; threadIdRef.current = normalizedThreadId;
}, [threadId]); }, [threadId]);
const _handleOnStart = useCallback((id: string) => { const handleStreamStart = useCallback((_threadId: string, _runId: string) => {
threadIdRef.current = _threadId;
if (!startedRef.current) { if (!startedRef.current) {
listeners.current.onStart?.(id); listeners.current.onStart?.(_threadId, _runId);
startedRef.current = true; startedRef.current = true;
} }
setOnStreamThreadId(_threadId);
}, []); }, []);
const handleStreamStart = useCallback(
(_threadId: string) => {
threadIdRef.current = _threadId;
_handleOnStart(_threadId);
},
[_handleOnStart],
);
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const updateSubtask = useUpdateSubtask(); const updateSubtask = useUpdateSubtask();
const runMetadataStorageRef = useRef<
ReturnType<typeof getRunMetadataStorage> | undefined
>(undefined);
if (
typeof window !== "undefined" &&
runMetadataStorageRef.current === undefined
) {
runMetadataStorageRef.current = getRunMetadataStorage();
}
const thread = useStream<AgentThreadState>({ const thread = useStream<AgentThreadState>({
client: getAPIClient(isMock), client: getAPIClient(isMock),
assistantId: "lead_agent", assistantId: "lead_agent",
threadId: onStreamThreadId, threadId: onStreamThreadId,
reconnectOnMount: runMetadataStorageRef.current reconnectOnMount: true,
? () => runMetadataStorageRef.current!
: false,
fetchStateHistory: { limit: 1 }, fetchStateHistory: { limit: 1 },
onCreated(meta) { onCreated(meta) {
handleStreamStart(meta.thread_id); handleStreamStart(meta.thread_id, meta.run_id);
setOnStreamThreadId(meta.thread_id);
if (context.agent_name && !isMock) { if (context.agent_name && !isMock) {
void getAPIClient() void getAPIClient()
.threads.update(meta.thread_id, { .threads.update(meta.thread_id, {
@@ -231,6 +183,34 @@ export function useThreadStream({
} }
}, },
onUpdateEvent(data) { onUpdateEvent(data) {
if (data["SummarizationMiddleware.before_model"]) {
const _messages = [
...(data["SummarizationMiddleware.before_model"].messages ?? []),
];
if (_messages.length < 2) {
return;
}
for (const m of _messages) {
if (m.name === "summary" && m.type === "human") {
summarizedRef.current?.add(m.id ?? "");
}
}
const _lastKeepMessage = _messages[2];
const _currentMessages = [...messagesRef.current];
const _movedMessages: Message[] = [];
for (const m of _currentMessages) {
if (m.id !== undefined && m.id === _lastKeepMessage?.id) {
break;
}
if (!summarizedRef.current?.has(m.id ?? "")) {
_movedMessages.push(m);
}
}
appendMessages(_movedMessages);
messagesRef.current = [];
}
const updates: Array<Partial<AgentThreadState> | null> = Object.values( const updates: Array<Partial<AgentThreadState> | null> = Object.values(
data || {}, data || {},
); );
@@ -295,9 +275,6 @@ export function useThreadStream({
onFinish(state) { onFinish(state) {
listeners.current.onFinish?.(state.values); listeners.current.onFinish?.(state.values);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
void queryClient.invalidateQueries({
queryKey: ["thread-message-enrichment"],
});
}, },
}); });
@@ -305,24 +282,25 @@ export function useThreadStream({
const [optimisticMessages, setOptimisticMessages] = useState<Message[]>([]); const [optimisticMessages, setOptimisticMessages] = useState<Message[]>([]);
const [isUploading, setIsUploading] = useState(false); const [isUploading, setIsUploading] = useState(false);
const sendInFlightRef = useRef(false); const sendInFlightRef = useRef(false);
const messagesRef = useRef<Message[]>([]);
const summarizedRef = useRef<Set<string>>(null);
// Track message count before sending so we know when server has responded // Track message count before sending so we know when server has responded
const prevMsgCountRef = useRef(thread.messages.length); const prevMsgCountRef = useRef(thread.messages.length);
summarizedRef.current ??= new Set<string>();
// Reset thread-local pending UI state when switching between threads so // Reset thread-local pending UI state when switching between threads so
// optimistic messages and in-flight guards do not leak across chat views. // optimistic messages and in-flight guards do not leak across chat views.
useEffect(() => { useEffect(() => {
startedRef.current = false; startedRef.current = false;
sendInFlightRef.current = false; sendInFlightRef.current = false;
prevMsgCountRef.current = 0;
setOptimisticMessages([]);
setIsUploading(false);
}, [threadId]); }, [threadId]);
// Clear optimistic when server messages arrive (count increases) // Clear optimistic when server messages arrive (count increases)
useEffect(() => { useEffect(() => {
if ( if (
optimisticMessages.length > 0 && optimisticMessages.length > 0 &&
thread.messages.length > prevMsgCountRef.current + 1 thread.messages.length > prevMsgCountRef.current
) { ) {
setOptimisticMessages([]); setOptimisticMessages([]);
} }
@@ -381,12 +359,7 @@ export function useThreadStream({
} }
setOptimisticMessages(newOptimistic); setOptimisticMessages(newOptimistic);
// Only fire onStart immediately for an existing persisted thread. listeners.current.onSend?.(threadId);
// Brand-new chats should wait for onCreated(meta.thread_id) so URL sync
// uses the real server-generated thread id.
if (threadIdRef.current) {
_handleOnStart(threadId);
}
let uploadedFileInfo: UploadedFileInfo[] = []; let uploadedFileInfo: UploadedFileInfo[] = [];
@@ -520,19 +493,106 @@ export function useThreadStream({
sendInFlightRef.current = false; sendInFlightRef.current = false;
} }
}, },
[thread, _handleOnStart, t.uploads.uploadingFiles, context, queryClient], [thread, t.uploads.uploadingFiles, context, queryClient],
); );
// Merge thread with optimistic messages for display // Cache the latest thread messages in a ref to compare against incoming history messages for deduplication,
const mergedThread = // and to allow access to the full message list in onUpdateEvent without causing re-renders.
optimisticMessages.length > 0 if (thread.messages.length >= messagesRef.current.length) {
? ({ messagesRef.current = thread.messages;
...thread, }
messages: [...thread.messages, ...optimisticMessages],
} as typeof thread)
: thread;
return [mergedThread, sendMessage, isUploading] as const; const mergedMessages = mergeMessages(
history,
thread.messages,
optimisticMessages,
);
// Merge history, live stream, and optimistic messages for display
// History messages may overlap with thread.messages; thread.messages take precedence
const mergedThread = {
...thread,
messages: mergedMessages,
} as typeof thread;
return {
thread: mergedThread,
sendMessage,
isUploading,
isHistoryLoading,
hasMoreHistory,
loadMoreHistory,
} as const;
}
export function useThreadHistory(threadId: string) {
const runs = useThreadRuns(threadId);
const threadIdRef = useRef(threadId);
const runsRef = useRef(runs.data ?? []);
const indexRef = useRef(-1);
const loadingRef = useRef(false);
const [loading, setLoading] = useState(false);
const [messages, setMessages] = useState<Message[]>([]);
loadingRef.current = loading;
const loadMessages = useCallback(async () => {
if (runsRef.current.length === 0) {
return;
}
const run = runsRef.current[indexRef.current];
if (!run || loadingRef.current) {
return;
}
try {
setLoading(true);
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadIdRef.current)}/runs/${encodeURIComponent(run.run_id)}/messages`,
{
method: "GET",
headers: {
"Content-Type": "application/json",
},
credentials: "include",
},
).then((res) => {
return res.json();
});
const _messages = result.data
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
.map((m) => m.content);
setMessages((prev) => [..._messages, ...prev]);
indexRef.current -= 1;
} catch (err) {
console.error(err);
} finally {
setLoading(false);
}
}, []);
useEffect(() => {
threadIdRef.current = threadId;
if (runs.data && runs.data.length > 0) {
runsRef.current = runs.data ?? [];
indexRef.current = runs.data.length - 1;
}
loadMessages().catch(() => {
toast.error("Failed to load thread history.");
});
}, [threadId, runs.data, loadMessages]);
const appendMessages = useCallback((_messages: Message[]) => {
setMessages((prev) => {
return [...prev, ..._messages];
});
}, []);
const hasMore = indexRef.current >= 0 || !runs.data;
return {
runs: runs.data,
messages,
loading,
appendMessages,
hasMore,
loadMore: loadMessages,
};
} }
export function useThreads( export function useThreads(
@@ -602,6 +662,33 @@ export function useThreads(
}); });
} }
export function useThreadRuns(threadId?: string) {
const apiClient = getAPIClient();
return useQuery<Run[]>({
queryKey: ["thread", threadId],
queryFn: async () => {
if (!threadId) {
return [];
}
const response = await apiClient.runs.list(threadId);
return response;
},
refetchOnWindowFocus: false,
});
}
export function useRunDetail(threadId: string, runId: string) {
const apiClient = getAPIClient();
return useQuery<Run>({
queryKey: ["thread", threadId, "run", runId],
queryFn: async () => {
const response = await apiClient.runs.get(threadId, runId);
return response;
},
refetchOnWindowFocus: false,
});
}
export function useDeleteThread() { export function useDeleteThread() {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const apiClient = getAPIClient(); const apiClient = getAPIClient();
@@ -609,7 +696,7 @@ export function useDeleteThread() {
mutationFn: async ({ threadId }: { threadId: string }) => { mutationFn: async ({ threadId }: { threadId: string }) => {
await apiClient.threads.delete(threadId); await apiClient.threads.delete(threadId);
const response = await fetchWithAuth( const response = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}`, `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}`,
{ {
method: "DELETE", method: "DELETE",
@@ -682,65 +769,3 @@ export function useRenameThread() {
}, },
}); });
} }
/** Per-message enrichment data attached by the backend ``/history`` helper. */
export interface MessageEnrichment {
run_id: string;
/** ``undefined`` = not feedback-eligible; ``null`` = eligible but unrated. */
feedback?: FeedbackData | null;
}
/**
* Fetch ``/history`` once and index feedback + run_id by message id.
*
* Replaces the old ``useThreadFeedback`` hook which keyed by AI-message
* ordinal position — an inherently fragile mapping that broke whenever
* ``ai_tool_call`` messages were interleaved with ``ai_message`` messages.
* Keying by ``message.id`` is stable regardless of run count, tool-call
* chains, or summarization.
*
* The ``/history`` response is refreshed on every stream completion via
* ``invalidateQueries(["thread-message-enrichment"])`` in ``onFinish``.
*/
export function useThreadMessageEnrichment(
threadId: string | null | undefined,
) {
return useQuery({
queryKey: ["thread-message-enrichment", threadId],
queryFn: async (): Promise<Map<string, MessageEnrichment>> => {
const empty = new Map<string, MessageEnrichment>();
if (!threadId) return empty;
const res = await fetchWithAuth(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/history`,
{
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ limit: 1 }),
},
);
if (!res.ok) return empty;
const entries = (await res.json()) as Array<{
values?: {
messages?: Array<{
id?: string;
run_id?: string;
feedback?: FeedbackData | null;
}>;
};
}>;
const messages = entries[0]?.values?.messages ?? [];
const map = new Map<string, MessageEnrichment>();
for (const m of messages) {
if (!m.id || !m.run_id) continue;
const entry: MessageEnrichment = { run_id: m.run_id };
// Preserve presence: "feedback" key absent → ineligible; present with
// null → eligible but unrated; present with object → rated.
if ("feedback" in m) entry.feedback = m.feedback;
map.set(m.id, entry);
}
return map;
},
enabled: !!threadId,
staleTime: 30_000,
});
}
+9
View File
@@ -22,3 +22,12 @@ export interface AgentThreadContext extends Record<string, unknown> {
export interface AgentThread extends Thread<AgentThreadState> { export interface AgentThread extends Thread<AgentThreadState> {
context?: AgentThreadContext; context?: AgentThreadContext;
} }
export interface RunMessage {
run_id: string;
content: Message;
metadata: {
caller: string;
};
created_at: string;
}
+3 -3
View File
@@ -2,7 +2,7 @@
* API functions for file uploads * API functions for file uploads
*/ */
import { fetchWithAuth } from "../api/fetcher"; import { fetch } from "../api/fetcher";
import { getBackendBaseURL } from "../config"; import { getBackendBaseURL } from "../config";
export interface UploadedFileInfo { export interface UploadedFileInfo {
@@ -51,7 +51,7 @@ export async function uploadFiles(
formData.append("files", file); formData.append("files", file);
}); });
const response = await fetchWithAuth( const response = await fetch(
`${getBackendBaseURL()}/api/threads/${threadId}/uploads`, `${getBackendBaseURL()}/api/threads/${threadId}/uploads`,
{ {
method: "POST", method: "POST",
@@ -92,7 +92,7 @@ export async function deleteUploadedFile(
threadId: string, threadId: string,
filename: string, filename: string,
): Promise<{ success: boolean; message: string }> { ): Promise<{ success: boolean; message: string }> {
const response = await fetchWithAuth( const response = await fetch(
`${getBackendBaseURL()}/api/threads/${threadId}/uploads/${filename}`, `${getBackendBaseURL()}/api/threads/${threadId}/uploads/${filename}`,
{ {
method: "DELETE", method: "DELETE",