fix(sandbox): add input sanitisation guard to SandboxAuditMiddleware (#1872)

* fix(sandbox): add L2 input sanitisation to SandboxAuditMiddleware

Add _validate_input() to reject malformed bash commands before regex
classification: empty commands, oversized commands (>10 000 chars), and
null bytes that could cause detection/execution layer inconsistency.

* fix(sandbox): address Copilot review — type guard, log truncation, reject reason

- Coerce None/non-string command to str before validation
- Truncate oversized commands in audit logs to prevent log amplification
- Propagate reject_reason through _pre_process() to block message
- Remove L2 label from comments and test class names

* fix(sandbox): isinstance type guard + async input sanitisation tests

Address review comments:
- Replace str() coercion with isinstance(raw_command, str) guard so
  non-string truthy values (0, [], False) fall back to empty string
  instead of passing validation as "0"/"[]"/"False".
- Add TestInputSanitisationBlocksInAwrapToolCall with 4 async tests
  covering empty, null-byte, oversized, and None command via
  awrap_tool_call path.
This commit is contained in:
KKK
2026-04-06 17:21:58 +08:00
committed by GitHub
parent 1ced6e977c
commit 055e4df049
2 changed files with 198 additions and 12 deletions
@@ -1,5 +1,6 @@
"""Tests for SandboxAuditMiddleware - command classification and audit logging."""
import unittest.mock
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
@@ -134,6 +135,98 @@ class TestClassifyCommand:
assert _classify_command(cmd) == "pass", f"Expected 'pass' for: {cmd!r}"
# ---------------------------------------------------------------------------
# _validate_input unit tests (input sanitisation)
# ---------------------------------------------------------------------------
class TestValidateInput:
def setup_method(self):
self.mw = SandboxAuditMiddleware()
def test_empty_string_rejected(self):
assert self.mw._validate_input("") == "empty command"
def test_whitespace_only_rejected(self):
assert self.mw._validate_input(" \t\n ") == "empty command"
def test_normal_command_accepted(self):
assert self.mw._validate_input("ls -la") is None
def test_command_at_max_length_accepted(self):
cmd = "a" * 10_000
assert self.mw._validate_input(cmd) is None
def test_command_exceeding_max_length_rejected(self):
cmd = "a" * 10_001
assert self.mw._validate_input(cmd) == "command too long"
def test_null_byte_rejected(self):
assert self.mw._validate_input("ls\x00; rm -rf /") == "null byte detected"
def test_null_byte_at_start_rejected(self):
assert self.mw._validate_input("\x00ls") == "null byte detected"
def test_null_byte_at_end_rejected(self):
assert self.mw._validate_input("ls\x00") == "null byte detected"
class TestInputSanitisationBlocksInWrapToolCall:
"""Verify that input sanitisation rejections flow through wrap_tool_call correctly."""
def setup_method(self):
self.mw = SandboxAuditMiddleware()
def test_empty_command_blocked_with_reason(self):
request = _make_request("")
handler = _make_handler()
result = self.mw.wrap_tool_call(request, handler)
assert not handler.called
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "empty command" in result.content.lower()
def test_null_byte_command_blocked_with_reason(self):
request = _make_request("echo\x00rm -rf /")
handler = _make_handler()
result = self.mw.wrap_tool_call(request, handler)
assert not handler.called
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "null byte" in result.content.lower()
def test_oversized_command_blocked_with_reason(self):
request = _make_request("a" * 10_001)
handler = _make_handler()
result = self.mw.wrap_tool_call(request, handler)
assert not handler.called
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "command too long" in result.content.lower()
def test_none_command_coerced_to_empty(self):
"""args.get('command') returning None should be coerced to str and rejected as empty."""
request = _make_request("")
# Simulate None value by patching args directly
request.tool_call["args"]["command"] = None
handler = _make_handler()
result = self.mw.wrap_tool_call(request, handler)
assert not handler.called
assert isinstance(result, ToolMessage)
assert result.status == "error"
def test_oversized_command_audit_log_truncated(self):
"""Oversized commands should be truncated in audit logs to prevent log amplification."""
big_cmd = "x" * 10_001
request = _make_request(big_cmd)
handler = _make_handler()
with unittest.mock.patch.object(self.mw, "_write_audit", wraps=self.mw._write_audit) as spy:
self.mw.wrap_tool_call(request, handler)
spy.assert_called_once()
_, kwargs = spy.call_args
assert kwargs.get("truncate") is True
# ---------------------------------------------------------------------------
# SandboxAuditMiddleware.wrap_tool_call integration tests
# ---------------------------------------------------------------------------
@@ -301,6 +394,63 @@ class TestSandboxAuditMiddlewareAwrapToolCall:
assert result == handler_mock.return_value
# ---------------------------------------------------------------------------
# Input sanitisation via awrap_tool_call (async path)
# ---------------------------------------------------------------------------
class TestInputSanitisationBlocksInAwrapToolCall:
"""Verify that input sanitisation rejections flow through awrap_tool_call correctly."""
def setup_method(self):
self.mw = SandboxAuditMiddleware()
async def _call_async(self, request):
handler_mock = _make_handler()
async def async_handler(req):
return handler_mock(req)
result = await self.mw.awrap_tool_call(request, async_handler)
return result, handler_mock.called
@pytest.mark.anyio
async def test_empty_command_blocked_with_reason(self):
request = _make_request("")
result, called = await self._call_async(request)
assert not called
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "empty command" in result.content.lower()
@pytest.mark.anyio
async def test_null_byte_command_blocked_with_reason(self):
request = _make_request("echo\x00rm -rf /")
result, called = await self._call_async(request)
assert not called
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "null byte" in result.content.lower()
@pytest.mark.anyio
async def test_oversized_command_blocked_with_reason(self):
request = _make_request("a" * 10_001)
result, called = await self._call_async(request)
assert not called
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "command too long" in result.content.lower()
@pytest.mark.anyio
async def test_none_command_coerced_to_empty(self):
request = _make_request("")
request.tool_call["args"]["command"] = None
result, called = await self._call_async(request)
assert not called
assert isinstance(result, ToolMessage)
assert result.status == "error"
# ---------------------------------------------------------------------------
# Precision / recall summary (asserted metrics for benchmark reporting)
# ---------------------------------------------------------------------------