Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2eb45e9bb5 |
@@ -1,6 +1,3 @@
|
||||
# Serper API Key (Google Search) - https://serper.dev
|
||||
SERPER_API_KEY=your-serper-api-key
|
||||
|
||||
# TAVILY API Key
|
||||
TAVILY_API_KEY=your-tavily-api-key
|
||||
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
name: Publish Containers
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
|
||||
jobs:
|
||||
|
||||
backend-container:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
attestations: write
|
||||
id-token: write
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}-backend
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=tag
|
||||
type=ref,event=branch
|
||||
type=sha
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
- name: Build and push Docker image
|
||||
id: push
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
|
||||
with:
|
||||
context: .
|
||||
file: backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- name: Generate artifact attestation
|
||||
uses: actions/attest-build-provenance@v2
|
||||
with:
|
||||
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
|
||||
subject-digest: ${{ steps.push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
|
||||
frontend-container:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
attestations: write
|
||||
id-token: write
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}-frontend
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=tag
|
||||
type=ref,event=branch
|
||||
type=sha
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
- name: Build and push Docker image
|
||||
id: push
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
|
||||
with:
|
||||
context: .
|
||||
file: frontend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- name: Generate artifact attestation
|
||||
uses: actions/attest-build-provenance@v2
|
||||
with:
|
||||
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
|
||||
subject-digest: ${{ steps.push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
@@ -50,12 +50,6 @@ COPY backend ./backend
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
|
||||
|
||||
# UTF-8 locale prevents UnicodeEncodeError on Chinese/emoji content in minimal
|
||||
# containers where locale configuration may be missing and the default encoding is not UTF-8.
|
||||
ENV LANG=C.UTF-8
|
||||
ENV LC_ALL=C.UTF-8
|
||||
ENV PYTHONIOENCODING=utf-8
|
||||
|
||||
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
|
||||
# Retains compiler toolchain from builder so startup-time `uv sync` can build
|
||||
# source distributions in development containers.
|
||||
@@ -72,10 +66,6 @@ CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app
|
||||
# Clean image without build-essential — reduces size (~200 MB) and attack surface.
|
||||
FROM python:3.12-slim-bookworm
|
||||
|
||||
ENV LANG=C.UTF-8
|
||||
ENV LC_ALL=C.UTF-8
|
||||
ENV PYTHONIOENCODING=utf-8
|
||||
|
||||
# Copy Node.js runtime from builder (provides npx for MCP servers)
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
|
||||
@@ -420,13 +420,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
||||
if not msg.files:
|
||||
return []
|
||||
|
||||
from deerflow.uploads.manager import (
|
||||
UnsafeUploadPathError,
|
||||
claim_unique_filename,
|
||||
ensure_uploads_dir,
|
||||
normalize_filename,
|
||||
write_upload_file_no_symlink,
|
||||
)
|
||||
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
|
||||
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
||||
@@ -477,10 +471,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
||||
|
||||
dest = uploads_dir / safe_name
|
||||
try:
|
||||
dest = write_upload_file_no_symlink(uploads_dir, safe_name, data)
|
||||
except UnsafeUploadPathError:
|
||||
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
|
||||
continue
|
||||
dest.write_bytes(data)
|
||||
except Exception:
|
||||
logger.exception("[Manager] failed to write inbound file: %s", dest)
|
||||
continue
|
||||
|
||||
@@ -13,11 +13,11 @@ matching the LangGraph Platform wire format expected by the
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
@@ -26,7 +26,6 @@ from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.utils.time import coerce_iso, now_iso
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||
@@ -234,7 +233,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_store = get_thread_store(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = now_iso()
|
||||
now = time.time()
|
||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||
|
||||
@@ -244,8 +243,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=existing_record.get("status", "idle"),
|
||||
created_at=coerce_iso(existing_record.get("created_at", "")),
|
||||
updated_at=coerce_iso(existing_record.get("updated_at", "")),
|
||||
created_at=str(existing_record.get("created_at", "")),
|
||||
updated_at=str(existing_record.get("updated_at", "")),
|
||||
metadata=existing_record.get("metadata", {}),
|
||||
)
|
||||
|
||||
@@ -263,6 +262,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
# Write an empty checkpoint so state endpoints work immediately
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
ckpt_metadata = {
|
||||
"step": -1,
|
||||
"source": "input",
|
||||
@@ -280,8 +281,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status="idle",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
created_at=str(now),
|
||||
updated_at=str(now),
|
||||
metadata=body.metadata,
|
||||
)
|
||||
|
||||
@@ -306,11 +307,8 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
ThreadResponse(
|
||||
thread_id=r["thread_id"],
|
||||
status=r.get("status", "idle"),
|
||||
# ``coerce_iso`` heals legacy unix-second values that
|
||||
# ``MemoryThreadMetaStore`` historically wrote with ``time.time()``;
|
||||
# SQL-backed rows already arrive as ISO strings and pass through.
|
||||
created_at=coerce_iso(r.get("created_at", "")),
|
||||
updated_at=coerce_iso(r.get("updated_at", "")),
|
||||
created_at=r.get("created_at", ""),
|
||||
updated_at=r.get("updated_at", ""),
|
||||
metadata=r.get("metadata", {}),
|
||||
values={"title": r["display_name"]} if r.get("display_name") else {},
|
||||
interrupts={},
|
||||
@@ -342,8 +340,8 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=record.get("status", "idle"),
|
||||
created_at=coerce_iso(record.get("created_at", "")),
|
||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
)
|
||||
|
||||
@@ -383,8 +381,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"status": "idle",
|
||||
"created_at": coerce_iso(ckpt_meta.get("created_at", "")),
|
||||
"updated_at": coerce_iso(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
|
||||
"created_at": ckpt_meta.get("created_at", ""),
|
||||
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
|
||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
||||
}
|
||||
|
||||
@@ -398,8 +396,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=status,
|
||||
created_at=coerce_iso(record.get("created_at", "")),
|
||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
values=serialize_channel_values(channel_values),
|
||||
)
|
||||
@@ -450,10 +448,10 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
||||
values=values,
|
||||
next=next_tasks,
|
||||
metadata=metadata,
|
||||
checkpoint={"id": checkpoint_id, "ts": coerce_iso(metadata.get("created_at", ""))},
|
||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||
checkpoint_id=checkpoint_id,
|
||||
parent_checkpoint_id=parent_checkpoint_id,
|
||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
tasks=tasks,
|
||||
)
|
||||
|
||||
@@ -503,7 +501,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
channel_values.update(body.values)
|
||||
|
||||
checkpoint["channel_values"] = channel_values
|
||||
metadata["updated_at"] = now_iso()
|
||||
metadata["updated_at"] = time.time()
|
||||
|
||||
if body.as_node:
|
||||
metadata["source"] = "update"
|
||||
@@ -544,7 +542,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
next=[],
|
||||
metadata=metadata,
|
||||
checkpoint_id=new_checkpoint_id,
|
||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
)
|
||||
|
||||
|
||||
@@ -611,7 +609,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
parent_checkpoint_id=parent_id,
|
||||
metadata=user_meta,
|
||||
values=values,
|
||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
next=next_tasks,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import stat
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_config
|
||||
@@ -15,14 +15,12 @@ from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
|
||||
from deerflow.uploads.manager import (
|
||||
PathTraversalError,
|
||||
UnsafeUploadPathError,
|
||||
delete_file_safe,
|
||||
enrich_file_listing,
|
||||
ensure_uploads_dir,
|
||||
get_uploads_dir,
|
||||
list_files_in_dir,
|
||||
normalize_filename,
|
||||
open_upload_file_no_symlink,
|
||||
upload_artifact_url,
|
||||
upload_virtual_path,
|
||||
)
|
||||
@@ -44,7 +42,6 @@ class UploadResponse(BaseModel):
|
||||
success: bool
|
||||
files: list[dict[str, str]]
|
||||
message: str
|
||||
skipped_files: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UploadLimits(BaseModel):
|
||||
@@ -119,18 +116,17 @@ def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None:
|
||||
logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True)
|
||||
|
||||
|
||||
async def _write_upload_file_with_limits(
|
||||
async def _write_upload_file_streaming(
|
||||
file: UploadFile,
|
||||
file_path: os.PathLike[str] | str,
|
||||
*,
|
||||
uploads_dir: os.PathLike[str] | str,
|
||||
display_filename: str,
|
||||
max_single_file_size: int,
|
||||
max_total_size: int,
|
||||
total_size: int,
|
||||
) -> tuple[os.PathLike[str] | str, int, int]:
|
||||
) -> tuple[int, int]:
|
||||
file_size = 0
|
||||
file_path, fh = open_upload_file_no_symlink(uploads_dir, display_filename)
|
||||
try:
|
||||
with open(file_path, "wb") as output:
|
||||
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
|
||||
file_size += len(chunk)
|
||||
total_size += len(chunk)
|
||||
@@ -138,17 +134,8 @@ async def _write_upload_file_with_limits(
|
||||
raise HTTPException(status_code=413, detail=f"File too large: {display_filename}")
|
||||
if total_size > max_total_size:
|
||||
raise HTTPException(status_code=413, detail="Total upload size too large")
|
||||
fh.write(chunk)
|
||||
except Exception:
|
||||
fh.close()
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
raise
|
||||
else:
|
||||
fh.close()
|
||||
return file_path, file_size, total_size
|
||||
output.write(chunk)
|
||||
return file_size, total_size
|
||||
|
||||
|
||||
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
|
||||
@@ -190,7 +177,6 @@ async def upload_files(
|
||||
uploaded_files = []
|
||||
written_paths = []
|
||||
sandbox_sync_targets = []
|
||||
skipped_files = []
|
||||
total_size = 0
|
||||
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
@@ -214,15 +200,16 @@ async def upload_files(
|
||||
continue
|
||||
|
||||
try:
|
||||
file_path, file_size, total_size = await _write_upload_file_with_limits(
|
||||
file_path = uploads_dir / safe_filename
|
||||
written_paths.append(file_path)
|
||||
file_size, total_size = await _write_upload_file_streaming(
|
||||
file,
|
||||
uploads_dir=uploads_dir,
|
||||
file_path,
|
||||
display_filename=safe_filename,
|
||||
max_single_file_size=limits.max_file_size,
|
||||
max_total_size=limits.max_total_size,
|
||||
total_size=total_size,
|
||||
)
|
||||
written_paths.append(file_path)
|
||||
|
||||
virtual_path = upload_virtual_path(safe_filename)
|
||||
|
||||
@@ -259,10 +246,6 @@ async def upload_files(
|
||||
except HTTPException as e:
|
||||
_cleanup_uploaded_paths(written_paths)
|
||||
raise e
|
||||
except UnsafeUploadPathError as e:
|
||||
logger.warning("Skipping upload with unsafe destination %s: %s", file.filename, e)
|
||||
skipped_files.append(safe_filename)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload {file.filename}: {e}")
|
||||
_cleanup_uploaded_paths(written_paths)
|
||||
@@ -273,15 +256,10 @@ async def upload_files(
|
||||
_make_file_sandbox_writable(file_path)
|
||||
sandbox.update_file(virtual_path, file_path.read_bytes())
|
||||
|
||||
message = f"Successfully uploaded {len(uploaded_files)} file(s)"
|
||||
if skipped_files:
|
||||
message += f"; skipped {len(skipped_files)} unsafe file(s)"
|
||||
|
||||
return UploadResponse(
|
||||
success=not skipped_files,
|
||||
success=True,
|
||||
files=uploaded_files,
|
||||
message=message,
|
||||
skipped_files=skipped_files,
|
||||
message=f"Successfully uploaded {len(uploaded_files)} file(s)",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,270 +1,31 @@
|
||||
"""Middleware for logging token usage and annotating step attribution."""
|
||||
|
||||
from __future__ import annotations
|
||||
"""Middleware for logging LLM token usage."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, override
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.todo import Todo
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOKEN_USAGE_ATTRIBUTION_KEY = "token_usage_attribution"
|
||||
|
||||
|
||||
def _string_arg(value: Any) -> str | None:
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip()
|
||||
return normalized or None
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_todos(value: Any) -> list[Todo]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
|
||||
normalized: list[Todo] = []
|
||||
for item in value:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
todo: Todo = {}
|
||||
content = _string_arg(item.get("content"))
|
||||
status = item.get("status")
|
||||
|
||||
if content is not None:
|
||||
todo["content"] = content
|
||||
if status in {"pending", "in_progress", "completed"}:
|
||||
todo["status"] = status
|
||||
|
||||
normalized.append(todo)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _todo_action_kind(previous: Todo | None, current: Todo) -> str:
|
||||
status = current.get("status")
|
||||
previous_content = previous.get("content") if previous else None
|
||||
current_content = current.get("content")
|
||||
|
||||
if previous is None:
|
||||
if status == "completed":
|
||||
return "todo_complete"
|
||||
if status == "in_progress":
|
||||
return "todo_start"
|
||||
return "todo_update"
|
||||
|
||||
if previous_content != current_content:
|
||||
return "todo_update"
|
||||
|
||||
if status == "completed":
|
||||
return "todo_complete"
|
||||
if status == "in_progress":
|
||||
return "todo_start"
|
||||
return "todo_update"
|
||||
|
||||
|
||||
def _build_todo_actions(previous_todos: list[Todo], next_todos: list[Todo]) -> list[dict[str, Any]]:
|
||||
# This is the single source of truth for precise write_todos token
|
||||
# attribution. The frontend intentionally falls back to a generic
|
||||
# "Update to-do list" label when this metadata is missing or malformed.
|
||||
previous_by_content: dict[str, list[tuple[int, Todo]]] = defaultdict(list)
|
||||
matched_previous_indices: set[int] = set()
|
||||
|
||||
for index, todo in enumerate(previous_todos):
|
||||
content = todo.get("content")
|
||||
if isinstance(content, str) and content:
|
||||
previous_by_content[content].append((index, todo))
|
||||
|
||||
actions: list[dict[str, Any]] = []
|
||||
|
||||
for index, todo in enumerate(next_todos):
|
||||
content = todo.get("content")
|
||||
if not isinstance(content, str) or not content:
|
||||
continue
|
||||
|
||||
previous_match: Todo | None = None
|
||||
content_matches = previous_by_content.get(content)
|
||||
if content_matches:
|
||||
while content_matches and content_matches[0][0] in matched_previous_indices:
|
||||
content_matches.pop(0)
|
||||
if content_matches:
|
||||
previous_index, previous_match = content_matches.pop(0)
|
||||
matched_previous_indices.add(previous_index)
|
||||
|
||||
if previous_match is None and index < len(previous_todos) and index not in matched_previous_indices:
|
||||
previous_match = previous_todos[index]
|
||||
matched_previous_indices.add(index)
|
||||
|
||||
if previous_match is not None:
|
||||
previous_content = previous_match.get("content")
|
||||
previous_status = previous_match.get("status")
|
||||
if previous_content == content and previous_status == todo.get("status"):
|
||||
continue
|
||||
|
||||
actions.append(
|
||||
{
|
||||
"kind": _todo_action_kind(previous_match, todo),
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
for index, todo in enumerate(previous_todos):
|
||||
if index in matched_previous_indices:
|
||||
continue
|
||||
|
||||
content = todo.get("content")
|
||||
if not isinstance(content, str) or not content:
|
||||
continue
|
||||
|
||||
actions.append(
|
||||
{
|
||||
"kind": "todo_remove",
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
def _describe_tool_call(tool_call: dict[str, Any], todos: list[Todo]) -> list[dict[str, Any]]:
|
||||
name = _string_arg(tool_call.get("name")) or "unknown"
|
||||
args = tool_call.get("args") if isinstance(tool_call.get("args"), dict) else {}
|
||||
tool_call_id = _string_arg(tool_call.get("id"))
|
||||
|
||||
if name == "write_todos":
|
||||
next_todos = _normalize_todos(args.get("todos"))
|
||||
actions = _build_todo_actions(todos, next_todos)
|
||||
if not actions:
|
||||
return [
|
||||
{
|
||||
"kind": "tool",
|
||||
"tool_name": name,
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
return [
|
||||
{
|
||||
**action,
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
for action in actions
|
||||
]
|
||||
|
||||
if name == "task":
|
||||
return [
|
||||
{
|
||||
"kind": "subagent",
|
||||
"description": _string_arg(args.get("description")),
|
||||
"subagent_type": _string_arg(args.get("subagent_type")),
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
if name in {"web_search", "image_search"}:
|
||||
query = _string_arg(args.get("query"))
|
||||
return [
|
||||
{
|
||||
"kind": "search",
|
||||
"tool_name": name,
|
||||
"query": query,
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
if name == "present_files":
|
||||
return [
|
||||
{
|
||||
"kind": "present_files",
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
if name == "ask_clarification":
|
||||
return [
|
||||
{
|
||||
"kind": "clarification",
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
return [
|
||||
{
|
||||
"kind": "tool",
|
||||
"tool_name": name,
|
||||
"description": _string_arg(args.get("description")),
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
|
||||
if actions:
|
||||
first_kind = actions[0].get("kind")
|
||||
if len(actions) == 1 and first_kind in {"todo_start", "todo_complete", "todo_update", "todo_remove"}:
|
||||
return "todo_update"
|
||||
if len(actions) == 1 and first_kind == "subagent":
|
||||
return "subagent_dispatch"
|
||||
return "tool_batch"
|
||||
|
||||
if message.content:
|
||||
return "final_answer"
|
||||
return "thinking"
|
||||
|
||||
|
||||
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
actions: list[dict[str, Any]] = []
|
||||
current_todos = list(todos)
|
||||
|
||||
for raw_tool_call in tool_calls:
|
||||
if not isinstance(raw_tool_call, dict):
|
||||
continue
|
||||
|
||||
described_actions = _describe_tool_call(raw_tool_call, current_todos)
|
||||
actions.extend(described_actions)
|
||||
|
||||
if raw_tool_call.get("name") == "write_todos":
|
||||
args = raw_tool_call.get("args") if isinstance(raw_tool_call.get("args"), dict) else {}
|
||||
current_todos = _normalize_todos(args.get("todos"))
|
||||
|
||||
tool_call_ids: list[str] = []
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
|
||||
tool_call_id = _string_arg(tool_call.get("id"))
|
||||
if tool_call_id is not None:
|
||||
tool_call_ids.append(tool_call_id)
|
||||
|
||||
return {
|
||||
# Schema changes should remain additive where possible so older
|
||||
# frontends can ignore unknown fields and fall back safely.
|
||||
"version": 1,
|
||||
"kind": _infer_step_kind(message, actions),
|
||||
"shared_attribution": len(actions) > 1,
|
||||
"tool_call_ids": tool_call_ids,
|
||||
"actions": actions,
|
||||
}
|
||||
|
||||
|
||||
class TokenUsageMiddleware(AgentMiddleware):
|
||||
"""Logs token usage from model responses and annotates the AI step."""
|
||||
"""Logs token usage from model response usage_metadata."""
|
||||
|
||||
def _apply(self, state: AgentState) -> dict | None:
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._log_usage(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._log_usage(state)
|
||||
|
||||
def _log_usage(self, state: AgentState) -> None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
usage = getattr(last, "usage_metadata", None)
|
||||
if usage:
|
||||
logger.info(
|
||||
@@ -273,22 +34,4 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
usage.get("output_tokens", "?"),
|
||||
usage.get("total_tokens", "?"),
|
||||
)
|
||||
|
||||
todos = state.get("todos") or []
|
||||
attribution = _build_attribution(last, todos if isinstance(todos, list) else [])
|
||||
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
|
||||
|
||||
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
|
||||
return None
|
||||
|
||||
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
|
||||
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||
return {"messages": [updated_msg]}
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state)
|
||||
return None
|
||||
|
||||
@@ -228,14 +228,21 @@ class DeerFlowClient:
|
||||
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=self._app_config),
|
||||
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
|
||||
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
|
||||
"middleware": _build_middlewares(
|
||||
config,
|
||||
model_name=model_name,
|
||||
agent_name=self._agent_name,
|
||||
custom_middlewares=self._middlewares,
|
||||
app_config=self._app_config,
|
||||
),
|
||||
"system_prompt": apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled,
|
||||
max_concurrent_subagents=max_concurrent_subagents,
|
||||
agent_name=self._agent_name,
|
||||
available_skills=self._available_skills,
|
||||
app_config=self._app_config,
|
||||
),
|
||||
"state_schema": ThreadState,
|
||||
}
|
||||
@@ -243,7 +250,7 @@ class DeerFlowClient:
|
||||
if checkpointer is None:
|
||||
from deerflow.runtime.checkpointer import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
checkpointer = get_checkpointer(app_config=self._app_config)
|
||||
if checkpointer is not None:
|
||||
kwargs["checkpointer"] = checkpointer
|
||||
|
||||
@@ -251,12 +258,15 @@ class DeerFlowClient:
|
||||
self._agent_config_key = key
|
||||
logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled)
|
||||
|
||||
@staticmethod
|
||||
def _get_tools(*, model_name: str | None, subagent_enabled: bool):
|
||||
def _get_tools(self, *, model_name: str | None, subagent_enabled: bool):
|
||||
"""Lazy import to avoid circular dependency at module level."""
|
||||
from deerflow.tools import get_available_tools
|
||||
|
||||
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
|
||||
return get_available_tools(
|
||||
model_name=model_name,
|
||||
subagent_enabled=subagent_enabled,
|
||||
app_config=self._app_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_tool_calls(tool_calls) -> list[dict]:
|
||||
@@ -264,35 +274,25 @@ class DeerFlowClient:
|
||||
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
|
||||
|
||||
@staticmethod
|
||||
def _serialize_additional_kwargs(msg) -> dict[str, Any] | None:
|
||||
"""Copy message additional_kwargs when present."""
|
||||
additional_kwargs = getattr(msg, "additional_kwargs", None)
|
||||
if isinstance(additional_kwargs, dict) and additional_kwargs:
|
||||
return dict(additional_kwargs)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
|
||||
"""Build a ``messages-tuple`` AI text event."""
|
||||
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent":
|
||||
"""Build a ``messages-tuple`` AI text event, attaching usage when present."""
|
||||
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
|
||||
if usage:
|
||||
data["usage_metadata"] = usage
|
||||
if additional_kwargs:
|
||||
data["additional_kwargs"] = additional_kwargs
|
||||
return StreamEvent(type="messages-tuple", data=data)
|
||||
|
||||
@staticmethod
|
||||
def _ai_tool_calls_event(msg_id: str | None, tool_calls, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
|
||||
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent":
|
||||
"""Build a ``messages-tuple`` AI tool-calls event."""
|
||||
data: dict[str, Any] = {
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"id": msg_id,
|
||||
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
|
||||
}
|
||||
if additional_kwargs:
|
||||
data["additional_kwargs"] = additional_kwargs
|
||||
return StreamEvent(type="messages-tuple", data=data)
|
||||
return StreamEvent(
|
||||
type="messages-tuple",
|
||||
data={
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"id": msg_id,
|
||||
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _tool_message_event(msg: ToolMessage) -> "StreamEvent":
|
||||
@@ -317,30 +317,19 @@ class DeerFlowClient:
|
||||
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
|
||||
if getattr(msg, "usage_metadata", None):
|
||||
d["usage_metadata"] = msg.usage_metadata
|
||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
||||
d["additional_kwargs"] = additional_kwargs
|
||||
return d
|
||||
if isinstance(msg, ToolMessage):
|
||||
d = {
|
||||
return {
|
||||
"type": "tool",
|
||||
"content": DeerFlowClient._extract_text(msg.content),
|
||||
"name": getattr(msg, "name", None),
|
||||
"tool_call_id": getattr(msg, "tool_call_id", None),
|
||||
"id": getattr(msg, "id", None),
|
||||
}
|
||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
||||
d["additional_kwargs"] = additional_kwargs
|
||||
return d
|
||||
if isinstance(msg, HumanMessage):
|
||||
d = {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
||||
d["additional_kwargs"] = additional_kwargs
|
||||
return d
|
||||
return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
if isinstance(msg, SystemMessage):
|
||||
d = {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
||||
d["additional_kwargs"] = additional_kwargs
|
||||
return d
|
||||
return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
|
||||
|
||||
@staticmethod
|
||||
@@ -398,7 +387,7 @@ class DeerFlowClient:
|
||||
if checkpointer is None:
|
||||
from deerflow.runtime.checkpointer.provider import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
checkpointer = get_checkpointer(app_config=self._app_config)
|
||||
|
||||
thread_info_map = {}
|
||||
|
||||
@@ -453,7 +442,7 @@ class DeerFlowClient:
|
||||
if checkpointer is None:
|
||||
from deerflow.runtime.checkpointer.provider import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
checkpointer = get_checkpointer(app_config=self._app_config)
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
checkpoints = []
|
||||
@@ -563,7 +552,6 @@ class DeerFlowClient:
|
||||
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
|
||||
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
|
||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
|
||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "additional_kwargs": {...}}
|
||||
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
|
||||
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
|
||||
"""
|
||||
@@ -586,7 +574,6 @@ class DeerFlowClient:
|
||||
# in both the final ``messages`` chunk and the values snapshot —
|
||||
# count it only on whichever arrives first.
|
||||
counted_usage_ids: set[str] = set()
|
||||
sent_additional_kwargs_by_id: dict[str, dict[str, Any]] = {}
|
||||
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
|
||||
def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
|
||||
@@ -616,20 +603,6 @@ class DeerFlowClient:
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
def _unsent_additional_kwargs(msg_id: str | None, additional_kwargs: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if not additional_kwargs:
|
||||
return None
|
||||
if not msg_id:
|
||||
return additional_kwargs
|
||||
|
||||
sent = sent_additional_kwargs_by_id.setdefault(msg_id, {})
|
||||
delta = {key: value for key, value in additional_kwargs.items() if sent.get(key) != value}
|
||||
if not delta:
|
||||
return None
|
||||
|
||||
sent.update(delta)
|
||||
return delta
|
||||
|
||||
for item in self._agent.stream(
|
||||
state,
|
||||
config=config,
|
||||
@@ -657,31 +630,17 @@ class DeerFlowClient:
|
||||
|
||||
if isinstance(msg_chunk, AIMessage):
|
||||
text = self._extract_text(msg_chunk.content)
|
||||
additional_kwargs = self._serialize_additional_kwargs(msg_chunk)
|
||||
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
|
||||
sent_additional_kwargs = False
|
||||
|
||||
if text:
|
||||
if msg_id:
|
||||
streamed_ids.add(msg_id)
|
||||
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
yield self._ai_text_event(
|
||||
msg_id,
|
||||
text,
|
||||
counted_usage,
|
||||
additional_kwargs_delta,
|
||||
)
|
||||
sent_additional_kwargs = bool(additional_kwargs_delta)
|
||||
yield self._ai_text_event(msg_id, text, counted_usage)
|
||||
|
||||
if msg_chunk.tool_calls:
|
||||
if msg_id:
|
||||
streamed_ids.add(msg_id)
|
||||
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
yield self._ai_tool_calls_event(
|
||||
msg_id,
|
||||
msg_chunk.tool_calls,
|
||||
additional_kwargs_delta,
|
||||
)
|
||||
yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls)
|
||||
|
||||
elif isinstance(msg_chunk, ToolMessage):
|
||||
if msg_id:
|
||||
@@ -704,45 +663,17 @@ class DeerFlowClient:
|
||||
if msg_id and msg_id in streamed_ids:
|
||||
if isinstance(msg, AIMessage):
|
||||
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
|
||||
additional_kwargs = self._serialize_additional_kwargs(msg)
|
||||
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
if additional_kwargs_delta:
|
||||
# Metadata-only follow-up: ``messages-tuple`` has no
|
||||
# dedicated attribution event, so clients should
|
||||
# merge this empty-content AI event by message id
|
||||
# and ignore it for text rendering.
|
||||
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
|
||||
continue
|
||||
|
||||
if isinstance(msg, AIMessage):
|
||||
counted_usage = _account_usage(msg_id, msg.usage_metadata)
|
||||
additional_kwargs = self._serialize_additional_kwargs(msg)
|
||||
sent_additional_kwargs = False
|
||||
|
||||
if msg.tool_calls:
|
||||
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
yield self._ai_tool_calls_event(
|
||||
msg_id,
|
||||
msg.tool_calls,
|
||||
additional_kwargs_delta,
|
||||
)
|
||||
sent_additional_kwargs = bool(additional_kwargs_delta)
|
||||
yield self._ai_tool_calls_event(msg_id, msg.tool_calls)
|
||||
|
||||
text = self._extract_text(msg.content)
|
||||
if text:
|
||||
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
yield self._ai_text_event(
|
||||
msg_id,
|
||||
text,
|
||||
counted_usage,
|
||||
additional_kwargs_delta,
|
||||
)
|
||||
elif msg_id:
|
||||
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
if not additional_kwargs_delta:
|
||||
continue
|
||||
# See the metadata-only follow-up convention above.
|
||||
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
|
||||
yield self._ai_text_event(msg_id, text, counted_usage)
|
||||
|
||||
elif isinstance(msg, ToolMessage):
|
||||
yield self._tool_message_event(msg)
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .tools import web_search_tool
|
||||
|
||||
__all__ = ["web_search_tool"]
|
||||
@@ -1,95 +0,0 @@
|
||||
"""
|
||||
Web Search Tool - Search the web using Serper (Google Search API).
|
||||
|
||||
Serper provides real-time Google Search results via a JSON API.
|
||||
An API key is required. Sign up at https://serper.dev to get one.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SERPER_ENDPOINT = "https://google.serper.dev/search"
|
||||
_api_key_warned = False
|
||||
|
||||
|
||||
def _get_api_key() -> str | None:
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
if config is not None:
|
||||
api_key = config.model_extra.get("api_key")
|
||||
if isinstance(api_key, str) and api_key.strip():
|
||||
return api_key
|
||||
return os.getenv("SERPER_API_KEY")
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str, max_results: int = 5) -> str:
|
||||
"""Search the web for information using Google Search via Serper.
|
||||
|
||||
Args:
|
||||
query: Search keywords describing what you want to find. Be specific for better results.
|
||||
max_results: Maximum number of search results to return. Default is 5.
|
||||
"""
|
||||
global _api_key_warned
|
||||
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
if config is not None and "max_results" in config.model_extra:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
|
||||
api_key = _get_api_key()
|
||||
if not api_key:
|
||||
if not _api_key_warned:
|
||||
_api_key_warned = True
|
||||
logger.warning("Serper API key is not set. Set SERPER_API_KEY in your environment or provide api_key in config.yaml. Sign up at https://serper.dev")
|
||||
return json.dumps(
|
||||
{"error": "SERPER_API_KEY is not configured", "query": query},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {"q": query, "num": max_results}
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=30) as client:
|
||||
response = client.post(_SERPER_ENDPOINT, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Serper API returned HTTP {e.response.status_code}: {e.response.text}")
|
||||
return json.dumps(
|
||||
{"error": f"Serper API error: HTTP {e.response.status_code}", "query": query},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Serper search failed: {type(e).__name__}: {e}")
|
||||
return json.dumps({"error": str(e), "query": query}, ensure_ascii=False)
|
||||
|
||||
organic = data.get("organic", [])
|
||||
if not organic:
|
||||
return json.dumps({"error": "No results found", "query": query}, ensure_ascii=False)
|
||||
|
||||
normalized_results = [
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("link", ""),
|
||||
"content": r.get("snippet", ""),
|
||||
}
|
||||
for r in organic[:max_results]
|
||||
]
|
||||
|
||||
output = {
|
||||
"query": query,
|
||||
"total_results": len(normalized_results),
|
||||
"results": normalized_results,
|
||||
}
|
||||
return json.dumps(output, indent=2, ensure_ascii=False)
|
||||
@@ -6,13 +6,6 @@ from pydantic import BaseModel, Field
|
||||
from deerflow.config.runtime_paths import project_root, resolve_path
|
||||
|
||||
|
||||
def _legacy_skills_candidates() -> tuple[Path, ...]:
|
||||
"""Return source-tree skills locations for monorepo compatibility."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
return (repo_root / "skills",)
|
||||
|
||||
|
||||
class SkillsConfig(BaseModel):
|
||||
"""Configuration for skills system"""
|
||||
|
||||
@@ -22,7 +15,7 @@ class SkillsConfig(BaseModel):
|
||||
)
|
||||
path: str | None = Field(
|
||||
default=None,
|
||||
description=("Path to skills directory. If not specified, defaults to `skills` under the caller project root, falling back to the legacy repo-root location for monorepo compatibility."),
|
||||
description="Path to skills directory. If not specified, defaults to skills under the caller project root.",
|
||||
)
|
||||
container_path: str = Field(
|
||||
default="/mnt/skills",
|
||||
@@ -33,30 +26,15 @@ class SkillsConfig(BaseModel):
|
||||
"""
|
||||
Get the resolved skills directory path.
|
||||
|
||||
Resolution order:
|
||||
1. Explicit ``path`` field
|
||||
2. ``DEER_FLOW_SKILLS_PATH`` environment variable
|
||||
3. ``skills`` under the caller project root (``project_root()``)
|
||||
4. Legacy repo-root candidates for monorepo compatibility (``_legacy_skills_candidates``)
|
||||
|
||||
When none of (3) or (4) exist on disk, the project-root default is returned so callers
|
||||
can still surface a stable "no skills" location without raising.
|
||||
Returns:
|
||||
Path to the skills directory
|
||||
"""
|
||||
if self.path:
|
||||
# Use configured path (can be absolute or relative to project root)
|
||||
return resolve_path(self.path)
|
||||
if env_path := os.getenv("DEER_FLOW_SKILLS_PATH"):
|
||||
return resolve_path(env_path)
|
||||
|
||||
project_default = project_root() / "skills"
|
||||
if project_default.is_dir():
|
||||
return project_default
|
||||
|
||||
for candidate in _legacy_skills_candidates():
|
||||
if candidate.is_dir():
|
||||
return candidate
|
||||
|
||||
return project_default
|
||||
return project_root() / "skills"
|
||||
|
||||
def get_skill_container_path(self, skill_name: str, category: str = "public") -> str:
|
||||
"""
|
||||
|
||||
@@ -27,34 +27,6 @@ from deerflow.models.credential_loader import CodexCliCredential, load_codex_cli
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
def _build_usage_metadata(oai_usage: dict) -> dict:
|
||||
"""Convert Codex/Responses API usage dict to LangChain usage_metadata format.
|
||||
|
||||
Maps OpenAI Responses API token usage fields to the dict structure that
|
||||
LangChain AIMessage.usage_metadata expects. This avoids depending on
|
||||
langchain_openai private helpers like ``_create_usage_metadata_responses``.
|
||||
"""
|
||||
input_tokens = oai_usage.get("input_tokens", 0)
|
||||
output_tokens = oai_usage.get("output_tokens", 0)
|
||||
total_tokens = oai_usage.get("total_tokens", input_tokens + output_tokens)
|
||||
metadata: dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
input_details = oai_usage.get("input_tokens_details") or {}
|
||||
output_details = oai_usage.get("output_tokens_details") or {}
|
||||
cache_read = input_details.get("cached_tokens")
|
||||
if cache_read is not None:
|
||||
metadata["input_token_details"] = {"cache_read": cache_read}
|
||||
reasoning = output_details.get("reasoning_tokens")
|
||||
if reasoning is not None:
|
||||
metadata["output_token_details"] = {"reasoning": reasoning}
|
||||
return metadata
|
||||
|
||||
|
||||
MAX_RETRIES = 3
|
||||
|
||||
|
||||
@@ -374,7 +346,6 @@ class CodexChatModel(BaseChatModel):
|
||||
)
|
||||
|
||||
usage = response.get("usage", {})
|
||||
usage_metadata = _build_usage_metadata(usage) if usage else None
|
||||
additional_kwargs = {}
|
||||
if reasoning_content:
|
||||
additional_kwargs["reasoning_content"] = reasoning_content
|
||||
@@ -384,7 +355,6 @@ class CodexChatModel(BaseChatModel):
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
additional_kwargs=additional_kwargs,
|
||||
usage_metadata=usage_metadata,
|
||||
response_metadata={
|
||||
"model": response.get("model", self.model),
|
||||
"usage": usage,
|
||||
|
||||
@@ -7,13 +7,13 @@ router for thread records.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
from deerflow.utils.time import coerce_iso, now_iso
|
||||
|
||||
THREADS_NS: tuple[str, ...] = ("threads",)
|
||||
|
||||
@@ -48,7 +48,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
|
||||
now = now_iso()
|
||||
now = time.time()
|
||||
record: dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
@@ -106,7 +106,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
if record is None:
|
||||
return
|
||||
record["display_name"] = display_name
|
||||
record["updated_at"] = now_iso()
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
@@ -114,7 +114,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
if record is None:
|
||||
return
|
||||
record["status"] = status
|
||||
record["updated_at"] = now_iso()
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
@@ -124,7 +124,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
merged = dict(record.get("metadata") or {})
|
||||
merged.update(metadata)
|
||||
record["metadata"] = merged
|
||||
record["updated_at"] = now_iso()
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
@@ -144,8 +144,6 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
"display_name": val.get("display_name"),
|
||||
"status": val.get("status", "idle"),
|
||||
"metadata": val.get("metadata", {}),
|
||||
# ``coerce_iso`` heals legacy unix-second values written by
|
||||
# earlier Gateway versions that called ``str(time.time())``.
|
||||
"created_at": coerce_iso(val.get("created_at", "")),
|
||||
"updated_at": coerce_iso(val.get("updated_at", "")),
|
||||
"created_at": str(val.get("created_at", "")),
|
||||
"updated_at": str(val.get("updated_at", "")),
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ from collections.abc import Iterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig, get_app_config
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
@@ -98,9 +98,78 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
|
||||
|
||||
_checkpointer: Checkpointer | None = None
|
||||
_checkpointer_ctx = None # open context manager keeping the connection alive
|
||||
_explicit_checkpointers: dict[int, Checkpointer] = {}
|
||||
_explicit_checkpointer_contexts: dict[int, object] = {}
|
||||
|
||||
|
||||
def get_checkpointer() -> Checkpointer:
|
||||
def _default_in_memory_checkpointer() -> Checkpointer:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
return InMemorySaver()
|
||||
|
||||
|
||||
def _persistent_database_backend(db_config) -> str | None:
|
||||
backend = getattr(db_config, "backend", None)
|
||||
if backend in {"sqlite", "postgres"}:
|
||||
return backend
|
||||
return None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _sync_checkpointer_from_database_cm(db_config) -> Iterator[Checkpointer]:
|
||||
"""Context manager that creates a sync checkpointer from unified DatabaseConfig."""
|
||||
backend = _persistent_database_backend(db_config)
|
||||
if backend is None:
|
||||
yield _default_in_memory_checkpointer()
|
||||
return
|
||||
|
||||
if backend == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = db_config.checkpointer_sqlite_path
|
||||
ensure_sqlite_parent_dir(conn_str)
|
||||
with SqliteSaver.from_conn_string(conn_str) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
|
||||
yield saver
|
||||
return
|
||||
|
||||
if backend == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres import PostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not db_config.postgres_url:
|
||||
raise ValueError("database.postgres_url is required for the postgres backend")
|
||||
|
||||
with PostgresSaver.from_conn_string(db_config.postgres_url) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using PostgresSaver")
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown database backend: {backend!r}")
|
||||
|
||||
|
||||
def _build_checkpointer_from_app_config(app_config: AppConfig) -> tuple[Checkpointer, object | None]:
|
||||
if app_config.checkpointer is not None:
|
||||
ctx = _sync_checkpointer_cm(app_config.checkpointer)
|
||||
return ctx.__enter__(), ctx
|
||||
|
||||
db_config = getattr(app_config, "database", None)
|
||||
if _persistent_database_backend(db_config) is not None:
|
||||
ctx = _sync_checkpointer_from_database_cm(db_config)
|
||||
return ctx.__enter__(), ctx
|
||||
|
||||
return _default_in_memory_checkpointer(), None
|
||||
|
||||
|
||||
def get_checkpointer(app_config: AppConfig | None = None) -> Checkpointer:
|
||||
"""Return the global sync checkpointer singleton, creating it on first call.
|
||||
|
||||
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
@@ -111,6 +180,18 @@ def get_checkpointer() -> Checkpointer:
|
||||
"""
|
||||
global _checkpointer, _checkpointer_ctx
|
||||
|
||||
if app_config is not None:
|
||||
cache_key = id(app_config)
|
||||
cached = _explicit_checkpointers.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
explicit_checkpointer, explicit_ctx = _build_checkpointer_from_app_config(app_config)
|
||||
_explicit_checkpointers[cache_key] = explicit_checkpointer
|
||||
if explicit_ctx is not None:
|
||||
_explicit_checkpointer_contexts[cache_key] = explicit_ctx
|
||||
return explicit_checkpointer
|
||||
|
||||
if _checkpointer is not None:
|
||||
return _checkpointer
|
||||
|
||||
@@ -121,28 +202,30 @@ def get_checkpointer() -> Checkpointer:
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
config = get_checkpointer_config()
|
||||
global_app_config = _app_config
|
||||
|
||||
if config is None and _app_config is None:
|
||||
if config is None and global_app_config is None:
|
||||
# Only load app config lazily when neither the app config nor an explicit
|
||||
# checkpointer config has been initialized yet. This keeps tests that
|
||||
# intentionally set the global checkpointer config isolated from any
|
||||
# ambient config.yaml on disk.
|
||||
try:
|
||||
get_app_config()
|
||||
global_app_config = get_app_config()
|
||||
except FileNotFoundError:
|
||||
# In test environments without config.yaml, this is expected.
|
||||
pass
|
||||
config = get_checkpointer_config()
|
||||
if config is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
_checkpointer = InMemorySaver()
|
||||
if config is not None:
|
||||
_checkpointer_ctx = _sync_checkpointer_cm(config)
|
||||
_checkpointer = _checkpointer_ctx.__enter__()
|
||||
return _checkpointer
|
||||
|
||||
_checkpointer_ctx = _sync_checkpointer_cm(config)
|
||||
_checkpointer = _checkpointer_ctx.__enter__()
|
||||
if global_app_config is not None:
|
||||
_checkpointer, _checkpointer_ctx = _build_checkpointer_from_app_config(global_app_config)
|
||||
return _checkpointer
|
||||
|
||||
_checkpointer = _default_in_memory_checkpointer()
|
||||
return _checkpointer
|
||||
|
||||
|
||||
@@ -161,6 +244,18 @@ def reset_checkpointer() -> None:
|
||||
_checkpointer_ctx = None
|
||||
_checkpointer = None
|
||||
|
||||
for cache_key, ctx in list(_explicit_checkpointer_contexts.items()):
|
||||
try:
|
||||
ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during explicit checkpointer cleanup", exc_info=True)
|
||||
finally:
|
||||
_explicit_checkpointer_contexts.pop(cache_key, None)
|
||||
_explicit_checkpointers.pop(cache_key, None)
|
||||
|
||||
_explicit_checkpointers.clear()
|
||||
_explicit_checkpointer_contexts.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync context manager
|
||||
@@ -168,7 +263,7 @@ def reset_checkpointer() -> None:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def checkpointer_context() -> Iterator[Checkpointer]:
|
||||
def checkpointer_context(app_config: AppConfig | None = None) -> Iterator[Checkpointer]:
|
||||
"""Sync context manager that yields a checkpointer and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
|
||||
@@ -181,12 +276,16 @@ def checkpointer_context() -> Iterator[Checkpointer]:
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
if config.checkpointer is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
resolved_app_config = app_config or get_app_config()
|
||||
if resolved_app_config.checkpointer is not None:
|
||||
with _sync_checkpointer_cm(resolved_app_config.checkpointer) as saver:
|
||||
yield saver
|
||||
return
|
||||
|
||||
with _sync_checkpointer_cm(config.checkpointer) as saver:
|
||||
yield saver
|
||||
db_config = getattr(resolved_app_config, "database", None)
|
||||
if _persistent_database_backend(db_config) is not None:
|
||||
with _sync_checkpointer_from_database_cm(db_config) as saver:
|
||||
yield saver
|
||||
return
|
||||
|
||||
yield _default_in_memory_checkpointer()
|
||||
|
||||
@@ -6,10 +6,9 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deerflow.utils.time import now_iso as _now_iso
|
||||
|
||||
from .schemas import DisconnectMode, RunStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -18,6 +17,10 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(UTC).isoformat()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunRecord:
|
||||
"""Mutable record for a single run."""
|
||||
|
||||
@@ -23,8 +23,6 @@ from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
@@ -444,12 +442,6 @@ async def _rollback_to_pre_run_checkpoint(
|
||||
if checkpoint_to_restore.get("id") is None:
|
||||
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
|
||||
return
|
||||
restore_marker = _new_checkpoint_marker()
|
||||
checkpoint_to_restore = {
|
||||
**checkpoint_to_restore,
|
||||
"id": restore_marker["id"],
|
||||
"ts": restore_marker["ts"],
|
||||
}
|
||||
metadata = pre_run_snapshot.get("metadata", {})
|
||||
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
|
||||
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
|
||||
@@ -501,11 +493,6 @@ async def _rollback_to_pre_run_checkpoint(
|
||||
)
|
||||
|
||||
|
||||
def _new_checkpoint_marker() -> dict[str, str]:
|
||||
marker = empty_checkpoint()
|
||||
return {"id": marker["id"], "ts": marker["ts"]}
|
||||
|
||||
|
||||
def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
"""Map LangGraph internal stream_mode name to SSE event name.
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from collections.abc import Iterator
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig, get_app_config
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -98,9 +98,26 @@ def _sync_store_cm(config) -> Iterator[BaseStore]:
|
||||
|
||||
_store: BaseStore | None = None
|
||||
_store_ctx = None # open context manager keeping the connection alive
|
||||
_explicit_stores: dict[int, BaseStore] = {}
|
||||
_explicit_store_contexts: dict[int, object] = {}
|
||||
|
||||
|
||||
def get_store() -> BaseStore:
|
||||
def _default_in_memory_store() -> BaseStore:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
return InMemoryStore()
|
||||
|
||||
|
||||
def _build_store_from_app_config(app_config: AppConfig) -> tuple[BaseStore, object | None]:
|
||||
if app_config.checkpointer is not None:
|
||||
ctx = _sync_store_cm(app_config.checkpointer)
|
||||
return ctx.__enter__(), ctx
|
||||
|
||||
return _default_in_memory_store(), None
|
||||
|
||||
|
||||
def get_store(app_config: AppConfig | None = None) -> BaseStore:
|
||||
"""Return the global sync Store singleton, creating it on first call.
|
||||
|
||||
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
@@ -112,6 +129,18 @@ def get_store() -> BaseStore:
|
||||
"""
|
||||
global _store, _store_ctx
|
||||
|
||||
if app_config is not None:
|
||||
cache_key = id(app_config)
|
||||
cached = _explicit_stores.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
explicit_store, explicit_ctx = _build_store_from_app_config(app_config)
|
||||
_explicit_stores[cache_key] = explicit_store
|
||||
if explicit_ctx is not None:
|
||||
_explicit_store_contexts[cache_key] = explicit_ctx
|
||||
return explicit_store
|
||||
|
||||
if _store is not None:
|
||||
return _store
|
||||
|
||||
@@ -130,10 +159,7 @@ def get_store() -> BaseStore:
|
||||
config = get_checkpointer_config()
|
||||
|
||||
if config is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
_store = InMemoryStore()
|
||||
_store = _default_in_memory_store()
|
||||
return _store
|
||||
|
||||
_store_ctx = _sync_store_cm(config)
|
||||
@@ -156,6 +182,18 @@ def reset_store() -> None:
|
||||
_store_ctx = None
|
||||
_store = None
|
||||
|
||||
for cache_key, ctx in list(_explicit_store_contexts.items()):
|
||||
try:
|
||||
ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during explicit store cleanup", exc_info=True)
|
||||
finally:
|
||||
_explicit_store_contexts.pop(cache_key, None)
|
||||
_explicit_stores.pop(cache_key, None)
|
||||
|
||||
_explicit_stores.clear()
|
||||
_explicit_store_contexts.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync context manager
|
||||
@@ -163,7 +201,7 @@ def reset_store() -> None:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def store_context() -> Iterator[BaseStore]:
|
||||
def store_context(app_config: AppConfig | None = None) -> Iterator[BaseStore]:
|
||||
"""Sync context manager that yields a Store and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_store`, this does **not** cache the instance — each
|
||||
@@ -176,13 +214,10 @@ def store_context() -> Iterator[BaseStore]:
|
||||
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
config = get_app_config()
|
||||
if config.checkpointer is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
yield InMemoryStore()
|
||||
resolved_app_config = app_config or get_app_config()
|
||||
if resolved_app_config.checkpointer is None:
|
||||
yield _default_in_memory_store()
|
||||
return
|
||||
|
||||
with _sync_store_cm(config.checkpointer) as store:
|
||||
with _sync_store_cm(resolved_app_config.checkpointer) as store:
|
||||
yield store
|
||||
|
||||
@@ -4,10 +4,8 @@ Pure business logic — no FastAPI/HTTP dependencies.
|
||||
Both Gateway and Client delegate to these functions.
|
||||
"""
|
||||
|
||||
import errno
|
||||
import os
|
||||
import re
|
||||
import stat
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
@@ -19,10 +17,6 @@ class PathTraversalError(ValueError):
|
||||
"""Raised when a path escapes its allowed base directory."""
|
||||
|
||||
|
||||
class UnsafeUploadPathError(ValueError):
|
||||
"""Raised when an upload destination is not a safe regular file path."""
|
||||
|
||||
|
||||
# thread_id must be alphanumeric, hyphens, underscores, or dots only.
|
||||
_SAFE_THREAD_ID = re.compile(r"^[a-zA-Z0-9._-]+$")
|
||||
|
||||
@@ -115,64 +109,6 @@ def validate_path_traversal(path: Path, base: Path) -> None:
|
||||
raise PathTraversalError("Path traversal detected") from None
|
||||
|
||||
|
||||
def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, object]:
|
||||
"""Open an upload destination for safe streaming writes.
|
||||
|
||||
Upload directories may be mounted into local sandboxes. A sandbox process can
|
||||
therefore leave a symlink at a future upload filename. Normal ``Path.write_bytes``
|
||||
follows that link and can overwrite files outside the uploads directory with
|
||||
gateway privileges. This helper rejects symlink destinations and uses
|
||||
``O_NOFOLLOW`` where available so the final path component cannot be raced into
|
||||
a symlink between validation and open.
|
||||
"""
|
||||
safe_name = normalize_filename(filename)
|
||||
dest = base_dir / safe_name
|
||||
|
||||
try:
|
||||
st = os.lstat(dest)
|
||||
except FileNotFoundError:
|
||||
st = None
|
||||
|
||||
if st is not None and not stat.S_ISREG(st.st_mode):
|
||||
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
|
||||
|
||||
validate_path_traversal(dest, base_dir)
|
||||
|
||||
if not hasattr(os, "O_NOFOLLOW"):
|
||||
raise UnsafeUploadPathError("Upload writes require O_NOFOLLOW support")
|
||||
|
||||
flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW
|
||||
if hasattr(os, "O_NONBLOCK"):
|
||||
flags |= os.O_NONBLOCK
|
||||
|
||||
try:
|
||||
fd = os.open(dest, flags, 0o600)
|
||||
except OSError as exc:
|
||||
if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
|
||||
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
|
||||
raise
|
||||
|
||||
try:
|
||||
opened_stat = os.fstat(fd)
|
||||
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1:
|
||||
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
|
||||
os.ftruncate(fd, 0)
|
||||
fh = os.fdopen(fd, "wb")
|
||||
fd = -1
|
||||
finally:
|
||||
if fd >= 0:
|
||||
os.close(fd)
|
||||
return dest, fh
|
||||
|
||||
|
||||
def write_upload_file_no_symlink(base_dir: Path, filename: str, data: bytes) -> Path:
|
||||
"""Write upload bytes without following a pre-existing destination symlink."""
|
||||
dest, fh = open_upload_file_no_symlink(base_dir, filename)
|
||||
with fh:
|
||||
fh.write(data)
|
||||
return dest
|
||||
|
||||
|
||||
def list_files_in_dir(directory: Path) -> dict:
|
||||
"""List files (not directories) in *directory*.
|
||||
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
"""ISO 8601 timestamp helpers for the Gateway and embedded runtime.
|
||||
|
||||
DeerFlow stores and serializes thread/run timestamps as ISO 8601 UTC
|
||||
strings to match the LangGraph Platform schema (see
|
||||
``langgraph_sdk.schema.Thread``, where ``created_at`` / ``updated_at``
|
||||
are ``datetime`` and JSON-encode to ISO 8601). All timestamp generation
|
||||
should funnel through :func:`now_iso` so the wire format stays
|
||||
consistent across endpoints, the embedded ``RunManager``, and the
|
||||
checkpoint metadata written by the Gateway.
|
||||
|
||||
:func:`coerce_iso` provides a forward-compatible read path for legacy
|
||||
records that historically stored ``str(time.time())`` floats.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
|
||||
__all__ = ["coerce_iso", "now_iso"]
|
||||
|
||||
_UNIX_TIMESTAMP_PATTERN = re.compile(r"^\d{10}(?:\.\d+)?$")
|
||||
"""Matches the unix-timestamp string shape historically written by
|
||||
``str(time.time())`` (10-digit seconds with optional fractional part).
|
||||
The 10-digit anchor avoids accidentally rewriting ISO years like
|
||||
``"2026"`` and stays valid until the year 2286.
|
||||
"""
|
||||
|
||||
|
||||
def now_iso() -> str:
|
||||
"""Return the current UTC time as an ISO 8601 string.
|
||||
|
||||
Example: ``"2026-04-27T03:19:46.511479+00:00"``.
|
||||
"""
|
||||
return datetime.now(UTC).isoformat()
|
||||
|
||||
|
||||
def coerce_iso(value: object) -> str:
|
||||
"""Best-effort coerce a stored timestamp to an ISO 8601 string.
|
||||
|
||||
Translates legacy unix-timestamp floats / strings written by older
|
||||
DeerFlow versions into ISO without a one-shot migration. ISO strings
|
||||
pass through unchanged; ``datetime`` instances are normalised to UTC
|
||||
(tz-naive values are assumed to be UTC) and emitted via
|
||||
``isoformat()`` so the wire format always uses the ``T`` separator;
|
||||
empty values become ``""``; unrecognised values are stringified as a
|
||||
last resort.
|
||||
"""
|
||||
if value is None or value == "":
|
||||
return ""
|
||||
if isinstance(value, bool):
|
||||
# ``bool`` is a subclass of ``int`` — treat as garbage, not 0/1.
|
||||
return str(value)
|
||||
if isinstance(value, datetime):
|
||||
# ``datetime`` must be handled before the ``int``/``float`` check;
|
||||
# str(datetime) would produce ``"YYYY-MM-DD HH:MM:SS+00:00"``
|
||||
# (space separator), which breaks strict ISO 8601 consumers.
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
else:
|
||||
value = value.astimezone(UTC)
|
||||
return value.isoformat()
|
||||
if isinstance(value, (int, float)):
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value), UTC).isoformat()
|
||||
except (ValueError, OverflowError, OSError):
|
||||
return str(value)
|
||||
if isinstance(value, str):
|
||||
if _UNIX_TIMESTAMP_PATTERN.match(value):
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value), UTC).isoformat()
|
||||
except (ValueError, OverflowError, OSError):
|
||||
return value
|
||||
return value
|
||||
return str(value)
|
||||
@@ -47,3 +47,4 @@ members = ["packages/harness"]
|
||||
|
||||
[tool.uv.sources]
|
||||
deerflow-harness = { workspace = true }
|
||||
|
||||
|
||||
@@ -3,12 +3,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.message_bus import InboundMessage, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.message_bus import MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
|
||||
def _run(coro):
|
||||
@@ -249,109 +248,6 @@ class TestResolveAttachments:
|
||||
assert result[0].filename == "data.csv"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inbound file ingestion tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInboundFileIngestion:
|
||||
def test_rejects_preexisting_symlink_destination(self, tmp_path):
|
||||
from app.channels import manager
|
||||
|
||||
uploads_dir = tmp_path / "uploads"
|
||||
uploads_dir.mkdir()
|
||||
outside_file = tmp_path / "outside-created.txt"
|
||||
(uploads_dir / "victim.txt").symlink_to(outside_file)
|
||||
|
||||
msg = InboundMessage(
|
||||
channel_name="test-channel",
|
||||
chat_id="chat-1",
|
||||
user_id="user-1",
|
||||
text="see attachment",
|
||||
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
|
||||
)
|
||||
|
||||
async def fake_reader(file_info, client):
|
||||
return b"attacker data"
|
||||
|
||||
with (
|
||||
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
|
||||
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
|
||||
):
|
||||
result = _run(manager._ingest_inbound_files("thread-1", msg))
|
||||
|
||||
assert result == []
|
||||
assert not outside_file.exists()
|
||||
assert (uploads_dir / "victim.txt").is_symlink()
|
||||
|
||||
def test_rejects_dangling_symlink_destination(self, tmp_path):
|
||||
from app.channels import manager
|
||||
|
||||
uploads_dir = tmp_path / "uploads"
|
||||
uploads_dir.mkdir()
|
||||
missing_target = tmp_path / "missing-created.txt"
|
||||
(uploads_dir / "victim.txt").symlink_to(missing_target)
|
||||
|
||||
msg = InboundMessage(
|
||||
channel_name="test-channel",
|
||||
chat_id="chat-1",
|
||||
user_id="user-1",
|
||||
text="see attachment",
|
||||
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
|
||||
)
|
||||
|
||||
async def fake_reader(file_info, client):
|
||||
return b"attacker data"
|
||||
|
||||
with (
|
||||
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
|
||||
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
|
||||
):
|
||||
result = _run(manager._ingest_inbound_files("thread-1", msg))
|
||||
|
||||
assert result == []
|
||||
assert not missing_target.exists()
|
||||
assert (uploads_dir / "victim.txt").is_symlink()
|
||||
|
||||
def test_hardlinked_existing_file_is_not_overwritten(self, tmp_path):
|
||||
from app.channels import manager
|
||||
|
||||
uploads_dir = tmp_path / "uploads"
|
||||
uploads_dir.mkdir()
|
||||
outside_file = tmp_path / "outside-created.txt"
|
||||
outside_file.write_text("protected", encoding="utf-8")
|
||||
os.link(outside_file, uploads_dir / "victim.txt")
|
||||
|
||||
msg = InboundMessage(
|
||||
channel_name="test-channel",
|
||||
chat_id="chat-1",
|
||||
user_id="user-1",
|
||||
text="see attachment",
|
||||
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
|
||||
)
|
||||
|
||||
async def fake_reader(file_info, client):
|
||||
return b"new attachment data"
|
||||
|
||||
with (
|
||||
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
|
||||
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
|
||||
):
|
||||
result = _run(manager._ingest_inbound_files("thread-1", msg))
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"filename": "victim_1.txt",
|
||||
"size": len(b"new attachment data"),
|
||||
"path": "/mnt/user-data/uploads/victim_1.txt",
|
||||
"is_image": False,
|
||||
}
|
||||
]
|
||||
assert outside_file.read_text(encoding="utf-8") == "protected"
|
||||
assert (uploads_dir / "victim.txt").read_text(encoding="utf-8") == "protected"
|
||||
assert (uploads_dir / "victim_1.txt").read_bytes() == b"new attachment data"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Channel base class _on_outbound with attachments
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Unit tests for checkpointer config and singleton factory."""
|
||||
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -103,6 +104,53 @@ class TestGetCheckpointer:
|
||||
cp2 = get_checkpointer()
|
||||
assert cp1 is not cp2
|
||||
|
||||
def test_explicit_app_config_bypasses_global_config_lookup(self):
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
explicit_config = SimpleNamespace(
|
||||
checkpointer=CheckpointerConfig(type="memory"),
|
||||
database=SimpleNamespace(backend="memory"),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"deerflow.runtime.checkpointer.provider.get_app_config",
|
||||
side_effect=AssertionError("ambient get_app_config() must not be used when app_config is explicit"),
|
||||
):
|
||||
cp = get_checkpointer(app_config=explicit_config)
|
||||
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
|
||||
def test_explicit_app_config_uses_unified_database_sqlite_backend(self):
|
||||
explicit_config = SimpleNamespace(
|
||||
checkpointer=None,
|
||||
database=SimpleNamespace(backend="sqlite", checkpointer_sqlite_path="/tmp/explicit/deerflow.db"),
|
||||
)
|
||||
|
||||
mock_saver_instance = MagicMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
|
||||
mock_cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_saver_cls = MagicMock()
|
||||
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
|
||||
|
||||
mock_module = MagicMock()
|
||||
mock_module.SqliteSaver = mock_saver_cls
|
||||
|
||||
with (
|
||||
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
|
||||
patch(
|
||||
"deerflow.runtime.checkpointer.provider.get_app_config",
|
||||
side_effect=AssertionError("ambient get_app_config() must not be used when app_config is explicit"),
|
||||
),
|
||||
patch("deerflow.runtime.checkpointer.provider.ensure_sqlite_parent_dir") as mock_ensure,
|
||||
):
|
||||
cp = get_checkpointer(app_config=explicit_config)
|
||||
|
||||
assert cp is mock_saver_instance
|
||||
mock_ensure.assert_called_once_with("/tmp/explicit/deerflow.db")
|
||||
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/explicit/deerflow.db")
|
||||
|
||||
def test_sqlite_raises_when_package_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
|
||||
|
||||
@@ -437,85 +437,6 @@ class TestStream:
|
||||
call_kwargs = agent.stream.call_args.kwargs
|
||||
assert "messages" in call_kwargs["stream_mode"]
|
||||
|
||||
def test_stream_emits_additional_kwargs_updates_for_streamed_ai_messages(self, client):
|
||||
"""stream() emits a follow-up AI event when attribution metadata arrives via values."""
|
||||
assembled = AIMessage(
|
||||
content="Hello!",
|
||||
id="ai-1",
|
||||
additional_kwargs={
|
||||
"token_usage_attribution": {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
agent = MagicMock()
|
||||
agent.stream.return_value = iter(
|
||||
[
|
||||
("messages", (AIMessageChunk(content="Hello!", id="ai-1"), {})),
|
||||
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(client, "_ensure_agent"),
|
||||
patch.object(client, "_agent", agent),
|
||||
):
|
||||
events = list(client.stream("hi", thread_id="t-stream-kwargs"))
|
||||
|
||||
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
|
||||
assert any(event.data.get("content") == "Hello!" for event in ai_events)
|
||||
assert any(event.data.get("additional_kwargs", {}).get("token_usage_attribution", {}).get("kind") == "final_answer" for event in ai_events)
|
||||
|
||||
def test_stream_emits_new_additional_kwargs_after_prior_metadata(self, client):
|
||||
"""stream() emits later attribution metadata even after earlier kwargs for the same id."""
|
||||
attribution = {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
assembled = AIMessage(
|
||||
content="Hello!",
|
||||
id="ai-1",
|
||||
additional_kwargs={
|
||||
"reasoning_content": "Thinking first.",
|
||||
"token_usage_attribution": attribution,
|
||||
},
|
||||
)
|
||||
agent = MagicMock()
|
||||
agent.stream.return_value = iter(
|
||||
[
|
||||
(
|
||||
"messages",
|
||||
(
|
||||
AIMessageChunk(
|
||||
content="Hello!",
|
||||
id="ai-1",
|
||||
additional_kwargs={"reasoning_content": "Thinking first."},
|
||||
),
|
||||
{},
|
||||
),
|
||||
),
|
||||
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(client, "_ensure_agent"),
|
||||
patch.object(client, "_agent", agent),
|
||||
):
|
||||
events = list(client.stream("hi", thread_id="t-stream-kwargs-delta"))
|
||||
|
||||
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
|
||||
metadata_events = [event for event in ai_events if event.data.get("additional_kwargs")]
|
||||
|
||||
assert metadata_events[0].data["additional_kwargs"] == {"reasoning_content": "Thinking first."}
|
||||
assert metadata_events[1].data["content"] == ""
|
||||
assert metadata_events[1].data["additional_kwargs"] == {"token_usage_attribution": attribution}
|
||||
|
||||
def test_chat_accumulates_streamed_deltas(self, client):
|
||||
"""chat() concatenates per-id deltas from messages mode."""
|
||||
agent = MagicMock()
|
||||
@@ -927,6 +848,28 @@ class TestEnsureAgent:
|
||||
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
|
||||
assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"}
|
||||
|
||||
def test_threads_explicit_app_config_to_dependencies(self, client):
|
||||
"""Client-owned AppConfig must flow into model/tool/prompt/checkpointer composition."""
|
||||
mock_agent = MagicMock()
|
||||
mock_checkpointer = MagicMock()
|
||||
config = client._get_runnable_config("t1")
|
||||
|
||||
with (
|
||||
patch("deerflow.client.create_chat_model", return_value=MagicMock()) as mock_create_chat_model,
|
||||
patch("deerflow.client.create_agent", return_value=mock_agent),
|
||||
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
||||
patch("deerflow.tools.get_available_tools", return_value=[]) as mock_get_available_tools,
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer) as mock_get_checkpointer,
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
assert mock_create_chat_model.call_args.kwargs["app_config"] is client._app_config
|
||||
assert mock_build_middlewares.call_args.kwargs["app_config"] is client._app_config
|
||||
assert mock_apply_prompt.call_args.kwargs["app_config"] is client._app_config
|
||||
assert mock_get_available_tools.call_args.kwargs["app_config"] is client._app_config
|
||||
assert mock_get_checkpointer.call_args.kwargs["app_config"] is client._app_config
|
||||
|
||||
def test_uses_default_checkpointer_when_available(self, client):
|
||||
mock_agent = MagicMock()
|
||||
mock_checkpointer = MagicMock()
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
"""Tests for DeerFlowClient message serialization helpers."""
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
|
||||
def test_serialize_ai_message_preserves_additional_kwargs():
|
||||
message = AIMessage(
|
||||
content="done",
|
||||
additional_kwargs={
|
||||
"token_usage_attribution": {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
},
|
||||
usage_metadata={"input_tokens": 12, "output_tokens": 3, "total_tokens": 15},
|
||||
)
|
||||
|
||||
serialized = DeerFlowClient._serialize_message(message)
|
||||
|
||||
assert serialized["type"] == "ai"
|
||||
assert serialized["usage_metadata"] == {
|
||||
"input_tokens": 12,
|
||||
"output_tokens": 3,
|
||||
"total_tokens": 15,
|
||||
}
|
||||
assert serialized["additional_kwargs"] == {
|
||||
"token_usage_attribution": {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_serialize_human_message_preserves_additional_kwargs():
|
||||
message = HumanMessage(
|
||||
content="hello",
|
||||
additional_kwargs={"files": [{"name": "diagram.png"}]},
|
||||
)
|
||||
|
||||
serialized = DeerFlowClient._serialize_message(message)
|
||||
|
||||
assert serialized == {
|
||||
"type": "human",
|
||||
"content": "hello",
|
||||
"id": None,
|
||||
"additional_kwargs": {"files": [{"name": "diagram.png"}]},
|
||||
}
|
||||
@@ -82,36 +82,6 @@ def test_parse_response_text_content():
|
||||
assert result.generations[0].message.content == "Hello world"
|
||||
|
||||
|
||||
def test_parse_response_populates_usage_metadata():
|
||||
model = _make_model()
|
||||
response = {
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"content": [{"type": "output_text", "text": "Hello world"}],
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
"input_tokens_details": {"cached_tokens": 3},
|
||||
"output_tokens_details": {"reasoning_tokens": 2},
|
||||
},
|
||||
"model": "gpt-5.4",
|
||||
}
|
||||
|
||||
result = model._parse_response(response)
|
||||
|
||||
meta = result.generations[0].message.usage_metadata
|
||||
assert meta is not None
|
||||
assert meta["input_tokens"] == 10
|
||||
assert meta["output_tokens"] == 5
|
||||
assert meta["total_tokens"] == 15
|
||||
assert meta["input_token_details"]["cache_read"] == 3
|
||||
assert meta["output_token_details"]["reasoning"] == 2
|
||||
|
||||
|
||||
def test_parse_response_reasoning_content():
|
||||
model = _make_model()
|
||||
response = {
|
||||
|
||||
@@ -3,8 +3,6 @@ from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, call
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from deerflow.runtime.runs.manager import RunManager
|
||||
from deerflow.runtime.runs.schemas import RunStatus
|
||||
@@ -18,14 +16,6 @@ class FakeCheckpointer:
|
||||
self.aput_writes = AsyncMock()
|
||||
|
||||
|
||||
def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int):
|
||||
checkpoint = empty_checkpoint()
|
||||
checkpoint["id"] = checkpoint_id
|
||||
checkpoint["channel_values"] = {"messages": messages}
|
||||
checkpoint["channel_versions"] = {"messages": version}
|
||||
return checkpoint
|
||||
|
||||
|
||||
def test_build_runtime_context_includes_app_config_when_present():
|
||||
app_config = object()
|
||||
|
||||
@@ -120,16 +110,16 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_not_awaited()
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
|
||||
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
assert restored_checkpoint["id"] != "ckpt-1"
|
||||
assert "channel_versions" in restored_checkpoint
|
||||
assert "channel_values" in restored_checkpoint
|
||||
assert restored_checkpoint["channel_versions"] == {"messages": 3}
|
||||
assert restored_checkpoint["channel_values"] == {"messages": ["before"]}
|
||||
assert restored_metadata == {"source": "input"}
|
||||
assert new_versions == {"messages": 3}
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{
|
||||
"id": "ckpt-1",
|
||||
"channel_versions": {"messages": 3},
|
||||
"channel_values": {"messages": ["before"]},
|
||||
},
|
||||
{"source": "input"},
|
||||
{"messages": 3},
|
||||
)
|
||||
assert checkpointer.aput_writes.await_args_list == [
|
||||
call(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
||||
@@ -144,40 +134,6 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer():
|
||||
checkpointer = InMemorySaver()
|
||||
thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
before_checkpoint = _make_checkpoint("0001", ["before"], 1)
|
||||
before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1})
|
||||
after_checkpoint = _make_checkpoint("0002", ["after"], 2)
|
||||
after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2})
|
||||
checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after")
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="0001",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": before_checkpoint,
|
||||
"metadata": {"step": 1},
|
||||
"pending_writes": [("task-before", "messages", "pending-before")],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
latest = checkpointer.get_tuple(thread_config)
|
||||
|
||||
assert latest is not None
|
||||
assert latest.config["configurable"]["checkpoint_id"] != "0001"
|
||||
assert latest.config["configurable"]["checkpoint_id"] != "0002"
|
||||
assert latest.checkpoint["channel_values"] == {"messages": ["before"]}
|
||||
assert latest.pending_writes == [("task-before", "messages", "pending-before")]
|
||||
assert ("task-after", "messages", "pending-after") not in latest.pending_writes
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_deletes_thread_when_no_snapshot_exists():
|
||||
checkpointer = FakeCheckpointer(put_result=None)
|
||||
@@ -238,13 +194,12 @@ async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
|
||||
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
assert restored_checkpoint["id"] != "ckpt-1"
|
||||
assert restored_checkpoint["channel_versions"] == {}
|
||||
assert restored_metadata == {}
|
||||
assert new_versions == {}
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{"id": "ckpt-1", "channel_versions": {}},
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
@@ -7,7 +7,6 @@ import yaml
|
||||
|
||||
from deerflow.config import app_config as app_config_module
|
||||
from deerflow.config import extensions_config as extensions_config_module
|
||||
from deerflow.config import skills_config as skills_config_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.paths import Paths
|
||||
@@ -36,7 +35,6 @@ def test_default_runtime_paths_resolve_from_current_project(tmp_path: Path, monk
|
||||
encoding="utf-8",
|
||||
)
|
||||
(tmp_path / "extensions_config.json").write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8")
|
||||
(tmp_path / "skills").mkdir()
|
||||
|
||||
assert AppConfig.resolve_config_path() == tmp_path / "config.yaml"
|
||||
assert ExtensionsConfig.resolve_config_path() == tmp_path / "extensions_config.json"
|
||||
@@ -123,40 +121,6 @@ def test_app_config_falls_back_to_legacy_when_project_root_lacks_config(tmp_path
|
||||
assert AppConfig.resolve_config_path() == legacy_backend_config
|
||||
|
||||
|
||||
def test_skills_config_falls_back_to_legacy_when_project_root_lacks_skills(tmp_path: Path, monkeypatch):
|
||||
"""When DEER_FLOW_PROJECT_ROOT is unset and cwd has no `skills/`, the legacy
|
||||
repo-root candidate must be used so monorepo runs (cwd=backend/) keep finding
|
||||
`<repo>/skills` instead of `<repo>/backend/skills` (regression test for #2694)."""
|
||||
_clear_path_env(monkeypatch)
|
||||
cwd = tmp_path / "cwd"
|
||||
cwd.mkdir()
|
||||
monkeypatch.chdir(cwd)
|
||||
|
||||
legacy_skills = tmp_path / "legacy-repo" / "skills"
|
||||
legacy_skills.mkdir(parents=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
skills_config_module,
|
||||
"_legacy_skills_candidates",
|
||||
lambda: (legacy_skills,),
|
||||
)
|
||||
|
||||
assert SkillsConfig().get_skills_path() == legacy_skills
|
||||
|
||||
|
||||
def test_skills_config_returns_project_default_when_neither_exists(tmp_path: Path, monkeypatch):
|
||||
"""When nothing exists, fall back to the project-root default path so callers
|
||||
surface a stable empty location instead of silently picking a stale legacy dir."""
|
||||
_clear_path_env(monkeypatch)
|
||||
cwd = tmp_path / "cwd"
|
||||
cwd.mkdir()
|
||||
monkeypatch.chdir(cwd)
|
||||
|
||||
monkeypatch.setattr(skills_config_module, "_legacy_skills_candidates", lambda: ())
|
||||
|
||||
assert SkillsConfig().get_skills_path() == cwd / "skills"
|
||||
|
||||
|
||||
def test_extensions_config_falls_back_to_legacy_when_project_root_lacks_file(tmp_path: Path, monkeypatch):
|
||||
"""ExtensionsConfig should hit the legacy backend/repo-root locations when
|
||||
the caller project root has no extensions_config.json/mcp_config.json."""
|
||||
|
||||
@@ -1,308 +0,0 @@
|
||||
"""Unit tests for the Serper community web search tool."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_api_key_warned():
|
||||
"""Reset the module-level warning flag before each test."""
|
||||
import deerflow.community.serper.tools as serper_mod
|
||||
|
||||
serper_mod._api_key_warned = False
|
||||
yield
|
||||
serper_mod._api_key_warned = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_with_key():
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "test-serper-key", "max_results": 5}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_no_key():
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
yield mock
|
||||
|
||||
|
||||
def _make_serper_response(organic: list) -> MagicMock:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {"organic": organic}
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
return mock_resp
|
||||
|
||||
|
||||
class TestGetApiKey:
|
||||
def test_returns_config_key_when_present(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "from-config"}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "from-config"
|
||||
|
||||
def test_falls_back_to_env_when_config_key_empty(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": ""}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-key"
|
||||
|
||||
def test_falls_back_to_env_when_config_key_whitespace(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": " "}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-key"
|
||||
|
||||
def test_falls_back_to_env_when_config_key_null(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": None}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-key"
|
||||
|
||||
def test_falls_back_to_env_when_no_config(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-only"}):
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-only"
|
||||
|
||||
def test_returns_none_when_no_key_anywhere(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
import os
|
||||
|
||||
os.environ.pop("SERPER_API_KEY", None)
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() is None
|
||||
|
||||
|
||||
class TestWebSearchTool:
|
||||
def test_basic_search_returns_normalized_results(self, mock_config_with_key):
|
||||
organic = [
|
||||
{"title": "Result 1", "link": "https://example.com/1", "snippet": "Snippet 1"},
|
||||
{"title": "Result 2", "link": "https://example.com/2", "snippet": "Snippet 2"},
|
||||
]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "python tutorial"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["query"] == "python tutorial"
|
||||
assert parsed["total_results"] == 2
|
||||
assert parsed["results"][0]["title"] == "Result 1"
|
||||
assert parsed["results"][0]["url"] == "https://example.com/1"
|
||||
assert parsed["results"][0]["content"] == "Snippet 1"
|
||||
|
||||
def test_respects_max_results_from_config(self, mock_config_with_key):
|
||||
mock_config_with_key.return_value.get_tool_config.return_value.model_extra = {
|
||||
"api_key": "test-key",
|
||||
"max_results": 3,
|
||||
}
|
||||
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["total_results"] == 3
|
||||
assert len(parsed["results"]) == 3
|
||||
|
||||
def test_max_results_parameter_accepted(self, mock_config_no_key):
|
||||
"""Tool accepts max_results as a call parameter when config does not override it."""
|
||||
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test", "max_results": 2})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["total_results"] == 2
|
||||
|
||||
def test_config_max_results_overrides_parameter(self):
|
||||
"""Config max_results overrides the parameter passed at call time, matching ddg_search behaviour."""
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "test-key", "max_results": 3}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
|
||||
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test", "max_results": 8})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["total_results"] == 3
|
||||
|
||||
def test_empty_organic_returns_error_json(self, mock_config_with_key):
|
||||
"""Empty organic list returns structured error, matching ddg_search convention."""
|
||||
mock_resp = _make_serper_response([])
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "no results"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
assert parsed["error"] == "No results found"
|
||||
assert parsed["query"] == "no results"
|
||||
|
||||
def test_missing_api_key_returns_error_json(self, mock_config_no_key):
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
import os
|
||||
|
||||
os.environ.pop("SERPER_API_KEY", None)
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
assert "SERPER_API_KEY" in parsed["error"]
|
||||
|
||||
def test_missing_api_key_logs_warning_once(self, mock_config_no_key, caplog):
|
||||
import logging
|
||||
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
import os
|
||||
|
||||
os.environ.pop("SERPER_API_KEY", None)
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.community.serper.tools"):
|
||||
web_search_tool.invoke({"query": "q1"})
|
||||
web_search_tool.invoke({"query": "q2"})
|
||||
|
||||
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert len(warnings) == 1
|
||||
|
||||
def test_http_error_returns_structured_error(self, mock_config_with_key):
|
||||
mock_error_response = MagicMock()
|
||||
mock_error_response.status_code = 403
|
||||
mock_error_response.text = "Forbidden"
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = httpx.HTTPStatusError("403", request=MagicMock(), response=mock_error_response)
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
assert "403" in parsed["error"]
|
||||
|
||||
def test_network_exception_returns_error_json(self, mock_config_with_key):
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
|
||||
def test_sends_correct_headers_and_payload(self, mock_config_with_key):
|
||||
organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
web_search_tool.invoke({"query": "hello world"})
|
||||
|
||||
call_kwargs = mock_post.call_args
|
||||
headers = call_kwargs.kwargs["headers"]
|
||||
payload = call_kwargs.kwargs["json"]
|
||||
|
||||
assert headers["X-API-KEY"] == "test-serper-key"
|
||||
assert payload["q"] == "hello world"
|
||||
assert payload["num"] == 5
|
||||
|
||||
def test_uses_env_key_when_config_absent(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-only-key"}):
|
||||
organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
web_search_tool.invoke({"query": "env key test"})
|
||||
headers = mock_post.call_args.kwargs["headers"]
|
||||
|
||||
assert headers["X-API-KEY"] == "env-only-key"
|
||||
|
||||
def test_partial_fields_in_organic_result(self, mock_config_with_key):
|
||||
"""Missing title/link/snippet should default to empty string."""
|
||||
organic = [{}]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["results"][0] == {"title": "", "url": "", "content": ""}
|
||||
@@ -19,7 +19,6 @@ def test_get_skills_root_path_points_to_current_project_skills(tmp_path: Path, m
|
||||
monkeypatch.delenv("DEER_FLOW_SKILLS_PATH", raising=False)
|
||||
monkeypatch.delenv("DEER_FLOW_PROJECT_ROOT", raising=False)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "skills").mkdir()
|
||||
|
||||
app_config = SimpleNamespace(skills=SkillsConfig())
|
||||
path = get_or_new_skill_storage(app_config=app_config).get_skills_root_path()
|
||||
|
||||
@@ -1,66 +1,12 @@
|
||||
import re
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from app.gateway.routers import threads
|
||||
from deerflow.config.paths import Paths
|
||||
from deerflow.persistence.thread_meta.memory import THREADS_NS, MemoryThreadMetaStore
|
||||
|
||||
_ISO_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
|
||||
|
||||
|
||||
class _PermissiveThreadMetaStore(MemoryThreadMetaStore):
|
||||
"""Memory store that skips user-id filtering for router tests.
|
||||
|
||||
Owner isolation is exercised separately in
|
||||
``test_memory_thread_meta_isolation.py``. Router tests need to drive
|
||||
the FastAPI surface end-to-end with a single fixed app user, but the
|
||||
stub auth middleware in ``_router_auth_helpers`` stamps a fresh UUID
|
||||
on every request, so the production filtering would reject every
|
||||
pre-seeded record. Bypass that filter so the test can focus on the
|
||||
timestamp wire format.
|
||||
"""
|
||||
|
||||
async def _get_owned_record(self, thread_id, user_id, method_name): # type: ignore[override]
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
return dict(item.value) if item is not None else None
|
||||
|
||||
async def check_access(self, thread_id, user_id, *, require_existing=False): # type: ignore[override]
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return not require_existing
|
||||
return True
|
||||
|
||||
async def create(self, thread_id, *, assistant_id=None, user_id=None, display_name=None, metadata=None): # type: ignore[override]
|
||||
return await super().create(thread_id, assistant_id=assistant_id, user_id=None, display_name=display_name, metadata=metadata)
|
||||
|
||||
async def search(self, *, metadata=None, status=None, limit=100, offset=0, user_id=None): # type: ignore[override]
|
||||
return await super().search(metadata=metadata, status=status, limit=limit, offset=offset, user_id=None)
|
||||
|
||||
|
||||
def _build_thread_app() -> tuple[FastAPI, InMemoryStore, InMemorySaver]:
|
||||
"""Build a stub-authed FastAPI app wired with an in-memory ThreadMetaStore.
|
||||
|
||||
The thread_store on ``app.state`` is a permissive subclass of
|
||||
``MemoryThreadMetaStore`` so tests can drive ``/api/threads``
|
||||
end-to-end and pre-seed legacy records via the underlying BaseStore.
|
||||
|
||||
Returns ``(app, store, checkpointer)`` for direct seeding/inspection.
|
||||
"""
|
||||
app = make_authed_test_app()
|
||||
store = InMemoryStore()
|
||||
checkpointer = InMemorySaver()
|
||||
app.state.store = store
|
||||
app.state.checkpointer = checkpointer
|
||||
app.state.thread_store = _PermissiveThreadMetaStore(store)
|
||||
app.include_router(threads.router)
|
||||
return app, store, checkpointer
|
||||
|
||||
|
||||
def test_delete_thread_data_removes_thread_directory(tmp_path):
|
||||
@@ -190,244 +136,3 @@ def test_strip_reserved_metadata_empty_input():
|
||||
def test_strip_reserved_metadata_strips_all_reserved_keys():
|
||||
out = threads._strip_reserved_metadata({"user_id": "x", "keep": "me"})
|
||||
assert out == {"keep": "me"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ISO 8601 timestamp contract (issue #2594)
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Threads endpoints document ``created_at`` / ``updated_at`` as ISO
|
||||
# timestamps and that is the format LangGraph Platform uses
|
||||
# (``langgraph_sdk.schema.Thread.created_at: datetime`` JSON-encodes to
|
||||
# ISO 8601). The tests below pin that contract end-to-end and also
|
||||
# exercise the ``coerce_iso`` healing path for legacy unix-timestamp
|
||||
# records written by older Gateway versions.
|
||||
|
||||
|
||||
def test_create_thread_returns_iso_timestamps() -> None:
|
||||
app, _store, _checkpointer = _build_thread_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/threads", json={"metadata": {}})
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
body = response.json()
|
||||
assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"]
|
||||
assert _ISO_TIMESTAMP_RE.match(body["updated_at"]), body["updated_at"]
|
||||
assert body["created_at"] == body["updated_at"]
|
||||
|
||||
|
||||
def test_get_thread_returns_iso_for_legacy_unix_record() -> None:
|
||||
"""A thread record written by older versions stores ``time.time()``
|
||||
floats. ``get_thread`` must transparently surface them as ISO so the
|
||||
frontend's ``new Date(...)`` parser does not break.
|
||||
"""
|
||||
app, store, checkpointer = _build_thread_app()
|
||||
|
||||
legacy_thread_id = "legacy-thread"
|
||||
legacy_ts = "1777252410.411327"
|
||||
|
||||
async def _seed() -> None:
|
||||
await store.aput(
|
||||
THREADS_NS,
|
||||
legacy_thread_id,
|
||||
{
|
||||
"thread_id": legacy_thread_id,
|
||||
"status": "idle",
|
||||
"created_at": legacy_ts,
|
||||
"updated_at": legacy_ts,
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
await checkpointer.aput(
|
||||
{"configurable": {"thread_id": legacy_thread_id, "checkpoint_ns": ""}},
|
||||
empty_checkpoint(),
|
||||
{"step": -1, "source": "input", "writes": None, "parents": {}},
|
||||
{},
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_seed())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get(f"/api/threads/{legacy_thread_id}")
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
body = response.json()
|
||||
assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"]
|
||||
assert _ISO_TIMESTAMP_RE.match(body["updated_at"]), body["updated_at"]
|
||||
|
||||
|
||||
def test_patch_thread_returns_iso_and_advances_updated_at() -> None:
|
||||
app, store, _checkpointer = _build_thread_app()
|
||||
thread_id = "patch-target"
|
||||
|
||||
legacy_created = "1777000000.000000"
|
||||
legacy_updated = "1777000000.000000"
|
||||
|
||||
async def _seed() -> None:
|
||||
await store.aput(
|
||||
THREADS_NS,
|
||||
thread_id,
|
||||
{
|
||||
"thread_id": thread_id,
|
||||
"status": "idle",
|
||||
"created_at": legacy_created,
|
||||
"updated_at": legacy_updated,
|
||||
"metadata": {"k": "v0"},
|
||||
},
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_seed())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.patch(f"/api/threads/{thread_id}", json={"metadata": {"k": "v1"}})
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
body = response.json()
|
||||
assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"]
|
||||
assert _ISO_TIMESTAMP_RE.match(body["updated_at"]), body["updated_at"]
|
||||
# Patch issues a fresh ``updated_at`` via ``MemoryThreadMetaStore.update_metadata``,
|
||||
# so it must be > the migrated legacy ``created_at`` (both ISO strings
|
||||
# sort lexicographically by time when the format is consistent).
|
||||
assert body["updated_at"] > body["created_at"]
|
||||
assert body["metadata"] == {"k": "v1"}
|
||||
|
||||
|
||||
def test_search_threads_normalizes_legacy_unix_seconds_to_iso() -> None:
|
||||
"""``MemoryThreadMetaStore`` may hold legacy ``time.time()`` floats
|
||||
written by older Gateway versions. ``/search`` must surface them as
|
||||
ISO via ``coerce_iso`` so the frontend's ``new Date(...)`` parser
|
||||
does not break.
|
||||
"""
|
||||
app, store, _checkpointer = _build_thread_app()
|
||||
|
||||
async def _seed() -> None:
|
||||
# Legacy unix-second float (the literal value from issue #2594).
|
||||
await store.aput(
|
||||
THREADS_NS,
|
||||
"legacy",
|
||||
{
|
||||
"thread_id": "legacy",
|
||||
"status": "idle",
|
||||
"created_at": 1777000000.0,
|
||||
"updated_at": 1777000000.0,
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
# Modern ISO string, slightly later.
|
||||
await store.aput(
|
||||
THREADS_NS,
|
||||
"modern",
|
||||
{
|
||||
"thread_id": "modern",
|
||||
"status": "idle",
|
||||
"created_at": "2026-04-27T00:00:00+00:00",
|
||||
"updated_at": "2026-04-27T00:00:00+00:00",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_seed())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/threads/search", json={"limit": 10})
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
items = response.json()
|
||||
assert {item["thread_id"] for item in items} == {"legacy", "modern"}
|
||||
for item in items:
|
||||
assert _ISO_TIMESTAMP_RE.match(item["created_at"]), item
|
||||
assert _ISO_TIMESTAMP_RE.match(item["updated_at"]), item
|
||||
|
||||
|
||||
def test_memory_thread_meta_store_writes_iso_on_create() -> None:
|
||||
"""``MemoryThreadMetaStore.create`` must emit ISO so newly created
|
||||
threads serialize correctly without depending on the router's
|
||||
``coerce_iso`` heal path.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
store = InMemoryStore()
|
||||
repo = MemoryThreadMetaStore(store)
|
||||
|
||||
async def _scenario() -> dict:
|
||||
await repo.create("fresh", user_id=None, metadata={"a": 1})
|
||||
record = (await store.aget(THREADS_NS, "fresh")).value
|
||||
return record
|
||||
|
||||
record = asyncio.run(_scenario())
|
||||
assert _ISO_TIMESTAMP_RE.match(record["created_at"]), record
|
||||
assert _ISO_TIMESTAMP_RE.match(record["updated_at"]), record
|
||||
|
||||
|
||||
def test_get_thread_state_returns_iso_for_legacy_checkpoint_metadata() -> None:
|
||||
"""Checkpoints written by older Gateway versions stored
|
||||
``created_at`` as a unix-second float in their metadata. The
|
||||
``/state`` endpoint must surface that value as ISO so the frontend's
|
||||
``new Date(...)`` parser does not break — same root cause as the
|
||||
thread-record bug fixed in #2594, but on the checkpoint side.
|
||||
"""
|
||||
app, _store, checkpointer = _build_thread_app()
|
||||
thread_id = "legacy-state"
|
||||
|
||||
async def _seed() -> None:
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
await checkpointer.aput(
|
||||
{"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}},
|
||||
empty_checkpoint(),
|
||||
{"step": -1, "source": "input", "writes": None, "parents": {}, "created_at": 1777252410.411327},
|
||||
{},
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_seed())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get(f"/api/threads/{thread_id}/state")
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
body = response.json()
|
||||
assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"]
|
||||
assert _ISO_TIMESTAMP_RE.match(body["checkpoint"]["ts"]), body["checkpoint"]
|
||||
|
||||
|
||||
def test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata() -> None:
|
||||
"""``/history`` walks ``checkpointer.alist`` and emits one entry per
|
||||
checkpoint. Each entry's ``created_at`` must come out as ISO even if
|
||||
older checkpoints stored a unix-second float in their metadata.
|
||||
"""
|
||||
app, _store, checkpointer = _build_thread_app()
|
||||
thread_id = "legacy-history"
|
||||
|
||||
async def _seed() -> None:
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
await checkpointer.aput(
|
||||
{"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}},
|
||||
empty_checkpoint(),
|
||||
{"step": -1, "source": "input", "writes": None, "parents": {}, "created_at": 1777252410.411327},
|
||||
{},
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_seed())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post(f"/api/threads/{thread_id}/history", json={"limit": 10})
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
entries = response.json()
|
||||
assert entries, "expected at least one history entry"
|
||||
for entry in entries:
|
||||
assert _ISO_TIMESTAMP_RE.match(entry["created_at"]), entry
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
"""Tests for TokenUsageMiddleware attribution annotations."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from deerflow.agents.middlewares.token_usage_middleware import (
|
||||
TOKEN_USAGE_ATTRIBUTION_KEY,
|
||||
TokenUsageMiddleware,
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime():
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "test-thread"}
|
||||
return runtime
|
||||
|
||||
|
||||
class TestTokenUsageMiddleware:
|
||||
def test_annotates_todo_updates_with_structured_actions(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "write_todos:1",
|
||||
"name": "write_todos",
|
||||
"args": {
|
||||
"todos": [
|
||||
{"content": "Inspect streaming path", "status": "completed"},
|
||||
{"content": "Design token attribution schema", "status": "in_progress"},
|
||||
]
|
||||
},
|
||||
}
|
||||
],
|
||||
usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120},
|
||||
)
|
||||
|
||||
state = {
|
||||
"messages": [message],
|
||||
"todos": [
|
||||
{"content": "Inspect streaming path", "status": "in_progress"},
|
||||
{"content": "Design token attribution schema", "status": "pending"},
|
||||
],
|
||||
}
|
||||
|
||||
result = middleware.after_model(state, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
updated_message = result["messages"][0]
|
||||
attribution = updated_message.additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
||||
assert attribution["kind"] == "tool_batch"
|
||||
assert attribution["shared_attribution"] is True
|
||||
assert attribution["tool_call_ids"] == ["write_todos:1"]
|
||||
assert attribution["actions"] == [
|
||||
{
|
||||
"kind": "todo_complete",
|
||||
"content": "Inspect streaming path",
|
||||
"tool_call_id": "write_todos:1",
|
||||
},
|
||||
{
|
||||
"kind": "todo_start",
|
||||
"content": "Design token attribution schema",
|
||||
"tool_call_id": "write_todos:1",
|
||||
},
|
||||
]
|
||||
|
||||
def test_annotates_subagent_and_search_steps(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "task:1",
|
||||
"name": "task",
|
||||
"args": {
|
||||
"description": "spec-coder patch message grouping",
|
||||
"subagent_type": "general-purpose",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "web_search:1",
|
||||
"name": "web_search",
|
||||
"args": {"query": "LangGraph useStream messages tuple"},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
||||
assert attribution["kind"] == "tool_batch"
|
||||
assert attribution["shared_attribution"] is True
|
||||
assert attribution["actions"] == [
|
||||
{
|
||||
"kind": "subagent",
|
||||
"description": "spec-coder patch message grouping",
|
||||
"subagent_type": "general-purpose",
|
||||
"tool_call_id": "task:1",
|
||||
},
|
||||
{
|
||||
"kind": "search",
|
||||
"tool_name": "web_search",
|
||||
"query": "LangGraph useStream messages tuple",
|
||||
"tool_call_id": "web_search:1",
|
||||
},
|
||||
]
|
||||
|
||||
def test_marks_final_answer_when_no_tools(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(content="Here is the final answer.")
|
||||
|
||||
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
||||
assert attribution["kind"] == "final_answer"
|
||||
assert attribution["shared_attribution"] is False
|
||||
assert attribution["actions"] == []
|
||||
|
||||
def test_annotates_removed_todos(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "write_todos:remove",
|
||||
"name": "write_todos",
|
||||
"args": {
|
||||
"todos": [],
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
result = middleware.after_model(
|
||||
{
|
||||
"messages": [message],
|
||||
"todos": [
|
||||
{"content": "Archive obsolete plan", "status": "pending"},
|
||||
],
|
||||
},
|
||||
_make_runtime(),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
||||
assert attribution["kind"] == "todo_update"
|
||||
assert attribution["shared_attribution"] is False
|
||||
assert attribution["actions"] == [
|
||||
{
|
||||
"kind": "todo_remove",
|
||||
"content": "Archive obsolete plan",
|
||||
"tool_call_id": "write_todos:remove",
|
||||
}
|
||||
]
|
||||
@@ -1,20 +1,14 @@
|
||||
"""Tests for deerflow.uploads.manager — shared upload management logic."""
|
||||
|
||||
import errno
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.uploads.manager import (
|
||||
PathTraversalError,
|
||||
UnsafeUploadPathError,
|
||||
claim_unique_filename,
|
||||
delete_file_safe,
|
||||
list_files_in_dir,
|
||||
normalize_filename,
|
||||
validate_path_traversal,
|
||||
write_upload_file_no_symlink,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -103,54 +97,6 @@ class TestValidatePathTraversal:
|
||||
validate_path_traversal(link, tmp_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# write_upload_file_no_symlink
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWriteUploadFileNoSymlink:
|
||||
def test_writes_new_file(self, tmp_path):
|
||||
dest = write_upload_file_no_symlink(tmp_path, "notes.txt", b"hello")
|
||||
|
||||
assert dest == tmp_path / "notes.txt"
|
||||
assert dest.read_bytes() == b"hello"
|
||||
|
||||
def test_overwrites_existing_regular_file_with_single_link(self, tmp_path):
|
||||
dest = tmp_path / "notes.txt"
|
||||
dest.write_bytes(b"old contents")
|
||||
assert os.stat(dest).st_nlink == 1
|
||||
|
||||
result = write_upload_file_no_symlink(tmp_path, "notes.txt", b"new contents")
|
||||
|
||||
assert result == dest
|
||||
assert dest.read_bytes() == b"new contents"
|
||||
assert os.stat(dest).st_nlink == 1
|
||||
|
||||
def test_fails_closed_without_no_follow_support(self, tmp_path, monkeypatch):
|
||||
monkeypatch.delattr(os, "O_NOFOLLOW", raising=False)
|
||||
|
||||
with pytest.raises(UnsafeUploadPathError, match="O_NOFOLLOW"):
|
||||
write_upload_file_no_symlink(tmp_path, "notes.txt", b"hello")
|
||||
|
||||
assert not (tmp_path / "notes.txt").exists()
|
||||
|
||||
def test_open_uses_nonblocking_flag_when_available(self, tmp_path):
|
||||
with patch("deerflow.uploads.manager.os.open", side_effect=OSError(errno.ENXIO, "no reader")) as open_mock:
|
||||
with pytest.raises(UnsafeUploadPathError, match="Unsafe upload destination"):
|
||||
write_upload_file_no_symlink(tmp_path, "pipe.txt", b"hello")
|
||||
|
||||
flags = open_mock.call_args.args[1]
|
||||
assert flags & os.O_NONBLOCK
|
||||
|
||||
@pytest.mark.parametrize("open_errno", [errno.ENXIO, errno.EAGAIN])
|
||||
def test_nonblocking_special_file_open_errors_are_unsafe(self, tmp_path, open_errno):
|
||||
with patch("deerflow.uploads.manager.os.open", side_effect=OSError(open_errno, "would block")):
|
||||
with pytest.raises(UnsafeUploadPathError, match="Unsafe upload destination"):
|
||||
write_upload_file_no_symlink(tmp_path, "pipe.txt", b"hello")
|
||||
|
||||
assert not (tmp_path / "pipe.txt").exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_files_in_dir
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import stat
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
@@ -429,105 +428,6 @@ def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path):
|
||||
assert [f.name for f in thread_uploads_dir.iterdir()] == ["passwd"]
|
||||
|
||||
|
||||
def test_upload_files_rejects_preexisting_symlink_destination(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
outside_file = tmp_path / "outside.txt"
|
||||
outside_file.write_text("protected", encoding="utf-8")
|
||||
(thread_uploads_dir / "victim.txt").symlink_to(outside_file)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
|
||||
with (
|
||||
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="victim.txt", file=BytesIO(b"attacker upload"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
|
||||
assert result.success is False
|
||||
assert result.files == []
|
||||
assert result.skipped_files == ["victim.txt"]
|
||||
assert "skipped 1 unsafe file" in result.message
|
||||
assert outside_file.read_text(encoding="utf-8") == "protected"
|
||||
assert (thread_uploads_dir / "victim.txt").is_symlink()
|
||||
|
||||
|
||||
def test_upload_files_rejects_dangling_symlink_destination(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
missing_target = tmp_path / "missing-target.txt"
|
||||
(thread_uploads_dir / "victim.txt").symlink_to(missing_target)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
|
||||
with (
|
||||
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="victim.txt", file=BytesIO(b"attacker upload"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
|
||||
assert result.success is False
|
||||
assert result.files == []
|
||||
assert result.skipped_files == ["victim.txt"]
|
||||
assert not missing_target.exists()
|
||||
assert (thread_uploads_dir / "victim.txt").is_symlink()
|
||||
|
||||
|
||||
def test_upload_files_rejects_hardlinked_destination_without_truncating(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
outside_file = tmp_path / "outside.txt"
|
||||
outside_file.write_text("protected", encoding="utf-8")
|
||||
os.link(outside_file, thread_uploads_dir / "victim.txt")
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
|
||||
with (
|
||||
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="victim.txt", file=BytesIO(b"attacker upload"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
|
||||
assert result.success is False
|
||||
assert result.files == []
|
||||
assert result.skipped_files == ["victim.txt"]
|
||||
assert outside_file.read_text(encoding="utf-8") == "protected"
|
||||
assert (thread_uploads_dir / "victim.txt").read_text(encoding="utf-8") == "protected"
|
||||
|
||||
|
||||
def test_upload_files_overwrites_existing_regular_file(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
existing_file = thread_uploads_dir / "notes.txt"
|
||||
existing_file.write_bytes(b"old upload")
|
||||
assert existing_file.stat().st_nlink == 1
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
|
||||
with (
|
||||
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="notes.txt", file=BytesIO(b"new upload"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
|
||||
assert result.success is True
|
||||
assert [file_info["filename"] for file_info in result.files] == ["notes.txt"]
|
||||
assert existing_file.read_bytes() == b"new upload"
|
||||
assert existing_file.stat().st_nlink == 1
|
||||
|
||||
|
||||
def test_delete_uploaded_file_removes_generated_markdown_companion(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
"""Tests for ``deerflow.utils.time``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import UTC, datetime, timedelta, timezone
|
||||
|
||||
from deerflow.utils.time import coerce_iso, now_iso
|
||||
|
||||
_ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
|
||||
|
||||
|
||||
def test_now_iso_is_utc_iso8601() -> None:
|
||||
value = now_iso()
|
||||
assert _ISO_RE.match(value), value
|
||||
parsed = datetime.fromisoformat(value)
|
||||
assert parsed.tzinfo is not None
|
||||
assert parsed.tzinfo.utcoffset(parsed) == UTC.utcoffset(parsed)
|
||||
|
||||
|
||||
def test_coerce_iso_passes_iso_through() -> None:
|
||||
iso = "2026-04-27T01:13:30.411334+00:00"
|
||||
assert coerce_iso(iso) == iso
|
||||
|
||||
|
||||
def test_coerce_iso_converts_unix_float_string() -> None:
|
||||
legacy = "1777252410.411327"
|
||||
out = coerce_iso(legacy)
|
||||
assert _ISO_RE.match(out), out
|
||||
# Round-trip: parsed timestamp matches the original epoch.
|
||||
parsed = datetime.fromisoformat(out)
|
||||
assert abs(parsed.timestamp() - 1777252410.411327) < 1e-3
|
||||
|
||||
|
||||
def test_coerce_iso_converts_unix_int_string() -> None:
|
||||
out = coerce_iso("1700000000")
|
||||
assert _ISO_RE.match(out), out
|
||||
|
||||
|
||||
def test_coerce_iso_converts_numeric_types() -> None:
|
||||
out_float = coerce_iso(1777252410.411327)
|
||||
out_int = coerce_iso(1700000000)
|
||||
assert _ISO_RE.match(out_float)
|
||||
assert _ISO_RE.match(out_int)
|
||||
|
||||
|
||||
def test_coerce_iso_handles_empty_and_none() -> None:
|
||||
assert coerce_iso(None) == ""
|
||||
assert coerce_iso("") == ""
|
||||
|
||||
|
||||
def test_coerce_iso_does_not_misinterpret_short_numeric() -> None:
|
||||
# A 4-digit year should never be parsed as a unix timestamp; only
|
||||
# 10-digit unix-second strings match the legacy pattern.
|
||||
assert coerce_iso("2026") == "2026"
|
||||
|
||||
|
||||
def test_coerce_iso_handles_unparseable_string() -> None:
|
||||
assert coerce_iso("not-a-timestamp") == "not-a-timestamp"
|
||||
|
||||
|
||||
def test_coerce_iso_rejects_bool() -> None:
|
||||
# ``bool`` is a subclass of ``int`` — must not be treated as epoch 0/1.
|
||||
assert coerce_iso(True) == "True"
|
||||
assert coerce_iso(False) == "False"
|
||||
|
||||
|
||||
def test_coerce_iso_handles_tz_aware_datetime() -> None:
|
||||
# str(datetime) would emit a space separator; coerce_iso must use ``T``.
|
||||
dt = datetime(2026, 4, 27, 1, 13, 30, 411327, tzinfo=UTC)
|
||||
out = coerce_iso(dt)
|
||||
assert out == "2026-04-27T01:13:30.411327+00:00"
|
||||
assert "T" in out and " " not in out
|
||||
|
||||
|
||||
def test_coerce_iso_handles_tz_naive_datetime_as_utc() -> None:
|
||||
dt = datetime(2026, 4, 27, 1, 13, 30, 411327)
|
||||
out = coerce_iso(dt)
|
||||
assert out == "2026-04-27T01:13:30.411327+00:00"
|
||||
parsed = datetime.fromisoformat(out)
|
||||
assert parsed.tzinfo is not None
|
||||
assert parsed.utcoffset() == timedelta(0)
|
||||
|
||||
|
||||
def test_coerce_iso_normalises_non_utc_datetime_to_utc() -> None:
|
||||
# +08:00 wall-clock 09:13 == UTC 01:13.
|
||||
plus_eight = timezone(timedelta(hours=8))
|
||||
dt = datetime(2026, 4, 27, 9, 13, 30, 411327, tzinfo=plus_eight)
|
||||
out = coerce_iso(dt)
|
||||
assert out == "2026-04-27T01:13:30.411327+00:00"
|
||||
@@ -373,16 +373,6 @@ tools:
|
||||
use: deerflow.community.ddg_search.tools:web_search_tool
|
||||
max_results: 5
|
||||
|
||||
# Web search tool (uses Serper - Google Search API, requires SERPER_API_KEY)
|
||||
# Serper provides real-time Google Search results. Sign up at https://serper.dev
|
||||
# Note: set SERPER_API_KEY in your environment before starting the app, or
|
||||
# uncomment and fill in api_key below (the $VAR syntax is resolved at startup).
|
||||
# - name: web_search
|
||||
# group: web
|
||||
# use: deerflow.community.serper.tools:web_search_tool
|
||||
# max_results: 5
|
||||
# # api_key: $SERPER_API_KEY # Optional if SERPER_API_KEY env var is set
|
||||
|
||||
# Web search tool (requires Tavily API key)
|
||||
# - name: web_search
|
||||
# group: web
|
||||
|
||||
@@ -25,7 +25,7 @@ import { useAgent } from "@/core/agents";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import { useModels } from "@/core/models/hooks";
|
||||
import { useNotification } from "@/core/notification/hooks";
|
||||
import { useLocalSettings, useThreadSettings } from "@/core/settings";
|
||||
import { useThreadSettings } from "@/core/settings";
|
||||
import { useThreadStream } from "@/core/threads/hooks";
|
||||
import { textOfMessage } from "@/core/threads/utils";
|
||||
import { env } from "@/env";
|
||||
@@ -45,7 +45,6 @@ export default function AgentChatPage() {
|
||||
const { threadId, setThreadId, isNewThread, setIsNewThread } =
|
||||
useThreadChat();
|
||||
const [settings, setSettings] = useThreadSettings(threadId);
|
||||
const [localSettings, setLocalSettings] = useLocalSettings();
|
||||
const { tokenUsageEnabled } = useModels();
|
||||
|
||||
const { showNotification } = useNotification();
|
||||
@@ -101,9 +100,6 @@ export default function AgentChatPage() {
|
||||
? MESSAGE_LIST_DEFAULT_PADDING_BOTTOM +
|
||||
MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM
|
||||
: undefined;
|
||||
const tokenUsageInlineMode = tokenUsageEnabled
|
||||
? localSettings.tokenUsage.inlineMode
|
||||
: "off";
|
||||
|
||||
return (
|
||||
<ThreadContext.Provider value={{ thread }}>
|
||||
@@ -143,10 +139,6 @@ export default function AgentChatPage() {
|
||||
<TokenUsageIndicator
|
||||
enabled={tokenUsageEnabled}
|
||||
messages={thread.messages}
|
||||
preferences={localSettings.tokenUsage}
|
||||
onPreferencesChange={(preferences) =>
|
||||
setLocalSettings("tokenUsage", preferences)
|
||||
}
|
||||
/>
|
||||
<ExportTrigger threadId={threadId} />
|
||||
<ArtifactTrigger />
|
||||
@@ -160,10 +152,10 @@ export default function AgentChatPage() {
|
||||
threadId={threadId}
|
||||
thread={thread}
|
||||
paddingBottom={messageListPaddingBottom}
|
||||
tokenUsageEnabled={tokenUsageEnabled}
|
||||
hasMoreHistory={hasMoreHistory}
|
||||
loadMoreHistory={loadMoreHistory}
|
||||
isHistoryLoading={isHistoryLoading}
|
||||
tokenUsageInlineMode={tokenUsageInlineMode}
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ import { ThreadContext } from "@/components/workspace/messages/context";
|
||||
import type { Agent } from "@/core/agents";
|
||||
import {
|
||||
AgentNameCheckError,
|
||||
AgentsApiDisabledError,
|
||||
checkAgentName,
|
||||
createAgent,
|
||||
getAgent,
|
||||
@@ -155,9 +154,7 @@ export default function NewAgentPage() {
|
||||
return;
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof AgentsApiDisabledError) {
|
||||
setNameError(t.agents.nameStepApiDisabledError);
|
||||
} else if (
|
||||
if (
|
||||
err instanceof AgentNameCheckError &&
|
||||
err.reason === "backend_unreachable"
|
||||
) {
|
||||
@@ -178,10 +175,6 @@ export default function NewAgentPage() {
|
||||
soul: "",
|
||||
});
|
||||
} catch (err) {
|
||||
if (err instanceof AgentsApiDisabledError) {
|
||||
setNameError(t.agents.nameStepApiDisabledError);
|
||||
return;
|
||||
}
|
||||
setNameError(
|
||||
getCreateAgentErrorMessage(
|
||||
err,
|
||||
@@ -204,7 +197,6 @@ export default function NewAgentPage() {
|
||||
nameInput,
|
||||
sendMessage,
|
||||
t.agents.nameStepAlreadyExistsError,
|
||||
t.agents.nameStepApiDisabledError,
|
||||
t.agents.nameStepNetworkError,
|
||||
t.agents.nameStepBootstrapMessage,
|
||||
t.agents.nameStepCheckError,
|
||||
|
||||
@@ -24,7 +24,7 @@ import { Welcome } from "@/components/workspace/welcome";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import { useModels } from "@/core/models/hooks";
|
||||
import { useNotification } from "@/core/notification/hooks";
|
||||
import { useLocalSettings, useThreadSettings } from "@/core/settings";
|
||||
import { useThreadSettings } from "@/core/settings";
|
||||
import { useThreadStream } from "@/core/threads/hooks";
|
||||
import { textOfMessage } from "@/core/threads/utils";
|
||||
import { env } from "@/env";
|
||||
@@ -36,7 +36,6 @@ export default function ChatPage() {
|
||||
const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } =
|
||||
useThreadChat();
|
||||
const [settings, setSettings] = useThreadSettings(threadId);
|
||||
const [localSettings, setLocalSettings] = useLocalSettings();
|
||||
const { tokenUsageEnabled } = useModels();
|
||||
const mountedRef = useRef(false);
|
||||
useSpecificChatMode();
|
||||
@@ -100,9 +99,6 @@ export default function ChatPage() {
|
||||
? MESSAGE_LIST_DEFAULT_PADDING_BOTTOM +
|
||||
MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM
|
||||
: undefined;
|
||||
const tokenUsageInlineMode = tokenUsageEnabled
|
||||
? localSettings.tokenUsage.inlineMode
|
||||
: "off";
|
||||
|
||||
return (
|
||||
<ThreadContext.Provider value={{ thread, isMock }}>
|
||||
@@ -123,10 +119,6 @@ export default function ChatPage() {
|
||||
<TokenUsageIndicator
|
||||
enabled={tokenUsageEnabled}
|
||||
messages={thread.messages}
|
||||
preferences={localSettings.tokenUsage}
|
||||
onPreferencesChange={(preferences) =>
|
||||
setLocalSettings("tokenUsage", preferences)
|
||||
}
|
||||
/>
|
||||
<ExportTrigger threadId={threadId} />
|
||||
<ArtifactTrigger />
|
||||
@@ -139,10 +131,10 @@ export default function ChatPage() {
|
||||
threadId={threadId}
|
||||
thread={thread}
|
||||
paddingBottom={messageListPaddingBottom}
|
||||
tokenUsageEnabled={tokenUsageEnabled}
|
||||
hasMoreHistory={hasMoreHistory}
|
||||
loadMoreHistory={loadMoreHistory}
|
||||
isHistoryLoading={isHistoryLoading}
|
||||
tokenUsageInlineMode={tokenUsageInlineMode}
|
||||
/>
|
||||
</div>
|
||||
<div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4">
|
||||
|
||||
@@ -2,7 +2,6 @@ import type { Message } from "@langchain/langgraph-sdk";
|
||||
import {
|
||||
BookOpenTextIcon,
|
||||
ChevronUp,
|
||||
CoinsIcon,
|
||||
FolderOpenIcon,
|
||||
GlobeIcon,
|
||||
LightbulbIcon,
|
||||
@@ -25,8 +24,6 @@ import {
|
||||
import { CodeBlock } from "@/components/ai-elements/code-block";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import { formatTokenCount } from "@/core/messages/usage";
|
||||
import type { TokenDebugStep } from "@/core/messages/usage-model";
|
||||
import {
|
||||
extractReasoningContentFromMessage,
|
||||
findToolCallResult,
|
||||
@@ -46,14 +43,10 @@ export function MessageGroup({
|
||||
className,
|
||||
messages,
|
||||
isLoading = false,
|
||||
tokenDebugSteps = [],
|
||||
showTokenDebugSummaries = false,
|
||||
}: {
|
||||
className?: string;
|
||||
messages: Message[];
|
||||
isLoading?: boolean;
|
||||
tokenDebugSteps?: TokenDebugStep[];
|
||||
showTokenDebugSummaries?: boolean;
|
||||
}) {
|
||||
const { t } = useI18n();
|
||||
const [showAbove, setShowAbove] = useState(
|
||||
@@ -63,28 +56,6 @@ export function MessageGroup({
|
||||
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true",
|
||||
);
|
||||
const steps = useMemo(() => convertToSteps(messages), [messages]);
|
||||
const debugStepByMessageId = useMemo(
|
||||
() =>
|
||||
new Map(
|
||||
tokenDebugSteps.map(
|
||||
(step) => [step.messageId || step.id, step] as const,
|
||||
),
|
||||
),
|
||||
[tokenDebugSteps],
|
||||
);
|
||||
const toolCallCountByMessageId = useMemo(() => {
|
||||
const counts = new Map<string, number>();
|
||||
|
||||
for (const step of steps) {
|
||||
if (step.type !== "toolCall" || !step.messageId) {
|
||||
continue;
|
||||
}
|
||||
|
||||
counts.set(step.messageId, (counts.get(step.messageId) ?? 0) + 1);
|
||||
}
|
||||
|
||||
return counts;
|
||||
}, [steps]);
|
||||
const lastToolCallStep = useMemo(() => {
|
||||
const filteredSteps = steps.filter((step) => step.type === "toolCall");
|
||||
return filteredSteps[filteredSteps.length - 1];
|
||||
@@ -106,125 +77,6 @@ export function MessageGroup({
|
||||
}
|
||||
}, [lastToolCallStep, steps]);
|
||||
const rehypePlugins = useRehypeSplitWordsIntoSpans(isLoading);
|
||||
const firstEligibleDebugSummaryStepIndexByMessageId = useMemo(() => {
|
||||
const firstIndices = new Map<string, number>();
|
||||
|
||||
if (!showTokenDebugSummaries) {
|
||||
return firstIndices;
|
||||
}
|
||||
|
||||
for (const [index, step] of steps.entries()) {
|
||||
const messageId = step.messageId;
|
||||
if (!messageId || firstIndices.has(messageId)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const debugStep = debugStepByMessageId.get(messageId);
|
||||
if (!debugStep) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolCallCount = toolCallCountByMessageId.get(messageId) ?? 0;
|
||||
if (!debugStep.sharedAttribution && toolCallCount > 0) {
|
||||
continue;
|
||||
}
|
||||
if (
|
||||
!debugStep.sharedAttribution &&
|
||||
toolCallCount === 0 &&
|
||||
debugStep.label === t.common.thinking &&
|
||||
debugStep.secondaryLabels.length === 0
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
firstIndices.set(messageId, index);
|
||||
}
|
||||
|
||||
return firstIndices;
|
||||
}, [
|
||||
debugStepByMessageId,
|
||||
showTokenDebugSummaries,
|
||||
steps,
|
||||
t.common.thinking,
|
||||
toolCallCountByMessageId,
|
||||
]);
|
||||
|
||||
const renderDebugSummary = (
|
||||
messageId: string | undefined,
|
||||
stepIndex: number,
|
||||
) => {
|
||||
if (!showTokenDebugSummaries || !messageId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const debugStep = debugStepByMessageId.get(messageId);
|
||||
if (!debugStep) {
|
||||
return null;
|
||||
}
|
||||
if (
|
||||
firstEligibleDebugSummaryStepIndexByMessageId.get(messageId) !== stepIndex
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<ChainOfThoughtStep
|
||||
key={`token-debug-${messageId}`}
|
||||
icon={CoinsIcon}
|
||||
label={
|
||||
<DebugStepLabel
|
||||
label={debugStep.label}
|
||||
token={formatDebugToken(debugStep, t)}
|
||||
/>
|
||||
}
|
||||
description={
|
||||
debugStep.sharedAttribution
|
||||
? t.tokenUsage.sharedAttribution
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{debugStep.secondaryLabels.length > 0 && (
|
||||
<ChainOfThoughtSearchResults>
|
||||
{debugStep.secondaryLabels.map((label, index) => (
|
||||
<ChainOfThoughtSearchResult
|
||||
key={`${debugStep.id}-${index}-${label}`}
|
||||
>
|
||||
{label}
|
||||
</ChainOfThoughtSearchResult>
|
||||
))}
|
||||
</ChainOfThoughtSearchResults>
|
||||
)}
|
||||
</ChainOfThoughtStep>
|
||||
);
|
||||
};
|
||||
|
||||
const renderToolCall = (
|
||||
step: CoTToolCallStep,
|
||||
options?: { isLast?: boolean },
|
||||
) => {
|
||||
const debugStep =
|
||||
showTokenDebugSummaries && step.messageId
|
||||
? debugStepByMessageId.get(step.messageId)
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<ToolCall
|
||||
key={step.id}
|
||||
{...step}
|
||||
isLast={options?.isLast}
|
||||
isLoading={isLoading}
|
||||
tokenDebugStep={
|
||||
debugStep && !debugStep.sharedAttribution ? debugStep : undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const lastReasoningDebugStep =
|
||||
showTokenDebugSummaries && lastReasoningStep?.messageId
|
||||
? debugStepByMessageId.get(lastReasoningStep.messageId)
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<ChainOfThought
|
||||
className={cn("w-full gap-2 rounded-lg border p-0.5", className)}
|
||||
@@ -259,46 +111,36 @@ export function MessageGroup({
|
||||
{lastToolCallStep && (
|
||||
<ChainOfThoughtContent className="px-4 pb-2">
|
||||
{showAbove &&
|
||||
aboveLastToolCallSteps.flatMap((step) => {
|
||||
const stepIndex = steps.indexOf(step);
|
||||
if (step.type === "reasoning") {
|
||||
return [
|
||||
renderDebugSummary(step.messageId, stepIndex),
|
||||
<ChainOfThoughtStep
|
||||
key={step.id}
|
||||
label={
|
||||
<MarkdownContent
|
||||
content={step.reasoning ?? ""}
|
||||
isLoading={isLoading}
|
||||
rehypePlugins={rehypePlugins}
|
||||
/>
|
||||
}
|
||||
></ChainOfThoughtStep>,
|
||||
];
|
||||
}
|
||||
|
||||
return [
|
||||
renderDebugSummary(step.messageId, stepIndex),
|
||||
renderToolCall(step),
|
||||
];
|
||||
})}
|
||||
{renderDebugSummary(
|
||||
lastToolCallStep.messageId,
|
||||
steps.indexOf(lastToolCallStep),
|
||||
)}
|
||||
aboveLastToolCallSteps.map((step) =>
|
||||
step.type === "reasoning" ? (
|
||||
<ChainOfThoughtStep
|
||||
key={step.id}
|
||||
label={
|
||||
<MarkdownContent
|
||||
content={step.reasoning ?? ""}
|
||||
isLoading={isLoading}
|
||||
rehypePlugins={rehypePlugins}
|
||||
/>
|
||||
}
|
||||
></ChainOfThoughtStep>
|
||||
) : (
|
||||
<ToolCall key={step.id} {...step} isLoading={isLoading} />
|
||||
),
|
||||
)}
|
||||
{lastToolCallStep && (
|
||||
<FlipDisplay uniqueKey={lastToolCallStep.id ?? ""}>
|
||||
{renderToolCall(lastToolCallStep, { isLast: true })}
|
||||
<ToolCall
|
||||
key={lastToolCallStep.id}
|
||||
{...lastToolCallStep}
|
||||
isLast={true}
|
||||
isLoading={isLoading}
|
||||
/>
|
||||
</FlipDisplay>
|
||||
)}
|
||||
</ChainOfThoughtContent>
|
||||
)}
|
||||
{lastReasoningStep && (
|
||||
<>
|
||||
{renderDebugSummary(
|
||||
lastReasoningStep.messageId,
|
||||
steps.indexOf(lastReasoningStep),
|
||||
)}
|
||||
<Button
|
||||
key={lastReasoningStep.id}
|
||||
className="w-full items-start justify-start text-left"
|
||||
@@ -308,22 +150,7 @@ export function MessageGroup({
|
||||
<div className="flex w-full items-center justify-between">
|
||||
<ChainOfThoughtStep
|
||||
className="font-normal"
|
||||
label={
|
||||
<DebugStepLabel
|
||||
label={t.common.thinking}
|
||||
token={shouldInlineThinkingToken({
|
||||
debugStep: lastReasoningDebugStep,
|
||||
toolCallCount: lastReasoningStep.messageId
|
||||
? (toolCallCountByMessageId.get(
|
||||
lastReasoningStep.messageId,
|
||||
) ?? 0)
|
||||
: 0,
|
||||
enabled: showTokenDebugSummaries,
|
||||
thinkingLabel: t.common.thinking,
|
||||
t,
|
||||
})}
|
||||
/>
|
||||
}
|
||||
label={t.common.thinking}
|
||||
icon={LightbulbIcon}
|
||||
></ChainOfThoughtStep>
|
||||
<div>
|
||||
@@ -356,60 +183,6 @@ export function MessageGroup({
|
||||
);
|
||||
}
|
||||
|
||||
function formatDebugToken(
|
||||
debugStep: TokenDebugStep,
|
||||
t: ReturnType<typeof useI18n>["t"],
|
||||
) {
|
||||
return debugStep.usage
|
||||
? `${formatTokenCount(debugStep.usage.totalTokens)} ${t.tokenUsage.label}`
|
||||
: t.tokenUsage.unavailableShort;
|
||||
}
|
||||
|
||||
function shouldInlineThinkingToken({
|
||||
debugStep,
|
||||
toolCallCount,
|
||||
enabled,
|
||||
thinkingLabel,
|
||||
t,
|
||||
}: {
|
||||
debugStep?: TokenDebugStep;
|
||||
toolCallCount: number;
|
||||
enabled: boolean;
|
||||
thinkingLabel: string;
|
||||
t: ReturnType<typeof useI18n>["t"];
|
||||
}) {
|
||||
if (
|
||||
!enabled ||
|
||||
!debugStep ||
|
||||
debugStep.sharedAttribution ||
|
||||
toolCallCount > 0 ||
|
||||
debugStep.label !== thinkingLabel
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return formatDebugToken(debugStep, t);
|
||||
}
|
||||
|
||||
function DebugStepLabel({
|
||||
label,
|
||||
token,
|
||||
}: {
|
||||
label: React.ReactNode;
|
||||
token?: string | null;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-3">
|
||||
<div className="min-w-0 flex-1">{label}</div>
|
||||
{token ? (
|
||||
<div className="text-muted-foreground shrink-0 font-mono text-[11px]">
|
||||
{token}
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ToolCall({
|
||||
id,
|
||||
messageId,
|
||||
@@ -418,7 +191,6 @@ function ToolCall({
|
||||
result,
|
||||
isLast = false,
|
||||
isLoading = false,
|
||||
tokenDebugStep,
|
||||
}: {
|
||||
id?: string;
|
||||
messageId?: string;
|
||||
@@ -427,20 +199,10 @@ function ToolCall({
|
||||
result?: string | Record<string, unknown>;
|
||||
isLast?: boolean;
|
||||
isLoading?: boolean;
|
||||
tokenDebugStep?: TokenDebugStep;
|
||||
}) {
|
||||
const { t } = useI18n();
|
||||
const { setOpen, autoOpen, autoSelect, selectedArtifact, select } =
|
||||
useArtifacts();
|
||||
const tokenLabel = tokenDebugStep
|
||||
? formatDebugToken(tokenDebugStep, t)
|
||||
: null;
|
||||
const resolveLabel = (fallback: React.ReactNode) =>
|
||||
tokenDebugStep ? (
|
||||
<DebugStepLabel label={tokenDebugStep.label} token={tokenLabel} />
|
||||
) : (
|
||||
fallback
|
||||
);
|
||||
|
||||
if (name === "web_search") {
|
||||
let label: React.ReactNode = t.toolCalls.searchForRelatedInfo;
|
||||
@@ -448,11 +210,7 @@ function ToolCall({
|
||||
label = t.toolCalls.searchOnWebFor(args.query);
|
||||
}
|
||||
return (
|
||||
<ChainOfThoughtStep
|
||||
key={id}
|
||||
label={resolveLabel(label)}
|
||||
icon={SearchIcon}
|
||||
>
|
||||
<ChainOfThoughtStep key={id} label={label} icon={SearchIcon}>
|
||||
{Array.isArray(result) && (
|
||||
<ChainOfThoughtSearchResults>
|
||||
{result.map((item) => (
|
||||
@@ -482,11 +240,7 @@ function ToolCall({
|
||||
}
|
||||
)?.results;
|
||||
return (
|
||||
<ChainOfThoughtStep
|
||||
key={id}
|
||||
label={resolveLabel(label)}
|
||||
icon={SearchIcon}
|
||||
>
|
||||
<ChainOfThoughtStep key={id} label={label} icon={SearchIcon}>
|
||||
{Array.isArray(results) && (
|
||||
<ChainOfThoughtSearchResults>
|
||||
{Array.isArray(results) &&
|
||||
@@ -526,7 +280,7 @@ function ToolCall({
|
||||
return (
|
||||
<ChainOfThoughtStep
|
||||
key={id}
|
||||
label={resolveLabel(t.toolCalls.viewWebPage)}
|
||||
label={t.toolCalls.viewWebPage}
|
||||
icon={GlobeIcon}
|
||||
>
|
||||
<ChainOfThoughtSearchResult>
|
||||
@@ -551,11 +305,7 @@ function ToolCall({
|
||||
}
|
||||
const path: string | undefined = (args as { path: string })?.path;
|
||||
return (
|
||||
<ChainOfThoughtStep
|
||||
key={id}
|
||||
label={resolveLabel(description)}
|
||||
icon={FolderOpenIcon}
|
||||
>
|
||||
<ChainOfThoughtStep key={id} label={description} icon={FolderOpenIcon}>
|
||||
{path && (
|
||||
<ChainOfThoughtSearchResult className="cursor-pointer">
|
||||
{path}
|
||||
@@ -571,11 +321,7 @@ function ToolCall({
|
||||
}
|
||||
const { path } = args as { path: string; content: string };
|
||||
return (
|
||||
<ChainOfThoughtStep
|
||||
key={id}
|
||||
label={resolveLabel(description)}
|
||||
icon={BookOpenTextIcon}
|
||||
>
|
||||
<ChainOfThoughtStep key={id} label={description} icon={BookOpenTextIcon}>
|
||||
{path && (
|
||||
<ChainOfThoughtSearchResult className="cursor-pointer">
|
||||
{path}
|
||||
@@ -607,7 +353,7 @@ function ToolCall({
|
||||
<ChainOfThoughtStep
|
||||
key={id}
|
||||
className="cursor-pointer"
|
||||
label={resolveLabel(description)}
|
||||
label={description}
|
||||
icon={NotebookPenIcon}
|
||||
onClick={() => {
|
||||
select(
|
||||
@@ -629,19 +375,13 @@ function ToolCall({
|
||||
const description: string | undefined = (args as { description: string })
|
||||
?.description;
|
||||
if (!description) {
|
||||
return (
|
||||
<ChainOfThoughtStep
|
||||
key={id}
|
||||
label={resolveLabel(t.toolCalls.executeCommand)}
|
||||
icon={SquareTerminalIcon}
|
||||
/>
|
||||
);
|
||||
return t.toolCalls.executeCommand;
|
||||
}
|
||||
const command: string | undefined = (args as { command: string })?.command;
|
||||
return (
|
||||
<ChainOfThoughtStep
|
||||
key={id}
|
||||
label={resolveLabel(description)}
|
||||
label={description}
|
||||
icon={SquareTerminalIcon}
|
||||
>
|
||||
{command && (
|
||||
@@ -658,7 +398,7 @@ function ToolCall({
|
||||
return (
|
||||
<ChainOfThoughtStep
|
||||
key={id}
|
||||
label={resolveLabel(t.toolCalls.needYourHelp)}
|
||||
label={t.toolCalls.needYourHelp}
|
||||
icon={MessageCircleQuestionMarkIcon}
|
||||
></ChainOfThoughtStep>
|
||||
);
|
||||
@@ -666,7 +406,7 @@ function ToolCall({
|
||||
return (
|
||||
<ChainOfThoughtStep
|
||||
key={id}
|
||||
label={resolveLabel(t.toolCalls.writeTodos)}
|
||||
label={t.toolCalls.writeTodos}
|
||||
icon={ListTodoIcon}
|
||||
></ChainOfThoughtStep>
|
||||
);
|
||||
@@ -676,7 +416,7 @@ function ToolCall({
|
||||
return (
|
||||
<ChainOfThoughtStep
|
||||
key={id}
|
||||
label={resolveLabel(description ?? t.toolCalls.useTool(name))}
|
||||
label={description ?? t.toolCalls.useTool(name)}
|
||||
icon={WrenchIcon}
|
||||
></ChainOfThoughtStep>
|
||||
);
|
||||
|
||||
@@ -50,6 +50,7 @@ import { cn } from "@/lib/utils";
|
||||
import { CopyButton } from "../copy-button";
|
||||
|
||||
import { MarkdownContent } from "./markdown-content";
|
||||
import { MessageTokenUsage } from "./message-token-usage";
|
||||
|
||||
function FeedbackButtons({
|
||||
threadId,
|
||||
@@ -120,20 +121,20 @@ function FeedbackButtons({
|
||||
|
||||
export function MessageListItem({
|
||||
className,
|
||||
threadId,
|
||||
message,
|
||||
isLoading,
|
||||
tokenUsageEnabled = false,
|
||||
feedback,
|
||||
runId,
|
||||
threadId,
|
||||
showCopyButton = true,
|
||||
}: {
|
||||
className?: string;
|
||||
message: Message;
|
||||
isLoading?: boolean;
|
||||
threadId: string;
|
||||
tokenUsageEnabled?: boolean;
|
||||
feedback?: FeedbackData | null;
|
||||
runId?: string;
|
||||
showCopyButton?: boolean;
|
||||
}) {
|
||||
const isHuman = message.type === "human";
|
||||
return (
|
||||
@@ -146,17 +147,16 @@ export function MessageListItem({
|
||||
message={message}
|
||||
isLoading={isLoading}
|
||||
threadId={threadId}
|
||||
tokenUsageEnabled={tokenUsageEnabled}
|
||||
/>
|
||||
{!isLoading && showCopyButton && (
|
||||
{!isLoading && (
|
||||
<MessageToolbar
|
||||
className={cn(
|
||||
isHuman
|
||||
? "absolute right-0 -bottom-9 left-0 justify-end"
|
||||
: "absolute right-0 bottom-0 left-0",
|
||||
"z-20 opacity-0 transition-opacity delay-200 duration-300 group-hover/conversation-message:opacity-100",
|
||||
isHuman ? "-bottom-9 justify-end" : "-bottom-8",
|
||||
"absolute right-0 left-0 z-20",
|
||||
)}
|
||||
>
|
||||
<div className="pointer-events-auto flex gap-1">
|
||||
<div className="flex gap-1">
|
||||
<CopyButton
|
||||
clipboardData={
|
||||
extractContentFromMessage(message) ??
|
||||
@@ -213,11 +213,13 @@ function MessageContent_({
|
||||
message,
|
||||
isLoading = false,
|
||||
threadId,
|
||||
tokenUsageEnabled = false,
|
||||
}: {
|
||||
className?: string;
|
||||
message: Message;
|
||||
isLoading?: boolean;
|
||||
threadId: string;
|
||||
tokenUsageEnabled?: boolean;
|
||||
}) {
|
||||
const rehypePlugins = useRehypeSplitWordsIntoSpans(isLoading);
|
||||
const isHuman = message.type === "human";
|
||||
@@ -295,6 +297,11 @@ function MessageContent_({
|
||||
<ReasoningTrigger />
|
||||
<ReasoningContent>{reasoningContent}</ReasoningContent>
|
||||
</Reasoning>
|
||||
<MessageTokenUsage
|
||||
enabled={tokenUsageEnabled}
|
||||
isLoading={isLoading}
|
||||
message={message}
|
||||
/>
|
||||
</AIElementMessageContent>
|
||||
);
|
||||
}
|
||||
@@ -332,6 +339,11 @@ function MessageContent_({
|
||||
className="my-3"
|
||||
components={components}
|
||||
/>
|
||||
<MessageTokenUsage
|
||||
enabled={tokenUsageEnabled}
|
||||
isLoading={isLoading}
|
||||
message={message}
|
||||
/>
|
||||
</AIElementMessageContent>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { Message } from "@langchain/langgraph-sdk";
|
||||
import type { BaseStream } from "@langchain/langgraph-sdk/react";
|
||||
import { ChevronUpIcon, Loader2Icon } from "lucide-react";
|
||||
import { useCallback, useEffect, useMemo, useRef } from "react";
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
|
||||
import {
|
||||
Conversation,
|
||||
@@ -9,20 +8,15 @@ import {
|
||||
} from "@/components/ai-elements/conversation";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import {
|
||||
buildTokenDebugSteps,
|
||||
type TokenUsageInlineMode,
|
||||
} from "@/core/messages/usage-model";
|
||||
import {
|
||||
extractContentFromMessage,
|
||||
extractPresentFilesFromMessage,
|
||||
extractReasoningContentFromMessage,
|
||||
extractTextFromMessage,
|
||||
getAssistantTurnUsageMessages,
|
||||
getMessageGroups,
|
||||
groupMessages,
|
||||
hasContent,
|
||||
hasPresentFiles,
|
||||
hasReasoning,
|
||||
hasToolCalls,
|
||||
} from "@/core/messages/utils";
|
||||
import { useRehypeSplitWordsIntoSpans } from "@/core/rehype";
|
||||
import type { Subtask } from "@/core/tasks";
|
||||
@@ -31,16 +25,12 @@ import type { AgentThreadState } from "@/core/threads";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
import { ArtifactFileList } from "../artifacts/artifact-file-list";
|
||||
import { CopyButton } from "../copy-button";
|
||||
import { StreamingIndicator } from "../streaming-indicator";
|
||||
|
||||
import { MarkdownContent } from "./markdown-content";
|
||||
import { MessageGroup } from "./message-group";
|
||||
import { MessageListItem } from "./message-list-item";
|
||||
import {
|
||||
MessageTokenUsageDebugList,
|
||||
MessageTokenUsageList,
|
||||
} from "./message-token-usage";
|
||||
import { MessageTokenUsageList } from "./message-token-usage";
|
||||
import { MessageListSkeleton } from "./skeleton";
|
||||
import { SubtaskCard } from "./subtask-card";
|
||||
|
||||
@@ -159,7 +149,7 @@ export function MessageList({
|
||||
threadId,
|
||||
thread,
|
||||
paddingBottom = MESSAGE_LIST_DEFAULT_PADDING_BOTTOM,
|
||||
tokenUsageInlineMode = "off",
|
||||
tokenUsageEnabled = false,
|
||||
hasMoreHistory,
|
||||
loadMoreHistory,
|
||||
isHistoryLoading,
|
||||
@@ -168,7 +158,7 @@ export function MessageList({
|
||||
threadId: string;
|
||||
thread: BaseStream<AgentThreadState>;
|
||||
paddingBottom?: number;
|
||||
tokenUsageInlineMode?: TokenUsageInlineMode;
|
||||
tokenUsageEnabled?: boolean;
|
||||
hasMoreHistory?: boolean;
|
||||
loadMoreHistory?: () => void;
|
||||
isHistoryLoading?: boolean;
|
||||
@@ -177,85 +167,10 @@ export function MessageList({
|
||||
const rehypePlugins = useRehypeSplitWordsIntoSpans(thread.isLoading);
|
||||
const updateSubtask = useUpdateSubtask();
|
||||
const messages = thread.messages;
|
||||
const groupedMessages = getMessageGroups(messages);
|
||||
const turnUsageMessagesByGroupIndex =
|
||||
getAssistantTurnUsageMessages(groupedMessages);
|
||||
const tokenDebugSteps = useMemo(
|
||||
() => buildTokenDebugSteps(messages, t),
|
||||
[messages, t],
|
||||
);
|
||||
|
||||
const renderAssistantCopyButton = useCallback((messages: Message[]) => {
|
||||
const clipboardData = [...messages]
|
||||
.reverse()
|
||||
.filter((message) => message.type === "ai")
|
||||
.map((message) => {
|
||||
const content = extractContentFromMessage(message);
|
||||
return content ?? extractReasoningContentFromMessage(message) ?? "";
|
||||
})
|
||||
.find((content) => content.length > 0);
|
||||
|
||||
if (!clipboardData) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mt-2 flex justify-start opacity-0 transition-opacity delay-200 duration-300 group-hover/assistant-turn:opacity-100">
|
||||
<CopyButton clipboardData={clipboardData} />
|
||||
</div>
|
||||
);
|
||||
}, []);
|
||||
|
||||
const renderTokenUsage = useCallback(
|
||||
({
|
||||
messages,
|
||||
turnUsageMessages,
|
||||
inlineDebug = true,
|
||||
debugMessageIds,
|
||||
}: {
|
||||
messages: Message[];
|
||||
turnUsageMessages?: Message[] | null;
|
||||
inlineDebug?: boolean;
|
||||
debugMessageIds?: string[];
|
||||
}) => {
|
||||
if (tokenUsageInlineMode === "per_turn") {
|
||||
return (
|
||||
<MessageTokenUsageList
|
||||
enabled={true}
|
||||
isLoading={thread.isLoading}
|
||||
messages={turnUsageMessages ?? []}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (tokenUsageInlineMode === "step_debug" && inlineDebug) {
|
||||
const messageIds = new Set(
|
||||
debugMessageIds ??
|
||||
messages
|
||||
.filter((message) => message.type === "ai")
|
||||
.map((message) => message.id)
|
||||
.filter((id): id is string => typeof id === "string"),
|
||||
);
|
||||
return (
|
||||
<MessageTokenUsageDebugList
|
||||
enabled={true}
|
||||
isLoading={thread.isLoading}
|
||||
steps={tokenDebugSteps.filter((step) =>
|
||||
messageIds.has(step.messageId),
|
||||
)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
},
|
||||
[thread.isLoading, tokenDebugSteps, tokenUsageInlineMode],
|
||||
);
|
||||
|
||||
if (thread.isThreadLoading && messages.length === 0) {
|
||||
return <MessageListSkeleton />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Conversation
|
||||
className={cn("flex size-full flex-col justify-center", className)}
|
||||
@@ -266,37 +181,19 @@ export function MessageList({
|
||||
hasMore={hasMoreHistory}
|
||||
loadMore={loadMoreHistory}
|
||||
/>
|
||||
{groupedMessages.map((group, groupIndex) => {
|
||||
const turnUsageMessages = turnUsageMessagesByGroupIndex[groupIndex];
|
||||
|
||||
{groupMessages(messages, (group) => {
|
||||
if (group.type === "human" || group.type === "assistant") {
|
||||
return (
|
||||
<div
|
||||
key={group.id}
|
||||
className={cn(
|
||||
"w-full",
|
||||
group.type === "assistant" && "group/assistant-turn",
|
||||
)}
|
||||
>
|
||||
{group.messages.map((msg) => {
|
||||
return (
|
||||
<MessageListItem
|
||||
key={`${group.id}/${msg.id}`}
|
||||
message={msg}
|
||||
isLoading={thread.isLoading}
|
||||
threadId={threadId}
|
||||
showCopyButton={group.type !== "assistant"}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{renderTokenUsage({
|
||||
messages: group.messages,
|
||||
turnUsageMessages,
|
||||
})}
|
||||
{group.type === "assistant" &&
|
||||
renderAssistantCopyButton(group.messages)}
|
||||
</div>
|
||||
);
|
||||
return group.messages.map((msg) => {
|
||||
return (
|
||||
<MessageListItem
|
||||
key={`${group.id}/${msg.id}`}
|
||||
threadId={threadId}
|
||||
message={msg}
|
||||
isLoading={thread.isLoading}
|
||||
tokenUsageEnabled={tokenUsageEnabled}
|
||||
/>
|
||||
);
|
||||
});
|
||||
} else if (group.type === "assistant:clarification") {
|
||||
const message = group.messages[0];
|
||||
if (message && hasContent(message)) {
|
||||
@@ -307,10 +204,11 @@ export function MessageList({
|
||||
isLoading={thread.isLoading}
|
||||
rehypePlugins={rehypePlugins}
|
||||
/>
|
||||
{renderTokenUsage({
|
||||
messages: group.messages,
|
||||
turnUsageMessages,
|
||||
})}
|
||||
<MessageTokenUsageList
|
||||
enabled={tokenUsageEnabled}
|
||||
isLoading={thread.isLoading}
|
||||
messages={group.messages}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -334,10 +232,11 @@ export function MessageList({
|
||||
/>
|
||||
)}
|
||||
<ArtifactFileList files={files} threadId={threadId} />
|
||||
{renderTokenUsage({
|
||||
messages: group.messages,
|
||||
turnUsageMessages,
|
||||
})}
|
||||
<MessageTokenUsageList
|
||||
enabled={tokenUsageEnabled}
|
||||
isLoading={thread.isLoading}
|
||||
messages={group.messages}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
} else if (group.type === "assistant:subagent") {
|
||||
@@ -390,19 +289,7 @@ export function MessageList({
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const results: React.ReactNode[] = [];
|
||||
const subagentDebugMessageIds: string[] = [];
|
||||
if (tasks.size > 0) {
|
||||
results.push(
|
||||
<div
|
||||
key="subtask-count"
|
||||
className="text-muted-foreground pt-2 text-sm font-normal"
|
||||
>
|
||||
{t.subtasks.executing(tasks.size)}
|
||||
</div>,
|
||||
);
|
||||
}
|
||||
for (const message of group.messages.filter(
|
||||
(message) => message.type === "ai",
|
||||
)) {
|
||||
@@ -412,17 +299,17 @@ export function MessageList({
|
||||
key={"thinking-group-" + message.id}
|
||||
messages={[message]}
|
||||
isLoading={thread.isLoading}
|
||||
tokenDebugSteps={tokenDebugSteps.filter(
|
||||
(step) => step.messageId === message.id,
|
||||
)}
|
||||
showTokenDebugSummaries={
|
||||
tokenUsageInlineMode === "step_debug"
|
||||
}
|
||||
/>,
|
||||
);
|
||||
} else if (message.id) {
|
||||
subagentDebugMessageIds.push(message.id);
|
||||
}
|
||||
results.push(
|
||||
<div
|
||||
key="subtask-count"
|
||||
className="text-muted-foreground font-norma pt-2 text-sm"
|
||||
>
|
||||
{t.subtasks.executing(tasks.size)}
|
||||
</div>,
|
||||
);
|
||||
const taskIds = message.tool_calls
|
||||
?.filter((toolCall) => toolCall.name === "task")
|
||||
.map((toolCall) => toolCall.id);
|
||||
@@ -442,31 +329,30 @@ export function MessageList({
|
||||
className="relative z-1 flex flex-col gap-2"
|
||||
>
|
||||
{results}
|
||||
{renderTokenUsage({
|
||||
messages: group.messages,
|
||||
turnUsageMessages,
|
||||
debugMessageIds: subagentDebugMessageIds,
|
||||
})}
|
||||
<MessageTokenUsageList
|
||||
enabled={tokenUsageEnabled}
|
||||
isLoading={thread.isLoading}
|
||||
messages={group.messages}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
const tokenUsageMessages = group.messages.filter(
|
||||
(message) =>
|
||||
message.type === "ai" &&
|
||||
(hasToolCalls(message) ? true : !hasContent(message)),
|
||||
);
|
||||
return (
|
||||
<div key={"group-" + group.id} className="w-full">
|
||||
<MessageGroup
|
||||
messages={group.messages}
|
||||
isLoading={thread.isLoading}
|
||||
tokenDebugSteps={tokenDebugSteps.filter((step) =>
|
||||
group.messages.some(
|
||||
(message) => message.id === step.messageId,
|
||||
),
|
||||
)}
|
||||
showTokenDebugSummaries={tokenUsageInlineMode === "step_debug"}
|
||||
/>
|
||||
{renderTokenUsage({
|
||||
messages: group.messages,
|
||||
turnUsageMessages,
|
||||
inlineDebug: false,
|
||||
})}
|
||||
<MessageTokenUsageList
|
||||
enabled={tokenUsageEnabled}
|
||||
isLoading={thread.isLoading}
|
||||
messages={tokenUsageMessages}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
@@ -1,27 +1,29 @@
|
||||
import type { Message } from "@langchain/langgraph-sdk";
|
||||
import { CoinsIcon } from "lucide-react";
|
||||
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import { accumulateUsage, formatTokenCount } from "@/core/messages/usage";
|
||||
import type { TokenDebugStep } from "@/core/messages/usage-model";
|
||||
import { formatTokenCount, getUsageMetadata } from "@/core/messages/usage";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
function TokenUsageSummary({
|
||||
export function MessageTokenUsage({
|
||||
className,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
totalTokens,
|
||||
unavailable = false,
|
||||
enabled = false,
|
||||
isLoading = false,
|
||||
message,
|
||||
}: {
|
||||
className?: string;
|
||||
inputTokens?: number;
|
||||
outputTokens?: number;
|
||||
totalTokens?: number;
|
||||
unavailable?: boolean;
|
||||
enabled?: boolean;
|
||||
isLoading?: boolean;
|
||||
message: Message;
|
||||
}) {
|
||||
const { t } = useI18n();
|
||||
|
||||
if (!enabled || isLoading || message.type !== "ai") {
|
||||
return null;
|
||||
}
|
||||
|
||||
const usage = getUsageMetadata(message);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
@@ -33,16 +35,16 @@ function TokenUsageSummary({
|
||||
<CoinsIcon className="size-3" />
|
||||
{t.tokenUsage.label}
|
||||
</span>
|
||||
{!unavailable ? (
|
||||
{usage ? (
|
||||
<>
|
||||
<span>
|
||||
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
|
||||
{t.tokenUsage.input}: {formatTokenCount(usage.inputTokens)}
|
||||
</span>
|
||||
<span>
|
||||
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
|
||||
{t.tokenUsage.output}: {formatTokenCount(usage.outputTokens)}
|
||||
</span>
|
||||
<span className="font-medium">
|
||||
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
|
||||
{t.tokenUsage.total}: {formatTokenCount(usage.totalTokens)}
|
||||
</span>
|
||||
</>
|
||||
) : (
|
||||
@@ -73,93 +75,17 @@ export function MessageTokenUsageList({
|
||||
return null;
|
||||
}
|
||||
|
||||
const usage = accumulateUsage(aiMessages);
|
||||
|
||||
return (
|
||||
<TokenUsageSummary
|
||||
className={className}
|
||||
inputTokens={usage?.inputTokens}
|
||||
outputTokens={usage?.outputTokens}
|
||||
totalTokens={usage?.totalTokens}
|
||||
unavailable={!usage}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export function MessageTokenUsageDebugList({
|
||||
className,
|
||||
enabled = false,
|
||||
isLoading = false,
|
||||
steps,
|
||||
}: {
|
||||
className?: string;
|
||||
enabled?: boolean;
|
||||
isLoading?: boolean;
|
||||
steps: TokenDebugStep[];
|
||||
}) {
|
||||
const { t } = useI18n();
|
||||
|
||||
if (!enabled || isLoading) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (steps.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("border-border/60 mt-1 border-t pt-2", className)}>
|
||||
<div className="space-y-2">
|
||||
{steps.map((step) => (
|
||||
<div
|
||||
key={step.id}
|
||||
className="bg-muted/30 border-border/50 flex items-start justify-between gap-3 rounded-md border px-3 py-2"
|
||||
>
|
||||
<div className="min-w-0 flex-1 space-y-1">
|
||||
<div className="text-foreground flex items-center gap-2 text-xs font-medium">
|
||||
<CoinsIcon className="text-muted-foreground size-3" />
|
||||
<span className="truncate">{step.label}</span>
|
||||
</div>
|
||||
{step.secondaryLabels.length > 0 && (
|
||||
<div className="flex flex-wrap gap-1.5">
|
||||
{step.secondaryLabels.map((label, index) => (
|
||||
<Badge
|
||||
key={`${step.id}-${index}-${label}`}
|
||||
className="px-1.5 py-0 text-[10px] font-normal"
|
||||
variant="secondary"
|
||||
>
|
||||
{label}
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{step.sharedAttribution && (
|
||||
<div className="text-muted-foreground text-[11px]">
|
||||
{t.tokenUsage.sharedAttribution}
|
||||
</div>
|
||||
)}
|
||||
<div className="text-muted-foreground text-[11px]">
|
||||
{step.usage ? (
|
||||
<>
|
||||
{t.tokenUsage.input}:{" "}
|
||||
{formatTokenCount(step.usage.inputTokens)}
|
||||
{" · "}
|
||||
{t.tokenUsage.output}:{" "}
|
||||
{formatTokenCount(step.usage.outputTokens)}
|
||||
</>
|
||||
) : (
|
||||
t.tokenUsage.unavailableShort
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<Badge className="shrink-0 font-mono" variant="outline">
|
||||
{step.usage
|
||||
? `${formatTokenCount(step.usage.totalTokens)} ${t.tokenUsage.label}`
|
||||
: t.tokenUsage.unavailableShort}
|
||||
</Badge>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<>
|
||||
{aiMessages.map((message, index) => (
|
||||
<MessageTokenUsage
|
||||
className={className}
|
||||
enabled={enabled}
|
||||
isLoading={isLoading}
|
||||
key={message.id ?? index}
|
||||
message={message}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -8,13 +8,11 @@ import { Input } from "@/components/ui/input";
|
||||
import { fetch, getCsrfHeaders } from "@/core/api/fetcher";
|
||||
import { useAuth } from "@/core/auth/AuthProvider";
|
||||
import { parseAuthError } from "@/core/auth/types";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
|
||||
import { SettingsSection } from "./settings-section";
|
||||
|
||||
export function AccountSettingsPage() {
|
||||
const { user, logout } = useAuth();
|
||||
const { t } = useI18n();
|
||||
const [currentPassword, setCurrentPassword] = useState("");
|
||||
const [newPassword, setNewPassword] = useState("");
|
||||
const [confirmPassword, setConfirmPassword] = useState("");
|
||||
@@ -28,11 +26,11 @@ export function AccountSettingsPage() {
|
||||
setMessage("");
|
||||
|
||||
if (newPassword !== confirmPassword) {
|
||||
setError(t.settings.account.passwordMismatch);
|
||||
setError("New passwords do not match");
|
||||
return;
|
||||
}
|
||||
if (newPassword.length < 8) {
|
||||
setError(t.settings.account.passwordTooShort);
|
||||
setError("Password must be at least 8 characters");
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -57,12 +55,12 @@ export function AccountSettingsPage() {
|
||||
return;
|
||||
}
|
||||
|
||||
setMessage(t.settings.account.passwordChangedSuccess);
|
||||
setMessage("Password changed successfully");
|
||||
setCurrentPassword("");
|
||||
setNewPassword("");
|
||||
setConfirmPassword("");
|
||||
} catch {
|
||||
setError(t.settings.account.networkError);
|
||||
setError("Network error. Please try again.");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
@@ -70,16 +68,12 @@ export function AccountSettingsPage() {
|
||||
|
||||
return (
|
||||
<div className="space-y-8">
|
||||
<SettingsSection title={t.settings.account.profileTitle}>
|
||||
<SettingsSection title="Profile">
|
||||
<div className="space-y-2">
|
||||
<div className="grid grid-cols-[max-content_max-content] items-center gap-4">
|
||||
<span className="text-muted-foreground text-sm">
|
||||
{t.settings.account.email}
|
||||
</span>
|
||||
<span className="text-muted-foreground text-sm">Email</span>
|
||||
<span className="text-sm font-medium">{user?.email ?? "—"}</span>
|
||||
<span className="text-muted-foreground text-sm">
|
||||
{t.settings.account.role}
|
||||
</span>
|
||||
<span className="text-muted-foreground text-sm">Role</span>
|
||||
<span className="text-sm font-medium capitalize">
|
||||
{user?.system_role ?? "—"}
|
||||
</span>
|
||||
@@ -88,20 +82,20 @@ export function AccountSettingsPage() {
|
||||
</SettingsSection>
|
||||
|
||||
<SettingsSection
|
||||
title={t.settings.account.changePasswordTitle}
|
||||
description={t.settings.account.changePasswordDescription}
|
||||
title="Change Password"
|
||||
description="Update your account password."
|
||||
>
|
||||
<form onSubmit={handleChangePassword} className="max-w-sm space-y-3">
|
||||
<Input
|
||||
type="password"
|
||||
placeholder={t.settings.account.currentPassword}
|
||||
placeholder="Current password"
|
||||
value={currentPassword}
|
||||
onChange={(e) => setCurrentPassword(e.target.value)}
|
||||
required
|
||||
/>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder={t.settings.account.newPassword}
|
||||
placeholder="New password"
|
||||
value={newPassword}
|
||||
onChange={(e) => setNewPassword(e.target.value)}
|
||||
required
|
||||
@@ -109,7 +103,7 @@ export function AccountSettingsPage() {
|
||||
/>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder={t.settings.account.confirmNewPassword}
|
||||
placeholder="Confirm new password"
|
||||
value={confirmPassword}
|
||||
onChange={(e) => setConfirmPassword(e.target.value)}
|
||||
required
|
||||
@@ -118,9 +112,7 @@ export function AccountSettingsPage() {
|
||||
{error && <p className="text-sm text-red-500">{error}</p>}
|
||||
{message && <p className="text-sm text-green-500">{message}</p>}
|
||||
<Button type="submit" variant="outline" size="sm" disabled={loading}>
|
||||
{loading
|
||||
? t.settings.account.updating
|
||||
: t.settings.account.updatePassword}
|
||||
{loading ? "Updating..." : "Update Password"}
|
||||
</Button>
|
||||
</form>
|
||||
</SettingsSection>
|
||||
@@ -133,7 +125,7 @@ export function AccountSettingsPage() {
|
||||
className="gap-2"
|
||||
>
|
||||
<LogOutIcon className="size-4" />
|
||||
{t.settings.account.signOut}
|
||||
Sign Out
|
||||
</Button>
|
||||
</SettingsSection>
|
||||
</div>
|
||||
|
||||
@@ -1,81 +1,60 @@
|
||||
"use client";
|
||||
|
||||
import type { Message } from "@langchain/langgraph-sdk";
|
||||
import { ChevronDownIcon, CoinsIcon } from "lucide-react";
|
||||
import { CoinsIcon } from "lucide-react";
|
||||
import { useMemo } from "react";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuLabel,
|
||||
DropdownMenuRadioGroup,
|
||||
DropdownMenuRadioItem,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import { accumulateUsage, formatTokenCount } from "@/core/messages/usage";
|
||||
import {
|
||||
getTokenUsageViewPreset,
|
||||
tokenUsagePreferencesFromPreset,
|
||||
type TokenUsagePreferences,
|
||||
type TokenUsageViewPreset,
|
||||
} from "@/core/messages/usage-model";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface TokenUsageIndicatorProps {
|
||||
messages: Message[];
|
||||
enabled?: boolean;
|
||||
preferences: TokenUsagePreferences;
|
||||
onPreferencesChange: (preferences: TokenUsagePreferences) => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function TokenUsageIndicator({
|
||||
messages,
|
||||
enabled = false,
|
||||
preferences,
|
||||
onPreferencesChange,
|
||||
className,
|
||||
}: TokenUsageIndicatorProps) {
|
||||
const { t } = useI18n();
|
||||
|
||||
const usage = useMemo(() => accumulateUsage(messages), [messages]);
|
||||
const preset = getTokenUsageViewPreset(preferences);
|
||||
|
||||
if (!enabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button
|
||||
<Tooltip delayDuration={200}>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
className={cn(
|
||||
"text-muted-foreground bg-background/70 hover:bg-background/90 flex h-auto items-center gap-1.5 rounded-full border px-2 py-1 text-xs font-normal",
|
||||
"text-muted-foreground bg-background/70 flex cursor-default items-center gap-1.5 rounded-full border px-2 py-1 text-xs",
|
||||
!usage && "opacity-60",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<CoinsIcon size={14} />
|
||||
<span>{t.tokenUsage.label}</span>
|
||||
<span className="font-mono">
|
||||
{preferences.headerTotal
|
||||
? usage
|
||||
? formatTokenCount(usage.totalTokens)
|
||||
: "-"
|
||||
: t.tokenUsage.presets[presetKeyToTranslationKey(preset)]}
|
||||
{usage ? formatTokenCount(usage.totalTokens) : "-"}
|
||||
</span>
|
||||
<ChevronDownIcon className="size-3" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent side="bottom" align="end" className="w-80">
|
||||
<DropdownMenuLabel>{t.tokenUsage.title}</DropdownMenuLabel>
|
||||
<div className="px-2 py-1 text-xs">
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="bottom" align="end">
|
||||
<div className="space-y-1 text-xs">
|
||||
<div className="font-medium">{t.tokenUsage.title}</div>
|
||||
{usage ? (
|
||||
<div className="space-y-1">
|
||||
<>
|
||||
<div className="flex justify-between gap-4">
|
||||
<span>{t.tokenUsage.input}</span>
|
||||
<span className="font-mono">
|
||||
@@ -96,53 +75,14 @@ export function TokenUsageIndicator({
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<div className="text-muted-foreground">
|
||||
<div className="text-muted-foreground max-w-56">
|
||||
{t.tokenUsage.unavailable}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuLabel>{t.tokenUsage.view}</DropdownMenuLabel>
|
||||
<DropdownMenuRadioGroup
|
||||
value={preset}
|
||||
onValueChange={(value) =>
|
||||
onPreferencesChange(
|
||||
tokenUsagePreferencesFromPreset(value as TokenUsageViewPreset),
|
||||
)
|
||||
}
|
||||
>
|
||||
{(
|
||||
["off", "summary", "per_turn", "debug"] as TokenUsageViewPreset[]
|
||||
).map((value) => {
|
||||
const translationKey = presetKeyToTranslationKey(value);
|
||||
return (
|
||||
<DropdownMenuRadioItem key={value} value={value}>
|
||||
<div className="grid gap-0.5">
|
||||
<span>{t.tokenUsage.presets[translationKey]}</span>
|
||||
<span className="text-muted-foreground text-xs">
|
||||
{t.tokenUsage.presetDescriptions[translationKey]}
|
||||
</span>
|
||||
</div>
|
||||
</DropdownMenuRadioItem>
|
||||
);
|
||||
})}
|
||||
</DropdownMenuRadioGroup>
|
||||
<DropdownMenuSeparator />
|
||||
<div className="text-muted-foreground px-2 py-2 text-xs leading-relaxed">
|
||||
{t.tokenUsage.note}
|
||||
</div>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
||||
function presetKeyToTranslationKey(preset: TokenUsageViewPreset) {
|
||||
switch (preset) {
|
||||
case "per_turn":
|
||||
return "perTurn" as const;
|
||||
default:
|
||||
return preset;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,17 +15,6 @@ export class AgentNameCheckError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
export class AgentsApiDisabledError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = "AgentsApiDisabledError";
|
||||
}
|
||||
}
|
||||
|
||||
function isAgentsApiDisabledDetail(detail: string | undefined): boolean {
|
||||
return typeof detail === "string" && detail.includes("agents_api.enabled");
|
||||
}
|
||||
|
||||
export async function listAgents(): Promise<Agent[]> {
|
||||
const res = await fetch(`${getBackendBaseURL()}/api/agents`);
|
||||
if (!res.ok) throw new Error(`Failed to load agents: ${res.statusText}`);
|
||||
@@ -47,9 +36,6 @@ export async function createAgent(request: CreateAgentRequest): Promise<Agent> {
|
||||
});
|
||||
if (!res.ok) {
|
||||
const err = (await res.json().catch(() => ({}))) as { detail?: string };
|
||||
if (isAgentsApiDisabledDetail(err.detail)) {
|
||||
throw new AgentsApiDisabledError(err.detail!);
|
||||
}
|
||||
throw new Error(err.detail ?? `Failed to create agent: ${res.statusText}`);
|
||||
}
|
||||
return res.json() as Promise<Agent>;
|
||||
@@ -95,9 +81,6 @@ export async function checkAgentName(
|
||||
|
||||
if (!res.ok) {
|
||||
const err = (await res.json().catch(() => ({}))) as { detail?: string };
|
||||
if (isAgentsApiDisabledDetail(err.detail)) {
|
||||
throw new AgentsApiDisabledError(err.detail!);
|
||||
}
|
||||
if (BACKEND_UNAVAILABLE_STATUSES.has(res.status)) {
|
||||
throw new AgentNameCheckError(
|
||||
"Could not reach the DeerFlow backend.",
|
||||
|
||||
@@ -204,8 +204,6 @@ export const enUS: Translations = {
|
||||
nameStepNetworkError:
|
||||
"Network request failed — check your network or backend connection",
|
||||
nameStepCheckError: "Could not verify name availability — please try again",
|
||||
nameStepApiDisabledError:
|
||||
"Custom agent management is not enabled on this server. Please contact your administrator.",
|
||||
nameStepBootstrapMessage:
|
||||
"The new custom agent name is {name}. Let's bootstrap it's **SOUL**.",
|
||||
save: "Save agent",
|
||||
@@ -306,32 +304,9 @@ export const enUS: Translations = {
|
||||
input: "Input",
|
||||
output: "Output",
|
||||
total: "Total",
|
||||
view: "Display",
|
||||
unavailable:
|
||||
"No token usage yet. Usage appears only after a successful model response when the provider returns usage_metadata.",
|
||||
unavailableShort: "No usage returned",
|
||||
note: "Shown from provider-returned usage_metadata. Totals are best-effort conversation totals and may differ from provider billing pages.",
|
||||
presets: {
|
||||
off: "Off",
|
||||
summary: "Summary",
|
||||
perTurn: "Per turn",
|
||||
debug: "Debug",
|
||||
},
|
||||
presetDescriptions: {
|
||||
off: "Hide token usage in the header and conversation.",
|
||||
summary: "Show only the current conversation total in the header.",
|
||||
perTurn:
|
||||
"Show the header total and one token summary per assistant turn.",
|
||||
debug: "Show the header total and step-level token debugging details.",
|
||||
},
|
||||
finalAnswer: "Final answer",
|
||||
stepTotal: "Step total",
|
||||
sharedAttribution: "Shared across multiple actions in this step",
|
||||
subagent: (description: string) => `Subagent: ${description}`,
|
||||
startTodo: (content: string) => `Start To-do: ${content}`,
|
||||
completeTodo: (content: string) => `Complete To-do: ${content}`,
|
||||
updateTodo: (content: string) => `Update To-do: ${content}`,
|
||||
removeTodo: (content: string) => `Remove To-do: ${content}`,
|
||||
},
|
||||
|
||||
// Shortcuts
|
||||
@@ -478,23 +453,6 @@ export const enUS: Translations = {
|
||||
notSupported: "Your browser does not support notifications.",
|
||||
disableNotification: "Disable notification",
|
||||
},
|
||||
account: {
|
||||
profileTitle: "Profile",
|
||||
email: "Email",
|
||||
role: "Role",
|
||||
changePasswordTitle: "Change Password",
|
||||
changePasswordDescription: "Update your account password.",
|
||||
currentPassword: "Current password",
|
||||
newPassword: "New password",
|
||||
confirmNewPassword: "Confirm new password",
|
||||
passwordMismatch: "New passwords do not match",
|
||||
passwordTooShort: "Password must be at least 8 characters",
|
||||
passwordChangedSuccess: "Password changed successfully",
|
||||
networkError: "Network error. Please try again.",
|
||||
updating: "Updating...",
|
||||
updatePassword: "Update Password",
|
||||
signOut: "Sign Out",
|
||||
},
|
||||
acknowledge: {
|
||||
emptyTitle: "Acknowledgements",
|
||||
emptyDescription: "Credits and acknowledgements will show here.",
|
||||
|
||||
@@ -141,7 +141,6 @@ export interface Translations {
|
||||
nameStepAlreadyExistsError: string;
|
||||
nameStepNetworkError: string;
|
||||
nameStepCheckError: string;
|
||||
nameStepApiDisabledError: string;
|
||||
nameStepBootstrapMessage: string;
|
||||
save: string;
|
||||
saving: string;
|
||||
@@ -236,30 +235,8 @@ export interface Translations {
|
||||
input: string;
|
||||
output: string;
|
||||
total: string;
|
||||
view: string;
|
||||
unavailable: string;
|
||||
unavailableShort: string;
|
||||
note: string;
|
||||
presets: {
|
||||
off: string;
|
||||
summary: string;
|
||||
perTurn: string;
|
||||
debug: string;
|
||||
};
|
||||
presetDescriptions: {
|
||||
off: string;
|
||||
summary: string;
|
||||
perTurn: string;
|
||||
debug: string;
|
||||
};
|
||||
finalAnswer: string;
|
||||
stepTotal: string;
|
||||
sharedAttribution: string;
|
||||
subagent: (description: string) => string;
|
||||
startTodo: (content: string) => string;
|
||||
completeTodo: (content: string) => string;
|
||||
updateTodo: (content: string) => string;
|
||||
removeTodo: (content: string) => string;
|
||||
};
|
||||
|
||||
// Shortcuts
|
||||
@@ -394,23 +371,6 @@ export interface Translations {
|
||||
notSupported: string;
|
||||
disableNotification: string;
|
||||
};
|
||||
account: {
|
||||
profileTitle: string;
|
||||
email: string;
|
||||
role: string;
|
||||
changePasswordTitle: string;
|
||||
changePasswordDescription: string;
|
||||
currentPassword: string;
|
||||
newPassword: string;
|
||||
confirmNewPassword: string;
|
||||
passwordMismatch: string;
|
||||
passwordTooShort: string;
|
||||
passwordChangedSuccess: string;
|
||||
networkError: string;
|
||||
updating: string;
|
||||
updatePassword: string;
|
||||
signOut: string;
|
||||
};
|
||||
acknowledge: {
|
||||
emptyTitle: string;
|
||||
emptyDescription: string;
|
||||
|
||||
@@ -192,8 +192,6 @@ export const zhCN: Translations = {
|
||||
nameStepAlreadyExistsError: "已存在同名智能体",
|
||||
nameStepNetworkError: "网络请求失败,请检查网络或后端连接",
|
||||
nameStepCheckError: "无法验证名称可用性,请稍后重试",
|
||||
nameStepApiDisabledError:
|
||||
"服务器未开启自定义智能体管理功能,请联系管理员。",
|
||||
nameStepBootstrapMessage:
|
||||
"新智能体的名称是 {name},现在开始为它生成 **SOUL**。",
|
||||
save: "保存智能体",
|
||||
@@ -292,31 +290,9 @@ export const zhCN: Translations = {
|
||||
input: "输入",
|
||||
output: "输出",
|
||||
total: "总计",
|
||||
view: "显示方式",
|
||||
unavailable:
|
||||
"暂无 Token 用量。只有模型成功返回且供应商提供 usage_metadata 时才会显示。",
|
||||
unavailableShort: "未返回用量",
|
||||
note: "基于供应商返回的 usage_metadata 展示。当前总量是 best-effort 的会话参考值,可能与平台账单页不完全一致。",
|
||||
presets: {
|
||||
off: "关闭",
|
||||
summary: "总览",
|
||||
perTurn: "每轮",
|
||||
debug: "调试",
|
||||
},
|
||||
presetDescriptions: {
|
||||
off: "隐藏顶部和会话内的 token 展示。",
|
||||
summary: "只在顶部显示当前对话累计 token。",
|
||||
perTurn: "显示顶部累计,并为每轮 assistant 回复显示一条汇总 token。",
|
||||
debug: "显示顶部累计,并展示按步骤归类的 token 调试信息。",
|
||||
},
|
||||
finalAnswer: "最终回复",
|
||||
stepTotal: "步骤总计",
|
||||
sharedAttribution: "该 token 由此步骤中的多个动作共同消耗",
|
||||
subagent: (description: string) => `子任务:${description}`,
|
||||
startTodo: (content: string) => `开始 To-do:${content}`,
|
||||
completeTodo: (content: string) => `完成 To-do:${content}`,
|
||||
updateTodo: (content: string) => `更新 To-do:${content}`,
|
||||
removeTodo: (content: string) => `移除 To-do:${content}`,
|
||||
},
|
||||
|
||||
// Shortcuts
|
||||
@@ -458,23 +434,6 @@ export const zhCN: Translations = {
|
||||
notSupported: "当前浏览器不支持通知功能。",
|
||||
disableNotification: "关闭通知",
|
||||
},
|
||||
account: {
|
||||
profileTitle: "个人信息",
|
||||
email: "邮箱",
|
||||
role: "角色",
|
||||
changePasswordTitle: "修改密码",
|
||||
changePasswordDescription: "更新你的账号密码。",
|
||||
currentPassword: "当前密码",
|
||||
newPassword: "新密码",
|
||||
confirmNewPassword: "确认新密码",
|
||||
passwordMismatch: "两次输入的新密码不一致",
|
||||
passwordTooShort: "密码长度至少为 8 个字符",
|
||||
passwordChangedSuccess: "密码修改成功",
|
||||
networkError: "网络错误,请重试。",
|
||||
updating: "更新中...",
|
||||
updatePassword: "修改密码",
|
||||
signOut: "退出登录",
|
||||
},
|
||||
acknowledge: {
|
||||
emptyTitle: "致谢",
|
||||
emptyDescription: "相关的致谢信息会展示在这里。",
|
||||
|
||||
@@ -1,440 +0,0 @@
|
||||
import type { Message } from "@langchain/langgraph-sdk";
|
||||
|
||||
import type { Translations } from "@/core/i18n/locales/types";
|
||||
|
||||
import { getUsageMetadata, type TokenUsage } from "./usage";
|
||||
import { hasContent } from "./utils";
|
||||
|
||||
export type TokenUsageInlineMode = "off" | "per_turn" | "step_debug";
|
||||
|
||||
export interface TokenUsagePreferences {
|
||||
headerTotal: boolean;
|
||||
inlineMode: TokenUsageInlineMode;
|
||||
}
|
||||
|
||||
export type TokenUsageViewPreset = "off" | "summary" | "per_turn" | "debug";
|
||||
|
||||
export interface TokenDebugStep {
|
||||
id: string;
|
||||
messageId: string;
|
||||
label: string;
|
||||
secondaryLabels: string[];
|
||||
usage: TokenUsage | null;
|
||||
sharedAttribution: boolean;
|
||||
}
|
||||
|
||||
type TokenUsageAttributionAction =
|
||||
| {
|
||||
kind: "todo_start" | "todo_complete" | "todo_update" | "todo_remove";
|
||||
content?: string;
|
||||
tool_call_id?: string;
|
||||
}
|
||||
| {
|
||||
kind: "subagent";
|
||||
description?: string | null;
|
||||
subagent_type?: string | null;
|
||||
tool_call_id?: string;
|
||||
}
|
||||
| {
|
||||
kind: "search";
|
||||
query?: string | null;
|
||||
tool_name?: string | null;
|
||||
tool_call_id?: string;
|
||||
}
|
||||
| {
|
||||
kind: "present_files" | "clarification";
|
||||
tool_call_id?: string;
|
||||
}
|
||||
| {
|
||||
kind: "tool";
|
||||
tool_name?: string | null;
|
||||
description?: string | null;
|
||||
tool_call_id?: string;
|
||||
};
|
||||
|
||||
interface TokenUsageAttribution {
|
||||
version?: number;
|
||||
kind?:
|
||||
| "thinking"
|
||||
| "final_answer"
|
||||
| "tool_batch"
|
||||
| "todo_update"
|
||||
| "subagent_dispatch";
|
||||
shared_attribution?: boolean;
|
||||
tool_call_ids?: string[];
|
||||
actions?: TokenUsageAttributionAction[];
|
||||
}
|
||||
|
||||
// Precise write_todos labels come from the backend attribution payload.
|
||||
// The frontend fallback intentionally stays generic so we do not duplicate
|
||||
// backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py
|
||||
//::_build_todo_actions and risk the two diffing algorithms drifting apart.
|
||||
|
||||
export function getTokenUsageViewPreset(
|
||||
preferences: TokenUsagePreferences,
|
||||
): TokenUsageViewPreset {
|
||||
if (!preferences.headerTotal && preferences.inlineMode === "off") {
|
||||
return "off";
|
||||
}
|
||||
if (preferences.headerTotal && preferences.inlineMode === "off") {
|
||||
return "summary";
|
||||
}
|
||||
if (preferences.inlineMode === "step_debug") {
|
||||
return "debug";
|
||||
}
|
||||
return "per_turn";
|
||||
}
|
||||
|
||||
export function tokenUsagePreferencesFromPreset(
|
||||
preset: TokenUsageViewPreset,
|
||||
): TokenUsagePreferences {
|
||||
switch (preset) {
|
||||
case "off":
|
||||
return { headerTotal: false, inlineMode: "off" };
|
||||
case "summary":
|
||||
return { headerTotal: true, inlineMode: "off" };
|
||||
case "debug":
|
||||
return { headerTotal: true, inlineMode: "step_debug" };
|
||||
case "per_turn":
|
||||
default:
|
||||
return { headerTotal: true, inlineMode: "per_turn" };
|
||||
}
|
||||
}
|
||||
|
||||
export function buildTokenDebugSteps(
|
||||
messages: Message[],
|
||||
t: Translations,
|
||||
): TokenDebugStep[] {
|
||||
const steps: TokenDebugStep[] = [];
|
||||
|
||||
for (const [index, message] of messages.entries()) {
|
||||
if (message.type !== "ai") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const usage = getUsageMetadata(message);
|
||||
const attribution = getTokenUsageAttribution(message);
|
||||
const actionLabels: string[] = [];
|
||||
|
||||
if (attribution) {
|
||||
actionLabels.push(...buildActionLabelsFromAttribution(attribution, t));
|
||||
|
||||
if (actionLabels.length === 0) {
|
||||
if (attribution.kind === "final_answer") {
|
||||
actionLabels.push(t.tokenUsage.finalAnswer);
|
||||
} else if (attribution.kind === "thinking") {
|
||||
actionLabels.push(t.common.thinking);
|
||||
}
|
||||
}
|
||||
|
||||
if (actionLabels.length > 0) {
|
||||
const sharedAttribution =
|
||||
attribution.shared_attribution ?? actionLabels.length > 1;
|
||||
steps.push({
|
||||
id: message.id ?? `token-step-${index}`,
|
||||
messageId: message.id ?? `token-step-${index}`,
|
||||
label:
|
||||
sharedAttribution && actionLabels.length > 1
|
||||
? t.tokenUsage.stepTotal
|
||||
: actionLabels[0]!,
|
||||
secondaryLabels:
|
||||
sharedAttribution && actionLabels.length > 1 ? actionLabels : [],
|
||||
usage,
|
||||
sharedAttribution,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (const toolCall of message.tool_calls ?? []) {
|
||||
const toolArgs = (toolCall.args ?? {}) as Record<string, unknown>;
|
||||
|
||||
if (toolCall.name === "write_todos") {
|
||||
actionLabels.push(t.toolCalls.writeTodos);
|
||||
continue;
|
||||
}
|
||||
|
||||
actionLabels.push(
|
||||
describeToolCall(
|
||||
{
|
||||
name: toolCall.name,
|
||||
args: toolArgs,
|
||||
},
|
||||
t,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if (actionLabels.length === 0) {
|
||||
if (hasContent(message)) {
|
||||
actionLabels.push(t.tokenUsage.finalAnswer);
|
||||
} else {
|
||||
actionLabels.push(t.common.thinking);
|
||||
}
|
||||
}
|
||||
|
||||
steps.push({
|
||||
id: message.id ?? `token-step-${index}`,
|
||||
messageId: message.id ?? `token-step-${index}`,
|
||||
label:
|
||||
actionLabels.length === 1 ? actionLabels[0]! : t.tokenUsage.stepTotal,
|
||||
secondaryLabels: actionLabels.length > 1 ? actionLabels : [],
|
||||
usage,
|
||||
sharedAttribution: actionLabels.length > 1,
|
||||
});
|
||||
}
|
||||
|
||||
return steps;
|
||||
}
|
||||
|
||||
function getTokenUsageAttribution(
|
||||
message: Message,
|
||||
): TokenUsageAttribution | null {
|
||||
if (message.type !== "ai") {
|
||||
return null;
|
||||
}
|
||||
|
||||
const additionalKwargs = message.additional_kwargs;
|
||||
if (!additionalKwargs || typeof additionalKwargs !== "object") {
|
||||
return null;
|
||||
}
|
||||
|
||||
const attribution = (additionalKwargs as Record<string, unknown>)
|
||||
.token_usage_attribution;
|
||||
const normalized = normalizeTokenUsageAttribution(attribution);
|
||||
if (!normalized) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return normalized;
|
||||
}
|
||||
|
||||
function buildActionLabelsFromAttribution(
|
||||
attribution: TokenUsageAttribution,
|
||||
t: Translations,
|
||||
): string[] {
|
||||
return (attribution.actions ?? [])
|
||||
.map((action) => describeAttributionAction(action, t))
|
||||
.filter((label): label is string => !!label);
|
||||
}
|
||||
|
||||
function describeAttributionAction(
|
||||
action: TokenUsageAttributionAction,
|
||||
t: Translations,
|
||||
): string | null {
|
||||
switch (action.kind) {
|
||||
case "todo_start":
|
||||
return action.content
|
||||
? t.tokenUsage.startTodo(action.content)
|
||||
: t.toolCalls.writeTodos;
|
||||
case "todo_complete":
|
||||
return action.content
|
||||
? t.tokenUsage.completeTodo(action.content)
|
||||
: t.toolCalls.writeTodos;
|
||||
case "todo_update":
|
||||
return action.content
|
||||
? t.tokenUsage.updateTodo(action.content)
|
||||
: t.toolCalls.writeTodos;
|
||||
case "todo_remove":
|
||||
return action.content
|
||||
? t.tokenUsage.removeTodo(action.content)
|
||||
: t.toolCalls.writeTodos;
|
||||
case "subagent":
|
||||
return t.tokenUsage.subagent(action.description ?? t.subtasks.subtask);
|
||||
case "search":
|
||||
if (action.query) {
|
||||
return t.toolCalls.searchFor(action.query);
|
||||
}
|
||||
return t.toolCalls.useTool(action.tool_name ?? "search");
|
||||
case "present_files":
|
||||
return t.toolCalls.presentFiles;
|
||||
case "clarification":
|
||||
return t.toolCalls.needYourHelp;
|
||||
case "tool":
|
||||
return describeToolCall(
|
||||
{
|
||||
name: action.tool_name ?? "tool",
|
||||
args: action.description ? { description: action.description } : {},
|
||||
},
|
||||
t,
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function describeToolCall(
|
||||
toolCall: {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
},
|
||||
t: Translations,
|
||||
): string {
|
||||
if (toolCall.name === "task") {
|
||||
const description =
|
||||
typeof toolCall.args.description === "string"
|
||||
? toolCall.args.description
|
||||
: t.subtasks.subtask;
|
||||
return t.tokenUsage.subagent(description);
|
||||
}
|
||||
|
||||
if (
|
||||
(toolCall.name === "web_search" || toolCall.name === "image_search") &&
|
||||
typeof toolCall.args.query === "string"
|
||||
) {
|
||||
return t.toolCalls.searchFor(toolCall.args.query);
|
||||
}
|
||||
|
||||
if (toolCall.name === "web_fetch") {
|
||||
return t.toolCalls.viewWebPage;
|
||||
}
|
||||
|
||||
if (toolCall.name === "present_files") {
|
||||
return t.toolCalls.presentFiles;
|
||||
}
|
||||
|
||||
if (toolCall.name === "ask_clarification") {
|
||||
return t.toolCalls.needYourHelp;
|
||||
}
|
||||
|
||||
if (typeof toolCall.args.description === "string") {
|
||||
return toolCall.args.description;
|
||||
}
|
||||
|
||||
return t.toolCalls.useTool(toolCall.name);
|
||||
}
|
||||
|
||||
function normalizeTokenUsageAttribution(
|
||||
value: unknown,
|
||||
): TokenUsageAttribution | null {
|
||||
const record = asRecord(value);
|
||||
if (!record) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const rawActions = record.actions;
|
||||
if (rawActions !== undefined && !Array.isArray(rawActions)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
// Versioning is additive for now: the frontend should ignore unknown
|
||||
// fields and fall back when required fields become incompatible.
|
||||
version: typeof record.version === "number" ? record.version : undefined,
|
||||
kind: isTokenUsageAttributionKind(record.kind) ? record.kind : undefined,
|
||||
shared_attribution:
|
||||
typeof record.shared_attribution === "boolean"
|
||||
? record.shared_attribution
|
||||
: undefined,
|
||||
tool_call_ids: Array.isArray(record.tool_call_ids)
|
||||
? record.tool_call_ids.filter(
|
||||
(toolCallId): toolCallId is string =>
|
||||
typeof toolCallId === "string" && toolCallId.trim().length > 0,
|
||||
)
|
||||
: undefined,
|
||||
actions: Array.isArray(rawActions)
|
||||
? rawActions
|
||||
.map((action) => normalizeTokenUsageAttributionAction(action))
|
||||
.filter(
|
||||
(action): action is TokenUsageAttributionAction => action !== null,
|
||||
)
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
function normalizeTokenUsageAttributionAction(
|
||||
value: unknown,
|
||||
): TokenUsageAttributionAction | null {
|
||||
const record = asRecord(value);
|
||||
if (!record) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const kind = record.kind;
|
||||
if (
|
||||
kind !== "todo_start" &&
|
||||
kind !== "todo_complete" &&
|
||||
kind !== "todo_update" &&
|
||||
kind !== "todo_remove" &&
|
||||
kind !== "subagent" &&
|
||||
kind !== "search" &&
|
||||
kind !== "present_files" &&
|
||||
kind !== "clarification" &&
|
||||
kind !== "tool"
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const content = readString(record.content);
|
||||
const toolCallId = readString(record.tool_call_id);
|
||||
|
||||
switch (kind) {
|
||||
case "todo_start":
|
||||
case "todo_complete":
|
||||
case "todo_update":
|
||||
case "todo_remove":
|
||||
return {
|
||||
kind,
|
||||
content,
|
||||
tool_call_id: toolCallId,
|
||||
};
|
||||
case "subagent":
|
||||
return {
|
||||
kind,
|
||||
description: readString(record.description),
|
||||
subagent_type: readString(record.subagent_type),
|
||||
tool_call_id: toolCallId,
|
||||
};
|
||||
case "search":
|
||||
return {
|
||||
kind,
|
||||
query: readString(record.query),
|
||||
tool_name: readString(record.tool_name),
|
||||
tool_call_id: toolCallId,
|
||||
};
|
||||
case "present_files":
|
||||
case "clarification":
|
||||
return {
|
||||
kind,
|
||||
tool_call_id: toolCallId,
|
||||
};
|
||||
case "tool":
|
||||
return {
|
||||
kind,
|
||||
tool_name: readString(record.tool_name),
|
||||
description: readString(record.description),
|
||||
tool_call_id: toolCallId,
|
||||
};
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function asRecord(value: unknown): Record<string, unknown> | null {
|
||||
if (!value || typeof value !== "object" || Array.isArray(value)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return value as Record<string, unknown>;
|
||||
}
|
||||
|
||||
function readString(value: unknown): string | undefined {
|
||||
if (typeof value !== "string") {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const normalized = value.trim();
|
||||
return normalized.length > 0 ? normalized : undefined;
|
||||
}
|
||||
|
||||
function isTokenUsageAttributionKind(
|
||||
value: unknown,
|
||||
): value is NonNullable<TokenUsageAttribution["kind"]> {
|
||||
return (
|
||||
value === "thinking" ||
|
||||
value === "final_answer" ||
|
||||
value === "tool_batch" ||
|
||||
value === "todo_update" ||
|
||||
value === "subagent_dispatch"
|
||||
);
|
||||
}
|
||||
@@ -18,7 +18,7 @@ interface AssistantClarificationGroup extends GenericMessageGroup<"assistant:cla
|
||||
|
||||
interface AssistantSubagentGroup extends GenericMessageGroup<"assistant:subagent"> {}
|
||||
|
||||
export type MessageGroup =
|
||||
type MessageGroup =
|
||||
| HumanMessageGroup
|
||||
| AssistantProcessingGroup
|
||||
| AssistantMessageGroup
|
||||
@@ -26,7 +26,10 @@ export type MessageGroup =
|
||||
| AssistantClarificationGroup
|
||||
| AssistantSubagentGroup;
|
||||
|
||||
export function getMessageGroups(messages: Message[]): MessageGroup[] {
|
||||
export function groupMessages<T>(
|
||||
messages: Message[],
|
||||
mapper: (group: MessageGroup) => T,
|
||||
): T[] {
|
||||
if (messages.length === 0) {
|
||||
return [];
|
||||
}
|
||||
@@ -121,52 +124,11 @@ export function getMessageGroups(messages: Message[]): MessageGroup[] {
|
||||
}
|
||||
}
|
||||
|
||||
return groups;
|
||||
}
|
||||
|
||||
export function groupMessages<T>(
|
||||
messages: Message[],
|
||||
mapper: (group: MessageGroup) => T,
|
||||
): T[] {
|
||||
return getMessageGroups(messages)
|
||||
return groups
|
||||
.map(mapper)
|
||||
.filter((result) => result !== undefined && result !== null) as T[];
|
||||
}
|
||||
|
||||
export function getAssistantTurnUsageMessages(groups: MessageGroup[]) {
|
||||
const usageMessagesByGroupIndex: Array<Message[] | null> = Array.from(
|
||||
{ length: groups.length },
|
||||
() => null,
|
||||
);
|
||||
|
||||
let turnStartIndex: number | null = null;
|
||||
|
||||
for (const [index, group] of groups.entries()) {
|
||||
if (group.type === "human") {
|
||||
turnStartIndex = null;
|
||||
continue;
|
||||
}
|
||||
|
||||
turnStartIndex ??= index;
|
||||
|
||||
const nextGroup = groups[index + 1];
|
||||
const isTurnEnd = !nextGroup || nextGroup.type === "human";
|
||||
|
||||
if (!isTurnEnd) {
|
||||
continue;
|
||||
}
|
||||
|
||||
usageMessagesByGroupIndex[index] = groups
|
||||
.slice(turnStartIndex, index + 1)
|
||||
.flatMap((currentGroup) => currentGroup.messages)
|
||||
.filter((message) => message.type === "ai");
|
||||
|
||||
turnStartIndex = null;
|
||||
}
|
||||
|
||||
return usageMessagesByGroupIndex;
|
||||
}
|
||||
|
||||
export function extractTextFromMessage(message: Message) {
|
||||
if (typeof message.content === "string") {
|
||||
return (
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
import type { TokenUsageInlineMode } from "../messages/usage-model";
|
||||
import type { AgentThreadContext } from "../threads";
|
||||
|
||||
export const DEFAULT_LOCAL_SETTINGS: LocalSettings = {
|
||||
notification: {
|
||||
enabled: true,
|
||||
},
|
||||
tokenUsage: {
|
||||
headerTotal: true,
|
||||
inlineMode: "per_turn",
|
||||
},
|
||||
context: {
|
||||
model_name: undefined,
|
||||
mode: undefined,
|
||||
@@ -27,10 +22,6 @@ export interface LocalSettings {
|
||||
notification: {
|
||||
enabled: boolean;
|
||||
};
|
||||
tokenUsage: {
|
||||
headerTotal: boolean;
|
||||
inlineMode: TokenUsageInlineMode;
|
||||
};
|
||||
context: Omit<
|
||||
AgentThreadContext,
|
||||
| "thread_id"
|
||||
@@ -53,10 +44,6 @@ function mergeLocalSettings(settings?: Partial<LocalSettings>): LocalSettings {
|
||||
...DEFAULT_LOCAL_SETTINGS.context,
|
||||
...settings?.context,
|
||||
},
|
||||
tokenUsage: {
|
||||
...DEFAULT_LOCAL_SETTINGS.tokenUsage,
|
||||
...settings?.tokenUsage,
|
||||
},
|
||||
notification: {
|
||||
...DEFAULT_LOCAL_SETTINGS.notification,
|
||||
...settings?.notification,
|
||||
|
||||
@@ -1,396 +0,0 @@
|
||||
import type { Message } from "@langchain/langgraph-sdk";
|
||||
import { expect, test } from "vitest";
|
||||
|
||||
import { enUS } from "@/core/i18n";
|
||||
import {
|
||||
buildTokenDebugSteps,
|
||||
getTokenUsageViewPreset,
|
||||
tokenUsagePreferencesFromPreset,
|
||||
} from "@/core/messages/usage-model";
|
||||
|
||||
test("maps token usage presets to persisted preferences", () => {
|
||||
expect(tokenUsagePreferencesFromPreset("off")).toEqual({
|
||||
headerTotal: false,
|
||||
inlineMode: "off",
|
||||
});
|
||||
expect(tokenUsagePreferencesFromPreset("summary")).toEqual({
|
||||
headerTotal: true,
|
||||
inlineMode: "off",
|
||||
});
|
||||
expect(tokenUsagePreferencesFromPreset("per_turn")).toEqual({
|
||||
headerTotal: true,
|
||||
inlineMode: "per_turn",
|
||||
});
|
||||
expect(tokenUsagePreferencesFromPreset("debug")).toEqual({
|
||||
headerTotal: true,
|
||||
inlineMode: "step_debug",
|
||||
});
|
||||
});
|
||||
|
||||
test("derives the active preset from persisted preferences", () => {
|
||||
expect(
|
||||
getTokenUsageViewPreset({
|
||||
headerTotal: false,
|
||||
inlineMode: "off",
|
||||
}),
|
||||
).toBe("off");
|
||||
|
||||
expect(
|
||||
getTokenUsageViewPreset({
|
||||
headerTotal: true,
|
||||
inlineMode: "off",
|
||||
}),
|
||||
).toBe("summary");
|
||||
|
||||
expect(
|
||||
getTokenUsageViewPreset({
|
||||
headerTotal: true,
|
||||
inlineMode: "per_turn",
|
||||
}),
|
||||
).toBe("per_turn");
|
||||
|
||||
expect(
|
||||
getTokenUsageViewPreset({
|
||||
headerTotal: true,
|
||||
inlineMode: "step_debug",
|
||||
}),
|
||||
).toBe("debug");
|
||||
});
|
||||
|
||||
test("uses generic todo labels when backend attribution is absent", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "write_todos:1",
|
||||
name: "write_todos",
|
||||
args: {
|
||||
todos: [{ content: "Draft the plan", status: "in_progress" }],
|
||||
},
|
||||
},
|
||||
],
|
||||
usage_metadata: {
|
||||
input_tokens: 100,
|
||||
output_tokens: 20,
|
||||
total_tokens: 120,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: "tool-1",
|
||||
type: "tool",
|
||||
name: "write_todos",
|
||||
tool_call_id: "write_todos:1",
|
||||
content: "ok",
|
||||
},
|
||||
{
|
||||
id: "ai-2",
|
||||
type: "ai",
|
||||
content: "",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "write_todos:2",
|
||||
name: "write_todos",
|
||||
args: {
|
||||
todos: [{ content: "Draft the plan", status: "completed" }],
|
||||
},
|
||||
},
|
||||
],
|
||||
usage_metadata: { input_tokens: 50, output_tokens: 10, total_tokens: 60 },
|
||||
},
|
||||
{
|
||||
id: "ai-3",
|
||||
type: "ai",
|
||||
content: "Here is the result",
|
||||
usage_metadata: { input_tokens: 40, output_tokens: 15, total_tokens: 55 },
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
|
||||
expect.objectContaining({
|
||||
messageId: "ai-1",
|
||||
label: "Update to-do list",
|
||||
sharedAttribution: false,
|
||||
}),
|
||||
expect.objectContaining({
|
||||
messageId: "ai-2",
|
||||
label: "Update to-do list",
|
||||
sharedAttribution: false,
|
||||
}),
|
||||
expect.objectContaining({
|
||||
messageId: "ai-3",
|
||||
label: "Final answer",
|
||||
sharedAttribution: false,
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
test("marks multi-action AI steps as shared attribution", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "web_search:1",
|
||||
name: "web_search",
|
||||
args: { query: "LangGraph stream mode" },
|
||||
},
|
||||
{
|
||||
id: "write_todos:1",
|
||||
name: "write_todos",
|
||||
args: {
|
||||
todos: [
|
||||
{
|
||||
content: "Inspect stream mode handling",
|
||||
status: "in_progress",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
usage_metadata: {
|
||||
input_tokens: 120,
|
||||
output_tokens: 30,
|
||||
total_tokens: 150,
|
||||
},
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
|
||||
expect.objectContaining({
|
||||
messageId: "ai-1",
|
||||
label: "Step total",
|
||||
sharedAttribution: true,
|
||||
secondaryLabels: [
|
||||
'Search for "LangGraph stream mode"',
|
||||
"Update to-do list",
|
||||
],
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
test("prefers backend attribution metadata when available", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "write_todos:1",
|
||||
name: "write_todos",
|
||||
args: {
|
||||
todos: [
|
||||
{
|
||||
content: "Fallback label should not win",
|
||||
status: "in_progress",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
additional_kwargs: {
|
||||
token_usage_attribution: {
|
||||
version: 1,
|
||||
kind: "todo_update",
|
||||
shared_attribution: false,
|
||||
actions: [{ kind: "todo_start", content: "Use backend attribution" }],
|
||||
},
|
||||
},
|
||||
usage_metadata: { input_tokens: 25, output_tokens: 5, total_tokens: 30 },
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
|
||||
expect.objectContaining({
|
||||
messageId: "ai-1",
|
||||
label: "Start To-do: Use backend attribution",
|
||||
sharedAttribution: false,
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
test("falls back safely when attribution payload is malformed", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "web_search:1",
|
||||
name: "web_search",
|
||||
args: { query: "LangGraph stream mode" },
|
||||
},
|
||||
],
|
||||
additional_kwargs: {
|
||||
token_usage_attribution: {
|
||||
version: 1,
|
||||
kind: "tool_batch",
|
||||
actions: { broken: true },
|
||||
},
|
||||
},
|
||||
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
|
||||
expect.objectContaining({
|
||||
messageId: "ai-1",
|
||||
label: 'Search for "LangGraph stream mode"',
|
||||
sharedAttribution: false,
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
test("ignores attribution actions that are not objects", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "",
|
||||
tool_calls: [],
|
||||
additional_kwargs: {
|
||||
token_usage_attribution: {
|
||||
version: 1,
|
||||
kind: "tool_batch",
|
||||
shared_attribution: true,
|
||||
actions: [
|
||||
null,
|
||||
"bad-action",
|
||||
{ kind: "search", query: "valid search", ignored: "extra-field" },
|
||||
],
|
||||
},
|
||||
},
|
||||
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
|
||||
expect.objectContaining({
|
||||
messageId: "ai-1",
|
||||
label: 'Search for "valid search"',
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
test("ignores malformed attribution fields and falls back to message content", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "Real final answer",
|
||||
tool_calls: [],
|
||||
additional_kwargs: {
|
||||
token_usage_attribution: {
|
||||
version: 1,
|
||||
kind: null,
|
||||
shared_attribution: null,
|
||||
tool_call_ids: [null, "tool-1", 123],
|
||||
actions: [{ query: "missing kind" }],
|
||||
},
|
||||
},
|
||||
usage_metadata: { input_tokens: 9, output_tokens: 3, total_tokens: 12 },
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
|
||||
expect.objectContaining({
|
||||
messageId: "ai-1",
|
||||
label: "Final answer",
|
||||
sharedAttribution: false,
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
test("ignores unknown top-level attribution fields", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "",
|
||||
tool_calls: [],
|
||||
additional_kwargs: {
|
||||
token_usage_attribution: {
|
||||
version: 1,
|
||||
kind: "tool_batch",
|
||||
shared_attribution: false,
|
||||
unknown_field: "ignored",
|
||||
actions: [{ kind: "subagent", description: "Inspect the fix" }],
|
||||
},
|
||||
},
|
||||
usage_metadata: { input_tokens: 12, output_tokens: 4, total_tokens: 16 },
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
|
||||
expect.objectContaining({
|
||||
messageId: "ai-1",
|
||||
label: "Subagent: Inspect the fix",
|
||||
sharedAttribution: false,
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
test("falls back to generic todo labels when backend attribution has no actions", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "write_todos:1",
|
||||
name: "write_todos",
|
||||
args: {
|
||||
todos: [{ content: "Clean up stale tasks", status: "in_progress" }],
|
||||
},
|
||||
},
|
||||
],
|
||||
usage_metadata: {
|
||||
input_tokens: 100,
|
||||
output_tokens: 20,
|
||||
total_tokens: 120,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: "ai-2",
|
||||
type: "ai",
|
||||
content: "",
|
||||
tool_calls: [
|
||||
{
|
||||
id: "write_todos:2",
|
||||
name: "write_todos",
|
||||
args: {
|
||||
todos: [],
|
||||
},
|
||||
},
|
||||
],
|
||||
additional_kwargs: {
|
||||
token_usage_attribution: {
|
||||
version: 1,
|
||||
kind: "todo_update",
|
||||
shared_attribution: false,
|
||||
actions: [],
|
||||
},
|
||||
},
|
||||
usage_metadata: { input_tokens: 30, output_tokens: 8, total_tokens: 38 },
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
|
||||
expect.objectContaining({
|
||||
messageId: "ai-1",
|
||||
label: "Update to-do list",
|
||||
}),
|
||||
expect.objectContaining({
|
||||
messageId: "ai-2",
|
||||
label: "Update to-do list",
|
||||
sharedAttribution: false,
|
||||
}),
|
||||
]);
|
||||
});
|
||||
@@ -1,65 +0,0 @@
|
||||
import type { Message } from "@langchain/langgraph-sdk";
|
||||
import { expect, test } from "vitest";
|
||||
|
||||
import {
|
||||
getAssistantTurnUsageMessages,
|
||||
getMessageGroups,
|
||||
} from "@/core/messages/utils";
|
||||
|
||||
test("aggregates token usage messages once per assistant turn", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "Plan a trip",
|
||||
},
|
||||
{
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "",
|
||||
tool_calls: [{ id: "tool-1", name: "web_search", args: {} }],
|
||||
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
|
||||
},
|
||||
{
|
||||
id: "tool-1-result",
|
||||
type: "tool",
|
||||
name: "web_search",
|
||||
tool_call_id: "tool-1",
|
||||
content: "[]",
|
||||
},
|
||||
{
|
||||
id: "ai-2",
|
||||
type: "ai",
|
||||
content: "Here is the itinerary",
|
||||
usage_metadata: { input_tokens: 2, output_tokens: 8, total_tokens: 10 },
|
||||
},
|
||||
{
|
||||
id: "human-2",
|
||||
type: "human",
|
||||
content: "Make it shorter",
|
||||
},
|
||||
{
|
||||
id: "ai-3",
|
||||
type: "ai",
|
||||
content: "Short version",
|
||||
usage_metadata: { input_tokens: 1, output_tokens: 1, total_tokens: 2 },
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
const groups = getMessageGroups(messages);
|
||||
const usageMessagesByGroupIndex = getAssistantTurnUsageMessages(groups);
|
||||
|
||||
expect(groups.map((group) => group.type)).toEqual([
|
||||
"human",
|
||||
"assistant:processing",
|
||||
"assistant",
|
||||
"human",
|
||||
"assistant",
|
||||
]);
|
||||
|
||||
expect(
|
||||
usageMessagesByGroupIndex.map(
|
||||
(groupMessages) => groupMessages?.map((message) => message.id) ?? null,
|
||||
),
|
||||
).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]);
|
||||
});
|
||||
Reference in New Issue
Block a user