mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 07:56:48 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 565ab432fc | |||
| df63c104a7 |
+20
-17
@@ -8,13 +8,17 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, TypeVar, cast
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
from deerflow.runtime import RunContext, RunManager
|
from deerflow.persistence.feedback import FeedbackRepository
|
||||||
|
from deerflow.runtime import RunContext, RunManager, StreamBridge
|
||||||
|
from deerflow.runtime.events.store.base import RunEventStore
|
||||||
|
from deerflow.runtime.runs.store.base import RunStore
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||||
@@ -22,6 +26,9 @@ if TYPE_CHECKING:
|
|||||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
"""Bootstrap and tear down all LangGraph runtime singletons.
|
||||||
@@ -84,25 +91,25 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _require(attr: str, label: str):
|
def _require(attr: str, label: str) -> Callable[[Request], T]:
|
||||||
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
|
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
|
||||||
|
|
||||||
def dep(request: Request):
|
def dep(request: Request) -> T:
|
||||||
val = getattr(request.app.state, attr, None)
|
val = getattr(request.app.state, attr, None)
|
||||||
if val is None:
|
if val is None:
|
||||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||||
return val
|
return cast(T, val)
|
||||||
|
|
||||||
dep.__name__ = dep.__qualname__ = f"get_{attr}"
|
dep.__name__ = dep.__qualname__ = f"get_{attr}"
|
||||||
return dep
|
return dep
|
||||||
|
|
||||||
|
|
||||||
get_stream_bridge = _require("stream_bridge", "Stream bridge")
|
get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge")
|
||||||
get_run_manager = _require("run_manager", "Run manager")
|
get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager")
|
||||||
get_checkpointer = _require("checkpointer", "Checkpointer")
|
get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer")
|
||||||
get_run_event_store = _require("run_event_store", "Run event store")
|
get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store")
|
||||||
get_feedback_repo = _require("feedback_repo", "Feedback")
|
get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback")
|
||||||
get_run_store = _require("run_store", "Run store")
|
get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store")
|
||||||
|
|
||||||
|
|
||||||
def get_store(request: Request):
|
def get_store(request: Request):
|
||||||
@@ -121,10 +128,7 @@ def get_thread_store(request: Request) -> ThreadMetaStore:
|
|||||||
def get_run_context(request: Request) -> RunContext:
|
def get_run_context(request: Request) -> RunContext:
|
||||||
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
||||||
|
|
||||||
Returns a *base* context with infrastructure dependencies. Callers that
|
Returns a *base* context with infrastructure dependencies.
|
||||||
need per-run fields (e.g. ``follow_up_to_run_id``) should use
|
|
||||||
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
|
|
||||||
to :func:`run_agent`.
|
|
||||||
"""
|
"""
|
||||||
from deerflow.config import get_app_config
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
@@ -137,7 +141,6 @@ def get_run_context(request: Request) -> RunContext:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Auth helpers (used by authz.py and auth middleware)
|
# Auth helpers (used by authz.py and auth middleware)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -123,7 +123,8 @@ async def run_messages(
|
|||||||
run = await _resolve_run(run_id, request)
|
run = await _resolve_run(run_id, request)
|
||||||
event_store = get_run_event_store(request)
|
event_store = get_run_event_store(request)
|
||||||
rows = await event_store.list_messages_by_run(
|
rows = await event_store.list_messages_by_run(
|
||||||
run["thread_id"], run_id,
|
run["thread_id"],
|
||||||
|
run_id,
|
||||||
limit=limit + 1,
|
limit=limit + 1,
|
||||||
before_seq=before_seq,
|
before_seq=before_seq,
|
||||||
after_seq=after_seq,
|
after_seq=after_seq,
|
||||||
|
|||||||
@@ -54,7 +54,6 @@ class RunCreateRequest(BaseModel):
|
|||||||
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
||||||
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
||||||
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
||||||
follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.")
|
|
||||||
|
|
||||||
|
|
||||||
class RunResponse(BaseModel):
|
class RunResponse(BaseModel):
|
||||||
@@ -312,11 +311,15 @@ async def list_thread_messages(
|
|||||||
if i in last_ai_indices:
|
if i in last_ai_indices:
|
||||||
run_id = msg["run_id"]
|
run_id = msg["run_id"]
|
||||||
fb = feedback_map.get(run_id)
|
fb = feedback_map.get(run_id)
|
||||||
msg["feedback"] = {
|
msg["feedback"] = (
|
||||||
"feedback_id": fb["feedback_id"],
|
{
|
||||||
"rating": fb["rating"],
|
"feedback_id": fb["feedback_id"],
|
||||||
"comment": fb.get("comment"),
|
"rating": fb["rating"],
|
||||||
} if fb else None
|
"comment": fb.get("comment"),
|
||||||
|
}
|
||||||
|
if fb
|
||||||
|
else None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
msg["feedback"] = None
|
msg["feedback"] = None
|
||||||
|
|
||||||
@@ -339,7 +342,8 @@ async def list_run_messages(
|
|||||||
"""
|
"""
|
||||||
event_store = get_run_event_store(request)
|
event_store = get_run_event_store(request)
|
||||||
rows = await event_store.list_messages_by_run(
|
rows = await event_store.list_messages_by_run(
|
||||||
thread_id, run_id,
|
thread_id,
|
||||||
|
run_id,
|
||||||
limit=limit + 1,
|
limit=limit + 1,
|
||||||
before_seq=before_seq,
|
before_seq=before_seq,
|
||||||
after_seq=after_seq,
|
after_seq=after_seq,
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=UploadResponse)
|
@router.post("", response_model=UploadResponse)
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
@require_permission("threads", "write", owner_check=True, require_existing=False)
|
||||||
async def upload_files(
|
async def upload_files(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|||||||
@@ -195,21 +195,6 @@ async def start_run(
|
|||||||
|
|
||||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||||
|
|
||||||
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
|
|
||||||
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
|
|
||||||
if follow_up_to_run_id is None:
|
|
||||||
run_store = get_run_store(request)
|
|
||||||
try:
|
|
||||||
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
|
|
||||||
if recent_runs and recent_runs[0].get("status") == "success":
|
|
||||||
follow_up_to_run_id = recent_runs[0]["run_id"]
|
|
||||||
except Exception:
|
|
||||||
pass # Don't block run creation
|
|
||||||
|
|
||||||
# Enrich base context with per-run field
|
|
||||||
if follow_up_to_run_id:
|
|
||||||
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
record = await run_mgr.create_or_reject(
|
record = await run_mgr.create_or_reject(
|
||||||
thread_id,
|
thread_id,
|
||||||
@@ -218,7 +203,6 @@ async def start_run(
|
|||||||
metadata=body.metadata or {},
|
metadata=body.metadata or {},
|
||||||
kwargs={"input": body.input, "config": body.config},
|
kwargs={"input": body.input, "config": body.config},
|
||||||
multitask_strategy=body.multitask_strategy,
|
multitask_strategy=body.multitask_strategy,
|
||||||
follow_up_to_run_id=follow_up_to_run_id,
|
|
||||||
)
|
)
|
||||||
except ConflictError as exc:
|
except ConflictError as exc:
|
||||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||||
@@ -9,6 +9,7 @@ from deerflow.agents.middlewares.clarification_middleware import ClarificationMi
|
|||||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||||
|
from deerflow.agents.middlewares.summarization_middleware import SummarizationMiddleware
|
||||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||||
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
||||||
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
||||||
|
|||||||
@@ -283,7 +283,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
# the conversation; injecting one mid-conversation crashes
|
# the conversation; injecting one mid-conversation crashes
|
||||||
# langchain_anthropic's _format_messages(). HumanMessage works
|
# langchain_anthropic's _format_messages(). HumanMessage works
|
||||||
# with all providers. See #1299.
|
# with all providers. See #1299.
|
||||||
return {"messages": [HumanMessage(content=warning)]}
|
return {"messages": [HumanMessage(content=warning, name="loop_warning")]}
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
from typing import override
|
||||||
|
|
||||||
|
from langchain.agents.middleware import SummarizationMiddleware as BaseSummarizationMiddleware
|
||||||
|
from langchain_core.messages.human import HumanMessage
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizationMiddleware(BaseSummarizationMiddleware):
|
||||||
|
@override
|
||||||
|
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||||
|
"""Override the base implementation to let the human message with the special name 'summary'.
|
||||||
|
And this message will be ignored to display in the frontend, but still can be used as context for the model.
|
||||||
|
"""
|
||||||
|
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
from typing import NotRequired, override
|
from typing import NotRequired, override
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from langgraph.config import get_config
|
from langgraph.config import get_config
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
@@ -97,8 +99,20 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
|||||||
paths = self._create_thread_directories(thread_id, user_id=user_id)
|
paths = self._create_thread_directories(thread_id, user_id=user_id)
|
||||||
logger.debug("Created thread data directories for thread %s", thread_id)
|
logger.debug("Created thread data directories for thread %s", thread_id)
|
||||||
|
|
||||||
|
messages = list(state.get("messages", []))
|
||||||
|
last_message = messages[-1] if messages else None
|
||||||
|
|
||||||
|
if last_message and isinstance(last_message, HumanMessage):
|
||||||
|
messages[-1] = HumanMessage(
|
||||||
|
content=last_message.content,
|
||||||
|
id=last_message.id,
|
||||||
|
name=last_message.name or "user-input",
|
||||||
|
additional_kwargs={**last_message.additional_kwargs, "run_id": runtime.context.get("run_id"), "timestamp": datetime.now(UTC).isoformat()},
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"thread_data": {
|
"thread_data": {
|
||||||
**paths,
|
**paths,
|
||||||
}
|
},
|
||||||
|
"messages": messages,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -279,6 +279,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
|||||||
updated_message = HumanMessage(
|
updated_message = HumanMessage(
|
||||||
content=f"{files_message}\n\n{original_content}",
|
content=f"{files_message}\n\n{original_content}",
|
||||||
id=last_message.id,
|
id=last_message.id,
|
||||||
|
name=last_message.name,
|
||||||
additional_kwargs=last_message.additional_kwargs,
|
additional_kwargs=last_message.additional_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,10 @@ handles token usage accumulation.
|
|||||||
|
|
||||||
Key design decisions:
|
Key design decisions:
|
||||||
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
||||||
- on_chat_model_start captures structured prompts as llm_request (OpenAI format)
|
- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and
|
||||||
|
extracts the first human message for run.input, because it is more reliable than
|
||||||
|
on_chain_start (fires on every node) — messages here are fully structured.
|
||||||
|
- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation.
|
||||||
- on_llm_end emits llm_response in OpenAI Chat Completions format
|
- on_llm_end emits llm_response in OpenAI Chat Completions format
|
||||||
- Token usage accumulated in memory, written to RunRow on run completion
|
- Token usage accumulated in memory, written to RunRow on run completion
|
||||||
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
||||||
@@ -18,10 +21,12 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from deerflow.runtime.events.store.base import RunEventStore
|
from deerflow.runtime.events.store.base import RunEventStore
|
||||||
@@ -72,34 +77,39 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
# LLM request/response tracking
|
# LLM request/response tracking
|
||||||
self._llm_call_index = 0
|
self._llm_call_index = 0
|
||||||
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
|
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
|
||||||
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
|
|
||||||
|
|
||||||
# Tool call ID cache
|
|
||||||
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
|
|
||||||
|
|
||||||
# -- Lifecycle callbacks --
|
# -- Lifecycle callbacks --
|
||||||
|
|
||||||
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_chain_start(
|
||||||
if kwargs.get("parent_run_id") is not None:
|
self,
|
||||||
return
|
serialized: dict[str, Any],
|
||||||
self._put(
|
inputs: dict[str, Any],
|
||||||
event_type="run_start",
|
*,
|
||||||
category="lifecycle",
|
run_id: UUID,
|
||||||
metadata={"input_preview": str(inputs)[:500]},
|
parent_run_id: UUID | None = None,
|
||||||
)
|
tags: list[str] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
caller = self._identify_caller(tags)
|
||||||
|
if parent_run_id is None:
|
||||||
|
# Root graph invocation — emit a single trace event for the run start.
|
||||||
|
chain_name = (serialized or {}).get("name", "unknown")
|
||||||
|
self._put(
|
||||||
|
event_type="run.start",
|
||||||
|
category="trace",
|
||||||
|
content={"chain": chain_name},
|
||||||
|
metadata={"caller": caller, **(metadata or {})},
|
||||||
|
)
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
if kwargs.get("parent_run_id") is not None:
|
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
||||||
return
|
|
||||||
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
|
|
||||||
self._flush_sync()
|
self._flush_sync()
|
||||||
|
|
||||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
if kwargs.get("parent_run_id") is not None:
|
|
||||||
return
|
|
||||||
self._put(
|
self._put(
|
||||||
event_type="run_error",
|
event_type="run.error",
|
||||||
category="lifecycle",
|
category="error",
|
||||||
content=str(error),
|
content=str(error),
|
||||||
metadata={"error_type": type(error).__name__},
|
metadata={"error_type": type(error).__name__},
|
||||||
)
|
)
|
||||||
@@ -107,266 +117,132 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
|
|
||||||
# -- LLM callbacks --
|
# -- LLM callbacks --
|
||||||
|
|
||||||
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
|
def on_chat_model_start(
|
||||||
"""Capture structured prompt messages for llm_request event."""
|
self,
|
||||||
from deerflow.runtime.converters import langchain_messages_to_openai
|
serialized: dict,
|
||||||
|
messages: list[list[BaseMessage]],
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Capture structured prompt messages for llm_request event.
|
||||||
|
|
||||||
|
This is also the canonical place to extract the first human message:
|
||||||
|
messages are fully structured here, it fires only on real LLM calls,
|
||||||
|
and the content is never compressed by checkpoint trimming.
|
||||||
|
"""
|
||||||
rid = str(run_id)
|
rid = str(run_id)
|
||||||
self._llm_start_times[rid] = time.monotonic()
|
self._llm_start_times[rid] = time.monotonic()
|
||||||
self._llm_call_index += 1
|
self._llm_call_index += 1
|
||||||
|
# Mark this run_id as seen so on_llm_end knows not to increment again.
|
||||||
|
self._cached_prompts[rid] = []
|
||||||
|
|
||||||
model_name = serialized.get("name", "")
|
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
|
||||||
self._cached_models[rid] = model_name
|
|
||||||
|
|
||||||
# Convert the first message list (LangChain passes list-of-lists)
|
# Capture the first human message sent to any LLM in this run.
|
||||||
prompt_msgs = messages[0] if messages else []
|
if not self._first_human_msg:
|
||||||
openai_msgs = langchain_messages_to_openai(prompt_msgs)
|
for batch in messages.reversed():
|
||||||
self._cached_prompts[rid] = openai_msgs
|
for m in batch.reversed():
|
||||||
|
if isinstance(m, HumanMessage) and m.name != "summary":
|
||||||
|
caller = self._identify_caller(tags)
|
||||||
|
self.set_first_human_message(m.text)
|
||||||
|
self._put(
|
||||||
|
event_type="llm.human.input",
|
||||||
|
category="message",
|
||||||
|
content=m.model_dump(),
|
||||||
|
metadata={"caller": caller},
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if self._first_human_msg:
|
||||||
|
break
|
||||||
|
|
||||||
caller = self._identify_caller(kwargs)
|
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||||
self._put(
|
|
||||||
event_type="llm_request",
|
|
||||||
category="trace",
|
|
||||||
content={"model": model_name, "messages": openai_msgs},
|
|
||||||
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
|
|
||||||
# Fallback: on_chat_model_start is preferred. This just tracks latency.
|
# Fallback: on_chat_model_start is preferred. This just tracks latency.
|
||||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||||
|
|
||||||
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None:
|
||||||
from deerflow.runtime.converters import langchain_to_openai_completion
|
messages: list[AnyMessage] = []
|
||||||
|
logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}")
|
||||||
try:
|
for generation in response.generations:
|
||||||
message = response.generations[0][0].message
|
for gen in generation:
|
||||||
except (IndexError, AttributeError):
|
if hasattr(gen, "message"):
|
||||||
logger.debug("on_llm_end: could not extract message from response")
|
messages.append(gen.message)
|
||||||
return
|
|
||||||
|
|
||||||
caller = self._identify_caller(kwargs)
|
|
||||||
|
|
||||||
# Latency
|
|
||||||
rid = str(run_id)
|
|
||||||
start = self._llm_start_times.pop(rid, None)
|
|
||||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
|
||||||
|
|
||||||
# Token usage from message
|
|
||||||
usage = getattr(message, "usage_metadata", None)
|
|
||||||
usage_dict = dict(usage) if usage else {}
|
|
||||||
|
|
||||||
# Resolve call index
|
|
||||||
call_index = self._llm_call_index
|
|
||||||
if rid not in self._cached_prompts:
|
|
||||||
# Fallback: on_chat_model_start was not called
|
|
||||||
self._llm_call_index += 1
|
|
||||||
call_index = self._llm_call_index
|
|
||||||
|
|
||||||
# Clean up caches
|
|
||||||
self._cached_prompts.pop(rid, None)
|
|
||||||
self._cached_models.pop(rid, None)
|
|
||||||
|
|
||||||
# Trace event: llm_response (OpenAI completion format)
|
|
||||||
content = getattr(message, "content", "")
|
|
||||||
self._put(
|
|
||||||
event_type="llm_response",
|
|
||||||
category="trace",
|
|
||||||
content=langchain_to_openai_completion(message),
|
|
||||||
metadata={
|
|
||||||
"caller": caller,
|
|
||||||
"usage": usage_dict,
|
|
||||||
"latency_ms": latency_ms,
|
|
||||||
"llm_call_index": call_index,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Message events: only lead_agent gets message-category events.
|
|
||||||
# Content uses message.model_dump() to align with checkpoint format.
|
|
||||||
tool_calls = getattr(message, "tool_calls", None) or []
|
|
||||||
if caller == "lead_agent":
|
|
||||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
|
||||||
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
|
||||||
if tool_calls:
|
|
||||||
# ai_tool_call: agent decided to use tools
|
|
||||||
self._put(
|
|
||||||
event_type="ai_tool_call",
|
|
||||||
category="message",
|
|
||||||
content=message.model_dump(),
|
|
||||||
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
|
|
||||||
)
|
|
||||||
elif isinstance(content, str) and content:
|
|
||||||
# ai_message: final text reply
|
|
||||||
self._put(
|
|
||||||
event_type="ai_message",
|
|
||||||
category="message",
|
|
||||||
content=message.model_dump(),
|
|
||||||
metadata={"model_name": model_name, "finish_reason": "stop"},
|
|
||||||
)
|
|
||||||
self._last_ai_msg = content
|
|
||||||
self._msg_count += 1
|
|
||||||
|
|
||||||
# Token accumulation
|
|
||||||
if self._track_tokens:
|
|
||||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
|
||||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
|
||||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
|
||||||
if total_tk == 0:
|
|
||||||
total_tk = input_tk + output_tk
|
|
||||||
if total_tk > 0:
|
|
||||||
self._total_input_tokens += input_tk
|
|
||||||
self._total_output_tokens += output_tk
|
|
||||||
self._total_tokens += total_tk
|
|
||||||
self._llm_call_count += 1
|
|
||||||
if caller.startswith("subagent:"):
|
|
||||||
self._subagent_tokens += total_tk
|
|
||||||
elif caller.startswith("middleware:"):
|
|
||||||
self._middleware_tokens += total_tk
|
|
||||||
else:
|
else:
|
||||||
self._lead_agent_tokens += total_tk
|
logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}")
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
caller = self._identify_caller(tags)
|
||||||
|
|
||||||
|
# Latency
|
||||||
|
rid = str(run_id)
|
||||||
|
start = self._llm_start_times.pop(rid, None)
|
||||||
|
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||||
|
|
||||||
|
# Token usage from message
|
||||||
|
usage = getattr(message, "usage_metadata", None)
|
||||||
|
usage_dict = dict(usage) if usage else {}
|
||||||
|
|
||||||
|
# Resolve call index
|
||||||
|
call_index = self._llm_call_index
|
||||||
|
if rid not in self._cached_prompts:
|
||||||
|
# Fallback: on_chat_model_start was not called
|
||||||
|
self._llm_call_index += 1
|
||||||
|
call_index = self._llm_call_index
|
||||||
|
|
||||||
|
# Trace event: llm_response (OpenAI completion format)
|
||||||
|
self._put(
|
||||||
|
event_type="llm.ai.response",
|
||||||
|
category="message",
|
||||||
|
content=message.model_dump(),
|
||||||
|
metadata={
|
||||||
|
"caller": caller,
|
||||||
|
"usage": usage_dict,
|
||||||
|
"latency_ms": latency_ms,
|
||||||
|
"llm_call_index": call_index,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Token accumulation
|
||||||
|
if self._track_tokens:
|
||||||
|
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||||
|
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||||
|
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||||
|
if total_tk == 0:
|
||||||
|
total_tk = input_tk + output_tk
|
||||||
|
if total_tk > 0:
|
||||||
|
self._total_input_tokens += input_tk
|
||||||
|
self._total_output_tokens += output_tk
|
||||||
|
self._total_tokens += total_tk
|
||||||
|
self._llm_call_count += 1
|
||||||
|
|
||||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
self._llm_start_times.pop(str(run_id), None)
|
self._llm_start_times.pop(str(run_id), None)
|
||||||
self._put(event_type="llm_error", category="trace", content=str(error))
|
self._put(event_type="llm.error", category="trace", content=str(error))
|
||||||
|
|
||||||
# -- Tool callbacks --
|
def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs):
|
||||||
|
"""Handle tool start event, cache tool call ID for later correlation"""
|
||||||
|
tool_call_id = str(run_id)
|
||||||
|
logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}")
|
||||||
|
|
||||||
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs):
|
||||||
tool_call_id = kwargs.get("tool_call_id")
|
"""Handle tool end event, append message and clear node data"""
|
||||||
if tool_call_id:
|
try:
|
||||||
self._tool_call_ids[str(run_id)] = tool_call_id
|
if isinstance(output, ToolMessage):
|
||||||
self._put(
|
msg = cast(ToolMessage, output)
|
||||||
event_type="tool_start",
|
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
||||||
category="trace",
|
elif isinstance(output, Command):
|
||||||
metadata={
|
cmd = cast(Command, output)
|
||||||
"tool_name": serialized.get("name", ""),
|
messages = cmd.update.get("messages", [])
|
||||||
"tool_call_id": tool_call_id,
|
for message in messages:
|
||||||
"args": str(input_str)[:2000],
|
if isinstance(message, BaseMessage):
|
||||||
},
|
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
||||||
)
|
else:
|
||||||
|
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
||||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
else:
|
||||||
from langchain_core.messages import ToolMessage
|
logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}")
|
||||||
from langgraph.types import Command
|
finally:
|
||||||
|
logger.info(f"Tool end for node {run_id}")
|
||||||
# Tools that update graph state return a ``Command`` (e.g.
|
|
||||||
# ``present_files``). LangGraph later unwraps the inner ToolMessage
|
|
||||||
# into checkpoint state, so to stay checkpoint-aligned we must
|
|
||||||
# extract it here rather than storing ``str(Command(...))``.
|
|
||||||
if isinstance(output, Command):
|
|
||||||
update = getattr(output, "update", None) or {}
|
|
||||||
inner_msgs = update.get("messages") if isinstance(update, dict) else None
|
|
||||||
if isinstance(inner_msgs, list):
|
|
||||||
inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None)
|
|
||||||
if inner_tool_msg is not None:
|
|
||||||
output = inner_tool_msg
|
|
||||||
|
|
||||||
# Extract fields from ToolMessage object when LangChain provides one.
|
|
||||||
# LangChain's _format_output wraps tool results into a ToolMessage
|
|
||||||
# with tool_call_id, name, status, and artifact — more complete than
|
|
||||||
# what kwargs alone provides.
|
|
||||||
if isinstance(output, ToolMessage):
|
|
||||||
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
|
||||||
tool_name = output.name or kwargs.get("name", "")
|
|
||||||
status = getattr(output, "status", "success") or "success"
|
|
||||||
content_str = output.content if isinstance(output.content, str) else str(output.content)
|
|
||||||
# Use model_dump() for checkpoint-aligned message content.
|
|
||||||
# Override tool_call_id if it was resolved from cache.
|
|
||||||
msg_content = output.model_dump()
|
|
||||||
if msg_content.get("tool_call_id") != tool_call_id:
|
|
||||||
msg_content["tool_call_id"] = tool_call_id
|
|
||||||
else:
|
|
||||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
|
||||||
tool_name = kwargs.get("name", "")
|
|
||||||
status = "success"
|
|
||||||
content_str = str(output)
|
|
||||||
# Construct checkpoint-aligned dict when output is a plain string.
|
|
||||||
msg_content = ToolMessage(
|
|
||||||
content=content_str,
|
|
||||||
tool_call_id=tool_call_id or "",
|
|
||||||
name=tool_name,
|
|
||||||
status=status,
|
|
||||||
).model_dump()
|
|
||||||
|
|
||||||
# Trace event (always)
|
|
||||||
self._put(
|
|
||||||
event_type="tool_end",
|
|
||||||
category="trace",
|
|
||||||
content=content_str,
|
|
||||||
metadata={
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
"status": status,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Message event: tool_result (checkpoint-aligned model_dump format)
|
|
||||||
self._put(
|
|
||||||
event_type="tool_result",
|
|
||||||
category="message",
|
|
||||||
content=msg_content,
|
|
||||||
metadata={"tool_name": tool_name, "status": status},
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
|
||||||
from langchain_core.messages import ToolMessage
|
|
||||||
|
|
||||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
|
||||||
tool_name = kwargs.get("name", "")
|
|
||||||
|
|
||||||
# Trace event
|
|
||||||
self._put(
|
|
||||||
event_type="tool_error",
|
|
||||||
category="trace",
|
|
||||||
content=str(error),
|
|
||||||
metadata={
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Message event: tool_result with error status (checkpoint-aligned)
|
|
||||||
msg_content = ToolMessage(
|
|
||||||
content=str(error),
|
|
||||||
tool_call_id=tool_call_id or "",
|
|
||||||
name=tool_name,
|
|
||||||
status="error",
|
|
||||||
).model_dump()
|
|
||||||
self._put(
|
|
||||||
event_type="tool_result",
|
|
||||||
category="message",
|
|
||||||
content=msg_content,
|
|
||||||
metadata={"tool_name": tool_name, "status": "error"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# -- Custom event callback --
|
|
||||||
|
|
||||||
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
|
||||||
from deerflow.runtime.serialization import serialize_lc_object
|
|
||||||
|
|
||||||
if name == "summarization":
|
|
||||||
data_dict = data if isinstance(data, dict) else {}
|
|
||||||
self._put(
|
|
||||||
event_type="summarization",
|
|
||||||
category="trace",
|
|
||||||
content=data_dict.get("summary", ""),
|
|
||||||
metadata={
|
|
||||||
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
|
|
||||||
"replaced_count": data_dict.get("replaced_count", 0),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self._put(
|
|
||||||
event_type="middleware:summarize",
|
|
||||||
category="middleware",
|
|
||||||
content={"role": "system", "content": data_dict.get("summary", "")},
|
|
||||||
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
|
|
||||||
self._put(
|
|
||||||
event_type=name,
|
|
||||||
category="trace",
|
|
||||||
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
|
|
||||||
)
|
|
||||||
|
|
||||||
# -- Internal methods --
|
# -- Internal methods --
|
||||||
|
|
||||||
@@ -431,8 +307,9 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
if exc:
|
if exc:
|
||||||
logger.warning("Journal flush task failed: %s", exc)
|
logger.warning("Journal flush task failed: %s", exc)
|
||||||
|
|
||||||
def _identify_caller(self, kwargs: dict) -> str:
|
def _identify_caller(self, tags: list[str] | None, **kwargs) -> str:
|
||||||
for tag in kwargs.get("tags") or []:
|
_tags = tags or kwargs.get("tags", [])
|
||||||
|
for tag in _tags:
|
||||||
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
||||||
return tag
|
return tag
|
||||||
# Default to lead_agent: the main agent graph does not inject
|
# Default to lead_agent: the main agent graph does not inject
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class RunManager:
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._store = store
|
self._store = store
|
||||||
|
|
||||||
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
|
async def _persist_to_store(self, record: RunRecord) -> None:
|
||||||
"""Best-effort persist run record to backing store."""
|
"""Best-effort persist run record to backing store."""
|
||||||
if self._store is None:
|
if self._store is None:
|
||||||
return
|
return
|
||||||
@@ -68,7 +68,6 @@ class RunManager:
|
|||||||
metadata=record.metadata or {},
|
metadata=record.metadata or {},
|
||||||
kwargs=record.kwargs or {},
|
kwargs=record.kwargs or {},
|
||||||
created_at=record.created_at,
|
created_at=record.created_at,
|
||||||
follow_up_to_run_id=follow_up_to_run_id,
|
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
||||||
@@ -90,7 +89,6 @@ class RunManager:
|
|||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
follow_up_to_run_id: str | None = None,
|
|
||||||
) -> RunRecord:
|
) -> RunRecord:
|
||||||
"""Create a new pending run and register it."""
|
"""Create a new pending run and register it."""
|
||||||
run_id = str(uuid.uuid4())
|
run_id = str(uuid.uuid4())
|
||||||
@@ -109,7 +107,7 @@ class RunManager:
|
|||||||
)
|
)
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
await self._persist_to_store(record)
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
@@ -122,7 +120,7 @@ class RunManager:
|
|||||||
async with self._lock:
|
async with self._lock:
|
||||||
# Dict insertion order matches creation order, so reversing it gives
|
# Dict insertion order matches creation order, so reversing it gives
|
||||||
# us deterministic newest-first results even when timestamps tie.
|
# us deterministic newest-first results even when timestamps tie.
|
||||||
return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id]
|
return [r for r in self._runs.values() if r.thread_id == thread_id]
|
||||||
|
|
||||||
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
||||||
"""Transition a run to a new status."""
|
"""Transition a run to a new status."""
|
||||||
@@ -176,7 +174,6 @@ class RunManager:
|
|||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
follow_up_to_run_id: str | None = None,
|
|
||||||
) -> RunRecord:
|
) -> RunRecord:
|
||||||
"""Atomically check for inflight runs and create a new one.
|
"""Atomically check for inflight runs and create a new one.
|
||||||
|
|
||||||
@@ -230,7 +227,7 @@ class RunManager:
|
|||||||
)
|
)
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
|
|
||||||
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
await self._persist_to_store(record)
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ class RunStore(abc.ABC):
|
|||||||
kwargs: dict[str, Any] | None = None,
|
kwargs: dict[str, Any] | None = None,
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
created_at: str | None = None,
|
created_at: str | None = None,
|
||||||
follow_up_to_run_id: str | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ class MemoryRunStore(RunStore):
|
|||||||
kwargs=None,
|
kwargs=None,
|
||||||
error=None,
|
error=None,
|
||||||
created_at=None,
|
created_at=None,
|
||||||
follow_up_to_run_id=None,
|
|
||||||
):
|
):
|
||||||
now = datetime.now(UTC).isoformat()
|
now = datetime.now(UTC).isoformat()
|
||||||
self._runs[run_id] = {
|
self._runs[run_id] = {
|
||||||
@@ -41,7 +40,6 @@ class MemoryRunStore(RunStore):
|
|||||||
"metadata": metadata or {},
|
"metadata": metadata or {},
|
||||||
"kwargs": kwargs or {},
|
"kwargs": kwargs or {},
|
||||||
"error": error,
|
"error": error,
|
||||||
"follow_up_to_run_id": follow_up_to_run_id,
|
|
||||||
"created_at": created_at or now,
|
"created_at": created_at or now,
|
||||||
"updated_at": now,
|
"updated_at": now,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ class RunContext:
|
|||||||
event_store: Any | None = field(default=None)
|
event_store: Any | None = field(default=None)
|
||||||
run_events_config: Any | None = field(default=None)
|
run_events_config: Any | None = field(default=None)
|
||||||
thread_store: Any | None = field(default=None)
|
thread_store: Any | None = field(default=None)
|
||||||
follow_up_to_run_id: str | None = field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_agent(
|
async def run_agent(
|
||||||
@@ -76,7 +75,6 @@ async def run_agent(
|
|||||||
event_store = ctx.event_store
|
event_store = ctx.event_store
|
||||||
run_events_config = ctx.run_events_config
|
run_events_config = ctx.run_events_config
|
||||||
thread_store = ctx.thread_store
|
thread_store = ctx.thread_store
|
||||||
follow_up_to_run_id = ctx.follow_up_to_run_id
|
|
||||||
|
|
||||||
run_id = record.run_id
|
run_id = record.run_id
|
||||||
thread_id = record.thread_id
|
thread_id = record.thread_id
|
||||||
@@ -113,22 +111,6 @@ async def run_agent(
|
|||||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
human_msg = _extract_human_message(graph_input)
|
|
||||||
if human_msg is not None:
|
|
||||||
msg_metadata = {}
|
|
||||||
if follow_up_to_run_id:
|
|
||||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
|
||||||
await event_store.put(
|
|
||||||
thread_id=thread_id,
|
|
||||||
run_id=run_id,
|
|
||||||
event_type="human_message",
|
|
||||||
category="message",
|
|
||||||
content=human_msg.model_dump(),
|
|
||||||
metadata=msg_metadata or None,
|
|
||||||
)
|
|
||||||
content = human_msg.content
|
|
||||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
|
||||||
|
|
||||||
# 1. Mark running
|
# 1. Mark running
|
||||||
await run_manager.set_status(run_id, RunStatus.running)
|
await run_manager.set_status(run_id, RunStatus.running)
|
||||||
|
|
||||||
@@ -166,12 +148,13 @@ async def run_agent(
|
|||||||
|
|
||||||
# Inject runtime context so middlewares can access thread_id
|
# Inject runtime context so middlewares can access thread_id
|
||||||
# (langgraph-cli does this automatically; we must do it manually)
|
# (langgraph-cli does this automatically; we must do it manually)
|
||||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store)
|
||||||
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
||||||
# prefers it over ``configurable`` for thread-level data), make
|
# prefers it over ``configurable`` for thread-level data), make
|
||||||
# sure ``thread_id`` is available there too.
|
# sure ``thread_id`` is available there too.
|
||||||
if "context" in config and isinstance(config["context"], dict):
|
if "context" in config and isinstance(config["context"], dict):
|
||||||
config["context"].setdefault("thread_id", thread_id)
|
config["context"].setdefault("thread_id", thread_id)
|
||||||
|
config["context"].setdefault("run_id", run_id)
|
||||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||||
|
|
||||||
# Inject RunJournal as a LangChain callback handler.
|
# Inject RunJournal as a LangChain callback handler.
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -75,27 +75,27 @@ async def test_cancel_not_inflight(manager: RunManager):
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_list_by_thread(manager: RunManager):
|
async def test_list_by_thread(manager: RunManager):
|
||||||
"""Same thread should return multiple runs, newest first."""
|
"""Same thread should return multiple runs."""
|
||||||
r1 = await manager.create("thread-1")
|
r1 = await manager.create("thread-1")
|
||||||
r2 = await manager.create("thread-1")
|
r2 = await manager.create("thread-1")
|
||||||
await manager.create("thread-2")
|
await manager.create("thread-2")
|
||||||
|
|
||||||
runs = await manager.list_by_thread("thread-1")
|
runs = await manager.list_by_thread("thread-1")
|
||||||
assert len(runs) == 2
|
assert len(runs) == 2
|
||||||
assert runs[0].run_id == r2.run_id
|
assert runs[0].run_id == r1.run_id
|
||||||
assert runs[1].run_id == r1.run_id
|
assert runs[1].run_id == r2.run_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_list_by_thread_is_stable_when_timestamps_tie(manager: RunManager, monkeypatch: pytest.MonkeyPatch):
|
async def test_list_by_thread_is_stable_when_timestamps_tie(manager: RunManager, monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Newest-first ordering should not depend on timestamp precision."""
|
"""Ordering should be stable (insertion order) even when timestamps tie."""
|
||||||
monkeypatch.setattr("deerflow.runtime.runs.manager._now_iso", lambda: "2026-01-01T00:00:00+00:00")
|
monkeypatch.setattr("deerflow.runtime.runs.manager._now_iso", lambda: "2026-01-01T00:00:00+00:00")
|
||||||
|
|
||||||
r1 = await manager.create("thread-1")
|
r1 = await manager.create("thread-1")
|
||||||
r2 = await manager.create("thread-1")
|
r2 = await manager.create("thread-1")
|
||||||
|
|
||||||
runs = await manager.list_by_thread("thread-1")
|
runs = await manager.list_by_thread("thread-1")
|
||||||
assert [run.run_id for run in runs] == [r2.run_id, r1.run_id]
|
assert [run.run_id for run in runs] == [r1.run_id, r2.run_id]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
|||||||
@@ -46,7 +46,13 @@ export default function AgentChatPage() {
|
|||||||
const [settings, setSettings] = useThreadSettings(threadId);
|
const [settings, setSettings] = useThreadSettings(threadId);
|
||||||
|
|
||||||
const { showNotification } = useNotification();
|
const { showNotification } = useNotification();
|
||||||
const [thread, sendMessage] = useThreadStream({
|
const {
|
||||||
|
thread,
|
||||||
|
sendMessage,
|
||||||
|
isHistoryLoading,
|
||||||
|
hasMoreHistory,
|
||||||
|
loadMoreHistory,
|
||||||
|
} = useThreadStream({
|
||||||
threadId: isNewThread ? undefined : threadId,
|
threadId: isNewThread ? undefined : threadId,
|
||||||
context: { ...settings.context, agent_name: agent_name },
|
context: { ...settings.context, agent_name: agent_name },
|
||||||
onStart: (createdThreadId) => {
|
onStart: (createdThreadId) => {
|
||||||
@@ -141,6 +147,9 @@ export default function AgentChatPage() {
|
|||||||
threadId={threadId}
|
threadId={threadId}
|
||||||
thread={thread}
|
thread={thread}
|
||||||
paddingBottom={messageListPaddingBottom}
|
paddingBottom={messageListPaddingBottom}
|
||||||
|
hasMoreHistory={hasMoreHistory}
|
||||||
|
loadMoreHistory={loadMoreHistory}
|
||||||
|
isHistoryLoading={isHistoryLoading}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useCallback, useEffect, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
import { type PromptInputMessage } from "@/components/ai-elements/prompt-input";
|
import { type PromptInputMessage } from "@/components/ai-elements/prompt-input";
|
||||||
import { ArtifactTrigger } from "@/components/workspace/artifacts";
|
import { ArtifactTrigger } from "@/components/workspace/artifacts";
|
||||||
@@ -35,19 +35,30 @@ export default function ChatPage() {
|
|||||||
const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } =
|
const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } =
|
||||||
useThreadChat();
|
useThreadChat();
|
||||||
const [settings, setSettings] = useThreadSettings(threadId);
|
const [settings, setSettings] = useThreadSettings(threadId);
|
||||||
const [mounted, setMounted] = useState(false);
|
const mountedRef = useRef(false);
|
||||||
useSpecificChatMode();
|
useSpecificChatMode();
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setMounted(true);
|
mountedRef.current = true;
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const { showNotification } = useNotification();
|
const { showNotification } = useNotification();
|
||||||
|
|
||||||
const [thread, sendMessage, isUploading] = useThreadStream({
|
const {
|
||||||
|
thread,
|
||||||
|
sendMessage,
|
||||||
|
isUploading,
|
||||||
|
isHistoryLoading,
|
||||||
|
hasMoreHistory,
|
||||||
|
loadMoreHistory,
|
||||||
|
} = useThreadStream({
|
||||||
threadId: isNewThread ? undefined : threadId,
|
threadId: isNewThread ? undefined : threadId,
|
||||||
context: settings.context,
|
context: settings.context,
|
||||||
isMock,
|
isMock,
|
||||||
|
onSend: (_threadId) => {
|
||||||
|
setThreadId(_threadId);
|
||||||
|
setIsNewThread(false);
|
||||||
|
},
|
||||||
onStart: (createdThreadId) => {
|
onStart: (createdThreadId) => {
|
||||||
setThreadId(createdThreadId);
|
setThreadId(createdThreadId);
|
||||||
setIsNewThread(false);
|
setIsNewThread(false);
|
||||||
@@ -115,6 +126,9 @@ export default function ChatPage() {
|
|||||||
threadId={threadId}
|
threadId={threadId}
|
||||||
thread={thread}
|
thread={thread}
|
||||||
paddingBottom={messageListPaddingBottom}
|
paddingBottom={messageListPaddingBottom}
|
||||||
|
hasMoreHistory={hasMoreHistory}
|
||||||
|
loadMoreHistory={loadMoreHistory}
|
||||||
|
isHistoryLoading={isHistoryLoading}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4">
|
<div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4">
|
||||||
@@ -138,7 +152,7 @@ export default function ChatPage() {
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{mounted ? (
|
{mountedRef.current ? (
|
||||||
<InputBox
|
<InputBox
|
||||||
className={cn("bg-background/5 w-full -translate-y-4")}
|
className={cn("bg-background/5 w-full -translate-y-4")}
|
||||||
isNewThread={isNewThread}
|
isNewThread={isNewThread}
|
||||||
@@ -170,7 +184,7 @@ export default function ChatPage() {
|
|||||||
<div
|
<div
|
||||||
aria-hidden="true"
|
aria-hidden="true"
|
||||||
className={cn(
|
className={cn(
|
||||||
"bg-background/5 h-32 w-full -translate-y-4 rounded-2xl border",
|
"bg-background/5 h-32 w-full -translate-y-4 rounded-2xl",
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ import {
|
|||||||
DropdownMenuLabel,
|
DropdownMenuLabel,
|
||||||
DropdownMenuSeparator,
|
DropdownMenuSeparator,
|
||||||
} from "@/components/ui/dropdown-menu";
|
} from "@/components/ui/dropdown-menu";
|
||||||
import { fetchWithAuth } from "@/core/api/fetcher";
|
import { fetch } from "@/core/api/fetcher";
|
||||||
import { getBackendBaseURL } from "@/core/config";
|
import { getBackendBaseURL } from "@/core/config";
|
||||||
import { useI18n } from "@/core/i18n/hooks";
|
import { useI18n } from "@/core/i18n/hooks";
|
||||||
import { useModels } from "@/core/models/hooks";
|
import { useModels } from "@/core/models/hooks";
|
||||||
@@ -155,6 +155,7 @@ export function InputBox({
|
|||||||
const [followupsLoading, setFollowupsLoading] = useState(false);
|
const [followupsLoading, setFollowupsLoading] = useState(false);
|
||||||
const lastGeneratedForAiIdRef = useRef<string | null>(null);
|
const lastGeneratedForAiIdRef = useRef<string | null>(null);
|
||||||
const wasStreamingRef = useRef(false);
|
const wasStreamingRef = useRef(false);
|
||||||
|
const messagesRef = useRef(thread.messages);
|
||||||
|
|
||||||
const [confirmOpen, setConfirmOpen] = useState(false);
|
const [confirmOpen, setConfirmOpen] = useState(false);
|
||||||
const [pendingSuggestion, setPendingSuggestion] = useState<string | null>(
|
const [pendingSuggestion, setPendingSuggestion] = useState<string | null>(
|
||||||
@@ -354,6 +355,10 @@ export function InputBox({
|
|||||||
followupsVisibilityChangeRef.current?.(showFollowups);
|
followupsVisibilityChangeRef.current?.(showFollowups);
|
||||||
}, [showFollowups]);
|
}, [showFollowups]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
messagesRef.current = thread.messages;
|
||||||
|
}, [thread.messages]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
return () => followupsVisibilityChangeRef.current?.(false);
|
return () => followupsVisibilityChangeRef.current?.(false);
|
||||||
}, []);
|
}, []);
|
||||||
@@ -370,14 +375,16 @@ export function InputBox({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const lastAi = [...thread.messages].reverse().find((m) => m.type === "ai");
|
const lastAi = [...messagesRef.current]
|
||||||
|
.reverse()
|
||||||
|
.find((m) => m.type === "ai");
|
||||||
const lastAiId = lastAi?.id ?? null;
|
const lastAiId = lastAi?.id ?? null;
|
||||||
if (!lastAiId || lastAiId === lastGeneratedForAiIdRef.current) {
|
if (!lastAiId || lastAiId === lastGeneratedForAiIdRef.current) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
lastGeneratedForAiIdRef.current = lastAiId;
|
lastGeneratedForAiIdRef.current = lastAiId;
|
||||||
|
|
||||||
const recent = thread.messages
|
const recent = messagesRef.current
|
||||||
.filter((m) => m.type === "human" || m.type === "ai")
|
.filter((m) => m.type === "human" || m.type === "ai")
|
||||||
.map((m) => {
|
.map((m) => {
|
||||||
const role = m.type === "human" ? "user" : "assistant";
|
const role = m.type === "human" ? "user" : "assistant";
|
||||||
@@ -396,19 +403,16 @@ export function InputBox({
|
|||||||
setFollowupsLoading(true);
|
setFollowupsLoading(true);
|
||||||
setFollowups([]);
|
setFollowups([]);
|
||||||
|
|
||||||
fetchWithAuth(
|
fetch(`${getBackendBaseURL()}/api/threads/${threadId}/suggestions`, {
|
||||||
`${getBackendBaseURL()}/api/threads/${threadId}/suggestions`,
|
method: "POST",
|
||||||
{
|
headers: { "Content-Type": "application/json" },
|
||||||
method: "POST",
|
body: JSON.stringify({
|
||||||
headers: { "Content-Type": "application/json" },
|
messages: recent,
|
||||||
body: JSON.stringify({
|
n: 3,
|
||||||
messages: recent,
|
model_name: context.model_name ?? undefined,
|
||||||
n: 3,
|
}),
|
||||||
model_name: context.model_name ?? undefined,
|
signal: controller.signal,
|
||||||
}),
|
})
|
||||||
signal: controller.signal,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.then(async (res) => {
|
.then(async (res) => {
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
return { suggestions: [] as string[] };
|
return { suggestions: [] as string[] };
|
||||||
@@ -430,7 +434,7 @@ export function InputBox({
|
|||||||
});
|
});
|
||||||
|
|
||||||
return () => controller.abort();
|
return () => controller.abort();
|
||||||
}, [context.model_name, disabled, isMock, status, thread.messages, threadId]);
|
}, [context.model_name, disabled, isMock, status, threadId]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div ref={promptRootRef} className="relative flex flex-col gap-4">
|
<div ref={promptRootRef} className="relative flex flex-col gap-4">
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import type { BaseStream } from "@langchain/langgraph-sdk/react";
|
import type { BaseStream } from "@langchain/langgraph-sdk/react";
|
||||||
|
import { ChevronUpIcon, Loader2Icon } from "lucide-react";
|
||||||
|
import { useCallback, useEffect, useRef } from "react";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
Conversation,
|
Conversation,
|
||||||
ConversationContent,
|
ConversationContent,
|
||||||
} from "@/components/ai-elements/conversation";
|
} from "@/components/ai-elements/conversation";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
import { useI18n } from "@/core/i18n/hooks";
|
import { useI18n } from "@/core/i18n/hooks";
|
||||||
import {
|
import {
|
||||||
extractContentFromMessage,
|
extractContentFromMessage,
|
||||||
@@ -18,7 +21,6 @@ import { useRehypeSplitWordsIntoSpans } from "@/core/rehype";
|
|||||||
import type { Subtask } from "@/core/tasks";
|
import type { Subtask } from "@/core/tasks";
|
||||||
import { useUpdateSubtask } from "@/core/tasks/context";
|
import { useUpdateSubtask } from "@/core/tasks/context";
|
||||||
import type { AgentThreadState } from "@/core/threads";
|
import type { AgentThreadState } from "@/core/threads";
|
||||||
import { useThreadMessageEnrichment } from "@/core/threads/hooks";
|
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
import { ArtifactFileList } from "../artifacts/artifact-file-list";
|
import { ArtifactFileList } from "../artifacts/artifact-file-list";
|
||||||
@@ -33,22 +35,134 @@ import { SubtaskCard } from "./subtask-card";
|
|||||||
export const MESSAGE_LIST_DEFAULT_PADDING_BOTTOM = 160;
|
export const MESSAGE_LIST_DEFAULT_PADDING_BOTTOM = 160;
|
||||||
export const MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM = 80;
|
export const MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM = 80;
|
||||||
|
|
||||||
|
const LOAD_MORE_HISTORY_THROTTLE_MS = 1200;
|
||||||
|
|
||||||
|
function LoadMoreHistoryIndicator({
|
||||||
|
isLoading,
|
||||||
|
hasMore,
|
||||||
|
loadMore,
|
||||||
|
}: {
|
||||||
|
isLoading?: boolean;
|
||||||
|
hasMore?: boolean;
|
||||||
|
loadMore?: () => void;
|
||||||
|
}) {
|
||||||
|
const { t } = useI18n();
|
||||||
|
const sentinelRef = useRef<HTMLDivElement | null>(null);
|
||||||
|
const timeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||||
|
const lastLoadRef = useRef(0);
|
||||||
|
|
||||||
|
const throttledLoadMore = useCallback(() => {
|
||||||
|
if (!hasMore || isLoading) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const now = Date.now();
|
||||||
|
const remaining =
|
||||||
|
LOAD_MORE_HISTORY_THROTTLE_MS - (now - lastLoadRef.current);
|
||||||
|
|
||||||
|
if (remaining <= 0) {
|
||||||
|
lastLoadRef.current = now;
|
||||||
|
loadMore?.();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (timeoutRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
timeoutRef.current = setTimeout(() => {
|
||||||
|
timeoutRef.current = null;
|
||||||
|
if (!hasMore || isLoading) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
lastLoadRef.current = Date.now();
|
||||||
|
loadMore?.();
|
||||||
|
}, remaining);
|
||||||
|
}, [hasMore, isLoading, loadMore]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const element = sentinelRef.current;
|
||||||
|
if (!element || !hasMore) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const observer = new IntersectionObserver(
|
||||||
|
([entry]) => {
|
||||||
|
if (entry?.isIntersecting) {
|
||||||
|
throttledLoadMore();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
rootMargin: "120px 0px 0px 0px",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
observer.observe(element);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
observer.disconnect();
|
||||||
|
};
|
||||||
|
}, [hasMore, throttledLoadMore]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
return () => {
|
||||||
|
if (timeoutRef.current) {
|
||||||
|
clearTimeout(timeoutRef.current);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
if (!hasMore && !isLoading) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div ref={sentinelRef} className="flex w-full justify-center">
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
className="text-muted-foreground hover:text-foreground rounded-full px-3"
|
||||||
|
disabled={(isLoading ?? false) || !hasMore}
|
||||||
|
onClick={throttledLoadMore}
|
||||||
|
>
|
||||||
|
{isLoading ? (
|
||||||
|
<>
|
||||||
|
<Loader2Icon className="mr-2 size-4 animate-spin" />
|
||||||
|
{t.common.loading}
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<ChevronUpIcon className="mr-2 size-4" />
|
||||||
|
{t.common.loadMore}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export function MessageList({
|
export function MessageList({
|
||||||
className,
|
className,
|
||||||
threadId,
|
threadId,
|
||||||
thread,
|
thread,
|
||||||
paddingBottom = MESSAGE_LIST_DEFAULT_PADDING_BOTTOM,
|
paddingBottom = MESSAGE_LIST_DEFAULT_PADDING_BOTTOM,
|
||||||
|
hasMoreHistory,
|
||||||
|
loadMoreHistory,
|
||||||
|
isHistoryLoading,
|
||||||
}: {
|
}: {
|
||||||
className?: string;
|
className?: string;
|
||||||
threadId: string;
|
threadId: string;
|
||||||
thread: BaseStream<AgentThreadState>;
|
thread: BaseStream<AgentThreadState>;
|
||||||
paddingBottom?: number;
|
paddingBottom?: number;
|
||||||
|
hasMoreHistory?: boolean;
|
||||||
|
loadMoreHistory?: () => void;
|
||||||
|
isHistoryLoading?: boolean;
|
||||||
}) {
|
}) {
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
const rehypePlugins = useRehypeSplitWordsIntoSpans(thread.isLoading);
|
const rehypePlugins = useRehypeSplitWordsIntoSpans(thread.isLoading);
|
||||||
const updateSubtask = useUpdateSubtask();
|
const updateSubtask = useUpdateSubtask();
|
||||||
const messages = thread.messages;
|
const messages = thread.messages;
|
||||||
const { data: enrichment } = useThreadMessageEnrichment(threadId);
|
|
||||||
|
|
||||||
if (thread.isThreadLoading && messages.length === 0) {
|
if (thread.isThreadLoading && messages.length === 0) {
|
||||||
return <MessageListSkeleton />;
|
return <MessageListSkeleton />;
|
||||||
@@ -57,19 +171,21 @@ export function MessageList({
|
|||||||
<Conversation
|
<Conversation
|
||||||
className={cn("flex size-full flex-col justify-center", className)}
|
className={cn("flex size-full flex-col justify-center", className)}
|
||||||
>
|
>
|
||||||
<ConversationContent className="mx-auto w-full max-w-(--container-width-md) gap-8 pt-12">
|
<ConversationContent className="mx-auto w-full max-w-(--container-width-md) gap-8 pt-8">
|
||||||
|
<LoadMoreHistoryIndicator
|
||||||
|
isLoading={isHistoryLoading}
|
||||||
|
hasMore={hasMoreHistory}
|
||||||
|
loadMore={loadMoreHistory}
|
||||||
|
/>
|
||||||
{groupMessages(messages, (group) => {
|
{groupMessages(messages, (group) => {
|
||||||
if (group.type === "human" || group.type === "assistant") {
|
if (group.type === "human" || group.type === "assistant") {
|
||||||
return group.messages.map((msg) => {
|
return group.messages.map((msg) => {
|
||||||
const entry = msg.id ? enrichment?.get(msg.id) : undefined;
|
|
||||||
return (
|
return (
|
||||||
<MessageListItem
|
<MessageListItem
|
||||||
key={`${group.id}/${msg.id}`}
|
key={`${group.id}/${msg.id}`}
|
||||||
threadId={threadId}
|
threadId={threadId}
|
||||||
message={msg}
|
message={msg}
|
||||||
isLoading={thread.isLoading}
|
isLoading={thread.isLoading}
|
||||||
runId={entry?.run_id}
|
|
||||||
feedback={entry?.feedback}
|
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import { useState } from "react";
|
|||||||
|
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Input } from "@/components/ui/input";
|
import { Input } from "@/components/ui/input";
|
||||||
import { fetchWithAuth, getCsrfHeaders } from "@/core/api/fetcher";
|
import { fetch, getCsrfHeaders } from "@/core/api/fetcher";
|
||||||
import { useAuth } from "@/core/auth/AuthProvider";
|
import { useAuth } from "@/core/auth/AuthProvider";
|
||||||
import { parseAuthError } from "@/core/auth/types";
|
import { parseAuthError } from "@/core/auth/types";
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ export function AccountSettingsPage() {
|
|||||||
|
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
try {
|
try {
|
||||||
const res = await fetchWithAuth("/api/v1/auth/change-password", {
|
const res = await fetch("/api/v1/auth/change-password", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { fetchWithAuth } from "@/core/api/fetcher";
|
import { fetch } from "@/core/api/fetcher";
|
||||||
import { getBackendBaseURL } from "@/core/config";
|
import { getBackendBaseURL } from "@/core/config";
|
||||||
|
|
||||||
import type { Agent, CreateAgentRequest, UpdateAgentRequest } from "./types";
|
import type { Agent, CreateAgentRequest, UpdateAgentRequest } from "./types";
|
||||||
@@ -29,7 +29,7 @@ export async function getAgent(name: string): Promise<Agent> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export async function createAgent(request: CreateAgentRequest): Promise<Agent> {
|
export async function createAgent(request: CreateAgentRequest): Promise<Agent> {
|
||||||
const res = await fetchWithAuth(`${getBackendBaseURL()}/api/agents`, {
|
const res = await fetch(`${getBackendBaseURL()}/api/agents`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
body: JSON.stringify(request),
|
body: JSON.stringify(request),
|
||||||
@@ -45,7 +45,7 @@ export async function updateAgent(
|
|||||||
name: string,
|
name: string,
|
||||||
request: UpdateAgentRequest,
|
request: UpdateAgentRequest,
|
||||||
): Promise<Agent> {
|
): Promise<Agent> {
|
||||||
const res = await fetchWithAuth(`${getBackendBaseURL()}/api/agents/${name}`, {
|
const res = await fetch(`${getBackendBaseURL()}/api/agents/${name}`, {
|
||||||
method: "PUT",
|
method: "PUT",
|
||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
body: JSON.stringify(request),
|
body: JSON.stringify(request),
|
||||||
@@ -58,7 +58,7 @@ export async function updateAgent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
export async function deleteAgent(name: string): Promise<void> {
|
export async function deleteAgent(name: string): Promise<void> {
|
||||||
const res = await fetchWithAuth(`${getBackendBaseURL()}/api/agents/${name}`, {
|
const res = await fetch(`${getBackendBaseURL()}/api/agents/${name}`, {
|
||||||
method: "DELETE",
|
method: "DELETE",
|
||||||
});
|
});
|
||||||
if (!res.ok) throw new Error(`Failed to delete agent: ${res.statusText}`);
|
if (!res.ok) throw new Error(`Failed to delete agent: ${res.statusText}`);
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { getBackendBaseURL } from "../config";
|
import { getBackendBaseURL } from "../config";
|
||||||
|
|
||||||
import { fetchWithAuth } from "./fetcher";
|
import { fetch } from "./fetcher";
|
||||||
|
|
||||||
export interface FeedbackData {
|
export interface FeedbackData {
|
||||||
feedback_id: string;
|
feedback_id: string;
|
||||||
@@ -14,7 +14,7 @@ export async function upsertFeedback(
|
|||||||
rating: number,
|
rating: number,
|
||||||
comment?: string,
|
comment?: string,
|
||||||
): Promise<FeedbackData> {
|
): Promise<FeedbackData> {
|
||||||
const res = await fetchWithAuth(
|
const res = await fetch(
|
||||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/runs/${encodeURIComponent(runId)}/feedback`,
|
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/runs/${encodeURIComponent(runId)}/feedback`,
|
||||||
{
|
{
|
||||||
method: "PUT",
|
method: "PUT",
|
||||||
@@ -32,7 +32,7 @@ export async function deleteFeedback(
|
|||||||
threadId: string,
|
threadId: string,
|
||||||
runId: string,
|
runId: string,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const res = await fetchWithAuth(
|
const res = await fetch(
|
||||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/runs/${encodeURIComponent(runId)}/feedback`,
|
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/runs/${encodeURIComponent(runId)}/feedback`,
|
||||||
{ method: "DELETE" },
|
{ method: "DELETE" },
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ export function readCsrfCookie(): string | null {
|
|||||||
* preserved; the helper only ADDS the CSRF header when it isn't already
|
* preserved; the helper only ADDS the CSRF header when it isn't already
|
||||||
* present, so explicit overrides win.
|
* present, so explicit overrides win.
|
||||||
*/
|
*/
|
||||||
export async function fetchWithAuth(
|
export async function fetch(
|
||||||
input: RequestInfo | string,
|
input: RequestInfo | string,
|
||||||
init?: RequestInit,
|
init?: RequestInit,
|
||||||
): Promise<Response> {
|
): Promise<Response> {
|
||||||
@@ -74,7 +74,7 @@ export async function fetchWithAuth(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const res = await fetch(url, {
|
const res = await globalThis.fetch(url, {
|
||||||
...init,
|
...init,
|
||||||
headers,
|
headers,
|
||||||
credentials: "include",
|
credentials: "include",
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ export const enUS: Translations = {
|
|||||||
close: "Close",
|
close: "Close",
|
||||||
more: "More",
|
more: "More",
|
||||||
search: "Search",
|
search: "Search",
|
||||||
|
loadMore: "Load more",
|
||||||
download: "Download",
|
download: "Download",
|
||||||
thinking: "Thinking",
|
thinking: "Thinking",
|
||||||
artifacts: "Artifacts",
|
artifacts: "Artifacts",
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ export interface Translations {
|
|||||||
close: string;
|
close: string;
|
||||||
more: string;
|
more: string;
|
||||||
search: string;
|
search: string;
|
||||||
|
loadMore: string;
|
||||||
download: string;
|
download: string;
|
||||||
thinking: string;
|
thinking: string;
|
||||||
artifacts: string;
|
artifacts: string;
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ export const zhCN: Translations = {
|
|||||||
close: "关闭",
|
close: "关闭",
|
||||||
more: "更多",
|
more: "更多",
|
||||||
search: "搜索",
|
search: "搜索",
|
||||||
|
loadMore: "加载更多",
|
||||||
download: "下载",
|
download: "下载",
|
||||||
thinking: "思考",
|
thinking: "思考",
|
||||||
artifacts: "文件",
|
artifacts: "文件",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { fetchWithAuth } from "@/core/api/fetcher";
|
import { fetch } from "@/core/api/fetcher";
|
||||||
import { getBackendBaseURL } from "@/core/config";
|
import { getBackendBaseURL } from "@/core/config";
|
||||||
|
|
||||||
import type { MCPConfig } from "./types";
|
import type { MCPConfig } from "./types";
|
||||||
@@ -9,15 +9,12 @@ export async function loadMCPConfig() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export async function updateMCPConfig(config: MCPConfig) {
|
export async function updateMCPConfig(config: MCPConfig) {
|
||||||
const response = await fetchWithAuth(
|
const response = await fetch(`${getBackendBaseURL()}/api/mcp/config`, {
|
||||||
`${getBackendBaseURL()}/api/mcp/config`,
|
method: "PUT",
|
||||||
{
|
headers: {
|
||||||
method: "PUT",
|
"Content-Type": "application/json",
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
body: JSON.stringify(config),
|
|
||||||
},
|
},
|
||||||
);
|
body: JSON.stringify(config),
|
||||||
|
});
|
||||||
return response.json();
|
return response.json();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { fetchWithAuth } from "../api/fetcher";
|
import { fetch } from "../api/fetcher";
|
||||||
import { getBackendBaseURL } from "../config";
|
import { getBackendBaseURL } from "../config";
|
||||||
|
|
||||||
import type {
|
import type {
|
||||||
@@ -86,14 +86,14 @@ export async function loadMemory(): Promise<UserMemory> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export async function clearMemory(): Promise<UserMemory> {
|
export async function clearMemory(): Promise<UserMemory> {
|
||||||
const response = await fetchWithAuth(`${getBackendBaseURL()}/api/memory`, {
|
const response = await fetch(`${getBackendBaseURL()}/api/memory`, {
|
||||||
method: "DELETE",
|
method: "DELETE",
|
||||||
});
|
});
|
||||||
return readMemoryResponse(response, "Failed to clear memory");
|
return readMemoryResponse(response, "Failed to clear memory");
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function deleteMemoryFact(factId: string): Promise<UserMemory> {
|
export async function deleteMemoryFact(factId: string): Promise<UserMemory> {
|
||||||
const response = await fetchWithAuth(
|
const response = await fetch(
|
||||||
`${getBackendBaseURL()}/api/memory/facts/${encodeURIComponent(factId)}`,
|
`${getBackendBaseURL()}/api/memory/facts/${encodeURIComponent(factId)}`,
|
||||||
{
|
{
|
||||||
method: "DELETE",
|
method: "DELETE",
|
||||||
@@ -108,32 +108,26 @@ export async function exportMemory(): Promise<UserMemory> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export async function importMemory(memory: UserMemory): Promise<UserMemory> {
|
export async function importMemory(memory: UserMemory): Promise<UserMemory> {
|
||||||
const response = await fetchWithAuth(
|
const response = await fetch(`${getBackendBaseURL()}/api/memory/import`, {
|
||||||
`${getBackendBaseURL()}/api/memory/import`,
|
method: "POST",
|
||||||
{
|
headers: {
|
||||||
method: "POST",
|
"Content-Type": "application/json",
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
body: JSON.stringify(memory),
|
|
||||||
},
|
},
|
||||||
);
|
body: JSON.stringify(memory),
|
||||||
|
});
|
||||||
return readMemoryResponse(response, "Failed to import memory");
|
return readMemoryResponse(response, "Failed to import memory");
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function createMemoryFact(
|
export async function createMemoryFact(
|
||||||
input: MemoryFactInput,
|
input: MemoryFactInput,
|
||||||
): Promise<UserMemory> {
|
): Promise<UserMemory> {
|
||||||
const response = await fetchWithAuth(
|
const response = await fetch(`${getBackendBaseURL()}/api/memory/facts`, {
|
||||||
`${getBackendBaseURL()}/api/memory/facts`,
|
method: "POST",
|
||||||
{
|
headers: {
|
||||||
method: "POST",
|
"Content-Type": "application/json",
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
body: JSON.stringify(input),
|
|
||||||
},
|
},
|
||||||
);
|
body: JSON.stringify(input),
|
||||||
|
});
|
||||||
return readMemoryResponse(response, "Failed to create memory fact");
|
return readMemoryResponse(response, "Failed to create memory fact");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,7 +135,7 @@ export async function updateMemoryFact(
|
|||||||
factId: string,
|
factId: string,
|
||||||
input: MemoryFactPatchInput,
|
input: MemoryFactPatchInput,
|
||||||
): Promise<UserMemory> {
|
): Promise<UserMemory> {
|
||||||
const response = await fetchWithAuth(
|
const response = await fetch(
|
||||||
`${getBackendBaseURL()}/api/memory/facts/${encodeURIComponent(factId)}`,
|
`${getBackendBaseURL()}/api/memory/facts/${encodeURIComponent(factId)}`,
|
||||||
{
|
{
|
||||||
method: "PATCH",
|
method: "PATCH",
|
||||||
|
|||||||
@@ -328,7 +328,11 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function isHiddenFromUIMessage(message: Message) {
|
export function isHiddenFromUIMessage(message: Message) {
|
||||||
return message.additional_kwargs?.hide_from_ui === true;
|
return (
|
||||||
|
message.additional_kwargs?.hide_from_ui === true ||
|
||||||
|
message.name === "summary" ||
|
||||||
|
message.name === "loop_warning"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { fetchWithAuth } from "@/core/api/fetcher";
|
import { fetch } from "@/core/api/fetcher";
|
||||||
import { getBackendBaseURL } from "@/core/config";
|
import { getBackendBaseURL } from "@/core/config";
|
||||||
|
|
||||||
import type { Skill } from "./type";
|
import type { Skill } from "./type";
|
||||||
@@ -10,7 +10,7 @@ export async function loadSkills() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export async function enableSkill(skillName: string, enabled: boolean) {
|
export async function enableSkill(skillName: string, enabled: boolean) {
|
||||||
const response = await fetchWithAuth(
|
const response = await fetch(
|
||||||
`${getBackendBaseURL()}/api/skills/${skillName}`,
|
`${getBackendBaseURL()}/api/skills/${skillName}`,
|
||||||
{
|
{
|
||||||
method: "PUT",
|
method: "PUT",
|
||||||
@@ -39,16 +39,13 @@ export interface InstallSkillResponse {
|
|||||||
export async function installSkill(
|
export async function installSkill(
|
||||||
request: InstallSkillRequest,
|
request: InstallSkillRequest,
|
||||||
): Promise<InstallSkillResponse> {
|
): Promise<InstallSkillResponse> {
|
||||||
const response = await fetchWithAuth(
|
const response = await fetch(`${getBackendBaseURL()}/api/skills/install`, {
|
||||||
`${getBackendBaseURL()}/api/skills/install`,
|
method: "POST",
|
||||||
{
|
headers: {
|
||||||
method: "POST",
|
"Content-Type": "application/json",
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
body: JSON.stringify(request),
|
|
||||||
},
|
},
|
||||||
);
|
body: JSON.stringify(request),
|
||||||
|
});
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
// Handle HTTP error responses (4xx, 5xx)
|
// Handle HTTP error responses (4xx, 5xx)
|
||||||
|
|||||||
+214
-189
@@ -1,15 +1,14 @@
|
|||||||
import type { AIMessage, Message } from "@langchain/langgraph-sdk";
|
import type { AIMessage, Message, Run } from "@langchain/langgraph-sdk";
|
||||||
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
|
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
|
||||||
import { useStream } from "@langchain/langgraph-sdk/react";
|
import { useStream } from "@langchain/langgraph-sdk/react";
|
||||||
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
|
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
|
|
||||||
import type { PromptInputMessage } from "@/components/ai-elements/prompt-input";
|
import type { PromptInputMessage } from "@/components/ai-elements/prompt-input";
|
||||||
|
|
||||||
import { getAPIClient } from "../api";
|
import { getAPIClient } from "../api";
|
||||||
import type { FeedbackData } from "../api/feedback";
|
import { fetch } from "../api/fetcher";
|
||||||
import { fetchWithAuth } from "../api/fetcher";
|
|
||||||
import { getBackendBaseURL } from "../config";
|
import { getBackendBaseURL } from "../config";
|
||||||
import { useI18n } from "../i18n/hooks";
|
import { useI18n } from "../i18n/hooks";
|
||||||
import type { FileInMessage } from "../messages/utils";
|
import type { FileInMessage } from "../messages/utils";
|
||||||
@@ -18,7 +17,7 @@ import { useUpdateSubtask } from "../tasks/context";
|
|||||||
import type { UploadedFileInfo } from "../uploads";
|
import type { UploadedFileInfo } from "../uploads";
|
||||||
import { promptInputFilePartToFile, uploadFiles } from "../uploads";
|
import { promptInputFilePartToFile, uploadFiles } from "../uploads";
|
||||||
|
|
||||||
import type { AgentThread, AgentThreadState } from "./types";
|
import type { AgentThread, AgentThreadState, RunMessage } from "./types";
|
||||||
|
|
||||||
export type ToolEndEvent = {
|
export type ToolEndEvent = {
|
||||||
name: string;
|
name: string;
|
||||||
@@ -29,7 +28,8 @@ export type ThreadStreamOptions = {
|
|||||||
threadId?: string | null | undefined;
|
threadId?: string | null | undefined;
|
||||||
context: LocalSettings["context"];
|
context: LocalSettings["context"];
|
||||||
isMock?: boolean;
|
isMock?: boolean;
|
||||||
onStart?: (threadId: string) => void;
|
onSend?: (threadId: string) => void;
|
||||||
|
onStart?: (threadId: string, runId: string) => void;
|
||||||
onFinish?: (state: AgentThreadState) => void;
|
onFinish?: (state: AgentThreadState) => void;
|
||||||
onToolEnd?: (event: ToolEndEvent) => void;
|
onToolEnd?: (event: ToolEndEvent) => void;
|
||||||
};
|
};
|
||||||
@@ -38,79 +38,41 @@ type SendMessageOptions = {
|
|||||||
additionalKwargs?: Record<string, unknown>;
|
additionalKwargs?: Record<string, unknown>;
|
||||||
};
|
};
|
||||||
|
|
||||||
function normalizeStoredRunId(runId: string | null): string | null {
|
function mergeMessages(
|
||||||
if (!runId) {
|
historyMessages: Message[],
|
||||||
return null;
|
threadMessages: Message[],
|
||||||
}
|
optimisticMessages: Message[],
|
||||||
|
): Message[] {
|
||||||
|
const threadMessageIds = new Set(
|
||||||
|
threadMessages
|
||||||
|
.map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id))
|
||||||
|
.filter(Boolean),
|
||||||
|
);
|
||||||
|
|
||||||
const trimmed = runId.trim();
|
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
|
||||||
if (!trimmed) {
|
// Scan from the end: shrink cutoff while messages are already in thread, stop as soon as
|
||||||
return null;
|
// we hit one that isn't — everything before that point is non-overlapping.
|
||||||
}
|
let cutoff = historyMessages.length;
|
||||||
|
for (let i = historyMessages.length - 1; i >= 0; i--) {
|
||||||
const queryIndex = trimmed.indexOf("?");
|
const msg = historyMessages[i];
|
||||||
if (queryIndex >= 0) {
|
if (!msg) {
|
||||||
const params = new URLSearchParams(trimmed.slice(queryIndex + 1));
|
continue;
|
||||||
const queryRunId = params.get("run_id")?.trim();
|
}
|
||||||
if (queryRunId) {
|
if (
|
||||||
return queryRunId;
|
(msg?.id && threadMessageIds.has(msg.id)) ||
|
||||||
|
("tool_call_id" in msg && threadMessageIds.has(msg.tool_call_id))
|
||||||
|
) {
|
||||||
|
cutoff = i;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const pathWithoutQueryOrHash = trimmed.split(/[?#]/, 1)[0]?.trim() ?? "";
|
return [
|
||||||
if (!pathWithoutQueryOrHash) {
|
...historyMessages.slice(0, cutoff),
|
||||||
return null;
|
...threadMessages,
|
||||||
}
|
...optimisticMessages,
|
||||||
|
];
|
||||||
const runsMarker = "/runs/";
|
|
||||||
const runsIndex = pathWithoutQueryOrHash.lastIndexOf(runsMarker);
|
|
||||||
if (runsIndex >= 0) {
|
|
||||||
const runIdAfterMarker = pathWithoutQueryOrHash
|
|
||||||
.slice(runsIndex + runsMarker.length)
|
|
||||||
.split("/", 1)[0]
|
|
||||||
?.trim();
|
|
||||||
if (runIdAfterMarker) {
|
|
||||||
return runIdAfterMarker;
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
const segments = pathWithoutQueryOrHash
|
|
||||||
.split("/")
|
|
||||||
.map((segment) => segment.trim())
|
|
||||||
.filter(Boolean);
|
|
||||||
return segments.at(-1) ?? null;
|
|
||||||
}
|
|
||||||
|
|
||||||
function getRunMetadataStorage(): {
|
|
||||||
getItem(key: `lg:stream:${string}`): string | null;
|
|
||||||
setItem(key: `lg:stream:${string}`, value: string): void;
|
|
||||||
removeItem(key: `lg:stream:${string}`): void;
|
|
||||||
} {
|
|
||||||
return {
|
|
||||||
getItem(key) {
|
|
||||||
const normalized = normalizeStoredRunId(
|
|
||||||
window.sessionStorage.getItem(key),
|
|
||||||
);
|
|
||||||
if (normalized) {
|
|
||||||
window.sessionStorage.setItem(key, normalized);
|
|
||||||
return normalized;
|
|
||||||
}
|
|
||||||
window.sessionStorage.removeItem(key);
|
|
||||||
return null;
|
|
||||||
},
|
|
||||||
setItem(key, value) {
|
|
||||||
const normalized = normalizeStoredRunId(value);
|
|
||||||
if (normalized) {
|
|
||||||
window.sessionStorage.setItem(key, normalized);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
window.sessionStorage.removeItem(key);
|
|
||||||
},
|
|
||||||
removeItem(key) {
|
|
||||||
window.sessionStorage.removeItem(key);
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function getStreamErrorMessage(error: unknown): string {
|
function getStreamErrorMessage(error: unknown): string {
|
||||||
@@ -140,6 +102,7 @@ export function useThreadStream({
|
|||||||
threadId,
|
threadId,
|
||||||
context,
|
context,
|
||||||
isMock,
|
isMock,
|
||||||
|
onSend,
|
||||||
onStart,
|
onStart,
|
||||||
onFinish,
|
onFinish,
|
||||||
onToolEnd,
|
onToolEnd,
|
||||||
@@ -151,17 +114,25 @@ export function useThreadStream({
|
|||||||
// and to allow access to the current thread id in onUpdateEvent
|
// and to allow access to the current thread id in onUpdateEvent
|
||||||
const threadIdRef = useRef<string | null>(threadId ?? null);
|
const threadIdRef = useRef<string | null>(threadId ?? null);
|
||||||
const startedRef = useRef(false);
|
const startedRef = useRef(false);
|
||||||
|
|
||||||
const listeners = useRef({
|
const listeners = useRef({
|
||||||
|
onSend,
|
||||||
onStart,
|
onStart,
|
||||||
onFinish,
|
onFinish,
|
||||||
onToolEnd,
|
onToolEnd,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const {
|
||||||
|
messages: history,
|
||||||
|
hasMore: hasMoreHistory,
|
||||||
|
loadMore: loadMoreHistory,
|
||||||
|
loading: isHistoryLoading,
|
||||||
|
appendMessages,
|
||||||
|
} = useThreadHistory(onStreamThreadId ?? "");
|
||||||
|
|
||||||
// Keep listeners ref updated with latest callbacks
|
// Keep listeners ref updated with latest callbacks
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
listeners.current = { onStart, onFinish, onToolEnd };
|
listeners.current = { onSend, onStart, onFinish, onToolEnd };
|
||||||
}, [onStart, onFinish, onToolEnd]);
|
}, [onSend, onStart, onFinish, onToolEnd]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const normalizedThreadId = threadId ?? null;
|
const normalizedThreadId = threadId ?? null;
|
||||||
@@ -175,45 +146,26 @@ export function useThreadStream({
|
|||||||
threadIdRef.current = normalizedThreadId;
|
threadIdRef.current = normalizedThreadId;
|
||||||
}, [threadId]);
|
}, [threadId]);
|
||||||
|
|
||||||
const _handleOnStart = useCallback((id: string) => {
|
const handleStreamStart = useCallback((_threadId: string, _runId: string) => {
|
||||||
|
threadIdRef.current = _threadId;
|
||||||
if (!startedRef.current) {
|
if (!startedRef.current) {
|
||||||
listeners.current.onStart?.(id);
|
listeners.current.onStart?.(_threadId, _runId);
|
||||||
startedRef.current = true;
|
startedRef.current = true;
|
||||||
}
|
}
|
||||||
|
setOnStreamThreadId(_threadId);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const handleStreamStart = useCallback(
|
|
||||||
(_threadId: string) => {
|
|
||||||
threadIdRef.current = _threadId;
|
|
||||||
_handleOnStart(_threadId);
|
|
||||||
},
|
|
||||||
[_handleOnStart],
|
|
||||||
);
|
|
||||||
|
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
const updateSubtask = useUpdateSubtask();
|
const updateSubtask = useUpdateSubtask();
|
||||||
const runMetadataStorageRef = useRef<
|
|
||||||
ReturnType<typeof getRunMetadataStorage> | undefined
|
|
||||||
>(undefined);
|
|
||||||
|
|
||||||
if (
|
|
||||||
typeof window !== "undefined" &&
|
|
||||||
runMetadataStorageRef.current === undefined
|
|
||||||
) {
|
|
||||||
runMetadataStorageRef.current = getRunMetadataStorage();
|
|
||||||
}
|
|
||||||
|
|
||||||
const thread = useStream<AgentThreadState>({
|
const thread = useStream<AgentThreadState>({
|
||||||
client: getAPIClient(isMock),
|
client: getAPIClient(isMock),
|
||||||
assistantId: "lead_agent",
|
assistantId: "lead_agent",
|
||||||
threadId: onStreamThreadId,
|
threadId: onStreamThreadId,
|
||||||
reconnectOnMount: runMetadataStorageRef.current
|
reconnectOnMount: true,
|
||||||
? () => runMetadataStorageRef.current!
|
|
||||||
: false,
|
|
||||||
fetchStateHistory: { limit: 1 },
|
fetchStateHistory: { limit: 1 },
|
||||||
onCreated(meta) {
|
onCreated(meta) {
|
||||||
handleStreamStart(meta.thread_id);
|
handleStreamStart(meta.thread_id, meta.run_id);
|
||||||
setOnStreamThreadId(meta.thread_id);
|
|
||||||
if (context.agent_name && !isMock) {
|
if (context.agent_name && !isMock) {
|
||||||
void getAPIClient()
|
void getAPIClient()
|
||||||
.threads.update(meta.thread_id, {
|
.threads.update(meta.thread_id, {
|
||||||
@@ -231,6 +183,34 @@ export function useThreadStream({
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
onUpdateEvent(data) {
|
onUpdateEvent(data) {
|
||||||
|
if (data["SummarizationMiddleware.before_model"]) {
|
||||||
|
const _messages = [
|
||||||
|
...(data["SummarizationMiddleware.before_model"].messages ?? []),
|
||||||
|
];
|
||||||
|
|
||||||
|
if (_messages.length < 2) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (const m of _messages) {
|
||||||
|
if (m.name === "summary" && m.type === "human") {
|
||||||
|
summarizedRef.current?.add(m.id ?? "");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const _lastKeepMessage = _messages[2];
|
||||||
|
const _currentMessages = [...messagesRef.current];
|
||||||
|
const _movedMessages: Message[] = [];
|
||||||
|
for (const m of _currentMessages) {
|
||||||
|
if (m.id !== undefined && m.id === _lastKeepMessage?.id) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (!summarizedRef.current?.has(m.id ?? "")) {
|
||||||
|
_movedMessages.push(m);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
appendMessages(_movedMessages);
|
||||||
|
messagesRef.current = [];
|
||||||
|
}
|
||||||
|
|
||||||
const updates: Array<Partial<AgentThreadState> | null> = Object.values(
|
const updates: Array<Partial<AgentThreadState> | null> = Object.values(
|
||||||
data || {},
|
data || {},
|
||||||
);
|
);
|
||||||
@@ -295,9 +275,6 @@ export function useThreadStream({
|
|||||||
onFinish(state) {
|
onFinish(state) {
|
||||||
listeners.current.onFinish?.(state.values);
|
listeners.current.onFinish?.(state.values);
|
||||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||||
void queryClient.invalidateQueries({
|
|
||||||
queryKey: ["thread-message-enrichment"],
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -305,24 +282,25 @@ export function useThreadStream({
|
|||||||
const [optimisticMessages, setOptimisticMessages] = useState<Message[]>([]);
|
const [optimisticMessages, setOptimisticMessages] = useState<Message[]>([]);
|
||||||
const [isUploading, setIsUploading] = useState(false);
|
const [isUploading, setIsUploading] = useState(false);
|
||||||
const sendInFlightRef = useRef(false);
|
const sendInFlightRef = useRef(false);
|
||||||
|
const messagesRef = useRef<Message[]>([]);
|
||||||
|
const summarizedRef = useRef<Set<string>>(null);
|
||||||
// Track message count before sending so we know when server has responded
|
// Track message count before sending so we know when server has responded
|
||||||
const prevMsgCountRef = useRef(thread.messages.length);
|
const prevMsgCountRef = useRef(thread.messages.length);
|
||||||
|
|
||||||
|
summarizedRef.current ??= new Set<string>();
|
||||||
|
|
||||||
// Reset thread-local pending UI state when switching between threads so
|
// Reset thread-local pending UI state when switching between threads so
|
||||||
// optimistic messages and in-flight guards do not leak across chat views.
|
// optimistic messages and in-flight guards do not leak across chat views.
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
startedRef.current = false;
|
startedRef.current = false;
|
||||||
sendInFlightRef.current = false;
|
sendInFlightRef.current = false;
|
||||||
prevMsgCountRef.current = 0;
|
|
||||||
setOptimisticMessages([]);
|
|
||||||
setIsUploading(false);
|
|
||||||
}, [threadId]);
|
}, [threadId]);
|
||||||
|
|
||||||
// Clear optimistic when server messages arrive (count increases)
|
// Clear optimistic when server messages arrive (count increases)
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (
|
if (
|
||||||
optimisticMessages.length > 0 &&
|
optimisticMessages.length > 0 &&
|
||||||
thread.messages.length > prevMsgCountRef.current + 1
|
thread.messages.length > prevMsgCountRef.current
|
||||||
) {
|
) {
|
||||||
setOptimisticMessages([]);
|
setOptimisticMessages([]);
|
||||||
}
|
}
|
||||||
@@ -381,12 +359,7 @@ export function useThreadStream({
|
|||||||
}
|
}
|
||||||
setOptimisticMessages(newOptimistic);
|
setOptimisticMessages(newOptimistic);
|
||||||
|
|
||||||
// Only fire onStart immediately for an existing persisted thread.
|
listeners.current.onSend?.(threadId);
|
||||||
// Brand-new chats should wait for onCreated(meta.thread_id) so URL sync
|
|
||||||
// uses the real server-generated thread id.
|
|
||||||
if (threadIdRef.current) {
|
|
||||||
_handleOnStart(threadId);
|
|
||||||
}
|
|
||||||
|
|
||||||
let uploadedFileInfo: UploadedFileInfo[] = [];
|
let uploadedFileInfo: UploadedFileInfo[] = [];
|
||||||
|
|
||||||
@@ -520,19 +493,106 @@ export function useThreadStream({
|
|||||||
sendInFlightRef.current = false;
|
sendInFlightRef.current = false;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[thread, _handleOnStart, t.uploads.uploadingFiles, context, queryClient],
|
[thread, t.uploads.uploadingFiles, context, queryClient],
|
||||||
);
|
);
|
||||||
|
|
||||||
// Merge thread with optimistic messages for display
|
// Cache the latest thread messages in a ref to compare against incoming history messages for deduplication,
|
||||||
const mergedThread =
|
// and to allow access to the full message list in onUpdateEvent without causing re-renders.
|
||||||
optimisticMessages.length > 0
|
if (thread.messages.length >= messagesRef.current.length) {
|
||||||
? ({
|
messagesRef.current = thread.messages;
|
||||||
...thread,
|
}
|
||||||
messages: [...thread.messages, ...optimisticMessages],
|
|
||||||
} as typeof thread)
|
|
||||||
: thread;
|
|
||||||
|
|
||||||
return [mergedThread, sendMessage, isUploading] as const;
|
const mergedMessages = mergeMessages(
|
||||||
|
history,
|
||||||
|
thread.messages,
|
||||||
|
optimisticMessages,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Merge history, live stream, and optimistic messages for display
|
||||||
|
// History messages may overlap with thread.messages; thread.messages take precedence
|
||||||
|
const mergedThread = {
|
||||||
|
...thread,
|
||||||
|
messages: mergedMessages,
|
||||||
|
} as typeof thread;
|
||||||
|
|
||||||
|
return {
|
||||||
|
thread: mergedThread,
|
||||||
|
sendMessage,
|
||||||
|
isUploading,
|
||||||
|
isHistoryLoading,
|
||||||
|
hasMoreHistory,
|
||||||
|
loadMoreHistory,
|
||||||
|
} as const;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useThreadHistory(threadId: string) {
|
||||||
|
const runs = useThreadRuns(threadId);
|
||||||
|
const threadIdRef = useRef(threadId);
|
||||||
|
const runsRef = useRef(runs.data ?? []);
|
||||||
|
const indexRef = useRef(-1);
|
||||||
|
const loadingRef = useRef(false);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [messages, setMessages] = useState<Message[]>([]);
|
||||||
|
|
||||||
|
loadingRef.current = loading;
|
||||||
|
const loadMessages = useCallback(async () => {
|
||||||
|
if (runsRef.current.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const run = runsRef.current[indexRef.current];
|
||||||
|
if (!run || loadingRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
setLoading(true);
|
||||||
|
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
|
||||||
|
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadIdRef.current)}/runs/${encodeURIComponent(run.run_id)}/messages`,
|
||||||
|
{
|
||||||
|
method: "GET",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
credentials: "include",
|
||||||
|
},
|
||||||
|
).then((res) => {
|
||||||
|
return res.json();
|
||||||
|
});
|
||||||
|
const _messages = result.data
|
||||||
|
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
||||||
|
.map((m) => m.content);
|
||||||
|
setMessages((prev) => [..._messages, ...prev]);
|
||||||
|
indexRef.current -= 1;
|
||||||
|
} catch (err) {
|
||||||
|
console.error(err);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
useEffect(() => {
|
||||||
|
threadIdRef.current = threadId;
|
||||||
|
if (runs.data && runs.data.length > 0) {
|
||||||
|
runsRef.current = runs.data ?? [];
|
||||||
|
indexRef.current = runs.data.length - 1;
|
||||||
|
}
|
||||||
|
loadMessages().catch(() => {
|
||||||
|
toast.error("Failed to load thread history.");
|
||||||
|
});
|
||||||
|
}, [threadId, runs.data, loadMessages]);
|
||||||
|
|
||||||
|
const appendMessages = useCallback((_messages: Message[]) => {
|
||||||
|
setMessages((prev) => {
|
||||||
|
return [...prev, ..._messages];
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
const hasMore = indexRef.current >= 0 || !runs.data;
|
||||||
|
return {
|
||||||
|
runs: runs.data,
|
||||||
|
messages,
|
||||||
|
loading,
|
||||||
|
appendMessages,
|
||||||
|
hasMore,
|
||||||
|
loadMore: loadMessages,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useThreads(
|
export function useThreads(
|
||||||
@@ -602,6 +662,33 @@ export function useThreads(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function useThreadRuns(threadId?: string) {
|
||||||
|
const apiClient = getAPIClient();
|
||||||
|
return useQuery<Run[]>({
|
||||||
|
queryKey: ["thread", threadId],
|
||||||
|
queryFn: async () => {
|
||||||
|
if (!threadId) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const response = await apiClient.runs.list(threadId);
|
||||||
|
return response;
|
||||||
|
},
|
||||||
|
refetchOnWindowFocus: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useRunDetail(threadId: string, runId: string) {
|
||||||
|
const apiClient = getAPIClient();
|
||||||
|
return useQuery<Run>({
|
||||||
|
queryKey: ["thread", threadId, "run", runId],
|
||||||
|
queryFn: async () => {
|
||||||
|
const response = await apiClient.runs.get(threadId, runId);
|
||||||
|
return response;
|
||||||
|
},
|
||||||
|
refetchOnWindowFocus: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
export function useDeleteThread() {
|
export function useDeleteThread() {
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
const apiClient = getAPIClient();
|
const apiClient = getAPIClient();
|
||||||
@@ -609,7 +696,7 @@ export function useDeleteThread() {
|
|||||||
mutationFn: async ({ threadId }: { threadId: string }) => {
|
mutationFn: async ({ threadId }: { threadId: string }) => {
|
||||||
await apiClient.threads.delete(threadId);
|
await apiClient.threads.delete(threadId);
|
||||||
|
|
||||||
const response = await fetchWithAuth(
|
const response = await fetch(
|
||||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}`,
|
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}`,
|
||||||
{
|
{
|
||||||
method: "DELETE",
|
method: "DELETE",
|
||||||
@@ -682,65 +769,3 @@ export function useRenameThread() {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Per-message enrichment data attached by the backend ``/history`` helper. */
|
|
||||||
export interface MessageEnrichment {
|
|
||||||
run_id: string;
|
|
||||||
/** ``undefined`` = not feedback-eligible; ``null`` = eligible but unrated. */
|
|
||||||
feedback?: FeedbackData | null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Fetch ``/history`` once and index feedback + run_id by message id.
|
|
||||||
*
|
|
||||||
* Replaces the old ``useThreadFeedback`` hook which keyed by AI-message
|
|
||||||
* ordinal position — an inherently fragile mapping that broke whenever
|
|
||||||
* ``ai_tool_call`` messages were interleaved with ``ai_message`` messages.
|
|
||||||
* Keying by ``message.id`` is stable regardless of run count, tool-call
|
|
||||||
* chains, or summarization.
|
|
||||||
*
|
|
||||||
* The ``/history`` response is refreshed on every stream completion via
|
|
||||||
* ``invalidateQueries(["thread-message-enrichment"])`` in ``onFinish``.
|
|
||||||
*/
|
|
||||||
export function useThreadMessageEnrichment(
|
|
||||||
threadId: string | null | undefined,
|
|
||||||
) {
|
|
||||||
return useQuery({
|
|
||||||
queryKey: ["thread-message-enrichment", threadId],
|
|
||||||
queryFn: async (): Promise<Map<string, MessageEnrichment>> => {
|
|
||||||
const empty = new Map<string, MessageEnrichment>();
|
|
||||||
if (!threadId) return empty;
|
|
||||||
const res = await fetchWithAuth(
|
|
||||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/history`,
|
|
||||||
{
|
|
||||||
method: "POST",
|
|
||||||
headers: { "Content-Type": "application/json" },
|
|
||||||
body: JSON.stringify({ limit: 1 }),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
if (!res.ok) return empty;
|
|
||||||
const entries = (await res.json()) as Array<{
|
|
||||||
values?: {
|
|
||||||
messages?: Array<{
|
|
||||||
id?: string;
|
|
||||||
run_id?: string;
|
|
||||||
feedback?: FeedbackData | null;
|
|
||||||
}>;
|
|
||||||
};
|
|
||||||
}>;
|
|
||||||
const messages = entries[0]?.values?.messages ?? [];
|
|
||||||
const map = new Map<string, MessageEnrichment>();
|
|
||||||
for (const m of messages) {
|
|
||||||
if (!m.id || !m.run_id) continue;
|
|
||||||
const entry: MessageEnrichment = { run_id: m.run_id };
|
|
||||||
// Preserve presence: "feedback" key absent → ineligible; present with
|
|
||||||
// null → eligible but unrated; present with object → rated.
|
|
||||||
if ("feedback" in m) entry.feedback = m.feedback;
|
|
||||||
map.set(m.id, entry);
|
|
||||||
}
|
|
||||||
return map;
|
|
||||||
},
|
|
||||||
enabled: !!threadId,
|
|
||||||
staleTime: 30_000,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -22,3 +22,12 @@ export interface AgentThreadContext extends Record<string, unknown> {
|
|||||||
export interface AgentThread extends Thread<AgentThreadState> {
|
export interface AgentThread extends Thread<AgentThreadState> {
|
||||||
context?: AgentThreadContext;
|
context?: AgentThreadContext;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface RunMessage {
|
||||||
|
run_id: string;
|
||||||
|
content: Message;
|
||||||
|
metadata: {
|
||||||
|
caller: string;
|
||||||
|
};
|
||||||
|
created_at: string;
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* API functions for file uploads
|
* API functions for file uploads
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { fetchWithAuth } from "../api/fetcher";
|
import { fetch } from "../api/fetcher";
|
||||||
import { getBackendBaseURL } from "../config";
|
import { getBackendBaseURL } from "../config";
|
||||||
|
|
||||||
export interface UploadedFileInfo {
|
export interface UploadedFileInfo {
|
||||||
@@ -51,7 +51,7 @@ export async function uploadFiles(
|
|||||||
formData.append("files", file);
|
formData.append("files", file);
|
||||||
});
|
});
|
||||||
|
|
||||||
const response = await fetchWithAuth(
|
const response = await fetch(
|
||||||
`${getBackendBaseURL()}/api/threads/${threadId}/uploads`,
|
`${getBackendBaseURL()}/api/threads/${threadId}/uploads`,
|
||||||
{
|
{
|
||||||
method: "POST",
|
method: "POST",
|
||||||
@@ -92,7 +92,7 @@ export async function deleteUploadedFile(
|
|||||||
threadId: string,
|
threadId: string,
|
||||||
filename: string,
|
filename: string,
|
||||||
): Promise<{ success: boolean; message: string }> {
|
): Promise<{ success: boolean; message: string }> {
|
||||||
const response = await fetchWithAuth(
|
const response = await fetch(
|
||||||
`${getBackendBaseURL()}/api/threads/${threadId}/uploads/${filename}`,
|
`${getBackendBaseURL()}/api/threads/${threadId}/uploads/${filename}`,
|
||||||
{
|
{
|
||||||
method: "DELETE",
|
method: "DELETE",
|
||||||
|
|||||||
Reference in New Issue
Block a user