From 93e3281cbf12baad0501117c6e0436d64bb74bb6 Mon Sep 17 00:00:00 2001 From: AochenShen99 Date: Tue, 9 Jun 2026 15:29:40 +0800 Subject: [PATCH 01/11] fix(dev): create backend/sandbox before uvicorn reload-exclude (#3459) (#3460) * fix(dev): create backend/sandbox before uvicorn reload-exclude (#3459) #3426 switched the dev gateway's --reload-exclude patterns to absolute paths. uvicorn only excludes an absolute path directly when it already exists as a directory; otherwise it globs the pattern, and Python 3.12's pathlib raises NotImplementedError("Non-relative patterns are unsupported") for an absolute glob pattern. serve.sh mkdir'd the .deer-flow excludes but not backend/sandbox, so `make dev` crashed on startup on a fresh checkout under Python 3.12 (#3454). docker/dev-entrypoint.sh had the same latent gap. Create backend/sandbox in both launchers so every absolute exclude stays on uvicorn's is_dir() short-circuit. Add a regression test that pins the uvicorn mechanism (crash on missing dir, safe once created) and enforces that every absolute --reload-exclude is mkdir'd before launch. Closes #3459 * test(dev): harden reload-exclude invariant parser against false pass/negatives The launcher invariant test parsed shell with a "mkdir -p" line filter and a substring membership check. Two latent gaps (sub-threshold for this fix, but this code guards a user-facing startup path, so close them): - A `\`-continued multi-line `mkdir` would drop arguments on continuation lines, silently weakening coverage. - Substring membership could false-pass when an exclude is a path-prefix of a different created dir (e.g. `/app/backend/sandbox` "found" inside `/app/backend/sandbox-other`). Fold line-continuations, drop comments, and shlex-tokenize each `mkdir` argument list into an exact set (quotes stripped, `$VAR` literal); assert exact set membership. Same shlex handling for `--reload-exclude` values. Verified the parser still flags the pre-fix missing `backend/sandbox` (RED preserved) and no longer false-passes on a path-prefix. * fix(dev): gitignore backend/sandbox runtime dir + pin mkdir-before-launch Address two review findings on the #3459 fix: - backend/sandbox was described as "gitignored runtime state" but no ignore rule actually matched it. Add an anchored `/sandbox/` to backend/.gitignore (anchored so it does NOT shadow the source package backend/packages/harness/deerflow/sandbox/) so sandbox artifacts created at runtime can't pollute the working tree or be committed by accident. New test asserts content under backend/sandbox is ignored, making the claim verifiable. - The launcher invariant test only proved the sandbox mkdir exists somewhere, not that it runs before uvicorn starts. Add an order test (sandbox mkdir line must precede the `uv run uvicorn` launch) so a future edit can't move the mkdir below the launch and silently reintroduce the crash. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * test(dev): fix reload-exclude parser to handle serve.sh's quoted flag bundle The previous autofix tokenized each whole line with shlex, but serve.sh packs every flag into a single double-quoted `GATEWAY_EXTRA_FLAGS="..."` assignment. shlex collapses that into one token, so no `--reload-exclude` flag is found and `test_launcher_precreates_every_absolute_reload_exclude[scripts/serve.sh]` failed CI with "expected at least one absolute reload-exclude". Parse `--reload-exclude` with a regex that matches a balanced single/double quoted group or a bare token, so the assignment's surrounding `"` is never swallowed into the value. This recovers all three serve.sh excludes (the prior regex also silently dropped the last `$BACKEND_RUNTIME_HOME` because the adjacent closing quote broke shlex) while still covering dev-entrypoint.sh and the space-separated `--reload-exclude ` form. --------- Co-authored-by: Willem Jiang Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- backend/.gitignore | 5 + backend/tests/test_dev_entrypoint.py | 3 +- backend/tests/test_gateway_runtime_cleanup.py | 4 +- backend/tests/test_uvicorn_reload_exclude.py | 185 ++++++++++++++++++ docker/dev-entrypoint.sh | 10 +- scripts/serve.sh | 7 +- 6 files changed, 207 insertions(+), 7 deletions(-) create mode 100644 backend/tests/test_uvicorn_reload_exclude.py diff --git a/backend/.gitignore b/backend/.gitignore index 6e56d9e81..3967bcb3a 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -24,5 +24,10 @@ config.yaml # Langgraph .langgraph_api +# Sandbox runtime working dir — pre-created and excluded from uvicorn reload +# (scripts/serve.sh, docker/dev-entrypoint.sh). Anchored so it does not match +# the source package backend/packages/harness/deerflow/sandbox/. +/sandbox/ + # Claude Code settings .claude/settings.local.json diff --git a/backend/tests/test_dev_entrypoint.py b/backend/tests/test_dev_entrypoint.py index 5ca8f3443..12bbe1898 100644 --- a/backend/tests/test_dev_entrypoint.py +++ b/backend/tests/test_dev_entrypoint.py @@ -44,7 +44,8 @@ def test_entrypoint_excludes_runtime_state_from_uvicorn_reload(): content = ENTRYPOINT.read_text(encoding="utf-8") assert ': "${DEER_FLOW_HOME:=/app/backend/.deer-flow}"' in content - assert 'mkdir -p "$DEER_FLOW_HOME" /app/backend/.deer-flow' in content + # sandbox must be created too, not just .deer-flow (#3459 / #3454). + assert 'mkdir -p "$DEER_FLOW_HOME" /app/backend/.deer-flow /app/backend/sandbox' in content assert "--reload-include='*.yaml .env'" not in content assert "--reload-include='*.yaml'" in content assert "--reload-include='.env'" in content diff --git a/backend/tests/test_gateway_runtime_cleanup.py b/backend/tests/test_gateway_runtime_cleanup.py index 4642559e5..145ef0eab 100644 --- a/backend/tests/test_gateway_runtime_cleanup.py +++ b/backend/tests/test_gateway_runtime_cleanup.py @@ -49,7 +49,9 @@ def test_local_dev_gateway_reload_excludes_runtime_state_with_absolute_dirs(): assert 'export DEER_FLOW_PROJECT_ROOT="$REPO_ROOT"' in serve_sh assert 'BACKEND_RUNTIME_HOME="$REPO_ROOT/backend/.deer-flow"' in serve_sh assert 'export DEER_FLOW_HOME="$BACKEND_RUNTIME_HOME"' in serve_sh - assert 'mkdir -p "$DEER_FLOW_HOME" "$BACKEND_RUNTIME_HOME"' in serve_sh + # Every absolute reload-exclude must be pre-created, including backend/sandbox + # (#3459 / #3454) — see test_uvicorn_reload_exclude.py for the mechanism. + assert 'mkdir -p "$DEER_FLOW_HOME" "$BACKEND_RUNTIME_HOME" "$REPO_ROOT/backend/sandbox"' in serve_sh assert "--reload-exclude='$DEER_FLOW_HOME'" in serve_sh assert "--reload-exclude='$BACKEND_RUNTIME_HOME'" in serve_sh assert "--reload-exclude='sandbox/'" not in serve_sh diff --git a/backend/tests/test_uvicorn_reload_exclude.py b/backend/tests/test_uvicorn_reload_exclude.py new file mode 100644 index 000000000..430a72284 --- /dev/null +++ b/backend/tests/test_uvicorn_reload_exclude.py @@ -0,0 +1,185 @@ +"""Regression for #3459 / #3454 — dev gateway reload-exclude must not crash. + +#3426 switched the dev gateway's ``--reload-exclude`` patterns from relative +(``sandbox/``) to absolute (``$REPO_ROOT/backend/sandbox``). uvicorn only +excludes such a path directly when it already exists as a directory; otherwise +it falls back to ``Path.cwd().glob(pattern)``, and on **Python 3.12** +``pathlib.Path.glob()`` raises ``NotImplementedError: Non-relative patterns are +unsupported`` for an absolute pattern. ``serve.sh`` created the ``.deer-flow`` +excludes but not ``backend/sandbox``, so a fresh checkout crashed ``make dev`` +on startup. + +Two layers of coverage: + +* ``test_*_resolve_*`` exercises uvicorn's real ``resolve_reload_patterns`` to + pin the failure mode and the fix's mechanism. +* ``test_launcher_precreates_every_absolute_reload_exclude`` enforces the actual + invariant on both launchers: every absolute exclude dir is ``mkdir -p``'d + before uvicorn starts. This encodes the root cause, so any future absolute + exclude that forgets its ``mkdir`` fails here. +""" + +from __future__ import annotations + +import re +import shlex +import subprocess +import sys +from pathlib import Path + +import pytest +from uvicorn.config import resolve_reload_patterns + +REPO_ROOT = Path(__file__).resolve().parents[2] + +LAUNCHERS = { + "scripts/serve.sh": REPO_ROOT / "scripts" / "serve.sh", + "docker/dev-entrypoint.sh": REPO_ROOT / "docker" / "dev-entrypoint.sh", +} + +# Shell terminators / redirects that end a simple command's argument list. +_CMD_BOUNDARY = re.compile(r"[;&|<>]") + + +def _logical_lines(script: str) -> list[str]: + """Fold ``\\``-continuations and drop comment lines, yielding logical lines. + + A ``mkdir`` or ``--reload-exclude`` list split across lines with a trailing + backslash becomes one line here, so an argument on a continuation line can't + be silently dropped by per-line scanning. + """ + folded = script.replace("\\\n", " ") + return [line for line in folded.splitlines() if not line.lstrip().startswith("#")] + + +def _shlex(fragment: str) -> list[str]: + """Tokenize a shell fragment (quotes stripped, ``$VAR`` kept literal, + trailing ``# comment`` honored); tolerate pathological quoting.""" + try: + return shlex.split(fragment, comments=True) + except ValueError: + return fragment.split() + + +# ``--reload-exclude`` followed by ``=`` or whitespace, then a value that is a +# single-quoted group, a double-quoted group, or a bare token. The quoted +# alternatives match a *balanced* pair first, so serve.sh's surrounding +# ``GATEWAY_EXTRA_FLAGS="..."`` closing quote is never swallowed into the value. +_RELOAD_EXCLUDE = re.compile(r"""--reload-exclude[=\s]+('[^']*'|"[^"]*"|[^\s'"]+)""") + + +def _reload_exclude_values(script: str) -> list[str]: + """Every ``--reload-exclude`` value, with surrounding quotes removed. + + Handles both CLI forms (``--reload-exclude=`` and the space form + ``--reload-exclude ``) and both shell quotings the launchers use: + + * ``docker/dev-entrypoint.sh`` puts each flag on its own line. + * ``scripts/serve.sh`` packs every flag into a single double-quoted + ``GATEWAY_EXTRA_FLAGS="... --reload-exclude='$X' ..."`` assignment. A + whole-line ``shlex`` would collapse that assignment into one token and + find no flags (this is what regressed serve.sh in CI); matching balanced + inner quotes here keeps the assignment's closing ``"`` out of the value, + so every exclude — including the last ``$BACKEND_RUNTIME_HOME`` — is seen. + """ + values: list[str] = [] + for line in _logical_lines(script): + for raw in _RELOAD_EXCLUDE.findall(line): + values.append(raw.strip("\"'")) + return values + + +def _mkdir_dirs(script: str) -> set[str]: + """Exact set of directories created by every ``mkdir`` command. + + Tokenizes each ``mkdir`` argument list rather than substring-matching, so + ``/app/backend/sandbox`` is not falsely considered created by, say, + ``mkdir -p /app/backend/sandbox-other``. + """ + dirs: set[str] = set() + for line in _logical_lines(script): + match = re.search(r"\bmkdir\b(.*)", line) + if not match: + continue + args = _CMD_BOUNDARY.split(match.group(1), maxsplit=1)[0] + for token in _shlex(args): + if token.startswith("-"): # skip flags such as -p + continue + dirs.add(token) + return dirs + + +@pytest.mark.skipif( + sys.version_info >= (3, 13), + reason="pathlib accepts absolute glob patterns on 3.13+, so the crash is 3.12-only", +) +def test_resolve_reload_patterns_crashes_on_missing_absolute_dir(tmp_path): + """The exact #3454 failure: absolute exclude + missing dir on Python 3.12.""" + missing = tmp_path / "sandbox" # absolute path that does not exist yet + assert not missing.exists() + with pytest.raises(NotImplementedError): + resolve_reload_patterns([str(missing)], []) + + +def test_resolve_reload_patterns_is_safe_once_dir_exists(tmp_path): + """The fix's mechanism: a pre-created dir takes uvicorn's is_dir() path.""" + sandbox = tmp_path / "sandbox" + sandbox.mkdir() + _patterns, directories = resolve_reload_patterns([str(sandbox)], []) + resolved = {d.resolve() for d in directories} + assert sandbox.resolve() in resolved + + +@pytest.mark.parametrize("name", list(LAUNCHERS)) +def test_launcher_precreates_every_absolute_reload_exclude(name): + """Every absolute ``--reload-exclude`` dir must be created by ``mkdir`` first. + + Relative glob patterns (``*.pyc``, ``__pycache__``) are safe and skipped; + anything anchored at ``/`` or a shell variable is an absolute path that + uvicorn would glob — and crash on — unless it already exists. Membership is + an exact match against the parsed ``mkdir`` argument set (not a substring + test), so a path-prefix can't produce a false pass. + """ + script = LAUNCHERS[name].read_text(encoding="utf-8") + created = _mkdir_dirs(script) + + absolute_excludes = [v for v in _reload_exclude_values(script) if v.startswith(("/", "$"))] + assert absolute_excludes, f"{name}: expected at least one absolute reload-exclude" + + for value in absolute_excludes: + assert value in created, f"{name}: absolute reload-exclude {value!r} is never created via mkdir (created dirs: {sorted(created)})" + + +@pytest.mark.parametrize("name", list(LAUNCHERS)) +def test_sandbox_mkdir_precedes_uvicorn_launch(name): + """The sandbox mkdir must come before the uvicorn launch, not just exist. + + ``_mkdir_dirs`` only proves the mkdir is present somewhere; this pins script + order so a future edit can't move (or guard) the mkdir below the launch and + silently reintroduce the #3454 crash on a fresh checkout. ``uv run uvicorn`` + matches the launch but not serve.sh's ``stop_all`` kill line. + """ + lines = LAUNCHERS[name].read_text(encoding="utf-8").splitlines() + launch_idx = next((i for i, ln in enumerate(lines) if "uv run uvicorn" in ln), None) + mkdir_idx = next((i for i, ln in enumerate(lines) if re.search(r"\bmkdir\b", ln) and "sandbox" in ln), None) + + assert launch_idx is not None, f"{name}: could not locate the 'uv run uvicorn' launch line" + assert mkdir_idx is not None, f"{name}: could not locate the sandbox mkdir line" + assert mkdir_idx < launch_idx, f"{name}: sandbox mkdir (line {mkdir_idx + 1}) must precede uvicorn launch (line {launch_idx + 1})" + + +def test_precreated_sandbox_artifacts_are_gitignored(): + """backend/sandbox is runtime state — its contents must stay out of git so + sandbox artifacts can't be accidentally committed (matches the reload-exclude + intent). A content path is existence-independent, unlike the bare dir path. + + Guards against the inaccurate "gitignored" claim by making it verifiable. + """ + probe = "backend/sandbox/__artifact_probe__" + result = subprocess.run( + ["git", "-C", str(REPO_ROOT), "check-ignore", "-q", probe], + capture_output=True, + ) + if result.returncode == 128: # not a git checkout (e.g. packaged install) + pytest.skip("not inside a git working tree") + assert result.returncode == 0, "backend/sandbox/* should be gitignored (see backend/.gitignore '/sandbox/')" diff --git a/docker/dev-entrypoint.sh b/docker/dev-entrypoint.sh index c7f2a1b31..23fa5c19a 100755 --- a/docker/dev-entrypoint.sh +++ b/docker/dev-entrypoint.sh @@ -64,12 +64,14 @@ if [ -n "$EXTRAS_FLAGS" ]; then echo "[startup] uv extras:$EXTRAS_FLAGS" fi -# Keep runtime-owned files out of uvicorn's reload watcher. The directory must -# exist before uvicorn starts so watchfiles treats it as an excluded directory, -# not as a plain glob pattern. +# Keep runtime-owned files out of uvicorn's reload watcher. Each excluded path +# must exist before uvicorn starts so watchfiles treats it as an excluded +# directory, not as a plain glob pattern — on Python 3.12, globbing an absolute +# pattern raises NotImplementedError and crashes startup (#3459 / #3454). That +# means `sandbox` must be created here too, not just `.deer-flow`. : "${DEER_FLOW_HOME:=/app/backend/.deer-flow}" export DEER_FLOW_HOME -mkdir -p "$DEER_FLOW_HOME" /app/backend/.deer-flow +mkdir -p "$DEER_FLOW_HOME" /app/backend/.deer-flow /app/backend/sandbox # ── Sync dependencies (with self-heal) ────────────────────────────────────── diff --git a/scripts/serve.sh b/scripts/serve.sh index 39dbb6679..bae85f22c 100755 --- a/scripts/serve.sh +++ b/scripts/serve.sh @@ -297,7 +297,12 @@ if [ -z "$DEER_FLOW_HOME" ]; then export DEER_FLOW_HOME="$BACKEND_RUNTIME_HOME" fi -mkdir -p "$DEER_FLOW_HOME" "$BACKEND_RUNTIME_HOME" +# `backend/sandbox` is excluded from uvicorn's reload watcher below. uvicorn only +# excludes an absolute path directly when it already exists as a directory; +# otherwise it globs the pattern, and Python 3.12's pathlib rejects absolute glob +# patterns with NotImplementedError, crashing `make dev` on a fresh checkout +# (#3459 / #3454). Creating it here keeps every absolute exclude on the is_dir path. +mkdir -p "$DEER_FLOW_HOME" "$BACKEND_RUNTIME_HOME" "$REPO_ROOT/backend/sandbox" DEER_FLOW_HOME="$(cd "$DEER_FLOW_HOME" && pwd -P)" BACKEND_RUNTIME_HOME="$(cd "$BACKEND_RUNTIME_HOME" && pwd -P)" export DEER_FLOW_HOME From 8db16bb3d80930f7628f38fa33b057fc5bf221f0 Mon Sep 17 00:00:00 2001 From: ly-wang19 <94427531+ly-wang19@users.noreply.github.com> Date: Tue, 9 Jun 2026 15:45:28 +0800 Subject: [PATCH 02/11] fix(config): coerce null config.yaml list sections to empty list (#3434) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Copying config.example.yaml to config.yaml and starting DeerFlow crashed with `pydantic ValidationError: models — Input should be a valid list [input_value=None]`, because the example ships every entry under `models:` commented out, so PyYAML parses the key as null. Reported in #1444. Add a field_validator(mode="before") on AppConfig that coerces null models/tools/tool_groups to [] (matching their default_factory=list), and emit an actionable warning from from_file when no models are configured (pointing to config.example.yaml / make setup). Adds regression tests. Closes #1444 Co-authored-by: ly-wang19 Co-authored-by: Claude Opus 4.8 (1M context) Co-authored-by: Willem Jiang --- .../harness/deerflow/config/app_config.py | 22 +++++++- backend/tests/test_app_config_reload.py | 51 +++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index 842b49d7a..7352d0af7 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -7,7 +7,7 @@ from typing import Any, Self import yaml from dotenv import load_dotenv -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict @@ -148,6 +148,21 @@ class AppConfig(BaseModel): ), ) + @field_validator("models", "tools", "tool_groups", mode="before") + @classmethod + def _coerce_null_list_sections(cls, value: Any) -> Any: + """Treat a present-but-empty config section as an empty list. + + Commenting out every entry under a top-level YAML key — e.g. ``models:`` + with only comments beneath it, exactly as shipped in + ``config.example.yaml`` — makes PyYAML parse the value as ``None``. + Without this, the documented ``cp config.example.yaml config.yaml`` + first-run flow crashes with an opaque ``Input should be a valid list`` + pydantic error. Coercing ``None`` to ``[]`` keeps that flow working and + matches the field's own ``default_factory=list``. + """ + return [] if value is None else value + @classmethod def resolve_config_path(cls, config_path: str | None = None) -> Path: """Resolve the config file path. @@ -209,6 +224,11 @@ class AppConfig(BaseModel): config_data["extensions"] = extensions_config.model_dump() result = cls.model_validate(config_data) + if not result.models: + logger.warning( + "No models are configured in %s. Add at least one entry under `models:` (see the commented examples in config.example.yaml) or run `make setup`.", + resolved_path, + ) acp_agents = cls._validate_acp_agents(config_data.get("acp_agents", {})) cls._apply_singleton_configs(result, acp_agents) return result diff --git a/backend/tests/test_app_config_reload.py b/backend/tests/test_app_config_reload.py index 3f744aee2..c0bc00bff 100644 --- a/backend/tests/test_app_config_reload.py +++ b/backend/tests/test_app_config_reload.py @@ -140,6 +140,57 @@ def test_app_config_defaults_empty_database_to_sqlite(tmp_path, monkeypatch): assert config.database.sqlite_dir == ".deer-flow/data" +def test_app_config_coerces_commented_out_list_sections(tmp_path, monkeypatch): + """Commenting out every entry under a list key makes PyYAML parse it as None. + + Regression for the documented ``cp config.example.yaml config.yaml`` flow + (issue #1444): such a config must load with empty lists instead of raising + ``Input should be a valid list``. + """ + config_path = tmp_path / "config.yaml" + extensions_path = tmp_path / "extensions_config.json" + _write_extensions_config(extensions_path) + config_path.write_text( + yaml.safe_dump( + { + "sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}, + "models": None, + "tools": None, + "tool_groups": None, + } + ), + encoding="utf-8", + ) + monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path)) + + config = AppConfig.from_file(str(config_path)) + + assert config.models == [] + assert config.tools == [] + assert config.tool_groups == [] + + +def test_app_config_warns_when_no_models_configured(tmp_path, monkeypatch, caplog): + config_path = tmp_path / "config.yaml" + extensions_path = tmp_path / "extensions_config.json" + _write_extensions_config(extensions_path) + config_path.write_text( + yaml.safe_dump( + { + "sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}, + "models": None, + } + ), + encoding="utf-8", + ) + monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path)) + + with caplog.at_level("WARNING", logger="deerflow.config.app_config"): + AppConfig.from_file(str(config_path)) + + assert "No models are configured" in caplog.text + + def test_get_app_config_reloads_when_file_changes(tmp_path, monkeypatch): config_path = tmp_path / "config.yaml" extensions_path = tmp_path / "extensions_config.json" From 37337b77f91b829ec23b9055cecc54aae38b243b Mon Sep 17 00:00:00 2001 From: hataa <79907651+hata33@users.noreply.github.com> Date: Tue, 9 Jun 2026 18:01:43 +0800 Subject: [PATCH 03/11] feat(models): add StepFun reasoning model adapter (#3461) Add PatchedChatStepFun adapter for StepFun reasoning models (step-3.7-flash, step-3.5-flash). Captures reasoning from both streaming and non-streaming responses and replays it on historical assistant messages for multi-turn tool-call conversations. - New: PatchedChatStepFun adapter with streaming/non-streaming reasoning capture - Support both reasoning and reasoning_content field names - 17 unit tests covering all response paths - Updated: config.example.yaml with StepFun configuration example --- .env.example | 1 + .../deerflow/models/patched_stepfun.py | 175 ++++++++++ backend/tests/test_patched_stepfun.py | 305 ++++++++++++++++++ config.example.yaml | 26 ++ 4 files changed, 507 insertions(+) create mode 100644 backend/packages/harness/deerflow/models/patched_stepfun.py create mode 100644 backend/tests/test_patched_stepfun.py diff --git a/.env.example b/.env.example index c4dbe326e..aec43adcf 100644 --- a/.env.example +++ b/.env.example @@ -21,6 +21,7 @@ INFOQUEST_API_KEY=your-infoquest-api-key # DEEPSEEK_API_KEY=your-deepseek-api-key # NOVITA_API_KEY=your-novita-api-key # OpenAI-compatible, see https://novita.ai # MINIMAX_API_KEY=your-minimax-api-key # OpenAI-compatible, see https://platform.minimax.io +# STEPFUN_API_KEY=your-stepfun-api-key # OpenAI-compatible, see https://platform.stepfun.com # VLLM_API_KEY=your-vllm-api-key # OpenAI-compatible # FEISHU_APP_ID=your-feishu-app-id # FEISHU_APP_SECRET=your-feishu-app-secret diff --git a/backend/packages/harness/deerflow/models/patched_stepfun.py b/backend/packages/harness/deerflow/models/patched_stepfun.py new file mode 100644 index 000000000..1a30332a2 --- /dev/null +++ b/backend/packages/harness/deerflow/models/patched_stepfun.py @@ -0,0 +1,175 @@ +"""Patched ChatOpenAI adapter for StepFun reasoning models. + +StepFun returns ``reasoning`` (or ``reasoning_content`` with deepseek-style) in +both streaming deltas and non-streaming responses. Standard ``ChatOpenAI`` +ignores these non-standard fields, so reasoning content is silently dropped. +This adapter captures reasoning from all response paths and replays it on +historical assistant messages for multi-turn tool-call conversations. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import AIMessage, AIMessageChunk +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_openai import ChatOpenAI + +from deerflow.models.assistant_payload_replay import ( + restore_assistant_payloads, + restore_reasoning_content, +) + +_MISSING = object() + + +def _extract_reasoning(value: Any) -> str | object: + """Return reasoning content from a dict/Pydantic object. + + StepFun may return reasoning via ``reasoning`` (default) or + ``reasoning_content`` (deepseek-style). Check both fields. + """ + if isinstance(value, Mapping): + # Check reasoning_content first (deepseek-style), then reasoning (default) + for field in ("reasoning_content", "reasoning"): + if field in value and value[field] is not None: + return value[field] + return _MISSING + + # Pydantic / SDK object attributes + for field in ("reasoning_content", "reasoning"): + attr = getattr(value, field, _MISSING) + if attr is not _MISSING and attr is not None: + return attr + + # Some SDK versions store extra fields in model_extra + model_extra = getattr(value, "model_extra", None) + if isinstance(model_extra, Mapping): + for field in ("reasoning_content", "reasoning"): + if field in model_extra and model_extra[field] is not None: + return model_extra[field] + + return _MISSING + + +def _with_reasoning_content(message: AIMessage | AIMessageChunk, reasoning: str) -> AIMessage | AIMessageChunk: + """Return a copy of *message* with reasoning_content stored in additional_kwargs.""" + additional_kwargs = dict(message.additional_kwargs) + if additional_kwargs.get("reasoning_content") != reasoning: + additional_kwargs["reasoning_content"] = reasoning + return message.model_copy(update={"additional_kwargs": additional_kwargs}) + + +def _get_typed_choice_message(response: Any, index: int) -> Any: + """Extract the SDK-typed choice message at *index*, if available.""" + choices = getattr(response, "choices", None) + if choices is None: + return None + try: + return choices[index].message + except (AttributeError, IndexError, TypeError): + return None + + +class PatchedChatStepFun(ChatOpenAI): + """ChatOpenAI with full reasoning support for StepFun models. + + Captures ``reasoning`` / ``reasoning_content`` from both streaming and + non-streaming responses and replays it on historical assistant messages in + multi-turn tool-call conversations. + """ + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @property + def lc_secrets(self) -> dict[str, str]: + return {"api_key": "STEPFUN_API_KEY", "openai_api_key": "STEPFUN_API_KEY"} + + # --- Request payload replay --- + + def _get_request_payload( + self, + input_: LanguageModelInput, + *, + stop: list[str] | None = None, + **kwargs: Any, + ) -> dict: + """Restore ``reasoning_content`` on historical assistant messages.""" + original_messages = self._convert_input(input_).to_messages() + payload = super()._get_request_payload(input_, stop=stop, **kwargs) + + restore_assistant_payloads( + payload.get("messages", []), + original_messages, + restore_reasoning_content, + ) + + return payload + + # --- Streaming reasoning capture --- + + def _convert_chunk_to_generation_chunk( + self, + chunk: dict, + default_chunk_class: type, + base_generation_info: dict | None, + ) -> ChatGenerationChunk | None: + """Capture ``reasoning`` / ``reasoning_content`` from streaming deltas.""" + generation_chunk = super()._convert_chunk_to_generation_chunk( + chunk, + default_chunk_class, + base_generation_info, + ) + if generation_chunk is None: + return None + + choices = chunk.get("choices", []) + if choices: + delta = choices[0].get("delta") or {} + reasoning = _extract_reasoning(delta) + if reasoning is not _MISSING and isinstance(generation_chunk.message, AIMessageChunk): + generation_chunk = ChatGenerationChunk( + message=_with_reasoning_content(generation_chunk.message, reasoning), + generation_info=generation_chunk.generation_info, + ) + + return generation_chunk + + # --- Non-streaming reasoning capture --- + + def _create_chat_result( + self, + response: dict | Any, + generation_info: dict | None = None, + ) -> ChatResult: + """Extract ``reasoning`` / ``reasoning_content`` from non-streaming responses.""" + result = super()._create_chat_result(response, generation_info) + response_dict = response if isinstance(response, dict) else response.model_dump() + choices = response_dict.get("choices", []) + + patched_generations: list[ChatGeneration] | None = None + for index, generation in enumerate(result.generations): + choice = choices[index] if index < len(choices) else {} + choice_message = choice.get("message", {}) if isinstance(choice, Mapping) else {} + reasoning = _extract_reasoning(choice_message) + + if reasoning is _MISSING and not isinstance(response, dict): + reasoning = _extract_reasoning(_get_typed_choice_message(response, index)) + + message = generation.message + if reasoning is not _MISSING and isinstance(message, AIMessage): + if patched_generations is None: + patched_generations = list(result.generations) + patched_generations[index] = ChatGeneration( + message=_with_reasoning_content(message, reasoning), + generation_info=generation.generation_info, + ) + + return ChatResult( + generations=patched_generations or result.generations, + llm_output=result.llm_output, + ) diff --git a/backend/tests/test_patched_stepfun.py b/backend/tests/test_patched_stepfun.py new file mode 100644 index 000000000..cc6221695 --- /dev/null +++ b/backend/tests/test_patched_stepfun.py @@ -0,0 +1,305 @@ +"""Tests for deerflow.models.patched_stepfun.PatchedChatStepFun.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage + + +def _make_model(**kwargs): + from deerflow.models.patched_stepfun import PatchedChatStepFun + + return PatchedChatStepFun( + model="step-3.7-flash", + api_key="test-key", + base_url="https://api.stepfun.com/v1", + **kwargs, + ) + + +# --------------------------------------------------------------------------- +# Basic properties +# --------------------------------------------------------------------------- + + +def test_is_lc_serializable_returns_true(): + from deerflow.models.patched_stepfun import PatchedChatStepFun + + assert PatchedChatStepFun.is_lc_serializable() is True + + +def test_lc_secrets_contains_stepfun_api_key_mapping(): + model = _make_model() + assert model.lc_secrets["api_key"] == "STEPFUN_API_KEY" + assert model.lc_secrets["openai_api_key"] == "STEPFUN_API_KEY" + + +# --------------------------------------------------------------------------- +# _extract_reasoning helper +# --------------------------------------------------------------------------- + + +def test_extract_reasoning_from_dict_with_reasoning(): + from deerflow.models.patched_stepfun import _extract_reasoning + + assert _extract_reasoning({"reasoning": "thinking..."}) == "thinking..." + + +def test_extract_reasoning_from_dict_with_reasoning_content(): + from deerflow.models.patched_stepfun import _extract_reasoning + + assert _extract_reasoning({"reasoning_content": "thinking..."}) == "thinking..." + + +def test_extract_reasoning_prefers_reasoning_content_over_reasoning(): + from deerflow.models.patched_stepfun import _extract_reasoning + + result = _extract_reasoning({"reasoning_content": "deepseek", "reasoning": "native"}) + assert result == "deepseek" + + +def test_extract_reasoning_missing_returns_sentinel(): + from deerflow.models.patched_stepfun import _MISSING, _extract_reasoning + + assert _extract_reasoning({}) is _MISSING + assert _extract_reasoning({"reasoning": None}) is _MISSING + + +# --------------------------------------------------------------------------- +# Request payload replay (_get_request_payload) +# --------------------------------------------------------------------------- + + +def test_reasoning_content_injected_into_assistant_tool_call_message(): + model = _make_model() + + human = HumanMessage(content="Check Beijing weather.") + ai = AIMessage( + content="", + additional_kwargs={"reasoning_content": "I need to call the weather tool."}, + ) + payload_message = { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_weather", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location":"Beijing"}'}, + } + ], + } + base_payload = { + "messages": [ + {"role": "user", "content": "Check Beijing weather."}, + payload_message, + ] + } + + with patch.object(type(model).__bases__[0], "_get_request_payload", return_value=base_payload): + with patch.object(model, "_convert_input") as mock_convert: + mock_convert.return_value = MagicMock(to_messages=lambda: [human, ai]) + payload = model._get_request_payload([human, ai]) + + assert payload["messages"][1]["reasoning_content"] == "I need to call the weather tool." + + +def test_reasoning_content_is_noop_when_missing(): + model = _make_model() + + human = HumanMessage(content="hello") + ai = AIMessage(content="hi", additional_kwargs={}) + base_payload = { + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ] + } + + with patch.object(type(model).__bases__[0], "_get_request_payload", return_value=base_payload): + with patch.object(model, "_convert_input") as mock_convert: + mock_convert.return_value = MagicMock(to_messages=lambda: [human, ai]) + payload = model._get_request_payload([human, ai]) + + assert "reasoning_content" not in payload["messages"][1] + + +# --------------------------------------------------------------------------- +# Streaming reasoning capture (_convert_chunk_to_generation_chunk) +# --------------------------------------------------------------------------- + + +def test_convert_chunk_captures_reasoning_field(): + """StepFun default format: delta.reasoning.""" + model = _make_model() + + chunk = model._convert_chunk_to_generation_chunk( + {"choices": [{"delta": {"role": "assistant", "reasoning": "I need "}}]}, + AIMessageChunk, + {}, + ) + + assert chunk is not None + assert chunk.message.additional_kwargs["reasoning_content"] == "I need " + + +def test_convert_chunk_captures_reasoning_content_field(): + """StepFun deepseek-style format: delta.reasoning_content.""" + model = _make_model() + + chunk = model._convert_chunk_to_generation_chunk( + {"choices": [{"delta": {"role": "assistant", "reasoning_content": "I need "}}]}, + AIMessageChunk, + {}, + ) + + assert chunk is not None + assert chunk.message.additional_kwargs["reasoning_content"] == "I need " + + +def test_convert_chunk_streams_reasoning_then_content(): + """Full streaming flow: reasoning deltas followed by content.""" + model = _make_model() + + first = model._convert_chunk_to_generation_chunk( + {"choices": [{"delta": {"role": "assistant", "reasoning": "I need "}}]}, + AIMessageChunk, + {}, + ) + second = model._convert_chunk_to_generation_chunk( + {"choices": [{"delta": {"reasoning": "a tool."}}]}, + AIMessageChunk, + {}, + ) + answer = model._convert_chunk_to_generation_chunk( + {"choices": [{"delta": {"content": "Done."}, "finish_reason": "stop"}], "model": "step-3.7-flash"}, + AIMessageChunk, + {}, + ) + + assert first is not None + assert second is not None + assert answer is not None + + combined = first.message + second.message + answer.message + assert combined.additional_kwargs["reasoning_content"] == "I need a tool." + assert combined.content == "Done." + + +def test_convert_chunk_noop_when_no_reasoning(): + model = _make_model() + + chunk = model._convert_chunk_to_generation_chunk( + {"choices": [{"delta": {"content": "Hello."}, "finish_reason": "stop"}], "model": "step-3.7-flash"}, + AIMessageChunk, + {}, + ) + + assert chunk is not None + assert "reasoning_content" not in chunk.message.additional_kwargs + + +# --------------------------------------------------------------------------- +# Non-streaming reasoning capture (_create_chat_result) +# --------------------------------------------------------------------------- + + +def test_create_chat_result_extracts_reasoning_field(): + """StepFun default format: message.reasoning.""" + model = _make_model() + response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "The weather is sunny.", + "reasoning": "The tool returned sunny weather.", + }, + "finish_reason": "stop", + } + ], + "model": "step-3.7-flash", + } + + result = model._create_chat_result(response) + message = result.generations[0].message + + assert message.content == "The weather is sunny." + assert message.additional_kwargs["reasoning_content"] == "The tool returned sunny weather." + + +def test_create_chat_result_extracts_reasoning_content_field(): + """StepFun deepseek-style format: message.reasoning_content.""" + model = _make_model() + response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "The weather is sunny.", + "reasoning_content": "The tool returned sunny weather.", + }, + "finish_reason": "stop", + } + ], + "model": "step-3.7-flash", + } + + result = model._create_chat_result(response) + message = result.generations[0].message + + assert message.content == "The weather is sunny." + assert message.additional_kwargs["reasoning_content"] == "The tool returned sunny weather." + + +def test_create_chat_result_reads_reasoning_from_sdk_object(): + """When the response is a Pydantic model, reasoning is an attribute.""" + model = _make_model() + + class FakeMessage: + reasoning = "Reasoning stored on the SDK message object." + reasoning_content = None + model_extra = None + + class FakeChoice: + message = FakeMessage() + + class FakeResponse: + choices = [FakeChoice()] + + def model_dump(self, **kwargs): + return { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Answer.", + }, + "finish_reason": "stop", + } + ], + "model": "step-3.7-flash", + } + + result = model._create_chat_result(FakeResponse()) + assert result.generations[0].message.additional_kwargs["reasoning_content"] == "Reasoning stored on the SDK message object." + + +def test_create_chat_result_noop_when_no_reasoning(): + model = _make_model() + response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!", + }, + "finish_reason": "stop", + } + ], + "model": "step-3.7-flash", + } + + result = model._create_chat_result(response) + assert "reasoning_content" not in result.generations[0].message.additional_kwargs diff --git a/config.example.yaml b/config.example.yaml index 5de11e226..290ef3302 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -274,6 +274,32 @@ models: # thinking: # type: disabled + # Example: StepFun (阶跃星辰) reasoning models + # StepFun provides OpenAI-compatible API with reasoning models. + # With reasoning_format: deepseek-style, the API returns reasoning_content + # (same field as DeepSeek), which must be replayed on historical assistant + # messages in multi-turn tool-call conversations. + # Use PatchedChatStepFun instead of plain ChatOpenAI. + # Docs: https://platform.stepfun.com/docs/api-reference/chat-completions + # - name: step-3.7-flash + # display_name: Step 3.7 Flash + # use: deerflow.models.patched_stepfun:PatchedChatStepFun + # model: step-3.7-flash + # api_key: $STEPFUN_API_KEY + # base_url: https://api.stepfun.com/v1 + # request_timeout: 600.0 + # max_retries: 2 + # max_tokens: 4096 + # supports_thinking: true + # supports_reasoning_effort: true + # supports_vision: true + # when_thinking_enabled: + # extra_body: + # reasoning_format: deepseek-style + # when_thinking_disabled: + # extra_body: + # reasoning_format: deepseek-style + # Example: MiniMax (OpenAI-compatible) - International Edition # MiniMax provides high-performance models with 512K context window and 128K max output # Docs: https://platform.minimax.io/docs/api-reference/text-openai-api From 63ce88f87410c48ed33f0bb918767c47e6891bb2 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Tue, 9 Jun 2026 15:58:31 +0200 Subject: [PATCH 04/11] fix(replay-e2e): key fixtures by caller and conversation (#3453) * add caller identity in replay e2e * make format * fix(replay-e2e): stabilize title caller replay * fix(replay-e2e): use captured caller without run manager --------- Co-authored-by: Willem Jiang --- backend/docs/REPLAY_E2E.md | 16 +- backend/scripts/build_fixture_from_jsonl.py | 3 +- backend/scripts/record_gateway.py | 32 +++- backend/tests/_replay_fixture.py | 3 +- .../replay/write_read_file.ultra.json | 22 ++- backend/tests/replay_provider.py | 150 ++++++++++++++++-- backend/tests/test_replay_provider.py | 116 ++++++++++++++ 7 files changed, 308 insertions(+), 34 deletions(-) create mode 100644 backend/tests/test_replay_provider.py diff --git a/backend/docs/REPLAY_E2E.md b/backend/docs/REPLAY_E2E.md index cd9920b4c..881c768b7 100644 --- a/backend/docs/REPLAY_E2E.md +++ b/backend/docs/REPLAY_E2E.md @@ -50,18 +50,22 @@ gateway's own run/event stores using the request's auth context, so the real ## How replay works `tests/replay_provider.py::ReplayChatModel` returns recorded assistant turns keyed -by a **normalized hash of the conversation** (human / ai / tool messages — role, -text, tool-call name+args; with ``, dates, UUIDs, tmp paths -stripped). A miss raises loudly rather than passing silently. +by a **normalized hash of the model caller + conversation**. The conversation is +human / ai / tool messages — role, text, tool-call name+args; with +``, dates, UUIDs, tmp paths stripped. The caller is the stable +source of the model call (`lead_agent`, `middleware:title`, `suggest_agent`, +`subagent:*`, etc.). A miss raises loudly rather than passing silently. **The system prompt is excluded from the match key.** The lead-agent system prompt is a living, frequently-edited implementation detail — its wording changes across PRs (e.g. #3195 added a "File Editing Workflow" section). Hashing it would make every fixture go stale and red-fail unrelated PRs the moment anyone edits the prompt. The conversation flow (user input → tool calls → results → answer) is the -stable contract that identifies a recorded turn. (This mirrors how open-design's -mock picker keys on the user prompt, not the system internals.) Combined with -pinning skills + extensions empty and disabling memory/summarization +stable contract that identifies a recorded turn. The caller still stays in the +key so two different model users with identical conversation text do not compete +for the same replay bucket. (This mirrors how open-design's mock picker keys on +the user prompt, not the system internals.) Combined with pinning skills + +extensions empty and disabling memory/summarization (`tests/_replay_fixture.py::build_config_yaml`), a fixture replays the same across machines, days, prompt edits, and CI. Replaying needs **no API key**. diff --git a/backend/scripts/build_fixture_from_jsonl.py b/backend/scripts/build_fixture_from_jsonl.py index 9bd7e1f93..6fcdba405 100644 --- a/backend/scripts/build_fixture_from_jsonl.py +++ b/backend/scripts/build_fixture_from_jsonl.py @@ -36,7 +36,8 @@ def main() -> int: for index, turn in enumerate(turns): data = turn["output"].get("data", {}) tool_calls = [tc.get("name") for tc in (data.get("tool_calls") or [])] - print(f" turn {index}: hash={turn['input_hash'][:12]} tool_calls={tool_calls} content={str(data.get('content'))[:50]!r}") + caller = turn.get("caller", "legacy") + print(f" turn {index}: caller={caller} hash={turn['input_hash'][:12]} tool_calls={tool_calls} content={str(data.get('content'))[:50]!r}") return 0 diff --git a/backend/scripts/record_gateway.py b/backend/scripts/record_gateway.py index ecab4b6cd..105c8bab7 100644 --- a/backend/scripts/record_gateway.py +++ b/backend/scripts/record_gateway.py @@ -28,27 +28,45 @@ sys.path.insert(0, str(_BACKEND / "tests")) def _install_capture(out_path: Path) -> None: from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import messages_to_dict - from replay_provider import hash_messages + from replay_provider import caller_identity, hash_messages, hash_replay_input import deerflow.models.factory as factory_mod class Capture(BaseCallbackHandler): def __init__(self) -> None: - self.inputs: dict[str, list] = {} + self.inputs: dict[str, tuple[list, str]] = {} - def on_chat_model_start(self, serialized, messages, *, run_id=None, **kwargs): # noqa: ANN001 - self.inputs[str(run_id)] = messages[0] if messages else [] + def on_chat_model_start( # noqa: ANN001 + self, + serialized, + messages, + *, + run_id=None, + tags=None, + name=None, + **kwargs, + ): + self.inputs[str(run_id)] = ( + messages[0] if messages else [], + caller_identity(name=name, tags=tags), + ) def on_llm_end(self, response, *, run_id=None, **kwargs): # noqa: ANN001 - inp = self.inputs.pop(str(run_id), None) - if inp is None: + captured = self.inputs.pop(str(run_id), None) + if captured is None: return + inp, caller = captured for batch in response.generations: for gen in batch: message = getattr(gen, "message", None) if message is None: continue - record = {"input_hash": hash_messages(inp), "output": messages_to_dict([message])[0]} + record = { + "caller": caller, + "conversation_hash": hash_messages(inp), + "input_hash": hash_replay_input(inp, caller=caller), + "output": messages_to_dict([message])[0], + } with open(out_path, "a", encoding="utf-8") as handle: handle.write(json.dumps(record, ensure_ascii=False) + "\n") handle.flush() diff --git a/backend/tests/_replay_fixture.py b/backend/tests/_replay_fixture.py index 56f1a080a..053fc6ae4 100644 --- a/backend/tests/_replay_fixture.py +++ b/backend/tests/_replay_fixture.py @@ -32,7 +32,8 @@ REPLAY_MODEL_BLOCK = """\ - name: scenario-model display_name: Scenario Model use: replay_provider:ReplayChatModel - model: replay""" + model: replay + supports_thinking: true""" def real_model_block(model: str) -> str: diff --git a/backend/tests/fixtures/replay/write_read_file.ultra.json b/backend/tests/fixtures/replay/write_read_file.ultra.json index 95cce6ce8..b8cbe142c 100644 --- a/backend/tests/fixtures/replay/write_read_file.ultra.json +++ b/backend/tests/fixtures/replay/write_read_file.ultra.json @@ -12,7 +12,9 @@ }, "turns": [ { - "input_hash": "9c50eda6ab7e8593dabccbdeadc70a4a7bf778b2c0c3f275f1f96cf2c8ab58db", + "caller": "lead_agent", + "conversation_hash": "9c50eda6ab7e8593dabccbdeadc70a4a7bf778b2c0c3f275f1f96cf2c8ab58db", + "input_hash": "27aeb4c11bff2c3ebc182fe52a06556823c21928620a400c7f26be9733c31f3f", "output": { "type": "ai", "data": { @@ -56,7 +58,9 @@ } }, { - "input_hash": "3598aeb87e221ca8f554e4d61ce6d5e8801754606fa5c95a89c38bd6cb623045", + "caller": "middleware:title", + "conversation_hash": "3598aeb87e221ca8f554e4d61ce6d5e8801754606fa5c95a89c38bd6cb623045", + "input_hash": "75101f9faa453b1a35deff920b1e3c1a9f0b013a7627fbbaa03436752776b953", "output": { "type": "ai", "data": { @@ -89,7 +93,9 @@ } }, { - "input_hash": "6af134379b2a9efa01b4f63032f88211d5f38f459f8bed621eb6c65e8e05c1f9", + "caller": "lead_agent", + "conversation_hash": "6af134379b2a9efa01b4f63032f88211d5f38f459f8bed621eb6c65e8e05c1f9", + "input_hash": "f7468603a43d301fcc0167c2f7cd10e53137bfc584f1b3d776614b7a612ed7a6", "output": { "type": "ai", "data": { @@ -132,7 +138,9 @@ } }, { - "input_hash": "04751c4f7b0107b78b5c97d417063883fd586f5ebcbc4acf79be6cb3c0cdaec1", + "caller": "lead_agent", + "conversation_hash": "04751c4f7b0107b78b5c97d417063883fd586f5ebcbc4acf79be6cb3c0cdaec1", + "input_hash": "218645dabc6926a1dbdf45dd20fba8a41e1e690cef78d7752566db3acf5a36ce", "output": { "type": "ai", "data": { @@ -165,7 +173,9 @@ } }, { - "input_hash": "8b98ebdbb53e88f000556c4753adede8eaa076ff6fd7b8a1285bfd18aee8144d", + "caller": "suggest_agent", + "conversation_hash": "8b98ebdbb53e88f000556c4753adede8eaa076ff6fd7b8a1285bfd18aee8144d", + "input_hash": "dcd855d389d7179a1e4bc7074fa9ba7ce697570af8947225d6bacb538f14a0cb", "output": { "type": "ai", "data": { @@ -230,4 +240,4 @@ } } ] -} \ No newline at end of file +} diff --git a/backend/tests/replay_provider.py b/backend/tests/replay_provider.py index ab2ef3791..035889305 100644 --- a/backend/tests/replay_provider.py +++ b/backend/tests/replay_provider.py @@ -2,14 +2,19 @@ record/replay e2e (mirrors open-design's ``mocks/`` golden traces). A fixture is a JSON file capturing the *real* model calls of one scenario, -keyed by a normalized hash of the **input** each call received:: +keyed by a normalized hash of the **caller + input** each call received:: { "scenario": "write_read_file", "mode": "ultra", "model": "gpt-5.5", "turns": [ - {"input_hash": "", "input_preview": "...", "output": }, + { + "caller": "lead_agent", + "conversation_hash": "", + "input_hash": "", + "output": , + }, ... ] } @@ -21,8 +26,11 @@ A real run makes model calls from several callers — the lead agent's own turns and their count/order is not something we want a replay to depend on. Matching by a normalized hash of the *input messages* means each call gets back exactly the output that was recorded for that input, regardless of order or which middleware -issued it. That keeps the in-graph, deterministic title call part of the -recording; memory/summarization, by contrast, are disabled in the replay config +issued it. The caller name (``lead_agent``, ``middleware:title``, +``suggest_agent``, ``subagent:*``, ...) is included so two different model +callers with the same conversation text do not compete for the same replay +bucket. That keeps the in-graph, deterministic title call part of the recording; +memory/summarization, by contrast, are disabled in the replay config (``_replay_fixture.py``) because their background, debounced timing is not reproducible across runs. @@ -67,7 +75,7 @@ from collections import deque from collections.abc import Iterator from typing import Any -from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.callbacks import BaseCallbackHandler, CallbackManagerForLLMRun from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, messages_from_dict from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult @@ -75,6 +83,14 @@ from langchain_core.runnables import Runnable from pydantic import PrivateAttr _FIXTURE_ENV = "DEERFLOW_REPLAY_FIXTURE" +_DEFAULT_CALLER = "lead_agent" +_CALLER_TAG_PREFIXES = ("middleware:", "subagent:") +_CALLER_NAME_ALIASES = { + # TitleMiddleware uses this run_name and tags the call as middleware:title. + # Some execution paths do not preserve the tag down to the model callback, + # so keep the run_name and tag in the same replay namespace. + "title_agent": "middleware:title", +} # Process-wide record of replay misses. A miss raises inside the model, but the # gateway's LLMErrorHandlingMiddleware swallows it into a normal assistant error @@ -94,6 +110,30 @@ def reset_replay_misses() -> None: _replay_misses.clear() +def _normalize_caller(caller: str | None) -> str: + value = _normalize_text(str(caller or "").strip()) + if not value: + return _DEFAULT_CALLER + return _CALLER_NAME_ALIASES.get(value, value) + + +def _caller_from_tags(tags: list[str] | None) -> str | None: + for tag in tags or []: + if isinstance(tag, str) and (tag == _DEFAULT_CALLER or tag.startswith(_CALLER_TAG_PREFIXES)): + return tag + return None + + +def caller_identity(*, name: str | None = None, tags: list[str] | None = None) -> str: + """Stable model-caller identity shared by record and replay. + + Tags win because graph middleware and subagents already use them as the + explicit caller marker. ``run_name`` is exposed to callbacks as ``name`` and + covers route-level callers such as ``suggest_agent``. + """ + return _normalize_caller(_caller_from_tags(tags) or name) + + # Volatile substrings that differ between a recording run and a replay run but # carry no semantic weight for matching. Normalized to stable placeholders # before hashing so the same logical input hashes identically across processes. @@ -172,10 +212,30 @@ def _canonical_messages(messages: list[BaseMessage]) -> str: def hash_messages(messages: list[BaseMessage]) -> str: - """Stable hash of a model call's input. Shared by recorder and replayer.""" + """Legacy stable hash of only a model call's conversation input.""" return hashlib.sha256(_canonical_messages(messages).encode("utf-8")).hexdigest() +def hash_replay_input(messages: list[BaseMessage], *, caller: str | None) -> str: + """Stable replay key for a caller-specific model input.""" + return hash_input_key(hash_messages(messages), caller=caller) + + +def hash_input_key(conversation_hash: str, *, caller: str | None) -> str: + """Namespace a conversation hash by caller identity. + + Keeping this as ``hash(caller + legacy_conversation_hash)`` lets existing + fixtures migrate without a live-model re-record: their old ``input_hash`` is + exactly the conversation hash. + """ + payload = json.dumps( + {"caller": _normalize_caller(caller), "conversation_hash": conversation_hash}, + sort_keys=True, + ensure_ascii=False, + ) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + def _load_fixture(fixture_path: str) -> dict[str, deque[AIMessage]]: with open(fixture_path, encoding="utf-8") as handle: payload = json.load(handle) @@ -199,24 +259,54 @@ class ReplayChatModel(BaseChatModel): _table: dict[str, deque] = PrivateAttr(default_factory=dict) _fixture_path: str = PrivateAttr(default="") + _run_callers: dict[str, str] = PrivateAttr(default_factory=dict) def __init__(self, **kwargs: Any) -> None: # Ignore provider noise the factory forwards from config (model, api_key, # base_url, ...). Fixture path comes from the ``fixture`` kwarg or env. fixture_path = kwargs.pop("fixture", None) or os.environ.get(_FIXTURE_ENV) - super().__init__() + callbacks = kwargs.pop("callbacks", None) + super().__init__(callbacks=callbacks) if not fixture_path: raise ValueError(f"ReplayChatModel needs a fixture path via the ``fixture`` kwarg or ${_FIXTURE_ENV}") self._fixture_path = fixture_path self._table = _load_fixture(fixture_path) + self.callbacks = [*(self.callbacks or []), _ReplayCallerCapture(self._run_callers)] @property def _llm_type(self) -> str: return "deerflow-replay" - def _match(self, messages: list[BaseMessage]) -> AIMessage: - key = hash_messages(messages) + def _caller_from_run_manager(self, run_manager: CallbackManagerForLLMRun | None) -> str: + if run_manager is None: + if len(self._run_callers) == 1: + # Some async LangGraph paths fire on_chat_model_start with the + # caller metadata but invoke the model implementation without a + # run_manager. When there is only one pending start event, it is + # the current call; use it so record/replay share the same + # caller key. + return self._run_callers.pop(next(iter(self._run_callers))) + return _DEFAULT_CALLER + run_id = str(getattr(run_manager, "run_id", "")) + caller = self._run_callers.pop(run_id, None) + if caller: + return caller + return caller_identity( + name=getattr(run_manager, "run_name", None) or getattr(run_manager, "name", None), + tags=getattr(run_manager, "tags", None), + ) + + def _match(self, messages: list[BaseMessage], run_manager: CallbackManagerForLLMRun | None = None) -> AIMessage: + caller = self._caller_from_run_manager(run_manager) + key = hash_replay_input(messages, caller=caller) bucket = self._table.get(key) + if not bucket: + # Backward compatibility for fixtures recorded before caller-aware + # keys. New recordings write caller-aware ``input_hash`` values. + legacy_key = hash_messages(messages) + bucket = self._table.get(legacy_key) + if bucket: + key = legacy_key if not bucket: _replay_misses.append(key) preview = _canonical_messages(messages) @@ -224,6 +314,7 @@ class ReplayChatModel(BaseChatModel): f"replay miss: no recorded output for input hash {key} in {self._fixture_path!r}. " "The replayed run diverged from the recording (graph changed, a non-deterministic tool result " "altered a downstream input, or a volatile field slipped past normalization). " + f"Caller: {caller!r}. " f"Known hashes: {sorted(self._table)}. " f"Normalized input (first 800 chars): {preview[:800]!r}" ) @@ -236,7 +327,7 @@ class ReplayChatModel(BaseChatModel): run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: - return ChatResult(generations=[ChatGeneration(message=self._match(messages))]) + return ChatResult(generations=[ChatGeneration(message=self._match(messages, run_manager))]) def _stream( self, @@ -245,9 +336,16 @@ class ReplayChatModel(BaseChatModel): run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - turn = self._match(messages) + turn = self._match(messages, run_manager) text = turn.content if isinstance(turn.content, str) else "" - chunk = ChatGenerationChunk(message=AIMessageChunk(content=turn.content, tool_calls=turn.tool_calls, additional_kwargs=turn.additional_kwargs, id=turn.id)) + chunk = ChatGenerationChunk( + message=AIMessageChunk( + content=turn.content, + tool_calls=turn.tool_calls, + additional_kwargs=turn.additional_kwargs, + id=turn.id, + ) + ) if run_manager is not None and text: run_manager.on_llm_new_token(text, chunk=chunk) yield chunk @@ -256,5 +354,31 @@ class ReplayChatModel(BaseChatModel): return self +class _ReplayCallerCapture(BaseCallbackHandler): + def __init__(self, run_callers: dict[str, str]) -> None: + self._run_callers = run_callers + + def on_chat_model_start( + self, + serialized: dict, + messages: list[list[BaseMessage]], + *, + run_id: Any = None, + tags: list[str] | None = None, + name: str | None = None, + **kwargs: Any, + ) -> None: + if run_id is not None: + self._run_callers[str(run_id)] = caller_identity(name=name, tags=tags) + + # Re-export so the recorder shares the exact hashing logic. -__all__ = ["ReplayChatModel", "hash_messages", "replay_misses", "reset_replay_misses"] +__all__ = [ + "ReplayChatModel", + "caller_identity", + "hash_input_key", + "hash_messages", + "hash_replay_input", + "replay_misses", + "reset_replay_misses", +] diff --git a/backend/tests/test_replay_provider.py b/backend/tests/test_replay_provider.py new file mode 100644 index 000000000..e87f93cfb --- /dev/null +++ b/backend/tests/test_replay_provider.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import json +from pathlib import Path + +from langchain_core.messages import AIMessage, HumanMessage, messages_to_dict +from replay_provider import ReplayChatModel, caller_identity, hash_messages, hash_replay_input + + +def _write_fixture(path: Path, turns: list[dict]) -> None: + path.write_text( + json.dumps( + { + "scenario": "unit", + "mode": "unit", + "model": "replay", + "prompt": "unit", + "context": {}, + "turns": turns, + } + ), + encoding="utf-8", + ) + + +def test_replay_key_includes_caller_identity(tmp_path: Path): + messages = [HumanMessage(content="same conversation")] + lead_output = AIMessage(content="lead") + suggest_output = AIMessage(content="suggest") + fixture_path = tmp_path / "fixture.json" + + _write_fixture( + fixture_path, + [ + { + "caller": "lead_agent", + "conversation_hash": hash_messages(messages), + "input_hash": hash_replay_input(messages, caller="lead_agent"), + "output": messages_to_dict([lead_output])[0], + }, + { + "caller": "suggest_agent", + "conversation_hash": hash_messages(messages), + "input_hash": hash_replay_input(messages, caller="suggest_agent"), + "output": messages_to_dict([suggest_output])[0], + }, + ], + ) + + model = ReplayChatModel(fixture=str(fixture_path)) + + assert model.invoke(messages, config={"run_name": "suggest_agent"}).content == "suggest" + assert model.invoke(messages, config={"run_name": "lead_agent"}).content == "lead" + + +def test_replay_supports_legacy_conversation_only_fixture(tmp_path: Path): + messages = [HumanMessage(content="legacy conversation")] + fixture_path = tmp_path / "legacy.json" + + _write_fixture( + fixture_path, + [ + { + "input_hash": hash_messages(messages), + "output": messages_to_dict([AIMessage(content="legacy")])[0], + } + ], + ) + + model = ReplayChatModel(fixture=str(fixture_path)) + + assert model.invoke(messages, config={"run_name": "suggest_agent"}).content == "legacy" + + +def test_title_run_name_uses_middleware_caller_namespace(tmp_path: Path): + messages = [HumanMessage(content="title prompt")] + fixture_path = tmp_path / "fixture.json" + + _write_fixture( + fixture_path, + [ + { + "caller": "middleware:title", + "conversation_hash": hash_messages(messages), + "input_hash": hash_replay_input(messages, caller="middleware:title"), + "output": messages_to_dict([AIMessage(content="generated title")])[0], + } + ], + ) + + model = ReplayChatModel(fixture=str(fixture_path)) + + assert caller_identity(name="title_agent") == "middleware:title" + assert model.invoke(messages, config={"run_name": "title_agent"}).content == "generated title" + + +def test_replay_uses_single_pending_capture_when_run_manager_is_missing(tmp_path: Path): + messages = [HumanMessage(content="title prompt")] + fixture_path = tmp_path / "fixture.json" + + _write_fixture( + fixture_path, + [ + { + "caller": "middleware:title", + "conversation_hash": hash_messages(messages), + "input_hash": hash_replay_input(messages, caller="middleware:title"), + "output": messages_to_dict([AIMessage(content="generated title")])[0], + } + ], + ) + + model = ReplayChatModel(fixture=str(fixture_path)) + model._run_callers["captured-run"] = caller_identity(name="title_agent", tags=["middleware:title"]) + + assert model._match(messages, run_manager=None).content == "generated title" From 5b81588b872e1b495f2dc18cf620a32faf30fe69 Mon Sep 17 00:00:00 2001 From: Admire <64821731+LittleChenLiya@users.noreply.github.com> Date: Tue, 9 Jun 2026 22:09:13 +0800 Subject: [PATCH 05/11] fix(frontend): fallback Streamdown clipboard copy (#3397) * fix(frontend): fallback streamdown clipboard copy * fix(frontend): address clipboard fallback review * fix(frontend): normalize clipboard fallback rejection * fix(frontend): harden clipboard fallback install * fix(frontend): clarify clipboard fallback errors * fix(frontend): cover clipboard fallback edge cases * fix(frontend): tighten clipboard fallback cleanup * fix(frontend): reduce clipboard fallback copy window * fix(frontend): guard clipboard item fallback install * fix(frontend): clean up clipboard fallback on selection errors * Address clipboard fallback review feedback * fix(frontend): guard clipboard fallback install during SSR --- .../src/components/ai-elements/message.tsx | 9 +- .../src/components/ai-elements/reasoning.tsx | 6 +- .../src/components/ai-elements/streamdown.tsx | 17 + .../artifacts/artifact-file-detail.tsx | 6 +- .../workspace/messages/subtask-card.tsx | 6 +- .../settings/about-settings-page.tsx | 4 +- .../settings/memory-settings-page.tsx | 6 +- frontend/src/core/clipboard.ts | 265 +++++++- frontend/tests/unit/core/clipboard.test.ts | 618 +++++++++++++++++- 9 files changed, 901 insertions(+), 36 deletions(-) create mode 100644 frontend/src/components/ai-elements/streamdown.tsx diff --git a/frontend/src/components/ai-elements/message.tsx b/frontend/src/components/ai-elements/message.tsx index c0071c219..c260f1f91 100644 --- a/frontend/src/components/ai-elements/message.tsx +++ b/frontend/src/components/ai-elements/message.tsx @@ -18,7 +18,8 @@ import { } from "lucide-react"; import type { ComponentProps, HTMLAttributes, ReactElement } from "react"; import { createContext, memo, useContext, useEffect, useState } from "react"; -import { Streamdown } from "streamdown"; + +import { ClipboardSafeStreamdown } from "./streamdown"; export type MessageProps = HTMLAttributes & { from: UIMessage["role"]; @@ -302,11 +303,13 @@ export const MessageBranchPage = ({ ); }; -export type MessageResponseProps = ComponentProps; +export type MessageResponseProps = ComponentProps< + typeof ClipboardSafeStreamdown +>; export const MessageResponse = memo( ({ className, ...props }: MessageResponseProps) => ( - *:first-child]:mt-0 [&>*:last-child]:mb-0", className, diff --git a/frontend/src/components/ai-elements/reasoning.tsx b/frontend/src/components/ai-elements/reasoning.tsx index b8e0bfcbc..94d4e7a5c 100644 --- a/frontend/src/components/ai-elements/reasoning.tsx +++ b/frontend/src/components/ai-elements/reasoning.tsx @@ -10,9 +10,9 @@ import { cn } from "@/lib/utils"; import { BrainIcon, ChevronDownIcon } from "lucide-react"; import type { ComponentProps, ReactNode } from "react"; import { createContext, memo, useContext, useEffect, useState } from "react"; -import { Streamdown } from "streamdown"; import { reasoningPlugins } from "@/core/streamdown/plugins"; import { Shimmer } from "./shimmer"; +import { ClipboardSafeStreamdown } from "./streamdown"; type ReasoningContextValue = { isStreaming: boolean; @@ -178,7 +178,9 @@ export const ReasoningContent = memo( )} {...props} > - {children} + + {children} + ), ); diff --git a/frontend/src/components/ai-elements/streamdown.tsx b/frontend/src/components/ai-elements/streamdown.tsx new file mode 100644 index 000000000..210053d9b --- /dev/null +++ b/frontend/src/components/ai-elements/streamdown.tsx @@ -0,0 +1,17 @@ +"use client"; + +import { type ComponentProps } from "react"; +import { Streamdown } from "streamdown"; + +import { installClipboardFallback } from "@/core/clipboard"; + +export type ClipboardSafeStreamdownProps = ComponentProps; + +// Only patch browser globals in client context; skip during SSR +if (typeof document !== "undefined") { + installClipboardFallback(); +} + +export function ClipboardSafeStreamdown(props: ClipboardSafeStreamdownProps) { + return ; +} diff --git a/frontend/src/components/workspace/artifacts/artifact-file-detail.tsx b/frontend/src/components/workspace/artifacts/artifact-file-detail.tsx index 70d1d2d88..4d9af67a6 100644 --- a/frontend/src/components/workspace/artifacts/artifact-file-detail.tsx +++ b/frontend/src/components/workspace/artifacts/artifact-file-detail.tsx @@ -10,7 +10,6 @@ import { } from "lucide-react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; -import { Streamdown } from "streamdown"; import { Artifact, @@ -20,6 +19,7 @@ import { ArtifactHeader, ArtifactTitle, } from "@/components/ai-elements/artifact"; +import { ClipboardSafeStreamdown } from "@/components/ai-elements/streamdown"; import { Select, SelectItem } from "@/components/ui/select"; import { SelectContent, @@ -400,13 +400,13 @@ export function ArtifactFilePreview({ if (language === "markdown") { return (
- {content ?? ""} - +
); } diff --git a/frontend/src/components/workspace/messages/subtask-card.tsx b/frontend/src/components/workspace/messages/subtask-card.tsx index b2aa74b34..7113ab480 100644 --- a/frontend/src/components/workspace/messages/subtask-card.tsx +++ b/frontend/src/components/workspace/messages/subtask-card.tsx @@ -6,7 +6,6 @@ import { XCircleIcon, } from "lucide-react"; import { useMemo, useState } from "react"; -import { Streamdown } from "streamdown"; import { ChainOfThought, @@ -14,6 +13,7 @@ import { ChainOfThoughtStep, } from "@/components/ai-elements/chain-of-thought"; import { Shimmer } from "@/components/ai-elements/shimmer"; +import { ClipboardSafeStreamdown } from "@/components/ai-elements/streamdown"; import { Button } from "@/components/ui/button"; import { ShineBorder } from "@/components/ui/shine-border"; import { useI18n } from "@/core/i18n/hooks"; @@ -126,12 +126,12 @@ export function SubtaskCard({ {task.prompt && ( {task.prompt} -
+ } > )} diff --git a/frontend/src/components/workspace/settings/about-settings-page.tsx b/frontend/src/components/workspace/settings/about-settings-page.tsx index 8635f8dec..d91e60a76 100644 --- a/frontend/src/components/workspace/settings/about-settings-page.tsx +++ b/frontend/src/components/workspace/settings/about-settings-page.tsx @@ -1,9 +1,9 @@ "use client"; -import { Streamdown } from "streamdown"; +import { ClipboardSafeStreamdown } from "@/components/ai-elements/streamdown"; import { aboutMarkdown } from "./about-content"; export function AboutSettingsPage() { - return {aboutMarkdown}; + return {aboutMarkdown}; } diff --git a/frontend/src/components/workspace/settings/memory-settings-page.tsx b/frontend/src/components/workspace/settings/memory-settings-page.tsx index f7b677445..9c84ff217 100644 --- a/frontend/src/components/workspace/settings/memory-settings-page.tsx +++ b/frontend/src/components/workspace/settings/memory-settings-page.tsx @@ -10,8 +10,8 @@ import { import Link from "next/link"; import { useDeferredValue, useId, useRef, useState } from "react"; import { toast } from "sonner"; -import { Streamdown } from "streamdown"; +import { ClipboardSafeStreamdown } from "@/components/ai-elements/streamdown"; import { Button } from "@/components/ui/button"; import { Dialog, @@ -639,12 +639,12 @@ export function MemorySettingsPage() {
{summaryReadOnly}
- {summariesToMarkdown(memory, filteredSectionGroups, t)} - + ) : null} diff --git a/frontend/src/core/clipboard.ts b/frontend/src/core/clipboard.ts index 1be382833..81bc92726 100644 --- a/frontend/src/core/clipboard.ts +++ b/frontend/src/core/clipboard.ts @@ -1,3 +1,47 @@ +type ClipboardItemLike = { + types?: readonly string[]; + getType?: (type: string) => Promise; + items?: Record; +}; + +function copyTextWithExecCommand(text: string): boolean { + const document = globalThis.document; + if ( + typeof document?.createElement !== "function" || + typeof document.body?.appendChild !== "function" || + typeof document.execCommand !== "function" + ) { + throw new Error("Clipboard DOM fallback not available"); + } + + const textarea = document.createElement("textarea"); + textarea.value = text; + textarea.setAttribute("readonly", ""); + textarea.style.position = "fixed"; + textarea.style.top = "-9999px"; + textarea.style.left = "-9999px"; + + let copied = false; + let appended = false; + try { + document.body.appendChild(textarea); + appended = true; + textarea.select(); + copied = document.execCommand("copy"); + } finally { + if (appended) { + const parentNode = textarea.parentNode; + if (typeof textarea.remove === "function") { + textarea.remove(); + } else if (typeof parentNode?.removeChild === "function") { + parentNode.removeChild(textarea); + } + } + } + + return copied; +} + export async function writeTextToClipboard(text: string): Promise { try { const clipboard = globalThis.navigator?.clipboard; @@ -6,26 +50,209 @@ export async function writeTextToClipboard(text: string): Promise { return true; } - const document = globalThis.document; - if (!document?.body?.appendChild || !document.execCommand) { - return false; - } - - const textarea = document.createElement("textarea"); - textarea.value = text; - textarea.setAttribute("readonly", ""); - textarea.style.position = "fixed"; - textarea.style.top = "-9999px"; - textarea.style.left = "-9999px"; - document.body.appendChild(textarea); - textarea.select(); - - try { - return document.execCommand("copy"); - } finally { - textarea.remove(); - } + return copyTextWithExecCommand(text); } catch { return false; } } + +function fallbackWriteText(text: string): Promise { + try { + if (!copyTextWithExecCommand(text)) { + return Promise.reject(new Error("Clipboard copy command failed")); + } + } catch (error) { + return Promise.reject( + error instanceof Error ? error : new Error(String(error)), + ); + } + return Promise.resolve(); +} + +function hasUsableClipboardItem(): boolean { + return typeof globalThis.ClipboardItem === "function"; +} + +async function readPlainTextFromClipboardItem( + item: ClipboardItemLike, +): Promise { + const plainText = item.items?.["text/plain"]; + if (typeof plainText === "string") { + return plainText; + } + if (plainText instanceof Blob) { + return await plainText.text(); + } + + if (item.types && !item.types.includes("text/plain")) { + throw new Error("Clipboard item is missing text/plain data"); + } + + if (typeof item.getType !== "function") { + throw new Error("Clipboard item cannot read text/plain data"); + } + + const blob = await item.getType("text/plain"); + if (blob instanceof Blob) { + return await blob.text(); + } + + throw new Error("Clipboard item text/plain data is not a Blob"); +} + +function canDefineNavigatorClipboard( + navigator: Navigator, + descriptor: PropertyDescriptor | undefined, +): boolean { + if (descriptor) { + return descriptor.configurable === true; + } + return Object.isExtensible(navigator); +} + +/** + * Installs browser clipboard fallbacks for Streamdown copy controls by patching + * missing navigator.clipboard methods and ClipboardItem when the host permits it. + */ +export function installClipboardFallback(): void { + const navigator = globalThis.navigator; + if (!navigator) { + return; + } + + const rawClipboard = navigator.clipboard; + const clipboard = + typeof rawClipboard === "object" && rawClipboard !== null + ? (rawClipboard as Partial) + : undefined; + const clipboardDescriptor = Object.getOwnPropertyDescriptor( + navigator, + "clipboard", + ); + const hasWriteText = typeof clipboard?.writeText === "function"; + const hasWrite = typeof clipboard?.write === "function"; + const hasClipboardItem = hasUsableClipboardItem(); + + if (hasWriteText && hasWrite && hasClipboardItem) { + return; + } + + const writeText = hasWriteText + ? clipboard.writeText!.bind(clipboard) + : fallbackWriteText; + const write = hasWrite + ? clipboard.write!.bind(clipboard) + : (items: ClipboardItemLike[]) => { + const firstItem = items[0]; + if (!firstItem) { + return Promise.reject(new Error("Clipboard item not available")); + } + + return readPlainTextFromClipboardItem(firstItem).then(writeText); + }; + + const fallbackClipboard = clipboard ?? {}; + + try { + const missingMethods: PropertyDescriptorMap = {}; + if (!hasWrite) { + missingMethods.write = { + configurable: true, + value: write, + writable: true, + }; + } + if (!hasWriteText) { + missingMethods.writeText = { + configurable: true, + value: writeText, + writable: true, + }; + } + + Object.defineProperties(fallbackClipboard, missingMethods); + + if ( + !clipboard && + canDefineNavigatorClipboard(navigator, clipboardDescriptor) + ) { + Object.defineProperty(navigator, "clipboard", { + configurable: true, + value: fallbackClipboard, + }); + } + } catch { + if (!canDefineNavigatorClipboard(navigator, clipboardDescriptor)) { + // The ClipboardItem fallback below is independent from navigator.clipboard. + if (hasClipboardItem) { + return; + } + } else { + const replacement = Object.create(clipboard ?? null); + for (const methodName of ["read", "readText"] as const) { + const method = clipboard?.[methodName]; + if (typeof method === "function") { + Object.defineProperty(replacement, methodName, { + configurable: true, + value: method.bind(clipboard), + writable: true, + }); + } + } + Object.defineProperties(replacement, { + write: { + configurable: true, + value: write, + writable: true, + }, + writeText: { + configurable: true, + value: writeText, + writable: true, + }, + }); + try { + Object.defineProperty(navigator, "clipboard", { + configurable: true, + value: replacement, + }); + } catch { + // The ClipboardItem fallback below is independent from navigator.clipboard. + } + } + } + + if (!hasClipboardItem) { + class ClipboardItemFallback { + items: Record; + types: string[]; + + constructor(items: Record) { + this.items = items; + this.types = Object.keys(items); + } + + getType(type: string): Promise { + const value = this.items[type]; + if (value instanceof Blob) { + return Promise.resolve(value); + } + if (typeof value === "string") { + return Promise.resolve(new Blob([value], { type })); + } + return Promise.reject( + new Error(`Clipboard item is missing ${type} data`), + ); + } + } + + try { + Object.defineProperty(globalThis, "ClipboardItem", { + configurable: true, + value: ClipboardItemFallback, + }); + } catch { + return; + } + } +} diff --git a/frontend/tests/unit/core/clipboard.test.ts b/frontend/tests/unit/core/clipboard.test.ts index 56db47c16..2329a17dc 100644 --- a/frontend/tests/unit/core/clipboard.test.ts +++ b/frontend/tests/unit/core/clipboard.test.ts @@ -1,11 +1,18 @@ import { afterEach, expect, test, vi } from "vitest"; -import { writeTextToClipboard } from "@/core/clipboard"; +import { + installClipboardFallback, + writeTextToClipboard, +} from "@/core/clipboard"; const originalNavigator = globalThis.navigator; const hadOriginalNavigator = "navigator" in globalThis; const originalDocument = globalThis.document; const hadOriginalDocument = "document" in globalThis; +const originalClipboardItemDescriptor = Object.getOwnPropertyDescriptor( + globalThis, + "ClipboardItem", +); afterEach(() => { vi.restoreAllMocks(); @@ -26,6 +33,16 @@ afterEach(() => { value: originalDocument, }); } + + if (!originalClipboardItemDescriptor) { + Reflect.deleteProperty(globalThis, "ClipboardItem"); + } else { + Object.defineProperty( + globalThis, + "ClipboardItem", + originalClipboardItemDescriptor, + ); + } }); test("writes text with the Clipboard API when available", async () => { @@ -90,6 +107,95 @@ test("falls back to execCommand when Clipboard API is unavailable", async () => expect(textarea.remove).toHaveBeenCalled(); }); +test("falls back to parent removal when textarea.remove is unavailable", async () => { + const parentNode = { + removeChild: vi.fn(), + }; + const textarea = { + parentNode, + select: vi.fn(), + setAttribute: vi.fn(), + style: {}, + value: "", + }; + const execCommand = vi.fn().mockReturnValue(true); + + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: { + body: { + appendChild: vi.fn(), + }, + createElement: vi.fn().mockReturnValue(textarea), + execCommand, + }, + }); + + await expect(writeTextToClipboard("hello")).resolves.toBe(true); + expect(parentNode.removeChild).toHaveBeenCalledWith(textarea); +}); + +test("does not fail cleanup when textarea removal APIs are unavailable", async () => { + const textarea = { + parentNode: {}, + select: vi.fn(), + setAttribute: vi.fn(), + style: {}, + value: "", + }; + + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: { + body: { + appendChild: vi.fn(), + }, + createElement: vi.fn().mockReturnValue(textarea), + execCommand: vi.fn().mockReturnValue(true), + }, + }); + + await expect(writeTextToClipboard("hello")).resolves.toBe(true); +}); + +test("cleans up the textarea when selecting text fails", async () => { + const textarea = { + remove: vi.fn(), + select: vi.fn(() => { + throw new Error("selection failed"); + }), + setAttribute: vi.fn(), + style: {}, + value: "", + }; + + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: { + body: { + appendChild: vi.fn(), + }, + createElement: vi.fn().mockReturnValue(textarea), + execCommand: vi.fn(), + }, + }); + + await expect(writeTextToClipboard("hello")).resolves.toBe(false); + expect(textarea.remove).toHaveBeenCalled(); +}); + test("returns false when execCommand fallback fails", async () => { const textarea = { remove: vi.fn(), @@ -118,6 +224,24 @@ test("returns false when execCommand fallback fails", async () => { expect(textarea.remove).toHaveBeenCalled(); }); +test("returns false when execCommand fallback cannot create an element", async () => { + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: { + body: { + appendChild: vi.fn(), + }, + execCommand: vi.fn(), + }, + }); + + await expect(writeTextToClipboard("hello")).resolves.toBe(false); +}); + test("returns false when navigator is unavailable", async () => { Object.defineProperty(globalThis, "navigator", { configurable: true, @@ -144,3 +268,495 @@ test("returns false when Clipboard API rejects", async () => { await expect(writeTextToClipboard("hello")).resolves.toBe(false); }); + +test("installs a writeText fallback when Clipboard API is unavailable", async () => { + const textarea = { + remove: vi.fn(), + select: vi.fn(), + setAttribute: vi.fn(), + style: {}, + value: "", + }; + const appendChild = vi.fn(); + const execCommand = vi.fn().mockReturnValue(true); + + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: { + body: { + appendChild, + }, + createElement: vi.fn().mockReturnValue(textarea), + execCommand, + }, + }); + + installClipboardFallback(); + + await expect(globalThis.navigator.clipboard.writeText("hello")).resolves.toBe( + undefined, + ); + expect(textarea.value).toBe("hello"); + expect(appendChild).toHaveBeenCalledWith(textarea); + expect(textarea.select).toHaveBeenCalled(); + expect(execCommand).toHaveBeenCalledWith("copy"); + expect(textarea.remove).toHaveBeenCalled(); +}); + +test("installed writeText fallback rejects instead of throwing synchronously", async () => { + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: undefined, + }); + + installClipboardFallback(); + + const result = globalThis.navigator.clipboard.writeText("hello"); + expect(result).toBeInstanceOf(Promise); + await expect(result).rejects.toThrow("Clipboard DOM fallback not available"); +}); + +test("installed writeText fallback converts thrown DOM failures to rejections", async () => { + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: { + body: { + appendChild: vi.fn(), + }, + createElement: vi.fn(() => { + throw new Error("dom unavailable"); + }), + execCommand: vi.fn(), + }, + }); + + installClipboardFallback(); + + const result = globalThis.navigator.clipboard.writeText("hello"); + expect(result).toBeInstanceOf(Promise); + await expect(result).rejects.toThrow("dom unavailable"); +}); + +test("installed writeText fallback distinguishes copy command failure", async () => { + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: { + body: { + appendChild: vi.fn(), + }, + createElement: vi.fn().mockReturnValue({ + remove: vi.fn(), + select: vi.fn(), + setAttribute: vi.fn(), + style: {}, + value: "", + }), + execCommand: vi.fn().mockReturnValue(false), + }, + }); + + installClipboardFallback(); + + await expect( + globalThis.navigator.clipboard.writeText("hello"), + ).rejects.toThrow("Clipboard copy command failed"); +}); + +test("installs a write fallback for ClipboardItem text/plain payloads", async () => { + const textarea = { + remove: vi.fn(), + select: vi.fn(), + setAttribute: vi.fn(), + style: {}, + value: "", + }; + const execCommand = vi.fn().mockReturnValue(true); + + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: { + body: { + appendChild: vi.fn(), + }, + createElement: vi.fn().mockReturnValue(textarea), + execCommand, + }, + }); + Reflect.deleteProperty(globalThis, "ClipboardItem"); + + installClipboardFallback(); + + const item = new globalThis.ClipboardItem({ + "text/html": new Blob(["
"], { type: "text/html" }), + "text/plain": "| A |\n| B |", + }); + await expect(globalThis.navigator.clipboard.write([item])).resolves.toBe( + undefined, + ); + expect(textarea.value).toBe("| A |\n| B |"); + expect(execCommand).toHaveBeenCalledWith("copy"); +}); + +test("installed write fallback rejects when ClipboardItem lacks text/plain", async () => { + const execCommand = vi.fn().mockReturnValue(true); + + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: { + body: { + appendChild: vi.fn(), + }, + createElement: vi.fn().mockReturnValue({ + remove: vi.fn(), + select: vi.fn(), + setAttribute: vi.fn(), + style: {}, + value: "", + }), + execCommand, + }, + }); + Reflect.deleteProperty(globalThis, "ClipboardItem"); + + installClipboardFallback(); + + const item = new globalThis.ClipboardItem({ + "text/html": new Blob(["
"], { type: "text/html" }), + }); + await expect(globalThis.navigator.clipboard.write([item])).rejects.toThrow( + "Clipboard item is missing text/plain data", + ); + expect(execCommand).not.toHaveBeenCalled(); +}); + +test("installed write fallback rejects when getType cannot provide text/plain", async () => { + const execCommand = vi.fn().mockReturnValue(true); + + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: { + body: { + appendChild: vi.fn(), + }, + createElement: vi.fn().mockReturnValue({ + remove: vi.fn(), + select: vi.fn(), + setAttribute: vi.fn(), + style: {}, + value: "", + }), + execCommand, + }, + }); + + installClipboardFallback(); + + await expect( + globalThis.navigator.clipboard.write([ + { + getType: vi.fn().mockRejectedValue(new Error("missing")), + types: ["text/plain"], + } as unknown as ClipboardItem, + ]), + ).rejects.toThrow("missing"); + expect(execCommand).not.toHaveBeenCalled(); +}); + +test("installed write fallback rejects before getType when item types exclude text/plain", async () => { + const getType = vi.fn().mockResolvedValue(new Blob(["ignored"])); + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: undefined, + }); + + installClipboardFallback(); + + await expect( + globalThis.navigator.clipboard.write([ + { + getType, + types: ["text/html"], + } as unknown as ClipboardItem, + ]), + ).rejects.toThrow("Clipboard item is missing text/plain data"); + expect(getType).not.toHaveBeenCalled(); +}); + +test("installed write fallback rejects when getType is missing", async () => { + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: undefined, + }); + + installClipboardFallback(); + + await expect( + globalThis.navigator.clipboard.write([ + { + types: ["text/plain"], + } as unknown as ClipboardItem, + ]), + ).rejects.toThrow("Clipboard item cannot read text/plain data"); +}); + +test("installed write fallback rejects when getType returns a non-Blob", async () => { + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: undefined, + }); + + installClipboardFallback(); + + await expect( + globalThis.navigator.clipboard.write([ + { + getType: vi.fn().mockResolvedValue("plain text"), + types: ["text/plain"], + } as unknown as ClipboardItem, + ]), + ).rejects.toThrow("Clipboard item text/plain data is not a Blob"); +}); + +test("installed write fallback preserves existing clipboard prototype methods", async () => { + const readText = vi.fn().mockResolvedValue("existing"); + const clipboard = Object.create({ + readText, + }); + + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: { + clipboard, + }, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: undefined, + }); + Reflect.deleteProperty(globalThis, "ClipboardItem"); + + installClipboardFallback(); + + expect(globalThis.navigator.clipboard).toBe(clipboard); + await expect(globalThis.navigator.clipboard.readText()).resolves.toBe( + "existing", + ); + expect(readText).toHaveBeenCalled(); + await expect( + globalThis.navigator.clipboard.writeText("hello"), + ).rejects.toThrow("Clipboard DOM fallback not available"); +}); + +test("installClipboardFallback does not replace existing clipboard methods when only ClipboardItem is missing", async () => { + const write = vi.fn().mockResolvedValue(undefined); + const writeText = vi.fn().mockResolvedValue(undefined); + const clipboard = { + write, + writeText, + }; + + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: { + clipboard, + }, + }); + Reflect.deleteProperty(globalThis, "ClipboardItem"); + + installClipboardFallback(); + + expect(globalThis.navigator.clipboard).toBe(clipboard); + expect(Reflect.get(globalThis.navigator.clipboard, "write")).toBe(write); + expect(Reflect.get(globalThis.navigator.clipboard, "writeText")).toBe( + writeText, + ); + expect(typeof globalThis.ClipboardItem).toBe("function"); +}); + +test("installClipboardFallback is idempotent for the same navigator", async () => { + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: undefined, + }); + Reflect.deleteProperty(globalThis, "ClipboardItem"); + + installClipboardFallback(); + const clipboard = globalThis.navigator.clipboard; + const ClipboardItemFallback = globalThis.ClipboardItem; + + installClipboardFallback(); + + expect(globalThis.navigator.clipboard).toBe(clipboard); + expect(globalThis.ClipboardItem).toBe(ClipboardItemFallback); +}); + +test("installClipboardFallback can recover when the same navigator loses fallback globals", async () => { + const navigator = {}; + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: navigator, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: undefined, + }); + Reflect.deleteProperty(globalThis, "ClipboardItem"); + + installClipboardFallback(); + Reflect.deleteProperty(globalThis, "ClipboardItem"); + Reflect.deleteProperty(navigator, "clipboard"); + + installClipboardFallback(); + + expect(typeof globalThis.navigator.clipboard.writeText).toBe("function"); + expect(typeof globalThis.ClipboardItem).toBe("function"); +}); + +test("installClipboardFallback defines writable fallback methods", async () => { + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: undefined, + }); + + installClipboardFallback(); + + expect( + Object.getOwnPropertyDescriptor(globalThis.navigator.clipboard, "write") + ?.writable, + ).toBe(true); + expect( + Object.getOwnPropertyDescriptor(globalThis.navigator.clipboard, "writeText") + ?.writable, + ).toBe(true); +}); + +test("installClipboardFallback skips missing clipboard on non-extensible navigator while installing ClipboardItem", async () => { + const navigator = {}; + Object.preventExtensions(navigator); + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: navigator, + }); + Reflect.deleteProperty(globalThis, "ClipboardItem"); + + installClipboardFallback(); + + expect("clipboard" in globalThis.navigator).toBe(false); + expect(typeof globalThis.ClipboardItem).toBe("function"); +}); + +test("installClipboardFallback handles non-object navigator.clipboard values", async () => { + const navigator = {}; + Object.defineProperty(navigator, "clipboard", { + configurable: true, + value: "locked", + }); + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: navigator, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: undefined, + }); + + installClipboardFallback(); + + expect(typeof globalThis.navigator.clipboard.writeText).toBe("function"); + await expect( + globalThis.navigator.clipboard.writeText("hello"), + ).rejects.toThrow("Clipboard DOM fallback not available"); +}); + +test("installClipboardFallback does not throw when ClipboardItem cannot be defined", async () => { + const originalDefineProperty = Object.defineProperty; + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: {}, + }); + Object.defineProperty(globalThis, "document", { + configurable: true, + value: undefined, + }); + Reflect.deleteProperty(globalThis, "ClipboardItem"); + vi.spyOn(Object, "defineProperty").mockImplementation( + (target, property, descriptor) => { + if (target === globalThis && property === "ClipboardItem") { + throw new Error("locked global"); + } + return originalDefineProperty(target, property, descriptor); + }, + ); + + expect(() => installClipboardFallback()).not.toThrow(); + expect(typeof globalThis.navigator.clipboard.writeText).toBe("function"); + expect("ClipboardItem" in globalThis).toBe(false); +}); + +test("installs ClipboardItem fallback when the global property exists but is unusable", async () => { + Object.defineProperty(globalThis, "navigator", { + configurable: true, + value: { + clipboard: { + write: vi.fn().mockResolvedValue(undefined), + writeText: vi.fn().mockResolvedValue(undefined), + }, + }, + }); + Object.defineProperty(globalThis, "ClipboardItem", { + configurable: true, + value: undefined, + }); + + installClipboardFallback(); + + expect(typeof globalThis.ClipboardItem).toBe("function"); +}); From b62c5a7b5bb1034d8fcfe16cf8314ee650efb417 Mon Sep 17 00:00:00 2001 From: ly-wang19 <94427531+ly-wang19@users.noreply.github.com> Date: Tue, 9 Jun 2026 22:24:53 +0800 Subject: [PATCH 06/11] fix(agents): offload blocking filesystem IO in the custom-agent router off the event loop (#3457) * fix(agents): offload blocking filesystem IO in delete_agent off the event loop delete_agent is an async route handler but resolved the agent directory (Paths.base_dir -> Path.resolve), probed it (Path.exists), and removed it (shutil.rmtree) directly on the event loop, blocking it for the duration of every delete. Surfaced by 'make detect-blocking-io'. Move the resolve/exists/rmtree sequence into a sync helper run via asyncio.to_thread, mapping its outcome back to the existing 404/409/500 responses (behavior unchanged). Adds a tests/blocking_io/ regression anchor under the strict Blockbuster gate, mirroring test_skills_load.py (#1917). Co-Authored-By: Claude Opus 4.8 (1M context) * fix(agents): offload blocking filesystem IO in create_agent_endpoint too Like delete_agent, the async create_agent_endpoint resolved and created the agent directory and wrote config.yaml + SOUL.md (with rmtree cleanup on failure) directly on the event loop. Move the whole create-or-409 sequence into a sync helper run via asyncio.to_thread; behavior is unchanged (201 / 409 / 500). Extends the blocking_io regression anchor to cover create as well as delete and renames it to test_agents_router.py. Co-Authored-By: Claude Opus 4.8 (1M context) * Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: ly-wang19 Co-authored-by: Claude Opus 4.8 (1M context) Co-authored-by: Willem Jiang Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- backend/app/gateway/routers/agents.py | 115 +++++++++++------- .../tests/blocking_io/test_agents_router.py | 64 ++++++++++ 2 files changed, 134 insertions(+), 45 deletions(-) create mode 100644 backend/tests/blocking_io/test_agents_router.py diff --git a/backend/app/gateway/routers/agents.py b/backend/app/gateway/routers/agents.py index 8769e9834..933dd4211 100644 --- a/backend/app/gateway/routers/agents.py +++ b/backend/app/gateway/routers/agents.py @@ -1,5 +1,6 @@ """CRUD API for custom agents.""" +import asyncio import logging import re import shutil @@ -213,48 +214,61 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse: user_id = get_effective_user_id() paths = get_paths() - agent_dir = paths.user_agent_dir(user_id, normalized_name) - legacy_dir = paths.agent_dir(normalized_name) + def _create_agent() -> AgentResponse | None: + # Worker thread: base-dir resolution, existence checks, directory/file + # creation, read-back, and failure cleanup are all blocking filesystem + # IO that must stay off the event loop. + agent_dir = paths.user_agent_dir(user_id, normalized_name) + legacy_dir = paths.agent_dir(normalized_name) - if agent_dir.exists() or legacy_dir.exists(): - raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists") + if legacy_dir.exists(): + return None # signals 409 to the caller + + try: + try: + agent_dir.mkdir(parents=True, exist_ok=False) + except FileExistsError: + return None # signals 409 to the caller + # Write config.yaml + config_data: dict = {"name": normalized_name} + if request.description: + config_data["description"] = request.description + if request.model is not None: + config_data["model"] = request.model + if request.tool_groups is not None: + config_data["tool_groups"] = request.tool_groups + if request.skills is not None: + config_data["skills"] = request.skills + + config_file = agent_dir / "config.yaml" + with open(config_file, "w", encoding="utf-8") as f: + yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True) + + # Write SOUL.md + soul_file = agent_dir / "SOUL.md" + soul_file.write_text(request.soul, encoding="utf-8") + + logger.info(f"Created agent '{normalized_name}' at {agent_dir}") + + agent_cfg = load_agent_config(normalized_name, user_id=user_id) + return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id) + except Exception: + # Clean up partial state on failure before surfacing the error. + if agent_dir.exists(): + shutil.rmtree(agent_dir) + raise try: - agent_dir.mkdir(parents=True, exist_ok=True) - - # Write config.yaml - config_data: dict = {"name": normalized_name} - if request.description: - config_data["description"] = request.description - if request.model is not None: - config_data["model"] = request.model - if request.tool_groups is not None: - config_data["tool_groups"] = request.tool_groups - if request.skills is not None: - config_data["skills"] = request.skills - - config_file = agent_dir / "config.yaml" - with open(config_file, "w", encoding="utf-8") as f: - yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True) - - # Write SOUL.md - soul_file = agent_dir / "SOUL.md" - soul_file.write_text(request.soul, encoding="utf-8") - - logger.info(f"Created agent '{normalized_name}' at {agent_dir}") - - agent_cfg = load_agent_config(normalized_name, user_id=user_id) - return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id) - - except HTTPException: - raise + response = await asyncio.to_thread(_create_agent) except Exception as e: - # Clean up on failure - if agent_dir.exists(): - shutil.rmtree(agent_dir) logger.error(f"Failed to create agent '{request.name}': {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}") + if response is None: + raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists") + + return response + @router.put( "/agents/{name}", @@ -428,19 +442,30 @@ async def delete_agent(name: str) -> None: name = _normalize_agent_name(name) user_id = get_effective_user_id() paths = get_paths() - agent_dir = paths.user_agent_dir(user_id, name) - if not agent_dir.exists(): - if paths.agent_dir(name).exists(): - raise HTTPException( - status_code=409, - detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."), - ) - raise HTTPException(status_code=404, detail=f"Agent '{name}' not found") + def _remove_agent_dir() -> tuple[str, str]: + # Runs in a worker thread: resolving the base dir, probing the directory + # (`exists`), and removing it (`rmtree`) are all blocking filesystem IO + # that must stay off the event loop. + agent_dir = paths.user_agent_dir(user_id, name) + if not agent_dir.exists(): + outcome = "legacy" if paths.agent_dir(name).exists() else "missing" + return outcome, str(agent_dir) + shutil.rmtree(agent_dir) + return "deleted", str(agent_dir) try: - shutil.rmtree(agent_dir) - logger.info(f"Deleted agent '{name}' from {agent_dir}") + outcome, agent_dir = await asyncio.to_thread(_remove_agent_dir) except Exception as e: logger.error(f"Failed to delete agent '{name}': {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Failed to delete agent: {str(e)}") + + if outcome == "legacy": + raise HTTPException( + status_code=409, + detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."), + ) + if outcome == "missing": + raise HTTPException(status_code=404, detail=f"Agent '{name}' not found") + + logger.info(f"Deleted agent '{name}' from {agent_dir}") diff --git a/backend/tests/blocking_io/test_agents_router.py b/backend/tests/blocking_io/test_agents_router.py new file mode 100644 index 000000000..1f7787397 --- /dev/null +++ b/backend/tests/blocking_io/test_agents_router.py @@ -0,0 +1,64 @@ +"""Regression anchors: the custom-agent router must not block the event loop. + +``app.gateway.routers.agents.create_agent_endpoint`` and ``delete_agent`` are +async route handlers that resolve the agent directory (``Paths.base_dir`` calls +``Path.resolve``), probe it (``Path.exists``), and create/remove it (``mkdir``, +config/SOUL writes, ``shutil.rmtree``) — all blocking IO. Both offload that work +via ``asyncio.to_thread``; if any of it regresses back onto the event loop, the +strict Blockbuster gate raises ``BlockingError`` and these tests fail. + +Imports live at module scope so the one-time FastAPI app construction (which +reads files while building OpenAPI schemas) happens at collection time, not on +the event loop under test. Test-side path resolution is itself offloaded with +``asyncio.to_thread`` (matching ``test_uploads_middleware``) so only the +handlers' own filesystem access is exercised on the loop. +""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from app.gateway.routers.agents import AgentCreateRequest, create_agent_endpoint, delete_agent +from deerflow.config.agents_api_config import load_agents_api_config_from_dict +from deerflow.config.paths import get_paths +from deerflow.runtime.user_context import get_effective_user_id + +pytestmark = pytest.mark.asyncio + + +async def test_create_agent_does_not_block_event_loop(tmp_path: Path, monkeypatch) -> None: + monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path)) + monkeypatch.setattr("deerflow.config.paths._paths", None) + load_agents_api_config_from_dict({"enabled": True}) + try: + response = await create_agent_endpoint(AgentCreateRequest(name="loop-make-agent", soul="You are a test agent.")) + assert response is not None + + user_id = get_effective_user_id() + # test-side check (resolution offloaded; not exercised on the loop) + agent_dir = await asyncio.to_thread(get_paths().user_agent_dir, user_id, "loop-make-agent") + assert await asyncio.to_thread((agent_dir / "config.yaml").exists) + finally: + load_agents_api_config_from_dict({}) + + +async def test_delete_agent_does_not_block_event_loop(tmp_path: Path, monkeypatch) -> None: + monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path)) + monkeypatch.setattr("deerflow.config.paths._paths", None) + load_agents_api_config_from_dict({"enabled": True}) + try: + user_id = get_effective_user_id() + user_id = get_effective_user_id() + # test-side seeding (resolution offloaded; not exercised on the loop) + agent_dir = await asyncio.to_thread(get_paths().user_agent_dir, user_id, "loop-test-agent") + await asyncio.to_thread(agent_dir.mkdir, parents=True, exist_ok=True) + await asyncio.to_thread((agent_dir / "config.yaml").write_text, "name: loop-test-agent\n", encoding="utf-8") + + await delete_agent("loop-test-agent") + + assert not await asyncio.to_thread(agent_dir.exists) + finally: + load_agents_api_config_from_dict({}) From 18bbb82f07a6893414fde8ac0ab22c837276ab19 Mon Sep 17 00:00:00 2001 From: tanghang97 <41568151+tanghang97@users.noreply.github.com> Date: Tue, 9 Jun 2026 22:37:54 +0800 Subject: [PATCH 07/11] Fix 'make dev' failure in Windows environment (#3236) * fix: Solving the problem of "make dev" failing to start in Windows environment * fix: revert the change to the startup_config and fix the lint errors * fix: Address Copilot review feedback - Validate wait-for-port input and avoid PowerShell port interpolation - Require Python 3 in serve.sh launcher detection - Keep Windows event loop policy setup in sitecustomize only - Clarify sitecustomize process-wide backend behavior --- backend/sitecustomize.py | 26 ++++++++++++++++++++++++++ scripts/serve.sh | 30 ++++++++++++++++-------------- scripts/wait-for-port.sh | 18 ++++++++++++++++++ 3 files changed, 60 insertions(+), 14 deletions(-) create mode 100644 backend/sitecustomize.py diff --git a/backend/sitecustomize.py b/backend/sitecustomize.py new file mode 100644 index 000000000..4b85819e1 --- /dev/null +++ b/backend/sitecustomize.py @@ -0,0 +1,26 @@ +"""Process-wide Python startup customizations for backend entrypoints. + +When ``backend/`` is on ``sys.path``, Python imports this module during +interpreter startup. Keep changes here suitable for all gateway, script, +migration, and test entrypoints that run in that environment. +""" + +from __future__ import annotations + +import asyncio +import sys + + +def _configure_windows_event_loop_policy() -> None: + if sys.platform != "win32": + return + + selector_policy = getattr(asyncio, "WindowsSelectorEventLoopPolicy", None) + if selector_policy is None: + return + + if not isinstance(asyncio.get_event_loop_policy(), selector_policy): + asyncio.set_event_loop_policy(selector_policy()) + + +_configure_windows_event_loop_policy() diff --git a/scripts/serve.sh b/scripts/serve.sh index bae85f22c..aac39bfa4 100755 --- a/scripts/serve.sh +++ b/scripts/serve.sh @@ -37,6 +37,17 @@ if [ -f "$REPO_ROOT/.env" ]; then set +a fi +_pick_python() { + local candidate + for candidate in python3 python py; do + if command -v "$candidate" >/dev/null 2>&1 && "$candidate" -c 'import sys; raise SystemExit(0 if sys.version_info.major >= 3 else 1)' >/dev/null 2>&1; then + printf '%s\n' "$candidate" + return 0 + fi + done + return 1 +} + # ── Argument parsing ───────────────────────────────────────────────────────── DEV_MODE=true @@ -274,11 +285,7 @@ fi if $DEV_MODE; then FRONTEND_CMD="pnpm run dev" else - if command -v python3 >/dev/null 2>&1; then - PYTHON_BIN="python3" - elif command -v python >/dev/null 2>&1; then - PYTHON_BIN="python" - else + if ! PYTHON_BIN="$(_pick_python)"; then echo "Python is required to generate BETTER_AUTH_SECRET." exit 1 fi @@ -337,15 +344,10 @@ fi # ── Install dependencies ──────────────────────────────────────────────────── -# Pick a Python for the extras detector. Falls back to plain `python` for -# Windows/Git Bash where only `python` is on PATH. -if command -v python3 >/dev/null 2>&1; then - DETECT_PYTHON="python3" -elif command -v python >/dev/null 2>&1; then - DETECT_PYTHON="python" -else - DETECT_PYTHON="" -fi +# Pick a runnable Python for the extras detector. On Windows/Git Bash, +# `python3` can resolve to the Microsoft Store alias in WindowsApps, which is +# present on PATH but not executable from Bash. +DETECT_PYTHON="$(_pick_python || true)" # Resolve uv extras (postgres, etc.) from UV_EXTRAS or config.yaml so that # `uv sync` does not wipe out optional dependencies on every restart. See diff --git a/scripts/wait-for-port.sh b/scripts/wait-for-port.sh index dc0dffa1d..ef2522e63 100755 --- a/scripts/wait-for-port.sh +++ b/scripts/wait-for-port.sh @@ -17,10 +17,28 @@ PORT="${1:?Usage: wait-for-port.sh [timeout] [service_name]}" TIMEOUT="${2:-60}" SERVICE="${3:-Service}" +case "$PORT" in + ''|*[!0-9]*) + echo "Port must be a numeric TCP port: $PORT" >&2 + exit 1 + ;; +esac + +if [ "$PORT" -lt 1 ] || [ "$PORT" -gt 65535 ]; then + echo "Port must be between 1 and 65535: $PORT" >&2 + exit 1 +fi + elapsed=0 interval=1 is_port_listening() { + if command -v powershell.exe >/dev/null 2>&1; then + if WAIT_FOR_PORT_PORT="$PORT" powershell.exe -NoProfile -ExecutionPolicy Bypass -Command "\$ErrorActionPreference='SilentlyContinue'; \$Port = [int]\$env:WAIT_FOR_PORT_PORT; if (Get-NetTCPConnection -LocalPort \$Port -State Listen) { exit 0 } else { exit 1 }" >/dev/null 2>&1; then + return 0 + fi + fi + if command -v lsof >/dev/null 2>&1; then if lsof -nP -iTCP:"$PORT" -sTCP:LISTEN -t >/dev/null 2>&1; then return 0 From 16391e35ab7106e3e1f706139dd043dadf62ca93 Mon Sep 17 00:00:00 2001 From: DanielWalnut <45447813+hetaoBackend@users.noreply.github.com> Date: Tue, 9 Jun 2026 23:07:17 +0800 Subject: [PATCH 08/11] fix(skills): harden slash skill activation across chat channels (#3466) * support slash skill activation * format slash skill activation * Preserve slash skill activation with uploads * Address slash skill review feedback * Address slash skill follow-up review * Fix lazy slash skill storage resolution * Keep slash skill activation out of system prompt * Address slash skill review issues * fix: harden slash skill command handling * feat(frontend): add slash skill autocomplete * fix: address slash skill review feedback * fix: preserve slash skill text for IM uploads --- README.md | 2 + backend/CLAUDE.md | 22 +- backend/app/channels/commands.py | 7 + backend/app/channels/dingtalk.py | 6 +- backend/app/channels/discord.py | 5 +- backend/app/channels/feishu.py | 6 +- backend/app/channels/manager.py | 144 ++- backend/app/channels/slack.py | 38 +- backend/app/channels/telegram.py | 36 +- backend/app/channels/wechat.py | 3 +- backend/app/channels/wecom.py | 3 +- .../deerflow/agents/lead_agent/agent.py | 35 +- .../deerflow/agents/lead_agent/prompt.py | 5 + .../skill_activation_middleware.py | 289 ++++++ .../agents/middlewares/uploads_middleware.py | 5 +- backend/packages/harness/deerflow/client.py | 10 +- .../packages/harness/deerflow/skills/slash.py | 65 ++ .../harness/deerflow/utils/messages.py | 31 + backend/tests/test_channels.py | 909 ++++++++++++++++++ backend/tests/test_discord_channel.py | 67 +- backend/tests/test_lead_agent_skills.py | 11 + backend/tests/test_slash_skills.py | 557 +++++++++++ .../test_uploads_middleware_core_logic.py | 48 + frontend/eslint.config.js | 2 + .../components/ai-elements/prompt-input.tsx | 5 + .../src/components/workspace/input-box.tsx | 198 +++- frontend/src/core/messages/utils.ts | 15 +- frontend/tests/e2e/chat.spec.ts | 199 ++++ frontend/tests/e2e/utils/mock-api.ts | 43 + .../tests/unit/core/messages/utils.test.ts | 33 + .../tests/unit/core/threads/export.test.ts | 16 + 31 files changed, 2758 insertions(+), 57 deletions(-) create mode 100644 backend/packages/harness/deerflow/agents/middlewares/skill_activation_middleware.py create mode 100644 backend/packages/harness/deerflow/skills/slash.py create mode 100644 backend/packages/harness/deerflow/utils/messages.py create mode 100644 backend/tests/test_slash_skills.py diff --git a/README.md b/README.md index a093b6f10..c06d5a9d0 100644 --- a/README.md +++ b/README.md @@ -585,6 +585,8 @@ A standard Agent Skill is a structured capability module — a Markdown file tha Skills are loaded progressively — only when the task needs them, not all at once. This keeps the context window lean and makes DeerFlow work well even with token-sensitive models. +Users can explicitly activate an enabled skill for a single turn by starting the request with `/skill-name`, for example `/data-analysis analyze uploads/foo.csv`. DeerFlow loads that skill's `SKILL.md` as hidden current-turn context while leaving the base prompt limited to skill metadata. Slash activation respects disabled skills, custom-agent skill whitelists, and existing channel commands such as `/new` and `/help`. + When you install `.skill` archives through the Gateway, DeerFlow accepts standard optional frontmatter metadata such as `version`, `author`, and `compatibility` instead of rejecting otherwise valid external skills. Tools follow the same philosophy. DeerFlow comes with a core toolset — web search, web fetch, file operations, bash execution — and supports custom tools via MCP servers and Python functions. Swap anything. Add anything. diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 29a776217..a0e256e19 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -202,16 +202,17 @@ Lead-agent middlewares are assembled in strict append order across `packages/har 6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider. 7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues 8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting -9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled) -10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode) -11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id -12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model -13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses) -14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support) -15. **DeferredToolFilterMiddleware** - Hides deferred (MCP) tool schemas from the bound model using a build-time deferred-name set + catalog hash, reading per-thread promotions from `ThreadState.promoted` (hash-scoped, no ContextVar); a tool becomes bound on subsequent turns after `tool_search` returns its schema (optional, if `tool_search.enabled`) -16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`) -17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer -18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last) +9. **SkillActivationMiddleware** - Detects strict `/skill-name task` syntax on the latest real user message, resolves only enabled and runtime-allowed skills, reads `SKILL.md` from trusted skill storage, injects the skill body as hidden current-turn model context, and records a `middleware:skill_activation` audit event with skill name, category, path, and content hash +10. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled) +11. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode) +12. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id +13. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model +14. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses) +15. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support) +16. **DeferredToolFilterMiddleware** - Hides deferred (MCP) tool schemas from the bound model using a build-time deferred-name set + catalog hash, reading per-thread promotions from `ThreadState.promoted` (hash-scoped, no ContextVar); a tool becomes bound on subsequent turns after `tool_search` returns its schema (optional, if `tool_search.enabled`) +17. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`) +18. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer +19. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last) ### Configuration System @@ -348,6 +349,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti - **Format**: Directory with `SKILL.md` (YAML frontmatter: name, description, license, allowed-tools) - **Loading**: `load_skills()` recursively scans `skills/{public,custom}` for `SKILL.md`, parses metadata, and reads enabled state from extensions_config.json - **Injection**: Enabled skills listed in agent system prompt with container paths +- **Slash activation**: `/skill-name task` loads that enabled skill's `SKILL.md` for the current model call only. The resolver rejects leading whitespace, missing separators, reserved channel commands (`/new`, `/help`, `/bootstrap`, `/status`, `/models`, `/memory`), disabled skills, and skills outside a custom agent's whitelist. - **Installation**: `POST /api/skills/install` extracts .skill ZIP archive to custom/ directory ### Model Factory (`packages/harness/deerflow/models/factory.py`) diff --git a/backend/app/channels/commands.py b/backend/app/channels/commands.py index 704330410..c783899c5 100644 --- a/backend/app/channels/commands.py +++ b/backend/app/channels/commands.py @@ -18,3 +18,10 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset( "/help", } ) + + +def is_known_channel_command(text: str) -> bool: + """Return whether text starts with a registered channel control command.""" + if not text.startswith("/"): + return False + return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS diff --git a/backend/app/channels/dingtalk.py b/backend/app/channels/dingtalk.py index f2833d4ff..fb53ce272 100644 --- a/backend/app/channels/dingtalk.py +++ b/backend/app/channels/dingtalk.py @@ -14,7 +14,7 @@ from typing import Any import httpx from app.channels.base import Channel -from app.channels.commands import KNOWN_CHANNEL_COMMANDS +from app.channels.commands import is_known_channel_command from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -59,9 +59,7 @@ def _normalize_allowed_users(allowed_users: Any) -> set[str]: def _is_dingtalk_command(text: str) -> bool: - if not text.startswith("/"): - return False - return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS + return is_known_channel_command(text) def _extract_text_from_rich_text(rich_text_list: list) -> str: diff --git a/backend/app/channels/discord.py b/backend/app/channels/discord.py index 3b113c28d..c88eb0239 100644 --- a/backend/app/channels/discord.py +++ b/backend/app/channels/discord.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import Any from app.channels.base import Channel +from app.channels.commands import is_known_channel_command from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -300,7 +301,7 @@ class DiscordChannel(Channel): # If this is a known active thread, process normally if thread_id in self._active_thread_ids: - msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT + msg_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT inbound = self._make_inbound( chat_id=chat_id, user_id=str(message.author.id), @@ -407,7 +408,7 @@ class DiscordChannel(Channel): chat_id = channel_id typing_target = message.channel # Type into the channel - msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT + msg_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT inbound = self._make_inbound( chat_id=chat_id, user_id=str(message.author.id), diff --git a/backend/app/channels/feishu.py b/backend/app/channels/feishu.py index eb6fb72ca..fddbc7186 100644 --- a/backend/app/channels/feishu.py +++ b/backend/app/channels/feishu.py @@ -11,7 +11,7 @@ import time from typing import Any, Literal from app.channels.base import Channel -from app.channels.commands import KNOWN_CHANNEL_COMMANDS +from app.channels.commands import is_known_channel_command from app.channels.message_bus import ( PENDING_CLARIFICATION_METADATA_KEY, RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY, @@ -30,9 +30,7 @@ PENDING_CLARIFICATION_TTL_SECONDS = 30 * 60 def _is_feishu_command(text: str) -> bool: - if not text.startswith("/"): - return False - return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS + return is_known_channel_command(text) class FeishuChannel(Channel): diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index 9beceeb3a..673723d6e 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -8,6 +8,7 @@ import mimetypes import re import time from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass from pathlib import Path from typing import Any @@ -26,8 +27,13 @@ from app.channels.message_bus import ( from app.channels.store import ChannelStore from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token from app.gateway.internal_auth import create_internal_auth_headers +from deerflow.config.agents_config import load_agent_config from deerflow.config.paths import make_safe_user_id from deerflow.runtime.user_context import get_effective_user_id +from deerflow.skills.slash import parse_slash_skill_reference +from deerflow.skills.storage import get_or_new_skill_storage +from deerflow.skills.storage.skill_storage import SkillStorage +from deerflow.utils.messages import ORIGINAL_USER_CONTENT_KEY logger = logging.getLogger(__name__) @@ -124,6 +130,16 @@ class InvalidChannelSessionConfigError(ValueError): """Raised when IM channel session overrides contain invalid agent config.""" +class SlashSkillCommandResolutionError(RuntimeError): + """Raised when IM slash-skill command resolution cannot complete safely.""" + + +@dataclass(frozen=True, slots=True) +class _SlashSkillCommandResolution: + route_to_chat: bool = False + failure_message: str | None = None + + def _is_thread_busy_error(exc: BaseException | None) -> bool: if exc is None: return False @@ -410,6 +426,46 @@ def _format_artifact_text(artifacts: list[str]) -> str: _OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/" +def _unknown_command_reply(command: str | None = None) -> str: + available = " | ".join(sorted(KNOWN_CHANNEL_COMMANDS)) + if command: + return f"Unknown command: /{command}. Available commands: {available}" + return f"Unknown command. Available commands: {available}" + + +def _human_input_message(content: str, *, original_content: str | None = None) -> dict[str, Any]: + message: dict[str, Any] = {"role": "human", "content": content} + if original_content is not None and original_content != content: + message["additional_kwargs"] = {ORIGINAL_USER_CONTENT_KEY: original_content} + return message + + +def _resolve_slash_skill_command( + text: str, + available_skills: set[str] | None = None, + storage: SkillStorage | Callable[[], SkillStorage] | None = None, +) -> _SlashSkillCommandResolution | None: + reference = parse_slash_skill_reference(text) + if reference is None: + return None + try: + resolved_storage = storage() if callable(storage) else storage or get_or_new_skill_storage() + skills = resolved_storage.load_skills(enabled_only=False) + + skill = next((candidate for candidate in skills if candidate.name == reference.name), None) + if skill is None: + return None + if not skill.enabled: + return _SlashSkillCommandResolution(failure_message=f"Skill `/{reference.name}` is installed but disabled. Enable it before using slash activation.") + if available_skills is not None and reference.name not in available_skills: + return _SlashSkillCommandResolution(failure_message=f"Skill `/{reference.name}` is not available for this agent.") + + return _SlashSkillCommandResolution(route_to_chat=True) + except Exception as exc: + logger.exception("[Manager] failed to resolve slash skill command") + raise SlashSkillCommandResolutionError("Failed to resolve slash skill command. Please check the skill configuration.") from exc + + def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]: """Resolve virtual artifact paths to host filesystem paths with metadata. @@ -624,6 +680,7 @@ class ChannelManager: self._default_session = _as_dict(default_session) self._channel_sessions = dict(channel_sessions or {}) self._client = None # lazy init — langgraph_sdk async client + self._skill_storage: SkillStorage | None = None self._csrf_token = generate_csrf_token() self._semaphore: asyncio.Semaphore | None = None self._running = False @@ -696,6 +753,21 @@ class ChannelManager: return assistant_id, run_config, run_context + def _resolve_available_skill_names(self, msg: InboundMessage) -> set[str] | None: + thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "" + _, _, run_context = self._resolve_run_params(msg, thread_id) + if run_context.get("is_bootstrap"): + return {"bootstrap"} + + agent_name = run_context.get("agent_name") + if not isinstance(agent_name, str) or not agent_name.strip(): + return None + + agent_config = load_agent_config(_normalize_custom_agent_name(agent_name)) + if agent_config and agent_config.skills is not None: + return set(agent_config.skills) + return None + # -- LangGraph SDK client (lazy) ---------------------------------------- def _get_client(self): @@ -713,6 +785,11 @@ class ChannelManager: ) return self._client + def _get_skill_storage(self) -> SkillStorage: + if self._skill_storage is None: + self._skill_storage = get_or_new_skill_storage() + return self._skill_storage + # -- lifecycle --------------------------------------------------------- async def start(self) -> None: @@ -782,6 +859,14 @@ class ChannelManager: exc, ) await self._send_error(msg, str(exc)) + except SlashSkillCommandResolutionError as exc: + logger.warning( + "Slash skill command resolution failed for %s (chat=%s): %s", + msg.channel_name, + msg.chat_id, + exc, + ) + await self._send_error(msg, str(exc)) except Exception: logger.exception( "Error handling message from %s (chat=%s)", @@ -836,9 +921,11 @@ class ChannelManager: if extra_context: run_context.update(extra_context) + original_text = msg.text uploaded = await _ingest_inbound_files(thread_id, msg) if uploaded: msg.text = f"{_format_uploaded_files_block(uploaded)}\n\n{msg.text}".strip() + human_message = _human_input_message(msg.text, original_content=original_text) if self._channel_supports_streaming(msg.channel_name): await self._handle_streaming_chat( @@ -848,6 +935,7 @@ class ChannelManager: assistant_id, run_config, run_context, + human_message, ) return @@ -856,7 +944,7 @@ class ChannelManager: result = await client.runs.wait( thread_id, assistant_id, - input={"messages": [{"role": "human", "content": msg.text}]}, + input={"messages": [human_message]}, config=run_config, context=run_context, multitask_strategy="reject", @@ -909,6 +997,7 @@ class ChannelManager: assistant_id: str, run_config: dict[str, Any], run_context: dict[str, Any], + human_message: dict[str, Any], ) -> None: logger.info("[Manager] invoking runs.stream(thread_id=%s, text=%r)", thread_id, msg.text[:100]) @@ -924,7 +1013,7 @@ class ChannelManager: async for chunk in client.runs.stream( thread_id, assistant_id, - input={"messages": [{"role": "human", "content": msg.text}]}, + input={"messages": [human_message]}, config=run_config, context=run_context, stream_mode=["messages-tuple", "values"], @@ -1011,11 +1100,20 @@ class ChannelManager: # -- command handling -------------------------------------------------- async def _handle_command(self, msg: InboundMessage) -> None: - text = msg.text.strip() + raw_text = msg.text + text = raw_text.strip() parts = text.split(maxsplit=1) - command = parts[0].lower().lstrip("/") + reply: str | None = None + if not parts: + command = None + reply = _unknown_command_reply() + else: + command = parts[0].lower().removeprefix("/") - if command == "bootstrap": + if reply is None and not raw_text.startswith("/"): + reply = _unknown_command_reply(command) + + if reply is None and command == "bootstrap": from dataclasses import replace as _dc_replace chat_text = parts[1] if len(parts) > 1 else "Initialize workspace" @@ -1023,7 +1121,7 @@ class ChannelManager: await self._handle_chat(chat_msg, extra_context={"is_bootstrap": True}) return - if command == "new": + if reply is None and command == "new": # Create a new thread through Gateway client = self._get_client() thread = await client.threads.create() @@ -1036,14 +1134,14 @@ class ChannelManager: user_id=msg.user_id, ) reply = "New conversation started." - elif command == "status": + elif reply is None and command == "status": thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) reply = f"Active thread: {thread_id}" if thread_id else "No active conversation." - elif command == "models": + elif reply is None and command == "models": reply = await self._fetch_gateway("/api/models", "models") - elif command == "memory": + elif reply is None and command == "memory": reply = await self._fetch_gateway("/api/memory", "memory") - elif command == "help": + elif reply is None and command == "help": reply = ( "Available commands:\n" "/bootstrap — Start a bootstrap session (enables agent setup)\n" @@ -1051,16 +1149,32 @@ class ChannelManager: "/status — Show current thread info\n" "/models — List available models\n" "/memory — Show memory status\n" + "/ — Activate an enabled skill for one turn\n" "/help — Show this help" ) - else: - available = " | ".join(sorted(KNOWN_CHANNEL_COMMANDS)) - reply = f"Unknown command: /{command}. Available commands: {available}" + elif reply is None: + slash_resolution = await asyncio.to_thread( + lambda: _resolve_slash_skill_command( + raw_text, + self._resolve_available_skill_names(msg), + self._get_skill_storage, + ) + ) + if slash_resolution and slash_resolution.failure_message: + reply = slash_resolution.failure_message + elif slash_resolution and slash_resolution.route_to_chat: + from dataclasses import replace as _dc_replace + + chat_msg = _dc_replace(msg, msg_type=InboundMessageType.CHAT) + await self._handle_chat(chat_msg) + return + else: + reply = _unknown_command_reply(command) outbound = OutboundMessage( channel_name=msg.channel_name, chat_id=msg.chat_id, - thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "", + thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "", text=reply, thread_ts=msg.thread_ts, metadata=_slim_metadata(msg.metadata), @@ -1098,7 +1212,7 @@ class ChannelManager: outbound = OutboundMessage( channel_name=msg.channel_name, chat_id=msg.chat_id, - thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "", + thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "", text=error_text, thread_ts=msg.thread_ts, metadata=_slim_metadata(msg.metadata), diff --git a/backend/app/channels/slack.py b/backend/app/channels/slack.py index 65cb36cf5..3e31a19b2 100644 --- a/backend/app/channels/slack.py +++ b/backend/app/channels/slack.py @@ -9,6 +9,7 @@ from typing import Any from markdown_to_mrkdwn import SlackMarkdownConverter from app.channels.base import Channel +from app.channels.commands import is_known_channel_command from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -32,6 +33,20 @@ def _normalize_allowed_users(allowed_users: Any) -> set[str]: return {str(user_id) for user_id in values if str(user_id)} +def _strip_leading_slack_bot_mention(text: str, bot_user_id: str | None) -> str: + if not bot_user_id: + return text + if not text.startswith("<@"): + return text + end = text.find(">") + if end <= 2: + return text + mentioned_user_id = text[2:end].split("|", 1)[0].lstrip("!") + if mentioned_user_id != bot_user_id: + return text + return text[end + 1 :].lstrip() + + class SlackChannel(Channel): """Slack IM channel using Socket Mode (WebSocket, no public IP). @@ -49,6 +64,8 @@ class SlackChannel(Channel): self._web_client = None self._loop: asyncio.AbstractEventLoop | None = None self._allowed_users = _normalize_allowed_users(config.get("allowed_users", [])) + configured_bot_user_id = config.get("bot_user_id") + self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None async def start(self) -> None: if self._running: @@ -72,6 +89,17 @@ class SlackChannel(Channel): return self._web_client = WebClient(token=bot_token) + if self._bot_user_id is None: + try: + auth_info = await asyncio.to_thread(self._web_client.auth_test) + user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None + if user_id is None: + auth_get = getattr(auth_info, "get", None) + user_id = auth_get("user_id") if callable(auth_get) else None + if isinstance(user_id, str) and user_id: + self._bot_user_id = user_id + except Exception: + logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True) self._socket_client = SocketModeClient( app_token=app_token, web_client=self._web_client, @@ -210,6 +238,12 @@ class SlackChannel(Channel): if event_type != "events_api": return + if self._bot_user_id is None: + authorization = next((item for item in req.payload.get("authorizations", []) if isinstance(item, dict)), None) + user_id = authorization.get("user_id") if authorization else None + if isinstance(user_id, str) and user_id: + self._bot_user_id = user_id + event = req.payload.get("event", {}) etype = event.get("type", "") @@ -233,13 +267,15 @@ class SlackChannel(Channel): return text = event.get("text", "").strip() + if event.get("type") == "app_mention": + text = _strip_leading_slack_bot_mention(text, self._bot_user_id) if not text: return channel_id = event.get("channel", "") thread_ts = event.get("thread_ts") or event.get("ts", "") - if text.startswith("/"): + if is_known_channel_command(text): msg_type = InboundMessageType.COMMAND else: msg_type = InboundMessageType.CHAT diff --git a/backend/app/channels/telegram.py b/backend/app/channels/telegram.py index 9985fd43f..fabdbfb61 100644 --- a/backend/app/channels/telegram.py +++ b/backend/app/channels/telegram.py @@ -60,12 +60,17 @@ class TelegramChannel(Channel): # Command handlers app.add_handler(CommandHandler("start", self._cmd_start)) + app.add_handler(CommandHandler("bootstrap", self._cmd_generic)) app.add_handler(CommandHandler("new", self._cmd_generic)) app.add_handler(CommandHandler("status", self._cmd_generic)) app.add_handler(CommandHandler("models", self._cmd_generic)) app.add_handler(CommandHandler("memory", self._cmd_generic)) app.add_handler(CommandHandler("help", self._cmd_generic)) + # Slash skill commands are dynamic and cannot all be pre-registered + # with Telegram, so route unknown slash commands through chat handling. + app.add_handler(MessageHandler(filters.TEXT & filters.COMMAND, self._on_text)) + # General message handler app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self._on_text)) @@ -228,6 +233,33 @@ class TelegramChannel(Channel): return True return user_id in self._allowed_users + def _get_bot_username(self, context) -> str | None: + bot = getattr(context, "bot", None) + username = getattr(bot, "username", None) + if not username and self._application is not None: + username = getattr(getattr(self._application, "bot", None), "username", None) + return str(username) if username else None + + @staticmethod + def _strip_bot_username_from_leading_command(text: str, bot_username: str | None) -> str: + username = (bot_username or "").lstrip("@").lower() + if not username or not text.startswith("/"): + return text + + parts = text.split(maxsplit=1) + command_token = parts[0] + if "@" not in command_token: + return text + + command_name, addressed_username = command_token[1:].rsplit("@", 1) + if not command_name or addressed_username.lower() != username: + return text + + normalized = f"/{command_name}" + if len(parts) > 1: + normalized = f"{normalized} {parts[1]}" + return normalized + async def _cmd_start(self, update, context) -> None: """Handle /start command.""" if not self._check_user(update.effective_user.id): @@ -243,7 +275,7 @@ class TelegramChannel(Channel): if not self._check_user(update.effective_user.id): return - text = update.message.text + text = self._strip_bot_username_from_leading_command(update.message.text.strip(), self._get_bot_username(context)) chat_id = str(update.effective_chat.id) user_id = str(update.effective_user.id) msg_id = str(update.message.message_id) @@ -279,7 +311,7 @@ class TelegramChannel(Channel): if not self._check_user(update.effective_user.id): return - text = update.message.text.strip() + text = self._strip_bot_username_from_leading_command(update.message.text.strip(), self._get_bot_username(context)) if not text: return diff --git a/backend/app/channels/wechat.py b/backend/app/channels/wechat.py index a8339c2e2..9a9ddf1a6 100644 --- a/backend/app/channels/wechat.py +++ b/backend/app/channels/wechat.py @@ -22,6 +22,7 @@ from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from app.channels.base import Channel +from app.channels.commands import is_known_channel_command from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -620,7 +621,7 @@ class WechatChannel(Channel): chat_id=chat_id, user_id=chat_id, text=text, - msg_type=InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT, + msg_type=InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT, thread_ts=thread_ts, files=files, metadata={ diff --git a/backend/app/channels/wecom.py b/backend/app/channels/wecom.py index 3e0cdb3d1..33d3cf1bb 100644 --- a/backend/app/channels/wecom.py +++ b/backend/app/channels/wecom.py @@ -8,6 +8,7 @@ from collections.abc import Awaitable, Callable from typing import Any, cast from app.channels.base import Channel +from app.channels.commands import is_known_channel_command from app.channels.message_bus import ( InboundMessageType, MessageBus, @@ -270,7 +271,7 @@ class WeComChannel(Channel): user_id = (body.get("from") or {}).get("userid") - inbound_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT + inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT inbound = self._make_inbound( chat_id=user_id, # keep user's conversation in memory user_id=user_id, diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py index 2d87799c7..cc3eee449 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py @@ -49,6 +49,8 @@ from deerflow.tracing import build_tracing_callbacks logger = logging.getLogger(__name__) +_BOOTSTRAP_SKILL_NAMES = {"bootstrap"} + def _get_runtime_config(config: RunnableConfig) -> dict: """Merge legacy configurable options with LangGraph runtime context.""" @@ -271,6 +273,7 @@ def build_middlewares( agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None, *, + available_skills: set[str] | None = None, app_config: AppConfig | None = None, deferred_setup=None, ): @@ -302,6 +305,13 @@ def build_middlewares( middlewares.append(DynamicContextMiddleware(agent_name=agent_name, app_config=resolved_app_config)) + # Deterministically load a full SKILL.md when the user starts the turn with + # /skill-name. This keeps the base system prompt metadata-only while giving + # explicit user activation priority over model-side relevance guessing. + from deerflow.agents.middlewares.skill_activation_middleware import SkillActivationMiddleware + + middlewares.append(SkillActivationMiddleware(available_skills=available_skills, app_config=resolved_app_config)) + # Add summarization middleware if enabled summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config) if summarization_middleware is not None: @@ -369,7 +379,7 @@ def build_middlewares( def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None: if is_bootstrap: - return {"bootstrap"} + return set(_BOOTSTRAP_SKILL_NAMES) if agent_config and agent_config.skills is not None: return set(agent_config.skills) return None @@ -475,17 +485,25 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig): if is_bootstrap: # Special bootstrap agent with minimal prompt for initial custom agent creation flow + # Keep the bootstrap skill set intentionally narrow so agent creation + # remains deterministic before the custom agent's own config exists. raw_tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent] filtered = filter_tools_by_skill_allowed_tools(raw_tools, skills_for_tool_policy) final_tools, setup = assemble_deferred_tools(filtered, enabled=resolved_app_config.tool_search.enabled) return create_agent( model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config, attach_tracing=False), tools=final_tools, - middleware=build_middlewares(config, model_name=model_name, app_config=resolved_app_config, deferred_setup=setup), + middleware=build_middlewares( + config, + model_name=model_name, + available_skills=set(_BOOTSTRAP_SKILL_NAMES), + app_config=resolved_app_config, + deferred_setup=setup, + ), system_prompt=apply_prompt_template( subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, - available_skills=set(["bootstrap"]), + available_skills=set(_BOOTSTRAP_SKILL_NAMES), app_config=resolved_app_config, deferred_names=setup.deferred_names, ), @@ -502,12 +520,19 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig): return create_agent( model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config, attach_tracing=False), tools=final_tools, - middleware=build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config, deferred_setup=setup), + middleware=build_middlewares( + config, + model_name=model_name, + agent_name=agent_name, + available_skills=available_skills, + app_config=resolved_app_config, + deferred_setup=setup, + ), system_prompt=apply_prompt_template( subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, - available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None, + available_skills=available_skills, app_config=resolved_app_config, deferred_names=setup.deferred_names, ), diff --git a/backend/packages/harness/deerflow/agents/lead_agent/prompt.py b/backend/packages/harness/deerflow/agents/lead_agent/prompt.py index f7d9fa8c6..7a32d0c9e 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/prompt.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/prompt.py @@ -625,6 +625,11 @@ You have access to skills that provide optimized workflows for specific tasks. E 4. Load referenced resources only when needed during execution 5. Follow the skill's instructions precisely +**Explicit Slash Skill Activation:** +- If the user starts a request with `/`, that skill was explicitly requested for the current turn. +- Follow the activated skill before choosing a general workflow. +- The runtime injects the activated skill content for explicit slash activations; do not call `read_file` for that SKILL.md again unless the injected skill references supporting resources you need. + **Skills are located at:** {container_base_path} {skill_evolution_section} {skills_list} diff --git a/backend/packages/harness/deerflow/agents/middlewares/skill_activation_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/skill_activation_middleware.py new file mode 100644 index 000000000..eff634c7e --- /dev/null +++ b/backend/packages/harness/deerflow/agents/middlewares/skill_activation_middleware.py @@ -0,0 +1,289 @@ +"""Middleware for explicit slash skill activation.""" + +from __future__ import annotations + +import asyncio +import hashlib +import html +import logging +import uuid +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, override + +from langchain.agents.middleware import AgentMiddleware +from langchain.agents.middleware.types import ModelRequest, ModelResponse +from langchain_core.messages import AIMessage, HumanMessage + +from deerflow.skills.slash import parse_slash_skill_reference, resolve_slash_skill +from deerflow.skills.storage import get_or_new_skill_storage +from deerflow.skills.storage.skill_storage import SkillStorage +from deerflow.skills.types import SKILL_MD_FILE +from deerflow.utils.messages import get_original_user_content_text + +if TYPE_CHECKING: + from deerflow.config.app_config import AppConfig + +logger = logging.getLogger(__name__) + +_SLASH_SKILL_ACTIVATION_KEY = "slash_skill_activation" +_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY = "slash_skill_activation_target_id" +_SUMMARY_MESSAGE_NAME = "summary" + + +@dataclass(frozen=True, slots=True) +class _Activation: + skill_name: str + category: str + container_file_path: str + skill_content: str + content_hash: str + remaining_text: str + + +@dataclass(frozen=True, slots=True) +class _ActivationResolution: + activation: _Activation | None = None + failure_message: str | None = None + + +def is_slash_skill_activation_reminder(message: object) -> bool: + """Return whether a message is hidden slash-skill activation context.""" + return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_SLASH_SKILL_ACTIVATION_KEY)) + + +def _is_user_activation_target(message: object) -> bool: + if not isinstance(message, HumanMessage): + return False + if message.name == _SUMMARY_MESSAGE_NAME: + return False + if message.additional_kwargs.get("hide_from_ui"): + return False + return True + + +class SkillActivationMiddleware(AgentMiddleware): + """Inject full SKILL.md content when the user explicitly types /skill-name.""" + + def __init__( + self, + *, + available_skills: set[str] | None = None, + app_config: AppConfig | None = None, + ) -> None: + super().__init__() + self._available_skills = set(available_skills) if available_skills is not None else None + self._app_config = app_config + + def _storage(self) -> SkillStorage: + if self._app_config is not None: + return get_or_new_skill_storage(app_config=self._app_config) + return get_or_new_skill_storage() + + @staticmethod + def _read_skill_content(skill_file: Path, skills_root: Path) -> str: + if skill_file.name != SKILL_MD_FILE: + raise ValueError(f"Expected {SKILL_MD_FILE}, got {skill_file.name}") + resolved_root = skills_root.resolve() + resolved_file = skill_file.resolve() + try: + resolved_file.relative_to(resolved_root) + except ValueError as exc: + raise ValueError("Resolved skill file must stay within the configured skills root.") from exc + if not resolved_file.is_file(): + raise FileNotFoundError(resolved_file) + return resolved_file.read_text(encoding="utf-8") + + def _resolve_activation(self, text: str) -> _ActivationResolution | None: + reference = parse_slash_skill_reference(text) + if reference is None: + return None + + storage = self._storage() + skills = storage.load_skills(enabled_only=False) + skill = next((candidate for candidate in skills if candidate.name == reference.name), None) + if skill is None: + return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is not installed.") + if not skill.enabled: + return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is installed but disabled. Enable it before using slash activation.") + if self._available_skills is not None and reference.name not in self._available_skills: + return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is not available for this agent.") + + resolved = resolve_slash_skill( + text, + skills, + available_skills=self._available_skills, + container_base_path=storage.get_container_root(), + ) + if resolved is None: + return _ActivationResolution(failure_message=f"Skill `/{reference.name}` could not be resolved.") + + try: + skill_content = self._read_skill_content(resolved.skill.skill_file, storage.get_skills_root_path()) + except (OSError, ValueError): + logger.exception("Failed to read slash-activated skill %s", resolved.skill.name) + return _ActivationResolution(failure_message=f"Skill `/{reference.name}` could not be loaded safely. Please check the skill installation.") + + content_hash = hashlib.sha256(skill_content.encode("utf-8")).hexdigest() + return _ActivationResolution( + activation=_Activation( + skill_name=resolved.skill.name, + category=str(resolved.skill.category), + container_file_path=resolved.container_file_path, + skill_content=skill_content, + content_hash=content_hash, + remaining_text=resolved.remaining_text, + ) + ) + + @staticmethod + def _build_activation_reminder(activation: _Activation) -> str: + user_request = activation.remaining_text or ("No additional task text was provided after the slash skill command. Ask the user what they want to do with this skill if the next step is unclear.") + escaped_user_request = html.escape(user_request, quote=False) + escaped_skill_content = html.escape(activation.skill_content, quote=False) + escaped_skill_name = html.escape(activation.skill_name, quote=True) + escaped_category = html.escape(activation.category, quote=True) + escaped_path = html.escape(activation.container_file_path, quote=True) + escaped_content_hash = html.escape(activation.content_hash, quote=True) + return f""" +The user explicitly activated the `{activation.skill_name}` skill for this turn. +Treat the task text as: + +{escaped_user_request} + + +Follow this skill before choosing a general workflow. Load supporting resources from the same skill directory only when needed. + + + +{escaped_skill_content} + + +""" + + @staticmethod + def _has_existing_activation_for_target(messages: list, target_index: int, target: HumanMessage) -> bool: + if target_index <= 0: + return False + + if target.id: + for previous in messages[:target_index]: + if not is_slash_skill_activation_reminder(previous): + continue + target_id = previous.additional_kwargs.get(_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY) + if target_id == target.id or previous.id == f"{target.id}__slash_activation": + return True + + previous = messages[target_index - 1] + return is_slash_skill_activation_reminder(previous) + + def _find_activation_target(self, messages: list) -> tuple[int, HumanMessage, _ActivationResolution] | None: + if not messages: + return None + + target_index = next((idx for idx in range(len(messages) - 1, -1, -1) if _is_user_activation_target(messages[idx])), None) + if target_index is None: + return None + + target = messages[target_index] + if target is None: + return None + if self._has_existing_activation_for_target(messages, target_index, target): + return None + + content = get_original_user_content_text(target.content, target.additional_kwargs) + resolution = self._resolve_activation(content) + if resolution is None: + return None + return target_index, target, resolution + + @staticmethod + def _record_activation(request: ModelRequest, activation: _Activation, *, hook: str) -> None: + runtime = getattr(request, "runtime", None) + context = getattr(runtime, "context", None) + journal = context.get("__run_journal") if isinstance(context, dict) else None + if journal is None: + return + try: + journal.record_middleware( + "skill_activation", + name="SkillActivationMiddleware", + hook=hook, + action="activate", + changes={ + "skill_name": activation.skill_name, + "category": activation.category, + "path": activation.container_file_path, + "content_hash": activation.content_hash, + }, + ) + except Exception: + logger.debug("Failed to record slash skill activation audit event", exc_info=True) + + def _prepare_model_request(self, request: ModelRequest, *, hook: str) -> ModelRequest | AIMessage | None: + target_and_resolution = self._find_activation_target(list(request.messages)) + if target_and_resolution is None: + return None + + target_index, target, resolution = target_and_resolution + if resolution.failure_message: + return AIMessage(content=resolution.failure_message) + + activation = resolution.activation + if activation is None: + return None + + logger.info( + "SkillActivationMiddleware: activating slash skill %s category=%s path=%s hash=%s", + activation.skill_name, + activation.category, + activation.container_file_path, + activation.content_hash, + ) + self._record_activation(request, activation, hook=hook) + activation_msg = self._make_activation_message(target, self._build_activation_reminder(activation)) + messages = list(request.messages) + messages.insert(target_index, activation_msg) + return request.override(messages=messages) + + @staticmethod + def _make_activation_message(target: HumanMessage, activation_content: str) -> HumanMessage: + stable_id = target.id or str(uuid.uuid4()) + additional_kwargs = { + "hide_from_ui": True, + _SLASH_SKILL_ACTIVATION_KEY: True, + } + if target.id: + additional_kwargs[_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY] = target.id + return HumanMessage( + content=activation_content, + id=f"{stable_id}__slash_activation", + additional_kwargs=additional_kwargs, + ) + + @override + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse | AIMessage: + prepared = self._prepare_model_request(request, hook="wrap_model_call") + if prepared is None: + return handler(request) + if isinstance(prepared, AIMessage): + return prepared + return handler(prepared) + + @override + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse | AIMessage: + prepared = await asyncio.to_thread(self._prepare_model_request, request, hook="awrap_model_call") + if prepared is None: + return await handler(request) + if isinstance(prepared, AIMessage): + return prepared + return await handler(prepared) diff --git a/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py index 61e822078..1d0d7a03f 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/uploads_middleware.py @@ -13,6 +13,7 @@ from langgraph.runtime import Runtime from deerflow.config.paths import Paths, get_paths from deerflow.runtime.user_context import get_effective_user_id from deerflow.utils.file_conversion import extract_outline +from deerflow.utils.messages import ORIGINAL_USER_CONTENT_KEY, message_content_to_text logger = logging.getLogger(__name__) @@ -265,6 +266,8 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]): # Extract original content - handle both string and list formats original_content = last_message.content + additional_kwargs = dict(last_message.additional_kwargs or {}) + additional_kwargs.setdefault(ORIGINAL_USER_CONTENT_KEY, message_content_to_text(original_content)) if isinstance(original_content, str): # Simple case: string content, just prepend files message updated_content = f"{files_message}\n\n{original_content}" @@ -285,7 +288,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]): content=updated_content, id=last_message.id, name=last_message.name, - additional_kwargs=last_message.additional_kwargs, + additional_kwargs=additional_kwargs, ) messages[last_message_index] = updated_message diff --git a/backend/packages/harness/deerflow/client.py b/backend/packages/harness/deerflow/client.py index 567338350..563c8f835 100644 --- a/backend/packages/harness/deerflow/client.py +++ b/backend/packages/harness/deerflow/client.py @@ -247,7 +247,15 @@ class DeerFlowClient: # Attaching them again on the model would emit duplicate spans. "model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, attach_tracing=False), "tools": final_tools, - "middleware": build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares, deferred_setup=deferred_setup), + "middleware": build_middlewares( + config, + model_name=model_name, + agent_name=self._agent_name, + available_skills=self._available_skills, + custom_middlewares=self._middlewares, + app_config=self._app_config, + deferred_setup=deferred_setup, + ), "system_prompt": apply_prompt_template( subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, diff --git a/backend/packages/harness/deerflow/skills/slash.py b/backend/packages/harness/deerflow/skills/slash.py new file mode 100644 index 000000000..757acea25 --- /dev/null +++ b/backend/packages/harness/deerflow/skills/slash.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass + +from deerflow.skills.types import Skill + +RESERVED_SLASH_SKILL_NAMES = frozenset({"bootstrap", "help", "memory", "models", "new", "status"}) +_SLASH_SKILL_RE = re.compile(r"^/([a-z0-9]+(?:-[a-z0-9]+)*)(?:\s+|$)") + + +@dataclass(frozen=True, slots=True) +class SlashSkillReference: + """Parsed slash-skill command with the skill name and remaining task text.""" + + name: str + remaining_text: str + + +@dataclass(frozen=True, slots=True) +class ResolvedSlashSkill: + """Slash-skill activation resolved against enabled runtime-visible skills.""" + + skill: Skill + remaining_text: str + container_file_path: str + + +def parse_slash_skill_reference(text: str) -> SlashSkillReference | None: + """Parse strict `/skill-name task` syntax, ignoring reserved control commands.""" + match = _SLASH_SKILL_RE.match(text) + if not match: + return None + name = match.group(1) + if name in RESERVED_SLASH_SKILL_NAMES: + return None + return SlashSkillReference( + name=name, + remaining_text=text[match.end() :].lstrip(), + ) + + +def resolve_slash_skill( + text: str, + skills: list[Skill], + *, + available_skills: set[str] | None = None, + container_base_path: str = "/mnt/skills", +) -> ResolvedSlashSkill | None: + """Resolve text into an enabled, whitelisted skill activation if possible.""" + reference = parse_slash_skill_reference(text) + if reference is None: + return None + if available_skills is not None and reference.name not in available_skills: + return None + + skill = next((candidate for candidate in skills if candidate.name == reference.name and candidate.enabled), None) + if skill is None: + return None + + return ResolvedSlashSkill( + skill=skill, + remaining_text=reference.remaining_text, + container_file_path=skill.get_container_file_path(container_base_path), + ) diff --git a/backend/packages/harness/deerflow/utils/messages.py b/backend/packages/harness/deerflow/utils/messages.py new file mode 100644 index 000000000..9ddf785fe --- /dev/null +++ b/backend/packages/harness/deerflow/utils/messages.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +ORIGINAL_USER_CONTENT_KEY = "original_user_content" + + +def message_content_to_text(content: Any) -> str: + """Extract text from LangChain message content shapes.""" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + parts.append(text) + return "\n".join(part for part in parts if part) + return str(content) + + +def get_original_user_content_text(content: Any, additional_kwargs: Mapping[str, Any] | None) -> str: + """Return pre-middleware user text when available, otherwise content text.""" + original_content = (additional_kwargs or {}).get(ORIGINAL_USER_CONTENT_KEY) + if isinstance(original_content, str): + return original_content + return message_content_to_text(content) diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 060d414b2..ba7ce7fc3 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -21,6 +21,42 @@ from app.channels.message_bus import ( ResolvedAttachment, ) from app.channels.store import ChannelStore +from deerflow.skills.types import Skill, SkillCategory +from deerflow.utils.messages import ORIGINAL_USER_CONTENT_KEY + + +def test_known_channel_command_detection_only_matches_control_commands(): + from app.channels.commands import is_known_channel_command + + assert is_known_channel_command("/new") + assert is_known_channel_command("/HELP now") + assert not is_known_channel_command("/mnt/user-data/uploads/report.pdf") + assert not is_known_channel_command("/data-analysis analyze uploads/foo.csv") + assert not is_known_channel_command(" /new") + + +def _make_channel_skill(tmp_path: Path, name: str, *, enabled: bool = True) -> Skill: + skill_dir = tmp_path / name + skill_dir.mkdir(parents=True, exist_ok=True) + skill_file = skill_dir / "SKILL.md" + skill_file.write_text(f"# {name}\n", encoding="utf-8") + return Skill( + name=name, + description=f"Description for {name}", + license="MIT", + skill_dir=skill_dir, + skill_file=skill_file, + relative_path=Path(name), + category=SkillCategory.CUSTOM, + enabled=enabled, + ) + + +def _make_channel_skill_storage(skills: list[Skill]): + return SimpleNamespace( + load_skills=lambda *, enabled_only: [skill for skill in skills if skill.enabled] if enabled_only else skills, + get_container_root=lambda: "/mnt/skills", + ) def _run(coro): @@ -1334,6 +1370,496 @@ class TestChannelManager: _run(go()) + def test_handle_command_blank_text_is_reported_without_running_agent(self): + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + + mock_client = _make_mock_langgraph_client() + manager._client = mock_client + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text=" ", + msg_type=InboundMessageType.COMMAND, + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + mock_client.runs.wait.assert_not_called() + assert outbound_received[0].text.startswith("Unknown command.") + + _run(go()) + + def test_handle_command_rejects_multi_slash_control_command(self): + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + + mock_client = _make_mock_langgraph_client() + manager._client = mock_client + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text="//help", + msg_type=InboundMessageType.COMMAND, + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + mock_client.runs.wait.assert_not_called() + assert outbound_received[0].text.startswith("Unknown command: //help.") + + _run(go()) + + def test_handle_command_requires_control_command_at_start(self): + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + + mock_client = _make_mock_langgraph_client(thread_id="new-thread-456") + manager._client = mock_client + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text=" /new", + msg_type=InboundMessageType.COMMAND, + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + mock_client.threads.create.assert_not_called() + assert store.get_thread_id("test", "chat1") is None + assert outbound_received[0].text.startswith("Unknown command: /new.") + + _run(go()) + + def test_handle_command_outbound_thread_id_uses_topic_thread(self): + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + store.set_thread_id("test", "chat1", "base-thread") + store.set_thread_id("test", "chat1", "topic-thread", topic_id="topic-1") + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text="/status", + msg_type=InboundMessageType.COMMAND, + topic_id="topic-1", + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + assert outbound_received[0].text == "Active thread: topic-thread" + assert outbound_received[0].thread_id == "topic-thread" + + _run(go()) + + def test_handle_command_slash_skill_routes_to_chat(self, tmp_path): + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + manager._skill_storage = _make_channel_skill_storage([_make_channel_skill(tmp_path, "data-analysis")]) + + mock_client = _make_mock_langgraph_client() + manager._client = mock_client + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text="/data-analysis analyze uploads/foo.csv", + msg_type=InboundMessageType.COMMAND, + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + mock_client.runs.wait.assert_called_once() + call_args = mock_client.runs.wait.call_args + assert call_args[1]["input"]["messages"][0]["content"] == "/data-analysis analyze uploads/foo.csv" + assert outbound_received[0].text == "Hello from agent!" + + _run(go()) + + def test_handle_command_slash_skill_with_attachment_preserves_original_content(self, monkeypatch, tmp_path): + from app.channels.manager import ChannelManager + + async def fake_ingest(thread_id, msg): + return [ + { + "filename": "report.pdf", + "size": 12, + "path": "/mnt/user-data/uploads/report.pdf", + "is_image": False, + } + ] + + monkeypatch.setattr("app.channels.manager._ingest_inbound_files", fake_ingest) + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + manager._skill_storage = _make_channel_skill_storage([_make_channel_skill(tmp_path, "data-analysis")]) + + mock_client = _make_mock_langgraph_client() + manager._client = mock_client + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + original_text = "/data-analysis analyze report.pdf" + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text=original_text, + files=[{"filename": "report.pdf"}], + msg_type=InboundMessageType.COMMAND, + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + mock_client.runs.wait.assert_called_once() + human_message = mock_client.runs.wait.call_args[1]["input"]["messages"][0] + assert human_message["content"].startswith("") + assert original_text in human_message["content"] + assert human_message["additional_kwargs"][ORIGINAL_USER_CONTENT_KEY] == original_text + assert outbound_received[0].text == "Hello from agent!" + + _run(go()) + + def test_streaming_slash_skill_with_attachment_preserves_original_content(self, monkeypatch, tmp_path): + from app.channels.manager import ChannelManager + + async def fake_ingest(thread_id, msg): + return [ + { + "filename": "report.pdf", + "size": 12, + "path": "/mnt/user-data/uploads/report.pdf", + "is_image": False, + } + ] + + monkeypatch.setattr("app.channels.manager._ingest_inbound_files", fake_ingest) + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + manager._skill_storage = _make_channel_skill_storage([_make_channel_skill(tmp_path, "data-analysis")]) + + mock_client = _make_mock_langgraph_client() + mock_client.runs.stream = MagicMock( + return_value=_make_async_iterator( + [ + _make_stream_part( + "values", + {"messages": [{"type": "ai", "content": "streamed response"}]}, + ) + ] + ) + ) + manager._client = mock_client + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + original_text = "/data-analysis analyze report.pdf" + inbound = InboundMessage( + channel_name="feishu", + chat_id="chat1", + user_id="user1", + text=original_text, + files=[{"filename": "report.pdf"}], + msg_type=InboundMessageType.COMMAND, + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: any(message.is_final for message in outbound_received)) + await manager.stop() + + mock_client.runs.stream.assert_called_once() + human_message = mock_client.runs.stream.call_args[1]["input"]["messages"][0] + assert human_message["content"].startswith("") + assert original_text in human_message["content"] + assert human_message["additional_kwargs"][ORIGINAL_USER_CONTENT_KEY] == original_text + + _run(go()) + + def test_handle_command_slash_skill_requires_command_at_start(self, tmp_path): + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + manager._skill_storage = _make_channel_skill_storage([_make_channel_skill(tmp_path, "data-analysis")]) + + mock_client = _make_mock_langgraph_client() + manager._client = mock_client + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text=" /data-analysis analyze uploads/foo.csv", + msg_type=InboundMessageType.COMMAND, + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + mock_client.runs.wait.assert_not_called() + assert outbound_received[0].text.startswith("Unknown command: /data-analysis.") + + _run(go()) + + def test_handle_command_slash_skill_respects_custom_agent_skill_whitelist(self, monkeypatch, tmp_path): + from app.channels.manager import ChannelManager + + monkeypatch.setattr("app.channels.manager.load_agent_config", lambda name: SimpleNamespace(skills=["frontend-design"])) + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager( + bus=bus, + store=store, + default_session={"assistant_id": "analyst-agent"}, + ) + manager._skill_storage = _make_channel_skill_storage([_make_channel_skill(tmp_path, "data-analysis")]) + + mock_client = _make_mock_langgraph_client() + manager._client = mock_client + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text="/data-analysis analyze uploads/foo.csv", + msg_type=InboundMessageType.COMMAND, + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + mock_client.runs.wait.assert_not_called() + assert outbound_received[0].text == "Skill `/data-analysis` is not available for this agent." + + _run(go()) + + def test_handle_command_slash_skill_reports_disabled_skill(self, tmp_path): + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + manager._skill_storage = _make_channel_skill_storage([_make_channel_skill(tmp_path, "data-analysis", enabled=False)]) + + mock_client = _make_mock_langgraph_client() + manager._client = mock_client + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text="/data-analysis analyze uploads/foo.csv", + msg_type=InboundMessageType.COMMAND, + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + mock_client.runs.wait.assert_not_called() + assert outbound_received[0].text == "Skill `/data-analysis` is installed but disabled. Enable it before using slash activation." + + _run(go()) + + def test_handle_command_uninstalled_slash_skill_stays_unknown_command(self, tmp_path): + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + manager._skill_storage = _make_channel_skill_storage([_make_channel_skill(tmp_path, "frontend-design")]) + + mock_client = _make_mock_langgraph_client() + manager._client = mock_client + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text="/data-analysis analyze uploads/foo.csv", + msg_type=InboundMessageType.COMMAND, + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + mock_client.runs.wait.assert_not_called() + assert outbound_received[0].text.startswith("Unknown command: /data-analysis.") + + _run(go()) + + def test_handle_command_slash_skill_resolution_error_is_reported(self, monkeypatch): + from app.channels.manager import ChannelManager, SlashSkillCommandResolutionError + + def fail_resolution(text, available_skills=None, storage=None): + raise SlashSkillCommandResolutionError("Failed to resolve slash skill command. Please check the skill configuration.") + + monkeypatch.setattr("app.channels.manager._resolve_slash_skill_command", fail_resolution) + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + store.set_thread_id("test", "chat1", "base-thread") + store.set_thread_id("test", "chat1", "topic-thread", topic_id="topic-1") + + mock_client = _make_mock_langgraph_client() + manager._client = mock_client + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text="/data-analysis analyze uploads/foo.csv", + msg_type=InboundMessageType.COMMAND, + topic_id="topic-1", + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 1) + await manager.stop() + + mock_client.runs.wait.assert_not_called() + assert outbound_received[0].text == "Failed to resolve slash skill command. Please check the skill configuration." + assert outbound_received[0].thread_id == "topic-thread" + + _run(go()) + def test_handle_command_new(self): from app.channels.manager import ChannelManager @@ -2440,6 +2966,36 @@ class TestWeComChannel: _run(go()) + def test_publish_ws_inbound_treats_slash_prefixed_paths_as_chat(self, monkeypatch): + from app.channels.wecom import WeComChannel + + async def go(): + bus = MessageBus() + bus.publish_inbound = AsyncMock() + channel = WeComChannel(bus, config={}) + channel._ws_client = SimpleNamespace(reply_stream=AsyncMock()) + + monkeypatch.setitem( + __import__("sys").modules, + "aibot", + SimpleNamespace(generate_req_id=lambda prefix: "stream-1"), + ) + + frame = { + "body": { + "msgid": "msg-1", + "from": {"userid": "user-1"}, + } + } + + await channel._publish_ws_inbound(frame, "/mnt/user-data/uploads/report.pdf") + + inbound = bus.publish_inbound.await_args.args[0] + assert inbound.text == "/mnt/user-data/uploads/report.pdf" + assert inbound.msg_type == InboundMessageType.CHAT + + _run(go()) + def test_on_outbound_sends_attachment_before_clearing_context(self, tmp_path): from app.channels.wecom import WeComChannel @@ -2788,6 +3344,219 @@ class TestSlackAllowedUsers: assert inbound.chat_id == "C123" assert inbound.text == "hello from slack" + def test_app_mention_strips_leading_bot_mention_before_command_detection(self): + from app.channels.slack import SlackChannel + + bus = MessageBus() + bus.publish_inbound = AsyncMock() + channel = SlackChannel(bus=bus, config={"bot_user_id": "UBOT"}) + channel._loop = MagicMock() + channel._loop.is_running.return_value = True + channel._add_reaction = MagicMock() + channel._send_running_reply = MagicMock() + + event = { + "type": "app_mention", + "user": "U123456", + "text": "<@UBOT> /help", + "channel": "C123", + "ts": "1710000000.000100", + } + + with patch( + "app.channels.slack.asyncio.run_coroutine_threadsafe", + side_effect=self._submit_coro, + ): + channel._handle_message_event(event) + + inbound = bus.publish_inbound.call_args.args[0] + assert inbound.text == "/help" + assert inbound.msg_type == InboundMessageType.COMMAND + + def test_app_mention_strips_labelled_leading_bot_mention(self): + from app.channels.slack import SlackChannel + + bus = MessageBus() + bus.publish_inbound = AsyncMock() + channel = SlackChannel(bus=bus, config={"bot_user_id": "UBOT"}) + channel._loop = MagicMock() + channel._loop.is_running.return_value = True + channel._add_reaction = MagicMock() + channel._send_running_reply = MagicMock() + + event = { + "type": "app_mention", + "user": "U123456", + "text": "<@UBOT|deerflow> /help", + "channel": "C123", + "ts": "1710000000.000100", + } + + with patch( + "app.channels.slack.asyncio.run_coroutine_threadsafe", + side_effect=self._submit_coro, + ): + channel._handle_message_event(event) + + inbound = bus.publish_inbound.call_args.args[0] + assert inbound.text == "/help" + assert inbound.msg_type == InboundMessageType.COMMAND + + def test_app_mention_strips_leading_bot_mention_before_slash_skill(self): + from app.channels.slack import SlackChannel + + bus = MessageBus() + bus.publish_inbound = AsyncMock() + channel = SlackChannel(bus=bus, config={"bot_user_id": "UBOT"}) + channel._loop = MagicMock() + channel._loop.is_running.return_value = True + channel._add_reaction = MagicMock() + channel._send_running_reply = MagicMock() + + event = { + "type": "app_mention", + "user": "U123456", + "text": "<@UBOT> /data-analysis analyze uploads/foo.csv", + "channel": "C123", + "ts": "1710000000.000100", + } + + with patch( + "app.channels.slack.asyncio.run_coroutine_threadsafe", + side_effect=self._submit_coro, + ): + channel._handle_message_event(event) + + inbound = bus.publish_inbound.call_args.args[0] + assert inbound.text == "/data-analysis analyze uploads/foo.csv" + assert inbound.msg_type == InboundMessageType.CHAT + + def test_app_mention_preserves_following_user_mention(self): + from app.channels.slack import SlackChannel + + bus = MessageBus() + bus.publish_inbound = AsyncMock() + channel = SlackChannel(bus=bus, config={"bot_user_id": "UBOT"}) + channel._loop = MagicMock() + channel._loop.is_running.return_value = True + channel._add_reaction = MagicMock() + channel._send_running_reply = MagicMock() + + event = { + "type": "app_mention", + "user": "U123456", + "text": "<@UBOT> <@UASSIGNEE> please review this", + "channel": "C123", + "ts": "1710000000.000100", + } + + with patch( + "app.channels.slack.asyncio.run_coroutine_threadsafe", + side_effect=self._submit_coro, + ): + channel._handle_message_event(event) + + inbound = bus.publish_inbound.call_args.args[0] + assert inbound.text == "<@UASSIGNEE> please review this" + assert inbound.msg_type == InboundMessageType.CHAT + + def test_app_mention_preserves_leading_non_bot_mention_when_bot_id_known(self): + from app.channels.slack import SlackChannel + + bus = MessageBus() + bus.publish_inbound = AsyncMock() + channel = SlackChannel(bus=bus, config={"bot_user_id": "UBOT"}) + channel._loop = MagicMock() + channel._loop.is_running.return_value = True + channel._add_reaction = MagicMock() + channel._send_running_reply = MagicMock() + + event = { + "type": "app_mention", + "user": "U123456", + "text": "<@UASSIGNEE> <@UBOT> please review this", + "channel": "C123", + "ts": "1710000000.000100", + } + + with patch( + "app.channels.slack.asyncio.run_coroutine_threadsafe", + side_effect=self._submit_coro, + ): + channel._handle_message_event(event) + + inbound = bus.publish_inbound.call_args.args[0] + assert inbound.text == "<@UASSIGNEE> <@UBOT> please review this" + assert inbound.msg_type == InboundMessageType.CHAT + + def test_app_mention_preserves_leading_non_bot_mention_when_bot_id_unknown(self): + from app.channels.slack import SlackChannel + + bus = MessageBus() + bus.publish_inbound = AsyncMock() + channel = SlackChannel(bus=bus, config={}) + channel._loop = MagicMock() + channel._loop.is_running.return_value = True + channel._add_reaction = MagicMock() + channel._send_running_reply = MagicMock() + + event = { + "type": "app_mention", + "user": "U123456", + "text": "<@UASSIGNEE> /help <@UBOT>", + "channel": "C123", + "ts": "1710000000.000100", + } + + with patch( + "app.channels.slack.asyncio.run_coroutine_threadsafe", + side_effect=self._submit_coro, + ): + channel._handle_message_event(event) + + inbound = bus.publish_inbound.call_args.args[0] + assert inbound.text == "<@UASSIGNEE> /help <@UBOT>" + assert inbound.msg_type == InboundMessageType.CHAT + + def test_socket_event_resolves_bot_user_id_before_app_mention_command_detection(self): + from app.channels.slack import SlackChannel + + bus = MessageBus() + bus.publish_inbound = AsyncMock() + channel = SlackChannel(bus=bus, config={}) + channel._SocketModeResponse = lambda envelope_id: SimpleNamespace(envelope_id=envelope_id) + channel._loop = MagicMock() + channel._loop.is_running.return_value = True + channel._add_reaction = MagicMock() + channel._send_running_reply = MagicMock() + + client = SimpleNamespace(send_socket_mode_response=MagicMock()) + req = SimpleNamespace( + envelope_id="env-1", + type="events_api", + payload={ + "authorizations": [{"user_id": "UBOT"}], + "event": { + "type": "app_mention", + "user": "U123456", + "text": "<@UBOT> /help", + "channel": "C123", + "ts": "1710000000.000100", + }, + }, + ) + + with patch( + "app.channels.slack.asyncio.run_coroutine_threadsafe", + side_effect=self._submit_coro, + ): + channel._on_socket_event(client, req) + + inbound = bus.publish_inbound.call_args.args[0] + assert channel._bot_user_id == "UBOT" + assert inbound.text == "/help" + assert inbound.msg_type == InboundMessageType.COMMAND + def test_scalar_allowed_users_warns_and_matches_stringified_event_user_id(self, caplog): from app.channels.slack import SlackChannel @@ -2861,6 +3630,86 @@ class TestSlackAllowedUsers: class TestTelegramSendRetry: + def test_start_registers_known_channel_commands(self, monkeypatch): + import sys + from types import ModuleType + + from app.channels.commands import KNOWN_CHANNEL_COMMANDS + from app.channels.telegram import TelegramChannel + + class FakeFilter: + def __init__(self, expr: str): + self.expr = expr + + def __and__(self, other): + return FakeFilter(f"{self.expr}&{other.expr}") + + def __invert__(self): + return FakeFilter(f"~{self.expr}") + + class FakeApplication: + def __init__(self): + self.handlers = [] + + def add_handler(self, handler): + self.handlers.append(handler) + + fake_app = FakeApplication() + + class FakeApplicationBuilder: + def token(self, token): + assert token == "test-token" + return self + + def build(self): + return fake_app + + def fake_command_handler(command, callback): + return SimpleNamespace(kind="command", command=command, callback=callback) + + def fake_message_handler(filter_expr, callback): + return SimpleNamespace(kind="message", filter_expr=filter_expr, callback=callback) + + telegram_mod = ModuleType("telegram") + telegram_ext_mod = ModuleType("telegram.ext") + telegram_ext_mod.ApplicationBuilder = FakeApplicationBuilder + telegram_ext_mod.CommandHandler = fake_command_handler + telegram_ext_mod.MessageHandler = fake_message_handler + telegram_ext_mod.filters = SimpleNamespace(TEXT=FakeFilter("TEXT"), COMMAND=FakeFilter("COMMAND")) + telegram_mod.ext = telegram_ext_mod + monkeypatch.setitem(sys.modules, "telegram", telegram_mod) + monkeypatch.setitem(sys.modules, "telegram.ext", telegram_ext_mod) + + class FakeThread: + def __init__(self, *, target, daemon): + self.target = target + self.daemon = daemon + + def start(self): + return None + + def join(self, timeout=None): + return None + + monkeypatch.setattr("app.channels.telegram.threading.Thread", FakeThread) + + async def go(): + bus = MessageBus() + ch = TelegramChannel(bus=bus, config={"bot_token": "test-token"}) + + await ch.start() + try: + registered_commands = {handler.command for handler in fake_app.handlers if handler.kind == "command"} + expected_commands = {command.removeprefix("/") for command in KNOWN_CHANNEL_COMMANDS} + assert expected_commands <= registered_commands + assert "start" in registered_commands + message_filters = {handler.filter_expr.expr for handler in fake_app.handlers if handler.kind == "message"} + assert {"TEXT&COMMAND", "TEXT&~COMMAND"} <= message_filters + finally: + await ch.stop() + + _run(go()) + def test_retries_on_failure_then_succeeds(self): from app.channels.telegram import TelegramChannel @@ -2984,6 +3833,47 @@ class TestTelegramPrivateChatThread: _run(go()) + def test_private_chat_slash_skill_text_routes_as_chat(self): + from app.channels.telegram import TelegramChannel + + async def go(): + bus = MessageBus() + ch = TelegramChannel(bus=bus, config={"bot_token": "test-token"}) + ch._main_loop = asyncio.get_event_loop() + + update = _make_telegram_update("private", message_id=12, text="/data-analysis analyze uploads/foo.csv") + await ch._on_text(update, None) + + msg = await asyncio.wait_for(bus.get_inbound(), timeout=2) + assert msg.text == "/data-analysis analyze uploads/foo.csv" + assert msg.msg_type == InboundMessageType.CHAT + assert msg.topic_id is None + + _run(go()) + + def test_slash_skill_addressed_to_telegram_bot_strips_username(self): + from app.channels.telegram import TelegramChannel + + async def go(): + bus = MessageBus() + ch = TelegramChannel(bus=bus, config={"bot_token": "test-token"}) + ch._main_loop = asyncio.get_event_loop() + + update = _make_telegram_update( + "group", + message_id=13, + text="/data-analysis@DeerFlowBot analyze uploads/foo.csv", + ) + context = SimpleNamespace(bot=SimpleNamespace(username="DeerFlowBot")) + await ch._on_text(update, context) + + msg = await asyncio.wait_for(bus.get_inbound(), timeout=2) + assert msg.text == "/data-analysis analyze uploads/foo.csv" + assert msg.msg_type == InboundMessageType.CHAT + assert msg.topic_id == "13" + + _run(go()) + def test_private_chat_with_reply_still_uses_none_topic(self): from app.channels.telegram import TelegramChannel @@ -3099,6 +3989,25 @@ class TestTelegramPrivateChatThread: _run(go()) + def test_cmd_generic_strips_addressed_telegram_bot_username(self): + from app.channels.telegram import TelegramChannel + + async def go(): + bus = MessageBus() + ch = TelegramChannel(bus=bus, config={"bot_token": "test-token"}) + ch._main_loop = asyncio.get_event_loop() + + update = _make_telegram_update("group", message_id=33, text="/status@DeerFlowBot") + context = SimpleNamespace(bot=SimpleNamespace(username="DeerFlowBot")) + await ch._cmd_generic(update, context) + + msg = await asyncio.wait_for(bus.get_inbound(), timeout=2) + assert msg.text == "/status" + assert msg.topic_id == "33" + assert msg.msg_type == InboundMessageType.COMMAND + + _run(go()) + class TestTelegramProcessingOrder: """Ensure 'working on it...' is sent before inbound is published.""" diff --git a/backend/tests/test_discord_channel.py b/backend/tests/test_discord_channel.py index 204d03bfc..0cebce5af 100644 --- a/backend/tests/test_discord_channel.py +++ b/backend/tests/test_discord_channel.py @@ -2,9 +2,13 @@ from __future__ import annotations +from types import SimpleNamespace + +import pytest + from app.channels.discord import DiscordChannel from app.channels.manager import CHANNEL_CAPABILITIES -from app.channels.message_bus import MessageBus +from app.channels.message_bus import InboundMessageType, MessageBus from app.channels.service import _CHANNEL_REGISTRY @@ -21,3 +25,64 @@ def test_discord_channel_init() -> None: channel = DiscordChannel(bus=bus, config={"bot_token": "token"}) assert channel.name == "discord" + + +def _make_discord_message(text: str): + return SimpleNamespace( + id=111, + content=text, + author=SimpleNamespace(id=123, bot=False, display_name="alice"), + guild=SimpleNamespace(id=321), + channel=SimpleNamespace(id=456), + add_reaction=lambda _emoji: None, + ) + + +@pytest.mark.asyncio +async def test_discord_bot_mention_slash_skill_routes_as_chat() -> None: + bus = MessageBus() + channel = DiscordChannel(bus=bus, config={"bot_token": "token"}) + captured = [] + channel._running = True + channel._client = SimpleNamespace(user=SimpleNamespace(id=999, mention="<@999>")) + channel._discord_module = SimpleNamespace(Thread=type("FakeThread", (), {})) + channel._publish = captured.append + + async def noop(*_args, **_kwargs): + return None + + channel._start_typing = noop + channel._add_reaction = noop + + await channel._on_message(_make_discord_message("<@999> /data-analysis analyze uploads/foo.csv")) + + assert len(captured) == 1 + inbound = captured[0] + assert inbound.text == "/data-analysis analyze uploads/foo.csv" + assert inbound.msg_type == InboundMessageType.CHAT + assert inbound.topic_id == "456" + + +@pytest.mark.asyncio +async def test_discord_bot_mention_known_command_routes_as_command() -> None: + bus = MessageBus() + channel = DiscordChannel(bus=bus, config={"bot_token": "token"}) + captured = [] + channel._running = True + channel._client = SimpleNamespace(user=SimpleNamespace(id=999, mention="<@999>")) + channel._discord_module = SimpleNamespace(Thread=type("FakeThread", (), {})) + channel._publish = captured.append + + async def noop(*_args, **_kwargs): + return None + + channel._start_typing = noop + channel._add_reaction = noop + + await channel._on_message(_make_discord_message("<@999> /help")) + + assert len(captured) == 1 + inbound = captured[0] + assert inbound.text == "/help" + assert inbound.msg_type == InboundMessageType.COMMAND + assert inbound.topic_id == "456" diff --git a/backend/tests/test_lead_agent_skills.py b/backend/tests/test_lead_agent_skills.py index 2f625857f..f10aa6fce 100644 --- a/backend/tests/test_lead_agent_skills.py +++ b/backend/tests/test_lead_agent_skills.py @@ -60,6 +60,17 @@ def test_get_skills_prompt_section_returns_all_when_available_skills_is_none(mon assert "skill2" in result +def test_get_skills_prompt_section_includes_slash_activation_guidance(monkeypatch): + skills = [_make_skill("data-analysis")] + monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills) + + result = get_skills_prompt_section(available_skills={"data-analysis"}) + + assert "Explicit Slash Skill Activation" in result + assert "The runtime injects the activated skill content" in result + assert "do not call `read_file` for that SKILL.md again" in result + + def test_get_skills_prompt_section_includes_self_evolution_rules(monkeypatch): skills = [_make_skill("skill1")] monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills) diff --git a/backend/tests/test_slash_skills.py b/backend/tests/test_slash_skills.py new file mode 100644 index 000000000..902bb7c1a --- /dev/null +++ b/backend/tests/test_slash_skills.py @@ -0,0 +1,557 @@ +import asyncio +import hashlib +from pathlib import Path +from types import SimpleNamespace + +from langchain.agents.middleware.types import ModelRequest +from langchain_core.messages import AIMessage, HumanMessage + +from app.channels.commands import KNOWN_CHANNEL_COMMANDS +from deerflow.agents.middlewares import skill_activation_middleware as middleware_module +from deerflow.agents.middlewares.skill_activation_middleware import SkillActivationMiddleware, is_slash_skill_activation_reminder +from deerflow.skills.slash import RESERVED_SLASH_SKILL_NAMES, parse_slash_skill_reference, resolve_slash_skill +from deerflow.skills.types import Skill, SkillCategory +from deerflow.utils.messages import ORIGINAL_USER_CONTENT_KEY + + +def _make_skill(tmp_path: Path, name: str, content: str = "skill body") -> Skill: + skill_dir = tmp_path / name + skill_dir.mkdir() + skill_file = skill_dir / "SKILL.md" + skill_file.write_text(content, encoding="utf-8") + return Skill( + name=name, + description=f"Description for {name}", + license="MIT", + skill_dir=skill_dir, + skill_file=skill_file, + relative_path=Path(name), + category=SkillCategory.CUSTOM, + enabled=True, + ) + + +def _make_storage(tmp_path: Path, skills: list[Skill]): + return SimpleNamespace( + load_skills=lambda *, enabled_only: [skill for skill in skills if skill.enabled] if enabled_only else skills, + get_container_root=lambda: "/mnt/skills", + get_skills_root_path=lambda: tmp_path, + ) + + +def _make_model_request(messages: list[HumanMessage], *, runtime=None) -> ModelRequest: + return ModelRequest( + model=object(), + messages=messages, + state={"messages": list(messages)}, + runtime=runtime, + ) + + +def test_parse_slash_skill_reference_extracts_name_and_remaining_text(): + parsed = parse_slash_skill_reference("/data-analysis analyze uploads/foo.csv") + + assert parsed is not None + assert parsed.name == "data-analysis" + assert parsed.remaining_text == "analyze uploads/foo.csv" + + +def test_parse_slash_skill_reference_accepts_skill_name_without_task(): + parsed = parse_slash_skill_reference("/data-analysis") + + assert parsed is not None + assert parsed.name == "data-analysis" + assert parsed.remaining_text == "" + + +def test_parse_slash_skill_reference_rejects_invalid_names(): + assert parse_slash_skill_reference("/DataAnalysis run") is None + assert parse_slash_skill_reference("/data_analysis run") is None + assert parse_slash_skill_reference("please use /data-analysis") is None + assert parse_slash_skill_reference(" /data-analysis run") is None + assert parse_slash_skill_reference("/data-analysis分析这个文档") is None + + +def test_resolve_slash_skill_ignores_reserved_control_commands(tmp_path): + for command in ["bootstrap", "help", "memory", "models", "new", "status"]: + skill = _make_skill(tmp_path, command) + + assert resolve_slash_skill(f"/{command} create an agent", [skill]) is None + + +def test_reserved_slash_skill_names_match_channel_commands(): + assert RESERVED_SLASH_SKILL_NAMES == {command.removeprefix("/") for command in KNOWN_CHANNEL_COMMANDS} + + +def test_resolve_slash_skill_respects_available_skill_whitelist(tmp_path): + skill = _make_skill(tmp_path, "data-analysis") + + assert resolve_slash_skill("/data-analysis run", [skill], available_skills=set()) is None + + resolved = resolve_slash_skill("/data-analysis run", [skill], available_skills={"data-analysis"}) + assert resolved is not None + assert resolved.skill.name == "data-analysis" + assert resolved.remaining_text == "run" + assert resolved.container_file_path == "/mnt/skills/custom/data-analysis/SKILL.md" + + +def test_resolve_slash_skill_rejects_disabled_skills(tmp_path): + skill = _make_skill(tmp_path, "data-analysis") + skill.enabled = False + + assert resolve_slash_skill("/data-analysis run", [skill]) is None + + +def test_skill_activation_middleware_injects_hidden_human_context_for_model_call(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + original = HumanMessage(content="/data-analysis analyze uploads/foo.csv", id="msg-1") + request = _make_model_request([original]) + captured = {} + + def handler(model_request: ModelRequest): + captured["messages"] = model_request.messages + return AIMessage(content="ok") + + result = middleware.wrap_model_call(request, handler) + + assert isinstance(result, AIMessage) + assert result.content == "ok" + activation_msg, user_msg = captured["messages"] + assert is_slash_skill_activation_reminder(activation_msg) + assert activation_msg.additional_kwargs["hide_from_ui"] is True + assert "Use pandas." in activation_msg.content + assert "\nanalyze uploads/foo.csv\n" in activation_msg.content + assert user_msg.content == original.content + assert request.state["messages"] == [original] + + +def test_skill_activation_middleware_does_not_duplicate_existing_activation(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + original = HumanMessage(content="/data-analysis analyze uploads/foo.csv", id="msg-1") + first_capture = {} + + def first_handler(model_request: ModelRequest): + first_capture["messages"] = model_request.messages + return AIMessage(content="ok") + + first_result = middleware.wrap_model_call(_make_model_request([original]), first_handler) + + assert isinstance(first_result, AIMessage) + activation_msg, user_msg = first_capture["messages"] + assert is_slash_skill_activation_reminder(activation_msg) + + second_capture = {} + + def second_handler(model_request: ModelRequest): + second_capture["messages"] = model_request.messages + return AIMessage(content="ok") + + second_result = middleware.wrap_model_call(_make_model_request([activation_msg, user_msg]), second_handler) + + assert isinstance(second_result, AIMessage) + assert second_capture["messages"] == [activation_msg, user_msg] + assert sum(is_slash_skill_activation_reminder(message) for message in second_capture["messages"]) == 1 + + +def test_skill_activation_middleware_does_not_duplicate_activation_separated_by_hidden_context(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + original = HumanMessage(content="/data-analysis analyze uploads/foo.csv", id="msg-1") + first_capture = {} + + def first_handler(model_request: ModelRequest): + first_capture["messages"] = model_request.messages + return AIMessage(content="ok") + + middleware.wrap_model_call(_make_model_request([original]), first_handler) + activation_msg, user_msg = first_capture["messages"] + hidden_context = HumanMessage(content="dynamic context", additional_kwargs={"hide_from_ui": True}) + second_capture = {} + + def second_handler(model_request: ModelRequest): + second_capture["messages"] = model_request.messages + return AIMessage(content="ok") + + second_result = middleware.wrap_model_call(_make_model_request([activation_msg, hidden_context, user_msg]), second_handler) + + assert isinstance(second_result, AIMessage) + assert second_capture["messages"] == [activation_msg, hidden_context, user_msg] + assert sum(is_slash_skill_activation_reminder(message) for message in second_capture["messages"]) == 1 + + +def test_skill_activation_middleware_dedupes_immediately_previous_activation_without_target_id(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + legacy_activation_msg = SkillActivationMiddleware._make_activation_message( + HumanMessage(content="/data-analysis analyze uploads/foo.csv"), + "existing activation context", + ) + target = HumanMessage(content="/data-analysis analyze uploads/foo.csv", id="msg-1") + captured = {} + + def handler(model_request: ModelRequest): + captured["messages"] = model_request.messages + return AIMessage(content="ok") + + result = middleware.wrap_model_call(_make_model_request([legacy_activation_msg, target]), handler) + + assert isinstance(result, AIMessage) + assert captured["messages"] == [legacy_activation_msg, target] + assert sum(is_slash_skill_activation_reminder(message) for message in captured["messages"]) == 1 + + +def test_skill_activation_middleware_async_injects_hidden_human_context_for_model_call(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + original = HumanMessage(content="/data-analysis analyze uploads/foo.csv", id="msg-1") + request = _make_model_request([original]) + captured = {} + + async def handler(model_request: ModelRequest): + captured["messages"] = model_request.messages + return AIMessage(content="ok") + + result = asyncio.run(middleware.awrap_model_call(request, handler)) + + assert isinstance(result, AIMessage) + assert result.content == "ok" + activation_msg, user_msg = captured["messages"] + assert is_slash_skill_activation_reminder(activation_msg) + assert activation_msg.additional_kwargs["hide_from_ui"] is True + assert "Use pandas." in activation_msg.content + assert "\nanalyze uploads/foo.csv\n" in activation_msg.content + assert user_msg.content == original.content + assert request.state["messages"] == [original] + + +def test_skill_activation_middleware_uses_fallback_when_task_text_is_empty(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + original = HumanMessage(content="/data-analysis", id="msg-1") + captured = {} + + def handler(model_request: ModelRequest): + captured["messages"] = model_request.messages + return AIMessage(content="ok") + + result = middleware.wrap_model_call(_make_model_request([original]), handler) + + assert isinstance(result, AIMessage) + activation_msg = captured["messages"][0] + assert "No additional task text was provided after the slash skill command." in activation_msg.content + + +def test_skill_activation_middleware_uses_original_user_content_when_uploads_are_injected(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + original = HumanMessage( + content="\n- report.pdf\n\n\n/data-analysis 分析这个文档", + id="msg-1", + additional_kwargs={ORIGINAL_USER_CONTENT_KEY: "/data-analysis 分析这个文档"}, + ) + captured = {} + + def handler(model_request: ModelRequest): + captured["messages"] = model_request.messages + return AIMessage(content="ok") + + result = middleware.wrap_model_call(_make_model_request([original]), handler) + + assert isinstance(result, AIMessage) + assert result.content == "ok" + activation_msg, user_msg = captured["messages"] + assert is_slash_skill_activation_reminder(activation_msg) + assert "Use pandas." in activation_msg.content + assert "\n分析这个文档\n" in activation_msg.content + assert user_msg.content == original.content + assert user_msg.additional_kwargs[ORIGINAL_USER_CONTENT_KEY] == "/data-analysis 分析这个文档" + + +def test_skill_activation_middleware_activates_from_list_content(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + original = HumanMessage(content=[{"type": "text", "text": "/data-analysis analyze uploads/foo.csv"}], id="msg-1") + captured = {} + + def handler(model_request: ModelRequest): + captured["messages"] = model_request.messages + return AIMessage(content="ok") + + result = middleware.wrap_model_call(_make_model_request([original]), handler) + + assert isinstance(result, AIMessage) + activation_msg, user_msg = captured["messages"] + assert is_slash_skill_activation_reminder(activation_msg) + assert "\nanalyze uploads/foo.csv\n" in activation_msg.content + assert user_msg.content == original.content + + +def test_skill_activation_middleware_records_activation_audit_event(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + recorded = [] + journal = SimpleNamespace(record_middleware=lambda *args, **kwargs: recorded.append((args, kwargs))) + runtime = SimpleNamespace(context={"__run_journal": journal}) + middleware = SkillActivationMiddleware() + original = HumanMessage(content="/data-analysis analyze uploads/foo.csv", id="msg-1") + + def handler(model_request: ModelRequest): + return AIMessage(content="ok") + + result = middleware.wrap_model_call(_make_model_request([original], runtime=runtime), handler) + + assert isinstance(result, AIMessage) + assert len(recorded) == 1 + args, kwargs = recorded[0] + assert args == ("skill_activation",) + assert kwargs["name"] == "SkillActivationMiddleware" + assert kwargs["hook"] == "wrap_model_call" + assert kwargs["action"] == "activate" + assert kwargs["changes"] == { + "skill_name": "data-analysis", + "category": "custom", + "path": "/mnt/skills/custom/data-analysis/SKILL.md", + "content_hash": hashlib.sha256(b"# Data Analysis\nUse pandas.").hexdigest(), + } + + +def test_skill_activation_middleware_async_records_activation_audit_event(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + recorded = [] + journal = SimpleNamespace(record_middleware=lambda *args, **kwargs: recorded.append((args, kwargs))) + runtime = SimpleNamespace(context={"__run_journal": journal}) + middleware = SkillActivationMiddleware() + original = HumanMessage(content="/data-analysis analyze uploads/foo.csv", id="msg-1") + + async def handler(model_request: ModelRequest): + return AIMessage(content="ok") + + result = asyncio.run(middleware.awrap_model_call(_make_model_request([original], runtime=runtime), handler)) + + assert isinstance(result, AIMessage) + assert len(recorded) == 1 + args, kwargs = recorded[0] + assert args == ("skill_activation",) + assert kwargs["hook"] == "awrap_model_call" + assert kwargs["changes"]["skill_name"] == "data-analysis" + assert kwargs["changes"]["content_hash"] == hashlib.sha256(b"# Data Analysis\nUse pandas.").hexdigest() + + +def test_skill_activation_middleware_ignores_activation_audit_errors(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + journal = SimpleNamespace(record_middleware=lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("db down"))) + runtime = SimpleNamespace(context={"__run_journal": journal}) + middleware = SkillActivationMiddleware() + original = HumanMessage(content="/data-analysis analyze uploads/foo.csv", id="msg-1") + + def handler(model_request: ModelRequest): + return AIMessage(content="ok") + + result = middleware.wrap_model_call(_make_model_request([original], runtime=runtime), handler) + + assert isinstance(result, AIMessage) + assert result.content == "ok" + + +def test_skill_activation_middleware_activates_only_latest_real_user_message(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + old_slash = HumanMessage(content="/data-analysis old request", id="msg-1") + latest_user = HumanMessage(content="continue normally", id="msg-2") + request = _make_model_request([old_slash, AIMessage(content="done"), latest_user]) + captured = {} + + def handler(model_request: ModelRequest): + captured["messages"] = model_request.messages + return AIMessage(content="ok") + + result = middleware.wrap_model_call(request, handler) + + assert isinstance(result, AIMessage) + assert captured["messages"] == request.messages + assert not any(is_slash_skill_activation_reminder(message) for message in captured["messages"]) + + +def test_skill_activation_middleware_ignores_hidden_and_summary_user_messages(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis", content="# Data Analysis\nUse pandas.") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + real_user = HumanMessage(content="continue normally", id="msg-1") + hidden_slash = HumanMessage(content="/data-analysis hidden request", id="msg-2", additional_kwargs={"hide_from_ui": True}) + summary_slash = HumanMessage(content="/data-analysis summary request", id="msg-3", name="summary") + request = _make_model_request([real_user, hidden_slash, summary_slash]) + captured = {} + + def handler(model_request: ModelRequest): + captured["messages"] = model_request.messages + return AIMessage(content="ok") + + result = middleware.wrap_model_call(request, handler) + + assert isinstance(result, AIMessage) + assert captured["messages"] == request.messages + assert not any(is_slash_skill_activation_reminder(message) for message in captured["messages"]) + + +def test_skill_activation_middleware_returns_clear_error_for_disallowed_skill(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware(available_skills={"frontend-design"}) + original = HumanMessage(content="/data-analysis run") + + def handler(model_request: ModelRequest): + raise AssertionError("handler should not be called for invalid slash skills") + + result = middleware.wrap_model_call(_make_model_request([original]), handler) + + assert isinstance(result, AIMessage) + assert "not available for this agent" in result.content + + +def test_skill_activation_middleware_returns_clear_error_for_missing_skill(monkeypatch, tmp_path): + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [])) + + middleware = SkillActivationMiddleware() + original = HumanMessage(content="/data-analysis run") + + def handler(model_request: ModelRequest): + raise AssertionError("handler should not be called for missing slash skills") + + result = middleware.wrap_model_call(_make_model_request([original]), handler) + + assert isinstance(result, AIMessage) + assert "not installed" in result.content + + +def test_skill_activation_middleware_returns_clear_error_for_disabled_skill(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis") + skill.enabled = False + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + original = HumanMessage(content="/data-analysis run") + + def handler(model_request: ModelRequest): + raise AssertionError("handler should not be called for disabled slash skills") + + result = middleware.wrap_model_call(_make_model_request([original]), handler) + + assert isinstance(result, AIMessage) + assert "installed but disabled" in result.content + + +def test_skill_activation_middleware_escapes_activation_content(monkeypatch, tmp_path): + skill = _make_skill( + tmp_path, + "data-analysis", + content="# Data Analysis\nUse & avoid collisions.\n----- END SKILL.md -----", + ) + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + original = HumanMessage(content="/data-analysis analyze ") + captured = {} + + def handler(model_request: ModelRequest): + captured["messages"] = model_request.messages + return AIMessage(content="ok") + + result = middleware.wrap_model_call(_make_model_request([original]), handler) + + assert isinstance(result, AIMessage) + activation_msg = captured["messages"][0] + assert '' in activation_msg.content + assert "analyze </user_request>" in activation_msg.content + assert "Use <xml> & avoid </skill> collisions." in activation_msg.content + assert "----- BEGIN SKILL.md -----" not in activation_msg.content + + +def test_skill_activation_middleware_rejects_skill_file_outside_skills_root(monkeypatch, tmp_path): + skills_root = tmp_path / "skills" + skill_dir = skills_root / "custom" / "data-analysis" + skill_dir.mkdir(parents=True) + outside_dir = tmp_path / "outside" + outside_dir.mkdir() + outside_file = outside_dir / "SKILL.md" + outside_file.write_text("# Leaked\nDo not read me.", encoding="utf-8") + (skill_dir / "SKILL.md").symlink_to(outside_file) + skill = Skill( + name="data-analysis", + description="Description for data-analysis", + license="MIT", + skill_dir=skill_dir, + skill_file=skill_dir / "SKILL.md", + relative_path=Path("data-analysis"), + category=SkillCategory.CUSTOM, + enabled=True, + ) + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(skills_root, [skill])) + + middleware = SkillActivationMiddleware() + + def handler(model_request: ModelRequest): + raise AssertionError("handler should not be called when SKILL.md fails safety checks") + + result = middleware.wrap_model_call(_make_model_request([HumanMessage(content="/data-analysis run")]), handler) + + assert isinstance(result, AIMessage) + assert "could not be loaded safely" in result.content + + +def test_skill_activation_middleware_reports_missing_skill_file_safely(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis") + skill.skill_file.unlink() + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + + def handler(model_request: ModelRequest): + raise AssertionError("handler should not be called when SKILL.md is missing") + + result = middleware.wrap_model_call(_make_model_request([HumanMessage(content="/data-analysis run")]), handler) + + assert isinstance(result, AIMessage) + assert "could not be loaded safely" in result.content + + +def test_skill_activation_middleware_reports_invalid_utf8_skill_file_safely(monkeypatch, tmp_path): + skill = _make_skill(tmp_path, "data-analysis") + skill.skill_file.write_bytes(b"\xff\xfe\x00") + monkeypatch.setattr(middleware_module, "get_or_new_skill_storage", lambda **kwargs: _make_storage(tmp_path, [skill])) + + middleware = SkillActivationMiddleware() + + def handler(model_request: ModelRequest): + raise AssertionError("handler should not be called when SKILL.md is not valid UTF-8") + + result = middleware.wrap_model_call(_make_model_request([HumanMessage(content="/data-analysis run")]), handler) + + assert isinstance(result, AIMessage) + assert "could not be loaded safely" in result.content diff --git a/backend/tests/test_uploads_middleware_core_logic.py b/backend/tests/test_uploads_middleware_core_logic.py index 6e39cda46..d0482bb71 100644 --- a/backend/tests/test_uploads_middleware_core_logic.py +++ b/backend/tests/test_uploads_middleware_core_logic.py @@ -14,6 +14,7 @@ from langchain_core.messages import AIMessage, HumanMessage from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware from deerflow.config.paths import Paths +from deerflow.utils.messages import ORIGINAL_USER_CONTENT_KEY THREAD_ID = "thread-abc123" @@ -263,6 +264,22 @@ class TestBeforeAgent: assert "" in combined_text assert "analyse this" in combined_text + def test_list_content_preserves_original_slash_skill_text(self, tmp_path): + mw = _middleware(tmp_path) + uploads_dir = _uploads_dir(tmp_path) + (uploads_dir / "data.csv").write_bytes(b"a,b") + + msg = _human( + [{"type": "text", "text": "/data-analysis analyze data.csv"}], + files=[{"filename": "data.csv", "size": 3, "path": "/mnt/user-data/uploads/data.csv"}], + ) + result = mw.before_agent(self._state(msg), _runtime()) + + assert result is not None + updated_msg = result["messages"][-1] + assert isinstance(updated_msg.content, list) + assert updated_msg.additional_kwargs[ORIGINAL_USER_CONTENT_KEY] == "/data-analysis analyze data.csv" + def test_preserves_additional_kwargs_on_updated_message(self, tmp_path): mw = _middleware(tmp_path) uploads_dir = _uploads_dir(tmp_path) @@ -278,6 +295,37 @@ class TestBeforeAgent: assert updated_kwargs.get("files") == files_meta assert updated_kwargs.get("element") == "task" + def test_preserves_original_user_content_before_upload_context(self, tmp_path): + mw = _middleware(tmp_path) + uploads_dir = _uploads_dir(tmp_path) + (uploads_dir / "report.pdf").write_bytes(b"pdf") + + msg = _human( + "/data-analysis 分析这个文档", + files=[{"filename": "report.pdf", "size": 3, "path": "/mnt/user-data/uploads/report.pdf"}], + ) + result = mw.before_agent(self._state(msg), _runtime()) + + assert result is not None + updated_msg = result["messages"][-1] + assert updated_msg.content.startswith("") + assert updated_msg.additional_kwargs[ORIGINAL_USER_CONTENT_KEY] == "/data-analysis 分析这个文档" + + def test_preserves_existing_original_user_content_marker(self, tmp_path): + mw = _middleware(tmp_path) + uploads_dir = _uploads_dir(tmp_path) + (uploads_dir / "report.pdf").write_bytes(b"pdf") + + msg = _human( + "\nold\n\n\n/data-analysis run", + files=[{"filename": "report.pdf", "size": 3, "path": "/mnt/user-data/uploads/report.pdf"}], + **{ORIGINAL_USER_CONTENT_KEY: "/data-analysis run"}, + ) + result = mw.before_agent(self._state(msg), _runtime()) + + assert result is not None + assert result["messages"][-1].additional_kwargs[ORIGINAL_USER_CONTENT_KEY] == "/data-analysis run" + def test_uploaded_files_returned_in_state_update(self, tmp_path): mw = _middleware(tmp_path) uploads_dir = _uploads_dir(tmp_path) diff --git a/frontend/eslint.config.js b/frontend/eslint.config.js index 71c172ef0..db3be484e 100644 --- a/frontend/eslint.config.js +++ b/frontend/eslint.config.js @@ -9,6 +9,8 @@ export default tseslint.config( { ignores: [ ".next", + "playwright-report", + "test-results", "src/components/ui/**", "src/components/ai-elements/**", "*.js", diff --git a/frontend/src/components/ai-elements/prompt-input.tsx b/frontend/src/components/ai-elements/prompt-input.tsx index ff19be843..dd9a20a29 100644 --- a/frontend/src/components/ai-elements/prompt-input.tsx +++ b/frontend/src/components/ai-elements/prompt-input.tsx @@ -881,6 +881,7 @@ export type PromptInputTextareaProps = ComponentProps< export const PromptInputTextarea = ({ onChange, + onKeyDown, className, placeholder = "What would you like to know?", ...props @@ -891,6 +892,10 @@ export const PromptInputTextarea = ({ const [isComposing, setIsComposing] = useState(false); const handleKeyDown: KeyboardEventHandler = (e) => { + onKeyDown?.(e); + if (e.defaultPrevented) { + return; + } if (e.key === "Enter") { if (isIMEComposing(e, isComposing)) { return; diff --git a/frontend/src/components/workspace/input-box.tsx b/frontend/src/components/workspace/input-box.tsx index 6344a26d2..6241016d5 100644 --- a/frontend/src/components/workspace/input-box.tsx +++ b/frontend/src/components/workspace/input-box.tsx @@ -20,6 +20,7 @@ import { useRef, useState, type ComponentProps, + type KeyboardEvent, } from "react"; import { @@ -59,6 +60,8 @@ import { fetch } from "@/core/api/fetcher"; import { getBackendBaseURL } from "@/core/config"; import { useI18n } from "@/core/i18n/hooks"; import { useModels } from "@/core/models/hooks"; +import type { Skill } from "@/core/skills"; +import { useSkills } from "@/core/skills/hooks"; import type { AgentThreadContext } from "@/core/threads"; import { textOfMessage } from "@/core/threads/utils"; import { cn } from "@/lib/utils"; @@ -86,6 +89,48 @@ import { Tooltip } from "./tooltip"; type InputMode = "flash" | "thinking" | "pro" | "ultra"; +const MAX_SKILL_SUGGESTIONS = 6; + +function getLeadingSlashSkillQuery(value: string): string | null { + if (!value.startsWith("/")) { + return null; + } + + const query = value.slice(1); + if (query.includes("/") || /\s/.test(query)) { + return null; + } + + return query; +} + +function getMatchingSkillSuggestions(skills: Skill[], query: string): Skill[] { + const normalizedQuery = query.toLowerCase(); + + return skills + .map((skill, index) => ({ + skill, + index, + name: skill.name.toLowerCase(), + })) + .filter(({ skill, name }) => { + if (!skill.enabled) { + return false; + } + return !normalizedQuery || name.includes(normalizedQuery); + }) + .sort((a, b) => { + const aStartsWith = a.name.startsWith(normalizedQuery); + const bStartsWith = b.name.startsWith(normalizedQuery); + if (aStartsWith !== bStartsWith) { + return aStartsWith ? -1 : 1; + } + return a.index - b.index; + }) + .slice(0, MAX_SKILL_SUGGESTIONS) + .map(({ skill }) => skill); +} + function getResolvedMode( mode: InputMode | undefined, supportsThinking: boolean, @@ -153,11 +198,17 @@ export function InputBox({ const { models } = useModels(); const { thread, isMock } = useThread(); const { textInput } = usePromptInputController(); + const { skills } = useSkills(); const promptRootRef = useRef(null); + const textareaRef = useRef(null); const [followups, setFollowups] = useState([]); const [followupsHidden, setFollowupsHidden] = useState(false); const [followupsLoading, setFollowupsLoading] = useState(false); + const [textareaFocused, setTextareaFocused] = useState(false); + const [skillSuggestionIndex, setSkillSuggestionIndex] = useState(0); + const [dismissedSkillSuggestionValue, setDismissedSkillSuggestionValue] = + useState(null); const lastGeneratedForAiIdRef = useRef(null); const wasStreamingRef = useRef(false); const messagesRef = useRef(thread.messages); @@ -347,9 +398,98 @@ export function InputBox({ setTimeout(() => requestFormSubmit(), 0); }, [pendingSuggestion, requestFormSubmit, textInput]); + const slashSkillQuery = useMemo( + () => getLeadingSlashSkillQuery(textInput.value ?? ""), + [textInput.value], + ); + const skillSuggestions = useMemo( + () => + slashSkillQuery === null + ? [] + : getMatchingSkillSuggestions(skills, slashSkillQuery), + [skills, slashSkillQuery], + ); + const showSkillSuggestions = + !disabled && + textareaFocused && + slashSkillQuery !== null && + skillSuggestions.length > 0 && + dismissedSkillSuggestionValue !== textInput.value; + + useEffect(() => { + setSkillSuggestionIndex(0); + }, [slashSkillQuery, skillSuggestions.length]); + + const applySkillSuggestion = useCallback( + (skill: Skill) => { + const nextValue = `/${skill.name} `; + textInput.setInput(nextValue); + setDismissedSkillSuggestionValue(nextValue); + requestAnimationFrame(() => { + const textarea = textareaRef.current; + if (!textarea) { + return; + } + textarea.focus(); + textarea.setSelectionRange(nextValue.length, nextValue.length); + }); + }, + [textInput], + ); + + const handleSkillSuggestionKeyDown = useCallback( + (event: KeyboardEvent) => { + if (!showSkillSuggestions) { + return; + } + + if (event.key === "ArrowDown") { + event.preventDefault(); + setSkillSuggestionIndex( + (index) => (index + 1) % skillSuggestions.length, + ); + return; + } + + if (event.key === "ArrowUp") { + event.preventDefault(); + setSkillSuggestionIndex( + (index) => + (index - 1 + skillSuggestions.length) % skillSuggestions.length, + ); + return; + } + + if (event.key === "Enter" || event.key === "Tab") { + if (event.shiftKey) { + return; + } + event.preventDefault(); + const selectedSkill = skillSuggestions[skillSuggestionIndex]; + if (selectedSkill) { + applySkillSuggestion(selectedSkill); + } + return; + } + + if (event.key === "Escape") { + event.preventDefault(); + setDismissedSkillSuggestionValue(textInput.value); + } + }, + [ + applySkillSuggestion, + showSkillSuggestions, + skillSuggestionIndex, + skillSuggestions, + textInput.value, + ], + ); + const showFollowups = !disabled && !isWelcomeMode && + !showSkillSuggestions && !followupsHidden && (followupsLoading || followups.length > 0); @@ -478,6 +618,48 @@ export function InputBox({ )} + {showSkillSuggestions && ( +
+
+ {skillSuggestions.map((skill, index) => { + const selected = index === skillSuggestionIndex; + return ( + + ); + })} +
+
+ )} setTextareaFocused(false)} + onFocus={() => setTextareaFocused(true)} + onKeyDown={handleSkillSuggestionKeyDown} + ref={textareaRef} /> @@ -860,11 +1046,13 @@ export function InputBox({ )} - {isWelcomeMode && searchParams.get("mode") !== "skill" && ( -
- -
- )} + {isWelcomeMode && + searchParams.get("mode") !== "skill" && + !showSkillSuggestions && ( +
+ +
+ )} diff --git a/frontend/src/core/messages/utils.ts b/frontend/src/core/messages/utils.ts index f1bbe4d07..9592db8b4 100644 --- a/frontend/src/core/messages/utils.ts +++ b/frontend/src/core/messages/utils.ts @@ -469,10 +469,14 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) { } export function isHiddenFromUIMessage(message: Message) { + const content = extractTextFromMessage(message); return ( message.additional_kwargs?.hide_from_ui === true || (typeof message.name === "string" && - HIDDEN_CONTROL_MESSAGE_NAMES.has(message.name)) + HIDDEN_CONTROL_MESSAGE_NAMES.has(message.name)) || + (message.type === "human" && + content.includes("") && + stripUploadedFilesTag(content).length === 0) ); } @@ -488,12 +492,13 @@ export interface FileInMessage { } /** - * Strip tag from message content. - * Returns the content with the tag removed. + * Strip backend-injected human context tags from message content. + * Kept under its historical name because callers use it for uploaded-file + * display cleanup. */ export function stripUploadedFilesTag(content: string): string { return content - .replace(/[\s\S]*?<\/uploaded_files>/g, "") + .replace(/<(uploaded_files|slash_skill_activation)>[\s\S]*?<\/\1>/g, "") .trim(); } @@ -504,6 +509,7 @@ export function stripUploadedFilesTag(content: string): string { * These markers are *not* user copy — they come from: * * - ``UploadsMiddleware`` → ```` + * - ``SkillActivationMiddleware`` → ```` * - ``DynamicContextMiddleware`` → ```` (carrying * ```` / ```` inside) * - ``TodoListMiddleware`` / ``LoopDetectionMiddleware`` style reminders @@ -517,6 +523,7 @@ export function stripUploadedFilesTag(content: string): string { */ export const INTERNAL_MARKER_TAGS = [ "uploaded_files", + "slash_skill_activation", "system-reminder", "memory", "current_date", diff --git a/frontend/tests/e2e/chat.spec.ts b/frontend/tests/e2e/chat.spec.ts index e608793df..4650a3c2c 100644 --- a/frontend/tests/e2e/chat.spec.ts +++ b/frontend/tests/e2e/chat.spec.ts @@ -24,6 +24,61 @@ test.describe("Chat workspace", () => { await expect(textarea).toHaveValue("Hello, DeerFlow!"); }); + test("suggests matching skills after a leading slash", async ({ page }) => { + await page.goto("/workspace/chats/new"); + + const textarea = page.getByPlaceholder(/how can i assist you/i); + await expect(textarea).toBeVisible({ timeout: 15_000 }); + + await textarea.fill("/dat"); + await expect( + page.getByRole("option", { name: /data-analysis/i }), + ).toBeVisible(); + await expect( + page.getByRole("option", { name: /disabled-skill/i }), + ).toBeHidden(); + + await textarea.press("Enter"); + + await expect(textarea).toHaveValue("/data-analysis "); + }); + + test("keeps Shift+Enter as newline while skill suggestions are visible", async ({ + page, + }) => { + await page.goto("/workspace/chats/new"); + + const textarea = page.getByPlaceholder(/how can i assist you/i); + await expect(textarea).toBeVisible({ timeout: 15_000 }); + + await textarea.fill("/dat"); + await expect( + page.getByRole("option", { name: /data-analysis/i }), + ).toBeVisible(); + + await textarea.press("Shift+Enter"); + + await expect(textarea).toHaveValue("/dat\n"); + await expect( + page.getByRole("option", { name: /data-analysis/i }), + ).toBeHidden(); + }); + + test("does not suggest skills for slash text away from the prompt start", async ({ + page, + }) => { + await page.goto("/workspace/chats/new"); + + const textarea = page.getByPlaceholder(/how can i assist you/i); + await expect(textarea).toBeVisible({ timeout: 15_000 }); + + await textarea.fill("please /dat"); + + await expect( + page.getByRole("option", { name: /data-analysis/i }), + ).toBeHidden(); + }); + test("sending a message triggers API call and shows response", async ({ page, }) => { @@ -49,6 +104,150 @@ test.describe("Chat workspace", () => { }); }); + test("slash skill command is submitted as normal chat text", async ({ + page, + }) => { + const slashCommand = "/data-analysis analyze uploads/foo.csv"; + let submittedText: string | undefined; + await page.route("**/runs/stream", (route) => { + const body = route.request().postDataJSON() as { + input?: { messages?: Array<{ content?: unknown }> }; + }; + const content = body.input?.messages?.at(-1)?.content; + if (typeof content === "string") { + submittedText = content; + } else if (Array.isArray(content)) { + submittedText = content + .map((block) => + typeof block === "object" && + block !== null && + "text" in block && + typeof block.text === "string" + ? block.text + : "", + ) + .join(""); + } + return handleRunStream(route); + }); + + await page.goto("/workspace/chats/new"); + + const textarea = page.getByPlaceholder(/how can i assist you/i); + await expect(textarea).toBeVisible({ timeout: 15_000 }); + + await textarea.fill(slashCommand); + await textarea.press("Enter"); + + await expect + .poll(() => submittedText, { timeout: 10_000 }) + .toBe(slashCommand); + await expect(page.getByText("Hello from DeerFlow!")).toBeVisible({ + timeout: 10_000, + }); + }); + + test("slash skill command with attachment preserves command text and file metadata", async ({ + page, + }) => { + const slashCommand = "/data-analysis analyze report.docx"; + let uploadCalled = false; + let submittedText: string | undefined; + let submittedFiles: + | Array<{ filename?: string; path?: string; status?: string }> + | undefined; + + await page.route("**/api/threads/*/uploads", async (route) => { + uploadCalled = true; + return route.fulfill({ + status: 200, + contentType: "application/json", + body: JSON.stringify({ + success: true, + message: "Uploaded", + files: [ + { + filename: "report.docx", + size: 12, + path: "report.docx", + virtual_path: "/mnt/user-data/uploads/report.docx", + artifact_url: "/api/threads/test/uploads/report.docx", + extension: ".docx", + }, + ], + }), + }); + }); + + await page.route("**/runs/stream", (route) => { + const body = route.request().postDataJSON() as { + input?: { + messages?: Array<{ + content?: unknown; + additional_kwargs?: { + files?: Array<{ + filename?: string; + path?: string; + status?: string; + }>; + }; + }>; + }; + }; + const message = body.input?.messages?.at(-1); + const content = message?.content; + if (typeof content === "string") { + submittedText = content; + } else if (Array.isArray(content)) { + submittedText = content + .map((block) => + typeof block === "object" && + block !== null && + "text" in block && + typeof block.text === "string" + ? block.text + : "", + ) + .join(""); + } + submittedFiles = message?.additional_kwargs?.files; + return handleRunStream(route); + }); + + await page.goto("/workspace/chats/new"); + + const textarea = page.getByPlaceholder(/how can i assist you/i); + await expect(textarea).toBeVisible({ timeout: 15_000 }); + + await page.getByLabel("Upload files").setInputFiles({ + name: "report.docx", + mimeType: + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + buffer: Buffer.from("fake docx"), + }); + + await textarea.fill(slashCommand); + await textarea.press("Enter"); + + await expect.poll(() => uploadCalled, { timeout: 10_000 }).toBeTruthy(); + await expect + .poll(() => submittedText, { timeout: 10_000 }) + .toBe(slashCommand); + await expect + .poll(() => submittedFiles, { timeout: 10_000 }) + .toEqual([ + { + filename: "report.docx", + size: 12, + path: "/mnt/user-data/uploads/report.docx", + status: "uploaded", + }, + ]); + await expect(page.getByText("Hello from DeerFlow!")).toBeVisible({ + timeout: 10_000, + }); + }); + test("keeps attachments visible while upload submit is pending", async ({ page, }) => { diff --git a/frontend/tests/e2e/utils/mock-api.ts b/frontend/tests/e2e/utils/mock-api.ts index cf10db08b..888b066b0 100644 --- a/frontend/tests/e2e/utils/mock-api.ts +++ b/frontend/tests/e2e/utils/mock-api.ts @@ -35,11 +35,41 @@ export type MockAgent = { system_prompt?: string; }; +export type MockSkill = { + name: string; + description: string; + category?: string; + license?: string | null; + enabled?: boolean; +}; + export type MockAPIOptions = { threads?: MockThread[]; agents?: MockAgent[]; + skills?: MockSkill[]; }; +const DEFAULT_SKILLS: MockSkill[] = [ + { + name: "data-analysis", + description: "Analyze structured data and produce charts.", + category: "public", + enabled: true, + }, + { + name: "frontend-design", + description: "Create polished frontend interfaces.", + category: "public", + enabled: true, + }, + { + name: "disabled-skill", + description: "Hidden from slash autocomplete.", + category: "public", + enabled: false, + }, +]; + // --------------------------------------------------------------------------- // mockLangGraphAPI // --------------------------------------------------------------------------- @@ -52,6 +82,7 @@ export type MockAPIOptions = { export function mockLangGraphAPI(page: Page, options?: MockAPIOptions) { const threads = options?.threads ?? []; const agents = options?.agents ?? []; + const skills = options?.skills ?? DEFAULT_SKILLS; // Thread search — sidebar thread list & chats list page void page.route("**/api/langgraph/threads/search", (route) => { @@ -259,6 +290,18 @@ export function mockLangGraphAPI(page: Page, options?: MockAPIOptions) { return route.fallback(); }); + // Skills list — settings page and slash autocomplete + void page.route("**/api/skills", (route) => { + if (route.request().method() === "GET") { + return route.fulfill({ + status: 200, + contentType: "application/json", + body: JSON.stringify({ skills }), + }); + } + return route.fallback(); + }); + // Follow-up suggestions — input box auto-suggest after AI response void page.route("**/api/threads/*/suggestions", (route) => { if (route.request().method() === "POST") { diff --git a/frontend/tests/unit/core/messages/utils.test.ts b/frontend/tests/unit/core/messages/utils.test.ts index b827c95eb..510e43a05 100644 --- a/frontend/tests/unit/core/messages/utils.test.ts +++ b/frontend/tests/unit/core/messages/utils.test.ts @@ -11,6 +11,7 @@ import { hasContent, hasReasoning, isAssistantMessageGroupStreaming, + stripUploadedFilesTag, } from "@/core/messages/utils"; function aiMessage(content: string): Message { @@ -173,6 +174,38 @@ describe("inline tag splitting", () => { }); }); +describe("human message internal context stripping", () => { + test("strips slash skill activation context from display content", () => { + const content = + "\n# Secret SKILL.md\n\nreal user task"; + + expect(stripUploadedFilesTag(content)).toBe("real user task"); + }); + + test("hides leaked slash skill activation messages with no user text", () => { + const messages = [ + { + id: "slash-activation", + type: "human", + content: + "\n# Secret SKILL.md\n", + }, + { + id: "ai-1", + type: "ai", + content: "Public answer", + }, + ] as Message[]; + + const groups = getMessageGroups(messages); + + expect(groups.map((group) => group.type)).toEqual(["assistant"]); + expect( + groups.flatMap((group) => group.messages).map((message) => message.id), + ).toEqual(["ai-1"]); + }); +}); + test("hides internal todo reminder messages from message groups", () => { const messages = [ { diff --git a/frontend/tests/unit/core/threads/export.test.ts b/frontend/tests/unit/core/threads/export.test.ts index 8ee520aa3..58219f8a0 100644 --- a/frontend/tests/unit/core/threads/export.test.ts +++ b/frontend/tests/unit/core/threads/export.test.ts @@ -260,6 +260,22 @@ describe("formatThreadAsJSON", () => { expect(raw).toContain("real user text"); }); + it("strips as defence in depth", () => { + // Slash activation normally rides in a hidden HumanMessage. If a replay + // or state merge loses the flag, export must still not leak full SKILL.md + // content into a user-visible transcript. + const leaky = human("real user task", { + id: "leak-slash-skill", + content: + "\n# Secret SKILL.md\nUse internal source.\n\nreal user task", + } as unknown as Partial); + const raw = formatThreadAsJSON(makeThread(), [leaky]); + expect(raw).not.toContain(""); + expect(raw).not.toContain("Secret SKILL.md"); + expect(raw).not.toContain("internal source"); + expect(raw).toContain("real user task"); + }); + it("sanitises tool message content when includeToolMessages is true", () => { const message = { id: "t-leak", From ae9e8bc0bf31b7a8e77c102d6925e747cfdce3c4 Mon Sep 17 00:00:00 2001 From: Lucy Shen <49802413+player0718@users.noreply.github.com> Date: Tue, 9 Jun 2026 23:16:14 +0800 Subject: [PATCH 09/11] fix(sandbox): make missing sandbox.mounts host_path a loud ERROR (#3244) (#3250) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In Docker production deployments, LocalSandboxProvider runs inside the deer-flow-gateway container, so any `sandbox.mounts[].host_path` from config.yaml is resolved against the gateway container's filesystem — not the host machine. When the path isn't also bind-mounted into the gateway service, the mount was silently dropped with only a WARNING log, leaving agents reading an empty directory in production while the same config worked under `make dev`. Escalate the missing-host_path branch to logger.error with explicit guidance about Docker bind mounts and docker-compose, so the failure is hard to miss in default log configurations. Skip behaviour is preserved to avoid breaking existing deployments. Also clarify the misleading `VolumeMountConfig.host_path` field description so it documents reality for both providers: - LocalSandboxProvider checks host_path from inside the gateway process (host in `make dev`, container in `make up`). - AioSandboxProvider (DooD) passes host_path straight to `docker -v` for the sandbox container, where the host Docker daemon resolves it from the host machine's perspective. config.example.yaml's `sandbox.mounts` comment gets a Note: block pointing operators at the docker-compose bind-mount requirement so the Docker-mode gotcha is discoverable from the canonical template. Adds a regression test that: - confirms missing host_path is still skipped (no behaviour break); - asserts an ERROR record is emitted referencing the offending paths; - asserts the message contains actionable Docker/gateway/docker-compose keywords so future refactors can't quietly downgrade it. Refs: https://github.com/bytedance/deer-flow/issues/3244 --- .../harness/deerflow/config/sandbox_config.py | 15 +++++- .../sandbox/local/local_sandbox_provider.py | 22 +++++++-- .../test_local_sandbox_provider_mounts.py | 48 +++++++++++++++++++ config.example.yaml | 6 ++- 4 files changed, 86 insertions(+), 5 deletions(-) diff --git a/backend/packages/harness/deerflow/config/sandbox_config.py b/backend/packages/harness/deerflow/config/sandbox_config.py index d9aac4ab4..7aac60003 100644 --- a/backend/packages/harness/deerflow/config/sandbox_config.py +++ b/backend/packages/harness/deerflow/config/sandbox_config.py @@ -4,7 +4,20 @@ from pydantic import BaseModel, ConfigDict, Field class VolumeMountConfig(BaseModel): """Configuration for a volume mount.""" - host_path: str = Field(..., description="Path on the host machine") + host_path: str = Field( + ..., + description=( + "Source path for the mount. Resolution depends on the active provider: " + "``LocalSandboxProvider`` checks this path from the gateway process — in " + "``make dev`` that is the host machine, but in Docker deployments " + "(``make up`` / docker-compose) it is the path *inside* the " + "``deer-flow-gateway`` container, so the host directory must also be " + "bind-mounted into the gateway service for the mount to take effect. " + "``AioSandboxProvider`` (DooD) passes this value straight to ``docker -v`` " + "for the sandbox container, where it is resolved by the host Docker daemon " + "from the host machine's perspective." + ), + ) container_path: str = Field(..., description="Path inside the container") read_only: bool = Field(default=False, description="Whether the mount is read-only") diff --git a/backend/packages/harness/deerflow/sandbox/local/local_sandbox_provider.py b/backend/packages/harness/deerflow/sandbox/local/local_sandbox_provider.py index 8b6b347ca..9e6523457 100644 --- a/backend/packages/harness/deerflow/sandbox/local/local_sandbox_provider.py +++ b/backend/packages/harness/deerflow/sandbox/local/local_sandbox_provider.py @@ -147,7 +147,17 @@ class LocalSandboxProvider(SandboxProvider): mount.container_path, ) continue - # Ensure the host path exists before adding mapping + # Ensure the host path exists before adding mapping. + # + # ``host_path`` is resolved against the filesystem of the + # process running this provider — for ``make dev`` that is + # the host machine, but for ``make up`` it is the + # ``deer-flow-gateway`` container, so any host path that + # isn't bind-mounted into the gateway image will be missing + # here. Skipping silently makes this a high-cost-to-debug + # silent failure (sandbox skill / tool reads an empty dir + # instead of the configured mount), so escalate to ERROR + # and include actionable guidance. See #3244. if host_path.exists(): mappings.append( PathMapping( @@ -157,10 +167,16 @@ class LocalSandboxProvider(SandboxProvider): ) ) else: - logger.warning( - "Mount host_path does not exist, skipping: %s -> %s", + logger.error( + "sandbox.mounts entry %s -> %s ignored: host_path %s does not exist from the " + "perspective of the gateway process. In Docker deployments (make up / docker-compose), " + "this path must also be bind-mounted into the gateway container — add a matching " + "volume entry under services.gateway.volumes in docker/docker-compose.yaml (and use " + "the in-container path here), or run in local mode (make dev) where the gateway sees " + "the host filesystem directly.", mount.host_path, mount.container_path, + mount.host_path, ) except Exception as e: # Log but don't fail if config loading fails diff --git a/backend/tests/test_local_sandbox_provider_mounts.py b/backend/tests/test_local_sandbox_provider_mounts.py index add5c4ea6..a9b0ec63d 100644 --- a/backend/tests/test_local_sandbox_provider_mounts.py +++ b/backend/tests/test_local_sandbox_provider_mounts.py @@ -612,6 +612,54 @@ class TestLocalSandboxProviderMounts: assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"] + def test_setup_path_mappings_logs_actionable_error_for_missing_host_path(self, tmp_path, caplog): + """Regression for #3244. + + When ``sandbox.mounts[].host_path`` is absent from the gateway process's + filesystem (the typical symptom in Docker production mode: host_path is a + host machine path that is not bind-mounted into the gateway container), + the mount is still skipped — but the failure must be a hard-to-miss ERROR + log with explicit, actionable guidance about Docker bind mounts, not the + old DEBUG/WARNING that buried the silent failure. + """ + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + missing_host_path = tmp_path / "does-not-exist" + + from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig + + sandbox_config = SandboxConfig( + use="deerflow.sandbox.local:LocalSandboxProvider", + mounts=[ + VolumeMountConfig(host_path=str(missing_host_path), container_path="/mnt/knowledge", read_only=True), + ], + ) + config = SimpleNamespace( + skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir, use="deerflow.skills.storage.local_skill_storage:LocalSkillStorage"), + sandbox=sandbox_config, + ) + + with caplog.at_level("ERROR", logger="deerflow.sandbox.local.local_sandbox_provider"): + with patch("deerflow.config.get_app_config", return_value=config): + provider = LocalSandboxProvider() + + # Silent-skip behaviour is preserved (no breaking change for existing deployments). + assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"] + + # The failure must be observable at ERROR level and reference the offending paths. + error_records = [r for r in caplog.records if r.levelname == "ERROR"] + assert error_records, "expected an ERROR log when host_path is missing" + message = "\n".join(r.getMessage() for r in error_records) + assert str(missing_host_path) in message + assert "/mnt/knowledge" in message + + # And it must include actionable Docker guidance so users don't lose hours + # to a silent empty-mount failure in production. + lowered = message.lower() + assert "docker" in lowered + assert "gateway" in lowered + assert "docker-compose" in lowered + def test_write_file_resolves_container_paths_in_content(self, tmp_path): """write_file should replace container paths in file content with local paths.""" data_dir = tmp_path / "data" diff --git a/config.example.yaml b/config.example.yaml index 290ef3302..82d2fa2ed 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -768,8 +768,12 @@ sandbox: allow_host_bash: false # Optional: Mount additional host directories into the sandbox. # Each mount maps a host path to a virtual container path accessible by the agent. + # Note: with LocalSandboxProvider under `make up` (docker-compose), host_path is + # checked from inside the deer-flow-gateway container — you must also bind-mount + # the same directory into services.gateway.volumes in docker/docker-compose.yaml + # for this mount to take effect (see issue #3244). # mounts: - # - host_path: /home/user/my-project # Absolute path on the host machine + # - host_path: /home/user/my-project # Absolute path; see note above for Docker mode # container_path: /mnt/my-project # Virtual path inside the sandbox # read_only: true # Whether the mount is read-only (default: false) From a57d05fe0a83551e928c5832183323cd29456687 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Wed, 10 Jun 2026 02:33:29 +0200 Subject: [PATCH 10/11] fix runtime journal run lifecycle events (#3470) --- .../packages/harness/deerflow/runtime/journal.py | 13 ++++++++++++- backend/tests/test_run_journal.py | 7 ++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index b65e1c0bb..74e03f165 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -164,7 +164,18 @@ class RunJournal(BaseCallbackHandler): metadata={"caller": caller, **(metadata or {})}, ) - def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None: + def on_chain_end( + self, + outputs: Any, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + # Nested chain ends fire for internal graph nodes; only the root chain + # represents the user-visible run lifecycle. + if parent_run_id is not None: + return self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"}) self._flush_sync() diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index 0b495954b..a68895d27 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -179,15 +179,16 @@ class TestLifecycleCallbacks: assert "run.end" in types @pytest.mark.anyio - async def test_nested_chain_no_run_start(self, journal_setup): - """Nested chains (parent_run_id set) should NOT produce run.start.""" + async def test_nested_chain_no_run_lifecycle_events(self, journal_setup): + """Nested chains (parent_run_id set) should NOT produce root run lifecycle events.""" j, store = journal_setup parent_id = uuid4() j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id) - j.on_chain_end({}, run_id=uuid4()) + j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id) await j.flush() events = await store.list_events("t1", "r1") assert not any(e["event_type"] == "run.start" for e in events) + assert not any(e["event_type"] == "run.end" for e in events) class TestToolCallbacks: From 2b795265e7d7ab935428ac1f1177c8d8b403dc34 Mon Sep 17 00:00:00 2001 From: DanielWalnut <45447813+hetaoBackend@users.noreply.github.com> Date: Wed, 10 Jun 2026 16:11:00 +0800 Subject: [PATCH 11/11] fix: align auth-disabled mode and mock history loading (#3471) * fix: align auth-disabled mode and mock history loading * fix: address auth-disabled review feedback * test: cover auth-disabled backend contract * style: format frontend tests * fix: address follow-up review comments --- backend/app/gateway/app.py | 2 + backend/app/gateway/auth_disabled.py | 54 +++++ backend/app/gateway/auth_middleware.py | 61 ++++-- backend/app/gateway/csrf_middleware.py | 5 + backend/app/gateway/deps.py | 11 + backend/app/gateway/langgraph_auth.py | 7 + backend/app/gateway/routers/auth.py | 10 + backend/tests/test_auth_middleware.py | 191 +++++++++++++++++- backend/tests/test_langgraph_auth.py | 9 + frontend/playwright.real-backend.config.ts | 10 +- frontend/src/core/auth/auth-disabled-user.ts | 23 +++ frontend/src/core/auth/server.ts | 10 +- frontend/src/core/threads/hooks.ts | 55 +++-- .../auth-disabled-contract.spec.ts | 16 ++ .../real-backend-render.spec.ts | 5 +- frontend/tests/e2e/chat.spec.ts | 1 + frontend/tests/e2e/thread-history.spec.ts | 79 ++++++++ frontend/tests/unit/core/auth/server.test.ts | 31 +++ 18 files changed, 528 insertions(+), 52 deletions(-) create mode 100644 backend/app/gateway/auth_disabled.py create mode 100644 frontend/src/core/auth/auth-disabled-user.ts create mode 100644 frontend/tests/e2e-real-backend/auth-disabled-contract.spec.ts diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index dd5701083..aa49b4ffc 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -6,6 +6,7 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from app.gateway.auth_disabled import warn_if_auth_disabled_enabled from app.gateway.auth_middleware import AuthMiddleware from app.gateway.config import get_gateway_config from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins @@ -172,6 +173,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: startup_config = get_app_config() apply_logging_level(startup_config.log_level) logger.info("Configuration loaded successfully") + warn_if_auth_disabled_enabled() except Exception as e: error_msg = f"Failed to load configuration during gateway startup: {e}" logger.exception(error_msg) diff --git a/backend/app/gateway/auth_disabled.py b/backend/app/gateway/auth_disabled.py new file mode 100644 index 000000000..396de7129 --- /dev/null +++ b/backend/app/gateway/auth_disabled.py @@ -0,0 +1,54 @@ +"""Shared helpers for local/E2E auth-disabled mode.""" + +from __future__ import annotations + +import logging +import os +from types import SimpleNamespace + +AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED" +AUTH_DISABLED_USER_ID = "e2e-user" +AUTH_DISABLED_USER_EMAIL = "e2e@test.local" + +AUTH_SOURCE_SESSION = "session" +AUTH_SOURCE_INTERNAL = "internal" +AUTH_SOURCE_AUTH_DISABLED = "auth_disabled" + +_PRODUCTION_ENV_VARS: tuple[str, ...] = ("DEER_FLOW_ENV", "ENVIRONMENT") +_PRODUCTION_ENV_VALUES: frozenset[str] = frozenset({"prod", "production"}) + +logger = logging.getLogger(__name__) + + +def is_explicit_production_environment() -> bool: + return any(os.environ.get(name, "").strip().lower() in _PRODUCTION_ENV_VALUES for name in _PRODUCTION_ENV_VARS) + + +def is_auth_disabled_requested() -> bool: + return os.environ.get(AUTH_DISABLED_ENV_VAR) == "1" + + +def is_auth_disabled() -> bool: + return is_auth_disabled_requested() and not is_explicit_production_environment() + + +def warn_if_auth_disabled_enabled() -> None: + if not is_auth_disabled(): + return + + logger.warning( + "%s=1 is active: authentication is bypassed and anonymous requests run as synthetic admin user %r. Do not enable this in shared or production deployments.", + AUTH_DISABLED_ENV_VAR, + AUTH_DISABLED_USER_ID, + ) + + +def get_auth_disabled_user(): + return SimpleNamespace( + id=AUTH_DISABLED_USER_ID, + email=AUTH_DISABLED_USER_EMAIL, + password_hash=None, + system_role="admin", + needs_setup=False, + token_version=0, + ) diff --git a/backend/app/gateway/auth_middleware.py b/backend/app/gateway/auth_middleware.py index 6b6452264..6d71186a0 100644 --- a/backend/app/gateway/auth_middleware.py +++ b/backend/app/gateway/auth_middleware.py @@ -17,6 +17,13 @@ from starlette.responses import JSONResponse from starlette.types import ASGIApp from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse +from app.gateway.auth_disabled import ( + AUTH_SOURCE_AUTH_DISABLED, + AUTH_SOURCE_INTERNAL, + AUTH_SOURCE_SESSION, + get_auth_disabled_user, + is_auth_disabled, +) from app.gateway.authz import _ALL_PERMISSIONS, AuthContext from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token from deerflow.runtime.user_context import reset_current_user, set_current_user @@ -80,8 +87,38 @@ class AuthMiddleware(BaseHTTPMiddleware): if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)): internal_user = get_internal_user() + auth_source = AUTH_SOURCE_SESSION + access_token = request.cookies.get("access_token") + # Non-public path: require session cookie - if internal_user is None and not request.cookies.get("access_token"): + if internal_user is not None: + user = internal_user + auth_source = AUTH_SOURCE_INTERNAL + elif access_token: + # Strict JWT validation: reject junk/expired tokens with 401 + # right here instead of silently passing through. This closes + # the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8): + # without this, non-isolation routes like /api/models would + # accept any cookie-shaped string as authentication. + # + # We call the *strict* resolver so that fine-grained error + # codes (token_expired, token_invalid, user_not_found, …) + # propagate from AuthErrorCode, not get flattened into one + # generic code. BaseHTTPMiddleware doesn't let HTTPException + # bubble up, so we catch and render it as JSONResponse here. + from app.gateway.deps import get_current_user_from_request + + try: + user = await get_current_user_from_request(request) + except HTTPException as exc: + if not is_auth_disabled(): + return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) + user = get_auth_disabled_user() + auth_source = AUTH_SOURCE_AUTH_DISABLED + elif is_auth_disabled(): + user = get_auth_disabled_user() + auth_source = AUTH_SOURCE_AUTH_DISABLED + else: return JSONResponse( status_code=401, content={ @@ -92,32 +129,12 @@ class AuthMiddleware(BaseHTTPMiddleware): }, ) - # Strict JWT validation: reject junk/expired tokens with 401 - # right here instead of silently passing through. This closes - # the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8): - # without this, non-isolation routes like /api/models would - # accept any cookie-shaped string as authentication. - # - # We call the *strict* resolver so that fine-grained error - # codes (token_expired, token_invalid, user_not_found, …) - # propagate from AuthErrorCode, not get flattened into one - # generic code. BaseHTTPMiddleware doesn't let HTTPException - # bubble up, so we catch and render it as JSONResponse here. - from app.gateway.deps import get_current_user_from_request - - if internal_user is not None: - user = internal_user - else: - try: - user = await get_current_user_from_request(request) - except HTTPException as exc: - return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) - # Stamp both request.state.user (for the contextvar pattern) # and request.state.auth (so @require_permission's "auth is # None" branch short-circuits instead of running the entire # JWT-decode + DB-lookup pipeline a second time per request). request.state.user = user + request.state.auth_source = auth_source request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS) token = set_current_user(user) try: diff --git a/backend/app/gateway/csrf_middleware.py b/backend/app/gateway/csrf_middleware.py index f34882032..c9edb83b2 100644 --- a/backend/app/gateway/csrf_middleware.py +++ b/backend/app/gateway/csrf_middleware.py @@ -14,6 +14,8 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse from starlette.types import ASGIApp +from app.gateway.auth_disabled import is_auth_disabled + CSRF_COOKIE_NAME = "csrf_token" CSRF_HEADER_NAME = "X-CSRF-Token" CSRF_TOKEN_LENGTH = 64 # bytes @@ -38,6 +40,9 @@ def should_check_csrf(request: Request) -> bool: if request.method not in ("POST", "PUT", "DELETE", "PATCH"): return False + if is_auth_disabled(): + return False + path = request.url.path.rstrip("/") # Exempt /api/v1/auth/me endpoint if path == "/api/v1/auth/me": diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index 5739d217d..c192828d9 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -331,6 +331,17 @@ async def get_current_user_from_request(request: Request): Raises HTTPException 401 if not authenticated. """ + state = getattr(request, "state", None) + state_user = getattr(state, "user", None) + from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED, AUTH_SOURCE_INTERNAL, AUTH_SOURCE_SESSION + + if state_user is not None and getattr(state, "auth_source", None) in { + AUTH_SOURCE_SESSION, + AUTH_SOURCE_AUTH_DISABLED, + AUTH_SOURCE_INTERNAL, + }: + return state_user + from app.gateway.auth import decode_token from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code diff --git a/backend/app/gateway/langgraph_auth.py b/backend/app/gateway/langgraph_auth.py index 202fab2d5..3ab3d2070 100644 --- a/backend/app/gateway/langgraph_auth.py +++ b/backend/app/gateway/langgraph_auth.py @@ -20,6 +20,7 @@ from langgraph_sdk import Auth from app.gateway.auth.errors import TokenError from app.gateway.auth.jwt import decode_token +from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled from app.gateway.deps import get_local_provider auth = Auth() @@ -38,6 +39,9 @@ def _check_csrf(request) -> None: if method.upper() not in _CSRF_METHODS: return + if is_auth_disabled(): + return + cookie_token = request.cookies.get("csrf_token") header_token = request.headers.get("x-csrf-token") @@ -66,6 +70,9 @@ async def authenticate(request): # are rejected early, even if the cookie carries a valid JWT. _check_csrf(request) + if is_auth_disabled(): + return AUTH_DISABLED_USER_ID + token = request.cookies.get("access_token") if not token: raise Auth.exceptions.HTTPException( diff --git a/backend/app/gateway/routers/auth.py b/backend/app/gateway/routers/auth.py index e57182c26..ee4f074d2 100644 --- a/backend/app/gateway/routers/auth.py +++ b/backend/app/gateway/routers/auth.py @@ -341,9 +341,19 @@ async def change_password(request: Request, response: Response, body: ChangePass - Re-issues session cookie with new token_version """ from app.gateway.auth.password import hash_password_async, verify_password_async + from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED user = await get_current_user_from_request(request) + if getattr(request.state, "auth_source", None) == AUTH_SOURCE_AUTH_DISABLED: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=AuthErrorResponse( + code=AuthErrorCode.INVALID_CREDENTIALS, + message="Password changes are not available when DEER_FLOW_AUTH_DISABLED=1.", + ).model_dump(), + ) + if user.password_hash is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump()) diff --git a/backend/tests/test_auth_middleware.py b/backend/tests/test_auth_middleware.py index 726786ac9..838bf57af 100644 --- a/backend/tests/test_auth_middleware.py +++ b/backend/tests/test_auth_middleware.py @@ -4,6 +4,7 @@ import pytest from starlette.testclient import TestClient from app.gateway.auth_middleware import AuthMiddleware, _is_public +from app.gateway.csrf_middleware import CSRFMiddleware # ── _is_public unit tests ───────────────────────────────────────────────── @@ -88,7 +89,9 @@ def test_unknown_api_path_is_protected(): def _make_app(): """Create a minimal FastAPI app with AuthMiddleware for testing.""" - from fastapi import FastAPI + from fastapi import FastAPI, Request + + from deerflow.runtime.user_context import get_effective_user_id app = FastAPI() app.add_middleware(AuthMiddleware) @@ -98,8 +101,16 @@ def _make_app(): return {"status": "ok"} @app.get("/api/v1/auth/me") - async def auth_me(): - return {"id": "1", "email": "test@test.com"} + async def auth_me(request: Request): + from app.gateway.deps import get_current_user_from_request + + user = await get_current_user_from_request(request) + return { + "id": str(user.id), + "email": user.email, + "system_role": user.system_role, + "needs_setup": user.needs_setup, + } @app.get("/api/v1/auth/setup-status") async def setup_status(): @@ -109,6 +120,29 @@ def _make_app(): async def models_get(): return {"models": []} + @app.get("/api/whoami") + async def whoami(request: Request): + user = request.state.user + return { + "id": str(user.id), + "email": getattr(user, "email", None), + "system_role": getattr(user, "system_role", None), + "context_user_id": get_effective_user_id(), + } + + @app.get("/api/current-user-from-dep") + async def current_user_from_dep(request: Request): + from app.gateway.deps import get_current_user_from_request + + user = await get_current_user_from_request(request) + state_user = request.state.user + return { + "id": str(user.id), + "state_id": str(state_user.id), + "auth_source": request.state.auth_source, + "context_user_id": get_effective_user_id(), + } + @app.put("/api/mcp/config") async def mcp_put(): return {"ok": True} @@ -132,8 +166,24 @@ def _make_app(): return app +def _make_auth_csrf_app(): + """Create a minimal app with production middleware ordering.""" + from fastapi import FastAPI + + app = FastAPI() + app.add_middleware(AuthMiddleware) + app.add_middleware(CSRFMiddleware) + + @app.post("/api/threads/abc/runs/stream") + async def protected_mutation(): + return {"ok": True} + + return app + + @pytest.fixture -def client(): +def client(monkeypatch): + monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) return TestClient(_make_app()) @@ -161,6 +211,139 @@ def test_protected_path_no_cookie_returns_401(client): assert body["detail"]["code"] == "not_authenticated" +def test_auth_disabled_allows_protected_path_without_cookie(monkeypatch): + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + client = TestClient(_make_app()) + + res = client.get("/api/models") + + assert res.status_code == 200 + assert res.json() == {"models": []} + + +def test_auth_disabled_stamps_e2e_admin_user_without_cookie(monkeypatch): + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + client = TestClient(_make_app()) + + res = client.get("/api/whoami") + + assert res.status_code == 200 + assert res.json() == { + "id": "e2e-user", + "email": "e2e@test.local", + "system_role": "admin", + "context_user_id": "e2e-user", + } + + +def test_auth_disabled_auth_me_reuses_middleware_user_without_cookie(monkeypatch): + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + client = TestClient(_make_app()) + + res = client.get("/api/v1/auth/me") + + assert res.status_code == 200 + assert res.json() == { + "id": "e2e-user", + "email": "e2e@test.local", + "system_role": "admin", + "needs_setup": False, + } + + +def test_auth_disabled_does_not_clobber_valid_session_cookie(monkeypatch): + from types import SimpleNamespace + + async def fake_current_user(request): + return SimpleNamespace( + id="session-user", + email="session@test.local", + system_role="user", + needs_setup=False, + ) + + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + monkeypatch.setattr("app.gateway.deps.get_current_user_from_request", fake_current_user) + client = TestClient(_make_app()) + + res = client.get("/api/whoami", cookies={"access_token": "valid-session"}) + + assert res.status_code == 200 + assert res.json() == { + "id": "session-user", + "email": "session@test.local", + "system_role": "user", + "context_user_id": "session-user", + } + + +def test_auth_disabled_does_not_clobber_internal_auth_identity(monkeypatch): + from app.gateway.internal_auth import create_internal_auth_headers + from deerflow.runtime.user_context import DEFAULT_USER_ID + + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + client = TestClient(_make_app()) + + res = client.get( + "/api/current-user-from-dep", + headers=create_internal_auth_headers(), + ) + + assert res.status_code == 200 + assert res.json() == { + "id": DEFAULT_USER_ID, + "state_id": DEFAULT_USER_ID, + "auth_source": "internal", + "context_user_id": DEFAULT_USER_ID, + } + + +def test_auth_disabled_skips_csrf_for_state_changing_requests(monkeypatch): + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + client = TestClient(_make_auth_csrf_app()) + + res = client.post("/api/threads/abc/runs/stream") + + assert res.status_code == 200 + assert res.json() == {"ok": True} + + +def test_auth_disabled_is_ignored_in_explicit_production_env(monkeypatch): + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + monkeypatch.setenv("DEER_FLOW_ENV", "production") + client = TestClient(_make_app()) + + res = client.get("/api/models") + + assert res.status_code == 401 + + +def test_auth_disabled_startup_warning_when_effective(monkeypatch, caplog): + from app.gateway.auth_disabled import warn_if_auth_disabled_enabled + + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + monkeypatch.delenv("DEER_FLOW_ENV", raising=False) + monkeypatch.delenv("ENVIRONMENT", raising=False) + + with caplog.at_level("WARNING", logger="app.gateway.auth_disabled"): + warn_if_auth_disabled_enabled() + + assert "authentication is bypassed" in caplog.text + assert "e2e-user" in caplog.text + + +def test_auth_disabled_startup_warning_suppressed_in_explicit_production_env(monkeypatch, caplog): + from app.gateway.auth_disabled import warn_if_auth_disabled_enabled + + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + monkeypatch.setenv("ENVIRONMENT", "production") + + with caplog.at_level("WARNING", logger="app.gateway.auth_disabled"): + warn_if_auth_disabled_enabled() + + assert "authentication is bypassed" not in caplog.text + + def test_protected_path_with_junk_cookie_rejected(client): """Junk cookie → 401. Middleware strictly validates the JWT now (AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad diff --git a/backend/tests/test_langgraph_auth.py b/backend/tests/test_langgraph_auth.py index d2ee81051..1d2d71e7c 100644 --- a/backend/tests/test_langgraph_auth.py +++ b/backend/tests/test_langgraph_auth.py @@ -21,6 +21,7 @@ from langgraph_sdk import Auth from app.gateway.auth.config import AuthConfig, set_auth_config from app.gateway.auth.jwt import create_access_token, decode_token from app.gateway.auth.models import User +from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID from app.gateway.langgraph_auth import add_owner_filter, authenticate # ── Helpers ─────────────────────────────────────────────────────────────── @@ -59,6 +60,14 @@ def test_no_cookie_raises_401(): assert "Not authenticated" in str(exc.value.detail) +def test_auth_disabled_skips_csrf_and_authenticates_e2e_user(monkeypatch): + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + + identity = asyncio.run(authenticate(_req(method="POST"))) + + assert identity == AUTH_DISABLED_USER_ID + + def test_invalid_jwt_raises_401(): with pytest.raises(Auth.exceptions.HTTPException) as exc: asyncio.run(authenticate(_req({"access_token": "garbage"}))) diff --git a/frontend/playwright.real-backend.config.ts b/frontend/playwright.real-backend.config.ts index 9db673b90..091386686 100644 --- a/frontend/playwright.real-backend.config.ts +++ b/frontend/playwright.real-backend.config.ts @@ -7,8 +7,9 @@ import { defineConfig, devices } from "@playwright/test"; * so the mock-based suite is untouched. * * Two webServers are started: the replay gateway (:8011) and the frontend - * (:3000, pointed at the gateway). Auth uses a throwaway test account the spec - * registers at runtime — no secrets. + * (:3000, pointed at the gateway). Auth-disabled mode is enabled on both + * servers so the no-cookie e2e contract is covered; specs that need session + * cookies still register a throwaway test account at runtime. */ export default defineConfig({ testDir: "./tests/e2e-real-backend", @@ -38,7 +39,10 @@ export default defineConfig({ // Mount the test-only run/message seeder used by multi-run-order.spec.ts // (#3352). The endpoint exists only on this replay gateway, never in the // production app. - env: { DEERFLOW_ENABLE_TEST_SEED: "1" }, + env: { + DEERFLOW_ENABLE_TEST_SEED: "1", + DEER_FLOW_AUTH_DISABLED: "1", + }, }, { command: "pnpm build && pnpm start", diff --git a/frontend/src/core/auth/auth-disabled-user.ts b/frontend/src/core/auth/auth-disabled-user.ts new file mode 100644 index 000000000..2e26a8911 --- /dev/null +++ b/frontend/src/core/auth/auth-disabled-user.ts @@ -0,0 +1,23 @@ +import type { User } from "./types"; + +export const AUTH_DISABLED_USER: User = { + id: "e2e-user", + email: "e2e@test.local", + system_role: "admin", + needs_setup: false, +}; + +const PRODUCTION_ENV_VALUES = new Set(["prod", "production"]); + +function isExplicitProductionEnvironment() { + return ["DEER_FLOW_ENV", "ENVIRONMENT"].some((name) => + PRODUCTION_ENV_VALUES.has((process.env[name] ?? "").trim().toLowerCase()), + ); +} + +export function isAuthDisabledMode() { + return ( + process.env.DEER_FLOW_AUTH_DISABLED === "1" && + !isExplicitProductionEnvironment() + ); +} diff --git a/frontend/src/core/auth/server.ts b/frontend/src/core/auth/server.ts index 5712f1e89..198ef087c 100644 --- a/frontend/src/core/auth/server.ts +++ b/frontend/src/core/auth/server.ts @@ -2,6 +2,7 @@ import { cookies } from "next/headers"; import { isStaticWebsiteOnly } from "../static-mode"; +import { AUTH_DISABLED_USER, isAuthDisabledMode } from "./auth-disabled-user"; import { getGatewayConfig } from "./gateway-config"; import { STATIC_WEBSITE_USER } from "./static-user"; import { type AuthResult, userSchema } from "./types"; @@ -20,15 +21,10 @@ export async function getServerSideUser(): Promise { }; } - if (process.env.DEER_FLOW_AUTH_DISABLED === "1") { + if (isAuthDisabledMode()) { return { tag: "authenticated", - user: { - id: "e2e-user", - email: "e2e@test.local", - system_role: "admin", - needs_setup: false, - }, + user: AUTH_DISABLED_USER, }; } diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index 4418a9e26..2ac1a1814 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -364,7 +364,7 @@ export function useThreadStream({ loadMore: loadMoreHistory, loading: isHistoryLoading, appendMessages, - } = useThreadHistory(onStreamThreadId ?? ""); + } = useThreadHistory(onStreamThreadId ?? "", { enabled: !isMock }); // Keep listeners ref updated with latest callbacks useEffect(() => { @@ -854,8 +854,15 @@ export function useThreadStream({ } as const; } -export function useThreadHistory(threadId: string) { - const runs = useThreadRuns(threadId); +type ThreadHistoryOptions = { + enabled?: boolean; +}; + +export function useThreadHistory( + threadId: string, + { enabled = true }: ThreadHistoryOptions = {}, +) { + const runs = useThreadRuns(threadId, { enabled }); const threadIdRef = useRef(threadId); const runsRef = useRef(runs.data ?? []); const indexRef = useRef(-1); @@ -864,10 +871,15 @@ export function useThreadHistory(threadId: string) { const loadingRunIdRef = useRef(null); const loadedRunIdsRef = useRef>(new Set()); const runBeforeSeqRef = useRef>(new Map()); + const loadGenerationRef = useRef(0); const [loading, setLoading] = useState(false); const [messages, setMessages] = useState([]); const loadMessages = useCallback(async () => { + if (!enabled) { + return; + } + const loadGeneration = loadGenerationRef.current; if (loadingRef.current) { const pendingRunIndex = findLatestUnloadedRunIndex( runsRef.current, @@ -921,12 +933,15 @@ export function useThreadHistory(threadId: string) { }).then((res) => { return res.json(); }); + if ( + loadGenerationRef.current !== loadGeneration || + threadIdRef.current !== requestThreadId + ) { + return; + } const _messages = result.data .filter((m) => !m.metadata.caller?.startsWith("middleware:")) .map((m) => m.content); - if (threadIdRef.current !== requestThreadId) { - return; - } setMessages((prev) => dedupeMessagesByIdentity([..._messages, ...prev]), ); @@ -961,16 +976,19 @@ export function useThreadHistory(threadId: string) { } catch (err) { console.error(err); } finally { - loadingRef.current = false; - loadingRunIdRef.current = null; - setLoading(false); + if (loadGenerationRef.current === loadGeneration) { + loadingRef.current = false; + loadingRunIdRef.current = null; + setLoading(false); + } } - }, []); + }, [enabled]); useEffect(() => { const threadChanged = threadIdRef.current !== threadId; threadIdRef.current = threadId; - if (threadChanged) { + if (!enabled || threadChanged) { + loadGenerationRef.current += 1; runsRef.current = []; indexRef.current = -1; pendingLoadRef.current = false; @@ -982,6 +1000,10 @@ export function useThreadHistory(threadId: string) { setMessages([]); } + if (!enabled) { + return; + } + if (runs.data && runs.data.length > 0) { runsRef.current = runs.data ?? []; indexRef.current = findLatestUnloadedRunIndex( @@ -992,14 +1014,15 @@ export function useThreadHistory(threadId: string) { loadMessages().catch(() => { toast.error("Failed to load thread history."); }); - }, [threadId, runs.data, loadMessages]); + }, [enabled, threadId, runs.data, loadMessages]); const appendMessages = useCallback((_messages: Message[]) => { setMessages((prev) => { return dedupeMessagesByIdentity([...prev, ..._messages]); }); }, []); - const hasMore = indexRef.current >= 0 || !runs.data; + const hasMore = + enabled && Boolean(threadId) && (indexRef.current >= 0 || !runs.data); return { runs: runs.data, messages, @@ -1077,7 +1100,10 @@ export function useThreads( }); } -export function useThreadRuns(threadId?: string) { +export function useThreadRuns( + threadId?: string, + { enabled = true }: { enabled?: boolean } = {}, +) { const apiClient = getAPIClient(); return useQuery({ queryKey: ["thread", threadId], @@ -1088,6 +1114,7 @@ export function useThreadRuns(threadId?: string) { const response = await apiClient.runs.list(threadId); return response; }, + enabled: enabled && Boolean(threadId), refetchOnWindowFocus: false, }); } diff --git a/frontend/tests/e2e-real-backend/auth-disabled-contract.spec.ts b/frontend/tests/e2e-real-backend/auth-disabled-contract.spec.ts new file mode 100644 index 000000000..23cb08d40 --- /dev/null +++ b/frontend/tests/e2e-real-backend/auth-disabled-contract.spec.ts @@ -0,0 +1,16 @@ +import { expect, test } from "@playwright/test"; + +import { AUTH_DISABLED_USER } from "../../src/core/auth/auth-disabled-user"; + +const APP = "http://localhost:3000"; + +test.describe("auth-disabled contract (real backend)", () => { + test("gateway /auth/me returns the frontend synthetic user without a cookie", async ({ + context, + }) => { + const resp = await context.request.get(`${APP}/api/v1/auth/me`); + + expect(resp.status(), await resp.text()).toBe(200); + await expect(resp.json()).resolves.toEqual(AUTH_DISABLED_USER); + }); +}); diff --git a/frontend/tests/e2e-real-backend/real-backend-render.spec.ts b/frontend/tests/e2e-real-backend/real-backend-render.spec.ts index 97c367d41..19047445b 100644 --- a/frontend/tests/e2e-real-backend/real-backend-render.spec.ts +++ b/frontend/tests/e2e-real-backend/real-backend-render.spec.ts @@ -101,10 +101,11 @@ test.describe("real backend render (replay, no API key)", () => { EXPECTED_SUGGESTION, "fixture should contain a suggestions turn (re-record; the record spec waits for /suggestions)", ).not.toBe(""); - await expect(page.getByText(EXPECTED_TITLE)).toBeVisible({ + const chat = page.locator("#chat"); + await expect(chat.getByText(EXPECTED_TITLE)).toBeVisible({ timeout: 60_000, }); - await expect(page.getByText(EXPECTED_SUGGESTION)).toBeVisible({ + await expect(chat.getByText(EXPECTED_SUGGESTION)).toBeVisible({ timeout: 30_000, }); diff --git a/frontend/tests/e2e/chat.spec.ts b/frontend/tests/e2e/chat.spec.ts index 4650a3c2c..50ab3c871 100644 --- a/frontend/tests/e2e/chat.spec.ts +++ b/frontend/tests/e2e/chat.spec.ts @@ -12,6 +12,7 @@ test.describe("Chat workspace", () => { const textarea = page.getByPlaceholder(/how can i assist you/i); await expect(textarea).toBeVisible({ timeout: 15_000 }); + await expect(page.getByRole("button", { name: /load more/i })).toBeHidden(); }); test("can type a message in the input box", async ({ page }) => { diff --git a/frontend/tests/e2e/thread-history.spec.ts b/frontend/tests/e2e/thread-history.spec.ts index 19fce310a..9476ca4ab 100644 --- a/frontend/tests/e2e/thread-history.spec.ts +++ b/frontend/tests/e2e/thread-history.spec.ts @@ -18,6 +18,7 @@ const THREADS = [ updated_at: "2025-06-02T12:00:00Z", }, ]; +const DEMO_THREAD_ID = "7cfa5f8f-a2f8-47ad-acbd-da7137baf990"; test.describe("Thread history", () => { test("sidebar shows existing threads", async ({ page }) => { @@ -61,6 +62,84 @@ test.describe("Thread history", () => { ).toBeVisible({ timeout: 15_000 }); }); + test("mock thread does not load real backend run history", async ({ + page, + }) => { + mockLangGraphAPI(page, { + threads: [ + { + thread_id: DEMO_THREAD_ID, + title: "Forecasting 2026 Trends and Opportunities", + updated_at: "2025-06-01T12:00:00Z", + messages: [ + { + type: "human", + id: `run-human-${DEMO_THREAD_ID}`, + content: [ + { + type: "text", + text: "This run-message endpoint should not be called.", + }, + ], + }, + ], + }, + ], + }); + const backendRunHistoryUrls: string[] = []; + await page.route( + /\/api\/langgraph\/threads\/[^/]+\/runs(?:\?|$)/, + (route) => { + if ( + route.request().method() === "GET" && + route + .request() + .url() + .includes(`/api/langgraph/threads/${DEMO_THREAD_ID}/runs`) + ) { + backendRunHistoryUrls.push(route.request().url()); + return route.fulfill({ + status: 500, + contentType: "application/json", + body: JSON.stringify({ + error: "mock=true must not load real runs", + }), + }); + } + return route.fallback(); + }, + ); + await page.route( + /\/api\/threads\/[^/]+\/runs\/[^/]+\/messages(?:\?|$)/, + (route) => { + if ( + route.request().method() === "GET" && + route.request().url().includes(`/api/threads/${DEMO_THREAD_ID}/runs/`) + ) { + backendRunHistoryUrls.push(route.request().url()); + return route.fulfill({ + status: 500, + contentType: "application/json", + body: JSON.stringify({ + error: "mock=true must not load real run messages", + }), + }); + } + return route.fallback(); + }, + ); + + await page.goto(`/workspace/chats/${DEMO_THREAD_ID}?mock=true`); + + await expect( + page.getByText("What might be the trends and opportunities in 2026?"), + ).toBeVisible({ timeout: 15_000 }); + await expect( + page.getByText("I've created a modern, minimalist website"), + ).toBeVisible(); + expect(backendRunHistoryUrls).toEqual([]); + }); + test("chats list page shows all threads", async ({ page }) => { mockLangGraphAPI(page, { threads: THREADS }); diff --git a/frontend/tests/unit/core/auth/server.test.ts b/frontend/tests/unit/core/auth/server.test.ts index fea6ef830..1dd02da33 100644 --- a/frontend/tests/unit/core/auth/server.test.ts +++ b/frontend/tests/unit/core/auth/server.test.ts @@ -1,5 +1,6 @@ import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { AUTH_DISABLED_USER } from "@/core/auth/auth-disabled-user"; import { STATIC_WEBSITE_USER } from "@/core/auth/static-user"; vi.mock("next/headers", () => ({ @@ -10,6 +11,8 @@ vi.mock("next/headers", () => ({ const ENV_KEYS = [ "DEER_FLOW_AUTH_DISABLED", + "DEER_FLOW_ENV", + "ENVIRONMENT", "NEXT_PUBLIC_STATIC_WEBSITE_ONLY", ] as const; @@ -51,6 +54,8 @@ describe("getServerSideUser", () => { beforeEach(() => { saved = snapshotEnv(); setEnv("DEER_FLOW_AUTH_DISABLED", undefined); + setEnv("DEER_FLOW_ENV", undefined); + setEnv("ENVIRONMENT", undefined); setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", undefined); }); @@ -74,4 +79,30 @@ describe("getServerSideUser", () => { }); expect(fetchSpy).not.toHaveBeenCalled(); }); + + test("bypasses gateway auth in auth-disabled mode", async () => { + setEnv("DEER_FLOW_AUTH_DISABLED", "1"); + const fetchSpy = vi.fn(() => { + throw new Error("fetch should not be called in auth-disabled mode"); + }); + vi.stubGlobal("fetch", fetchSpy); + + const { getServerSideUser } = await loadFreshServerAuth(); + + await expect(getServerSideUser()).resolves.toEqual({ + tag: "authenticated", + user: AUTH_DISABLED_USER, + }); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + + test("does not enable auth-disabled mode in explicit production environments", async () => { + setEnv("DEER_FLOW_AUTH_DISABLED", "1"); + setEnv("DEER_FLOW_ENV", "production"); + + const { isAuthDisabledMode } = + await import("@/core/auth/auth-disabled-user"); + + expect(isAuthDisabledMode()).toBe(false); + }); });