Compare commits

..

1 Commits

Author SHA1 Message Date
greatmengqi 2eb45e9bb5 fix: thread app config through client and sync providers 2026-05-02 12:07:26 +08:00
53 changed files with 538 additions and 4089 deletions
-3
View File
@@ -1,6 +1,3 @@
# Serper API Key (Google Search) - https://serper.dev
SERPER_API_KEY=your-serper-api-key
# TAVILY API Key
TAVILY_API_KEY=your-tavily-api-key
-101
View File
@@ -1,101 +0,0 @@
name: Publish Containers
on:
push:
tags:
- "v*"
jobs:
backend-container:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
attestations: write
id-token: write
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}-backend
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Log in to the Container registry
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=tag
type=ref,event=branch
type=sha
type=raw,value=latest,enable={{is_default_branch}}
- name: Build and push Docker image
id: push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
with:
context: .
file: backend/Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
- name: Generate artifact attestation
uses: actions/attest-build-provenance@v2
with:
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
frontend-container:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
attestations: write
id-token: write
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}-frontend
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Log in to the Container registry
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=tag
type=ref,event=branch
type=sha
type=raw,value=latest,enable={{is_default_branch}}
- name: Build and push Docker image
id: push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
with:
context: .
file: frontend/Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
- name: Generate artifact attestation
uses: actions/attest-build-provenance@v2
with:
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
-10
View File
@@ -50,12 +50,6 @@ COPY backend ./backend
RUN --mount=type=cache,target=/root/.cache/uv \
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
# UTF-8 locale prevents UnicodeEncodeError on Chinese/emoji content in minimal
# containers where locale configuration may be missing and the default encoding is not UTF-8.
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
ENV PYTHONIOENCODING=utf-8
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
# Retains compiler toolchain from builder so startup-time `uv sync` can build
# source distributions in development containers.
@@ -72,10 +66,6 @@ CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app
# Clean image without build-essential — reduces size (~200 MB) and attack surface.
FROM python:3.12-slim-bookworm
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
ENV PYTHONIOENCODING=utf-8
# Copy Node.js runtime from builder (provides npx for MCP servers)
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
+2 -11
View File
@@ -420,13 +420,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
if not msg.files:
return []
from deerflow.uploads.manager import (
UnsafeUploadPathError,
claim_unique_filename,
ensure_uploads_dir,
normalize_filename,
write_upload_file_no_symlink,
)
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
uploads_dir = ensure_uploads_dir(thread_id)
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
@@ -477,10 +471,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
dest = uploads_dir / safe_name
try:
dest = write_upload_file_no_symlink(uploads_dir, safe_name, data)
except UnsafeUploadPathError:
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
continue
dest.write_bytes(data)
except Exception:
logger.exception("[Manager] failed to write inbound file: %s", dest)
continue
+21 -23
View File
@@ -13,11 +13,11 @@ matching the LangGraph Platform wire format expected by the
from __future__ import annotations
import logging
import time
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from langgraph.checkpoint.base import empty_checkpoint
from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_permission
@@ -26,7 +26,6 @@ from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.utils.time import coerce_iso, now_iso
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["threads"])
@@ -234,7 +233,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
checkpointer = get_checkpointer(request)
thread_store = get_thread_store(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = now_iso()
now = time.time()
# ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
@@ -244,8 +243,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
return ThreadResponse(
thread_id=thread_id,
status=existing_record.get("status", "idle"),
created_at=coerce_iso(existing_record.get("created_at", "")),
updated_at=coerce_iso(existing_record.get("updated_at", "")),
created_at=str(existing_record.get("created_at", "")),
updated_at=str(existing_record.get("updated_at", "")),
metadata=existing_record.get("metadata", {}),
)
@@ -263,6 +262,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
# Write an empty checkpoint so state endpoints work immediately
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
from langgraph.checkpoint.base import empty_checkpoint
ckpt_metadata = {
"step": -1,
"source": "input",
@@ -280,8 +281,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
return ThreadResponse(
thread_id=thread_id,
status="idle",
created_at=now,
updated_at=now,
created_at=str(now),
updated_at=str(now),
metadata=body.metadata,
)
@@ -306,11 +307,8 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
ThreadResponse(
thread_id=r["thread_id"],
status=r.get("status", "idle"),
# ``coerce_iso`` heals legacy unix-second values that
# ``MemoryThreadMetaStore`` historically wrote with ``time.time()``;
# SQL-backed rows already arrive as ISO strings and pass through.
created_at=coerce_iso(r.get("created_at", "")),
updated_at=coerce_iso(r.get("updated_at", "")),
created_at=r.get("created_at", ""),
updated_at=r.get("updated_at", ""),
metadata=r.get("metadata", {}),
values={"title": r["display_name"]} if r.get("display_name") else {},
interrupts={},
@@ -342,8 +340,8 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
return ThreadResponse(
thread_id=thread_id,
status=record.get("status", "idle"),
created_at=coerce_iso(record.get("created_at", "")),
updated_at=coerce_iso(record.get("updated_at", "")),
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
)
@@ -383,8 +381,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
record = {
"thread_id": thread_id,
"status": "idle",
"created_at": coerce_iso(ckpt_meta.get("created_at", "")),
"updated_at": coerce_iso(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
"created_at": ckpt_meta.get("created_at", ""),
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
}
@@ -398,8 +396,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
return ThreadResponse(
thread_id=thread_id,
status=status,
created_at=coerce_iso(record.get("created_at", "")),
updated_at=coerce_iso(record.get("updated_at", "")),
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
values=serialize_channel_values(channel_values),
)
@@ -450,10 +448,10 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
values=values,
next=next_tasks,
metadata=metadata,
checkpoint={"id": checkpoint_id, "ts": coerce_iso(metadata.get("created_at", ""))},
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_checkpoint_id,
created_at=coerce_iso(metadata.get("created_at", "")),
created_at=str(metadata.get("created_at", "")),
tasks=tasks,
)
@@ -503,7 +501,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
channel_values.update(body.values)
checkpoint["channel_values"] = channel_values
metadata["updated_at"] = now_iso()
metadata["updated_at"] = time.time()
if body.as_node:
metadata["source"] = "update"
@@ -544,7 +542,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
next=[],
metadata=metadata,
checkpoint_id=new_checkpoint_id,
created_at=coerce_iso(metadata.get("created_at", "")),
created_at=str(metadata.get("created_at", "")),
)
@@ -611,7 +609,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
parent_checkpoint_id=parent_id,
metadata=user_meta,
values=values,
created_at=coerce_iso(metadata.get("created_at", "")),
created_at=str(metadata.get("created_at", "")),
next=next_tasks,
)
)
+13 -35
View File
@@ -5,7 +5,7 @@ import os
import stat
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from pydantic import BaseModel, Field
from pydantic import BaseModel
from app.gateway.authz import require_permission
from app.gateway.deps import get_config
@@ -15,14 +15,12 @@ from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
from deerflow.uploads.manager import (
PathTraversalError,
UnsafeUploadPathError,
delete_file_safe,
enrich_file_listing,
ensure_uploads_dir,
get_uploads_dir,
list_files_in_dir,
normalize_filename,
open_upload_file_no_symlink,
upload_artifact_url,
upload_virtual_path,
)
@@ -44,7 +42,6 @@ class UploadResponse(BaseModel):
success: bool
files: list[dict[str, str]]
message: str
skipped_files: list[str] = Field(default_factory=list)
class UploadLimits(BaseModel):
@@ -119,18 +116,17 @@ def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None:
logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True)
async def _write_upload_file_with_limits(
async def _write_upload_file_streaming(
file: UploadFile,
file_path: os.PathLike[str] | str,
*,
uploads_dir: os.PathLike[str] | str,
display_filename: str,
max_single_file_size: int,
max_total_size: int,
total_size: int,
) -> tuple[os.PathLike[str] | str, int, int]:
) -> tuple[int, int]:
file_size = 0
file_path, fh = open_upload_file_no_symlink(uploads_dir, display_filename)
try:
with open(file_path, "wb") as output:
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
file_size += len(chunk)
total_size += len(chunk)
@@ -138,17 +134,8 @@ async def _write_upload_file_with_limits(
raise HTTPException(status_code=413, detail=f"File too large: {display_filename}")
if total_size > max_total_size:
raise HTTPException(status_code=413, detail="Total upload size too large")
fh.write(chunk)
except Exception:
fh.close()
try:
os.unlink(file_path)
except FileNotFoundError:
pass
raise
else:
fh.close()
return file_path, file_size, total_size
output.write(chunk)
return file_size, total_size
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
@@ -190,7 +177,6 @@ async def upload_files(
uploaded_files = []
written_paths = []
sandbox_sync_targets = []
skipped_files = []
total_size = 0
sandbox_provider = get_sandbox_provider()
@@ -214,15 +200,16 @@ async def upload_files(
continue
try:
file_path, file_size, total_size = await _write_upload_file_with_limits(
file_path = uploads_dir / safe_filename
written_paths.append(file_path)
file_size, total_size = await _write_upload_file_streaming(
file,
uploads_dir=uploads_dir,
file_path,
display_filename=safe_filename,
max_single_file_size=limits.max_file_size,
max_total_size=limits.max_total_size,
total_size=total_size,
)
written_paths.append(file_path)
virtual_path = upload_virtual_path(safe_filename)
@@ -259,10 +246,6 @@ async def upload_files(
except HTTPException as e:
_cleanup_uploaded_paths(written_paths)
raise e
except UnsafeUploadPathError as e:
logger.warning("Skipping upload with unsafe destination %s: %s", file.filename, e)
skipped_files.append(safe_filename)
continue
except Exception as e:
logger.error(f"Failed to upload {file.filename}: {e}")
_cleanup_uploaded_paths(written_paths)
@@ -273,15 +256,10 @@ async def upload_files(
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, file_path.read_bytes())
message = f"Successfully uploaded {len(uploaded_files)} file(s)"
if skipped_files:
message += f"; skipped {len(skipped_files)} unsafe file(s)"
return UploadResponse(
success=not skipped_files,
success=True,
files=uploaded_files,
message=message,
skipped_files=skipped_files,
message=f"Successfully uploaded {len(uploaded_files)} file(s)",
)
@@ -1,270 +1,31 @@
"""Middleware for logging token usage and annotating step attribution."""
from __future__ import annotations
"""Middleware for logging LLM token usage."""
import logging
from collections import defaultdict
from typing import Any, override
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.todo import Todo
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
TOKEN_USAGE_ATTRIBUTION_KEY = "token_usage_attribution"
def _string_arg(value: Any) -> str | None:
if isinstance(value, str):
normalized = value.strip()
return normalized or None
return None
def _normalize_todos(value: Any) -> list[Todo]:
if not isinstance(value, list):
return []
normalized: list[Todo] = []
for item in value:
if not isinstance(item, dict):
continue
todo: Todo = {}
content = _string_arg(item.get("content"))
status = item.get("status")
if content is not None:
todo["content"] = content
if status in {"pending", "in_progress", "completed"}:
todo["status"] = status
normalized.append(todo)
return normalized
def _todo_action_kind(previous: Todo | None, current: Todo) -> str:
status = current.get("status")
previous_content = previous.get("content") if previous else None
current_content = current.get("content")
if previous is None:
if status == "completed":
return "todo_complete"
if status == "in_progress":
return "todo_start"
return "todo_update"
if previous_content != current_content:
return "todo_update"
if status == "completed":
return "todo_complete"
if status == "in_progress":
return "todo_start"
return "todo_update"
def _build_todo_actions(previous_todos: list[Todo], next_todos: list[Todo]) -> list[dict[str, Any]]:
# This is the single source of truth for precise write_todos token
# attribution. The frontend intentionally falls back to a generic
# "Update to-do list" label when this metadata is missing or malformed.
previous_by_content: dict[str, list[tuple[int, Todo]]] = defaultdict(list)
matched_previous_indices: set[int] = set()
for index, todo in enumerate(previous_todos):
content = todo.get("content")
if isinstance(content, str) and content:
previous_by_content[content].append((index, todo))
actions: list[dict[str, Any]] = []
for index, todo in enumerate(next_todos):
content = todo.get("content")
if not isinstance(content, str) or not content:
continue
previous_match: Todo | None = None
content_matches = previous_by_content.get(content)
if content_matches:
while content_matches and content_matches[0][0] in matched_previous_indices:
content_matches.pop(0)
if content_matches:
previous_index, previous_match = content_matches.pop(0)
matched_previous_indices.add(previous_index)
if previous_match is None and index < len(previous_todos) and index not in matched_previous_indices:
previous_match = previous_todos[index]
matched_previous_indices.add(index)
if previous_match is not None:
previous_content = previous_match.get("content")
previous_status = previous_match.get("status")
if previous_content == content and previous_status == todo.get("status"):
continue
actions.append(
{
"kind": _todo_action_kind(previous_match, todo),
"content": content,
}
)
for index, todo in enumerate(previous_todos):
if index in matched_previous_indices:
continue
content = todo.get("content")
if not isinstance(content, str) or not content:
continue
actions.append(
{
"kind": "todo_remove",
"content": content,
}
)
return actions
def _describe_tool_call(tool_call: dict[str, Any], todos: list[Todo]) -> list[dict[str, Any]]:
name = _string_arg(tool_call.get("name")) or "unknown"
args = tool_call.get("args") if isinstance(tool_call.get("args"), dict) else {}
tool_call_id = _string_arg(tool_call.get("id"))
if name == "write_todos":
next_todos = _normalize_todos(args.get("todos"))
actions = _build_todo_actions(todos, next_todos)
if not actions:
return [
{
"kind": "tool",
"tool_name": name,
"tool_call_id": tool_call_id,
}
]
return [
{
**action,
"tool_call_id": tool_call_id,
}
for action in actions
]
if name == "task":
return [
{
"kind": "subagent",
"description": _string_arg(args.get("description")),
"subagent_type": _string_arg(args.get("subagent_type")),
"tool_call_id": tool_call_id,
}
]
if name in {"web_search", "image_search"}:
query = _string_arg(args.get("query"))
return [
{
"kind": "search",
"tool_name": name,
"query": query,
"tool_call_id": tool_call_id,
}
]
if name == "present_files":
return [
{
"kind": "present_files",
"tool_call_id": tool_call_id,
}
]
if name == "ask_clarification":
return [
{
"kind": "clarification",
"tool_call_id": tool_call_id,
}
]
return [
{
"kind": "tool",
"tool_name": name,
"description": _string_arg(args.get("description")),
"tool_call_id": tool_call_id,
}
]
def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
if actions:
first_kind = actions[0].get("kind")
if len(actions) == 1 and first_kind in {"todo_start", "todo_complete", "todo_update", "todo_remove"}:
return "todo_update"
if len(actions) == 1 and first_kind == "subagent":
return "subagent_dispatch"
return "tool_batch"
if message.content:
return "final_answer"
return "thinking"
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
tool_calls = getattr(message, "tool_calls", None) or []
actions: list[dict[str, Any]] = []
current_todos = list(todos)
for raw_tool_call in tool_calls:
if not isinstance(raw_tool_call, dict):
continue
described_actions = _describe_tool_call(raw_tool_call, current_todos)
actions.extend(described_actions)
if raw_tool_call.get("name") == "write_todos":
args = raw_tool_call.get("args") if isinstance(raw_tool_call.get("args"), dict) else {}
current_todos = _normalize_todos(args.get("todos"))
tool_call_ids: list[str] = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
continue
tool_call_id = _string_arg(tool_call.get("id"))
if tool_call_id is not None:
tool_call_ids.append(tool_call_id)
return {
# Schema changes should remain additive where possible so older
# frontends can ignore unknown fields and fall back safely.
"version": 1,
"kind": _infer_step_kind(message, actions),
"shared_attribution": len(actions) > 1,
"tool_call_ids": tool_call_ids,
"actions": actions,
}
class TokenUsageMiddleware(AgentMiddleware):
"""Logs token usage from model responses and annotates the AI step."""
"""Logs token usage from model response usage_metadata."""
def _apply(self, state: AgentState) -> dict | None:
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state)
def _log_usage(self, state: AgentState) -> None:
messages = state.get("messages", [])
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
usage = getattr(last, "usage_metadata", None)
if usage:
logger.info(
@@ -273,22 +34,4 @@ class TokenUsageMiddleware(AgentMiddleware):
usage.get("output_tokens", "?"),
usage.get("total_tokens", "?"),
)
todos = state.get("todos") or []
attribution = _build_attribution(last, todos if isinstance(todos, list) else [])
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
return None
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
return {"messages": [updated_msg]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state)
return None
+37 -106
View File
@@ -228,14 +228,21 @@ class DeerFlowClient:
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
kwargs: dict[str, Any] = {
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=self._app_config),
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
"middleware": _build_middlewares(
config,
model_name=model_name,
agent_name=self._agent_name,
custom_middlewares=self._middlewares,
app_config=self._app_config,
),
"system_prompt": apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=self._agent_name,
available_skills=self._available_skills,
app_config=self._app_config,
),
"state_schema": ThreadState,
}
@@ -243,7 +250,7 @@ class DeerFlowClient:
if checkpointer is None:
from deerflow.runtime.checkpointer import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(app_config=self._app_config)
if checkpointer is not None:
kwargs["checkpointer"] = checkpointer
@@ -251,12 +258,15 @@ class DeerFlowClient:
self._agent_config_key = key
logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled)
@staticmethod
def _get_tools(*, model_name: str | None, subagent_enabled: bool):
def _get_tools(self, *, model_name: str | None, subagent_enabled: bool):
"""Lazy import to avoid circular dependency at module level."""
from deerflow.tools import get_available_tools
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
return get_available_tools(
model_name=model_name,
subagent_enabled=subagent_enabled,
app_config=self._app_config,
)
@staticmethod
def _serialize_tool_calls(tool_calls) -> list[dict]:
@@ -264,35 +274,25 @@ class DeerFlowClient:
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
@staticmethod
def _serialize_additional_kwargs(msg) -> dict[str, Any] | None:
"""Copy message additional_kwargs when present."""
additional_kwargs = getattr(msg, "additional_kwargs", None)
if isinstance(additional_kwargs, dict) and additional_kwargs:
return dict(additional_kwargs)
return None
@staticmethod
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI text event."""
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI text event, attaching usage when present."""
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
if usage:
data["usage_metadata"] = usage
if additional_kwargs:
data["additional_kwargs"] = additional_kwargs
return StreamEvent(type="messages-tuple", data=data)
@staticmethod
def _ai_tool_calls_event(msg_id: str | None, tool_calls, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent":
"""Build a ``messages-tuple`` AI tool-calls event."""
data: dict[str, Any] = {
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
}
if additional_kwargs:
data["additional_kwargs"] = additional_kwargs
return StreamEvent(type="messages-tuple", data=data)
return StreamEvent(
type="messages-tuple",
data={
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
},
)
@staticmethod
def _tool_message_event(msg: ToolMessage) -> "StreamEvent":
@@ -317,30 +317,19 @@ class DeerFlowClient:
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
if getattr(msg, "usage_metadata", None):
d["usage_metadata"] = msg.usage_metadata
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
if isinstance(msg, ToolMessage):
d = {
return {
"type": "tool",
"content": DeerFlowClient._extract_text(msg.content),
"name": getattr(msg, "name", None),
"tool_call_id": getattr(msg, "tool_call_id", None),
"id": getattr(msg, "id", None),
}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
if isinstance(msg, HumanMessage):
d = {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
if isinstance(msg, SystemMessage):
d = {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
@staticmethod
@@ -398,7 +387,7 @@ class DeerFlowClient:
if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(app_config=self._app_config)
thread_info_map = {}
@@ -453,7 +442,7 @@ class DeerFlowClient:
if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(app_config=self._app_config)
config = {"configurable": {"thread_id": thread_id}}
checkpoints = []
@@ -563,7 +552,6 @@ class DeerFlowClient:
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "additional_kwargs": {...}}
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
"""
@@ -586,7 +574,6 @@ class DeerFlowClient:
# in both the final ``messages`` chunk and the values snapshot —
# count it only on whichever arrives first.
counted_usage_ids: set[str] = set()
sent_additional_kwargs_by_id: dict[str, dict[str, Any]] = {}
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
@@ -616,20 +603,6 @@ class DeerFlowClient:
"total_tokens": total_tokens,
}
def _unsent_additional_kwargs(msg_id: str | None, additional_kwargs: dict[str, Any] | None) -> dict[str, Any] | None:
if not additional_kwargs:
return None
if not msg_id:
return additional_kwargs
sent = sent_additional_kwargs_by_id.setdefault(msg_id, {})
delta = {key: value for key, value in additional_kwargs.items() if sent.get(key) != value}
if not delta:
return None
sent.update(delta)
return delta
for item in self._agent.stream(
state,
config=config,
@@ -657,31 +630,17 @@ class DeerFlowClient:
if isinstance(msg_chunk, AIMessage):
text = self._extract_text(msg_chunk.content)
additional_kwargs = self._serialize_additional_kwargs(msg_chunk)
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
sent_additional_kwargs = False
if text:
if msg_id:
streamed_ids.add(msg_id)
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_text_event(
msg_id,
text,
counted_usage,
additional_kwargs_delta,
)
sent_additional_kwargs = bool(additional_kwargs_delta)
yield self._ai_text_event(msg_id, text, counted_usage)
if msg_chunk.tool_calls:
if msg_id:
streamed_ids.add(msg_id)
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_tool_calls_event(
msg_id,
msg_chunk.tool_calls,
additional_kwargs_delta,
)
yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls)
elif isinstance(msg_chunk, ToolMessage):
if msg_id:
@@ -704,45 +663,17 @@ class DeerFlowClient:
if msg_id and msg_id in streamed_ids:
if isinstance(msg, AIMessage):
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
additional_kwargs = self._serialize_additional_kwargs(msg)
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
if additional_kwargs_delta:
# Metadata-only follow-up: ``messages-tuple`` has no
# dedicated attribution event, so clients should
# merge this empty-content AI event by message id
# and ignore it for text rendering.
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
continue
if isinstance(msg, AIMessage):
counted_usage = _account_usage(msg_id, msg.usage_metadata)
additional_kwargs = self._serialize_additional_kwargs(msg)
sent_additional_kwargs = False
if msg.tool_calls:
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_tool_calls_event(
msg_id,
msg.tool_calls,
additional_kwargs_delta,
)
sent_additional_kwargs = bool(additional_kwargs_delta)
yield self._ai_tool_calls_event(msg_id, msg.tool_calls)
text = self._extract_text(msg.content)
if text:
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_text_event(
msg_id,
text,
counted_usage,
additional_kwargs_delta,
)
elif msg_id:
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
if not additional_kwargs_delta:
continue
# See the metadata-only follow-up convention above.
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
yield self._ai_text_event(msg_id, text, counted_usage)
elif isinstance(msg, ToolMessage):
yield self._tool_message_event(msg)
@@ -1,3 +0,0 @@
from .tools import web_search_tool
__all__ = ["web_search_tool"]
@@ -1,95 +0,0 @@
"""
Web Search Tool - Search the web using Serper (Google Search API).
Serper provides real-time Google Search results via a JSON API.
An API key is required. Sign up at https://serper.dev to get one.
"""
import json
import logging
import os
import httpx
from langchain.tools import tool
from deerflow.config import get_app_config
logger = logging.getLogger(__name__)
_SERPER_ENDPOINT = "https://google.serper.dev/search"
_api_key_warned = False
def _get_api_key() -> str | None:
config = get_app_config().get_tool_config("web_search")
if config is not None:
api_key = config.model_extra.get("api_key")
if isinstance(api_key, str) and api_key.strip():
return api_key
return os.getenv("SERPER_API_KEY")
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str, max_results: int = 5) -> str:
"""Search the web for information using Google Search via Serper.
Args:
query: Search keywords describing what you want to find. Be specific for better results.
max_results: Maximum number of search results to return. Default is 5.
"""
global _api_key_warned
config = get_app_config().get_tool_config("web_search")
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results", max_results)
api_key = _get_api_key()
if not api_key:
if not _api_key_warned:
_api_key_warned = True
logger.warning("Serper API key is not set. Set SERPER_API_KEY in your environment or provide api_key in config.yaml. Sign up at https://serper.dev")
return json.dumps(
{"error": "SERPER_API_KEY is not configured", "query": query},
ensure_ascii=False,
)
headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}
payload = {"q": query, "num": max_results}
try:
with httpx.Client(timeout=30) as client:
response = client.post(_SERPER_ENDPOINT, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
except httpx.HTTPStatusError as e:
logger.error(f"Serper API returned HTTP {e.response.status_code}: {e.response.text}")
return json.dumps(
{"error": f"Serper API error: HTTP {e.response.status_code}", "query": query},
ensure_ascii=False,
)
except Exception as e:
logger.error(f"Serper search failed: {type(e).__name__}: {e}")
return json.dumps({"error": str(e), "query": query}, ensure_ascii=False)
organic = data.get("organic", [])
if not organic:
return json.dumps({"error": "No results found", "query": query}, ensure_ascii=False)
normalized_results = [
{
"title": r.get("title", ""),
"url": r.get("link", ""),
"content": r.get("snippet", ""),
}
for r in organic[:max_results]
]
output = {
"query": query,
"total_results": len(normalized_results),
"results": normalized_results,
}
return json.dumps(output, indent=2, ensure_ascii=False)
@@ -6,13 +6,6 @@ from pydantic import BaseModel, Field
from deerflow.config.runtime_paths import project_root, resolve_path
def _legacy_skills_candidates() -> tuple[Path, ...]:
"""Return source-tree skills locations for monorepo compatibility."""
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
return (repo_root / "skills",)
class SkillsConfig(BaseModel):
"""Configuration for skills system"""
@@ -22,7 +15,7 @@ class SkillsConfig(BaseModel):
)
path: str | None = Field(
default=None,
description=("Path to skills directory. If not specified, defaults to `skills` under the caller project root, falling back to the legacy repo-root location for monorepo compatibility."),
description="Path to skills directory. If not specified, defaults to skills under the caller project root.",
)
container_path: str = Field(
default="/mnt/skills",
@@ -33,30 +26,15 @@ class SkillsConfig(BaseModel):
"""
Get the resolved skills directory path.
Resolution order:
1. Explicit ``path`` field
2. ``DEER_FLOW_SKILLS_PATH`` environment variable
3. ``skills`` under the caller project root (``project_root()``)
4. Legacy repo-root candidates for monorepo compatibility (``_legacy_skills_candidates``)
When none of (3) or (4) exist on disk, the project-root default is returned so callers
can still surface a stable "no skills" location without raising.
Returns:
Path to the skills directory
"""
if self.path:
# Use configured path (can be absolute or relative to project root)
return resolve_path(self.path)
if env_path := os.getenv("DEER_FLOW_SKILLS_PATH"):
return resolve_path(env_path)
project_default = project_root() / "skills"
if project_default.is_dir():
return project_default
for candidate in _legacy_skills_candidates():
if candidate.is_dir():
return candidate
return project_default
return project_root() / "skills"
def get_skill_container_path(self, skill_name: str, category: str = "public") -> str:
"""
@@ -27,34 +27,6 @@ from deerflow.models.credential_loader import CodexCliCredential, load_codex_cli
logger = logging.getLogger(__name__)
CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
def _build_usage_metadata(oai_usage: dict) -> dict:
"""Convert Codex/Responses API usage dict to LangChain usage_metadata format.
Maps OpenAI Responses API token usage fields to the dict structure that
LangChain AIMessage.usage_metadata expects. This avoids depending on
langchain_openai private helpers like ``_create_usage_metadata_responses``.
"""
input_tokens = oai_usage.get("input_tokens", 0)
output_tokens = oai_usage.get("output_tokens", 0)
total_tokens = oai_usage.get("total_tokens", input_tokens + output_tokens)
metadata: dict = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}
input_details = oai_usage.get("input_tokens_details") or {}
output_details = oai_usage.get("output_tokens_details") or {}
cache_read = input_details.get("cached_tokens")
if cache_read is not None:
metadata["input_token_details"] = {"cache_read": cache_read}
reasoning = output_details.get("reasoning_tokens")
if reasoning is not None:
metadata["output_token_details"] = {"reasoning": reasoning}
return metadata
MAX_RETRIES = 3
@@ -374,7 +346,6 @@ class CodexChatModel(BaseChatModel):
)
usage = response.get("usage", {})
usage_metadata = _build_usage_metadata(usage) if usage else None
additional_kwargs = {}
if reasoning_content:
additional_kwargs["reasoning_content"] = reasoning_content
@@ -384,7 +355,6 @@ class CodexChatModel(BaseChatModel):
tool_calls=tool_calls if tool_calls else [],
invalid_tool_calls=invalid_tool_calls,
additional_kwargs=additional_kwargs,
usage_metadata=usage_metadata,
response_metadata={
"model": response.get("model", self.model),
"usage": usage,
@@ -7,13 +7,13 @@ router for thread records.
from __future__ import annotations
import time
from typing import Any
from langgraph.store.base import BaseStore
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
from deerflow.utils.time import coerce_iso, now_iso
THREADS_NS: tuple[str, ...] = ("threads",)
@@ -48,7 +48,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
metadata: dict | None = None,
) -> dict:
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
now = now_iso()
now = time.time()
record: dict[str, Any] = {
"thread_id": thread_id,
"assistant_id": assistant_id,
@@ -106,7 +106,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
if record is None:
return
record["display_name"] = display_name
record["updated_at"] = now_iso()
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
@@ -114,7 +114,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
if record is None:
return
record["status"] = status
record["updated_at"] = now_iso()
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
@@ -124,7 +124,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
merged = dict(record.get("metadata") or {})
merged.update(metadata)
record["metadata"] = merged
record["updated_at"] = now_iso()
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
@@ -144,8 +144,6 @@ class MemoryThreadMetaStore(ThreadMetaStore):
"display_name": val.get("display_name"),
"status": val.get("status", "idle"),
"metadata": val.get("metadata", {}),
# ``coerce_iso`` heals legacy unix-second values written by
# earlier Gateway versions that called ``str(time.time())``.
"created_at": coerce_iso(val.get("created_at", "")),
"updated_at": coerce_iso(val.get("updated_at", "")),
"created_at": str(val.get("created_at", "")),
"updated_at": str(val.get("updated_at", "")),
}
@@ -25,7 +25,7 @@ from collections.abc import Iterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
@@ -98,9 +98,78 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
_checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive
_explicit_checkpointers: dict[int, Checkpointer] = {}
_explicit_checkpointer_contexts: dict[int, object] = {}
def get_checkpointer() -> Checkpointer:
def _default_in_memory_checkpointer() -> Checkpointer:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
return InMemorySaver()
def _persistent_database_backend(db_config) -> str | None:
backend = getattr(db_config, "backend", None)
if backend in {"sqlite", "postgres"}:
return backend
return None
@contextlib.contextmanager
def _sync_checkpointer_from_database_cm(db_config) -> Iterator[Checkpointer]:
"""Context manager that creates a sync checkpointer from unified DatabaseConfig."""
backend = _persistent_database_backend(db_config)
if backend is None:
yield _default_in_memory_checkpointer()
return
if backend == "sqlite":
try:
from langgraph.checkpoint.sqlite import SqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = db_config.checkpointer_sqlite_path
ensure_sqlite_parent_dir(conn_str)
with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
yield saver
return
if backend == "postgres":
try:
from langgraph.checkpoint.postgres import PostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not db_config.postgres_url:
raise ValueError("database.postgres_url is required for the postgres backend")
with PostgresSaver.from_conn_string(db_config.postgres_url) as saver:
saver.setup()
logger.info("Checkpointer: using PostgresSaver")
yield saver
return
raise ValueError(f"Unknown database backend: {backend!r}")
def _build_checkpointer_from_app_config(app_config: AppConfig) -> tuple[Checkpointer, object | None]:
if app_config.checkpointer is not None:
ctx = _sync_checkpointer_cm(app_config.checkpointer)
return ctx.__enter__(), ctx
db_config = getattr(app_config, "database", None)
if _persistent_database_backend(db_config) is not None:
ctx = _sync_checkpointer_from_database_cm(db_config)
return ctx.__enter__(), ctx
return _default_in_memory_checkpointer(), None
def get_checkpointer(app_config: AppConfig | None = None) -> Checkpointer:
"""Return the global sync checkpointer singleton, creating it on first call.
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
@@ -111,6 +180,18 @@ def get_checkpointer() -> Checkpointer:
"""
global _checkpointer, _checkpointer_ctx
if app_config is not None:
cache_key = id(app_config)
cached = _explicit_checkpointers.get(cache_key)
if cached is not None:
return cached
explicit_checkpointer, explicit_ctx = _build_checkpointer_from_app_config(app_config)
_explicit_checkpointers[cache_key] = explicit_checkpointer
if explicit_ctx is not None:
_explicit_checkpointer_contexts[cache_key] = explicit_ctx
return explicit_checkpointer
if _checkpointer is not None:
return _checkpointer
@@ -121,28 +202,30 @@ def get_checkpointer() -> Checkpointer:
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
global_app_config = _app_config
if config is None and _app_config is None:
if config is None and global_app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
global_app_config = get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
_checkpointer = InMemorySaver()
if config is not None:
_checkpointer_ctx = _sync_checkpointer_cm(config)
_checkpointer = _checkpointer_ctx.__enter__()
return _checkpointer
_checkpointer_ctx = _sync_checkpointer_cm(config)
_checkpointer = _checkpointer_ctx.__enter__()
if global_app_config is not None:
_checkpointer, _checkpointer_ctx = _build_checkpointer_from_app_config(global_app_config)
return _checkpointer
_checkpointer = _default_in_memory_checkpointer()
return _checkpointer
@@ -161,6 +244,18 @@ def reset_checkpointer() -> None:
_checkpointer_ctx = None
_checkpointer = None
for cache_key, ctx in list(_explicit_checkpointer_contexts.items()):
try:
ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during explicit checkpointer cleanup", exc_info=True)
finally:
_explicit_checkpointer_contexts.pop(cache_key, None)
_explicit_checkpointers.pop(cache_key, None)
_explicit_checkpointers.clear()
_explicit_checkpointer_contexts.clear()
# ---------------------------------------------------------------------------
# Sync context manager
@@ -168,7 +263,7 @@ def reset_checkpointer() -> None:
@contextlib.contextmanager
def checkpointer_context() -> Iterator[Checkpointer]:
def checkpointer_context(app_config: AppConfig | None = None) -> Iterator[Checkpointer]:
"""Sync context manager that yields a checkpointer and cleans up on exit.
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
@@ -181,12 +276,16 @@ def checkpointer_context() -> Iterator[Checkpointer]:
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
resolved_app_config = app_config or get_app_config()
if resolved_app_config.checkpointer is not None:
with _sync_checkpointer_cm(resolved_app_config.checkpointer) as saver:
yield saver
return
with _sync_checkpointer_cm(config.checkpointer) as saver:
yield saver
db_config = getattr(resolved_app_config, "database", None)
if _persistent_database_backend(db_config) is not None:
with _sync_checkpointer_from_database_cm(db_config) as saver:
yield saver
return
yield _default_in_memory_checkpointer()
@@ -6,10 +6,9 @@ import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from deerflow.utils.time import now_iso as _now_iso
from .schemas import DisconnectMode, RunStatus
if TYPE_CHECKING:
@@ -18,6 +17,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _now_iso() -> str:
return datetime.now(UTC).isoformat()
@dataclass
class RunRecord:
"""Mutable record for a single run."""
@@ -23,8 +23,6 @@ from dataclasses import dataclass, field
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, cast
from langgraph.checkpoint.base import empty_checkpoint
if TYPE_CHECKING:
from langchain_core.messages import HumanMessage
@@ -444,12 +442,6 @@ async def _rollback_to_pre_run_checkpoint(
if checkpoint_to_restore.get("id") is None:
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
return
restore_marker = _new_checkpoint_marker()
checkpoint_to_restore = {
**checkpoint_to_restore,
"id": restore_marker["id"],
"ts": restore_marker["ts"],
}
metadata = pre_run_snapshot.get("metadata", {})
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
@@ -501,11 +493,6 @@ async def _rollback_to_pre_run_checkpoint(
)
def _new_checkpoint_marker() -> dict[str, str]:
marker = empty_checkpoint()
return {"id": marker["id"], "ts": marker["ts"]}
def _lg_mode_to_sse_event(mode: str) -> str:
"""Map LangGraph internal stream_mode name to SSE event name.
@@ -26,7 +26,7 @@ from collections.abc import Iterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -98,9 +98,26 @@ def _sync_store_cm(config) -> Iterator[BaseStore]:
_store: BaseStore | None = None
_store_ctx = None # open context manager keeping the connection alive
_explicit_stores: dict[int, BaseStore] = {}
_explicit_store_contexts: dict[int, object] = {}
def get_store() -> BaseStore:
def _default_in_memory_store() -> BaseStore:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
return InMemoryStore()
def _build_store_from_app_config(app_config: AppConfig) -> tuple[BaseStore, object | None]:
if app_config.checkpointer is not None:
ctx = _sync_store_cm(app_config.checkpointer)
return ctx.__enter__(), ctx
return _default_in_memory_store(), None
def get_store(app_config: AppConfig | None = None) -> BaseStore:
"""Return the global sync Store singleton, creating it on first call.
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
@@ -112,6 +129,18 @@ def get_store() -> BaseStore:
"""
global _store, _store_ctx
if app_config is not None:
cache_key = id(app_config)
cached = _explicit_stores.get(cache_key)
if cached is not None:
return cached
explicit_store, explicit_ctx = _build_store_from_app_config(app_config)
_explicit_stores[cache_key] = explicit_store
if explicit_ctx is not None:
_explicit_store_contexts[cache_key] = explicit_ctx
return explicit_store
if _store is not None:
return _store
@@ -130,10 +159,7 @@ def get_store() -> BaseStore:
config = get_checkpointer_config()
if config is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
_store = InMemoryStore()
_store = _default_in_memory_store()
return _store
_store_ctx = _sync_store_cm(config)
@@ -156,6 +182,18 @@ def reset_store() -> None:
_store_ctx = None
_store = None
for cache_key, ctx in list(_explicit_store_contexts.items()):
try:
ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during explicit store cleanup", exc_info=True)
finally:
_explicit_store_contexts.pop(cache_key, None)
_explicit_stores.pop(cache_key, None)
_explicit_stores.clear()
_explicit_store_contexts.clear()
# ---------------------------------------------------------------------------
# Sync context manager
@@ -163,7 +201,7 @@ def reset_store() -> None:
@contextlib.contextmanager
def store_context() -> Iterator[BaseStore]:
def store_context(app_config: AppConfig | None = None) -> Iterator[BaseStore]:
"""Sync context manager that yields a Store and cleans up on exit.
Unlike :func:`get_store`, this does **not** cache the instance — each
@@ -176,13 +214,10 @@ def store_context() -> Iterator[BaseStore]:
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
resolved_app_config = app_config or get_app_config()
if resolved_app_config.checkpointer is None:
yield _default_in_memory_store()
return
with _sync_store_cm(config.checkpointer) as store:
with _sync_store_cm(resolved_app_config.checkpointer) as store:
yield store
@@ -4,10 +4,8 @@ Pure business logic — no FastAPI/HTTP dependencies.
Both Gateway and Client delegate to these functions.
"""
import errno
import os
import re
import stat
from pathlib import Path
from urllib.parse import quote
@@ -19,10 +17,6 @@ class PathTraversalError(ValueError):
"""Raised when a path escapes its allowed base directory."""
class UnsafeUploadPathError(ValueError):
"""Raised when an upload destination is not a safe regular file path."""
# thread_id must be alphanumeric, hyphens, underscores, or dots only.
_SAFE_THREAD_ID = re.compile(r"^[a-zA-Z0-9._-]+$")
@@ -115,64 +109,6 @@ def validate_path_traversal(path: Path, base: Path) -> None:
raise PathTraversalError("Path traversal detected") from None
def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, object]:
"""Open an upload destination for safe streaming writes.
Upload directories may be mounted into local sandboxes. A sandbox process can
therefore leave a symlink at a future upload filename. Normal ``Path.write_bytes``
follows that link and can overwrite files outside the uploads directory with
gateway privileges. This helper rejects symlink destinations and uses
``O_NOFOLLOW`` where available so the final path component cannot be raced into
a symlink between validation and open.
"""
safe_name = normalize_filename(filename)
dest = base_dir / safe_name
try:
st = os.lstat(dest)
except FileNotFoundError:
st = None
if st is not None and not stat.S_ISREG(st.st_mode):
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
validate_path_traversal(dest, base_dir)
if not hasattr(os, "O_NOFOLLOW"):
raise UnsafeUploadPathError("Upload writes require O_NOFOLLOW support")
flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW
if hasattr(os, "O_NONBLOCK"):
flags |= os.O_NONBLOCK
try:
fd = os.open(dest, flags, 0o600)
except OSError as exc:
if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
raise
try:
opened_stat = os.fstat(fd)
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1:
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
os.ftruncate(fd, 0)
fh = os.fdopen(fd, "wb")
fd = -1
finally:
if fd >= 0:
os.close(fd)
return dest, fh
def write_upload_file_no_symlink(base_dir: Path, filename: str, data: bytes) -> Path:
"""Write upload bytes without following a pre-existing destination symlink."""
dest, fh = open_upload_file_no_symlink(base_dir, filename)
with fh:
fh.write(data)
return dest
def list_files_in_dir(directory: Path) -> dict:
"""List files (not directories) in *directory*.
@@ -1,75 +0,0 @@
"""ISO 8601 timestamp helpers for the Gateway and embedded runtime.
DeerFlow stores and serializes thread/run timestamps as ISO 8601 UTC
strings to match the LangGraph Platform schema (see
``langgraph_sdk.schema.Thread``, where ``created_at`` / ``updated_at``
are ``datetime`` and JSON-encode to ISO 8601). All timestamp generation
should funnel through :func:`now_iso` so the wire format stays
consistent across endpoints, the embedded ``RunManager``, and the
checkpoint metadata written by the Gateway.
:func:`coerce_iso` provides a forward-compatible read path for legacy
records that historically stored ``str(time.time())`` floats.
"""
from __future__ import annotations
import re
from datetime import UTC, datetime
__all__ = ["coerce_iso", "now_iso"]
_UNIX_TIMESTAMP_PATTERN = re.compile(r"^\d{10}(?:\.\d+)?$")
"""Matches the unix-timestamp string shape historically written by
``str(time.time())`` (10-digit seconds with optional fractional part).
The 10-digit anchor avoids accidentally rewriting ISO years like
``"2026"`` and stays valid until the year 2286.
"""
def now_iso() -> str:
"""Return the current UTC time as an ISO 8601 string.
Example: ``"2026-04-27T03:19:46.511479+00:00"``.
"""
return datetime.now(UTC).isoformat()
def coerce_iso(value: object) -> str:
"""Best-effort coerce a stored timestamp to an ISO 8601 string.
Translates legacy unix-timestamp floats / strings written by older
DeerFlow versions into ISO without a one-shot migration. ISO strings
pass through unchanged; ``datetime`` instances are normalised to UTC
(tz-naive values are assumed to be UTC) and emitted via
``isoformat()`` so the wire format always uses the ``T`` separator;
empty values become ``""``; unrecognised values are stringified as a
last resort.
"""
if value is None or value == "":
return ""
if isinstance(value, bool):
# ``bool`` is a subclass of ``int`` — treat as garbage, not 0/1.
return str(value)
if isinstance(value, datetime):
# ``datetime`` must be handled before the ``int``/``float`` check;
# str(datetime) would produce ``"YYYY-MM-DD HH:MM:SS+00:00"``
# (space separator), which breaks strict ISO 8601 consumers.
if value.tzinfo is None:
value = value.replace(tzinfo=UTC)
else:
value = value.astimezone(UTC)
return value.isoformat()
if isinstance(value, (int, float)):
try:
return datetime.fromtimestamp(float(value), UTC).isoformat()
except (ValueError, OverflowError, OSError):
return str(value)
if isinstance(value, str):
if _UNIX_TIMESTAMP_PATTERN.match(value):
try:
return datetime.fromtimestamp(float(value), UTC).isoformat()
except (ValueError, OverflowError, OSError):
return value
return value
return str(value)
+1
View File
@@ -47,3 +47,4 @@ members = ["packages/harness"]
[tool.uv.sources]
deerflow-harness = { workspace = true }
+1 -105
View File
@@ -3,12 +3,11 @@
from __future__ import annotations
import asyncio
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
from app.channels.base import Channel
from app.channels.message_bus import InboundMessage, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.message_bus import MessageBus, OutboundMessage, ResolvedAttachment
def _run(coro):
@@ -249,109 +248,6 @@ class TestResolveAttachments:
assert result[0].filename == "data.csv"
# ---------------------------------------------------------------------------
# Inbound file ingestion tests
# ---------------------------------------------------------------------------
class TestInboundFileIngestion:
def test_rejects_preexisting_symlink_destination(self, tmp_path):
from app.channels import manager
uploads_dir = tmp_path / "uploads"
uploads_dir.mkdir()
outside_file = tmp_path / "outside-created.txt"
(uploads_dir / "victim.txt").symlink_to(outside_file)
msg = InboundMessage(
channel_name="test-channel",
chat_id="chat-1",
user_id="user-1",
text="see attachment",
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
)
async def fake_reader(file_info, client):
return b"attacker data"
with (
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
):
result = _run(manager._ingest_inbound_files("thread-1", msg))
assert result == []
assert not outside_file.exists()
assert (uploads_dir / "victim.txt").is_symlink()
def test_rejects_dangling_symlink_destination(self, tmp_path):
from app.channels import manager
uploads_dir = tmp_path / "uploads"
uploads_dir.mkdir()
missing_target = tmp_path / "missing-created.txt"
(uploads_dir / "victim.txt").symlink_to(missing_target)
msg = InboundMessage(
channel_name="test-channel",
chat_id="chat-1",
user_id="user-1",
text="see attachment",
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
)
async def fake_reader(file_info, client):
return b"attacker data"
with (
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
):
result = _run(manager._ingest_inbound_files("thread-1", msg))
assert result == []
assert not missing_target.exists()
assert (uploads_dir / "victim.txt").is_symlink()
def test_hardlinked_existing_file_is_not_overwritten(self, tmp_path):
from app.channels import manager
uploads_dir = tmp_path / "uploads"
uploads_dir.mkdir()
outside_file = tmp_path / "outside-created.txt"
outside_file.write_text("protected", encoding="utf-8")
os.link(outside_file, uploads_dir / "victim.txt")
msg = InboundMessage(
channel_name="test-channel",
chat_id="chat-1",
user_id="user-1",
text="see attachment",
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
)
async def fake_reader(file_info, client):
return b"new attachment data"
with (
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
):
result = _run(manager._ingest_inbound_files("thread-1", msg))
assert result == [
{
"filename": "victim_1.txt",
"size": len(b"new attachment data"),
"path": "/mnt/user-data/uploads/victim_1.txt",
"is_image": False,
}
]
assert outside_file.read_text(encoding="utf-8") == "protected"
assert (uploads_dir / "victim.txt").read_text(encoding="utf-8") == "protected"
assert (uploads_dir / "victim_1.txt").read_bytes() == b"new attachment data"
# ---------------------------------------------------------------------------
# Channel base class _on_outbound with attachments
# ---------------------------------------------------------------------------
+48
View File
@@ -1,6 +1,7 @@
"""Unit tests for checkpointer config and singleton factory."""
import sys
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -103,6 +104,53 @@ class TestGetCheckpointer:
cp2 = get_checkpointer()
assert cp1 is not cp2
def test_explicit_app_config_bypasses_global_config_lookup(self):
from langgraph.checkpoint.memory import InMemorySaver
explicit_config = SimpleNamespace(
checkpointer=CheckpointerConfig(type="memory"),
database=SimpleNamespace(backend="memory"),
)
with patch(
"deerflow.runtime.checkpointer.provider.get_app_config",
side_effect=AssertionError("ambient get_app_config() must not be used when app_config is explicit"),
):
cp = get_checkpointer(app_config=explicit_config)
assert isinstance(cp, InMemorySaver)
def test_explicit_app_config_uses_unified_database_sqlite_backend(self):
explicit_config = SimpleNamespace(
checkpointer=None,
database=SimpleNamespace(backend="sqlite", checkpointer_sqlite_path="/tmp/explicit/deerflow.db"),
)
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
with (
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
patch(
"deerflow.runtime.checkpointer.provider.get_app_config",
side_effect=AssertionError("ambient get_app_config() must not be used when app_config is explicit"),
),
patch("deerflow.runtime.checkpointer.provider.ensure_sqlite_parent_dir") as mock_ensure,
):
cp = get_checkpointer(app_config=explicit_config)
assert cp is mock_saver_instance
mock_ensure.assert_called_once_with("/tmp/explicit/deerflow.db")
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/explicit/deerflow.db")
def test_sqlite_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
+22 -79
View File
@@ -437,85 +437,6 @@ class TestStream:
call_kwargs = agent.stream.call_args.kwargs
assert "messages" in call_kwargs["stream_mode"]
def test_stream_emits_additional_kwargs_updates_for_streamed_ai_messages(self, client):
"""stream() emits a follow-up AI event when attribution metadata arrives via values."""
assembled = AIMessage(
content="Hello!",
id="ai-1",
additional_kwargs={
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
},
)
agent = MagicMock()
agent.stream.return_value = iter(
[
("messages", (AIMessageChunk(content="Hello!", id="ai-1"), {})),
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
]
)
with (
patch.object(client, "_ensure_agent"),
patch.object(client, "_agent", agent),
):
events = list(client.stream("hi", thread_id="t-stream-kwargs"))
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
assert any(event.data.get("content") == "Hello!" for event in ai_events)
assert any(event.data.get("additional_kwargs", {}).get("token_usage_attribution", {}).get("kind") == "final_answer" for event in ai_events)
def test_stream_emits_new_additional_kwargs_after_prior_metadata(self, client):
"""stream() emits later attribution metadata even after earlier kwargs for the same id."""
attribution = {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
assembled = AIMessage(
content="Hello!",
id="ai-1",
additional_kwargs={
"reasoning_content": "Thinking first.",
"token_usage_attribution": attribution,
},
)
agent = MagicMock()
agent.stream.return_value = iter(
[
(
"messages",
(
AIMessageChunk(
content="Hello!",
id="ai-1",
additional_kwargs={"reasoning_content": "Thinking first."},
),
{},
),
),
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
]
)
with (
patch.object(client, "_ensure_agent"),
patch.object(client, "_agent", agent),
):
events = list(client.stream("hi", thread_id="t-stream-kwargs-delta"))
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
metadata_events = [event for event in ai_events if event.data.get("additional_kwargs")]
assert metadata_events[0].data["additional_kwargs"] == {"reasoning_content": "Thinking first."}
assert metadata_events[1].data["content"] == ""
assert metadata_events[1].data["additional_kwargs"] == {"token_usage_attribution": attribution}
def test_chat_accumulates_streamed_deltas(self, client):
"""chat() concatenates per-id deltas from messages mode."""
agent = MagicMock()
@@ -927,6 +848,28 @@ class TestEnsureAgent:
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"}
def test_threads_explicit_app_config_to_dependencies(self, client):
"""Client-owned AppConfig must flow into model/tool/prompt/checkpointer composition."""
mock_agent = MagicMock()
mock_checkpointer = MagicMock()
config = client._get_runnable_config("t1")
with (
patch("deerflow.client.create_chat_model", return_value=MagicMock()) as mock_create_chat_model,
patch("deerflow.client.create_agent", return_value=mock_agent),
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
patch("deerflow.tools.get_available_tools", return_value=[]) as mock_get_available_tools,
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer) as mock_get_checkpointer,
):
client._ensure_agent(config)
assert mock_create_chat_model.call_args.kwargs["app_config"] is client._app_config
assert mock_build_middlewares.call_args.kwargs["app_config"] is client._app_config
assert mock_apply_prompt.call_args.kwargs["app_config"] is client._app_config
assert mock_get_available_tools.call_args.kwargs["app_config"] is client._app_config
assert mock_get_checkpointer.call_args.kwargs["app_config"] is client._app_config
def test_uses_default_checkpointer_when_available(self, client):
mock_agent = MagicMock()
mock_checkpointer = MagicMock()
@@ -1,53 +0,0 @@
"""Tests for DeerFlowClient message serialization helpers."""
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.client import DeerFlowClient
def test_serialize_ai_message_preserves_additional_kwargs():
message = AIMessage(
content="done",
additional_kwargs={
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
},
usage_metadata={"input_tokens": 12, "output_tokens": 3, "total_tokens": 15},
)
serialized = DeerFlowClient._serialize_message(message)
assert serialized["type"] == "ai"
assert serialized["usage_metadata"] == {
"input_tokens": 12,
"output_tokens": 3,
"total_tokens": 15,
}
assert serialized["additional_kwargs"] == {
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
}
def test_serialize_human_message_preserves_additional_kwargs():
message = HumanMessage(
content="hello",
additional_kwargs={"files": [{"name": "diagram.png"}]},
)
serialized = DeerFlowClient._serialize_message(message)
assert serialized == {
"type": "human",
"content": "hello",
"id": None,
"additional_kwargs": {"files": [{"name": "diagram.png"}]},
}
-30
View File
@@ -82,36 +82,6 @@ def test_parse_response_text_content():
assert result.generations[0].message.content == "Hello world"
def test_parse_response_populates_usage_metadata():
model = _make_model()
response = {
"output": [
{
"type": "message",
"content": [{"type": "output_text", "text": "Hello world"}],
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
"input_tokens_details": {"cached_tokens": 3},
"output_tokens_details": {"reasoning_tokens": 2},
},
"model": "gpt-5.4",
}
result = model._parse_response(response)
meta = result.generations[0].message.usage_metadata
assert meta is not None
assert meta["input_tokens"] == 10
assert meta["output_tokens"] == 5
assert meta["total_tokens"] == 15
assert meta["input_token_details"]["cache_read"] == 3
assert meta["output_token_details"]["reasoning"] == 2
def test_parse_response_reasoning_content():
model = _make_model()
response = {
+16 -61
View File
@@ -3,8 +3,6 @@ from types import SimpleNamespace
from unittest.mock import AsyncMock, call
import pytest
from langgraph.checkpoint.base import empty_checkpoint
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.runtime.runs.manager import RunManager
from deerflow.runtime.runs.schemas import RunStatus
@@ -18,14 +16,6 @@ class FakeCheckpointer:
self.aput_writes = AsyncMock()
def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int):
checkpoint = empty_checkpoint()
checkpoint["id"] = checkpoint_id
checkpoint["channel_values"] = {"messages": messages}
checkpoint["channel_versions"] = {"messages": version}
return checkpoint
def test_build_runtime_context_includes_app_config_when_present():
app_config = object()
@@ -120,16 +110,16 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
)
checkpointer.adelete_thread.assert_not_awaited()
checkpointer.aput.assert_awaited_once()
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
assert restored_checkpoint["id"] != "ckpt-1"
assert "channel_versions" in restored_checkpoint
assert "channel_values" in restored_checkpoint
assert restored_checkpoint["channel_versions"] == {"messages": 3}
assert restored_checkpoint["channel_values"] == {"messages": ["before"]}
assert restored_metadata == {"source": "input"}
assert new_versions == {"messages": 3}
checkpointer.aput.assert_awaited_once_with(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
{
"id": "ckpt-1",
"channel_versions": {"messages": 3},
"channel_values": {"messages": ["before"]},
},
{"source": "input"},
{"messages": 3},
)
assert checkpointer.aput_writes.await_args_list == [
call(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
@@ -144,40 +134,6 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
]
@pytest.mark.anyio
async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer():
checkpointer = InMemorySaver()
thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
before_checkpoint = _make_checkpoint("0001", ["before"], 1)
before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1})
after_checkpoint = _make_checkpoint("0002", ["after"], 2)
after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2})
checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after")
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="0001",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": before_checkpoint,
"metadata": {"step": 1},
"pending_writes": [("task-before", "messages", "pending-before")],
},
snapshot_capture_failed=False,
)
latest = checkpointer.get_tuple(thread_config)
assert latest is not None
assert latest.config["configurable"]["checkpoint_id"] != "0001"
assert latest.config["configurable"]["checkpoint_id"] != "0002"
assert latest.checkpoint["channel_values"] == {"messages": ["before"]}
assert latest.pending_writes == [("task-before", "messages", "pending-before")]
assert ("task-after", "messages", "pending-after") not in latest.pending_writes
@pytest.mark.anyio
async def test_rollback_deletes_thread_when_no_snapshot_exists():
checkpointer = FakeCheckpointer(put_result=None)
@@ -238,13 +194,12 @@ async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
snapshot_capture_failed=False,
)
checkpointer.aput.assert_awaited_once()
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
assert restored_checkpoint["id"] != "ckpt-1"
assert restored_checkpoint["channel_versions"] == {}
assert restored_metadata == {}
assert new_versions == {}
checkpointer.aput.assert_awaited_once_with(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
{"id": "ckpt-1", "channel_versions": {}},
{},
{},
)
@pytest.mark.anyio
-36
View File
@@ -7,7 +7,6 @@ import yaml
from deerflow.config import app_config as app_config_module
from deerflow.config import extensions_config as extensions_config_module
from deerflow.config import skills_config as skills_config_module
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.paths import Paths
@@ -36,7 +35,6 @@ def test_default_runtime_paths_resolve_from_current_project(tmp_path: Path, monk
encoding="utf-8",
)
(tmp_path / "extensions_config.json").write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8")
(tmp_path / "skills").mkdir()
assert AppConfig.resolve_config_path() == tmp_path / "config.yaml"
assert ExtensionsConfig.resolve_config_path() == tmp_path / "extensions_config.json"
@@ -123,40 +121,6 @@ def test_app_config_falls_back_to_legacy_when_project_root_lacks_config(tmp_path
assert AppConfig.resolve_config_path() == legacy_backend_config
def test_skills_config_falls_back_to_legacy_when_project_root_lacks_skills(tmp_path: Path, monkeypatch):
"""When DEER_FLOW_PROJECT_ROOT is unset and cwd has no `skills/`, the legacy
repo-root candidate must be used so monorepo runs (cwd=backend/) keep finding
`<repo>/skills` instead of `<repo>/backend/skills` (regression test for #2694)."""
_clear_path_env(monkeypatch)
cwd = tmp_path / "cwd"
cwd.mkdir()
monkeypatch.chdir(cwd)
legacy_skills = tmp_path / "legacy-repo" / "skills"
legacy_skills.mkdir(parents=True)
monkeypatch.setattr(
skills_config_module,
"_legacy_skills_candidates",
lambda: (legacy_skills,),
)
assert SkillsConfig().get_skills_path() == legacy_skills
def test_skills_config_returns_project_default_when_neither_exists(tmp_path: Path, monkeypatch):
"""When nothing exists, fall back to the project-root default path so callers
surface a stable empty location instead of silently picking a stale legacy dir."""
_clear_path_env(monkeypatch)
cwd = tmp_path / "cwd"
cwd.mkdir()
monkeypatch.chdir(cwd)
monkeypatch.setattr(skills_config_module, "_legacy_skills_candidates", lambda: ())
assert SkillsConfig().get_skills_path() == cwd / "skills"
def test_extensions_config_falls_back_to_legacy_when_project_root_lacks_file(tmp_path: Path, monkeypatch):
"""ExtensionsConfig should hit the legacy backend/repo-root locations when
the caller project root has no extensions_config.json/mcp_config.json."""
-308
View File
@@ -1,308 +0,0 @@
"""Unit tests for the Serper community web search tool."""
import json
from unittest.mock import MagicMock, patch
import httpx
import pytest
@pytest.fixture(autouse=True)
def reset_api_key_warned():
"""Reset the module-level warning flag before each test."""
import deerflow.community.serper.tools as serper_mod
serper_mod._api_key_warned = False
yield
serper_mod._api_key_warned = False
@pytest.fixture
def mock_config_with_key():
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "test-serper-key", "max_results": 5}
mock.return_value.get_tool_config.return_value = tool_config
yield mock
@pytest.fixture
def mock_config_no_key():
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {}
mock.return_value.get_tool_config.return_value = tool_config
yield mock
def _make_serper_response(organic: list) -> MagicMock:
mock_resp = MagicMock()
mock_resp.json.return_value = {"organic": organic}
mock_resp.raise_for_status = MagicMock()
return mock_resp
class TestGetApiKey:
def test_returns_config_key_when_present(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "from-config"}
mock.return_value.get_tool_config.return_value = tool_config
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "from-config"
def test_falls_back_to_env_when_config_key_empty(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": ""}
mock.return_value.get_tool_config.return_value = tool_config
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "env-key"
def test_falls_back_to_env_when_config_key_whitespace(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": " "}
mock.return_value.get_tool_config.return_value = tool_config
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "env-key"
def test_falls_back_to_env_when_config_key_null(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": None}
mock.return_value.get_tool_config.return_value = tool_config
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "env-key"
def test_falls_back_to_env_when_no_config(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
mock.return_value.get_tool_config.return_value = None
with patch.dict("os.environ", {"SERPER_API_KEY": "env-only"}):
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "env-only"
def test_returns_none_when_no_key_anywhere(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
mock.return_value.get_tool_config.return_value = None
with patch.dict("os.environ", {}, clear=True):
import os
os.environ.pop("SERPER_API_KEY", None)
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() is None
class TestWebSearchTool:
def test_basic_search_returns_normalized_results(self, mock_config_with_key):
organic = [
{"title": "Result 1", "link": "https://example.com/1", "snippet": "Snippet 1"},
{"title": "Result 2", "link": "https://example.com/2", "snippet": "Snippet 2"},
]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "python tutorial"})
parsed = json.loads(result)
assert parsed["query"] == "python tutorial"
assert parsed["total_results"] == 2
assert parsed["results"][0]["title"] == "Result 1"
assert parsed["results"][0]["url"] == "https://example.com/1"
assert parsed["results"][0]["content"] == "Snippet 1"
def test_respects_max_results_from_config(self, mock_config_with_key):
mock_config_with_key.return_value.get_tool_config.return_value.model_extra = {
"api_key": "test-key",
"max_results": 3,
}
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert parsed["total_results"] == 3
assert len(parsed["results"]) == 3
def test_max_results_parameter_accepted(self, mock_config_no_key):
"""Tool accepts max_results as a call parameter when config does not override it."""
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
mock_resp = _make_serper_response(organic)
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test", "max_results": 2})
parsed = json.loads(result)
assert parsed["total_results"] == 2
def test_config_max_results_overrides_parameter(self):
"""Config max_results overrides the parameter passed at call time, matching ddg_search behaviour."""
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "test-key", "max_results": 3}
mock.return_value.get_tool_config.return_value = tool_config
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test", "max_results": 8})
parsed = json.loads(result)
assert parsed["total_results"] == 3
def test_empty_organic_returns_error_json(self, mock_config_with_key):
"""Empty organic list returns structured error, matching ddg_search convention."""
mock_resp = _make_serper_response([])
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "no results"})
parsed = json.loads(result)
assert "error" in parsed
assert parsed["error"] == "No results found"
assert parsed["query"] == "no results"
def test_missing_api_key_returns_error_json(self, mock_config_no_key):
with patch.dict("os.environ", {}, clear=True):
import os
os.environ.pop("SERPER_API_KEY", None)
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert "error" in parsed
assert "SERPER_API_KEY" in parsed["error"]
def test_missing_api_key_logs_warning_once(self, mock_config_no_key, caplog):
import logging
with patch.dict("os.environ", {}, clear=True):
import os
os.environ.pop("SERPER_API_KEY", None)
from deerflow.community.serper.tools import web_search_tool
with caplog.at_level(logging.WARNING, logger="deerflow.community.serper.tools"):
web_search_tool.invoke({"query": "q1"})
web_search_tool.invoke({"query": "q2"})
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
assert len(warnings) == 1
def test_http_error_returns_structured_error(self, mock_config_with_key):
mock_error_response = MagicMock()
mock_error_response.status_code = 403
mock_error_response.text = "Forbidden"
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = httpx.HTTPStatusError("403", request=MagicMock(), response=mock_error_response)
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert "error" in parsed
assert "403" in parsed["error"]
def test_network_exception_returns_error_json(self, mock_config_with_key):
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert "error" in parsed
def test_sends_correct_headers_and_payload(self, mock_config_with_key):
organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
web_search_tool.invoke({"query": "hello world"})
call_kwargs = mock_post.call_args
headers = call_kwargs.kwargs["headers"]
payload = call_kwargs.kwargs["json"]
assert headers["X-API-KEY"] == "test-serper-key"
assert payload["q"] == "hello world"
assert payload["num"] == 5
def test_uses_env_key_when_config_absent(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
mock.return_value.get_tool_config.return_value = None
with patch.dict("os.environ", {"SERPER_API_KEY": "env-only-key"}):
organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
web_search_tool.invoke({"query": "env key test"})
headers = mock_post.call_args.kwargs["headers"]
assert headers["X-API-KEY"] == "env-only-key"
def test_partial_fields_in_organic_result(self, mock_config_with_key):
"""Missing title/link/snippet should default to empty string."""
organic = [{}]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert parsed["results"][0] == {"title": "", "url": "", "content": ""}
-1
View File
@@ -19,7 +19,6 @@ def test_get_skills_root_path_points_to_current_project_skills(tmp_path: Path, m
monkeypatch.delenv("DEER_FLOW_SKILLS_PATH", raising=False)
monkeypatch.delenv("DEER_FLOW_PROJECT_ROOT", raising=False)
monkeypatch.chdir(tmp_path)
(tmp_path / "skills").mkdir()
app_config = SimpleNamespace(skills=SkillsConfig())
path = get_or_new_skill_storage(app_config=app_config).get_skills_root_path()
+1 -296
View File
@@ -1,66 +1,12 @@
import re
from unittest.mock import patch
import pytest
from _router_auth_helpers import make_authed_test_app
from fastapi import FastAPI, HTTPException
from fastapi import HTTPException
from fastapi.testclient import TestClient
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.store.memory import InMemoryStore
from app.gateway.routers import threads
from deerflow.config.paths import Paths
from deerflow.persistence.thread_meta.memory import THREADS_NS, MemoryThreadMetaStore
_ISO_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
class _PermissiveThreadMetaStore(MemoryThreadMetaStore):
"""Memory store that skips user-id filtering for router tests.
Owner isolation is exercised separately in
``test_memory_thread_meta_isolation.py``. Router tests need to drive
the FastAPI surface end-to-end with a single fixed app user, but the
stub auth middleware in ``_router_auth_helpers`` stamps a fresh UUID
on every request, so the production filtering would reject every
pre-seeded record. Bypass that filter so the test can focus on the
timestamp wire format.
"""
async def _get_owned_record(self, thread_id, user_id, method_name): # type: ignore[override]
item = await self._store.aget(THREADS_NS, thread_id)
return dict(item.value) if item is not None else None
async def check_access(self, thread_id, user_id, *, require_existing=False): # type: ignore[override]
item = await self._store.aget(THREADS_NS, thread_id)
if item is None:
return not require_existing
return True
async def create(self, thread_id, *, assistant_id=None, user_id=None, display_name=None, metadata=None): # type: ignore[override]
return await super().create(thread_id, assistant_id=assistant_id, user_id=None, display_name=display_name, metadata=metadata)
async def search(self, *, metadata=None, status=None, limit=100, offset=0, user_id=None): # type: ignore[override]
return await super().search(metadata=metadata, status=status, limit=limit, offset=offset, user_id=None)
def _build_thread_app() -> tuple[FastAPI, InMemoryStore, InMemorySaver]:
"""Build a stub-authed FastAPI app wired with an in-memory ThreadMetaStore.
The thread_store on ``app.state`` is a permissive subclass of
``MemoryThreadMetaStore`` so tests can drive ``/api/threads``
end-to-end and pre-seed legacy records via the underlying BaseStore.
Returns ``(app, store, checkpointer)`` for direct seeding/inspection.
"""
app = make_authed_test_app()
store = InMemoryStore()
checkpointer = InMemorySaver()
app.state.store = store
app.state.checkpointer = checkpointer
app.state.thread_store = _PermissiveThreadMetaStore(store)
app.include_router(threads.router)
return app, store, checkpointer
def test_delete_thread_data_removes_thread_directory(tmp_path):
@@ -190,244 +136,3 @@ def test_strip_reserved_metadata_empty_input():
def test_strip_reserved_metadata_strips_all_reserved_keys():
out = threads._strip_reserved_metadata({"user_id": "x", "keep": "me"})
assert out == {"keep": "me"}
# ---------------------------------------------------------------------------
# ISO 8601 timestamp contract (issue #2594)
# ---------------------------------------------------------------------------
#
# Threads endpoints document ``created_at`` / ``updated_at`` as ISO
# timestamps and that is the format LangGraph Platform uses
# (``langgraph_sdk.schema.Thread.created_at: datetime`` JSON-encodes to
# ISO 8601). The tests below pin that contract end-to-end and also
# exercise the ``coerce_iso`` healing path for legacy unix-timestamp
# records written by older Gateway versions.
def test_create_thread_returns_iso_timestamps() -> None:
app, _store, _checkpointer = _build_thread_app()
with TestClient(app) as client:
response = client.post("/api/threads", json={"metadata": {}})
assert response.status_code == 200, response.text
body = response.json()
assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"]
assert _ISO_TIMESTAMP_RE.match(body["updated_at"]), body["updated_at"]
assert body["created_at"] == body["updated_at"]
def test_get_thread_returns_iso_for_legacy_unix_record() -> None:
"""A thread record written by older versions stores ``time.time()``
floats. ``get_thread`` must transparently surface them as ISO so the
frontend's ``new Date(...)`` parser does not break.
"""
app, store, checkpointer = _build_thread_app()
legacy_thread_id = "legacy-thread"
legacy_ts = "1777252410.411327"
async def _seed() -> None:
await store.aput(
THREADS_NS,
legacy_thread_id,
{
"thread_id": legacy_thread_id,
"status": "idle",
"created_at": legacy_ts,
"updated_at": legacy_ts,
"metadata": {},
},
)
from langgraph.checkpoint.base import empty_checkpoint
await checkpointer.aput(
{"configurable": {"thread_id": legacy_thread_id, "checkpoint_ns": ""}},
empty_checkpoint(),
{"step": -1, "source": "input", "writes": None, "parents": {}},
{},
)
import asyncio
asyncio.run(_seed())
with TestClient(app) as client:
response = client.get(f"/api/threads/{legacy_thread_id}")
assert response.status_code == 200, response.text
body = response.json()
assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"]
assert _ISO_TIMESTAMP_RE.match(body["updated_at"]), body["updated_at"]
def test_patch_thread_returns_iso_and_advances_updated_at() -> None:
app, store, _checkpointer = _build_thread_app()
thread_id = "patch-target"
legacy_created = "1777000000.000000"
legacy_updated = "1777000000.000000"
async def _seed() -> None:
await store.aput(
THREADS_NS,
thread_id,
{
"thread_id": thread_id,
"status": "idle",
"created_at": legacy_created,
"updated_at": legacy_updated,
"metadata": {"k": "v0"},
},
)
import asyncio
asyncio.run(_seed())
with TestClient(app) as client:
response = client.patch(f"/api/threads/{thread_id}", json={"metadata": {"k": "v1"}})
assert response.status_code == 200, response.text
body = response.json()
assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"]
assert _ISO_TIMESTAMP_RE.match(body["updated_at"]), body["updated_at"]
# Patch issues a fresh ``updated_at`` via ``MemoryThreadMetaStore.update_metadata``,
# so it must be > the migrated legacy ``created_at`` (both ISO strings
# sort lexicographically by time when the format is consistent).
assert body["updated_at"] > body["created_at"]
assert body["metadata"] == {"k": "v1"}
def test_search_threads_normalizes_legacy_unix_seconds_to_iso() -> None:
"""``MemoryThreadMetaStore`` may hold legacy ``time.time()`` floats
written by older Gateway versions. ``/search`` must surface them as
ISO via ``coerce_iso`` so the frontend's ``new Date(...)`` parser
does not break.
"""
app, store, _checkpointer = _build_thread_app()
async def _seed() -> None:
# Legacy unix-second float (the literal value from issue #2594).
await store.aput(
THREADS_NS,
"legacy",
{
"thread_id": "legacy",
"status": "idle",
"created_at": 1777000000.0,
"updated_at": 1777000000.0,
"metadata": {},
},
)
# Modern ISO string, slightly later.
await store.aput(
THREADS_NS,
"modern",
{
"thread_id": "modern",
"status": "idle",
"created_at": "2026-04-27T00:00:00+00:00",
"updated_at": "2026-04-27T00:00:00+00:00",
"metadata": {},
},
)
import asyncio
asyncio.run(_seed())
with TestClient(app) as client:
response = client.post("/api/threads/search", json={"limit": 10})
assert response.status_code == 200, response.text
items = response.json()
assert {item["thread_id"] for item in items} == {"legacy", "modern"}
for item in items:
assert _ISO_TIMESTAMP_RE.match(item["created_at"]), item
assert _ISO_TIMESTAMP_RE.match(item["updated_at"]), item
def test_memory_thread_meta_store_writes_iso_on_create() -> None:
"""``MemoryThreadMetaStore.create`` must emit ISO so newly created
threads serialize correctly without depending on the router's
``coerce_iso`` heal path.
"""
import asyncio
store = InMemoryStore()
repo = MemoryThreadMetaStore(store)
async def _scenario() -> dict:
await repo.create("fresh", user_id=None, metadata={"a": 1})
record = (await store.aget(THREADS_NS, "fresh")).value
return record
record = asyncio.run(_scenario())
assert _ISO_TIMESTAMP_RE.match(record["created_at"]), record
assert _ISO_TIMESTAMP_RE.match(record["updated_at"]), record
def test_get_thread_state_returns_iso_for_legacy_checkpoint_metadata() -> None:
"""Checkpoints written by older Gateway versions stored
``created_at`` as a unix-second float in their metadata. The
``/state`` endpoint must surface that value as ISO so the frontend's
``new Date(...)`` parser does not break same root cause as the
thread-record bug fixed in #2594, but on the checkpoint side.
"""
app, _store, checkpointer = _build_thread_app()
thread_id = "legacy-state"
async def _seed() -> None:
from langgraph.checkpoint.base import empty_checkpoint
await checkpointer.aput(
{"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}},
empty_checkpoint(),
{"step": -1, "source": "input", "writes": None, "parents": {}, "created_at": 1777252410.411327},
{},
)
import asyncio
asyncio.run(_seed())
with TestClient(app) as client:
response = client.get(f"/api/threads/{thread_id}/state")
assert response.status_code == 200, response.text
body = response.json()
assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"]
assert _ISO_TIMESTAMP_RE.match(body["checkpoint"]["ts"]), body["checkpoint"]
def test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata() -> None:
"""``/history`` walks ``checkpointer.alist`` and emits one entry per
checkpoint. Each entry's ``created_at`` must come out as ISO even if
older checkpoints stored a unix-second float in their metadata.
"""
app, _store, checkpointer = _build_thread_app()
thread_id = "legacy-history"
async def _seed() -> None:
from langgraph.checkpoint.base import empty_checkpoint
await checkpointer.aput(
{"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}},
empty_checkpoint(),
{"step": -1, "source": "input", "writes": None, "parents": {}, "created_at": 1777252410.411327},
{},
)
import asyncio
asyncio.run(_seed())
with TestClient(app) as client:
response = client.post(f"/api/threads/{thread_id}/history", json={"limit": 10})
assert response.status_code == 200, response.text
entries = response.json()
assert entries, "expected at least one history entry"
for entry in entries:
assert _ISO_TIMESTAMP_RE.match(entry["created_at"]), entry
@@ -1,157 +0,0 @@
"""Tests for TokenUsageMiddleware attribution annotations."""
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage
from deerflow.agents.middlewares.token_usage_middleware import (
TOKEN_USAGE_ATTRIBUTION_KEY,
TokenUsageMiddleware,
)
def _make_runtime():
runtime = MagicMock()
runtime.context = {"thread_id": "test-thread"}
return runtime
class TestTokenUsageMiddleware:
def test_annotates_todo_updates_with_structured_actions(self):
middleware = TokenUsageMiddleware()
message = AIMessage(
content="",
tool_calls=[
{
"id": "write_todos:1",
"name": "write_todos",
"args": {
"todos": [
{"content": "Inspect streaming path", "status": "completed"},
{"content": "Design token attribution schema", "status": "in_progress"},
]
},
}
],
usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120},
)
state = {
"messages": [message],
"todos": [
{"content": "Inspect streaming path", "status": "in_progress"},
{"content": "Design token attribution schema", "status": "pending"},
],
}
result = middleware.after_model(state, _make_runtime())
assert result is not None
updated_message = result["messages"][0]
attribution = updated_message.additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "tool_batch"
assert attribution["shared_attribution"] is True
assert attribution["tool_call_ids"] == ["write_todos:1"]
assert attribution["actions"] == [
{
"kind": "todo_complete",
"content": "Inspect streaming path",
"tool_call_id": "write_todos:1",
},
{
"kind": "todo_start",
"content": "Design token attribution schema",
"tool_call_id": "write_todos:1",
},
]
def test_annotates_subagent_and_search_steps(self):
middleware = TokenUsageMiddleware()
message = AIMessage(
content="",
tool_calls=[
{
"id": "task:1",
"name": "task",
"args": {
"description": "spec-coder patch message grouping",
"subagent_type": "general-purpose",
},
},
{
"id": "web_search:1",
"name": "web_search",
"args": {"query": "LangGraph useStream messages tuple"},
},
],
)
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "tool_batch"
assert attribution["shared_attribution"] is True
assert attribution["actions"] == [
{
"kind": "subagent",
"description": "spec-coder patch message grouping",
"subagent_type": "general-purpose",
"tool_call_id": "task:1",
},
{
"kind": "search",
"tool_name": "web_search",
"query": "LangGraph useStream messages tuple",
"tool_call_id": "web_search:1",
},
]
def test_marks_final_answer_when_no_tools(self):
middleware = TokenUsageMiddleware()
message = AIMessage(content="Here is the final answer.")
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "final_answer"
assert attribution["shared_attribution"] is False
assert attribution["actions"] == []
def test_annotates_removed_todos(self):
middleware = TokenUsageMiddleware()
message = AIMessage(
content="",
tool_calls=[
{
"id": "write_todos:remove",
"name": "write_todos",
"args": {
"todos": [],
},
}
],
)
result = middleware.after_model(
{
"messages": [message],
"todos": [
{"content": "Archive obsolete plan", "status": "pending"},
],
},
_make_runtime(),
)
assert result is not None
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "todo_update"
assert attribution["shared_attribution"] is False
assert attribution["actions"] == [
{
"kind": "todo_remove",
"content": "Archive obsolete plan",
"tool_call_id": "write_todos:remove",
}
]
-54
View File
@@ -1,20 +1,14 @@
"""Tests for deerflow.uploads.manager — shared upload management logic."""
import errno
import os
from unittest.mock import patch
import pytest
from deerflow.uploads.manager import (
PathTraversalError,
UnsafeUploadPathError,
claim_unique_filename,
delete_file_safe,
list_files_in_dir,
normalize_filename,
validate_path_traversal,
write_upload_file_no_symlink,
)
# ---------------------------------------------------------------------------
@@ -103,54 +97,6 @@ class TestValidatePathTraversal:
validate_path_traversal(link, tmp_path)
# ---------------------------------------------------------------------------
# write_upload_file_no_symlink
# ---------------------------------------------------------------------------
class TestWriteUploadFileNoSymlink:
def test_writes_new_file(self, tmp_path):
dest = write_upload_file_no_symlink(tmp_path, "notes.txt", b"hello")
assert dest == tmp_path / "notes.txt"
assert dest.read_bytes() == b"hello"
def test_overwrites_existing_regular_file_with_single_link(self, tmp_path):
dest = tmp_path / "notes.txt"
dest.write_bytes(b"old contents")
assert os.stat(dest).st_nlink == 1
result = write_upload_file_no_symlink(tmp_path, "notes.txt", b"new contents")
assert result == dest
assert dest.read_bytes() == b"new contents"
assert os.stat(dest).st_nlink == 1
def test_fails_closed_without_no_follow_support(self, tmp_path, monkeypatch):
monkeypatch.delattr(os, "O_NOFOLLOW", raising=False)
with pytest.raises(UnsafeUploadPathError, match="O_NOFOLLOW"):
write_upload_file_no_symlink(tmp_path, "notes.txt", b"hello")
assert not (tmp_path / "notes.txt").exists()
def test_open_uses_nonblocking_flag_when_available(self, tmp_path):
with patch("deerflow.uploads.manager.os.open", side_effect=OSError(errno.ENXIO, "no reader")) as open_mock:
with pytest.raises(UnsafeUploadPathError, match="Unsafe upload destination"):
write_upload_file_no_symlink(tmp_path, "pipe.txt", b"hello")
flags = open_mock.call_args.args[1]
assert flags & os.O_NONBLOCK
@pytest.mark.parametrize("open_errno", [errno.ENXIO, errno.EAGAIN])
def test_nonblocking_special_file_open_errors_are_unsafe(self, tmp_path, open_errno):
with patch("deerflow.uploads.manager.os.open", side_effect=OSError(open_errno, "would block")):
with pytest.raises(UnsafeUploadPathError, match="Unsafe upload destination"):
write_upload_file_no_symlink(tmp_path, "pipe.txt", b"hello")
assert not (tmp_path / "pipe.txt").exists()
# ---------------------------------------------------------------------------
# list_files_in_dir
# ---------------------------------------------------------------------------
-100
View File
@@ -1,5 +1,4 @@
import asyncio
import os
import stat
from io import BytesIO
from pathlib import Path
@@ -429,105 +428,6 @@ def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path):
assert [f.name for f in thread_uploads_dir.iterdir()] == ["passwd"]
def test_upload_files_rejects_preexisting_symlink_destination(tmp_path):
thread_uploads_dir = tmp_path / "uploads"
thread_uploads_dir.mkdir(parents=True)
outside_file = tmp_path / "outside.txt"
outside_file.write_text("protected", encoding="utf-8")
(thread_uploads_dir / "victim.txt").symlink_to(outside_file)
provider = MagicMock()
provider.uses_thread_data_mounts = True
with (
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "get_sandbox_provider", return_value=provider),
):
file = UploadFile(filename="victim.txt", file=BytesIO(b"attacker upload"))
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
assert result.success is False
assert result.files == []
assert result.skipped_files == ["victim.txt"]
assert "skipped 1 unsafe file" in result.message
assert outside_file.read_text(encoding="utf-8") == "protected"
assert (thread_uploads_dir / "victim.txt").is_symlink()
def test_upload_files_rejects_dangling_symlink_destination(tmp_path):
thread_uploads_dir = tmp_path / "uploads"
thread_uploads_dir.mkdir(parents=True)
missing_target = tmp_path / "missing-target.txt"
(thread_uploads_dir / "victim.txt").symlink_to(missing_target)
provider = MagicMock()
provider.uses_thread_data_mounts = True
with (
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "get_sandbox_provider", return_value=provider),
):
file = UploadFile(filename="victim.txt", file=BytesIO(b"attacker upload"))
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
assert result.success is False
assert result.files == []
assert result.skipped_files == ["victim.txt"]
assert not missing_target.exists()
assert (thread_uploads_dir / "victim.txt").is_symlink()
def test_upload_files_rejects_hardlinked_destination_without_truncating(tmp_path):
thread_uploads_dir = tmp_path / "uploads"
thread_uploads_dir.mkdir(parents=True)
outside_file = tmp_path / "outside.txt"
outside_file.write_text("protected", encoding="utf-8")
os.link(outside_file, thread_uploads_dir / "victim.txt")
provider = MagicMock()
provider.uses_thread_data_mounts = True
with (
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "get_sandbox_provider", return_value=provider),
):
file = UploadFile(filename="victim.txt", file=BytesIO(b"attacker upload"))
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
assert result.success is False
assert result.files == []
assert result.skipped_files == ["victim.txt"]
assert outside_file.read_text(encoding="utf-8") == "protected"
assert (thread_uploads_dir / "victim.txt").read_text(encoding="utf-8") == "protected"
def test_upload_files_overwrites_existing_regular_file(tmp_path):
thread_uploads_dir = tmp_path / "uploads"
thread_uploads_dir.mkdir(parents=True)
existing_file = thread_uploads_dir / "notes.txt"
existing_file.write_bytes(b"old upload")
assert existing_file.stat().st_nlink == 1
provider = MagicMock()
provider.uses_thread_data_mounts = True
with (
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "get_sandbox_provider", return_value=provider),
):
file = UploadFile(filename="notes.txt", file=BytesIO(b"new upload"))
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
assert result.success is True
assert [file_info["filename"] for file_info in result.files] == ["notes.txt"]
assert existing_file.read_bytes() == b"new upload"
assert existing_file.stat().st_nlink == 1
def test_delete_uploaded_file_removes_generated_markdown_companion(tmp_path):
thread_uploads_dir = tmp_path / "uploads"
thread_uploads_dir.mkdir(parents=True)
-90
View File
@@ -1,90 +0,0 @@
"""Tests for ``deerflow.utils.time``."""
from __future__ import annotations
import re
from datetime import UTC, datetime, timedelta, timezone
from deerflow.utils.time import coerce_iso, now_iso
_ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
def test_now_iso_is_utc_iso8601() -> None:
value = now_iso()
assert _ISO_RE.match(value), value
parsed = datetime.fromisoformat(value)
assert parsed.tzinfo is not None
assert parsed.tzinfo.utcoffset(parsed) == UTC.utcoffset(parsed)
def test_coerce_iso_passes_iso_through() -> None:
iso = "2026-04-27T01:13:30.411334+00:00"
assert coerce_iso(iso) == iso
def test_coerce_iso_converts_unix_float_string() -> None:
legacy = "1777252410.411327"
out = coerce_iso(legacy)
assert _ISO_RE.match(out), out
# Round-trip: parsed timestamp matches the original epoch.
parsed = datetime.fromisoformat(out)
assert abs(parsed.timestamp() - 1777252410.411327) < 1e-3
def test_coerce_iso_converts_unix_int_string() -> None:
out = coerce_iso("1700000000")
assert _ISO_RE.match(out), out
def test_coerce_iso_converts_numeric_types() -> None:
out_float = coerce_iso(1777252410.411327)
out_int = coerce_iso(1700000000)
assert _ISO_RE.match(out_float)
assert _ISO_RE.match(out_int)
def test_coerce_iso_handles_empty_and_none() -> None:
assert coerce_iso(None) == ""
assert coerce_iso("") == ""
def test_coerce_iso_does_not_misinterpret_short_numeric() -> None:
# A 4-digit year should never be parsed as a unix timestamp; only
# 10-digit unix-second strings match the legacy pattern.
assert coerce_iso("2026") == "2026"
def test_coerce_iso_handles_unparseable_string() -> None:
assert coerce_iso("not-a-timestamp") == "not-a-timestamp"
def test_coerce_iso_rejects_bool() -> None:
# ``bool`` is a subclass of ``int`` — must not be treated as epoch 0/1.
assert coerce_iso(True) == "True"
assert coerce_iso(False) == "False"
def test_coerce_iso_handles_tz_aware_datetime() -> None:
# str(datetime) would emit a space separator; coerce_iso must use ``T``.
dt = datetime(2026, 4, 27, 1, 13, 30, 411327, tzinfo=UTC)
out = coerce_iso(dt)
assert out == "2026-04-27T01:13:30.411327+00:00"
assert "T" in out and " " not in out
def test_coerce_iso_handles_tz_naive_datetime_as_utc() -> None:
dt = datetime(2026, 4, 27, 1, 13, 30, 411327)
out = coerce_iso(dt)
assert out == "2026-04-27T01:13:30.411327+00:00"
parsed = datetime.fromisoformat(out)
assert parsed.tzinfo is not None
assert parsed.utcoffset() == timedelta(0)
def test_coerce_iso_normalises_non_utc_datetime_to_utc() -> None:
# +08:00 wall-clock 09:13 == UTC 01:13.
plus_eight = timezone(timedelta(hours=8))
dt = datetime(2026, 4, 27, 9, 13, 30, 411327, tzinfo=plus_eight)
out = coerce_iso(dt)
assert out == "2026-04-27T01:13:30.411327+00:00"
-10
View File
@@ -373,16 +373,6 @@ tools:
use: deerflow.community.ddg_search.tools:web_search_tool
max_results: 5
# Web search tool (uses Serper - Google Search API, requires SERPER_API_KEY)
# Serper provides real-time Google Search results. Sign up at https://serper.dev
# Note: set SERPER_API_KEY in your environment before starting the app, or
# uncomment and fill in api_key below (the $VAR syntax is resolved at startup).
# - name: web_search
# group: web
# use: deerflow.community.serper.tools:web_search_tool
# max_results: 5
# # api_key: $SERPER_API_KEY # Optional if SERPER_API_KEY env var is set
# Web search tool (requires Tavily API key)
# - name: web_search
# group: web
@@ -25,7 +25,7 @@ import { useAgent } from "@/core/agents";
import { useI18n } from "@/core/i18n/hooks";
import { useModels } from "@/core/models/hooks";
import { useNotification } from "@/core/notification/hooks";
import { useLocalSettings, useThreadSettings } from "@/core/settings";
import { useThreadSettings } from "@/core/settings";
import { useThreadStream } from "@/core/threads/hooks";
import { textOfMessage } from "@/core/threads/utils";
import { env } from "@/env";
@@ -45,7 +45,6 @@ export default function AgentChatPage() {
const { threadId, setThreadId, isNewThread, setIsNewThread } =
useThreadChat();
const [settings, setSettings] = useThreadSettings(threadId);
const [localSettings, setLocalSettings] = useLocalSettings();
const { tokenUsageEnabled } = useModels();
const { showNotification } = useNotification();
@@ -101,9 +100,6 @@ export default function AgentChatPage() {
? MESSAGE_LIST_DEFAULT_PADDING_BOTTOM +
MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM
: undefined;
const tokenUsageInlineMode = tokenUsageEnabled
? localSettings.tokenUsage.inlineMode
: "off";
return (
<ThreadContext.Provider value={{ thread }}>
@@ -143,10 +139,6 @@ export default function AgentChatPage() {
<TokenUsageIndicator
enabled={tokenUsageEnabled}
messages={thread.messages}
preferences={localSettings.tokenUsage}
onPreferencesChange={(preferences) =>
setLocalSettings("tokenUsage", preferences)
}
/>
<ExportTrigger threadId={threadId} />
<ArtifactTrigger />
@@ -160,10 +152,10 @@ export default function AgentChatPage() {
threadId={threadId}
thread={thread}
paddingBottom={messageListPaddingBottom}
tokenUsageEnabled={tokenUsageEnabled}
hasMoreHistory={hasMoreHistory}
loadMoreHistory={loadMoreHistory}
isHistoryLoading={isHistoryLoading}
tokenUsageInlineMode={tokenUsageInlineMode}
/>
</div>
@@ -33,7 +33,6 @@ import { ThreadContext } from "@/components/workspace/messages/context";
import type { Agent } from "@/core/agents";
import {
AgentNameCheckError,
AgentsApiDisabledError,
checkAgentName,
createAgent,
getAgent,
@@ -155,9 +154,7 @@ export default function NewAgentPage() {
return;
}
} catch (err) {
if (err instanceof AgentsApiDisabledError) {
setNameError(t.agents.nameStepApiDisabledError);
} else if (
if (
err instanceof AgentNameCheckError &&
err.reason === "backend_unreachable"
) {
@@ -178,10 +175,6 @@ export default function NewAgentPage() {
soul: "",
});
} catch (err) {
if (err instanceof AgentsApiDisabledError) {
setNameError(t.agents.nameStepApiDisabledError);
return;
}
setNameError(
getCreateAgentErrorMessage(
err,
@@ -204,7 +197,6 @@ export default function NewAgentPage() {
nameInput,
sendMessage,
t.agents.nameStepAlreadyExistsError,
t.agents.nameStepApiDisabledError,
t.agents.nameStepNetworkError,
t.agents.nameStepBootstrapMessage,
t.agents.nameStepCheckError,
@@ -24,7 +24,7 @@ import { Welcome } from "@/components/workspace/welcome";
import { useI18n } from "@/core/i18n/hooks";
import { useModels } from "@/core/models/hooks";
import { useNotification } from "@/core/notification/hooks";
import { useLocalSettings, useThreadSettings } from "@/core/settings";
import { useThreadSettings } from "@/core/settings";
import { useThreadStream } from "@/core/threads/hooks";
import { textOfMessage } from "@/core/threads/utils";
import { env } from "@/env";
@@ -36,7 +36,6 @@ export default function ChatPage() {
const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } =
useThreadChat();
const [settings, setSettings] = useThreadSettings(threadId);
const [localSettings, setLocalSettings] = useLocalSettings();
const { tokenUsageEnabled } = useModels();
const mountedRef = useRef(false);
useSpecificChatMode();
@@ -100,9 +99,6 @@ export default function ChatPage() {
? MESSAGE_LIST_DEFAULT_PADDING_BOTTOM +
MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM
: undefined;
const tokenUsageInlineMode = tokenUsageEnabled
? localSettings.tokenUsage.inlineMode
: "off";
return (
<ThreadContext.Provider value={{ thread, isMock }}>
@@ -123,10 +119,6 @@ export default function ChatPage() {
<TokenUsageIndicator
enabled={tokenUsageEnabled}
messages={thread.messages}
preferences={localSettings.tokenUsage}
onPreferencesChange={(preferences) =>
setLocalSettings("tokenUsage", preferences)
}
/>
<ExportTrigger threadId={threadId} />
<ArtifactTrigger />
@@ -139,10 +131,10 @@ export default function ChatPage() {
threadId={threadId}
thread={thread}
paddingBottom={messageListPaddingBottom}
tokenUsageEnabled={tokenUsageEnabled}
hasMoreHistory={hasMoreHistory}
loadMoreHistory={loadMoreHistory}
isHistoryLoading={isHistoryLoading}
tokenUsageInlineMode={tokenUsageInlineMode}
/>
</div>
<div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4">
@@ -2,7 +2,6 @@ import type { Message } from "@langchain/langgraph-sdk";
import {
BookOpenTextIcon,
ChevronUp,
CoinsIcon,
FolderOpenIcon,
GlobeIcon,
LightbulbIcon,
@@ -25,8 +24,6 @@ import {
import { CodeBlock } from "@/components/ai-elements/code-block";
import { Button } from "@/components/ui/button";
import { useI18n } from "@/core/i18n/hooks";
import { formatTokenCount } from "@/core/messages/usage";
import type { TokenDebugStep } from "@/core/messages/usage-model";
import {
extractReasoningContentFromMessage,
findToolCallResult,
@@ -46,14 +43,10 @@ export function MessageGroup({
className,
messages,
isLoading = false,
tokenDebugSteps = [],
showTokenDebugSummaries = false,
}: {
className?: string;
messages: Message[];
isLoading?: boolean;
tokenDebugSteps?: TokenDebugStep[];
showTokenDebugSummaries?: boolean;
}) {
const { t } = useI18n();
const [showAbove, setShowAbove] = useState(
@@ -63,28 +56,6 @@ export function MessageGroup({
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true",
);
const steps = useMemo(() => convertToSteps(messages), [messages]);
const debugStepByMessageId = useMemo(
() =>
new Map(
tokenDebugSteps.map(
(step) => [step.messageId || step.id, step] as const,
),
),
[tokenDebugSteps],
);
const toolCallCountByMessageId = useMemo(() => {
const counts = new Map<string, number>();
for (const step of steps) {
if (step.type !== "toolCall" || !step.messageId) {
continue;
}
counts.set(step.messageId, (counts.get(step.messageId) ?? 0) + 1);
}
return counts;
}, [steps]);
const lastToolCallStep = useMemo(() => {
const filteredSteps = steps.filter((step) => step.type === "toolCall");
return filteredSteps[filteredSteps.length - 1];
@@ -106,125 +77,6 @@ export function MessageGroup({
}
}, [lastToolCallStep, steps]);
const rehypePlugins = useRehypeSplitWordsIntoSpans(isLoading);
const firstEligibleDebugSummaryStepIndexByMessageId = useMemo(() => {
const firstIndices = new Map<string, number>();
if (!showTokenDebugSummaries) {
return firstIndices;
}
for (const [index, step] of steps.entries()) {
const messageId = step.messageId;
if (!messageId || firstIndices.has(messageId)) {
continue;
}
const debugStep = debugStepByMessageId.get(messageId);
if (!debugStep) {
continue;
}
const toolCallCount = toolCallCountByMessageId.get(messageId) ?? 0;
if (!debugStep.sharedAttribution && toolCallCount > 0) {
continue;
}
if (
!debugStep.sharedAttribution &&
toolCallCount === 0 &&
debugStep.label === t.common.thinking &&
debugStep.secondaryLabels.length === 0
) {
continue;
}
firstIndices.set(messageId, index);
}
return firstIndices;
}, [
debugStepByMessageId,
showTokenDebugSummaries,
steps,
t.common.thinking,
toolCallCountByMessageId,
]);
const renderDebugSummary = (
messageId: string | undefined,
stepIndex: number,
) => {
if (!showTokenDebugSummaries || !messageId) {
return null;
}
const debugStep = debugStepByMessageId.get(messageId);
if (!debugStep) {
return null;
}
if (
firstEligibleDebugSummaryStepIndexByMessageId.get(messageId) !== stepIndex
) {
return null;
}
return (
<ChainOfThoughtStep
key={`token-debug-${messageId}`}
icon={CoinsIcon}
label={
<DebugStepLabel
label={debugStep.label}
token={formatDebugToken(debugStep, t)}
/>
}
description={
debugStep.sharedAttribution
? t.tokenUsage.sharedAttribution
: undefined
}
>
{debugStep.secondaryLabels.length > 0 && (
<ChainOfThoughtSearchResults>
{debugStep.secondaryLabels.map((label, index) => (
<ChainOfThoughtSearchResult
key={`${debugStep.id}-${index}-${label}`}
>
{label}
</ChainOfThoughtSearchResult>
))}
</ChainOfThoughtSearchResults>
)}
</ChainOfThoughtStep>
);
};
const renderToolCall = (
step: CoTToolCallStep,
options?: { isLast?: boolean },
) => {
const debugStep =
showTokenDebugSummaries && step.messageId
? debugStepByMessageId.get(step.messageId)
: undefined;
return (
<ToolCall
key={step.id}
{...step}
isLast={options?.isLast}
isLoading={isLoading}
tokenDebugStep={
debugStep && !debugStep.sharedAttribution ? debugStep : undefined
}
/>
);
};
const lastReasoningDebugStep =
showTokenDebugSummaries && lastReasoningStep?.messageId
? debugStepByMessageId.get(lastReasoningStep.messageId)
: undefined;
return (
<ChainOfThought
className={cn("w-full gap-2 rounded-lg border p-0.5", className)}
@@ -259,46 +111,36 @@ export function MessageGroup({
{lastToolCallStep && (
<ChainOfThoughtContent className="px-4 pb-2">
{showAbove &&
aboveLastToolCallSteps.flatMap((step) => {
const stepIndex = steps.indexOf(step);
if (step.type === "reasoning") {
return [
renderDebugSummary(step.messageId, stepIndex),
<ChainOfThoughtStep
key={step.id}
label={
<MarkdownContent
content={step.reasoning ?? ""}
isLoading={isLoading}
rehypePlugins={rehypePlugins}
/>
}
></ChainOfThoughtStep>,
];
}
return [
renderDebugSummary(step.messageId, stepIndex),
renderToolCall(step),
];
})}
{renderDebugSummary(
lastToolCallStep.messageId,
steps.indexOf(lastToolCallStep),
)}
aboveLastToolCallSteps.map((step) =>
step.type === "reasoning" ? (
<ChainOfThoughtStep
key={step.id}
label={
<MarkdownContent
content={step.reasoning ?? ""}
isLoading={isLoading}
rehypePlugins={rehypePlugins}
/>
}
></ChainOfThoughtStep>
) : (
<ToolCall key={step.id} {...step} isLoading={isLoading} />
),
)}
{lastToolCallStep && (
<FlipDisplay uniqueKey={lastToolCallStep.id ?? ""}>
{renderToolCall(lastToolCallStep, { isLast: true })}
<ToolCall
key={lastToolCallStep.id}
{...lastToolCallStep}
isLast={true}
isLoading={isLoading}
/>
</FlipDisplay>
)}
</ChainOfThoughtContent>
)}
{lastReasoningStep && (
<>
{renderDebugSummary(
lastReasoningStep.messageId,
steps.indexOf(lastReasoningStep),
)}
<Button
key={lastReasoningStep.id}
className="w-full items-start justify-start text-left"
@@ -308,22 +150,7 @@ export function MessageGroup({
<div className="flex w-full items-center justify-between">
<ChainOfThoughtStep
className="font-normal"
label={
<DebugStepLabel
label={t.common.thinking}
token={shouldInlineThinkingToken({
debugStep: lastReasoningDebugStep,
toolCallCount: lastReasoningStep.messageId
? (toolCallCountByMessageId.get(
lastReasoningStep.messageId,
) ?? 0)
: 0,
enabled: showTokenDebugSummaries,
thinkingLabel: t.common.thinking,
t,
})}
/>
}
label={t.common.thinking}
icon={LightbulbIcon}
></ChainOfThoughtStep>
<div>
@@ -356,60 +183,6 @@ export function MessageGroup({
);
}
function formatDebugToken(
debugStep: TokenDebugStep,
t: ReturnType<typeof useI18n>["t"],
) {
return debugStep.usage
? `${formatTokenCount(debugStep.usage.totalTokens)} ${t.tokenUsage.label}`
: t.tokenUsage.unavailableShort;
}
function shouldInlineThinkingToken({
debugStep,
toolCallCount,
enabled,
thinkingLabel,
t,
}: {
debugStep?: TokenDebugStep;
toolCallCount: number;
enabled: boolean;
thinkingLabel: string;
t: ReturnType<typeof useI18n>["t"];
}) {
if (
!enabled ||
!debugStep ||
debugStep.sharedAttribution ||
toolCallCount > 0 ||
debugStep.label !== thinkingLabel
) {
return null;
}
return formatDebugToken(debugStep, t);
}
function DebugStepLabel({
label,
token,
}: {
label: React.ReactNode;
token?: string | null;
}) {
return (
<div className="flex items-center justify-between gap-3">
<div className="min-w-0 flex-1">{label}</div>
{token ? (
<div className="text-muted-foreground shrink-0 font-mono text-[11px]">
{token}
</div>
) : null}
</div>
);
}
function ToolCall({
id,
messageId,
@@ -418,7 +191,6 @@ function ToolCall({
result,
isLast = false,
isLoading = false,
tokenDebugStep,
}: {
id?: string;
messageId?: string;
@@ -427,20 +199,10 @@ function ToolCall({
result?: string | Record<string, unknown>;
isLast?: boolean;
isLoading?: boolean;
tokenDebugStep?: TokenDebugStep;
}) {
const { t } = useI18n();
const { setOpen, autoOpen, autoSelect, selectedArtifact, select } =
useArtifacts();
const tokenLabel = tokenDebugStep
? formatDebugToken(tokenDebugStep, t)
: null;
const resolveLabel = (fallback: React.ReactNode) =>
tokenDebugStep ? (
<DebugStepLabel label={tokenDebugStep.label} token={tokenLabel} />
) : (
fallback
);
if (name === "web_search") {
let label: React.ReactNode = t.toolCalls.searchForRelatedInfo;
@@ -448,11 +210,7 @@ function ToolCall({
label = t.toolCalls.searchOnWebFor(args.query);
}
return (
<ChainOfThoughtStep
key={id}
label={resolveLabel(label)}
icon={SearchIcon}
>
<ChainOfThoughtStep key={id} label={label} icon={SearchIcon}>
{Array.isArray(result) && (
<ChainOfThoughtSearchResults>
{result.map((item) => (
@@ -482,11 +240,7 @@ function ToolCall({
}
)?.results;
return (
<ChainOfThoughtStep
key={id}
label={resolveLabel(label)}
icon={SearchIcon}
>
<ChainOfThoughtStep key={id} label={label} icon={SearchIcon}>
{Array.isArray(results) && (
<ChainOfThoughtSearchResults>
{Array.isArray(results) &&
@@ -526,7 +280,7 @@ function ToolCall({
return (
<ChainOfThoughtStep
key={id}
label={resolveLabel(t.toolCalls.viewWebPage)}
label={t.toolCalls.viewWebPage}
icon={GlobeIcon}
>
<ChainOfThoughtSearchResult>
@@ -551,11 +305,7 @@ function ToolCall({
}
const path: string | undefined = (args as { path: string })?.path;
return (
<ChainOfThoughtStep
key={id}
label={resolveLabel(description)}
icon={FolderOpenIcon}
>
<ChainOfThoughtStep key={id} label={description} icon={FolderOpenIcon}>
{path && (
<ChainOfThoughtSearchResult className="cursor-pointer">
{path}
@@ -571,11 +321,7 @@ function ToolCall({
}
const { path } = args as { path: string; content: string };
return (
<ChainOfThoughtStep
key={id}
label={resolveLabel(description)}
icon={BookOpenTextIcon}
>
<ChainOfThoughtStep key={id} label={description} icon={BookOpenTextIcon}>
{path && (
<ChainOfThoughtSearchResult className="cursor-pointer">
{path}
@@ -607,7 +353,7 @@ function ToolCall({
<ChainOfThoughtStep
key={id}
className="cursor-pointer"
label={resolveLabel(description)}
label={description}
icon={NotebookPenIcon}
onClick={() => {
select(
@@ -629,19 +375,13 @@ function ToolCall({
const description: string | undefined = (args as { description: string })
?.description;
if (!description) {
return (
<ChainOfThoughtStep
key={id}
label={resolveLabel(t.toolCalls.executeCommand)}
icon={SquareTerminalIcon}
/>
);
return t.toolCalls.executeCommand;
}
const command: string | undefined = (args as { command: string })?.command;
return (
<ChainOfThoughtStep
key={id}
label={resolveLabel(description)}
label={description}
icon={SquareTerminalIcon}
>
{command && (
@@ -658,7 +398,7 @@ function ToolCall({
return (
<ChainOfThoughtStep
key={id}
label={resolveLabel(t.toolCalls.needYourHelp)}
label={t.toolCalls.needYourHelp}
icon={MessageCircleQuestionMarkIcon}
></ChainOfThoughtStep>
);
@@ -666,7 +406,7 @@ function ToolCall({
return (
<ChainOfThoughtStep
key={id}
label={resolveLabel(t.toolCalls.writeTodos)}
label={t.toolCalls.writeTodos}
icon={ListTodoIcon}
></ChainOfThoughtStep>
);
@@ -676,7 +416,7 @@ function ToolCall({
return (
<ChainOfThoughtStep
key={id}
label={resolveLabel(description ?? t.toolCalls.useTool(name))}
label={description ?? t.toolCalls.useTool(name)}
icon={WrenchIcon}
></ChainOfThoughtStep>
);
@@ -50,6 +50,7 @@ import { cn } from "@/lib/utils";
import { CopyButton } from "../copy-button";
import { MarkdownContent } from "./markdown-content";
import { MessageTokenUsage } from "./message-token-usage";
function FeedbackButtons({
threadId,
@@ -120,20 +121,20 @@ function FeedbackButtons({
export function MessageListItem({
className,
threadId,
message,
isLoading,
tokenUsageEnabled = false,
feedback,
runId,
threadId,
showCopyButton = true,
}: {
className?: string;
message: Message;
isLoading?: boolean;
threadId: string;
tokenUsageEnabled?: boolean;
feedback?: FeedbackData | null;
runId?: string;
showCopyButton?: boolean;
}) {
const isHuman = message.type === "human";
return (
@@ -146,17 +147,16 @@ export function MessageListItem({
message={message}
isLoading={isLoading}
threadId={threadId}
tokenUsageEnabled={tokenUsageEnabled}
/>
{!isLoading && showCopyButton && (
{!isLoading && (
<MessageToolbar
className={cn(
isHuman
? "absolute right-0 -bottom-9 left-0 justify-end"
: "absolute right-0 bottom-0 left-0",
"z-20 opacity-0 transition-opacity delay-200 duration-300 group-hover/conversation-message:opacity-100",
isHuman ? "-bottom-9 justify-end" : "-bottom-8",
"absolute right-0 left-0 z-20",
)}
>
<div className="pointer-events-auto flex gap-1">
<div className="flex gap-1">
<CopyButton
clipboardData={
extractContentFromMessage(message) ??
@@ -213,11 +213,13 @@ function MessageContent_({
message,
isLoading = false,
threadId,
tokenUsageEnabled = false,
}: {
className?: string;
message: Message;
isLoading?: boolean;
threadId: string;
tokenUsageEnabled?: boolean;
}) {
const rehypePlugins = useRehypeSplitWordsIntoSpans(isLoading);
const isHuman = message.type === "human";
@@ -295,6 +297,11 @@ function MessageContent_({
<ReasoningTrigger />
<ReasoningContent>{reasoningContent}</ReasoningContent>
</Reasoning>
<MessageTokenUsage
enabled={tokenUsageEnabled}
isLoading={isLoading}
message={message}
/>
</AIElementMessageContent>
);
}
@@ -332,6 +339,11 @@ function MessageContent_({
className="my-3"
components={components}
/>
<MessageTokenUsage
enabled={tokenUsageEnabled}
isLoading={isLoading}
message={message}
/>
</AIElementMessageContent>
);
}
@@ -1,7 +1,6 @@
import type { Message } from "@langchain/langgraph-sdk";
import type { BaseStream } from "@langchain/langgraph-sdk/react";
import { ChevronUpIcon, Loader2Icon } from "lucide-react";
import { useCallback, useEffect, useMemo, useRef } from "react";
import { useCallback, useEffect, useRef } from "react";
import {
Conversation,
@@ -9,20 +8,15 @@ import {
} from "@/components/ai-elements/conversation";
import { Button } from "@/components/ui/button";
import { useI18n } from "@/core/i18n/hooks";
import {
buildTokenDebugSteps,
type TokenUsageInlineMode,
} from "@/core/messages/usage-model";
import {
extractContentFromMessage,
extractPresentFilesFromMessage,
extractReasoningContentFromMessage,
extractTextFromMessage,
getAssistantTurnUsageMessages,
getMessageGroups,
groupMessages,
hasContent,
hasPresentFiles,
hasReasoning,
hasToolCalls,
} from "@/core/messages/utils";
import { useRehypeSplitWordsIntoSpans } from "@/core/rehype";
import type { Subtask } from "@/core/tasks";
@@ -31,16 +25,12 @@ import type { AgentThreadState } from "@/core/threads";
import { cn } from "@/lib/utils";
import { ArtifactFileList } from "../artifacts/artifact-file-list";
import { CopyButton } from "../copy-button";
import { StreamingIndicator } from "../streaming-indicator";
import { MarkdownContent } from "./markdown-content";
import { MessageGroup } from "./message-group";
import { MessageListItem } from "./message-list-item";
import {
MessageTokenUsageDebugList,
MessageTokenUsageList,
} from "./message-token-usage";
import { MessageTokenUsageList } from "./message-token-usage";
import { MessageListSkeleton } from "./skeleton";
import { SubtaskCard } from "./subtask-card";
@@ -159,7 +149,7 @@ export function MessageList({
threadId,
thread,
paddingBottom = MESSAGE_LIST_DEFAULT_PADDING_BOTTOM,
tokenUsageInlineMode = "off",
tokenUsageEnabled = false,
hasMoreHistory,
loadMoreHistory,
isHistoryLoading,
@@ -168,7 +158,7 @@ export function MessageList({
threadId: string;
thread: BaseStream<AgentThreadState>;
paddingBottom?: number;
tokenUsageInlineMode?: TokenUsageInlineMode;
tokenUsageEnabled?: boolean;
hasMoreHistory?: boolean;
loadMoreHistory?: () => void;
isHistoryLoading?: boolean;
@@ -177,85 +167,10 @@ export function MessageList({
const rehypePlugins = useRehypeSplitWordsIntoSpans(thread.isLoading);
const updateSubtask = useUpdateSubtask();
const messages = thread.messages;
const groupedMessages = getMessageGroups(messages);
const turnUsageMessagesByGroupIndex =
getAssistantTurnUsageMessages(groupedMessages);
const tokenDebugSteps = useMemo(
() => buildTokenDebugSteps(messages, t),
[messages, t],
);
const renderAssistantCopyButton = useCallback((messages: Message[]) => {
const clipboardData = [...messages]
.reverse()
.filter((message) => message.type === "ai")
.map((message) => {
const content = extractContentFromMessage(message);
return content ?? extractReasoningContentFromMessage(message) ?? "";
})
.find((content) => content.length > 0);
if (!clipboardData) {
return null;
}
return (
<div className="mt-2 flex justify-start opacity-0 transition-opacity delay-200 duration-300 group-hover/assistant-turn:opacity-100">
<CopyButton clipboardData={clipboardData} />
</div>
);
}, []);
const renderTokenUsage = useCallback(
({
messages,
turnUsageMessages,
inlineDebug = true,
debugMessageIds,
}: {
messages: Message[];
turnUsageMessages?: Message[] | null;
inlineDebug?: boolean;
debugMessageIds?: string[];
}) => {
if (tokenUsageInlineMode === "per_turn") {
return (
<MessageTokenUsageList
enabled={true}
isLoading={thread.isLoading}
messages={turnUsageMessages ?? []}
/>
);
}
if (tokenUsageInlineMode === "step_debug" && inlineDebug) {
const messageIds = new Set(
debugMessageIds ??
messages
.filter((message) => message.type === "ai")
.map((message) => message.id)
.filter((id): id is string => typeof id === "string"),
);
return (
<MessageTokenUsageDebugList
enabled={true}
isLoading={thread.isLoading}
steps={tokenDebugSteps.filter((step) =>
messageIds.has(step.messageId),
)}
/>
);
}
return null;
},
[thread.isLoading, tokenDebugSteps, tokenUsageInlineMode],
);
if (thread.isThreadLoading && messages.length === 0) {
return <MessageListSkeleton />;
}
return (
<Conversation
className={cn("flex size-full flex-col justify-center", className)}
@@ -266,37 +181,19 @@ export function MessageList({
hasMore={hasMoreHistory}
loadMore={loadMoreHistory}
/>
{groupedMessages.map((group, groupIndex) => {
const turnUsageMessages = turnUsageMessagesByGroupIndex[groupIndex];
{groupMessages(messages, (group) => {
if (group.type === "human" || group.type === "assistant") {
return (
<div
key={group.id}
className={cn(
"w-full",
group.type === "assistant" && "group/assistant-turn",
)}
>
{group.messages.map((msg) => {
return (
<MessageListItem
key={`${group.id}/${msg.id}`}
message={msg}
isLoading={thread.isLoading}
threadId={threadId}
showCopyButton={group.type !== "assistant"}
/>
);
})}
{renderTokenUsage({
messages: group.messages,
turnUsageMessages,
})}
{group.type === "assistant" &&
renderAssistantCopyButton(group.messages)}
</div>
);
return group.messages.map((msg) => {
return (
<MessageListItem
key={`${group.id}/${msg.id}`}
threadId={threadId}
message={msg}
isLoading={thread.isLoading}
tokenUsageEnabled={tokenUsageEnabled}
/>
);
});
} else if (group.type === "assistant:clarification") {
const message = group.messages[0];
if (message && hasContent(message)) {
@@ -307,10 +204,11 @@ export function MessageList({
isLoading={thread.isLoading}
rehypePlugins={rehypePlugins}
/>
{renderTokenUsage({
messages: group.messages,
turnUsageMessages,
})}
<MessageTokenUsageList
enabled={tokenUsageEnabled}
isLoading={thread.isLoading}
messages={group.messages}
/>
</div>
);
}
@@ -334,10 +232,11 @@ export function MessageList({
/>
)}
<ArtifactFileList files={files} threadId={threadId} />
{renderTokenUsage({
messages: group.messages,
turnUsageMessages,
})}
<MessageTokenUsageList
enabled={tokenUsageEnabled}
isLoading={thread.isLoading}
messages={group.messages}
/>
</div>
);
} else if (group.type === "assistant:subagent") {
@@ -390,19 +289,7 @@ export function MessageList({
}
}
}
const results: React.ReactNode[] = [];
const subagentDebugMessageIds: string[] = [];
if (tasks.size > 0) {
results.push(
<div
key="subtask-count"
className="text-muted-foreground pt-2 text-sm font-normal"
>
{t.subtasks.executing(tasks.size)}
</div>,
);
}
for (const message of group.messages.filter(
(message) => message.type === "ai",
)) {
@@ -412,17 +299,17 @@ export function MessageList({
key={"thinking-group-" + message.id}
messages={[message]}
isLoading={thread.isLoading}
tokenDebugSteps={tokenDebugSteps.filter(
(step) => step.messageId === message.id,
)}
showTokenDebugSummaries={
tokenUsageInlineMode === "step_debug"
}
/>,
);
} else if (message.id) {
subagentDebugMessageIds.push(message.id);
}
results.push(
<div
key="subtask-count"
className="text-muted-foreground font-norma pt-2 text-sm"
>
{t.subtasks.executing(tasks.size)}
</div>,
);
const taskIds = message.tool_calls
?.filter((toolCall) => toolCall.name === "task")
.map((toolCall) => toolCall.id);
@@ -442,31 +329,30 @@ export function MessageList({
className="relative z-1 flex flex-col gap-2"
>
{results}
{renderTokenUsage({
messages: group.messages,
turnUsageMessages,
debugMessageIds: subagentDebugMessageIds,
})}
<MessageTokenUsageList
enabled={tokenUsageEnabled}
isLoading={thread.isLoading}
messages={group.messages}
/>
</div>
);
}
const tokenUsageMessages = group.messages.filter(
(message) =>
message.type === "ai" &&
(hasToolCalls(message) ? true : !hasContent(message)),
);
return (
<div key={"group-" + group.id} className="w-full">
<MessageGroup
messages={group.messages}
isLoading={thread.isLoading}
tokenDebugSteps={tokenDebugSteps.filter((step) =>
group.messages.some(
(message) => message.id === step.messageId,
),
)}
showTokenDebugSummaries={tokenUsageInlineMode === "step_debug"}
/>
{renderTokenUsage({
messages: group.messages,
turnUsageMessages,
inlineDebug: false,
})}
<MessageTokenUsageList
enabled={tokenUsageEnabled}
isLoading={thread.isLoading}
messages={tokenUsageMessages}
/>
</div>
);
})}
@@ -1,27 +1,29 @@
import type { Message } from "@langchain/langgraph-sdk";
import { CoinsIcon } from "lucide-react";
import { Badge } from "@/components/ui/badge";
import { useI18n } from "@/core/i18n/hooks";
import { accumulateUsage, formatTokenCount } from "@/core/messages/usage";
import type { TokenDebugStep } from "@/core/messages/usage-model";
import { formatTokenCount, getUsageMetadata } from "@/core/messages/usage";
import { cn } from "@/lib/utils";
function TokenUsageSummary({
export function MessageTokenUsage({
className,
inputTokens,
outputTokens,
totalTokens,
unavailable = false,
enabled = false,
isLoading = false,
message,
}: {
className?: string;
inputTokens?: number;
outputTokens?: number;
totalTokens?: number;
unavailable?: boolean;
enabled?: boolean;
isLoading?: boolean;
message: Message;
}) {
const { t } = useI18n();
if (!enabled || isLoading || message.type !== "ai") {
return null;
}
const usage = getUsageMetadata(message);
return (
<div
className={cn(
@@ -33,16 +35,16 @@ function TokenUsageSummary({
<CoinsIcon className="size-3" />
{t.tokenUsage.label}
</span>
{!unavailable ? (
{usage ? (
<>
<span>
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
{t.tokenUsage.input}: {formatTokenCount(usage.inputTokens)}
</span>
<span>
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
{t.tokenUsage.output}: {formatTokenCount(usage.outputTokens)}
</span>
<span className="font-medium">
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
{t.tokenUsage.total}: {formatTokenCount(usage.totalTokens)}
</span>
</>
) : (
@@ -73,93 +75,17 @@ export function MessageTokenUsageList({
return null;
}
const usage = accumulateUsage(aiMessages);
return (
<TokenUsageSummary
className={className}
inputTokens={usage?.inputTokens}
outputTokens={usage?.outputTokens}
totalTokens={usage?.totalTokens}
unavailable={!usage}
/>
);
}
export function MessageTokenUsageDebugList({
className,
enabled = false,
isLoading = false,
steps,
}: {
className?: string;
enabled?: boolean;
isLoading?: boolean;
steps: TokenDebugStep[];
}) {
const { t } = useI18n();
if (!enabled || isLoading) {
return null;
}
if (steps.length === 0) {
return null;
}
return (
<div className={cn("border-border/60 mt-1 border-t pt-2", className)}>
<div className="space-y-2">
{steps.map((step) => (
<div
key={step.id}
className="bg-muted/30 border-border/50 flex items-start justify-between gap-3 rounded-md border px-3 py-2"
>
<div className="min-w-0 flex-1 space-y-1">
<div className="text-foreground flex items-center gap-2 text-xs font-medium">
<CoinsIcon className="text-muted-foreground size-3" />
<span className="truncate">{step.label}</span>
</div>
{step.secondaryLabels.length > 0 && (
<div className="flex flex-wrap gap-1.5">
{step.secondaryLabels.map((label, index) => (
<Badge
key={`${step.id}-${index}-${label}`}
className="px-1.5 py-0 text-[10px] font-normal"
variant="secondary"
>
{label}
</Badge>
))}
</div>
)}
{step.sharedAttribution && (
<div className="text-muted-foreground text-[11px]">
{t.tokenUsage.sharedAttribution}
</div>
)}
<div className="text-muted-foreground text-[11px]">
{step.usage ? (
<>
{t.tokenUsage.input}:{" "}
{formatTokenCount(step.usage.inputTokens)}
{" · "}
{t.tokenUsage.output}:{" "}
{formatTokenCount(step.usage.outputTokens)}
</>
) : (
t.tokenUsage.unavailableShort
)}
</div>
</div>
<Badge className="shrink-0 font-mono" variant="outline">
{step.usage
? `${formatTokenCount(step.usage.totalTokens)} ${t.tokenUsage.label}`
: t.tokenUsage.unavailableShort}
</Badge>
</div>
))}
</div>
</div>
<>
{aiMessages.map((message, index) => (
<MessageTokenUsage
className={className}
enabled={enabled}
isLoading={isLoading}
key={message.id ?? index}
message={message}
/>
))}
</>
);
}
@@ -8,13 +8,11 @@ import { Input } from "@/components/ui/input";
import { fetch, getCsrfHeaders } from "@/core/api/fetcher";
import { useAuth } from "@/core/auth/AuthProvider";
import { parseAuthError } from "@/core/auth/types";
import { useI18n } from "@/core/i18n/hooks";
import { SettingsSection } from "./settings-section";
export function AccountSettingsPage() {
const { user, logout } = useAuth();
const { t } = useI18n();
const [currentPassword, setCurrentPassword] = useState("");
const [newPassword, setNewPassword] = useState("");
const [confirmPassword, setConfirmPassword] = useState("");
@@ -28,11 +26,11 @@ export function AccountSettingsPage() {
setMessage("");
if (newPassword !== confirmPassword) {
setError(t.settings.account.passwordMismatch);
setError("New passwords do not match");
return;
}
if (newPassword.length < 8) {
setError(t.settings.account.passwordTooShort);
setError("Password must be at least 8 characters");
return;
}
@@ -57,12 +55,12 @@ export function AccountSettingsPage() {
return;
}
setMessage(t.settings.account.passwordChangedSuccess);
setMessage("Password changed successfully");
setCurrentPassword("");
setNewPassword("");
setConfirmPassword("");
} catch {
setError(t.settings.account.networkError);
setError("Network error. Please try again.");
} finally {
setLoading(false);
}
@@ -70,16 +68,12 @@ export function AccountSettingsPage() {
return (
<div className="space-y-8">
<SettingsSection title={t.settings.account.profileTitle}>
<SettingsSection title="Profile">
<div className="space-y-2">
<div className="grid grid-cols-[max-content_max-content] items-center gap-4">
<span className="text-muted-foreground text-sm">
{t.settings.account.email}
</span>
<span className="text-muted-foreground text-sm">Email</span>
<span className="text-sm font-medium">{user?.email ?? "—"}</span>
<span className="text-muted-foreground text-sm">
{t.settings.account.role}
</span>
<span className="text-muted-foreground text-sm">Role</span>
<span className="text-sm font-medium capitalize">
{user?.system_role ?? "—"}
</span>
@@ -88,20 +82,20 @@ export function AccountSettingsPage() {
</SettingsSection>
<SettingsSection
title={t.settings.account.changePasswordTitle}
description={t.settings.account.changePasswordDescription}
title="Change Password"
description="Update your account password."
>
<form onSubmit={handleChangePassword} className="max-w-sm space-y-3">
<Input
type="password"
placeholder={t.settings.account.currentPassword}
placeholder="Current password"
value={currentPassword}
onChange={(e) => setCurrentPassword(e.target.value)}
required
/>
<Input
type="password"
placeholder={t.settings.account.newPassword}
placeholder="New password"
value={newPassword}
onChange={(e) => setNewPassword(e.target.value)}
required
@@ -109,7 +103,7 @@ export function AccountSettingsPage() {
/>
<Input
type="password"
placeholder={t.settings.account.confirmNewPassword}
placeholder="Confirm new password"
value={confirmPassword}
onChange={(e) => setConfirmPassword(e.target.value)}
required
@@ -118,9 +112,7 @@ export function AccountSettingsPage() {
{error && <p className="text-sm text-red-500">{error}</p>}
{message && <p className="text-sm text-green-500">{message}</p>}
<Button type="submit" variant="outline" size="sm" disabled={loading}>
{loading
? t.settings.account.updating
: t.settings.account.updatePassword}
{loading ? "Updating..." : "Update Password"}
</Button>
</form>
</SettingsSection>
@@ -133,7 +125,7 @@ export function AccountSettingsPage() {
className="gap-2"
>
<LogOutIcon className="size-4" />
{t.settings.account.signOut}
Sign Out
</Button>
</SettingsSection>
</div>
@@ -1,81 +1,60 @@
"use client";
import type { Message } from "@langchain/langgraph-sdk";
import { ChevronDownIcon, CoinsIcon } from "lucide-react";
import { CoinsIcon } from "lucide-react";
import { useMemo } from "react";
import { Button } from "@/components/ui/button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuLabel,
DropdownMenuRadioGroup,
DropdownMenuRadioItem,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { useI18n } from "@/core/i18n/hooks";
import { accumulateUsage, formatTokenCount } from "@/core/messages/usage";
import {
getTokenUsageViewPreset,
tokenUsagePreferencesFromPreset,
type TokenUsagePreferences,
type TokenUsageViewPreset,
} from "@/core/messages/usage-model";
import { cn } from "@/lib/utils";
interface TokenUsageIndicatorProps {
messages: Message[];
enabled?: boolean;
preferences: TokenUsagePreferences;
onPreferencesChange: (preferences: TokenUsagePreferences) => void;
className?: string;
}
export function TokenUsageIndicator({
messages,
enabled = false,
preferences,
onPreferencesChange,
className,
}: TokenUsageIndicatorProps) {
const { t } = useI18n();
const usage = useMemo(() => accumulateUsage(messages), [messages]);
const preset = getTokenUsageViewPreset(preferences);
if (!enabled) {
return null;
}
return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button
<Tooltip delayDuration={200}>
<TooltipTrigger asChild>
<button
type="button"
variant="ghost"
className={cn(
"text-muted-foreground bg-background/70 hover:bg-background/90 flex h-auto items-center gap-1.5 rounded-full border px-2 py-1 text-xs font-normal",
"text-muted-foreground bg-background/70 flex cursor-default items-center gap-1.5 rounded-full border px-2 py-1 text-xs",
!usage && "opacity-60",
className,
)}
>
<CoinsIcon size={14} />
<span>{t.tokenUsage.label}</span>
<span className="font-mono">
{preferences.headerTotal
? usage
? formatTokenCount(usage.totalTokens)
: "-"
: t.tokenUsage.presets[presetKeyToTranslationKey(preset)]}
{usage ? formatTokenCount(usage.totalTokens) : "-"}
</span>
<ChevronDownIcon className="size-3" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent side="bottom" align="end" className="w-80">
<DropdownMenuLabel>{t.tokenUsage.title}</DropdownMenuLabel>
<div className="px-2 py-1 text-xs">
</button>
</TooltipTrigger>
<TooltipContent side="bottom" align="end">
<div className="space-y-1 text-xs">
<div className="font-medium">{t.tokenUsage.title}</div>
{usage ? (
<div className="space-y-1">
<>
<div className="flex justify-between gap-4">
<span>{t.tokenUsage.input}</span>
<span className="font-mono">
@@ -96,53 +75,14 @@ export function TokenUsageIndicator({
</span>
</div>
</div>
</div>
</>
) : (
<div className="text-muted-foreground">
<div className="text-muted-foreground max-w-56">
{t.tokenUsage.unavailable}
</div>
)}
</div>
<DropdownMenuSeparator />
<DropdownMenuLabel>{t.tokenUsage.view}</DropdownMenuLabel>
<DropdownMenuRadioGroup
value={preset}
onValueChange={(value) =>
onPreferencesChange(
tokenUsagePreferencesFromPreset(value as TokenUsageViewPreset),
)
}
>
{(
["off", "summary", "per_turn", "debug"] as TokenUsageViewPreset[]
).map((value) => {
const translationKey = presetKeyToTranslationKey(value);
return (
<DropdownMenuRadioItem key={value} value={value}>
<div className="grid gap-0.5">
<span>{t.tokenUsage.presets[translationKey]}</span>
<span className="text-muted-foreground text-xs">
{t.tokenUsage.presetDescriptions[translationKey]}
</span>
</div>
</DropdownMenuRadioItem>
);
})}
</DropdownMenuRadioGroup>
<DropdownMenuSeparator />
<div className="text-muted-foreground px-2 py-2 text-xs leading-relaxed">
{t.tokenUsage.note}
</div>
</DropdownMenuContent>
</DropdownMenu>
</TooltipContent>
</Tooltip>
);
}
function presetKeyToTranslationKey(preset: TokenUsageViewPreset) {
switch (preset) {
case "per_turn":
return "perTurn" as const;
default:
return preset;
}
}
-17
View File
@@ -15,17 +15,6 @@ export class AgentNameCheckError extends Error {
}
}
export class AgentsApiDisabledError extends Error {
constructor(message: string) {
super(message);
this.name = "AgentsApiDisabledError";
}
}
function isAgentsApiDisabledDetail(detail: string | undefined): boolean {
return typeof detail === "string" && detail.includes("agents_api.enabled");
}
export async function listAgents(): Promise<Agent[]> {
const res = await fetch(`${getBackendBaseURL()}/api/agents`);
if (!res.ok) throw new Error(`Failed to load agents: ${res.statusText}`);
@@ -47,9 +36,6 @@ export async function createAgent(request: CreateAgentRequest): Promise<Agent> {
});
if (!res.ok) {
const err = (await res.json().catch(() => ({}))) as { detail?: string };
if (isAgentsApiDisabledDetail(err.detail)) {
throw new AgentsApiDisabledError(err.detail!);
}
throw new Error(err.detail ?? `Failed to create agent: ${res.statusText}`);
}
return res.json() as Promise<Agent>;
@@ -95,9 +81,6 @@ export async function checkAgentName(
if (!res.ok) {
const err = (await res.json().catch(() => ({}))) as { detail?: string };
if (isAgentsApiDisabledDetail(err.detail)) {
throw new AgentsApiDisabledError(err.detail!);
}
if (BACKEND_UNAVAILABLE_STATUSES.has(res.status)) {
throw new AgentNameCheckError(
"Could not reach the DeerFlow backend.",
-42
View File
@@ -204,8 +204,6 @@ export const enUS: Translations = {
nameStepNetworkError:
"Network request failed — check your network or backend connection",
nameStepCheckError: "Could not verify name availability — please try again",
nameStepApiDisabledError:
"Custom agent management is not enabled on this server. Please contact your administrator.",
nameStepBootstrapMessage:
"The new custom agent name is {name}. Let's bootstrap it's **SOUL**.",
save: "Save agent",
@@ -306,32 +304,9 @@ export const enUS: Translations = {
input: "Input",
output: "Output",
total: "Total",
view: "Display",
unavailable:
"No token usage yet. Usage appears only after a successful model response when the provider returns usage_metadata.",
unavailableShort: "No usage returned",
note: "Shown from provider-returned usage_metadata. Totals are best-effort conversation totals and may differ from provider billing pages.",
presets: {
off: "Off",
summary: "Summary",
perTurn: "Per turn",
debug: "Debug",
},
presetDescriptions: {
off: "Hide token usage in the header and conversation.",
summary: "Show only the current conversation total in the header.",
perTurn:
"Show the header total and one token summary per assistant turn.",
debug: "Show the header total and step-level token debugging details.",
},
finalAnswer: "Final answer",
stepTotal: "Step total",
sharedAttribution: "Shared across multiple actions in this step",
subagent: (description: string) => `Subagent: ${description}`,
startTodo: (content: string) => `Start To-do: ${content}`,
completeTodo: (content: string) => `Complete To-do: ${content}`,
updateTodo: (content: string) => `Update To-do: ${content}`,
removeTodo: (content: string) => `Remove To-do: ${content}`,
},
// Shortcuts
@@ -478,23 +453,6 @@ export const enUS: Translations = {
notSupported: "Your browser does not support notifications.",
disableNotification: "Disable notification",
},
account: {
profileTitle: "Profile",
email: "Email",
role: "Role",
changePasswordTitle: "Change Password",
changePasswordDescription: "Update your account password.",
currentPassword: "Current password",
newPassword: "New password",
confirmNewPassword: "Confirm new password",
passwordMismatch: "New passwords do not match",
passwordTooShort: "Password must be at least 8 characters",
passwordChangedSuccess: "Password changed successfully",
networkError: "Network error. Please try again.",
updating: "Updating...",
updatePassword: "Update Password",
signOut: "Sign Out",
},
acknowledge: {
emptyTitle: "Acknowledgements",
emptyDescription: "Credits and acknowledgements will show here.",
-40
View File
@@ -141,7 +141,6 @@ export interface Translations {
nameStepAlreadyExistsError: string;
nameStepNetworkError: string;
nameStepCheckError: string;
nameStepApiDisabledError: string;
nameStepBootstrapMessage: string;
save: string;
saving: string;
@@ -236,30 +235,8 @@ export interface Translations {
input: string;
output: string;
total: string;
view: string;
unavailable: string;
unavailableShort: string;
note: string;
presets: {
off: string;
summary: string;
perTurn: string;
debug: string;
};
presetDescriptions: {
off: string;
summary: string;
perTurn: string;
debug: string;
};
finalAnswer: string;
stepTotal: string;
sharedAttribution: string;
subagent: (description: string) => string;
startTodo: (content: string) => string;
completeTodo: (content: string) => string;
updateTodo: (content: string) => string;
removeTodo: (content: string) => string;
};
// Shortcuts
@@ -394,23 +371,6 @@ export interface Translations {
notSupported: string;
disableNotification: string;
};
account: {
profileTitle: string;
email: string;
role: string;
changePasswordTitle: string;
changePasswordDescription: string;
currentPassword: string;
newPassword: string;
confirmNewPassword: string;
passwordMismatch: string;
passwordTooShort: string;
passwordChangedSuccess: string;
networkError: string;
updating: string;
updatePassword: string;
signOut: string;
};
acknowledge: {
emptyTitle: string;
emptyDescription: string;
-41
View File
@@ -192,8 +192,6 @@ export const zhCN: Translations = {
nameStepAlreadyExistsError: "已存在同名智能体",
nameStepNetworkError: "网络请求失败,请检查网络或后端连接",
nameStepCheckError: "无法验证名称可用性,请稍后重试",
nameStepApiDisabledError:
"服务器未开启自定义智能体管理功能,请联系管理员。",
nameStepBootstrapMessage:
"新智能体的名称是 {name},现在开始为它生成 **SOUL**。",
save: "保存智能体",
@@ -292,31 +290,9 @@ export const zhCN: Translations = {
input: "输入",
output: "输出",
total: "总计",
view: "显示方式",
unavailable:
"暂无 Token 用量。只有模型成功返回且供应商提供 usage_metadata 时才会显示。",
unavailableShort: "未返回用量",
note: "基于供应商返回的 usage_metadata 展示。当前总量是 best-effort 的会话参考值,可能与平台账单页不完全一致。",
presets: {
off: "关闭",
summary: "总览",
perTurn: "每轮",
debug: "调试",
},
presetDescriptions: {
off: "隐藏顶部和会话内的 token 展示。",
summary: "只在顶部显示当前对话累计 token。",
perTurn: "显示顶部累计,并为每轮 assistant 回复显示一条汇总 token。",
debug: "显示顶部累计,并展示按步骤归类的 token 调试信息。",
},
finalAnswer: "最终回复",
stepTotal: "步骤总计",
sharedAttribution: "该 token 由此步骤中的多个动作共同消耗",
subagent: (description: string) => `子任务:${description}`,
startTodo: (content: string) => `开始 To-do${content}`,
completeTodo: (content: string) => `完成 To-do${content}`,
updateTodo: (content: string) => `更新 To-do${content}`,
removeTodo: (content: string) => `移除 To-do${content}`,
},
// Shortcuts
@@ -458,23 +434,6 @@ export const zhCN: Translations = {
notSupported: "当前浏览器不支持通知功能。",
disableNotification: "关闭通知",
},
account: {
profileTitle: "个人信息",
email: "邮箱",
role: "角色",
changePasswordTitle: "修改密码",
changePasswordDescription: "更新你的账号密码。",
currentPassword: "当前密码",
newPassword: "新密码",
confirmNewPassword: "确认新密码",
passwordMismatch: "两次输入的新密码不一致",
passwordTooShort: "密码长度至少为 8 个字符",
passwordChangedSuccess: "密码修改成功",
networkError: "网络错误,请重试。",
updating: "更新中...",
updatePassword: "修改密码",
signOut: "退出登录",
},
acknowledge: {
emptyTitle: "致谢",
emptyDescription: "相关的致谢信息会展示在这里。",
-440
View File
@@ -1,440 +0,0 @@
import type { Message } from "@langchain/langgraph-sdk";
import type { Translations } from "@/core/i18n/locales/types";
import { getUsageMetadata, type TokenUsage } from "./usage";
import { hasContent } from "./utils";
export type TokenUsageInlineMode = "off" | "per_turn" | "step_debug";
export interface TokenUsagePreferences {
headerTotal: boolean;
inlineMode: TokenUsageInlineMode;
}
export type TokenUsageViewPreset = "off" | "summary" | "per_turn" | "debug";
export interface TokenDebugStep {
id: string;
messageId: string;
label: string;
secondaryLabels: string[];
usage: TokenUsage | null;
sharedAttribution: boolean;
}
type TokenUsageAttributionAction =
| {
kind: "todo_start" | "todo_complete" | "todo_update" | "todo_remove";
content?: string;
tool_call_id?: string;
}
| {
kind: "subagent";
description?: string | null;
subagent_type?: string | null;
tool_call_id?: string;
}
| {
kind: "search";
query?: string | null;
tool_name?: string | null;
tool_call_id?: string;
}
| {
kind: "present_files" | "clarification";
tool_call_id?: string;
}
| {
kind: "tool";
tool_name?: string | null;
description?: string | null;
tool_call_id?: string;
};
interface TokenUsageAttribution {
version?: number;
kind?:
| "thinking"
| "final_answer"
| "tool_batch"
| "todo_update"
| "subagent_dispatch";
shared_attribution?: boolean;
tool_call_ids?: string[];
actions?: TokenUsageAttributionAction[];
}
// Precise write_todos labels come from the backend attribution payload.
// The frontend fallback intentionally stays generic so we do not duplicate
// backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py
//::_build_todo_actions and risk the two diffing algorithms drifting apart.
export function getTokenUsageViewPreset(
preferences: TokenUsagePreferences,
): TokenUsageViewPreset {
if (!preferences.headerTotal && preferences.inlineMode === "off") {
return "off";
}
if (preferences.headerTotal && preferences.inlineMode === "off") {
return "summary";
}
if (preferences.inlineMode === "step_debug") {
return "debug";
}
return "per_turn";
}
export function tokenUsagePreferencesFromPreset(
preset: TokenUsageViewPreset,
): TokenUsagePreferences {
switch (preset) {
case "off":
return { headerTotal: false, inlineMode: "off" };
case "summary":
return { headerTotal: true, inlineMode: "off" };
case "debug":
return { headerTotal: true, inlineMode: "step_debug" };
case "per_turn":
default:
return { headerTotal: true, inlineMode: "per_turn" };
}
}
export function buildTokenDebugSteps(
messages: Message[],
t: Translations,
): TokenDebugStep[] {
const steps: TokenDebugStep[] = [];
for (const [index, message] of messages.entries()) {
if (message.type !== "ai") {
continue;
}
const usage = getUsageMetadata(message);
const attribution = getTokenUsageAttribution(message);
const actionLabels: string[] = [];
if (attribution) {
actionLabels.push(...buildActionLabelsFromAttribution(attribution, t));
if (actionLabels.length === 0) {
if (attribution.kind === "final_answer") {
actionLabels.push(t.tokenUsage.finalAnswer);
} else if (attribution.kind === "thinking") {
actionLabels.push(t.common.thinking);
}
}
if (actionLabels.length > 0) {
const sharedAttribution =
attribution.shared_attribution ?? actionLabels.length > 1;
steps.push({
id: message.id ?? `token-step-${index}`,
messageId: message.id ?? `token-step-${index}`,
label:
sharedAttribution && actionLabels.length > 1
? t.tokenUsage.stepTotal
: actionLabels[0]!,
secondaryLabels:
sharedAttribution && actionLabels.length > 1 ? actionLabels : [],
usage,
sharedAttribution,
});
continue;
}
}
for (const toolCall of message.tool_calls ?? []) {
const toolArgs = (toolCall.args ?? {}) as Record<string, unknown>;
if (toolCall.name === "write_todos") {
actionLabels.push(t.toolCalls.writeTodos);
continue;
}
actionLabels.push(
describeToolCall(
{
name: toolCall.name,
args: toolArgs,
},
t,
),
);
}
if (actionLabels.length === 0) {
if (hasContent(message)) {
actionLabels.push(t.tokenUsage.finalAnswer);
} else {
actionLabels.push(t.common.thinking);
}
}
steps.push({
id: message.id ?? `token-step-${index}`,
messageId: message.id ?? `token-step-${index}`,
label:
actionLabels.length === 1 ? actionLabels[0]! : t.tokenUsage.stepTotal,
secondaryLabels: actionLabels.length > 1 ? actionLabels : [],
usage,
sharedAttribution: actionLabels.length > 1,
});
}
return steps;
}
function getTokenUsageAttribution(
message: Message,
): TokenUsageAttribution | null {
if (message.type !== "ai") {
return null;
}
const additionalKwargs = message.additional_kwargs;
if (!additionalKwargs || typeof additionalKwargs !== "object") {
return null;
}
const attribution = (additionalKwargs as Record<string, unknown>)
.token_usage_attribution;
const normalized = normalizeTokenUsageAttribution(attribution);
if (!normalized) {
return null;
}
return normalized;
}
function buildActionLabelsFromAttribution(
attribution: TokenUsageAttribution,
t: Translations,
): string[] {
return (attribution.actions ?? [])
.map((action) => describeAttributionAction(action, t))
.filter((label): label is string => !!label);
}
function describeAttributionAction(
action: TokenUsageAttributionAction,
t: Translations,
): string | null {
switch (action.kind) {
case "todo_start":
return action.content
? t.tokenUsage.startTodo(action.content)
: t.toolCalls.writeTodos;
case "todo_complete":
return action.content
? t.tokenUsage.completeTodo(action.content)
: t.toolCalls.writeTodos;
case "todo_update":
return action.content
? t.tokenUsage.updateTodo(action.content)
: t.toolCalls.writeTodos;
case "todo_remove":
return action.content
? t.tokenUsage.removeTodo(action.content)
: t.toolCalls.writeTodos;
case "subagent":
return t.tokenUsage.subagent(action.description ?? t.subtasks.subtask);
case "search":
if (action.query) {
return t.toolCalls.searchFor(action.query);
}
return t.toolCalls.useTool(action.tool_name ?? "search");
case "present_files":
return t.toolCalls.presentFiles;
case "clarification":
return t.toolCalls.needYourHelp;
case "tool":
return describeToolCall(
{
name: action.tool_name ?? "tool",
args: action.description ? { description: action.description } : {},
},
t,
);
default:
return null;
}
}
function describeToolCall(
toolCall: {
name: string;
args: Record<string, unknown>;
},
t: Translations,
): string {
if (toolCall.name === "task") {
const description =
typeof toolCall.args.description === "string"
? toolCall.args.description
: t.subtasks.subtask;
return t.tokenUsage.subagent(description);
}
if (
(toolCall.name === "web_search" || toolCall.name === "image_search") &&
typeof toolCall.args.query === "string"
) {
return t.toolCalls.searchFor(toolCall.args.query);
}
if (toolCall.name === "web_fetch") {
return t.toolCalls.viewWebPage;
}
if (toolCall.name === "present_files") {
return t.toolCalls.presentFiles;
}
if (toolCall.name === "ask_clarification") {
return t.toolCalls.needYourHelp;
}
if (typeof toolCall.args.description === "string") {
return toolCall.args.description;
}
return t.toolCalls.useTool(toolCall.name);
}
function normalizeTokenUsageAttribution(
value: unknown,
): TokenUsageAttribution | null {
const record = asRecord(value);
if (!record) {
return null;
}
const rawActions = record.actions;
if (rawActions !== undefined && !Array.isArray(rawActions)) {
return null;
}
return {
// Versioning is additive for now: the frontend should ignore unknown
// fields and fall back when required fields become incompatible.
version: typeof record.version === "number" ? record.version : undefined,
kind: isTokenUsageAttributionKind(record.kind) ? record.kind : undefined,
shared_attribution:
typeof record.shared_attribution === "boolean"
? record.shared_attribution
: undefined,
tool_call_ids: Array.isArray(record.tool_call_ids)
? record.tool_call_ids.filter(
(toolCallId): toolCallId is string =>
typeof toolCallId === "string" && toolCallId.trim().length > 0,
)
: undefined,
actions: Array.isArray(rawActions)
? rawActions
.map((action) => normalizeTokenUsageAttributionAction(action))
.filter(
(action): action is TokenUsageAttributionAction => action !== null,
)
: undefined,
};
}
function normalizeTokenUsageAttributionAction(
value: unknown,
): TokenUsageAttributionAction | null {
const record = asRecord(value);
if (!record) {
return null;
}
const kind = record.kind;
if (
kind !== "todo_start" &&
kind !== "todo_complete" &&
kind !== "todo_update" &&
kind !== "todo_remove" &&
kind !== "subagent" &&
kind !== "search" &&
kind !== "present_files" &&
kind !== "clarification" &&
kind !== "tool"
) {
return null;
}
const content = readString(record.content);
const toolCallId = readString(record.tool_call_id);
switch (kind) {
case "todo_start":
case "todo_complete":
case "todo_update":
case "todo_remove":
return {
kind,
content,
tool_call_id: toolCallId,
};
case "subagent":
return {
kind,
description: readString(record.description),
subagent_type: readString(record.subagent_type),
tool_call_id: toolCallId,
};
case "search":
return {
kind,
query: readString(record.query),
tool_name: readString(record.tool_name),
tool_call_id: toolCallId,
};
case "present_files":
case "clarification":
return {
kind,
tool_call_id: toolCallId,
};
case "tool":
return {
kind,
tool_name: readString(record.tool_name),
description: readString(record.description),
tool_call_id: toolCallId,
};
default:
return null;
}
}
function asRecord(value: unknown): Record<string, unknown> | null {
if (!value || typeof value !== "object" || Array.isArray(value)) {
return null;
}
return value as Record<string, unknown>;
}
function readString(value: unknown): string | undefined {
if (typeof value !== "string") {
return undefined;
}
const normalized = value.trim();
return normalized.length > 0 ? normalized : undefined;
}
function isTokenUsageAttributionKind(
value: unknown,
): value is NonNullable<TokenUsageAttribution["kind"]> {
return (
value === "thinking" ||
value === "final_answer" ||
value === "tool_batch" ||
value === "todo_update" ||
value === "subagent_dispatch"
);
}
+6 -44
View File
@@ -18,7 +18,7 @@ interface AssistantClarificationGroup extends GenericMessageGroup<"assistant:cla
interface AssistantSubagentGroup extends GenericMessageGroup<"assistant:subagent"> {}
export type MessageGroup =
type MessageGroup =
| HumanMessageGroup
| AssistantProcessingGroup
| AssistantMessageGroup
@@ -26,7 +26,10 @@ export type MessageGroup =
| AssistantClarificationGroup
| AssistantSubagentGroup;
export function getMessageGroups(messages: Message[]): MessageGroup[] {
export function groupMessages<T>(
messages: Message[],
mapper: (group: MessageGroup) => T,
): T[] {
if (messages.length === 0) {
return [];
}
@@ -121,52 +124,11 @@ export function getMessageGroups(messages: Message[]): MessageGroup[] {
}
}
return groups;
}
export function groupMessages<T>(
messages: Message[],
mapper: (group: MessageGroup) => T,
): T[] {
return getMessageGroups(messages)
return groups
.map(mapper)
.filter((result) => result !== undefined && result !== null) as T[];
}
export function getAssistantTurnUsageMessages(groups: MessageGroup[]) {
const usageMessagesByGroupIndex: Array<Message[] | null> = Array.from(
{ length: groups.length },
() => null,
);
let turnStartIndex: number | null = null;
for (const [index, group] of groups.entries()) {
if (group.type === "human") {
turnStartIndex = null;
continue;
}
turnStartIndex ??= index;
const nextGroup = groups[index + 1];
const isTurnEnd = !nextGroup || nextGroup.type === "human";
if (!isTurnEnd) {
continue;
}
usageMessagesByGroupIndex[index] = groups
.slice(turnStartIndex, index + 1)
.flatMap((currentGroup) => currentGroup.messages)
.filter((message) => message.type === "ai");
turnStartIndex = null;
}
return usageMessagesByGroupIndex;
}
export function extractTextFromMessage(message: Message) {
if (typeof message.content === "string") {
return (
-13
View File
@@ -1,14 +1,9 @@
import type { TokenUsageInlineMode } from "../messages/usage-model";
import type { AgentThreadContext } from "../threads";
export const DEFAULT_LOCAL_SETTINGS: LocalSettings = {
notification: {
enabled: true,
},
tokenUsage: {
headerTotal: true,
inlineMode: "per_turn",
},
context: {
model_name: undefined,
mode: undefined,
@@ -27,10 +22,6 @@ export interface LocalSettings {
notification: {
enabled: boolean;
};
tokenUsage: {
headerTotal: boolean;
inlineMode: TokenUsageInlineMode;
};
context: Omit<
AgentThreadContext,
| "thread_id"
@@ -53,10 +44,6 @@ function mergeLocalSettings(settings?: Partial<LocalSettings>): LocalSettings {
...DEFAULT_LOCAL_SETTINGS.context,
...settings?.context,
},
tokenUsage: {
...DEFAULT_LOCAL_SETTINGS.tokenUsage,
...settings?.tokenUsage,
},
notification: {
...DEFAULT_LOCAL_SETTINGS.notification,
...settings?.notification,
@@ -1,396 +0,0 @@
import type { Message } from "@langchain/langgraph-sdk";
import { expect, test } from "vitest";
import { enUS } from "@/core/i18n";
import {
buildTokenDebugSteps,
getTokenUsageViewPreset,
tokenUsagePreferencesFromPreset,
} from "@/core/messages/usage-model";
test("maps token usage presets to persisted preferences", () => {
expect(tokenUsagePreferencesFromPreset("off")).toEqual({
headerTotal: false,
inlineMode: "off",
});
expect(tokenUsagePreferencesFromPreset("summary")).toEqual({
headerTotal: true,
inlineMode: "off",
});
expect(tokenUsagePreferencesFromPreset("per_turn")).toEqual({
headerTotal: true,
inlineMode: "per_turn",
});
expect(tokenUsagePreferencesFromPreset("debug")).toEqual({
headerTotal: true,
inlineMode: "step_debug",
});
});
test("derives the active preset from persisted preferences", () => {
expect(
getTokenUsageViewPreset({
headerTotal: false,
inlineMode: "off",
}),
).toBe("off");
expect(
getTokenUsageViewPreset({
headerTotal: true,
inlineMode: "off",
}),
).toBe("summary");
expect(
getTokenUsageViewPreset({
headerTotal: true,
inlineMode: "per_turn",
}),
).toBe("per_turn");
expect(
getTokenUsageViewPreset({
headerTotal: true,
inlineMode: "step_debug",
}),
).toBe("debug");
});
test("uses generic todo labels when backend attribution is absent", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [
{
id: "write_todos:1",
name: "write_todos",
args: {
todos: [{ content: "Draft the plan", status: "in_progress" }],
},
},
],
usage_metadata: {
input_tokens: 100,
output_tokens: 20,
total_tokens: 120,
},
},
{
id: "tool-1",
type: "tool",
name: "write_todos",
tool_call_id: "write_todos:1",
content: "ok",
},
{
id: "ai-2",
type: "ai",
content: "",
tool_calls: [
{
id: "write_todos:2",
name: "write_todos",
args: {
todos: [{ content: "Draft the plan", status: "completed" }],
},
},
],
usage_metadata: { input_tokens: 50, output_tokens: 10, total_tokens: 60 },
},
{
id: "ai-3",
type: "ai",
content: "Here is the result",
usage_metadata: { input_tokens: 40, output_tokens: 15, total_tokens: 55 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: "Update to-do list",
sharedAttribution: false,
}),
expect.objectContaining({
messageId: "ai-2",
label: "Update to-do list",
sharedAttribution: false,
}),
expect.objectContaining({
messageId: "ai-3",
label: "Final answer",
sharedAttribution: false,
}),
]);
});
test("marks multi-action AI steps as shared attribution", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [
{
id: "web_search:1",
name: "web_search",
args: { query: "LangGraph stream mode" },
},
{
id: "write_todos:1",
name: "write_todos",
args: {
todos: [
{
content: "Inspect stream mode handling",
status: "in_progress",
},
],
},
},
],
usage_metadata: {
input_tokens: 120,
output_tokens: 30,
total_tokens: 150,
},
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: "Step total",
sharedAttribution: true,
secondaryLabels: [
'Search for "LangGraph stream mode"',
"Update to-do list",
],
}),
]);
});
test("prefers backend attribution metadata when available", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [
{
id: "write_todos:1",
name: "write_todos",
args: {
todos: [
{
content: "Fallback label should not win",
status: "in_progress",
},
],
},
},
],
additional_kwargs: {
token_usage_attribution: {
version: 1,
kind: "todo_update",
shared_attribution: false,
actions: [{ kind: "todo_start", content: "Use backend attribution" }],
},
},
usage_metadata: { input_tokens: 25, output_tokens: 5, total_tokens: 30 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: "Start To-do: Use backend attribution",
sharedAttribution: false,
}),
]);
});
test("falls back safely when attribution payload is malformed", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [
{
id: "web_search:1",
name: "web_search",
args: { query: "LangGraph stream mode" },
},
],
additional_kwargs: {
token_usage_attribution: {
version: 1,
kind: "tool_batch",
actions: { broken: true },
},
},
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: 'Search for "LangGraph stream mode"',
sharedAttribution: false,
}),
]);
});
test("ignores attribution actions that are not objects", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [],
additional_kwargs: {
token_usage_attribution: {
version: 1,
kind: "tool_batch",
shared_attribution: true,
actions: [
null,
"bad-action",
{ kind: "search", query: "valid search", ignored: "extra-field" },
],
},
},
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: 'Search for "valid search"',
}),
]);
});
test("ignores malformed attribution fields and falls back to message content", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "Real final answer",
tool_calls: [],
additional_kwargs: {
token_usage_attribution: {
version: 1,
kind: null,
shared_attribution: null,
tool_call_ids: [null, "tool-1", 123],
actions: [{ query: "missing kind" }],
},
},
usage_metadata: { input_tokens: 9, output_tokens: 3, total_tokens: 12 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: "Final answer",
sharedAttribution: false,
}),
]);
});
test("ignores unknown top-level attribution fields", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [],
additional_kwargs: {
token_usage_attribution: {
version: 1,
kind: "tool_batch",
shared_attribution: false,
unknown_field: "ignored",
actions: [{ kind: "subagent", description: "Inspect the fix" }],
},
},
usage_metadata: { input_tokens: 12, output_tokens: 4, total_tokens: 16 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: "Subagent: Inspect the fix",
sharedAttribution: false,
}),
]);
});
test("falls back to generic todo labels when backend attribution has no actions", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [
{
id: "write_todos:1",
name: "write_todos",
args: {
todos: [{ content: "Clean up stale tasks", status: "in_progress" }],
},
},
],
usage_metadata: {
input_tokens: 100,
output_tokens: 20,
total_tokens: 120,
},
},
{
id: "ai-2",
type: "ai",
content: "",
tool_calls: [
{
id: "write_todos:2",
name: "write_todos",
args: {
todos: [],
},
},
],
additional_kwargs: {
token_usage_attribution: {
version: 1,
kind: "todo_update",
shared_attribution: false,
actions: [],
},
},
usage_metadata: { input_tokens: 30, output_tokens: 8, total_tokens: 38 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: "Update to-do list",
}),
expect.objectContaining({
messageId: "ai-2",
label: "Update to-do list",
sharedAttribution: false,
}),
]);
});
@@ -1,65 +0,0 @@
import type { Message } from "@langchain/langgraph-sdk";
import { expect, test } from "vitest";
import {
getAssistantTurnUsageMessages,
getMessageGroups,
} from "@/core/messages/utils";
test("aggregates token usage messages once per assistant turn", () => {
const messages = [
{
id: "human-1",
type: "human",
content: "Plan a trip",
},
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [{ id: "tool-1", name: "web_search", args: {} }],
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
},
{
id: "tool-1-result",
type: "tool",
name: "web_search",
tool_call_id: "tool-1",
content: "[]",
},
{
id: "ai-2",
type: "ai",
content: "Here is the itinerary",
usage_metadata: { input_tokens: 2, output_tokens: 8, total_tokens: 10 },
},
{
id: "human-2",
type: "human",
content: "Make it shorter",
},
{
id: "ai-3",
type: "ai",
content: "Short version",
usage_metadata: { input_tokens: 1, output_tokens: 1, total_tokens: 2 },
},
] as Message[];
const groups = getMessageGroups(messages);
const usageMessagesByGroupIndex = getAssistantTurnUsageMessages(groups);
expect(groups.map((group) => group.type)).toEqual([
"human",
"assistant:processing",
"assistant",
"human",
"assistant",
]);
expect(
usageMessagesByGroupIndex.map(
(groupMessages) => groupMessages?.map((message) => message.id) ?? null,
),
).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]);
});