mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
fix(tool-search): reliably hide deferred MCP schemas by removing the ContextVar (closures + graph state) (#3342)
* feat(tool-search): add hash-scoped promoted state to ThreadState * feat(tool-search): add immutable DeferredToolCatalog with stable hash * feat(tool-search): add build_deferred_tool_setup + Command-writing tool_search * refactor(tool-search): replace deferred-tool ContextVar with closures + graph state (#3272) Build the deferred catalog + tool_search tool per agent from the policy-filtered tool list (after skill allowed-tools), pass deferred_names + catalog_hash explicitly to DeferredToolFilterMiddleware and the prompt, and record promotions in ThreadState.promoted (scoped by catalog_hash) via a Command-returning tool_search. Removes DeferredToolRegistry and the _registry_var ContextVar so deferral no longer depends on build/execute sharing an async context. MCP tools are tagged with metadata[deerflow_mcp]; client.py assembles deferral the same way. Catalog is built AFTER tool-policy filtering (no policy-excluded tool can leak via tool_search) and assembly is fail-closed. Migrate tests off the deleted registry APIs; delete the obsolete ContextVar-based #2884 regression (re-covered by state-based tests in a follow-up). * test(tool-search): lock tool_search promotion into next model turn via graph state * test(tool-search): cross-context, policy-leak, fail-closed, #2884 isolation regressions * test(tool-search): align real-LLM e2e with closure-based deferred setup * docs: update DeferredToolFilterMiddleware description for closure+state design * style(tests): drop unused import in test_deferred_setup (ruff) * test(tool-search): harden merge_promoted + replace tautological catalog test From independent code review: - merge_promoted: use existing.get("catalog_hash") so a forward-incompatible or externally-injected persisted promoted dict triggers a replace instead of a KeyError crash; add regression test for the malformed-existing case. - test_deferred_catalog: replace the `== [] or True` tautology (a test that could never fail) with a deterministic invalid-regex->literal-fallback check (positive match on calc + negative empty match). - DeferredToolCatalog: comment why frozen-without-slots is required for the cached_property hash/names fields (adding slots=True would break them). * fix(tool-search): read tool_search.enabled from self._app_config in client DeerFlowClient._ensure_agent called get_app_config() directly to read tool_search.enabled, but the client already resolves and stores its config as self._app_config at construction (and uses it everywhere else). The bare call re-resolves config from disk at agent-build time, which raises FileNotFoundError in environments without a config.yaml (CI) — test_client.py's fixture only patches get_app_config during __init__, so the later call hit the real loader. Use self._app_config, matching the rest of the client. * test(tool-search): lock tool_search post-policy append ordering tool_search is appended after skill-allowlist filtering, so the allowlist can no longer deny it by name. Lock the intended contract: it only appears when allowed MCP tools survive the filter, and its catalog (derived from the already policy-filtered list) can never expose a denied tool. Addresses the ordering observation from the Copilot review on #3342.
This commit is contained in:
@@ -0,0 +1,66 @@
|
||||
import pytest
|
||||
from langchain_core.tools import tool as as_tool
|
||||
|
||||
from deerflow.tools.builtins.tool_search import DeferredToolCatalog
|
||||
|
||||
|
||||
@as_tool
|
||||
def alpha_search(query: str) -> str:
|
||||
"Search alpha records by query."
|
||||
return query
|
||||
|
||||
|
||||
@as_tool
|
||||
def beta_translate(text: str) -> str:
|
||||
"Translate beta text."
|
||||
return text
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def catalog() -> DeferredToolCatalog:
|
||||
return DeferredToolCatalog((alpha_search, beta_translate))
|
||||
|
||||
|
||||
def test_names(catalog):
|
||||
assert catalog.names == frozenset({"alpha_search", "beta_translate"})
|
||||
|
||||
|
||||
def test_search_select(catalog):
|
||||
got = catalog.search("select:alpha_search")
|
||||
assert [t.name for t in got] == ["alpha_search"]
|
||||
|
||||
|
||||
def test_search_plus_keyword(catalog):
|
||||
got = catalog.search("+beta translate")
|
||||
assert [t.name for t in got] == ["beta_translate"]
|
||||
|
||||
|
||||
def test_search_regex_on_description(catalog):
|
||||
got = catalog.search("translate")
|
||||
assert "beta_translate" in [t.name for t in got]
|
||||
|
||||
|
||||
def test_search_invalid_regex_falls_back_to_literal():
|
||||
@as_tool
|
||||
def calc(expr: str) -> str:
|
||||
"Compute sum(a, b) style expressions."
|
||||
return expr
|
||||
|
||||
cat = DeferredToolCatalog((calc, alpha_search))
|
||||
# "sum(" is an invalid regex (unbalanced paren). search() must not raise; it
|
||||
# falls back to a literal match, which finds calc's "sum(" in its description.
|
||||
assert [t.name for t in cat.search("sum(")] == ["calc"]
|
||||
# A literal with no match is deterministically empty (and still must not raise).
|
||||
assert cat.search("zzz(") == []
|
||||
|
||||
|
||||
def test_hash_stable_across_instances():
|
||||
c1 = DeferredToolCatalog((alpha_search, beta_translate))
|
||||
c2 = DeferredToolCatalog((beta_translate, alpha_search))
|
||||
assert c1.hash == c2.hash
|
||||
|
||||
|
||||
def test_hash_changes_with_membership():
|
||||
c1 = DeferredToolCatalog((alpha_search, beta_translate))
|
||||
c2 = DeferredToolCatalog((alpha_search,))
|
||||
assert c1.hash != c2.hash
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for DeferredToolFilterMiddleware (closure deferred-set + state promotion)."""
|
||||
|
||||
from langchain_core.tools import tool as as_tool
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
|
||||
@as_tool
|
||||
def mcp_a(x: str) -> str:
|
||||
"a"
|
||||
return x
|
||||
|
||||
|
||||
@as_tool
|
||||
def mcp_b(x: str) -> str:
|
||||
"b"
|
||||
return x
|
||||
|
||||
|
||||
@as_tool
|
||||
def active_c(x: str) -> str:
|
||||
"c"
|
||||
return x
|
||||
|
||||
|
||||
class _Req:
|
||||
def __init__(self, tools, state):
|
||||
self.tools = tools
|
||||
self.state = state
|
||||
self.overridden = None
|
||||
|
||||
def override(self, tools):
|
||||
self.overridden = tools
|
||||
return self
|
||||
|
||||
|
||||
def _mw():
|
||||
return DeferredToolFilterMiddleware(frozenset({"mcp_a", "mcp_b"}), "h1")
|
||||
|
||||
|
||||
def test_hides_all_deferred_when_no_promotion():
|
||||
req = _Req([mcp_a, mcp_b, active_c], {})
|
||||
out = _mw()._filter_tools(req)
|
||||
assert [t.name for t in out.overridden] == ["active_c"]
|
||||
|
||||
|
||||
def test_promoted_under_matching_hash_passes_through():
|
||||
req = _Req([mcp_a, mcp_b, active_c], {"promoted": {"catalog_hash": "h1", "names": ["mcp_a"]}})
|
||||
out = _mw()._filter_tools(req)
|
||||
assert {t.name for t in out.overridden} == {"mcp_a", "active_c"}
|
||||
|
||||
|
||||
def test_promotion_ignored_when_hash_mismatch():
|
||||
req = _Req([mcp_a, mcp_b, active_c], {"promoted": {"catalog_hash": "STALE", "names": ["mcp_a"]}})
|
||||
out = _mw()._filter_tools(req)
|
||||
assert [t.name for t in out.overridden] == ["active_c"]
|
||||
|
||||
|
||||
def test_no_deferred_names_is_noop():
|
||||
req = _Req([active_c], {})
|
||||
out = DeferredToolFilterMiddleware(frozenset(), "h1")._filter_tools(req)
|
||||
assert out.overridden is None # returned unchanged
|
||||
|
||||
|
||||
def test_blocked_message_for_unpromoted_deferred_call():
|
||||
class _TCReq:
|
||||
tool_call = {"name": "mcp_a", "id": "tc1"}
|
||||
state = {}
|
||||
|
||||
msg = _mw()._blocked_tool_message(_TCReq())
|
||||
assert msg is not None and msg.status == "error" and "tool_search" in msg.content
|
||||
|
||||
|
||||
def test_no_block_for_promoted_call():
|
||||
class _TCReq:
|
||||
tool_call = {"name": "mcp_a", "id": "tc1"}
|
||||
state = {"promoted": {"catalog_hash": "h1", "names": ["mcp_a"]}}
|
||||
|
||||
assert _mw()._blocked_tool_message(_TCReq()) is None
|
||||
|
||||
|
||||
def test_no_block_for_non_deferred_call():
|
||||
class _TCReq:
|
||||
tool_call = {"name": "active_c", "id": "tc1"}
|
||||
state = {}
|
||||
|
||||
assert _mw()._blocked_tool_message(_TCReq()) is None
|
||||
@@ -0,0 +1,77 @@
|
||||
"""End-to-end: tool_search promotes a deferred tool into the next model turn.
|
||||
|
||||
Locks the full loop through a real ``create_agent`` graph:
|
||||
turn 1 -> deferred MCP tools hidden from bind_tools; model calls tool_search
|
||||
ToolNode-> tool_search returns Command(update={"promoted": {...}}) -> state
|
||||
turn 2 -> middleware reads state["promoted"] (hash-scoped) -> the searched
|
||||
tool's schema is now bound; un-searched deferred tools stay hidden
|
||||
|
||||
This is the behavior #3272's redesign depends on (no ContextVar): promotion
|
||||
flows through graph state, so it works regardless of build/execute context.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import tool as as_tool
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
|
||||
|
||||
|
||||
@as_tool
|
||||
def active_tool(x: str) -> str:
|
||||
"An always-active tool."
|
||||
return x
|
||||
|
||||
|
||||
@as_tool
|
||||
def mcp_calc(expression: str) -> str:
|
||||
"Evaluate arithmetic."
|
||||
return expression
|
||||
|
||||
|
||||
@as_tool
|
||||
def mcp_other(x: str) -> str:
|
||||
"Another deferred MCP tool."
|
||||
return x
|
||||
|
||||
|
||||
def _tag(t):
|
||||
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
|
||||
return t
|
||||
|
||||
|
||||
def test_tool_search_promotes_into_next_turn():
|
||||
bound: list[list[str]] = []
|
||||
|
||||
class RecordingModel(GenericFakeChatModel):
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
bound.append([getattr(t, "name", None) for t in tools])
|
||||
return self
|
||||
|
||||
setup = build_deferred_tool_setup([active_tool, _tag(mcp_calc), _tag(mcp_other)], enabled=True)
|
||||
turn1 = AIMessage(content="", tool_calls=[{"name": "tool_search", "args": {"query": "select:mcp_calc"}, "id": "c1", "type": "tool_call"}])
|
||||
turn2 = AIMessage(content="done")
|
||||
model = RecordingModel(messages=iter([turn1, turn2]))
|
||||
|
||||
graph = create_agent(
|
||||
model=model,
|
||||
tools=[active_tool, mcp_calc, mcp_other, setup.tool_search_tool],
|
||||
middleware=[DeferredToolFilterMiddleware(setup.deferred_names, setup.catalog_hash)],
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
result = asyncio.run(graph.ainvoke({"messages": [HumanMessage(content="use the deferred calculator")]}))
|
||||
|
||||
assert len(bound) >= 2, f"expected >=2 model binds, got {bound}"
|
||||
# Turn 1: both deferred MCP tools hidden.
|
||||
assert "mcp_calc" not in bound[0] and "mcp_other" not in bound[0]
|
||||
# Turn 2: the searched tool is promoted (visible); the un-searched one stays hidden.
|
||||
assert "mcp_calc" in bound[1]
|
||||
assert "mcp_other" not in bound[1]
|
||||
# Promotion recorded in graph state, scoped by catalog hash.
|
||||
assert result["promoted"] == {"catalog_hash": setup.catalog_hash, "names": ["mcp_calc"]}
|
||||
@@ -0,0 +1,66 @@
|
||||
from langchain_core.tools import tool as as_tool
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.tools.builtins.tool_search import DeferredToolCatalog, _is_mcp_tool, build_deferred_tool_setup, build_tool_search_tool
|
||||
|
||||
|
||||
@as_tool
|
||||
def mcp_calc(expression: str) -> str:
|
||||
"Evaluate arithmetic."
|
||||
return expression
|
||||
|
||||
|
||||
@as_tool
|
||||
def local_echo(text: str) -> str:
|
||||
"Echo text."
|
||||
return text
|
||||
|
||||
|
||||
def _tag_mcp(t):
|
||||
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
|
||||
return t
|
||||
|
||||
|
||||
def test_is_mcp_tool_reads_metadata():
|
||||
assert _is_mcp_tool(_tag_mcp(mcp_calc)) is True
|
||||
assert _is_mcp_tool(local_echo) is False
|
||||
|
||||
|
||||
def test_setup_disabled_returns_empty():
|
||||
setup = build_deferred_tool_setup([_tag_mcp(mcp_calc), local_echo], enabled=False)
|
||||
assert setup.tool_search_tool is None
|
||||
assert setup.deferred_names == frozenset()
|
||||
assert setup.catalog_hash is None
|
||||
|
||||
|
||||
def test_setup_no_mcp_returns_empty():
|
||||
setup = build_deferred_tool_setup([local_echo], enabled=True)
|
||||
assert setup.tool_search_tool is None
|
||||
assert setup.deferred_names == frozenset()
|
||||
|
||||
|
||||
def test_setup_builds_from_mcp_survivors():
|
||||
setup = build_deferred_tool_setup([_tag_mcp(mcp_calc), local_echo], enabled=True)
|
||||
assert setup.deferred_names == frozenset({"mcp_calc"})
|
||||
assert setup.tool_search_tool is not None
|
||||
assert setup.tool_search_tool.name == "tool_search"
|
||||
assert setup.catalog_hash
|
||||
|
||||
|
||||
def test_tool_search_returns_command_with_hash_scoped_promotion():
|
||||
catalog = DeferredToolCatalog((mcp_calc,))
|
||||
ts = build_tool_search_tool(catalog)
|
||||
out = ts.invoke({"type": "tool_call", "name": "tool_search", "args": {"query": "select:mcp_calc"}, "id": "tc1"})
|
||||
assert isinstance(out, Command)
|
||||
promoted = out.update["promoted"]
|
||||
assert promoted == {"catalog_hash": catalog.hash, "names": ["mcp_calc"]}
|
||||
msg = out.update["messages"][0]
|
||||
assert msg.tool_call_id == "tc1" and msg.name == "tool_search"
|
||||
assert "mcp_calc" in msg.content
|
||||
|
||||
|
||||
def test_tool_search_no_match_empty_names():
|
||||
catalog = DeferredToolCatalog((mcp_calc,))
|
||||
ts = build_tool_search_tool(catalog)
|
||||
out = ts.invoke({"type": "tool_call", "name": "tool_search", "args": {"query": "select:nonexistent"}, "id": "tc2"})
|
||||
assert out.update["promoted"]["names"] == []
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Regressions for the deferred-tool redesign (#3272).
|
||||
|
||||
- Cross-context: building the graph in one async context and running it in a
|
||||
sibling context (that did NOT inherit the build context) must still hide
|
||||
deferred tools. The old ContextVar implementation failed this; the closure +
|
||||
graph-state implementation must pass.
|
||||
- Policy leak (Finding 1): a tool removed by policy must not be searchable.
|
||||
- Fail-closed (Finding 2): a wiring regression must raise, not silently leak.
|
||||
- #2884 isolation: a second (subagent-style) setup build must not affect the
|
||||
lead agent's middleware/promotion.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import tool as as_tool
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.tools.builtins.tool_search import DeferredToolSetup, build_deferred_tool_setup
|
||||
|
||||
|
||||
@as_tool
|
||||
def active_tool(x: str) -> str:
|
||||
"active"
|
||||
return x
|
||||
|
||||
|
||||
@as_tool
|
||||
def mcp_secret(x: str) -> str:
|
||||
"deferred mcp tool — must be hidden from bind_tools until promoted"
|
||||
return x
|
||||
|
||||
|
||||
def _tag(t):
|
||||
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
|
||||
return t
|
||||
|
||||
|
||||
_BOUND: list[list[str]] = []
|
||||
|
||||
|
||||
class _RecordingModel(GenericFakeChatModel):
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
_BOUND.append([getattr(t, "name", None) for t in tools])
|
||||
return self
|
||||
|
||||
|
||||
def _build_graph():
|
||||
filtered = [active_tool, _tag(mcp_secret)]
|
||||
setup = build_deferred_tool_setup(filtered, enabled=True)
|
||||
final = [*filtered, setup.tool_search_tool]
|
||||
model = _RecordingModel(messages=iter([AIMessage(content="done")] * 4))
|
||||
return create_agent(
|
||||
model=model,
|
||||
tools=final,
|
||||
middleware=[DeferredToolFilterMiddleware(setup.deferred_names, setup.catalog_hash)],
|
||||
system_prompt="t",
|
||||
)
|
||||
|
||||
|
||||
async def _abuild():
|
||||
return _build_graph()
|
||||
|
||||
|
||||
def test_deferred_hidden_when_built_and_run_in_different_contexts():
|
||||
"""Build in one task, run in a sibling task that did not inherit it."""
|
||||
_BOUND.clear()
|
||||
|
||||
async def main():
|
||||
graph = await asyncio.create_task(_abuild())
|
||||
|
||||
async def run():
|
||||
await graph.ainvoke({"messages": [HumanMessage(content="hi")]})
|
||||
|
||||
await asyncio.create_task(run())
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
assert _BOUND, "model was never bound"
|
||||
assert not any("mcp_secret" in names for names in _BOUND), f"deferred MCP tool leaked into bind_tools: {_BOUND}"
|
||||
|
||||
|
||||
def test_policy_excluded_mcp_tool_not_in_catalog():
|
||||
"""Finding 1: a tool removed by policy is not searchable/exposed."""
|
||||
filtered_after_policy = [active_tool] # mcp_secret denied by skill allowed-tools
|
||||
setup = build_deferred_tool_setup(filtered_after_policy, enabled=True)
|
||||
assert setup.deferred_names == frozenset()
|
||||
assert setup.tool_search_tool is None
|
||||
|
||||
|
||||
def test_fail_closed_when_mcp_survives_without_setup(monkeypatch):
|
||||
"""Finding 2: simulate a wiring regression and assert it fails loudly.
|
||||
|
||||
``_assemble_deferred`` lazy-imports ``build_deferred_tool_setup`` from the
|
||||
source module, so patch it there (not on the agent module).
|
||||
"""
|
||||
from deerflow.agents.lead_agent import agent as agentmod
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.tools.builtins.tool_search.build_deferred_tool_setup",
|
||||
lambda tools, *, enabled: DeferredToolSetup(None, frozenset(), None),
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="fail-closed"):
|
||||
agentmod._assemble_deferred([_tag(mcp_secret)], enabled=True)
|
||||
|
||||
|
||||
def test_subagent_reentry_does_not_touch_lead_state():
|
||||
"""#2884: building a second (subagent) setup must not affect the lead's
|
||||
middleware. With no shared registry/ContextVar, the lead middleware depends
|
||||
only on its own deferred_names + the passed state."""
|
||||
lead_setup = build_deferred_tool_setup([active_tool, _tag(mcp_secret)], enabled=True)
|
||||
mw = DeferredToolFilterMiddleware(lead_setup.deferred_names, lead_setup.catalog_hash)
|
||||
|
||||
# Simulate a subagent build re-entering tool assembly with its own setup.
|
||||
_ = build_deferred_tool_setup([_tag(mcp_secret)], enabled=True)
|
||||
|
||||
class _Req:
|
||||
def __init__(self):
|
||||
self.tools = [active_tool, mcp_secret]
|
||||
self.state = {"promoted": {"catalog_hash": lead_setup.catalog_hash, "names": ["mcp_secret"]}}
|
||||
|
||||
def override(self, tools):
|
||||
self.tools = tools
|
||||
return self
|
||||
|
||||
out = mw._filter_tools(_Req())
|
||||
assert {t.name for t in out.tools} == {"active_tool", "mcp_secret"} # promotion intact
|
||||
|
||||
|
||||
def _make_skill(allowed_tools):
|
||||
"""Skill carrying an explicit allowed-tools allowlist (None = legacy allow-all)."""
|
||||
return Skill(
|
||||
name="s",
|
||||
description="d",
|
||||
license="MIT",
|
||||
skill_dir=Path("/tmp/s"),
|
||||
skill_file=Path("/tmp/s/SKILL.md"),
|
||||
relative_path=Path("s"),
|
||||
category="public",
|
||||
allowed_tools=allowed_tools,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
def test_policy_denied_mcp_yields_no_tool_search_end_to_end():
|
||||
"""An allowlist that denies the MCP tool gates it end-to-end: after the real
|
||||
policy filter no MCP tool survives, so ``_assemble_deferred`` adds no
|
||||
tool_search (and does not fail-closed, because no MCP tool leaked through)."""
|
||||
from deerflow.agents.lead_agent import agent as agentmod
|
||||
|
||||
filtered = filter_tools_by_skill_allowed_tools([active_tool, _tag(mcp_secret)], [_make_skill(["active_tool"])])
|
||||
final_tools, setup = agentmod._assemble_deferred(filtered, enabled=True)
|
||||
|
||||
assert [t.name for t in final_tools] == ["active_tool"]
|
||||
assert "tool_search" not in {t.name for t in final_tools}
|
||||
assert setup.deferred_names == frozenset()
|
||||
|
||||
|
||||
def test_tool_search_appended_after_policy_but_never_exposes_denied_tool():
|
||||
"""Intentional behavior change vs. upstream (Copilot review on PR #3342).
|
||||
|
||||
``tool_search`` is appended AFTER skill-allowlist filtering, so an allowlist
|
||||
can no longer deny ``tool_search`` by name. This is safe by construction: the
|
||||
tool only appears when allowed MCP tools survive the filter, and its catalog
|
||||
is derived from the already policy-filtered list — so it can never expose a
|
||||
tool the allowlist denied. Locks that contract so the ordering cannot regress.
|
||||
"""
|
||||
from deerflow.agents.lead_agent import agent as agentmod
|
||||
|
||||
allowed = ["active_tool", "mcp_secret"] # permits the MCP tool, does NOT list tool_search
|
||||
filtered = filter_tools_by_skill_allowed_tools([active_tool, _tag(mcp_secret)], [_make_skill(allowed)])
|
||||
final_tools, setup = agentmod._assemble_deferred(filtered, enabled=True)
|
||||
|
||||
names = {t.name for t in final_tools}
|
||||
assert "tool_search" in names # appended despite not being in the allowlist
|
||||
assert setup.deferred_names == frozenset({"mcp_secret"})
|
||||
assert set(setup.deferred_names) <= set(allowed) # catalog never exceeds the allowlist
|
||||
@@ -82,15 +82,6 @@ def fake_translator(text: str, target_lang: str) -> str:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry_between_tests():
|
||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
||||
|
||||
reset_deferred_registry()
|
||||
yield
|
||||
reset_deferred_registry()
|
||||
|
||||
|
||||
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
|
||||
@@ -145,6 +136,7 @@ async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
_patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator])
|
||||
@@ -158,18 +150,17 @@ async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch:
|
||||
Use this whenever the user asks you to delegate work — pass a short
|
||||
description as ``prompt``.
|
||||
"""
|
||||
# ``task_tool`` does this internally. Whether the registry-reset that
|
||||
# used to happen here actually leaks back to the parent task depends
|
||||
# on asyncio's implicit context-copying semantics (gather creates
|
||||
# child tasks with copied contexts, so reset_deferred_registry is
|
||||
# task-local) — but the fix in this PR is what GUARANTEES the
|
||||
# promotion sticks regardless of which integration path triggers a
|
||||
# re-entrant ``get_available_tools`` call.
|
||||
# ``task_tool`` does this internally. With the closure + graph-state
|
||||
# design there is no shared registry/ContextVar, so a re-entrant
|
||||
# ``get_available_tools`` call here cannot affect the lead agent's
|
||||
# deferred middleware or its promotion state.
|
||||
get_available_tools(subagent_enabled=False)
|
||||
_calls.append(f"fake_subagent_trigger:{prompt}")
|
||||
return "subagent completed"
|
||||
|
||||
tools = get_available_tools() + [fake_subagent_trigger]
|
||||
raw_tools = get_available_tools() + [fake_subagent_trigger]
|
||||
setup = build_deferred_tool_setup(raw_tools, enabled=True)
|
||||
tools = [*raw_tools, setup.tool_search_tool] if setup.tool_search_tool else raw_tools
|
||||
|
||||
model = ChatOpenAI(
|
||||
model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"),
|
||||
@@ -195,7 +186,7 @@ async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch:
|
||||
graph = create_agent(
|
||||
model=model,
|
||||
tools=tools,
|
||||
middleware=[DeferredToolFilterMiddleware()],
|
||||
middleware=[DeferredToolFilterMiddleware(setup.deferred_names, setup.catalog_hash)],
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,390 +0,0 @@
|
||||
"""Reproduce + regression-guard issue #2884.
|
||||
|
||||
Hypothesis from the issue:
|
||||
``tools.tools.get_available_tools`` unconditionally calls
|
||||
``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry``
|
||||
every time it is invoked. If anything calls ``get_available_tools`` again
|
||||
during the same async context (after the agent has promoted tools via
|
||||
``tool_search``), the promotion is wiped and the next model call hides the
|
||||
tool's schema again.
|
||||
|
||||
These tests pin two things:
|
||||
|
||||
A. **At the unit boundary** — verify the failure mode directly. Promote a
|
||||
tool in the registry, then call ``get_available_tools`` again and observe
|
||||
that the ContextVar registry is reset and the promotion is lost.
|
||||
|
||||
B. **At the graph-execution boundary** — drive a real ``create_agent`` graph
|
||||
with the real ``DeferredToolFilterMiddleware`` through two model turns.
|
||||
The first turn calls ``tool_search`` which promotes a tool. The second
|
||||
turn must see that tool's schema in ``request.tools``. If
|
||||
``get_available_tools`` were to run again between the two turns and reset
|
||||
the registry, the second turn's filter would strip the tool.
|
||||
|
||||
Strategy: use the production ``deerflow.tools.tools.get_available_tools``
|
||||
unmodified; mock only the LLM and the MCP tool source. Patch
|
||||
``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that
|
||||
``get_available_tools`` resolves via lazy import) to return our fixture
|
||||
tools so we don't need a real MCP server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import tool as as_tool
|
||||
|
||||
|
||||
class FakeToolCallingModel(FakeMessagesListChatModel):
|
||||
"""FakeMessagesListChatModel + no-op bind_tools so create_agent works."""
|
||||
|
||||
def bind_tools( # type: ignore[override]
|
||||
self,
|
||||
tools: Any,
|
||||
*,
|
||||
tool_choice: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@as_tool
|
||||
def fake_mcp_search(query: str) -> str:
|
||||
"""Pretend to search a knowledge base for the given query."""
|
||||
return f"results for {query}"
|
||||
|
||||
|
||||
@as_tool
|
||||
def fake_mcp_fetch(url: str) -> str:
|
||||
"""Pretend to fetch a page at the given URL."""
|
||||
return f"content of {url}"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _supply_env(monkeypatch: pytest.MonkeyPatch):
|
||||
"""config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used")
|
||||
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_deferred_registry_between_tests():
|
||||
"""Each test must start with a clean ContextVar.
|
||||
|
||||
The registry lives in a module-level ContextVar with no per-task isolation
|
||||
in a synchronous test runner, so one test's promotion can leak into the
|
||||
next and silently break filter assertions.
|
||||
"""
|
||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
||||
|
||||
reset_deferred_registry()
|
||||
yield
|
||||
reset_deferred_registry()
|
||||
|
||||
|
||||
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
|
||||
"""Make get_available_tools believe an MCP server is registered.
|
||||
|
||||
Build a real ``ExtensionsConfig`` with one enabled MCP server entry so
|
||||
that both ``AppConfig.from_file`` (which calls
|
||||
``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools``
|
||||
(which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``)
|
||||
see a valid instance. Then point the MCP tool cache at our fixture tools.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
|
||||
real_ext = ExtensionsConfig(
|
||||
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
classmethod(lambda cls: real_ext),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
|
||||
|
||||
|
||||
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Force config.tool_search.enabled=True without touching the yaml.
|
||||
|
||||
Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs``
|
||||
which permanently mutates module-level singletons (``_memory_config``,
|
||||
``_title_config``, …) to match the developer's ``config.yaml`` — even
|
||||
after pytest restores our patch. That leaks across tests later in the
|
||||
run that rely on those singletons' DEFAULTS (e.g. memory queue tests
|
||||
require ``_memory_config.enabled = True``, which is the dataclass default
|
||||
but FALSE in the actual yaml).
|
||||
|
||||
Build a minimal mock AppConfig instead and never call the real loader.
|
||||
"""
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig
|
||||
|
||||
mock_cfg = AppConfig.model_construct(
|
||||
log_level="info",
|
||||
models=[],
|
||||
tools=[],
|
||||
tool_groups=[],
|
||||
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
|
||||
tool_search=ToolSearchConfig(enabled=True),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Section A — direct unit-level reproduction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Re-entrant ``get_available_tools()`` must preserve prior promotions.
|
||||
|
||||
Step 1: call get_available_tools() — registers MCP tools as deferred.
|
||||
Step 2: simulate the agent calling tool_search by promoting one tool.
|
||||
Step 3: call get_available_tools() again (the same code path
|
||||
``task_tool`` exercises mid-run).
|
||||
|
||||
Assertion: after step 3, the promoted tool is STILL promoted (not
|
||||
re-deferred). On ``main`` before the fix, step 3's
|
||||
``reset_deferred_registry()`` wiped the promotion and re-registered
|
||||
every MCP tool as deferred — this assertion fired with
|
||||
``REGRESSION (#2884)``.
|
||||
"""
|
||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
|
||||
_force_tool_search_enabled(monkeypatch)
|
||||
|
||||
# Step 1: first call — both MCP tools start deferred
|
||||
get_available_tools()
|
||||
reg1 = get_deferred_registry()
|
||||
assert reg1 is not None
|
||||
assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"}
|
||||
|
||||
# Step 2: simulate tool_search promoting one of them
|
||||
reg1.promote({"fake_mcp_search"})
|
||||
assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search"
|
||||
|
||||
# Step 3: second call — registry must NOT silently undo the promotion
|
||||
get_available_tools()
|
||||
reg2 = get_deferred_registry()
|
||||
assert reg2 is not None
|
||||
deferred_after = {e.name for e in reg2.entries}
|
||||
assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Section B — graph-execution reproduction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ToolSearchPromotingModel(FakeToolCallingModel):
|
||||
"""Two-turn model that:
|
||||
|
||||
Turn 1 → emit a tool_call for ``tool_search`` (the real one)
|
||||
Turn 2 → emit a tool_call for ``fake_mcp_search`` (the promoted tool)
|
||||
|
||||
Records the tools it received on each turn so the test can inspect what
|
||||
DeferredToolFilterMiddleware actually fed to ``bind_tools``.
|
||||
"""
|
||||
|
||||
bound_tools_per_turn: list[list[str]] = []
|
||||
|
||||
def bind_tools( # type: ignore[override]
|
||||
self,
|
||||
tools: Any,
|
||||
*,
|
||||
tool_choice: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
# Record the tool names the model would see in this turn
|
||||
names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools]
|
||||
self.bound_tools_per_turn.append(names)
|
||||
return self
|
||||
|
||||
|
||||
def _build_promoting_model() -> _ToolSearchPromotingModel:
|
||||
return _ToolSearchPromotingModel(
|
||||
responses=[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "tool_search",
|
||||
"args": {"query": "select:fake_mcp_search"},
|
||||
"id": "call_search_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "fake_mcp_search",
|
||||
"args": {"query": "hello"},
|
||||
"id": "call_mcp_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(content="all done"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch):
|
||||
"""End-to-end: drive a real create_agent graph through two turns.
|
||||
|
||||
Without the fix, the second-turn bind_tools call should NOT contain
|
||||
fake_mcp_search (because DeferredToolFilterMiddleware sees it in the
|
||||
registry and strips it). With the fix, the model sees the schema and can
|
||||
invoke it.
|
||||
"""
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
|
||||
_force_tool_search_enabled(monkeypatch)
|
||||
|
||||
tools = get_available_tools()
|
||||
# Sanity: the assembled tool list includes the deferred tools (they're in
|
||||
# bind_tools but DeferredToolFilterMiddleware strips deferred ones before
|
||||
# they reach the model)
|
||||
tool_names = {getattr(t, "name", "") for t in tools}
|
||||
assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names
|
||||
|
||||
model = _build_promoting_model()
|
||||
model.bound_tools_per_turn = [] # reset class-level recorder
|
||||
|
||||
graph = create_agent(
|
||||
model=model,
|
||||
tools=tools,
|
||||
middleware=[DeferredToolFilterMiddleware()],
|
||||
system_prompt="bug-2884-repro",
|
||||
)
|
||||
|
||||
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
|
||||
|
||||
# Turn 1: model should NOT see fake_mcp_search (it's deferred)
|
||||
turn1 = set(model.bound_tools_per_turn[0])
|
||||
assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}"
|
||||
assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}"
|
||||
|
||||
# Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it.
|
||||
# This is the load-bearing assertion for issue #2884.
|
||||
assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}"
|
||||
turn2 = set(model.bound_tools_per_turn[1])
|
||||
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Section C — the actual issue #2884 trigger: a re-entrant
|
||||
# get_available_tools call (e.g. when task_tool spawns a subagent) must not
|
||||
# wipe the parent's promotion.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Issue #2884 in its real shape: a re-entrant get_available_tools call
|
||||
(the same pattern that happens when ``task_tool`` builds a subagent's
|
||||
toolset mid-run) must not wipe the parent agent's tool_search promotions.
|
||||
|
||||
Turn 1's tool batch contains BOTH ``tool_search`` (which promotes
|
||||
``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls
|
||||
``get_available_tools`` again — exactly what ``task_tool`` does when it
|
||||
builds a subagent's toolset). With the fix, turn 2's bind_tools sees the
|
||||
promoted tool. Without the fix, the re-entry wipes the registry and
|
||||
the filter re-hides it.
|
||||
"""
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
from deerflow.tools.tools import get_available_tools
|
||||
|
||||
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
|
||||
_force_tool_search_enabled(monkeypatch)
|
||||
|
||||
# The trigger tool simulates what task_tool does internally: rebuild the
|
||||
# toolset by calling get_available_tools while the registry is live.
|
||||
@as_tool
|
||||
def fake_subagent_trigger(prompt: str) -> str:
|
||||
"""Pretend to spawn a subagent. Internally rebuilds the toolset."""
|
||||
get_available_tools(subagent_enabled=False)
|
||||
return f"spawned subagent for: {prompt}"
|
||||
|
||||
tools = get_available_tools() + [fake_subagent_trigger]
|
||||
|
||||
bound_per_turn: list[list[str]] = []
|
||||
|
||||
class _Model(FakeToolCallingModel):
|
||||
def bind_tools(self, tools_arg, **kwargs): # type: ignore[override]
|
||||
bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg])
|
||||
return self
|
||||
|
||||
model = _Model(
|
||||
responses=[
|
||||
# Turn 1: do both in one batch — promote AND trigger the
|
||||
# subagent-style rebuild. LangGraph executes them in order in the
|
||||
# same agent step.
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "tool_search",
|
||||
"args": {"query": "select:fake_mcp_search"},
|
||||
"id": "call_search_1",
|
||||
"type": "tool_call",
|
||||
},
|
||||
{
|
||||
"name": "fake_subagent_trigger",
|
||||
"args": {"prompt": "go"},
|
||||
"id": "call_trigger_1",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
# Turn 2: try to invoke the promoted tool. The model gets this
|
||||
# turn only if turn 1's bind_tools recorded what the filter sent.
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "fake_mcp_search",
|
||||
"args": {"query": "hello"},
|
||||
"id": "call_mcp_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(content="all done"),
|
||||
]
|
||||
)
|
||||
|
||||
graph = create_agent(
|
||||
model=model,
|
||||
tools=tools,
|
||||
middleware=[DeferredToolFilterMiddleware()],
|
||||
system_prompt="bug-2884-subagent-repro",
|
||||
)
|
||||
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
|
||||
|
||||
# Turn 1 sanity: deferred tool not visible yet
|
||||
assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0]
|
||||
|
||||
# The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the
|
||||
# re-entrant get_available_tools call that happened in turn 1's tool batch.
|
||||
assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}"
|
||||
turn2 = set(bound_per_turn[1])
|
||||
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}"
|
||||
@@ -0,0 +1,43 @@
|
||||
from deerflow.agents.thread_state import merge_promoted
|
||||
|
||||
|
||||
def test_merge_promoted_preserves_existing_when_new_is_none():
|
||||
existing = {"catalog_hash": "abc", "names": ["search"]}
|
||||
|
||||
assert merge_promoted(existing, None) is existing
|
||||
|
||||
|
||||
def test_merge_promoted_preserves_existing_when_new_is_empty_dict():
|
||||
existing = {"catalog_hash": "abc", "names": ["search"]}
|
||||
|
||||
assert merge_promoted(existing, {}) is existing
|
||||
|
||||
|
||||
def test_merge_promoted_replaces_none_existing_with_deduplicated_new_names():
|
||||
result = merge_promoted(None, {"catalog_hash": "abc", "names": ["search", "search", "fetch"]})
|
||||
|
||||
assert result == {"catalog_hash": "abc", "names": ["search", "fetch"]}
|
||||
|
||||
|
||||
def test_merge_promoted_replaces_when_catalog_hash_changes():
|
||||
existing = {"catalog_hash": "abc", "names": ["old"]}
|
||||
|
||||
result = merge_promoted(existing, {"catalog_hash": "def", "names": ["new", "new", "old"]})
|
||||
|
||||
assert result == {"catalog_hash": "def", "names": ["new", "old"]}
|
||||
|
||||
|
||||
def test_merge_promoted_unions_names_when_catalog_hash_matches():
|
||||
existing = {"catalog_hash": "abc", "names": ["search", "fetch"]}
|
||||
|
||||
result = merge_promoted(existing, {"catalog_hash": "abc", "names": ["fetch", "scrape"]})
|
||||
|
||||
assert result == {"catalog_hash": "abc", "names": ["search", "fetch", "scrape"]}
|
||||
|
||||
|
||||
def test_merge_promoted_replaces_malformed_existing_without_crash():
|
||||
# A forward-incompatible / externally-injected persisted state could lack
|
||||
# catalog_hash; the reducer must treat it as a mismatch and replace, not crash.
|
||||
result = merge_promoted({"names": ["stale"]}, {"catalog_hash": "abc", "names": ["search"]})
|
||||
|
||||
assert result == {"catalog_hash": "abc", "names": ["search"]}
|
||||
@@ -1,609 +1,38 @@
|
||||
"""Tests for the tool_search (deferred tool loading) feature."""
|
||||
"""Tests for the tool_search (deferred tool loading) config + prompt section.
|
||||
|
||||
import json
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool as langchain_tool
|
||||
Catalog search, setup assembly, the Command-writing tool_search tool, and the
|
||||
filter middleware are covered by:
|
||||
- tests/test_deferred_catalog.py
|
||||
- tests/test_deferred_setup.py
|
||||
- tests/test_deferred_filter_middleware.py
|
||||
- tests/test_thread_state_promoted.py
|
||||
"""
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
|
||||
from deerflow.tools.builtins.tool_search import (
|
||||
DeferredToolRegistry,
|
||||
get_deferred_registry,
|
||||
reset_deferred_registry,
|
||||
set_deferred_registry,
|
||||
)
|
||||
|
||||
# ── Fixtures ──
|
||||
|
||||
|
||||
def _make_mock_tool(name: str, description: str):
|
||||
"""Create a minimal LangChain tool for testing."""
|
||||
|
||||
@langchain_tool(name)
|
||||
def mock_tool(arg: str) -> str:
|
||||
"""Mock tool."""
|
||||
return f"{name}: {arg}"
|
||||
|
||||
mock_tool.description = description
|
||||
return mock_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry():
|
||||
"""Create a fresh DeferredToolRegistry with test tools."""
|
||||
reg = DeferredToolRegistry()
|
||||
reg.register(_make_mock_tool("github_create_issue", "Create a new issue in a GitHub repository"))
|
||||
reg.register(_make_mock_tool("github_list_repos", "List repositories for a GitHub user"))
|
||||
reg.register(_make_mock_tool("slack_send_message", "Send a message to a Slack channel"))
|
||||
reg.register(_make_mock_tool("slack_list_channels", "List available Slack channels"))
|
||||
reg.register(_make_mock_tool("sentry_list_issues", "List issues from Sentry error tracking"))
|
||||
reg.register(_make_mock_tool("database_query", "Execute a SQL query against the database"))
|
||||
return reg
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singleton():
|
||||
"""Reset the module-level singleton before/after each test."""
|
||||
reset_deferred_registry()
|
||||
yield
|
||||
reset_deferred_registry()
|
||||
|
||||
|
||||
# ── ToolSearchConfig Tests ──
|
||||
|
||||
|
||||
class TestToolSearchConfig:
|
||||
def test_default_disabled(self):
|
||||
config = ToolSearchConfig()
|
||||
assert config.enabled is False
|
||||
assert ToolSearchConfig().enabled is False
|
||||
|
||||
def test_enabled(self):
|
||||
config = ToolSearchConfig(enabled=True)
|
||||
assert config.enabled is True
|
||||
assert ToolSearchConfig(enabled=True).enabled is True
|
||||
|
||||
def test_load_from_dict(self):
|
||||
config = load_tool_search_config_from_dict({"enabled": True})
|
||||
assert config.enabled is True
|
||||
assert load_tool_search_config_from_dict({"enabled": True}).enabled is True
|
||||
|
||||
def test_load_from_empty_dict(self):
|
||||
config = load_tool_search_config_from_dict({})
|
||||
assert config.enabled is False
|
||||
|
||||
|
||||
# ── DeferredToolRegistry Tests ──
|
||||
|
||||
|
||||
class TestDeferredToolRegistry:
|
||||
def test_register_and_len(self, registry):
|
||||
assert len(registry) == 6
|
||||
|
||||
def test_entries(self, registry):
|
||||
names = [e.name for e in registry.entries]
|
||||
assert "github_create_issue" in names
|
||||
assert "slack_send_message" in names
|
||||
|
||||
def test_deferred_names(self, registry):
|
||||
names = registry.deferred_names
|
||||
assert "github_create_issue" in names
|
||||
assert "slack_send_message" in names
|
||||
assert len(names) == 6
|
||||
|
||||
def test_contains(self, registry):
|
||||
assert registry.contains("github_create_issue") is True
|
||||
assert registry.contains("not_registered") is False
|
||||
|
||||
def test_search_select_single(self, registry):
|
||||
results = registry.search("select:github_create_issue")
|
||||
assert len(results) == 1
|
||||
assert results[0].name == "github_create_issue"
|
||||
|
||||
def test_search_select_multiple(self, registry):
|
||||
results = registry.search("select:github_create_issue,slack_send_message")
|
||||
names = {t.name for t in results}
|
||||
assert names == {"github_create_issue", "slack_send_message"}
|
||||
|
||||
def test_search_select_nonexistent(self, registry):
|
||||
results = registry.search("select:nonexistent_tool")
|
||||
assert results == []
|
||||
|
||||
def test_search_plus_keyword(self, registry):
|
||||
results = registry.search("+github")
|
||||
names = {t.name for t in results}
|
||||
assert names == {"github_create_issue", "github_list_repos"}
|
||||
|
||||
def test_search_plus_keyword_with_ranking(self, registry):
|
||||
results = registry.search("+github issue")
|
||||
assert len(results) == 2
|
||||
# "github_create_issue" should rank higher (has "issue" in name)
|
||||
assert results[0].name == "github_create_issue"
|
||||
|
||||
def test_search_regex_keyword(self, registry):
|
||||
results = registry.search("slack")
|
||||
names = {t.name for t in results}
|
||||
assert "slack_send_message" in names
|
||||
assert "slack_list_channels" in names
|
||||
|
||||
def test_search_regex_description(self, registry):
|
||||
results = registry.search("SQL")
|
||||
assert len(results) == 1
|
||||
assert results[0].name == "database_query"
|
||||
|
||||
def test_search_regex_case_insensitive(self, registry):
|
||||
results = registry.search("GITHUB")
|
||||
assert len(results) == 2
|
||||
|
||||
def test_search_invalid_regex_falls_back_to_literal(self, registry):
|
||||
# "[" is invalid regex, should be escaped and used as literal
|
||||
results = registry.search("[")
|
||||
assert results == []
|
||||
|
||||
def test_search_name_match_ranks_higher(self, registry):
|
||||
# "issue" appears in both github_create_issue (name) and sentry_list_issues (name+desc)
|
||||
results = registry.search("issue")
|
||||
names = [t.name for t in results]
|
||||
# Both should be found (both have "issue" in name)
|
||||
assert "github_create_issue" in names
|
||||
assert "sentry_list_issues" in names
|
||||
|
||||
def test_search_max_results(self):
|
||||
reg = DeferredToolRegistry()
|
||||
for i in range(10):
|
||||
reg.register(_make_mock_tool(f"tool_{i}", f"Tool number {i}"))
|
||||
results = reg.search("tool")
|
||||
assert len(results) <= 5 # MAX_RESULTS = 5
|
||||
|
||||
def test_search_empty_registry(self):
|
||||
reg = DeferredToolRegistry()
|
||||
assert reg.search("anything") == []
|
||||
|
||||
def test_empty_registry_len(self):
|
||||
reg = DeferredToolRegistry()
|
||||
assert len(reg) == 0
|
||||
|
||||
|
||||
# ── Singleton Tests ──
|
||||
|
||||
|
||||
class TestSingleton:
|
||||
def test_default_none(self):
|
||||
assert get_deferred_registry() is None
|
||||
|
||||
def test_set_and_get(self, registry):
|
||||
set_deferred_registry(registry)
|
||||
assert get_deferred_registry() is registry
|
||||
|
||||
def test_reset(self, registry):
|
||||
set_deferred_registry(registry)
|
||||
reset_deferred_registry()
|
||||
assert get_deferred_registry() is None
|
||||
|
||||
def test_contextvar_isolation_across_contexts(self, registry):
|
||||
"""P2: Each async context gets its own independent registry value."""
|
||||
import contextvars
|
||||
|
||||
reg_a = DeferredToolRegistry()
|
||||
reg_a.register(_make_mock_tool("tool_a", "Tool A"))
|
||||
|
||||
reg_b = DeferredToolRegistry()
|
||||
reg_b.register(_make_mock_tool("tool_b", "Tool B"))
|
||||
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
def run_in_context_a():
|
||||
set_deferred_registry(reg_a)
|
||||
seen["ctx_a"] = get_deferred_registry()
|
||||
|
||||
def run_in_context_b():
|
||||
set_deferred_registry(reg_b)
|
||||
seen["ctx_b"] = get_deferred_registry()
|
||||
|
||||
ctx_a = contextvars.copy_context()
|
||||
ctx_b = contextvars.copy_context()
|
||||
ctx_a.run(run_in_context_a)
|
||||
ctx_b.run(run_in_context_b)
|
||||
|
||||
# Each context got its own registry, neither bleeds into the other
|
||||
assert seen["ctx_a"] is reg_a
|
||||
assert seen["ctx_b"] is reg_b
|
||||
# The current context is unchanged
|
||||
assert get_deferred_registry() is None
|
||||
|
||||
|
||||
# ── tool_search Tool Tests ──
|
||||
|
||||
|
||||
class TestToolSearchTool:
|
||||
def test_no_registry(self):
|
||||
from deerflow.tools.builtins.tool_search import tool_search
|
||||
|
||||
result = tool_search.invoke({"query": "github"})
|
||||
assert result == "No deferred tools available."
|
||||
|
||||
def test_no_match(self, registry):
|
||||
from deerflow.tools.builtins.tool_search import tool_search
|
||||
|
||||
set_deferred_registry(registry)
|
||||
result = tool_search.invoke({"query": "nonexistent_xyz_tool"})
|
||||
assert "No tools found matching" in result
|
||||
|
||||
def test_returns_valid_json(self, registry):
|
||||
from deerflow.tools.builtins.tool_search import tool_search
|
||||
|
||||
set_deferred_registry(registry)
|
||||
result = tool_search.invoke({"query": "select:github_create_issue"})
|
||||
parsed = json.loads(result)
|
||||
assert isinstance(parsed, list)
|
||||
assert len(parsed) == 1
|
||||
assert parsed[0]["name"] == "github_create_issue"
|
||||
|
||||
def test_returns_openai_function_format(self, registry):
|
||||
from deerflow.tools.builtins.tool_search import tool_search
|
||||
|
||||
set_deferred_registry(registry)
|
||||
result = tool_search.invoke({"query": "select:slack_send_message"})
|
||||
parsed = json.loads(result)
|
||||
func_def = parsed[0]
|
||||
# OpenAI function format should have these keys
|
||||
assert "name" in func_def
|
||||
assert "description" in func_def
|
||||
assert "parameters" in func_def
|
||||
|
||||
def test_keyword_search_returns_json(self, registry):
|
||||
from deerflow.tools.builtins.tool_search import tool_search
|
||||
|
||||
set_deferred_registry(registry)
|
||||
result = tool_search.invoke({"query": "github"})
|
||||
parsed = json.loads(result)
|
||||
assert len(parsed) == 2
|
||||
names = {d["name"] for d in parsed}
|
||||
assert names == {"github_create_issue", "github_list_repos"}
|
||||
|
||||
|
||||
# ── Prompt Section Tests ──
|
||||
assert load_tool_search_config_from_dict({}).enabled is False
|
||||
|
||||
|
||||
class TestDeferredToolsPromptSection:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_app_config(self, monkeypatch):
|
||||
"""Provide a minimal AppConfig mock so tests don't need config.yaml."""
|
||||
from unittest.mock import MagicMock
|
||||
def test_empty_without_names(self):
|
||||
assert get_deferred_tools_prompt_section() == ""
|
||||
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig
|
||||
def test_empty_with_empty_frozenset(self):
|
||||
assert get_deferred_tools_prompt_section(deferred_names=frozenset()) == ""
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.tool_search = ToolSearchConfig() # disabled by default
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: mock_config)
|
||||
|
||||
def test_empty_when_disabled(self):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
|
||||
# tool_search.enabled defaults to False
|
||||
section = get_deferred_tools_prompt_section()
|
||||
assert section == ""
|
||||
|
||||
def test_empty_when_enabled_but_no_registry(self, monkeypatch):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
|
||||
section = get_deferred_tools_prompt_section()
|
||||
assert section == ""
|
||||
|
||||
def test_empty_when_enabled_but_empty_registry(self, monkeypatch):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
|
||||
set_deferred_registry(DeferredToolRegistry())
|
||||
section = get_deferred_tools_prompt_section()
|
||||
assert section == ""
|
||||
|
||||
def test_lists_tool_names(self, registry, monkeypatch):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
monkeypatch.setattr(get_app_config().tool_search, "enabled", True)
|
||||
set_deferred_registry(registry)
|
||||
section = get_deferred_tools_prompt_section()
|
||||
assert "<available-deferred-tools>" in section
|
||||
assert "</available-deferred-tools>" in section
|
||||
assert "github_create_issue" in section
|
||||
assert "slack_send_message" in section
|
||||
assert "sentry_list_issues" in section
|
||||
# Should only have names, no descriptions
|
||||
assert "Create a new issue" not in section
|
||||
|
||||
|
||||
# ── DeferredToolFilterMiddleware Tests ──
|
||||
|
||||
|
||||
class TestDeferredToolFilterMiddleware:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _ensure_middlewares_package(self):
|
||||
"""Remove mock entries injected by test_subagent_executor.py.
|
||||
|
||||
That file replaces deerflow.agents and deerflow.agents.middlewares with
|
||||
MagicMock objects in sys.modules (session-scoped) to break circular imports.
|
||||
We must clear those mocks so real submodule imports work.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_keys = [
|
||||
"deerflow.agents",
|
||||
"deerflow.agents.middlewares",
|
||||
"deerflow.agents.middlewares.deferred_tool_filter_middleware",
|
||||
]
|
||||
for key in mock_keys:
|
||||
if isinstance(sys.modules.get(key), MagicMock):
|
||||
del sys.modules[key]
|
||||
|
||||
def test_filters_deferred_tools(self, registry):
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
set_deferred_registry(registry)
|
||||
middleware = DeferredToolFilterMiddleware()
|
||||
|
||||
# Build a mock tools list: 2 active + 1 deferred
|
||||
active_tool = _make_mock_tool("my_active_tool", "An active tool")
|
||||
deferred_tool = registry.entries[0].tool # github_create_issue
|
||||
|
||||
class FakeRequest:
|
||||
def __init__(self, tools):
|
||||
self.tools = tools
|
||||
|
||||
def override(self, **kwargs):
|
||||
return FakeRequest(kwargs.get("tools", self.tools))
|
||||
|
||||
request = FakeRequest(tools=[active_tool, deferred_tool])
|
||||
filtered = middleware._filter_tools(request)
|
||||
|
||||
assert len(filtered.tools) == 1
|
||||
assert filtered.tools[0].name == "my_active_tool"
|
||||
|
||||
def test_no_op_when_no_registry(self):
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
middleware = DeferredToolFilterMiddleware()
|
||||
active_tool = _make_mock_tool("my_tool", "A tool")
|
||||
|
||||
class FakeRequest:
|
||||
def __init__(self, tools):
|
||||
self.tools = tools
|
||||
|
||||
def override(self, **kwargs):
|
||||
return FakeRequest(kwargs.get("tools", self.tools))
|
||||
|
||||
request = FakeRequest(tools=[active_tool])
|
||||
filtered = middleware._filter_tools(request)
|
||||
|
||||
assert len(filtered.tools) == 1
|
||||
assert filtered.tools[0].name == "my_tool"
|
||||
|
||||
def test_preserves_dict_tools(self, registry):
|
||||
"""Dict tools (provider built-ins) should not be filtered."""
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
set_deferred_registry(registry)
|
||||
middleware = DeferredToolFilterMiddleware()
|
||||
|
||||
dict_tool = {"type": "function", "function": {"name": "some_builtin"}}
|
||||
active_tool = _make_mock_tool("my_active_tool", "Active")
|
||||
|
||||
class FakeRequest:
|
||||
def __init__(self, tools):
|
||||
self.tools = tools
|
||||
|
||||
def override(self, **kwargs):
|
||||
return FakeRequest(kwargs.get("tools", self.tools))
|
||||
|
||||
request = FakeRequest(tools=[dict_tool, active_tool])
|
||||
filtered = middleware._filter_tools(request)
|
||||
|
||||
# dict_tool has no .name attr → getattr returns None → not in deferred_names → kept
|
||||
assert len(filtered.tools) == 2
|
||||
|
||||
|
||||
# ── Promote Tests ──
|
||||
|
||||
|
||||
class TestDeferredToolRegistryPromote:
|
||||
def test_promote_removes_tools(self, registry):
|
||||
assert len(registry) == 6
|
||||
registry.promote({"github_create_issue", "slack_send_message"})
|
||||
assert len(registry) == 4
|
||||
remaining = {e.name for e in registry.entries}
|
||||
assert "github_create_issue" not in remaining
|
||||
assert "slack_send_message" not in remaining
|
||||
assert "github_list_repos" in remaining
|
||||
|
||||
def test_promote_nonexistent_is_noop(self, registry):
|
||||
assert len(registry) == 6
|
||||
registry.promote({"nonexistent_tool"})
|
||||
assert len(registry) == 6
|
||||
|
||||
def test_promote_empty_set_is_noop(self, registry):
|
||||
assert len(registry) == 6
|
||||
registry.promote(set())
|
||||
assert len(registry) == 6
|
||||
|
||||
def test_promote_all(self, registry):
|
||||
all_names = {e.name for e in registry.entries}
|
||||
registry.promote(all_names)
|
||||
assert len(registry) == 0
|
||||
|
||||
def test_search_after_promote_excludes_promoted(self, registry):
|
||||
"""After promoting github tools, searching 'github' returns nothing."""
|
||||
registry.promote({"github_create_issue", "github_list_repos"})
|
||||
results = registry.search("github")
|
||||
assert results == []
|
||||
|
||||
def test_filter_after_promote_passes_through(self, registry):
|
||||
"""After tool_search promotes a tool, the middleware lets it through."""
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Clear any mock entries
|
||||
mock_keys = [
|
||||
"deerflow.agents",
|
||||
"deerflow.agents.middlewares",
|
||||
"deerflow.agents.middlewares.deferred_tool_filter_middleware",
|
||||
]
|
||||
for key in mock_keys:
|
||||
if isinstance(sys.modules.get(key), MagicMock):
|
||||
del sys.modules[key]
|
||||
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
set_deferred_registry(registry)
|
||||
middleware = DeferredToolFilterMiddleware()
|
||||
|
||||
target_tool = registry.entries[0].tool # github_create_issue
|
||||
active_tool = _make_mock_tool("my_active_tool", "Active")
|
||||
|
||||
class FakeRequest:
|
||||
def __init__(self, tools):
|
||||
self.tools = tools
|
||||
|
||||
def override(self, **kwargs):
|
||||
return FakeRequest(kwargs.get("tools", self.tools))
|
||||
|
||||
# Before promote: deferred tool is filtered
|
||||
request = FakeRequest(tools=[active_tool, target_tool])
|
||||
filtered = middleware._filter_tools(request)
|
||||
assert len(filtered.tools) == 1
|
||||
assert filtered.tools[0].name == "my_active_tool"
|
||||
|
||||
# Promote the tool
|
||||
registry.promote({"github_create_issue"})
|
||||
|
||||
# After promote: tool passes through the filter
|
||||
request2 = FakeRequest(tools=[active_tool, target_tool])
|
||||
filtered2 = middleware._filter_tools(request2)
|
||||
assert len(filtered2.tools) == 2
|
||||
tool_names = {t.name for t in filtered2.tools}
|
||||
assert "github_create_issue" in tool_names
|
||||
assert "my_active_tool" in tool_names
|
||||
|
||||
|
||||
class TestToolSearchPromotion:
|
||||
def test_tool_search_promotes_matched_tools(self, registry):
|
||||
"""tool_search should promote matched tools so they become callable."""
|
||||
from deerflow.tools.builtins.tool_search import tool_search
|
||||
|
||||
set_deferred_registry(registry)
|
||||
assert len(registry) == 6
|
||||
|
||||
# Search for github tools — should return schemas AND promote them
|
||||
result = tool_search.invoke({"query": "select:github_create_issue"})
|
||||
parsed = json.loads(result)
|
||||
assert len(parsed) == 1
|
||||
assert parsed[0]["name"] == "github_create_issue"
|
||||
|
||||
# The tool should now be promoted (removed from registry)
|
||||
assert len(registry) == 5
|
||||
remaining = {e.name for e in registry.entries}
|
||||
assert "github_create_issue" not in remaining
|
||||
|
||||
def test_tool_search_keyword_promotes_all_matches(self, registry):
|
||||
"""Keyword search promotes all matched tools."""
|
||||
from deerflow.tools.builtins.tool_search import tool_search
|
||||
|
||||
set_deferred_registry(registry)
|
||||
result = tool_search.invoke({"query": "slack"})
|
||||
parsed = json.loads(result)
|
||||
assert len(parsed) == 2
|
||||
|
||||
# Both slack tools promoted
|
||||
remaining = {e.name for e in registry.entries}
|
||||
assert "slack_send_message" not in remaining
|
||||
assert "slack_list_channels" not in remaining
|
||||
assert len(registry) == 4
|
||||
|
||||
|
||||
class TestDeferredToolExecutionGate:
|
||||
def test_unpromoted_deferred_tool_call_is_blocked(self, registry):
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
set_deferred_registry(registry)
|
||||
middleware = DeferredToolFilterMiddleware()
|
||||
request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"})
|
||||
called = False
|
||||
|
||||
def handler(_request):
|
||||
nonlocal called
|
||||
called = True
|
||||
return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue")
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
|
||||
assert called is False
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
assert result.tool_call_id == "call-1"
|
||||
assert "tool_search" in result.content
|
||||
assert "github_create_issue" in result.content
|
||||
|
||||
def test_promoted_deferred_tool_call_is_allowed(self, registry):
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
registry.promote({"github_create_issue"})
|
||||
set_deferred_registry(registry)
|
||||
middleware = DeferredToolFilterMiddleware()
|
||||
request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"})
|
||||
called = False
|
||||
|
||||
def handler(_request):
|
||||
nonlocal called
|
||||
called = True
|
||||
return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue")
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
|
||||
assert called is True
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.content == "executed"
|
||||
|
||||
def test_non_deferred_tool_call_is_allowed(self, registry):
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
set_deferred_registry(registry)
|
||||
middleware = DeferredToolFilterMiddleware()
|
||||
request = SimpleNamespace(tool_call={"name": "local_tool", "id": "call-1"})
|
||||
called = False
|
||||
|
||||
def handler(_request):
|
||||
nonlocal called
|
||||
called = True
|
||||
return ToolMessage(content="executed", tool_call_id="call-1", name="local_tool")
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
|
||||
assert called is True
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.content == "executed"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_unpromoted_deferred_tool_call_is_blocked_async(self, registry):
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
set_deferred_registry(registry)
|
||||
middleware = DeferredToolFilterMiddleware()
|
||||
request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"})
|
||||
called = False
|
||||
|
||||
async def handler(_request):
|
||||
nonlocal called
|
||||
called = True
|
||||
return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue")
|
||||
|
||||
result = await middleware.awrap_tool_call(request, handler)
|
||||
|
||||
assert called is False
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
assert result.tool_call_id == "call-1"
|
||||
assert "tool_search" in result.content
|
||||
assert "github_create_issue" in result.content
|
||||
def test_lists_sorted_names(self):
|
||||
out = get_deferred_tools_prompt_section(deferred_names=frozenset({"b_tool", "a_tool"}))
|
||||
assert out == "<available-deferred-tools>\na_tool\nb_tool\n</available-deferred-tools>"
|
||||
|
||||
Reference in New Issue
Block a user