mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 07:56:48 +00:00
fix: gate deferred MCP tool execution (#2513)
* fix: gate deferred MCP tool execution * style: format deferred tool middleware * fix: address deferred tool review feedback
This commit is contained in:
@@ -2,8 +2,10 @@
|
||||
|
||||
import json
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool as langchain_tool
|
||||
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
|
||||
@@ -83,6 +85,16 @@ class TestDeferredToolRegistry:
|
||||
assert "github_create_issue" in names
|
||||
assert "slack_send_message" in names
|
||||
|
||||
def test_deferred_names(self, registry):
|
||||
names = registry.deferred_names
|
||||
assert "github_create_issue" in names
|
||||
assert "slack_send_message" in names
|
||||
assert len(names) == 6
|
||||
|
||||
def test_contains(self, registry):
|
||||
assert registry.contains("github_create_issue") is True
|
||||
assert registry.contains("not_registered") is False
|
||||
|
||||
def test_search_select_single(self, registry):
|
||||
results = registry.search("select:github_create_issue")
|
||||
assert len(results) == 1
|
||||
@@ -509,3 +521,89 @@ class TestToolSearchPromotion:
|
||||
assert "slack_send_message" not in remaining
|
||||
assert "slack_list_channels" not in remaining
|
||||
assert len(registry) == 4
|
||||
|
||||
|
||||
class TestDeferredToolExecutionGate:
|
||||
def test_unpromoted_deferred_tool_call_is_blocked(self, registry):
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
set_deferred_registry(registry)
|
||||
middleware = DeferredToolFilterMiddleware()
|
||||
request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"})
|
||||
called = False
|
||||
|
||||
def handler(_request):
|
||||
nonlocal called
|
||||
called = True
|
||||
return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue")
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
|
||||
assert called is False
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
assert result.tool_call_id == "call-1"
|
||||
assert "tool_search" in result.content
|
||||
assert "github_create_issue" in result.content
|
||||
|
||||
def test_promoted_deferred_tool_call_is_allowed(self, registry):
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
registry.promote({"github_create_issue"})
|
||||
set_deferred_registry(registry)
|
||||
middleware = DeferredToolFilterMiddleware()
|
||||
request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"})
|
||||
called = False
|
||||
|
||||
def handler(_request):
|
||||
nonlocal called
|
||||
called = True
|
||||
return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue")
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
|
||||
assert called is True
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.content == "executed"
|
||||
|
||||
def test_non_deferred_tool_call_is_allowed(self, registry):
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
set_deferred_registry(registry)
|
||||
middleware = DeferredToolFilterMiddleware()
|
||||
request = SimpleNamespace(tool_call={"name": "local_tool", "id": "call-1"})
|
||||
called = False
|
||||
|
||||
def handler(_request):
|
||||
nonlocal called
|
||||
called = True
|
||||
return ToolMessage(content="executed", tool_call_id="call-1", name="local_tool")
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
|
||||
assert called is True
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.content == "executed"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_unpromoted_deferred_tool_call_is_blocked_async(self, registry):
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
set_deferred_registry(registry)
|
||||
middleware = DeferredToolFilterMiddleware()
|
||||
request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"})
|
||||
called = False
|
||||
|
||||
async def handler(_request):
|
||||
nonlocal called
|
||||
called = True
|
||||
return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue")
|
||||
|
||||
result = await middleware.awrap_tool_call(request, handler)
|
||||
|
||||
assert called is False
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.status == "error"
|
||||
assert result.tool_call_id == "call-1"
|
||||
assert "tool_search" in result.content
|
||||
assert "github_create_issue" in result.content
|
||||
|
||||
Reference in New Issue
Block a user