mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 09:25:57 +00:00
16391e35ab
* support slash skill activation * format slash skill activation * Preserve slash skill activation with uploads * Address slash skill review feedback * Address slash skill follow-up review * Fix lazy slash skill storage resolution * Keep slash skill activation out of system prompt * Address slash skill review issues * fix: harden slash skill command handling * feat(frontend): add slash skill autocomplete * fix: address slash skill review feedback * fix: preserve slash skill text for IM uploads
290 lines
12 KiB
Python
290 lines
12 KiB
Python
"""Middleware for explicit slash skill activation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import html
|
|
import logging
|
|
import uuid
|
|
from collections.abc import Awaitable, Callable
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, override
|
|
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
|
|
from deerflow.skills.slash import parse_slash_skill_reference, resolve_slash_skill
|
|
from deerflow.skills.storage import get_or_new_skill_storage
|
|
from deerflow.skills.storage.skill_storage import SkillStorage
|
|
from deerflow.skills.types import SKILL_MD_FILE
|
|
from deerflow.utils.messages import get_original_user_content_text
|
|
|
|
if TYPE_CHECKING:
|
|
from deerflow.config.app_config import AppConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_SLASH_SKILL_ACTIVATION_KEY = "slash_skill_activation"
|
|
_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY = "slash_skill_activation_target_id"
|
|
_SUMMARY_MESSAGE_NAME = "summary"
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class _Activation:
|
|
skill_name: str
|
|
category: str
|
|
container_file_path: str
|
|
skill_content: str
|
|
content_hash: str
|
|
remaining_text: str
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class _ActivationResolution:
|
|
activation: _Activation | None = None
|
|
failure_message: str | None = None
|
|
|
|
|
|
def is_slash_skill_activation_reminder(message: object) -> bool:
|
|
"""Return whether a message is hidden slash-skill activation context."""
|
|
return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_SLASH_SKILL_ACTIVATION_KEY))
|
|
|
|
|
|
def _is_user_activation_target(message: object) -> bool:
|
|
if not isinstance(message, HumanMessage):
|
|
return False
|
|
if message.name == _SUMMARY_MESSAGE_NAME:
|
|
return False
|
|
if message.additional_kwargs.get("hide_from_ui"):
|
|
return False
|
|
return True
|
|
|
|
|
|
class SkillActivationMiddleware(AgentMiddleware):
|
|
"""Inject full SKILL.md content when the user explicitly types /skill-name."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
available_skills: set[str] | None = None,
|
|
app_config: AppConfig | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self._available_skills = set(available_skills) if available_skills is not None else None
|
|
self._app_config = app_config
|
|
|
|
def _storage(self) -> SkillStorage:
|
|
if self._app_config is not None:
|
|
return get_or_new_skill_storage(app_config=self._app_config)
|
|
return get_or_new_skill_storage()
|
|
|
|
@staticmethod
|
|
def _read_skill_content(skill_file: Path, skills_root: Path) -> str:
|
|
if skill_file.name != SKILL_MD_FILE:
|
|
raise ValueError(f"Expected {SKILL_MD_FILE}, got {skill_file.name}")
|
|
resolved_root = skills_root.resolve()
|
|
resolved_file = skill_file.resolve()
|
|
try:
|
|
resolved_file.relative_to(resolved_root)
|
|
except ValueError as exc:
|
|
raise ValueError("Resolved skill file must stay within the configured skills root.") from exc
|
|
if not resolved_file.is_file():
|
|
raise FileNotFoundError(resolved_file)
|
|
return resolved_file.read_text(encoding="utf-8")
|
|
|
|
def _resolve_activation(self, text: str) -> _ActivationResolution | None:
|
|
reference = parse_slash_skill_reference(text)
|
|
if reference is None:
|
|
return None
|
|
|
|
storage = self._storage()
|
|
skills = storage.load_skills(enabled_only=False)
|
|
skill = next((candidate for candidate in skills if candidate.name == reference.name), None)
|
|
if skill is None:
|
|
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is not installed.")
|
|
if not skill.enabled:
|
|
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is installed but disabled. Enable it before using slash activation.")
|
|
if self._available_skills is not None and reference.name not in self._available_skills:
|
|
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is not available for this agent.")
|
|
|
|
resolved = resolve_slash_skill(
|
|
text,
|
|
skills,
|
|
available_skills=self._available_skills,
|
|
container_base_path=storage.get_container_root(),
|
|
)
|
|
if resolved is None:
|
|
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` could not be resolved.")
|
|
|
|
try:
|
|
skill_content = self._read_skill_content(resolved.skill.skill_file, storage.get_skills_root_path())
|
|
except (OSError, ValueError):
|
|
logger.exception("Failed to read slash-activated skill %s", resolved.skill.name)
|
|
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` could not be loaded safely. Please check the skill installation.")
|
|
|
|
content_hash = hashlib.sha256(skill_content.encode("utf-8")).hexdigest()
|
|
return _ActivationResolution(
|
|
activation=_Activation(
|
|
skill_name=resolved.skill.name,
|
|
category=str(resolved.skill.category),
|
|
container_file_path=resolved.container_file_path,
|
|
skill_content=skill_content,
|
|
content_hash=content_hash,
|
|
remaining_text=resolved.remaining_text,
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
def _build_activation_reminder(activation: _Activation) -> str:
|
|
user_request = activation.remaining_text or ("No additional task text was provided after the slash skill command. Ask the user what they want to do with this skill if the next step is unclear.")
|
|
escaped_user_request = html.escape(user_request, quote=False)
|
|
escaped_skill_content = html.escape(activation.skill_content, quote=False)
|
|
escaped_skill_name = html.escape(activation.skill_name, quote=True)
|
|
escaped_category = html.escape(activation.category, quote=True)
|
|
escaped_path = html.escape(activation.container_file_path, quote=True)
|
|
escaped_content_hash = html.escape(activation.content_hash, quote=True)
|
|
return f"""<slash_skill_activation>
|
|
The user explicitly activated the `{activation.skill_name}` skill for this turn.
|
|
Treat the task text as:
|
|
<user_request>
|
|
{escaped_user_request}
|
|
</user_request>
|
|
|
|
Follow this skill before choosing a general workflow. Load supporting resources from the same skill directory only when needed.
|
|
|
|
<skill name="{escaped_skill_name}" category="{escaped_category}" path="{escaped_path}" sha256="{escaped_content_hash}">
|
|
<skill_content encoding="xml-escaped">
|
|
{escaped_skill_content}
|
|
</skill_content>
|
|
</skill>
|
|
</slash_skill_activation>"""
|
|
|
|
@staticmethod
|
|
def _has_existing_activation_for_target(messages: list, target_index: int, target: HumanMessage) -> bool:
|
|
if target_index <= 0:
|
|
return False
|
|
|
|
if target.id:
|
|
for previous in messages[:target_index]:
|
|
if not is_slash_skill_activation_reminder(previous):
|
|
continue
|
|
target_id = previous.additional_kwargs.get(_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY)
|
|
if target_id == target.id or previous.id == f"{target.id}__slash_activation":
|
|
return True
|
|
|
|
previous = messages[target_index - 1]
|
|
return is_slash_skill_activation_reminder(previous)
|
|
|
|
def _find_activation_target(self, messages: list) -> tuple[int, HumanMessage, _ActivationResolution] | None:
|
|
if not messages:
|
|
return None
|
|
|
|
target_index = next((idx for idx in range(len(messages) - 1, -1, -1) if _is_user_activation_target(messages[idx])), None)
|
|
if target_index is None:
|
|
return None
|
|
|
|
target = messages[target_index]
|
|
if target is None:
|
|
return None
|
|
if self._has_existing_activation_for_target(messages, target_index, target):
|
|
return None
|
|
|
|
content = get_original_user_content_text(target.content, target.additional_kwargs)
|
|
resolution = self._resolve_activation(content)
|
|
if resolution is None:
|
|
return None
|
|
return target_index, target, resolution
|
|
|
|
@staticmethod
|
|
def _record_activation(request: ModelRequest, activation: _Activation, *, hook: str) -> None:
|
|
runtime = getattr(request, "runtime", None)
|
|
context = getattr(runtime, "context", None)
|
|
journal = context.get("__run_journal") if isinstance(context, dict) else None
|
|
if journal is None:
|
|
return
|
|
try:
|
|
journal.record_middleware(
|
|
"skill_activation",
|
|
name="SkillActivationMiddleware",
|
|
hook=hook,
|
|
action="activate",
|
|
changes={
|
|
"skill_name": activation.skill_name,
|
|
"category": activation.category,
|
|
"path": activation.container_file_path,
|
|
"content_hash": activation.content_hash,
|
|
},
|
|
)
|
|
except Exception:
|
|
logger.debug("Failed to record slash skill activation audit event", exc_info=True)
|
|
|
|
def _prepare_model_request(self, request: ModelRequest, *, hook: str) -> ModelRequest | AIMessage | None:
|
|
target_and_resolution = self._find_activation_target(list(request.messages))
|
|
if target_and_resolution is None:
|
|
return None
|
|
|
|
target_index, target, resolution = target_and_resolution
|
|
if resolution.failure_message:
|
|
return AIMessage(content=resolution.failure_message)
|
|
|
|
activation = resolution.activation
|
|
if activation is None:
|
|
return None
|
|
|
|
logger.info(
|
|
"SkillActivationMiddleware: activating slash skill %s category=%s path=%s hash=%s",
|
|
activation.skill_name,
|
|
activation.category,
|
|
activation.container_file_path,
|
|
activation.content_hash,
|
|
)
|
|
self._record_activation(request, activation, hook=hook)
|
|
activation_msg = self._make_activation_message(target, self._build_activation_reminder(activation))
|
|
messages = list(request.messages)
|
|
messages.insert(target_index, activation_msg)
|
|
return request.override(messages=messages)
|
|
|
|
@staticmethod
|
|
def _make_activation_message(target: HumanMessage, activation_content: str) -> HumanMessage:
|
|
stable_id = target.id or str(uuid.uuid4())
|
|
additional_kwargs = {
|
|
"hide_from_ui": True,
|
|
_SLASH_SKILL_ACTIVATION_KEY: True,
|
|
}
|
|
if target.id:
|
|
additional_kwargs[_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY] = target.id
|
|
return HumanMessage(
|
|
content=activation_content,
|
|
id=f"{stable_id}__slash_activation",
|
|
additional_kwargs=additional_kwargs,
|
|
)
|
|
|
|
@override
|
|
def wrap_model_call(
|
|
self,
|
|
request: ModelRequest,
|
|
handler: Callable[[ModelRequest], ModelResponse],
|
|
) -> ModelResponse | AIMessage:
|
|
prepared = self._prepare_model_request(request, hook="wrap_model_call")
|
|
if prepared is None:
|
|
return handler(request)
|
|
if isinstance(prepared, AIMessage):
|
|
return prepared
|
|
return handler(prepared)
|
|
|
|
@override
|
|
async def awrap_model_call(
|
|
self,
|
|
request: ModelRequest,
|
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
) -> ModelResponse | AIMessage:
|
|
prepared = await asyncio.to_thread(self._prepare_model_request, request, hook="awrap_model_call")
|
|
if prepared is None:
|
|
return await handler(request)
|
|
if isinstance(prepared, AIMessage):
|
|
return prepared
|
|
return await handler(prepared)
|