mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
fix(skills): enforce allowed-tools metadata (#2626)
* fix(skills): parse allowed-tools frontmatter * fix(skills): validate allowed-tools metadata * fix(skills): add shared allowed-tools policy * fix(subagents): enforce skill allowed-tools * fix(agent): enforce skill allowed-tools * refactor(skills): dedupe TypeVar and reuse cached enabled skills - Drop redundant module-level TypeVar in tool_policy; rely on PEP 695 syntax. - Expose get_cached_enabled_skills() and have the lead agent reuse it instead of synchronously rescanning skills on every request. * fix(agent): expose config-scoped skill cache * fix(subagents): pass filtered tools explicitly * fix(skills): clean allowed-tools policy feedback
This commit is contained in:
@@ -1,17 +1,20 @@
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.subagents_config import CustomSubagentConfig, SubagentsAppConfig
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.skills.types import Skill, SkillCategory
|
||||
|
||||
|
||||
def _set_skills_cache_state(*, skills=None, active=False, version=0):
|
||||
prompt_module._get_cached_skills_prompt_section.cache_clear()
|
||||
with prompt_module._enabled_skills_lock:
|
||||
prompt_module._enabled_skills_cache = skills
|
||||
prompt_module._enabled_skills_by_config_cache.clear()
|
||||
prompt_module._enabled_skills_refresh_active = active
|
||||
prompt_module._enabled_skills_refresh_version = version
|
||||
prompt_module._enabled_skills_refresh_event.clear()
|
||||
@@ -232,7 +235,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=skill_dir.relative_to(tmp_path),
|
||||
category="custom",
|
||||
category=SkillCategory.CUSTOM,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
@@ -252,6 +255,58 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
_set_skills_cache_state()
|
||||
|
||||
|
||||
def test_explicit_config_enabled_skills_are_cached_by_config_identity(monkeypatch, tmp_path):
|
||||
def make_skill(name: str) -> Skill:
|
||||
skill_dir = tmp_path / name
|
||||
return Skill(
|
||||
name=name,
|
||||
description=f"Description for {name}",
|
||||
license="MIT",
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=skill_dir.relative_to(tmp_path),
|
||||
category=SkillCategory.CUSTOM,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
config = cast(
|
||||
AppConfig,
|
||||
cast(
|
||||
object,
|
||||
SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
),
|
||||
),
|
||||
)
|
||||
load_count = 0
|
||||
|
||||
def fake_get_or_new_skill_storage(**kwargs):
|
||||
nonlocal load_count
|
||||
assert kwargs == {"app_config": config}
|
||||
|
||||
def load_skills(*, enabled_only):
|
||||
nonlocal load_count
|
||||
load_count += 1
|
||||
assert enabled_only is True
|
||||
return [make_skill("cached-skill")]
|
||||
|
||||
return SimpleNamespace(load_skills=load_skills)
|
||||
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fake_get_or_new_skill_storage)
|
||||
_set_skills_cache_state()
|
||||
|
||||
try:
|
||||
first = prompt_module.get_skills_prompt_section(app_config=config)
|
||||
second = prompt_module.get_skills_prompt_section(app_config=config)
|
||||
|
||||
assert "cached-skill" in first
|
||||
assert "cached-skill" in second
|
||||
assert load_count == 1
|
||||
finally:
|
||||
_set_skills_cache_state()
|
||||
|
||||
|
||||
def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_path):
|
||||
started = threading.Event()
|
||||
release = threading.Event()
|
||||
@@ -269,7 +324,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=skill_dir.relative_to(tmp_path),
|
||||
category="custom",
|
||||
category=SkillCategory.CUSTOM,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,12 @@ from deerflow.config.agents_config import AgentConfig
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
|
||||
def _make_skill(name: str) -> Skill:
|
||||
class NamedTool:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
def _make_skill(name: str, allowed_tools: list[str] | None = None) -> Skill:
|
||||
return Skill(
|
||||
name=name,
|
||||
description=f"Description for {name}",
|
||||
@@ -15,6 +20,7 @@ def _make_skill(name: str) -> Skill:
|
||||
skill_file=Path(f"/tmp/{name}/SKILL.md"),
|
||||
relative_path=Path(name),
|
||||
category="public",
|
||||
allowed_tools=allowed_tools,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
@@ -132,6 +138,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
|
||||
@@ -164,3 +171,106 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
assert captured_skills[-1] == {"skill1"}
|
||||
|
||||
|
||||
def test_make_lead_agent_filters_tools_from_available_skills(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted", "legacy"]))
|
||||
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("restricted", ["read_file"]), _make_skill("legacy", None)])
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
|
||||
|
||||
|
||||
def test_make_lead_agent_all_legacy_skills_preserve_all_tools(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
|
||||
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("legacy", None)])
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file")])
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
assert [tool.name for tool in agent_kwargs["tools"]] == ["bash", "read_file", "update_agent"]
|
||||
|
||||
|
||||
def test_make_lead_agent_enforces_allowed_tools_when_skill_cache_is_cold(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
mock_storage = SimpleNamespace(load_skills=lambda *, enabled_only: [_make_skill("restricted", ["read_file"])])
|
||||
|
||||
with prompt_module._enabled_skills_lock:
|
||||
prompt_module._enabled_skills_cache = None
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None, **kwargs: mock_storage)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
|
||||
|
||||
|
||||
def test_make_lead_agent_fails_closed_when_skill_policy_load_fails(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
create_agent_mock = MagicMock()
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", create_agent_mock)
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
|
||||
|
||||
def fail_storage(*args, **kwargs):
|
||||
raise RuntimeError("skill storage unavailable")
|
||||
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fail_storage)
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
|
||||
|
||||
with pytest.raises(RuntimeError, match="skill storage unavailable"):
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
|
||||
create_agent_mock.assert_not_called()
|
||||
|
||||
@@ -86,6 +86,33 @@ def test_parse_license_field(tmp_path):
|
||||
assert skill.license == "MIT"
|
||||
|
||||
|
||||
def test_parse_missing_allowed_tools_returns_none(tmp_path):
|
||||
skill_file = _write_skill(tmp_path, "name: my-skill\ndescription: Test")
|
||||
skill = parse_skill_file(skill_file, category="custom")
|
||||
assert skill is not None
|
||||
assert skill.allowed_tools is None
|
||||
|
||||
|
||||
def test_parse_allowed_tools_list(tmp_path):
|
||||
skill_file = _write_skill(tmp_path, 'name: my-skill\ndescription: Test\nallowed-tools: ["bash", "read_file"]')
|
||||
skill = parse_skill_file(skill_file, category="custom")
|
||||
assert skill is not None
|
||||
assert skill.allowed_tools == ["bash", "read_file"]
|
||||
|
||||
|
||||
def test_parse_empty_allowed_tools_list(tmp_path):
|
||||
skill_file = _write_skill(tmp_path, "name: my-skill\ndescription: Test\nallowed-tools: []")
|
||||
skill = parse_skill_file(skill_file, category="custom")
|
||||
assert skill is not None
|
||||
assert skill.allowed_tools == []
|
||||
|
||||
|
||||
def test_parse_invalid_allowed_tools_returns_none(tmp_path):
|
||||
skill_file = _write_skill(tmp_path, "name: my-skill\ndescription: Test\nallowed-tools: bash")
|
||||
skill = parse_skill_file(skill_file, category="custom")
|
||||
assert skill is None
|
||||
|
||||
|
||||
def test_parse_missing_name_returns_none(tmp_path):
|
||||
"""Skills missing a name field are rejected."""
|
||||
skill_file = _write_skill(tmp_path, "description: A test skill")
|
||||
|
||||
@@ -30,13 +30,47 @@ class TestValidateSkillFrontmatter:
|
||||
def test_valid_with_all_allowed_fields(self, tmp_path):
|
||||
skill_dir = _write_skill(
|
||||
tmp_path,
|
||||
"---\nname: my-skill\ndescription: A skill\nlicense: MIT\nversion: '1.0'\nauthor: test\n---\n\nBody\n",
|
||||
"---\nname: my-skill\ndescription: A skill\nlicense: MIT\nversion: '1.0'\nauthor: test\nallowed-tools: [bash, read_file]\n---\n\nBody\n",
|
||||
)
|
||||
valid, msg, name = _validate_skill_frontmatter(skill_dir)
|
||||
assert valid is True
|
||||
assert msg == "Skill is valid!"
|
||||
assert name == "my-skill"
|
||||
|
||||
def test_allows_empty_allowed_tools(self, tmp_path):
|
||||
skill_dir = _write_skill(
|
||||
tmp_path,
|
||||
"---\nname: my-skill\ndescription: A skill\nallowed-tools: []\n---\n\nBody\n",
|
||||
)
|
||||
valid, msg, name = _validate_skill_frontmatter(skill_dir)
|
||||
assert valid is True
|
||||
assert msg == "Skill is valid!"
|
||||
assert name == "my-skill"
|
||||
|
||||
def test_rejects_allowed_tools_string(self, tmp_path):
|
||||
skill_dir = _write_skill(
|
||||
tmp_path,
|
||||
"---\nname: my-skill\ndescription: A skill\nallowed-tools: bash\n---\n\nBody\n",
|
||||
)
|
||||
valid, msg, name = _validate_skill_frontmatter(skill_dir)
|
||||
assert valid is False
|
||||
assert "allowed-tools" in msg
|
||||
assert str(tmp_path) not in msg
|
||||
assert "SKILL.md" in msg
|
||||
assert name is None
|
||||
|
||||
def test_rejects_allowed_tools_non_string_entry(self, tmp_path):
|
||||
skill_dir = _write_skill(
|
||||
tmp_path,
|
||||
"---\nname: my-skill\ndescription: A skill\nallowed-tools: [bash, 1]\n---\n\nBody\n",
|
||||
)
|
||||
valid, msg, name = _validate_skill_frontmatter(skill_dir)
|
||||
assert valid is False
|
||||
assert "allowed-tools" in msg
|
||||
assert str(tmp_path) not in msg
|
||||
assert "SKILL.md" in msg
|
||||
assert name is None
|
||||
|
||||
def test_missing_skill_md(self, tmp_path):
|
||||
valid, msg, name = _validate_skill_frontmatter(tmp_path)
|
||||
assert valid is False
|
||||
|
||||
@@ -17,11 +17,14 @@ import asyncio
|
||||
import sys
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
# Module names that need to be mocked to break circular imports
|
||||
_MOCKED_MODULE_NAMES = [
|
||||
"deerflow.agents",
|
||||
@@ -32,14 +35,15 @@ _MOCKED_MODULE_NAMES = [
|
||||
"deerflow.sandbox.middleware",
|
||||
"deerflow.sandbox.security",
|
||||
"deerflow.models",
|
||||
"deerflow.skills.storage",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_executor_classes():
|
||||
"""Set up mocked modules and import real executor classes.
|
||||
|
||||
This fixture runs once per session and yields the executor classes.
|
||||
This fixture runs once per test and yields the executor classes.
|
||||
It handles module cleanup to avoid affecting other test files.
|
||||
"""
|
||||
# Save original modules
|
||||
@@ -53,6 +57,9 @@ def _setup_executor_classes():
|
||||
# Set up mocks
|
||||
for name in _MOCKED_MODULE_NAMES:
|
||||
sys.modules[name] = MagicMock()
|
||||
storage_module = ModuleType("deerflow.skills.storage")
|
||||
storage_module.get_or_new_skill_storage = lambda **kwargs: SimpleNamespace(load_skills=lambda *, enabled_only: [])
|
||||
sys.modules["deerflow.skills.storage"] = storage_module
|
||||
|
||||
# Import real classes inside fixture
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
@@ -117,6 +124,26 @@ class MockAIMessage:
|
||||
return msg
|
||||
|
||||
|
||||
class NamedTool:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
def _skill(name: str, allowed_tools: list[str] | None) -> Skill:
|
||||
skill_dir = Path(f"/tmp/{name}")
|
||||
return Skill(
|
||||
name=name,
|
||||
description=f"{name} skill",
|
||||
license=None,
|
||||
skill_dir=skill_dir,
|
||||
skill_file=skill_dir / "SKILL.md",
|
||||
relative_path=Path(name),
|
||||
category="custom",
|
||||
allowed_tools=allowed_tools,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
async def async_iterator(items):
|
||||
"""Helper to create an async iterator from a list."""
|
||||
for item in items:
|
||||
@@ -288,7 +315,7 @@ class TestAgentConstruction:
|
||||
captured["app_config"] = app_config
|
||||
return SimpleNamespace(load_skills=lambda *, enabled_only: [SimpleNamespace(name="demo-skill", skill_file=skill_file)])
|
||||
|
||||
monkeypatch.setattr("deerflow.skills.storage.get_or_new_skill_storage", fake_get_or_new_skill_storage)
|
||||
monkeypatch.setattr(sys.modules["deerflow.skills.storage"], "get_or_new_skill_storage", fake_get_or_new_skill_storage)
|
||||
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
@@ -297,7 +324,8 @@ class TestAgentConstruction:
|
||||
thread_id="test-thread",
|
||||
)
|
||||
|
||||
messages = await executor._load_skill_messages()
|
||||
skills = await executor._load_skills()
|
||||
messages = await executor._load_skill_messages(skills)
|
||||
|
||||
assert captured["app_config"] is app_config
|
||||
assert len(messages) == 1
|
||||
@@ -487,6 +515,115 @@ class TestAsyncExecutionPath:
|
||||
assert "Task" in result.result
|
||||
|
||||
|
||||
class TestSkillAllowedTools:
|
||||
@pytest.mark.anyio
|
||||
async def test_skill_allowed_tools_union_filters_agent_tools(self, classes, base_config, mock_agent, msg):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("a", ["bash"]), _skill("b", ["read_file"])]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
await executor._aexecute("Task")
|
||||
|
||||
create_agent_mock.assert_called_once()
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash", "read_file"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_all_missing_allowed_tools_preserves_legacy_allow_all(self, classes, base_config, mock_agent, msg):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("legacy-a", None), _skill("legacy-b", None)]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
await executor._aexecute("Task")
|
||||
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash", "read_file", "web_search"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_mixed_missing_allowed_tools_does_not_disable_explicit_restrictions(self, classes, base_config, mock_agent, msg):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("legacy", None), _skill("restricted", ["bash"])]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
await executor._aexecute("Task")
|
||||
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_mixed_missing_allowed_tools_order_does_not_disable_explicit_restrictions(self, classes, base_config, mock_agent, msg):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("restricted", ["bash"]), _skill("legacy", None)]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
await executor._aexecute("Task")
|
||||
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_allowed_tools_contributes_no_tools(self, classes, base_config, mock_agent, msg, caplog):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
|
||||
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
|
||||
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
return [_skill("empty", []), _skill("reader", ["read_file"])]
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock, caplog.at_level("INFO"):
|
||||
await executor._aexecute("Task")
|
||||
|
||||
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["read_file"]
|
||||
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
|
||||
assert "declared empty allowed-tools" in caplog.text
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skill_load_failure_fails_without_creating_agent(self, classes, base_config, mock_agent):
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
executor = SubagentExecutor(config=base_config, tools=[NamedTool("bash")], thread_id="test-thread")
|
||||
|
||||
async def load_skills():
|
||||
raise RuntimeError("skill storage unavailable")
|
||||
|
||||
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
|
||||
result = await executor._aexecute("Task")
|
||||
|
||||
assert result.status == classes["SubagentStatus"].FAILED
|
||||
assert result.error == "skill storage unavailable"
|
||||
create_agent_mock.assert_not_called()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Sync Execution Path Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user