Merge branch 'main' into rayhpeng/persistence-scaffold

# Conflicts:
#	.env.example
#	backend/packages/harness/deerflow/agents/middlewares/title_middleware.py
This commit is contained in:
rayhpeng
2026-04-04 21:28:07 +08:00
180 changed files with 10945 additions and 787 deletions
+50
View File
@@ -131,3 +131,53 @@ class TestListDirSerialization:
result = sandbox.list_dir("/test")
assert result == ["/a", "/b"]
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
class TestConcurrentFileWrites:
"""Verify file write paths do not lose concurrent updates."""
def test_append_should_preserve_both_parallel_writes(self, sandbox):
storage = {"content": "seed\n"}
active_reads = 0
state_lock = threading.Lock()
overlap_detected = threading.Event()
def overlapping_read_file(path):
nonlocal active_reads
with state_lock:
active_reads += 1
snapshot = storage["content"]
if active_reads == 2:
overlap_detected.set()
overlap_detected.wait(0.05)
with state_lock:
active_reads -= 1
return snapshot
def write_back(*, file, content, **kwargs):
storage["content"] = content
return SimpleNamespace(data=SimpleNamespace())
sandbox.read_file = overlapping_read_file
sandbox._client.file.write_file = write_back
barrier = threading.Barrier(2)
def writer(payload: str):
barrier.wait()
sandbox.write_file("/tmp/shared.log", payload, append=True)
threads = [
threading.Thread(target=writer, args=("A\n",)),
threading.Thread(target=writer, args=("B\n",)),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"}
+198 -1
View File
@@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from app.channels.base import Channel
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.store import ChannelStore
@@ -1718,6 +1718,159 @@ class TestFeishuChannel:
_run(go())
class TestWeComChannel:
def test_publish_ws_inbound_starts_stream_and_publishes_message(self, monkeypatch):
from app.channels.wecom import WeComChannel
async def go():
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = WeComChannel(bus, config={})
channel._ws_client = SimpleNamespace(reply_stream=AsyncMock())
monkeypatch.setitem(
__import__("sys").modules,
"aibot",
SimpleNamespace(generate_req_id=lambda prefix: "stream-1"),
)
frame = {
"body": {
"msgid": "msg-1",
"from": {"userid": "user-1"},
"aibotid": "bot-1",
"chattype": "single",
}
}
files = [{"type": "image", "url": "https://example.com/image.png"}]
await channel._publish_ws_inbound(frame, "hello", files=files)
channel._ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "Working on it...", False)
bus.publish_inbound.assert_awaited_once()
inbound = bus.publish_inbound.await_args.args[0]
assert inbound.channel_name == "wecom"
assert inbound.chat_id == "user-1"
assert inbound.user_id == "user-1"
assert inbound.text == "hello"
assert inbound.thread_ts == "msg-1"
assert inbound.topic_id == "user-1"
assert inbound.files == files
assert inbound.metadata == {"aibotid": "bot-1", "chattype": "single"}
assert channel._ws_frames["msg-1"] is frame
assert channel._ws_stream_ids["msg-1"] == "stream-1"
_run(go())
def test_publish_ws_inbound_uses_configured_working_message(self, monkeypatch):
from app.channels.wecom import WeComChannel
async def go():
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = WeComChannel(bus, config={"working_message": "Please wait..."})
channel._ws_client = SimpleNamespace(reply_stream=AsyncMock())
channel._working_message = "Please wait..."
monkeypatch.setitem(
__import__("sys").modules,
"aibot",
SimpleNamespace(generate_req_id=lambda prefix: "stream-1"),
)
frame = {
"body": {
"msgid": "msg-1",
"from": {"userid": "user-1"},
}
}
await channel._publish_ws_inbound(frame, "hello")
channel._ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "Please wait...", False)
_run(go())
def test_on_outbound_sends_attachment_before_clearing_context(self, tmp_path):
from app.channels.wecom import WeComChannel
async def go():
bus = MessageBus()
channel = WeComChannel(bus, config={})
frame = {"body": {"msgid": "msg-1"}}
ws_client = SimpleNamespace(
reply_stream=AsyncMock(),
reply=AsyncMock(),
)
channel._ws_client = ws_client
channel._ws_frames["msg-1"] = frame
channel._ws_stream_ids["msg-1"] = "stream-1"
channel._upload_media_ws = AsyncMock(return_value="media-1")
attachment_path = tmp_path / "image.png"
attachment_path.write_bytes(b"png")
attachment = ResolvedAttachment(
virtual_path="/mnt/user-data/outputs/image.png",
actual_path=attachment_path,
filename="image.png",
mime_type="image/png",
size=attachment_path.stat().st_size,
is_image=True,
)
msg = OutboundMessage(
channel_name="wecom",
chat_id="user-1",
thread_id="thread-1",
text="done",
attachments=[attachment],
is_final=True,
thread_ts="msg-1",
)
await channel._on_outbound(msg)
ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "done", True)
channel._upload_media_ws.assert_awaited_once_with(
media_type="image",
filename="image.png",
path=str(attachment_path),
size=attachment.size,
)
ws_client.reply.assert_awaited_once_with(frame, {"image": {"media_id": "media-1"}, "msgtype": "image"})
assert "msg-1" not in channel._ws_frames
assert "msg-1" not in channel._ws_stream_ids
_run(go())
def test_send_falls_back_to_send_message_without_thread_context(self):
from app.channels.wecom import WeComChannel
async def go():
bus = MessageBus()
channel = WeComChannel(bus, config={})
channel._ws_client = SimpleNamespace(send_message=AsyncMock())
msg = OutboundMessage(
channel_name="wecom",
chat_id="user-1",
thread_id="thread-1",
text="hello",
thread_ts=None,
)
await channel.send(msg)
channel._ws_client.send_message.assert_awaited_once_with(
"user-1",
{"msgtype": "markdown", "markdown": {"content": "hello"}},
)
_run(go())
class TestChannelService:
def test_get_status_no_channels(self):
from app.channels.service import ChannelService
@@ -1854,6 +2007,20 @@ class TestSlackSendRetry:
_run(go())
def test_raises_runtime_error_when_no_attempts_configured(self):
from app.channels.slack import SlackChannel
async def go():
bus = MessageBus()
ch = SlackChannel(bus=bus, config={"bot_token": "xoxb-test", "app_token": "xapp-test"})
ch._web_client = MagicMock()
msg = OutboundMessage(channel_name="slack", chat_id="C123", thread_id="t1", text="hello")
with pytest.raises(RuntimeError, match="without an exception"):
await ch.send(msg, _max_retries=0)
_run(go())
# ---------------------------------------------------------------------------
# Telegram send retry tests
@@ -1912,6 +2079,36 @@ class TestTelegramSendRetry:
_run(go())
def test_raises_runtime_error_when_no_attempts_configured(self):
from app.channels.telegram import TelegramChannel
async def go():
bus = MessageBus()
ch = TelegramChannel(bus=bus, config={"bot_token": "test-token"})
ch._application = MagicMock()
msg = OutboundMessage(channel_name="telegram", chat_id="12345", thread_id="t1", text="hello")
with pytest.raises(RuntimeError, match="without an exception"):
await ch.send(msg, _max_retries=0)
_run(go())
class TestFeishuSendRetry:
def test_raises_runtime_error_when_no_attempts_configured(self):
from app.channels.feishu import FeishuChannel
async def go():
bus = MessageBus()
ch = FeishuChannel(bus=bus, config={"app_id": "id", "app_secret": "secret"})
ch._api_client = MagicMock()
msg = OutboundMessage(channel_name="feishu", chat_id="chat", thread_id="t1", text="hello")
with pytest.raises(RuntimeError, match="without an exception"):
await ch.send(msg, _max_retries=0)
_run(go())
# ---------------------------------------------------------------------------
# Telegram private-chat thread context tests
+11 -2
View File
@@ -59,18 +59,20 @@ class TestClientInit:
assert client._subagent_enabled is False
assert client._plan_mode is False
assert client._agent_name is None
assert client._available_skills is None
assert client._checkpointer is None
assert client._agent is None
def test_custom_params(self, mock_app_config):
mock_middleware = MagicMock()
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", middlewares=[mock_middleware])
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
assert c._model_name == "gpt-4"
assert c._thinking_enabled is False
assert c._subagent_enabled is True
assert c._plan_mode is True
assert c._agent_name == "test-agent"
assert c._available_skills == {"skill1", "skill2"}
assert c._middlewares == [mock_middleware]
def test_invalid_agent_name(self, mock_app_config):
@@ -394,8 +396,10 @@ class TestEnsureAgent:
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._agent_name = "custom-agent"
client._available_skills = {"test_skill"}
client._ensure_agent(config)
assert client._agent is mock_agent
@@ -404,6 +408,7 @@ class TestEnsureAgent:
assert mock_build_middlewares.call_args.kwargs.get("agent_name") == "custom-agent"
mock_apply_prompt.assert_called_once()
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"}
def test_uses_default_checkpointer_when_available(self, client):
mock_agent = MagicMock()
@@ -441,6 +446,7 @@ class TestEnsureAgent:
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config)
@@ -469,7 +475,7 @@ class TestEnsureAgent:
"""_ensure_agent does not recreate if config key unchanged."""
mock_agent = MagicMock()
client._agent = mock_agent
client._agent_config_key = (None, True, False, False)
client._agent_config_key = (None, True, False, False, None, None)
config = client._get_runnable_config("t1")
client._ensure_agent(config)
@@ -1276,6 +1282,7 @@ class TestScenarioAgentRecreation:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config_a)
first_agent = client._agent
@@ -1303,6 +1310,7 @@ class TestScenarioAgentRecreation:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config)
client._ensure_agent(config)
@@ -1327,6 +1335,7 @@ class TestScenarioAgentRecreation:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config)
client.reset_agent()
+22
View File
@@ -164,6 +164,28 @@ class TestLoadAgentConfig:
assert cfg.tool_groups == ["file:read", "file:write"]
def test_load_config_with_skills_empty_list(self, tmp_path):
config_dict = {"name": "no-skills-agent", "skills": []}
_write_agent(tmp_path, "no-skills-agent", config_dict)
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)):
from deerflow.config.agents_config import load_agent_config
cfg = load_agent_config("no-skills-agent")
assert cfg.skills == []
def test_load_config_with_skills_omitted(self, tmp_path):
config_dict = {"name": "default-skills-agent"}
_write_agent(tmp_path, "default-skills-agent", config_dict)
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)):
from deerflow.config.agents_config import load_agent_config
cfg = load_agent_config("default-skills-agent")
assert cfg.skills is None
def test_legacy_prompt_file_field_ignored(self, tmp_path):
"""Unknown fields like the old prompt_file should be silently ignored."""
agent_dir = tmp_path / "agents" / "legacy-agent"
+55
View File
@@ -3,6 +3,7 @@ from unittest.mock import MagicMock
import pytest
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.feishu import FeishuChannel
from app.channels.message_bus import MessageBus
@@ -68,3 +69,57 @@ def test_feishu_on_message_rich_text():
assert "Paragraph 1, part 1. Paragraph 1, part 2." in parsed_text
assert "@bot Paragraph 2." in parsed_text
assert "\n\n" in parsed_text
@pytest.mark.parametrize("command", sorted(KNOWN_CHANNEL_COMMANDS))
def test_feishu_recognizes_all_known_slash_commands(command):
"""Every entry in KNOWN_CHANNEL_COMMANDS must be classified as a command."""
bus = MessageBus()
config = {"app_id": "test", "app_secret": "test"}
channel = FeishuChannel(bus, config)
event = MagicMock()
event.event.message.chat_id = "chat_1"
event.event.message.message_id = "msg_1"
event.event.message.root_id = None
event.event.sender.sender_id.open_id = "user_1"
event.event.message.content = json.dumps({"text": command})
with pytest.MonkeyPatch.context() as m:
mock_make_inbound = MagicMock()
m.setattr(channel, "_make_inbound", mock_make_inbound)
channel._on_message(event)
mock_make_inbound.assert_called_once()
assert mock_make_inbound.call_args[1]["msg_type"].value == "command", f"{command!r} should be classified as COMMAND"
@pytest.mark.parametrize(
"text",
[
"/unknown",
"/mnt/user-data/outputs/prd/technical-design.md",
"/etc/passwd",
"/not-a-command at all",
],
)
def test_feishu_treats_unknown_slash_text_as_chat(text):
"""Slash-prefixed text that is not a known command must be classified as CHAT."""
bus = MessageBus()
config = {"app_id": "test", "app_secret": "test"}
channel = FeishuChannel(bus, config)
event = MagicMock()
event.event.message.chat_id = "chat_1"
event.event.message.message_id = "msg_1"
event.event.message.root_id = None
event.event.sender.sender_id.open_id = "user_1"
event.event.message.content = json.dumps({"text": text})
with pytest.MonkeyPatch.context() as m:
mock_make_inbound = MagicMock()
m.setattr(channel, "_make_inbound", mock_make_inbound)
channel._on_message(event)
mock_make_inbound.assert_called_once()
assert mock_make_inbound.call_args[1]["msg_type"].value == "chat", f"{text!r} should be classified as CHAT"
+459
View File
@@ -0,0 +1,459 @@
"""Tests for file_conversion utilities (PR1: pymupdf4llm + asyncio.to_thread; PR2: extract_outline)."""
from __future__ import annotations
import asyncio
import sys
from types import ModuleType
from unittest.mock import MagicMock, patch
from deerflow.utils.file_conversion import (
_ASYNC_THRESHOLD_BYTES,
_MIN_CHARS_PER_PAGE,
MAX_OUTLINE_ENTRIES,
_do_convert,
_pymupdf_output_too_sparse,
convert_file_to_markdown,
extract_outline,
)
def _make_pymupdf_mock(page_count: int) -> ModuleType:
"""Return a fake *pymupdf* module whose ``open()`` reports *page_count* pages."""
mock_doc = MagicMock()
mock_doc.__len__ = MagicMock(return_value=page_count)
fake_pymupdf = ModuleType("pymupdf")
fake_pymupdf.open = MagicMock(return_value=mock_doc) # type: ignore[attr-defined]
return fake_pymupdf
def _run(coro):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
# ---------------------------------------------------------------------------
# _pymupdf_output_too_sparse
# ---------------------------------------------------------------------------
class TestPymupdfOutputTooSparse:
"""Check the chars-per-page sparsity heuristic."""
def test_dense_text_pdf_not_sparse(self, tmp_path):
"""Normal text PDF: many chars per page → not sparse."""
pdf = tmp_path / "dense.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
# 10 pages × 10 000 chars → 1000/page ≫ threshold
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=10)}):
result = _pymupdf_output_too_sparse("x" * 10_000, pdf)
assert result is False
def test_image_based_pdf_is_sparse(self, tmp_path):
"""Image-based PDF: near-zero chars per page → sparse."""
pdf = tmp_path / "image.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
# 612 chars / 31 pages ≈ 19.7/page < _MIN_CHARS_PER_PAGE (50)
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=31)}):
result = _pymupdf_output_too_sparse("x" * 612, pdf)
assert result is True
def test_fallback_when_pymupdf_unavailable(self, tmp_path):
"""When pymupdf is not installed, fall back to absolute 200-char threshold."""
pdf = tmp_path / "broken.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
# Remove pymupdf from sys.modules so the `import pymupdf` inside the
# function raises ImportError, triggering the absolute-threshold fallback.
with patch.dict(sys.modules, {"pymupdf": None}):
sparse = _pymupdf_output_too_sparse("x" * 100, pdf)
not_sparse = _pymupdf_output_too_sparse("x" * 300, pdf)
assert sparse is True
assert not_sparse is False
def test_exactly_at_threshold_is_not_sparse(self, tmp_path):
"""Chars-per-page == threshold is treated as NOT sparse (boundary inclusive)."""
pdf = tmp_path / "boundary.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
# 2 pages × _MIN_CHARS_PER_PAGE chars = exactly at threshold
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=2)}):
result = _pymupdf_output_too_sparse("x" * (_MIN_CHARS_PER_PAGE * 2), pdf)
assert result is False
# ---------------------------------------------------------------------------
# _do_convert — routing logic
# ---------------------------------------------------------------------------
class TestDoConvert:
"""Verify that _do_convert routes to the right sub-converter."""
def test_non_pdf_always_uses_markitdown(self, tmp_path):
"""DOCX / XLSX / PPTX always go through MarkItDown regardless of setting."""
docx = tmp_path / "report.docx"
docx.write_bytes(b"PK fake docx")
with patch(
"deerflow.utils.file_conversion._convert_with_markitdown",
return_value="# Markdown from MarkItDown",
) as mock_md:
result = _do_convert(docx, "auto")
mock_md.assert_called_once_with(docx)
assert result == "# Markdown from MarkItDown"
def test_pdf_auto_uses_pymupdf4llm_when_dense(self, tmp_path):
"""auto mode: use pymupdf4llm output when it's dense enough."""
pdf = tmp_path / "report.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
dense_text = "# Heading\n" + "word " * 2000 # clearly dense
with (
patch(
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
return_value=dense_text,
),
patch(
"deerflow.utils.file_conversion._pymupdf_output_too_sparse",
return_value=False,
),
patch("deerflow.utils.file_conversion._convert_with_markitdown") as mock_md,
):
result = _do_convert(pdf, "auto")
mock_md.assert_not_called()
assert result == dense_text
def test_pdf_auto_falls_back_when_sparse(self, tmp_path):
"""auto mode: fall back to MarkItDown when pymupdf4llm output is sparse."""
pdf = tmp_path / "scanned.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
with (
patch(
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
return_value="x" * 612, # 19.7 chars/page for 31-page doc
),
patch(
"deerflow.utils.file_conversion._pymupdf_output_too_sparse",
return_value=True,
),
patch(
"deerflow.utils.file_conversion._convert_with_markitdown",
return_value="OCR result via MarkItDown",
) as mock_md,
):
result = _do_convert(pdf, "auto")
mock_md.assert_called_once_with(pdf)
assert result == "OCR result via MarkItDown"
def test_pdf_explicit_pymupdf4llm_skips_sparsity_check(self, tmp_path):
"""'pymupdf4llm' mode: use output as-is even if sparse."""
pdf = tmp_path / "explicit.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
sparse_text = "x" * 10 # very short
with (
patch(
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
return_value=sparse_text,
),
patch("deerflow.utils.file_conversion._convert_with_markitdown") as mock_md,
):
result = _do_convert(pdf, "pymupdf4llm")
mock_md.assert_not_called()
assert result == sparse_text
def test_pdf_explicit_markitdown_skips_pymupdf4llm(self, tmp_path):
"""'markitdown' mode: never attempt pymupdf4llm."""
pdf = tmp_path / "force_md.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
with (
patch("deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm") as mock_pymu,
patch(
"deerflow.utils.file_conversion._convert_with_markitdown",
return_value="MarkItDown result",
),
):
result = _do_convert(pdf, "markitdown")
mock_pymu.assert_not_called()
assert result == "MarkItDown result"
def test_pdf_auto_falls_back_when_pymupdf4llm_not_installed(self, tmp_path):
"""auto mode: if pymupdf4llm is not installed, use MarkItDown directly."""
pdf = tmp_path / "no_pymupdf.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
with (
patch(
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
return_value=None, # None signals not installed
),
patch(
"deerflow.utils.file_conversion._convert_with_markitdown",
return_value="MarkItDown fallback",
) as mock_md,
):
result = _do_convert(pdf, "auto")
mock_md.assert_called_once_with(pdf)
assert result == "MarkItDown fallback"
# ---------------------------------------------------------------------------
# convert_file_to_markdown — async + file writing
# ---------------------------------------------------------------------------
class TestConvertFileToMarkdown:
def test_small_file_runs_synchronously(self, tmp_path):
"""Small files (< 1 MB) are converted in the event loop thread."""
pdf = tmp_path / "small.pdf"
pdf.write_bytes(b"%PDF-1.4 " + b"x" * 100) # well under 1 MB
with (
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
patch(
"deerflow.utils.file_conversion._do_convert",
return_value="# Small PDF",
) as mock_convert,
patch("asyncio.to_thread") as mock_thread,
):
md_path = _run(convert_file_to_markdown(pdf))
# asyncio.to_thread must NOT have been called
mock_thread.assert_not_called()
mock_convert.assert_called_once()
assert md_path == pdf.with_suffix(".md")
assert md_path.read_text() == "# Small PDF"
def test_large_file_offloaded_to_thread(self, tmp_path):
"""Large files (> 1 MB) are offloaded via asyncio.to_thread."""
pdf = tmp_path / "large.pdf"
# Write slightly more than the threshold
pdf.write_bytes(b"%PDF-1.4 " + b"x" * (_ASYNC_THRESHOLD_BYTES + 1))
async def fake_to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
with (
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
patch(
"deerflow.utils.file_conversion._do_convert",
return_value="# Large PDF",
),
patch("asyncio.to_thread", side_effect=fake_to_thread) as mock_thread,
):
md_path = _run(convert_file_to_markdown(pdf))
mock_thread.assert_called_once()
assert md_path == pdf.with_suffix(".md")
assert md_path.read_text() == "# Large PDF"
def test_returns_none_on_conversion_error(self, tmp_path):
"""If conversion raises, return None without propagating the exception."""
pdf = tmp_path / "broken.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
with (
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
patch(
"deerflow.utils.file_conversion._do_convert",
side_effect=RuntimeError("conversion failed"),
),
):
result = _run(convert_file_to_markdown(pdf))
assert result is None
def test_writes_utf8_markdown_file(self, tmp_path):
"""Generated .md file is written with UTF-8 encoding."""
pdf = tmp_path / "report.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
chinese_content = "# 中文报告\n\n这是测试内容。"
with (
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
patch(
"deerflow.utils.file_conversion._do_convert",
return_value=chinese_content,
),
):
md_path = _run(convert_file_to_markdown(pdf))
assert md_path is not None
assert md_path.read_text(encoding="utf-8") == chinese_content
# ---------------------------------------------------------------------------
# extract_outline
# ---------------------------------------------------------------------------
class TestExtractOutline:
"""Tests for extract_outline()."""
def test_empty_file_returns_empty(self, tmp_path):
"""Empty markdown file yields no outline entries."""
md = tmp_path / "empty.md"
md.write_text("", encoding="utf-8")
assert extract_outline(md) == []
def test_missing_file_returns_empty(self, tmp_path):
"""Non-existent path returns [] without raising."""
assert extract_outline(tmp_path / "nonexistent.md") == []
def test_standard_markdown_headings(self, tmp_path):
"""# / ## / ### headings are all recognised."""
md = tmp_path / "doc.md"
md.write_text(
"# Chapter One\n\nSome text.\n\n## Section 1.1\n\nMore text.\n\n### Sub 1.1.1\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert len(outline) == 3
assert outline[0] == {"title": "Chapter One", "line": 1}
assert outline[1] == {"title": "Section 1.1", "line": 5}
assert outline[2] == {"title": "Sub 1.1.1", "line": 9}
def test_bold_sec_item_heading(self, tmp_path):
"""**ITEM N. TITLE** lines in SEC filings are recognised."""
md = tmp_path / "10k.md"
md.write_text(
"Cover page text.\n\n**ITEM 1. BUSINESS**\n\nBody.\n\n**ITEM 1A. RISK FACTORS**\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert len(outline) == 2
assert outline[0] == {"title": "ITEM 1. BUSINESS", "line": 3}
assert outline[1] == {"title": "ITEM 1A. RISK FACTORS", "line": 7}
def test_bold_part_heading(self, tmp_path):
"""**PART I** / **PART II** headings are recognised."""
md = tmp_path / "10k.md"
md.write_text("**PART I**\n\n**PART II**\n\n**PART III**\n", encoding="utf-8")
outline = extract_outline(md)
assert len(outline) == 3
titles = [e["title"] for e in outline]
assert "PART I" in titles
assert "PART II" in titles
assert "PART III" in titles
def test_sec_cover_page_boilerplate_excluded(self, tmp_path):
"""Address lines and short cover boilerplate must NOT appear in outline."""
md = tmp_path / "8k.md"
md.write_text(
"## **UNITED STATES SECURITIES AND EXCHANGE COMMISSION**\n\n**WASHINGTON, DC 20549**\n\n**CURRENT REPORT**\n\n**SIGNATURES**\n\n**TESLA, INC.**\n\n**ITEM 2.02. RESULTS OF OPERATIONS**\n",
encoding="utf-8",
)
outline = extract_outline(md)
titles = [e["title"] for e in outline]
# Cover-page boilerplate should be excluded
assert "WASHINGTON, DC 20549" not in titles
assert "CURRENT REPORT" not in titles
assert "SIGNATURES" not in titles
assert "TESLA, INC." not in titles
# Real SEC heading must be included
assert "ITEM 2.02. RESULTS OF OPERATIONS" in titles
def test_chinese_headings_via_standard_markdown(self, tmp_path):
"""Chinese annual report headings emitted as # by pymupdf4llm are captured."""
md = tmp_path / "annual.md"
md.write_text(
"# 第一节 公司简介\n\n内容。\n\n## 第三节 管理层讨论与分析\n\n分析内容。\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert len(outline) == 2
assert outline[0]["title"] == "第一节 公司简介"
assert outline[1]["title"] == "第三节 管理层讨论与分析"
def test_outline_capped_at_max_entries(self, tmp_path):
"""When truncated, result has MAX_OUTLINE_ENTRIES real entries + 1 sentinel."""
lines = [f"# Heading {i}" for i in range(MAX_OUTLINE_ENTRIES + 10)]
md = tmp_path / "long.md"
md.write_text("\n".join(lines), encoding="utf-8")
outline = extract_outline(md)
# Last entry is the truncation sentinel
assert outline[-1] == {"truncated": True}
# Visible entries are exactly MAX_OUTLINE_ENTRIES
visible = [e for e in outline if not e.get("truncated")]
assert len(visible) == MAX_OUTLINE_ENTRIES
def test_no_truncation_sentinel_when_under_limit(self, tmp_path):
"""Short documents produce no sentinel entry."""
lines = [f"# Heading {i}" for i in range(5)]
md = tmp_path / "short.md"
md.write_text("\n".join(lines), encoding="utf-8")
outline = extract_outline(md)
assert len(outline) == 5
assert not any(e.get("truncated") for e in outline)
def test_blank_lines_and_whitespace_ignored(self, tmp_path):
"""Blank lines between headings do not produce empty entries."""
md = tmp_path / "spaced.md"
md.write_text("\n\n# Title One\n\n\n\n# Title Two\n\n", encoding="utf-8")
outline = extract_outline(md)
assert len(outline) == 2
assert all(e["title"] for e in outline)
def test_inline_bold_not_confused_with_heading(self, tmp_path):
"""Mid-sentence bold text must not be mistaken for a heading."""
md = tmp_path / "prose.md"
md.write_text(
"This sentence has **bold words** inside it.\n\nAnother with **MULTIPLE CAPS** inline.\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert outline == []
def test_split_bold_heading_academic_paper(self, tmp_path):
"""**<num>** **<title>** lines from academic papers are recognised (Style 3)."""
md = tmp_path / "paper.md"
md.write_text(
"## **Attention Is All You Need**\n\n**1** **Introduction**\n\nBody text.\n\n**2** **Background**\n\nMore text.\n\n**3.1** **Encoder and Decoder Stacks**\n",
encoding="utf-8",
)
outline = extract_outline(md)
titles = [e["title"] for e in outline]
assert "1 Introduction" in titles
assert "2 Background" in titles
assert "3.1 Encoder and Decoder Stacks" in titles
def test_split_bold_year_columns_excluded(self, tmp_path):
"""Financial table headers like **2023** **2022** **2021** are NOT headings."""
md = tmp_path / "annual.md"
md.write_text(
"# Financial Summary\n\n**2023** **2022** **2021**\n\nRevenue 100 90 80\n",
encoding="utf-8",
)
outline = extract_outline(md)
titles = [e["title"] for e in outline]
# Only the # heading should appear, not the year-column row
assert titles == ["Financial Summary"]
def test_adjacent_bold_spans_merged_in_markdown_heading(self, tmp_path):
"""** ** artefacts inside a # heading are merged into clean plain text."""
md = tmp_path / "sec.md"
md.write_text(
"## **UNITED STATES** **SECURITIES AND EXCHANGE COMMISSION**\n\nBody text.\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert len(outline) == 1
# Title must be clean — no ** ** artefacts
assert outline[0]["title"] == "UNITED STATES SECURITIES AND EXCHANGE COMMISSION"
+189 -9
View File
@@ -109,17 +109,11 @@ def test_build_run_config_with_overrides():
def test_build_run_config_custom_agent_injects_agent_name():
"""Custom assistant_id must be forwarded as configurable['agent_name'].
Regression test for #1644: when the LangGraph Platform-compatible
/runs endpoint receives a custom assistant_id (e.g. 'finalis'), the
Gateway must inject configurable['agent_name'] so that make_lead_agent
loads the correct agents/finalis/SOUL.md.
"""
"""Custom assistant_id must be forwarded as configurable['agent_name']."""
from app.gateway.services import build_run_config
config = build_run_config("thread-1", None, None, assistant_id="finalis")
assert config["configurable"]["agent_name"] == "finalis", "Custom assistant_id must be forwarded as configurable['agent_name'] so that make_lead_agent loads the correct SOUL.md"
assert config["configurable"]["agent_name"] == "finalis"
def test_build_run_config_lead_agent_no_agent_name():
@@ -148,7 +142,7 @@ def test_build_run_config_explicit_agent_name_not_overwritten():
None,
assistant_id="other-agent",
)
assert config["configurable"]["agent_name"] == "explicit-agent", "An explicit configurable['agent_name'] in the request body must not be overwritten by the assistant_id mapping"
assert config["configurable"]["agent_name"] == "explicit-agent"
def test_resolve_agent_factory_returns_make_lead_agent():
@@ -160,3 +154,189 @@ def test_resolve_agent_factory_returns_make_lead_agent():
assert resolve_agent_factory("lead_agent") is make_lead_agent
assert resolve_agent_factory("finalis") is make_lead_agent
assert resolve_agent_factory("custom-agent-123") is make_lead_agent
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Regression tests for issue #1699:
# context field in langgraph-compat requests not merged into configurable
# ---------------------------------------------------------------------------
def test_run_create_request_accepts_context():
"""RunCreateRequest must accept the ``context`` field without dropping it."""
from app.gateway.routers.thread_runs import RunCreateRequest
body = RunCreateRequest(
input={"messages": [{"role": "user", "content": "hi"}]},
context={
"model_name": "deepseek-v3",
"thinking_enabled": True,
"is_plan_mode": True,
"subagent_enabled": True,
"thread_id": "some-thread-id",
},
)
assert body.context is not None
assert body.context["model_name"] == "deepseek-v3"
assert body.context["is_plan_mode"] is True
assert body.context["subagent_enabled"] is True
def test_run_create_request_context_defaults_to_none():
"""RunCreateRequest without context should default to None (backward compat)."""
from app.gateway.routers.thread_runs import RunCreateRequest
body = RunCreateRequest(input=None)
assert body.context is None
def test_context_merges_into_configurable():
"""Context values must be merged into config['configurable'] by start_run.
Since start_run is async and requires many dependencies, we test the
merging logic directly by simulating what start_run does.
"""
from app.gateway.services import build_run_config
# Simulate the context merging logic from start_run
config = build_run_config("thread-1", None, None)
context = {
"model_name": "deepseek-v3",
"mode": "ultra",
"reasoning_effort": "high",
"thinking_enabled": True,
"is_plan_mode": True,
"subagent_enabled": True,
"max_concurrent_subagents": 5,
"thread_id": "should-be-ignored",
}
_CONTEXT_CONFIGURABLE_KEYS = {
"model_name",
"mode",
"thinking_enabled",
"reasoning_effort",
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
}
configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
if key in context:
configurable.setdefault(key, context[key])
assert config["configurable"]["model_name"] == "deepseek-v3"
assert config["configurable"]["thinking_enabled"] is True
assert config["configurable"]["is_plan_mode"] is True
assert config["configurable"]["subagent_enabled"] is True
assert config["configurable"]["max_concurrent_subagents"] == 5
assert config["configurable"]["reasoning_effort"] == "high"
assert config["configurable"]["mode"] == "ultra"
# thread_id from context should NOT override the one from build_run_config
assert config["configurable"]["thread_id"] == "thread-1"
# Non-allowlisted keys should not appear
assert "thread_id" not in {k for k in context if k in _CONTEXT_CONFIGURABLE_KEYS}
def test_context_does_not_override_existing_configurable():
"""Values already in config.configurable must NOT be overridden by context."""
from app.gateway.services import build_run_config
config = build_run_config(
"thread-1",
{"configurable": {"model_name": "gpt-4", "is_plan_mode": False}},
None,
)
context = {
"model_name": "deepseek-v3",
"is_plan_mode": True,
"subagent_enabled": True,
}
_CONTEXT_CONFIGURABLE_KEYS = {
"model_name",
"mode",
"thinking_enabled",
"reasoning_effort",
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
}
configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
if key in context:
configurable.setdefault(key, context[key])
# Existing values must NOT be overridden
assert config["configurable"]["model_name"] == "gpt-4"
assert config["configurable"]["is_plan_mode"] is False
# New values should be added
assert config["configurable"]["subagent_enabled"] is True
# ---------------------------------------------------------------------------
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
# ---------------------------------------------------------------------------
def test_build_run_config_with_context():
"""When caller sends 'context', prefer it over 'configurable'."""
from app.gateway.services import build_run_config
config = build_run_config(
"thread-1",
{"context": {"user_id": "u-42", "thread_id": "thread-1"}},
None,
)
assert "context" in config
assert config["context"]["user_id"] == "u-42"
assert "configurable" not in config
assert config["recursion_limit"] == 100
def test_build_run_config_context_plus_configurable_warns(caplog):
"""When caller sends both 'context' and 'configurable', prefer 'context' and log a warning."""
import logging
from app.gateway.services import build_run_config
with caplog.at_level(logging.WARNING, logger="app.gateway.services"):
config = build_run_config(
"thread-1",
{
"context": {"user_id": "u-42"},
"configurable": {"model_name": "gpt-4"},
},
None,
)
assert "context" in config
assert config["context"]["user_id"] == "u-42"
assert "configurable" not in config
assert any("both 'context' and 'configurable'" in r.message for r in caplog.records)
def test_build_run_config_context_passthrough_other_keys():
"""Non-conflicting keys from request_config are still passed through when context is used."""
from app.gateway.services import build_run_config
config = build_run_config(
"thread-1",
{"context": {"thread_id": "thread-1"}, "tags": ["prod"]},
None,
)
assert config["context"]["thread_id"] == "thread-1"
assert "configurable" not in config
assert config["tags"] == ["prod"]
def test_build_run_config_no_request_config():
"""When request_config is None, fall back to basic configurable with thread_id."""
from app.gateway.services import build_run_config
config = build_run_config("thread-abc", None, None)
assert config["configurable"] == {"thread_id": "thread-abc"}
assert "context" not in config
+135 -3
View File
@@ -8,6 +8,7 @@ import pytest
from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig, set_extensions_config
from deerflow.tools.builtins.invoke_acp_agent_tool import (
_build_acp_mcp_servers,
_build_mcp_servers,
_build_permission_response,
_get_work_dir,
@@ -42,6 +43,43 @@ def test_build_mcp_servers_filters_disabled_and_maps_transports():
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
def test_build_acp_mcp_servers_formats_list_payload():
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
fresh_config = ExtensionsConfig(
mcp_servers={
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"], env={"FOO": "bar"}),
"http": McpServerConfig(enabled=True, type="http", url="https://example.com/mcp", headers={"Authorization": "Bearer token"}),
"disabled": McpServerConfig(enabled=False, type="stdio", command="echo"),
},
skills={},
)
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: fresh_config),
)
try:
assert _build_acp_mcp_servers() == [
{
"name": "stdio",
"type": "stdio",
"command": "npx",
"args": ["srv"],
"env": [{"name": "FOO", "value": "bar"}],
},
{
"name": "http",
"type": "http",
"url": "https://example.com/mcp",
"headers": [{"name": "Authorization", "value": "Bearer token"}],
},
]
finally:
monkeypatch.undo()
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
def test_build_permission_response_prefers_allow_once():
response = _build_permission_response(
[
@@ -251,9 +289,15 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path):
assert captured["spawn"] == {"cmd": "codex-acp", "args": ["--json"], "cwd": expected_cwd}
assert captured["new_session"] == {
"cwd": expected_cwd,
"mcp_servers": {
"github": {"transport": "stdio", "command": "npx", "args": ["github-mcp"]},
},
"mcp_servers": [
{
"name": "github",
"type": "stdio",
"command": "npx",
"args": ["github-mcp"],
"env": [],
}
],
"model": "gpt-5-codex",
}
assert captured["prompt"] == {
@@ -448,6 +492,94 @@ async def test_invoke_acp_agent_passes_env_to_spawn(monkeypatch, tmp_path):
assert captured["env"] == {"OPENAI_API_KEY": "sk-from-env", "FOO": "bar"}
@pytest.mark.anyio
async def test_invoke_acp_agent_skips_invalid_mcp_servers(monkeypatch, tmp_path, caplog):
"""Invalid MCP config should be logged and skipped instead of failing ACP invocation."""
from deerflow.config import paths as paths_module
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
monkeypatch.setattr(
"deerflow.tools.builtins.invoke_acp_agent_tool._build_acp_mcp_servers",
lambda: (_ for _ in ()).throw(ValueError("missing command")),
)
captured: dict[str, object] = {}
class DummyClient:
def __init__(self) -> None:
self._chunks: list[str] = []
@property
def collected_text(self) -> str:
return ""
async def session_update(self, session_id, update, **kwargs):
pass
async def request_permission(self, options, session_id, tool_call, **kwargs):
raise AssertionError("should not be called")
class DummyConn:
async def initialize(self, **kwargs):
pass
async def new_session(self, **kwargs):
captured["new_session"] = kwargs
return SimpleNamespace(session_id="s1")
async def prompt(self, **kwargs):
pass
class DummyProcessContext:
def __init__(self, client, cmd, *args, env=None, cwd=None):
captured["spawn"] = {"cmd": cmd, "args": list(args), "env": env, "cwd": cwd}
async def __aenter__(self):
return DummyConn(), object()
async def __aexit__(self, exc_type, exc, tb):
return False
class DummyRequestError(Exception):
@staticmethod
def method_not_found(method):
return DummyRequestError(method)
monkeypatch.setitem(
sys.modules,
"acp",
SimpleNamespace(
PROTOCOL_VERSION="2026-03-24",
Client=DummyClient,
RequestError=DummyRequestError,
spawn_agent_process=lambda client, cmd, *args, env=None, cwd: DummyProcessContext(client, cmd, *args, env=env, cwd=cwd),
text_block=lambda text: {"type": "text", "text": text},
),
)
monkeypatch.setitem(
sys.modules,
"acp.schema",
SimpleNamespace(
ClientCapabilities=lambda: {},
Implementation=lambda **kwargs: kwargs,
TextContentBlock=type("TextContentBlock", (), {"__init__": lambda self, text: setattr(self, "text", text)}),
),
)
tool = build_invoke_acp_agent_tool({"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")})
caplog.set_level("WARNING")
try:
await tool.coroutine(agent="codex", prompt="Do something")
finally:
sys.modules.pop("acp", None)
sys.modules.pop("acp.schema", None)
assert captured["new_session"]["mcp_servers"] == []
assert "continuing without MCP servers" in caplog.text
assert "missing command" in caplog.text
@pytest.mark.anyio
async def test_invoke_acp_agent_passes_none_env_when_not_configured(monkeypatch, tmp_path):
"""When env is empty, None is passed to spawn_agent_process (subprocess inherits parent env)."""
+177
View File
@@ -0,0 +1,177 @@
"""Tests for JinaClient async crawl method."""
import logging
from unittest.mock import MagicMock
import httpx
import pytest
import deerflow.community.jina_ai.jina_client as jina_client_module
from deerflow.community.jina_ai.jina_client import JinaClient
from deerflow.community.jina_ai.tools import web_fetch_tool
@pytest.fixture
def jina_client():
return JinaClient()
@pytest.mark.anyio
async def test_crawl_success(jina_client, monkeypatch):
"""Test successful crawl returns response text."""
async def mock_post(self, url, **kwargs):
return httpx.Response(200, text="<html><body>Hello</body></html>", request=httpx.Request("POST", url))
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
result = await jina_client.crawl("https://example.com")
assert result == "<html><body>Hello</body></html>"
@pytest.mark.anyio
async def test_crawl_non_200_status(jina_client, monkeypatch):
"""Test that non-200 status returns error message."""
async def mock_post(self, url, **kwargs):
return httpx.Response(429, text="Rate limited", request=httpx.Request("POST", url))
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
result = await jina_client.crawl("https://example.com")
assert result.startswith("Error:")
assert "429" in result
@pytest.mark.anyio
async def test_crawl_empty_response(jina_client, monkeypatch):
"""Test that empty response returns error message."""
async def mock_post(self, url, **kwargs):
return httpx.Response(200, text="", request=httpx.Request("POST", url))
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
result = await jina_client.crawl("https://example.com")
assert result.startswith("Error:")
assert "empty" in result.lower()
@pytest.mark.anyio
async def test_crawl_whitespace_only_response(jina_client, monkeypatch):
"""Test that whitespace-only response returns error message."""
async def mock_post(self, url, **kwargs):
return httpx.Response(200, text=" \n ", request=httpx.Request("POST", url))
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
result = await jina_client.crawl("https://example.com")
assert result.startswith("Error:")
assert "empty" in result.lower()
@pytest.mark.anyio
async def test_crawl_network_error(jina_client, monkeypatch):
"""Test that network errors are handled gracefully."""
async def mock_post(self, url, **kwargs):
raise httpx.ConnectError("Connection refused")
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
result = await jina_client.crawl("https://example.com")
assert result.startswith("Error:")
assert "failed" in result.lower()
@pytest.mark.anyio
async def test_crawl_passes_headers(jina_client, monkeypatch):
"""Test that correct headers are sent."""
captured_headers = {}
async def mock_post(self, url, **kwargs):
captured_headers.update(kwargs.get("headers", {}))
return httpx.Response(200, text="ok", request=httpx.Request("POST", url))
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
await jina_client.crawl("https://example.com", return_format="markdown", timeout=30)
assert captured_headers["X-Return-Format"] == "markdown"
assert captured_headers["X-Timeout"] == "30"
@pytest.mark.anyio
async def test_crawl_includes_api_key_when_set(jina_client, monkeypatch):
"""Test that Authorization header is set when JINA_API_KEY is available."""
captured_headers = {}
async def mock_post(self, url, **kwargs):
captured_headers.update(kwargs.get("headers", {}))
return httpx.Response(200, text="ok", request=httpx.Request("POST", url))
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
monkeypatch.setenv("JINA_API_KEY", "test-key-123")
await jina_client.crawl("https://example.com")
assert captured_headers["Authorization"] == "Bearer test-key-123"
@pytest.mark.anyio
async def test_crawl_warns_once_when_api_key_missing(jina_client, monkeypatch, caplog):
"""Test that the missing API key warning is logged only once."""
jina_client_module._api_key_warned = False
async def mock_post(self, url, **kwargs):
return httpx.Response(200, text="ok", request=httpx.Request("POST", url))
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
monkeypatch.delenv("JINA_API_KEY", raising=False)
with caplog.at_level(logging.WARNING, logger="deerflow.community.jina_ai.jina_client"):
await jina_client.crawl("https://example.com")
await jina_client.crawl("https://example.com")
warning_count = sum(1 for record in caplog.records if "Jina API key is not set" in record.message)
assert warning_count == 1
@pytest.mark.anyio
async def test_crawl_no_auth_header_without_api_key(jina_client, monkeypatch):
"""Test that no Authorization header is set when JINA_API_KEY is not available."""
jina_client_module._api_key_warned = False
captured_headers = {}
async def mock_post(self, url, **kwargs):
captured_headers.update(kwargs.get("headers", {}))
return httpx.Response(200, text="ok", request=httpx.Request("POST", url))
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
monkeypatch.delenv("JINA_API_KEY", raising=False)
await jina_client.crawl("https://example.com")
assert "Authorization" not in captured_headers
@pytest.mark.anyio
async def test_web_fetch_tool_returns_error_on_crawl_failure(monkeypatch):
"""Test that web_fetch_tool short-circuits and returns the error string when crawl fails."""
async def mock_crawl(self, url, **kwargs):
return "Error: Jina API returned status 429: Rate limited"
mock_config = MagicMock()
mock_config.get_tool_config.return_value = None
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
result = await web_fetch_tool.ainvoke("https://example.com")
assert result.startswith("Error:")
assert "429" in result
@pytest.mark.anyio
async def test_web_fetch_tool_returns_markdown_on_success(monkeypatch):
"""Test that web_fetch_tool returns extracted markdown on successful crawl."""
async def mock_crawl(self, url, **kwargs):
return "<html><body><p>Hello world</p></body></html>"
mock_config = MagicMock()
mock_config.get_tool_config.return_value = None
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
result = await web_fetch_tool.ainvoke("https://example.com")
assert "Hello world" in result
assert not result.startswith("Error:")
+96
View File
@@ -0,0 +1,96 @@
from pathlib import Path
from deerflow.agents.lead_agent.prompt import get_skills_prompt_section
from deerflow.config.agents_config import AgentConfig
from deerflow.skills.types import Skill
def _make_skill(name: str) -> Skill:
return Skill(
name=name,
description=f"Description for {name}",
license="MIT",
skill_dir=Path(f"/tmp/{name}"),
skill_file=Path(f"/tmp/{name}/SKILL.md"),
relative_path=Path(name),
category="public",
enabled=True,
)
def test_get_skills_prompt_section_returns_empty_when_no_skills_match(monkeypatch):
skills = [_make_skill("skill1"), _make_skill("skill2")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt.load_skills", lambda enabled_only: skills)
result = get_skills_prompt_section(available_skills={"non_existent_skill"})
assert result == ""
def test_get_skills_prompt_section_returns_empty_when_available_skills_empty(monkeypatch):
skills = [_make_skill("skill1"), _make_skill("skill2")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt.load_skills", lambda enabled_only: skills)
result = get_skills_prompt_section(available_skills=set())
assert result == ""
def test_get_skills_prompt_section_returns_skills(monkeypatch):
skills = [_make_skill("skill1"), _make_skill("skill2")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt.load_skills", lambda enabled_only: skills)
result = get_skills_prompt_section(available_skills={"skill1"})
assert "skill1" in result
assert "skill2" not in result
def test_get_skills_prompt_section_returns_all_when_available_skills_is_none(monkeypatch):
skills = [_make_skill("skill1"), _make_skill("skill2")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt.load_skills", lambda enabled_only: skills)
result = get_skills_prompt_section(available_skills=None)
assert "skill1" in result
assert "skill2" in result
def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
from unittest.mock import MagicMock
from deerflow.agents.lead_agent import agent as lead_agent_module
# Mock dependencies
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock())
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
class MockModelConfig:
supports_thinking = False
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = MockModelConfig()
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
captured_skills = []
def mock_apply_prompt_template(**kwargs):
captured_skills.append(kwargs.get("available_skills"))
return "mock_prompt"
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", mock_apply_prompt_template)
# Case 1: Empty skills list
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=[]))
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert captured_skills[-1] == set()
# Case 2: None skills list
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert captured_skills[-1] is None
# Case 3: Some skills list
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert captured_skills[-1] == {"skill1"}
@@ -0,0 +1,136 @@
from __future__ import annotations
import asyncio
from types import SimpleNamespace
import pytest
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
from deerflow.agents.middlewares.llm_error_handling_middleware import (
LLMErrorHandlingMiddleware,
)
class FakeError(Exception):
def __init__(
self,
message: str,
*,
status_code: int | None = None,
code: str | None = None,
headers: dict[str, str] | None = None,
body: dict | None = None,
) -> None:
super().__init__(message)
self.status_code = status_code
self.code = code
self.body = body
self.response = SimpleNamespace(status_code=status_code, headers=headers or {}) if status_code is not None or headers else None
def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware:
middleware = LLMErrorHandlingMiddleware()
for key, value in attrs.items():
setattr(middleware, key, value)
return middleware
def test_async_model_call_retries_busy_provider_then_succeeds(
monkeypatch: pytest.MonkeyPatch,
) -> None:
middleware = _build_middleware(retry_max_attempts=3, retry_base_delay_ms=25, retry_cap_delay_ms=25)
attempts = 0
waits: list[float] = []
events: list[dict] = []
async def fake_sleep(delay: float) -> None:
waits.append(delay)
def fake_writer():
return events.append
async def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
if attempts < 3:
raise FakeError("当前服务集群负载较高,请稍后重试,感谢您的耐心等待。 (2064)")
return AIMessage(content="ok")
monkeypatch.setattr("asyncio.sleep", fake_sleep)
monkeypatch.setattr(
"langgraph.config.get_stream_writer",
fake_writer,
)
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
assert isinstance(result, AIMessage)
assert result.content == "ok"
assert attempts == 3
assert waits == [0.025, 0.025]
assert [event["type"] for event in events] == ["llm_retry", "llm_retry"]
def test_async_model_call_returns_user_message_for_quota_errors() -> None:
middleware = _build_middleware(retry_max_attempts=3)
async def handler(_request) -> AIMessage:
raise FakeError(
"insufficient_quota: account balance is empty",
status_code=429,
code="insufficient_quota",
)
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
assert isinstance(result, AIMessage)
assert "out of quota" in str(result.content)
def test_sync_model_call_uses_retry_after_header(monkeypatch: pytest.MonkeyPatch) -> None:
middleware = _build_middleware(retry_max_attempts=2, retry_base_delay_ms=10, retry_cap_delay_ms=10)
waits: list[float] = []
attempts = 0
def fake_sleep(delay: float) -> None:
waits.append(delay)
def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
if attempts == 1:
raise FakeError(
"server busy",
status_code=503,
headers={"Retry-After": "2"},
)
return AIMessage(content="ok")
monkeypatch.setattr("time.sleep", fake_sleep)
result = middleware.wrap_model_call(SimpleNamespace(), handler)
assert isinstance(result, AIMessage)
assert result.content == "ok"
assert waits == [2.0]
def test_sync_model_call_propagates_graph_bubble_up() -> None:
middleware = _build_middleware()
def handler(_request) -> AIMessage:
raise GraphBubbleUp()
with pytest.raises(GraphBubbleUp):
middleware.wrap_model_call(SimpleNamespace(), handler)
def test_async_model_call_propagates_graph_bubble_up() -> None:
middleware = _build_middleware()
async def handler(_request) -> AIMessage:
raise GraphBubbleUp()
with pytest.raises(GraphBubbleUp):
asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
@@ -0,0 +1,388 @@
import errno
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
from deerflow.sandbox.local.local_sandbox_provider import LocalSandboxProvider
class TestPathMapping:
def test_path_mapping_dataclass(self):
mapping = PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True)
assert mapping.container_path == "/mnt/skills"
assert mapping.local_path == "/home/user/skills"
assert mapping.read_only is True
def test_path_mapping_defaults_to_false(self):
mapping = PathMapping(container_path="/mnt/data", local_path="/home/user/data")
assert mapping.read_only is False
class TestLocalSandboxPathResolution:
def test_resolve_path_exact_match(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
],
)
resolved = sandbox._resolve_path("/mnt/skills")
assert resolved == "/home/user/skills"
def test_resolve_path_nested_path(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
],
)
resolved = sandbox._resolve_path("/mnt/skills/agent/prompt.py")
assert resolved == "/home/user/skills/agent/prompt.py"
def test_resolve_path_no_mapping(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
],
)
resolved = sandbox._resolve_path("/mnt/other/file.txt")
assert resolved == "/mnt/other/file.txt"
def test_resolve_path_longest_prefix_first(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
PathMapping(container_path="/mnt", local_path="/var/mnt"),
],
)
resolved = sandbox._resolve_path("/mnt/skills/file.py")
# Should match /mnt/skills first (longer prefix)
assert resolved == "/home/user/skills/file.py"
def test_reverse_resolve_path_exact_match(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir)),
],
)
resolved = sandbox._reverse_resolve_path(str(skills_dir))
assert resolved == "/mnt/skills"
def test_reverse_resolve_path_nested(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
file_path = skills_dir / "agent" / "prompt.py"
file_path.parent.mkdir()
file_path.write_text("test")
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir)),
],
)
resolved = sandbox._reverse_resolve_path(str(file_path))
assert resolved == "/mnt/skills/agent/prompt.py"
class TestReadOnlyPath:
def test_is_read_only_true(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
],
)
assert sandbox._is_read_only_path("/home/user/skills/file.py") is True
def test_is_read_only_false_for_writable(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/data", local_path="/home/user/data", read_only=False),
],
)
assert sandbox._is_read_only_path("/home/user/data/file.txt") is False
def test_is_read_only_false_for_unmapped_path(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
],
)
# Path not under any mapping
assert sandbox._is_read_only_path("/tmp/other/file.txt") is False
def test_is_read_only_true_for_exact_match(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
],
)
assert sandbox._is_read_only_path("/home/user/skills") is True
def test_write_file_blocked_on_read_only(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
],
)
# Skills dir is read-only, write should be blocked
with pytest.raises(OSError) as exc_info:
sandbox.write_file("/mnt/skills/new_file.py", "content")
assert exc_info.value.errno == errno.EROFS
def test_write_file_allowed_on_writable_mount(self, tmp_path):
data_dir = tmp_path / "data"
data_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/data", local_path=str(data_dir), read_only=False),
],
)
sandbox.write_file("/mnt/data/file.txt", "content")
assert (data_dir / "file.txt").read_text() == "content"
def test_update_file_blocked_on_read_only(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
existing_file = skills_dir / "existing.py"
existing_file.write_bytes(b"original")
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
],
)
with pytest.raises(OSError) as exc_info:
sandbox.update_file("/mnt/skills/existing.py", b"updated")
assert exc_info.value.errno == errno.EROFS
class TestMultipleMounts:
def test_multiple_read_write_mounts(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
data_dir = tmp_path / "data"
data_dir.mkdir()
external_dir = tmp_path / "external"
external_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
PathMapping(container_path="/mnt/data", local_path=str(data_dir), read_only=False),
PathMapping(container_path="/mnt/external", local_path=str(external_dir), read_only=True),
],
)
# Skills is read-only
with pytest.raises(OSError):
sandbox.write_file("/mnt/skills/file.py", "content")
# Data is writable
sandbox.write_file("/mnt/data/file.txt", "data content")
assert (data_dir / "file.txt").read_text() == "data content"
# External is read-only
with pytest.raises(OSError):
sandbox.write_file("/mnt/external/file.txt", "content")
def test_nested_mounts_writable_under_readonly(self, tmp_path):
"""A writable mount nested under a read-only mount should allow writes."""
ro_dir = tmp_path / "ro"
ro_dir.mkdir()
rw_dir = ro_dir / "writable"
rw_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/repo", local_path=str(ro_dir), read_only=True),
PathMapping(container_path="/mnt/repo/writable", local_path=str(rw_dir), read_only=False),
],
)
# Parent mount is read-only
with pytest.raises(OSError):
sandbox.write_file("/mnt/repo/file.txt", "content")
# Nested writable mount should allow writes
sandbox.write_file("/mnt/repo/writable/file.txt", "content")
assert (rw_dir / "file.txt").read_text() == "content"
def test_execute_command_path_replacement(self, tmp_path, monkeypatch):
data_dir = tmp_path / "data"
data_dir.mkdir()
test_file = data_dir / "test.txt"
test_file.write_text("hello")
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/data", local_path=str(data_dir)),
],
)
# Mock subprocess to capture the resolved command
captured = {}
original_run = __import__("subprocess").run
def mock_run(*args, **kwargs):
if len(args) > 0:
captured["command"] = args[0]
return original_run(*args, **kwargs)
monkeypatch.setattr("deerflow.sandbox.local.local_sandbox.subprocess.run", mock_run)
monkeypatch.setattr("deerflow.sandbox.local.local_sandbox.LocalSandbox._get_shell", lambda self: "/bin/sh")
sandbox.execute_command("cat /mnt/data/test.txt")
# Verify the command received the resolved local path
assert str(data_dir) in captured.get("command", "")
def test_reverse_resolve_path_does_not_match_partial_prefix(self, tmp_path):
foo_dir = tmp_path / "foo"
foo_dir.mkdir()
foobar_dir = tmp_path / "foobar"
foobar_dir.mkdir()
target = foobar_dir / "file.txt"
target.write_text("test")
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/foo", local_path=str(foo_dir)),
],
)
resolved = sandbox._reverse_resolve_path(str(target))
assert resolved == str(target.resolve())
def test_reverse_resolve_paths_in_output_supports_backslash_separator(self, tmp_path):
mount_dir = tmp_path / "mount"
mount_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/data", local_path=str(mount_dir)),
],
)
output = f"Copied: {mount_dir}\\file.txt"
masked = sandbox._reverse_resolve_paths_in_output(output)
assert "/mnt/data/file.txt" in masked
assert str(mount_dir) not in masked
class TestLocalSandboxProviderMounts:
def test_setup_path_mappings_uses_configured_skills_container_path_as_reserved_prefix(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
custom_dir = tmp_path / "custom"
custom_dir.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path=str(custom_dir), container_path="/custom-skills/nested", read_only=False),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/custom-skills", get_skills_path=lambda: skills_dir),
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"]
def test_setup_path_mappings_skips_relative_host_path(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path="relative/path", container_path="/mnt/data", read_only=False),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
def test_setup_path_mappings_skips_non_absolute_container_path(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
custom_dir = tmp_path / "custom"
custom_dir.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path=str(custom_dir), container_path="mnt/data", read_only=False),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
def test_setup_path_mappings_normalizes_container_path_trailing_slash(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
custom_dir = tmp_path / "custom"
custom_dir.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path=str(custom_dir), container_path="/mnt/data/", read_only=False),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]
+119 -2
View File
@@ -1,5 +1,6 @@
"""Tests for LoopDetectionMiddleware."""
import copy
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
@@ -19,8 +20,13 @@ def _make_runtime(thread_id="test-thread"):
def _make_state(tool_calls=None, content=""):
"""Build a minimal AgentState dict with an AIMessage."""
msg = AIMessage(content=content, tool_calls=tool_calls or [])
"""Build a minimal AgentState dict with an AIMessage.
Deep-copies *content* when it is mutable (e.g. list) so that
successive calls never share the same object reference.
"""
safe_content = copy.deepcopy(content) if isinstance(content, list) else content
msg = AIMessage(content=safe_content, tool_calls=tool_calls or [])
return {"messages": [msg]}
@@ -229,3 +235,114 @@ class TestLoopDetection:
mw._apply(_make_state(tool_calls=call), runtime)
assert "default" in mw._history
class TestAppendText:
"""Unit tests for LoopDetectionMiddleware._append_text."""
def test_none_content_returns_text(self):
result = LoopDetectionMiddleware._append_text(None, "hello")
assert result == "hello"
def test_str_content_concatenates(self):
result = LoopDetectionMiddleware._append_text("existing", "appended")
assert result == "existing\n\nappended"
def test_empty_str_content_concatenates(self):
result = LoopDetectionMiddleware._append_text("", "appended")
assert result == "\n\nappended"
def test_list_content_appends_text_block(self):
"""List content (e.g. Anthropic thinking mode) should get a new text block."""
content = [
{"type": "thinking", "text": "Let me think..."},
{"type": "text", "text": "Here is my answer"},
]
result = LoopDetectionMiddleware._append_text(content, "stop msg")
assert isinstance(result, list)
assert len(result) == 3
assert result[0] == content[0]
assert result[1] == content[1]
assert result[2] == {"type": "text", "text": "\n\nstop msg"}
def test_empty_list_content_appends_text_block(self):
result = LoopDetectionMiddleware._append_text([], "stop msg")
assert isinstance(result, list)
assert len(result) == 1
assert result[0] == {"type": "text", "text": "\n\nstop msg"}
def test_unexpected_type_coerced_to_str(self):
"""Unexpected content types should be coerced to str as a fallback."""
result = LoopDetectionMiddleware._append_text(42, "stop msg")
assert isinstance(result, str)
assert result == "42\n\nstop msg"
def test_list_content_not_mutated_in_place(self):
"""_append_text must not modify the original list."""
original = [{"type": "text", "text": "hello"}]
result = LoopDetectionMiddleware._append_text(original, "appended")
assert len(original) == 1 # original unchanged
assert len(result) == 2 # new list has the appended block
class TestHardStopWithListContent:
"""Regression tests: hard stop must not crash when AIMessage.content is a list."""
def test_hard_stop_with_list_content(self):
"""Hard stop on list content should not raise TypeError (regression)."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
runtime = _make_runtime()
call = [_bash_call("ls")]
# Build state with list content (e.g. Anthropic thinking mode)
list_content = [
{"type": "thinking", "text": "Let me think..."},
{"type": "text", "text": "I'll run ls"},
]
for _ in range(3):
mw._apply(_make_state(tool_calls=call, content=list_content), runtime)
# Fourth call triggers hard stop — must not raise TypeError
result = mw._apply(_make_state(tool_calls=call, content=list_content), runtime)
assert result is not None
msg = result["messages"][0]
assert isinstance(msg, AIMessage)
assert msg.tool_calls == []
# Content should remain a list with the stop message appended
assert isinstance(msg.content, list)
assert len(msg.content) == 3
assert msg.content[2]["type"] == "text"
assert _HARD_STOP_MSG in msg.content[2]["text"]
def test_hard_stop_with_none_content(self):
"""Hard stop on None content should produce a plain string."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
# Fourth call with default empty-string content
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msg = result["messages"][0]
assert isinstance(msg.content, str)
assert _HARD_STOP_MSG in msg.content
def test_hard_stop_with_str_content(self):
"""Hard stop on str content should concatenate the stop message."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime)
result = mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime)
assert result is not None
msg = result["messages"][0]
assert isinstance(msg.content, str)
assert msg.content.startswith("thinking...")
assert _HARD_STOP_MSG in msg.content
@@ -119,3 +119,57 @@ def test_format_memory_skips_non_string_content_facts() -> None:
# The formatted line for a list content would be "- [knowledge | 0.85] ['list']".
assert "| 0.85]" not in result
assert "Valid fact" in result
def test_format_memory_renders_correction_source_error() -> None:
memory_data = {
"facts": [
{
"content": "Use make dev for local development.",
"category": "correction",
"confidence": 0.95,
"sourceError": "The agent previously suggested npm start.",
}
]
}
result = format_memory_for_injection(memory_data, max_tokens=2000)
assert "Use make dev for local development." in result
assert "avoid: The agent previously suggested npm start." in result
def test_format_memory_renders_correction_without_source_error_normally() -> None:
memory_data = {
"facts": [
{
"content": "Use make dev for local development.",
"category": "correction",
"confidence": 0.95,
}
]
}
result = format_memory_for_injection(memory_data, max_tokens=2000)
assert "Use make dev for local development." in result
assert "avoid:" not in result
def test_format_memory_includes_long_term_background() -> None:
"""longTermBackground in history must be injected into the prompt."""
memory_data = {
"user": {},
"history": {
"recentMonths": {"summary": "Recent activity summary"},
"earlierContext": {"summary": "Earlier context summary"},
"longTermBackground": {"summary": "Core expertise in distributed systems"},
},
"facts": [],
}
result = format_memory_for_injection(memory_data, max_tokens=2000)
assert "Background: Core expertise in distributed systems" in result
assert "Recent: Recent activity summary" in result
assert "Earlier: Earlier context summary" in result
+50
View File
@@ -0,0 +1,50 @@
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
def _memory_config(**overrides: object) -> MemoryConfig:
config = MemoryConfig()
for key, value in overrides.items():
setattr(config, key, value)
return config
def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["first"], correction_detected=True)
queue.add(thread_id="thread-1", messages=["second"], correction_detected=False)
assert len(queue._queue) == 1
assert queue._queue[0].messages == ["second"]
assert queue._queue[0].correction_detected is True
def test_process_queue_forwards_correction_flag_to_updater() -> None:
queue = MemoryUpdateQueue()
queue._queue = [
ConversationContext(
thread_id="thread-1",
messages=["conversation"],
agent_name="lead_agent",
correction_detected=True,
)
]
mock_updater = MagicMock()
mock_updater.update_memory.return_value = True
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
queue._process_queue()
mock_updater.update_memory.assert_called_once_with(
messages=["conversation"],
thread_id="thread-1",
agent_name="lead_agent",
correction_detected=True,
)
+50
View File
@@ -72,6 +72,56 @@ def test_import_memory_route_returns_imported_memory() -> None:
assert response.json()["facts"] == imported_memory["facts"]
def test_export_memory_route_preserves_source_error() -> None:
app = FastAPI()
app.include_router(memory.router)
exported_memory = _sample_memory(
facts=[
{
"id": "fact_correction",
"content": "Use make dev for local development.",
"category": "correction",
"confidence": 0.95,
"createdAt": "2026-03-20T00:00:00Z",
"source": "thread-1",
"sourceError": "The agent previously suggested npm start.",
}
]
)
with patch("app.gateway.routers.memory.get_memory_data", return_value=exported_memory):
with TestClient(app) as client:
response = client.get("/api/memory/export")
assert response.status_code == 200
assert response.json()["facts"][0]["sourceError"] == "The agent previously suggested npm start."
def test_import_memory_route_preserves_source_error() -> None:
app = FastAPI()
app.include_router(memory.router)
imported_memory = _sample_memory(
facts=[
{
"id": "fact_correction",
"content": "Use make dev for local development.",
"category": "correction",
"confidence": 0.95,
"createdAt": "2026-03-20T00:00:00Z",
"source": "thread-1",
"sourceError": "The agent previously suggested npm start.",
}
]
)
with patch("app.gateway.routers.memory.import_memory_data", return_value=imported_memory):
with TestClient(app) as client:
response = client.post("/api/memory/import", json=imported_memory)
assert response.status_code == 200
assert response.json()["facts"][0]["sourceError"] == "The agent previously suggested npm start."
def test_clear_memory_route_returns_cleared_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
+97
View File
@@ -146,6 +146,53 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
assert result["facts"][1]["source"] == "thread-9"
def test_apply_updates_preserves_source_error() -> None:
updater = MemoryUpdater()
current_memory = _make_memory()
update_data = {
"newFacts": [
{
"content": "Use make dev for local development.",
"category": "correction",
"confidence": 0.95,
"sourceError": "The agent previously suggested npm start.",
}
]
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
assert result["facts"][0]["sourceError"] == "The agent previously suggested npm start."
assert result["facts"][0]["category"] == "correction"
def test_apply_updates_ignores_empty_source_error() -> None:
updater = MemoryUpdater()
current_memory = _make_memory()
update_data = {
"newFacts": [
{
"content": "Use make dev for local development.",
"category": "correction",
"confidence": 0.95,
"sourceError": " ",
}
]
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
assert "sourceError" not in result["facts"][0]
def test_clear_memory_data_resets_all_sections() -> None:
with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True):
result = clear_memory_data()
@@ -522,3 +569,53 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg])
assert result is True
def test_correction_hint_injected_when_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "No, that's wrong."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Understood"
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], correction_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" in prompt
def test_correction_hint_empty_when_not_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Let's talk about memory."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Sure"
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], correction_detected=False)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" not in prompt
+59 -1
View File
@@ -10,7 +10,7 @@ persisting in long-term memory:
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
# ---------------------------------------------------------------------------
# Helpers
@@ -134,6 +134,64 @@ class TestFilterMessagesForMemory:
assert "<uploaded_files>" not in all_content
# ===========================================================================
# detect_correction
# ===========================================================================
class TestDetectCorrection:
def test_detects_english_correction_signal(self):
msgs = [
_human("Please help me run the project."),
_ai("Use npm start."),
_human("That's wrong, use make dev instead."),
_ai("Understood."),
]
assert detect_correction(msgs) is True
def test_detects_chinese_correction_signal(self):
msgs = [
_human("帮我启动项目"),
_ai("用 npm start"),
_human("不对,改用 make dev"),
_ai("明白了"),
]
assert detect_correction(msgs) is True
def test_returns_false_without_signal(self):
msgs = [
_human("Please explain the build setup."),
_ai("Here is the build setup."),
_human("Thanks, that makes sense."),
]
assert detect_correction(msgs) is False
def test_only_checks_recent_messages(self):
msgs = [
_human("That is wrong, use make dev instead."),
_ai("Noted."),
_human("Let's discuss tests."),
_ai("Sure."),
_human("What about linting?"),
_ai("Use ruff."),
_human("And formatting?"),
_ai("Use make format."),
]
assert detect_correction(msgs) is False
def test_handles_list_content(self):
msgs = [
HumanMessage(content=["That is wrong,", {"type": "text", "text": "use make dev instead."}]),
_ai("Updated."),
]
assert detect_correction(msgs) is True
# ===========================================================================
# _strip_upload_mentions_from_memory
# ===========================================================================
+13 -2
View File
@@ -73,7 +73,7 @@ def _patch_factory(monkeypatch, app_config: AppConfig, model_class=FakeChatModel
"""Patch get_app_config, resolve_class, and tracing for isolated unit tests."""
monkeypatch.setattr(factory_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: model_class)
monkeypatch.setattr(factory_module, "is_tracing_enabled", lambda: False)
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
# ---------------------------------------------------------------------------
@@ -95,12 +95,23 @@ def test_uses_first_model_when_name_is_none(monkeypatch):
def test_raises_when_model_not_found(monkeypatch):
cfg = _make_app_config([_make_model("only-model")])
monkeypatch.setattr(factory_module, "get_app_config", lambda: cfg)
monkeypatch.setattr(factory_module, "is_tracing_enabled", lambda: False)
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
with pytest.raises(ValueError, match="ghost-model"):
factory_module.create_chat_model(name="ghost-model")
def test_appends_all_tracing_callbacks(monkeypatch):
cfg = _make_app_config([_make_model("alpha")])
_patch_factory(monkeypatch, cfg)
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: ["smith-callback", "langfuse-callback"])
FakeChatModel.captured_kwargs = {}
model = factory_module.create_chat_model(name="alpha")
assert model.callbacks == ["smith-callback", "langfuse-callback"]
# ---------------------------------------------------------------------------
# thinking_enabled=True
# ---------------------------------------------------------------------------
+393
View File
@@ -0,0 +1,393 @@
from types import SimpleNamespace
from unittest.mock import patch
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
from deerflow.sandbox.local.local_sandbox import LocalSandbox
from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches
from deerflow.sandbox.tools import glob_tool, grep_tool
def _make_runtime(tmp_path):
workspace = tmp_path / "workspace"
uploads = tmp_path / "uploads"
outputs = tmp_path / "outputs"
workspace.mkdir()
uploads.mkdir()
outputs.mkdir()
return SimpleNamespace(
state={
"sandbox": {"sandbox_id": "local"},
"thread_data": {
"workspace_path": str(workspace),
"uploads_path": str(uploads),
"outputs_path": str(outputs),
},
},
context={"thread_id": "thread-1"},
)
def test_glob_tool_returns_virtual_paths_and_ignores_common_dirs(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "app.py").write_text("print('hi')\n", encoding="utf-8")
(workspace / "pkg").mkdir()
(workspace / "pkg" / "util.py").write_text("print('util')\n", encoding="utf-8")
(workspace / "node_modules").mkdir()
(workspace / "node_modules" / "skip.py").write_text("ignored\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = glob_tool.func(
runtime=runtime,
description="find python files",
pattern="**/*.py",
path="/mnt/user-data/workspace",
)
assert "/mnt/user-data/workspace/app.py" in result
assert "/mnt/user-data/workspace/pkg/util.py" in result
assert "node_modules" not in result
assert str(workspace) not in result
def test_glob_tool_supports_skills_virtual_paths(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
skills_dir = tmp_path / "skills"
(skills_dir / "public" / "demo").mkdir(parents=True)
(skills_dir / "public" / "demo" / "SKILL.md").write_text("# Demo\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
with (
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
patch("deerflow.sandbox.tools._get_skills_host_path", return_value=str(skills_dir)),
):
result = glob_tool.func(
runtime=runtime,
description="find skills",
pattern="**/SKILL.md",
path="/mnt/skills",
)
assert "/mnt/skills/public/demo/SKILL.md" in result
assert str(skills_dir) not in result
def test_grep_tool_filters_by_glob_and_skips_binary_files(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "main.py").write_text("TODO = 'ship it'\nprint(TODO)\n", encoding="utf-8")
(workspace / "notes.txt").write_text("TODO in txt should be filtered\n", encoding="utf-8")
(workspace / "image.bin").write_bytes(b"\0binary TODO")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = grep_tool.func(
runtime=runtime,
description="find todo references",
pattern="TODO",
path="/mnt/user-data/workspace",
glob="**/*.py",
)
assert "/mnt/user-data/workspace/main.py:1: TODO = 'ship it'" in result
assert "notes.txt" not in result
assert "image.bin" not in result
assert str(workspace) not in result
def test_grep_tool_truncates_results(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "main.py").write_text("TODO one\nTODO two\nTODO three\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
# Prevent config.yaml tool config from overriding the caller-supplied max_results=2.
monkeypatch.setattr("deerflow.sandbox.tools.get_app_config", lambda: SimpleNamespace(get_tool_config=lambda name: None))
result = grep_tool.func(
runtime=runtime,
description="limit matches",
pattern="TODO",
path="/mnt/user-data/workspace",
max_results=2,
)
assert "Found 2 matches under /mnt/user-data/workspace (showing first 2)" in result
assert "TODO one" in result
assert "TODO two" in result
assert "TODO three" not in result
assert "Results truncated." in result
def test_glob_tool_include_dirs_filters_nested_ignored_paths(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "src").mkdir()
(workspace / "src" / "main.py").write_text("x\n", encoding="utf-8")
(workspace / "node_modules").mkdir()
(workspace / "node_modules" / "lib").mkdir()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = glob_tool.func(
runtime=runtime,
description="find dirs",
pattern="**",
path="/mnt/user-data/workspace",
include_dirs=True,
)
assert "src" in result
assert "node_modules" not in result
def test_grep_tool_literal_mode(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "file.py").write_text("price = (a+b)\nresult = a+b\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
# literal=True should treat (a+b) as a plain string, not a regex group
result = grep_tool.func(
runtime=runtime,
description="literal search",
pattern="(a+b)",
path="/mnt/user-data/workspace",
literal=True,
)
assert "price = (a+b)" in result
assert "result = a+b" not in result
def test_grep_tool_case_sensitive(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "file.py").write_text("TODO: fix\ntodo: also fix\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = grep_tool.func(
runtime=runtime,
description="case sensitive search",
pattern="TODO",
path="/mnt/user-data/workspace",
case_sensitive=True,
)
assert "TODO: fix" in result
assert "todo: also fix" not in result
def test_grep_tool_invalid_regex_returns_error(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = grep_tool.func(
runtime=runtime,
description="bad pattern",
pattern="[invalid",
path="/mnt/user-data/workspace",
)
assert "Invalid regex pattern" in result
def test_aio_sandbox_glob_include_dirs_filters_nested_ignored(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"list_path",
lambda **kwargs: SimpleNamespace(
data=SimpleNamespace(
files=[
SimpleNamespace(name="src", path="/mnt/workspace/src"),
SimpleNamespace(name="node_modules", path="/mnt/workspace/node_modules"),
# child of node_modules — should be filtered via should_ignore_path
SimpleNamespace(name="lib", path="/mnt/workspace/node_modules/lib"),
]
)
),
)
matches, truncated = sandbox.glob("/mnt/workspace", "**", include_dirs=True)
assert "/mnt/workspace/src" in matches
assert "/mnt/workspace/node_modules" not in matches
assert "/mnt/workspace/node_modules/lib" not in matches
assert truncated is False
def test_aio_sandbox_grep_invalid_regex_raises() -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
import re
try:
sandbox.grep("/mnt/workspace", "[invalid")
assert False, "Expected re.error"
except re.error:
pass
def test_aio_sandbox_glob_parses_json(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"find_files",
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(files=["/mnt/user-data/workspace/app.py", "/mnt/user-data/workspace/node_modules/skip.py"])),
)
matches, truncated = sandbox.glob("/mnt/user-data/workspace", "**/*.py")
assert matches == ["/mnt/user-data/workspace/app.py"]
assert truncated is False
def test_aio_sandbox_grep_parses_json(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"list_path",
lambda **kwargs: SimpleNamespace(
data=SimpleNamespace(
files=[
SimpleNamespace(
name="app.py",
path="/mnt/user-data/workspace/app.py",
is_directory=False,
)
]
)
),
)
monkeypatch.setattr(
sandbox._client.file,
"search_in_file",
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(line_numbers=[7], matches=["TODO = True"])),
)
matches, truncated = sandbox.grep("/mnt/user-data/workspace", "TODO")
assert matches == [GrepMatch(path="/mnt/user-data/workspace/app.py", line_number=7, line="TODO = True")]
assert truncated is False
def test_find_glob_matches_raises_not_a_directory(tmp_path) -> None:
file_path = tmp_path / "file.txt"
file_path.write_text("x\n", encoding="utf-8")
try:
find_glob_matches(file_path, "**/*.py")
assert False, "Expected NotADirectoryError"
except NotADirectoryError:
pass
def test_find_grep_matches_raises_not_a_directory(tmp_path) -> None:
file_path = tmp_path / "file.txt"
file_path.write_text("TODO\n", encoding="utf-8")
try:
find_grep_matches(file_path, "TODO")
assert False, "Expected NotADirectoryError"
except NotADirectoryError:
pass
def test_find_grep_matches_skips_symlink_outside_root(tmp_path) -> None:
workspace = tmp_path / "workspace"
workspace.mkdir()
outside = tmp_path / "outside.txt"
outside.write_text("TODO outside\n", encoding="utf-8")
(workspace / "outside-link.txt").symlink_to(outside)
matches, truncated = find_grep_matches(workspace, "TODO")
assert matches == []
assert truncated is False
def test_glob_tool_honors_smaller_requested_max_results(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "a.py").write_text("print('a')\n", encoding="utf-8")
(workspace / "b.py").write_text("print('b')\n", encoding="utf-8")
(workspace / "c.py").write_text("print('c')\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
monkeypatch.setattr(
"deerflow.sandbox.tools.get_app_config",
lambda: SimpleNamespace(get_tool_config=lambda name: SimpleNamespace(model_extra={"max_results": 50})),
)
result = glob_tool.func(
runtime=runtime,
description="limit glob matches",
pattern="**/*.py",
path="/mnt/user-data/workspace",
max_results=2,
)
assert "Found 2 paths under /mnt/user-data/workspace (showing first 2)" in result
assert "Results truncated." in result
def test_aio_sandbox_glob_include_dirs_enforces_root_boundary(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"list_path",
lambda **kwargs: SimpleNamespace(
data=SimpleNamespace(
files=[
SimpleNamespace(name="src", path="/mnt/workspace/src"),
SimpleNamespace(name="src2", path="/mnt/workspace2/src2"),
]
)
),
)
matches, truncated = sandbox.glob("/mnt/workspace", "**", include_dirs=True)
assert matches == ["/mnt/workspace/src"]
assert truncated is False
def test_aio_sandbox_grep_skips_mismatched_line_number_payloads(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"list_path",
lambda **kwargs: SimpleNamespace(
data=SimpleNamespace(
files=[
SimpleNamespace(
name="app.py",
path="/mnt/user-data/workspace/app.py",
is_directory=False,
)
]
)
),
)
monkeypatch.setattr(
sandbox._client.file,
"search_in_file",
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(line_numbers=[7], matches=["TODO = True", "extra"])),
)
matches, truncated = sandbox.grep("/mnt/user-data/workspace", "TODO")
assert matches == [GrepMatch(path="/mnt/user-data/workspace/app.py", line_number=7, line="TODO = True")]
assert truncated is False
@@ -1,3 +1,4 @@
import threading
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch
@@ -7,7 +8,10 @@ import pytest
from deerflow.sandbox.tools import (
VIRTUAL_PATH_PREFIX,
_apply_cwd_prefix,
_get_custom_mount_for_path,
_get_custom_mounts,
_is_acp_workspace_path,
_is_custom_mount_path,
_is_skills_path,
_reject_path_traversal,
_resolve_acp_workspace_path,
@@ -17,8 +21,10 @@ from deerflow.sandbox.tools import (
mask_local_paths_in_output,
replace_virtual_path,
replace_virtual_paths_in_command,
str_replace_tool,
validate_local_bash_command_paths,
validate_local_tool_path,
write_file_tool,
)
_THREAD_DATA = {
@@ -93,6 +99,25 @@ def test_validate_local_tool_path_rejects_non_virtual_path() -> None:
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
def test_validate_local_tool_path_rejects_non_virtual_path_mentions_configured_mounts() -> None:
with pytest.raises(PermissionError, match="configured mount paths"):
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -> None:
from deerflow.config.sandbox_config import VolumeMountConfig
mounts = [
VolumeMountConfig(host_path="/tmp/host-user-data", container_path=VIRTUAL_PATH_PREFIX, read_only=False),
]
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=True)
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, read_only=True)
def test_validate_local_tool_path_rejects_bare_virtual_root() -> None:
"""The bare /mnt/user-data root without trailing slash is not a valid sub-path."""
with pytest.raises(PermissionError, match="Only paths under"):
@@ -321,6 +346,56 @@ def test_validate_local_bash_command_paths_allows_skills_path() -> None:
)
def test_validate_local_bash_command_paths_allows_urls() -> None:
"""URLs in bash commands should not be mistaken for absolute paths (issue #1385)."""
# HTTPS URLs
validate_local_bash_command_paths(
"curl -X POST https://example.com/api/v1/risk/check",
_THREAD_DATA,
)
# HTTP URLs
validate_local_bash_command_paths(
"curl http://localhost:8080/health",
_THREAD_DATA,
)
# URLs with query strings
validate_local_bash_command_paths(
"curl https://api.example.com/v2/search?q=test",
_THREAD_DATA,
)
# FTP URLs
validate_local_bash_command_paths(
"curl ftp://ftp.example.com/pub/file.tar.gz",
_THREAD_DATA,
)
# URL mixed with valid virtual path
validate_local_bash_command_paths(
"curl https://example.com/data -o /mnt/user-data/workspace/data.json",
_THREAD_DATA,
)
def test_validate_local_bash_command_paths_blocks_file_urls() -> None:
"""file:// URLs should be treated as unsafe and blocked."""
with pytest.raises(PermissionError):
validate_local_bash_command_paths("curl file:///etc/passwd", _THREAD_DATA)
def test_validate_local_bash_command_paths_blocks_file_urls_case_insensitive() -> None:
"""file:// URL detection should be case-insensitive."""
with pytest.raises(PermissionError):
validate_local_bash_command_paths("curl FILE:///etc/shadow", _THREAD_DATA)
def test_validate_local_bash_command_paths_blocks_file_urls_mixed_with_valid() -> None:
"""file:// URLs should be blocked even when mixed with valid paths."""
with pytest.raises(PermissionError):
validate_local_bash_command_paths(
"curl file:///etc/passwd -o /mnt/user-data/workspace/out.txt",
_THREAD_DATA,
)
def test_validate_local_bash_command_paths_still_blocks_other_paths() -> None:
"""Paths outside virtual and system prefixes must still be blocked."""
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
@@ -512,3 +587,371 @@ def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None
with patch("deerflow.config.extensions_config.get_extensions_config", return_value=disabled_config):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
# ---------- Custom mount path tests ----------
def _mock_custom_mounts():
"""Create mock VolumeMountConfig objects for testing."""
from deerflow.config.sandbox_config import VolumeMountConfig
return [
VolumeMountConfig(host_path="/home/user/code-read", container_path="/mnt/code-read", read_only=True),
VolumeMountConfig(host_path="/home/user/data", container_path="/mnt/data", read_only=False),
]
def test_is_custom_mount_path_recognises_configured_mounts() -> None:
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
assert _is_custom_mount_path("/mnt/code-read") is True
assert _is_custom_mount_path("/mnt/code-read/src/main.py") is True
assert _is_custom_mount_path("/mnt/data") is True
assert _is_custom_mount_path("/mnt/data/file.txt") is True
assert _is_custom_mount_path("/mnt/code-read-extra/foo") is False
assert _is_custom_mount_path("/mnt/other") is False
def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
from deerflow.config.sandbox_config import VolumeMountConfig
mounts = [
VolumeMountConfig(host_path="/var/mnt", container_path="/mnt", read_only=False),
VolumeMountConfig(host_path="/home/user/code", container_path="/mnt/code", read_only=True),
]
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
mount = _get_custom_mount_for_path("/mnt/code/file.py")
assert mount is not None
assert mount.container_path == "/mnt/code"
def test_validate_local_tool_path_allows_custom_mount_read() -> None:
"""read_file / ls should be able to access custom mount paths."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=True)
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=True)
def test_validate_local_tool_path_blocks_read_only_mount_write() -> None:
"""write_file / str_replace must NOT write to read-only custom mounts."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="Write access to read-only mount is not allowed"):
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=False)
def test_validate_local_tool_path_allows_writable_mount_write() -> None:
"""write_file / str_replace should succeed on writable custom mounts."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=False)
def test_validate_local_tool_path_blocks_traversal_in_custom_mount() -> None:
"""Path traversal via .. in custom mount paths must be rejected."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, read_only=True)
def test_validate_local_bash_command_paths_allows_custom_mount() -> None:
"""bash commands referencing custom mount paths should be allowed."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA)
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA)
def test_validate_local_bash_command_paths_blocks_traversal_in_custom_mount() -> None:
"""Bash commands with traversal in custom mount paths should be blocked."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA)
def test_validate_local_bash_command_paths_still_blocks_non_mount_paths() -> None:
"""Paths not matching any custom mount should still be blocked."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
def test_get_custom_mounts_caching(monkeypatch, tmp_path) -> None:
"""_get_custom_mounts should cache after first successful load."""
# Clear any existing cache
if hasattr(_get_custom_mounts, "_cached"):
monkeypatch.delattr(_get_custom_mounts, "_cached")
# Use real directories so host_path.exists() filtering passes
dir_a = tmp_path / "code-read"
dir_a.mkdir()
dir_b = tmp_path / "data"
dir_b.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
mounts = [
VolumeMountConfig(host_path=str(dir_a), container_path="/mnt/code-read", read_only=True),
VolumeMountConfig(host_path=str(dir_b), container_path="/mnt/data", read_only=False),
]
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch("deerflow.config.get_app_config", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 2
# After caching, should return cached value even without mock
assert hasattr(_get_custom_mounts, "_cached")
assert len(_get_custom_mounts()) == 2
# Cleanup
monkeypatch.delattr(_get_custom_mounts, "_cached")
def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path) -> None:
"""_get_custom_mounts should only return mounts whose host_path exists."""
if hasattr(_get_custom_mounts, "_cached"):
monkeypatch.delattr(_get_custom_mounts, "_cached")
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
existing_dir = tmp_path / "existing"
existing_dir.mkdir()
mounts = [
VolumeMountConfig(host_path=str(existing_dir), container_path="/mnt/existing", read_only=True),
VolumeMountConfig(host_path="/nonexistent/path/12345", container_path="/mnt/ghost", read_only=False),
]
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch("deerflow.config.get_app_config", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 1
assert result[0].container_path == "/mnt/existing"
# Cleanup
monkeypatch.delattr(_get_custom_mounts, "_cached")
def test_get_custom_mount_for_path_boundary_no_false_prefix_match() -> None:
"""_get_custom_mount_for_path must not match /mnt/code-read-extra for /mnt/code-read."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo")
assert mount is None
def test_str_replace_parallel_updates_should_preserve_both_edits(monkeypatch) -> None:
class SharedSandbox:
def __init__(self) -> None:
self.content = "alpha\nbeta\n"
self._active_reads = 0
self._state_lock = threading.Lock()
self._overlap_detected = threading.Event()
def read_file(self, path: str) -> str:
with self._state_lock:
self._active_reads += 1
snapshot = self.content
if self._active_reads == 2:
self._overlap_detected.set()
self._overlap_detected.wait(0.05)
with self._state_lock:
self._active_reads -= 1
return snapshot
def write_file(self, path: str, content: str, append: bool = False) -> None:
self.content = content
sandbox = SharedSandbox()
runtimes = [
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
]
failures: list[BaseException] = []
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
def worker(runtime: SimpleNamespace, old_str: str, new_str: str) -> None:
try:
result = str_replace_tool.func(
runtime=runtime,
description="并发替换同一文件",
path="/mnt/user-data/workspace/shared.txt",
old_str=old_str,
new_str=new_str,
)
assert result == "OK"
except BaseException as exc: # pragma: no cover - failure is asserted below
failures.append(exc)
threads = [
threading.Thread(target=worker, args=(runtimes[0], "alpha", "ALPHA")),
threading.Thread(target=worker, args=(runtimes[1], "beta", "BETA")),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert failures == []
assert "ALPHA" in sandbox.content
assert "BETA" in sandbox.content
def test_str_replace_parallel_updates_in_isolated_sandboxes_should_not_share_path_lock(monkeypatch) -> None:
class IsolatedSandbox:
def __init__(self, sandbox_id: str, shared_state: dict[str, object]) -> None:
self.id = sandbox_id
self.content = "alpha\nbeta\n"
self._shared_state = shared_state
def read_file(self, path: str) -> str:
state_lock = self._shared_state["state_lock"]
with state_lock:
active_reads = self._shared_state["active_reads"]
self._shared_state["active_reads"] = active_reads + 1
snapshot = self.content
if self._shared_state["active_reads"] == 2:
overlap_detected = self._shared_state["overlap_detected"]
overlap_detected.set()
overlap_detected = self._shared_state["overlap_detected"]
overlap_detected.wait(0.05)
with state_lock:
active_reads = self._shared_state["active_reads"]
self._shared_state["active_reads"] = active_reads - 1
return snapshot
def write_file(self, path: str, content: str, append: bool = False) -> None:
self.content = content
shared_state: dict[str, object] = {
"active_reads": 0,
"state_lock": threading.Lock(),
"overlap_detected": threading.Event(),
}
sandboxes = {
"sandbox-a": IsolatedSandbox("sandbox-a", shared_state),
"sandbox-b": IsolatedSandbox("sandbox-b", shared_state),
}
runtimes = [
SimpleNamespace(state={}, context={"thread_id": "thread-1", "sandbox_key": "sandbox-a"}, config={}),
SimpleNamespace(state={}, context={"thread_id": "thread-2", "sandbox_key": "sandbox-b"}, config={}),
]
failures: list[BaseException] = []
monkeypatch.setattr(
"deerflow.sandbox.tools.ensure_sandbox_initialized",
lambda runtime: sandboxes[runtime.context["sandbox_key"]],
)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
def worker(runtime: SimpleNamespace, old_str: str, new_str: str) -> None:
try:
result = str_replace_tool.func(
runtime=runtime,
description="隔离 sandbox 并发替换同一路径",
path="/mnt/user-data/workspace/shared.txt",
old_str=old_str,
new_str=new_str,
)
assert result == "OK"
except BaseException as exc: # pragma: no cover - failure is asserted below
failures.append(exc)
threads = [
threading.Thread(target=worker, args=(runtimes[0], "alpha", "ALPHA")),
threading.Thread(target=worker, args=(runtimes[1], "beta", "BETA")),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert failures == []
assert sandboxes["sandbox-a"].content == "ALPHA\nbeta\n"
assert sandboxes["sandbox-b"].content == "alpha\nBETA\n"
assert shared_state["overlap_detected"].is_set()
def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkeypatch) -> None:
class SharedSandbox:
def __init__(self) -> None:
self.id = "sandbox-1"
self.content = "alpha\n"
self.state_lock = threading.Lock()
self.str_replace_has_snapshot = threading.Event()
self.append_finished = threading.Event()
def read_file(self, path: str) -> str:
with self.state_lock:
snapshot = self.content
self.str_replace_has_snapshot.set()
self.append_finished.wait(0.05)
return snapshot
def write_file(self, path: str, content: str, append: bool = False) -> None:
with self.state_lock:
if append:
self.content += content
self.append_finished.set()
else:
self.content = content
sandbox = SharedSandbox()
runtimes = [
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
]
failures: list[BaseException] = []
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
def replace_worker() -> None:
try:
result = str_replace_tool.func(
runtime=runtimes[0],
description="替换旧内容",
path="/mnt/user-data/workspace/shared.txt",
old_str="alpha",
new_str="ALPHA",
)
assert result == "OK"
except BaseException as exc: # pragma: no cover - failure is asserted below
failures.append(exc)
def append_worker() -> None:
try:
sandbox.str_replace_has_snapshot.wait(0.05)
result = write_file_tool.func(
runtime=runtimes[1],
description="追加新内容",
path="/mnt/user-data/workspace/shared.txt",
content="tail\n",
append=True,
)
assert result == "OK"
except BaseException as exc: # pragma: no cover - failure is asserted below
failures.append(exc)
replace_thread = threading.Thread(target=replace_worker)
append_thread = threading.Thread(target=append_worker)
replace_thread.start()
append_thread.start()
replace_thread.join()
append_thread.join()
assert failures == []
assert sandbox.content == "ALPHA\ntail\n"
+21
View File
@@ -93,6 +93,27 @@ class TestParseSkillFile:
assert result is not None
assert result.description == "A skill: does things"
def test_multiline_yaml_folded_description(self, tmp_path):
skill_file = _write_skill(
tmp_path,
"---\nname: multiline-skill\ndescription: >\n This is a multiline\n description for a skill.\n\n It spans multiple lines.\nlicense: MIT\n---\n\nBody\n",
)
result = parse_skill_file(skill_file, "public")
assert result is not None
assert result.name == "multiline-skill"
assert result.description == "This is a multiline description for a skill.\n\nIt spans multiple lines."
assert result.license == "MIT"
def test_multiline_yaml_literal_description(self, tmp_path):
skill_file = _write_skill(
tmp_path,
"---\nname: pipe-skill\ndescription: |\n First line.\n Second line.\n---\n\nBody\n",
)
result = parse_skill_file(skill_file, "public")
assert result is not None
assert result.name == "pipe-skill"
assert result.description == "First line.\nSecond line."
def test_empty_front_matter_returns_none(self, tmp_path):
skill_file = _write_skill(tmp_path, "---\n\n---\n\nBody\n")
assert parse_skill_file(skill_file, "public") is None
+187
View File
@@ -140,6 +140,193 @@ async def test_event_id_format(bridge: MemoryStreamBridge):
assert re.match(r"^\d+-\d+$", event.id), f"Expected timestamp-seq format, got {event.id}"
# ---------------------------------------------------------------------------
# END sentinel guarantee tests
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_end_sentinel_delivered_when_queue_full():
"""END sentinel must always be delivered, even when the queue is completely full.
This is the critical regression test for the bug where publish_end()
would silently drop the END sentinel when the queue was full, causing
subscribe() to hang forever and leaking resources.
"""
bridge = MemoryStreamBridge(queue_maxsize=2)
run_id = "run-end-full"
# Fill the queue to capacity
await bridge.publish(run_id, "event-1", {"n": 1})
await bridge.publish(run_id, "event-2", {"n": 2})
assert bridge._queues[run_id].full()
# publish_end should succeed by evicting old events
await bridge.publish_end(run_id)
# Subscriber must receive END_SENTINEL
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
break
assert any(e is END_SENTINEL for e in events), "END sentinel was not delivered"
@pytest.mark.anyio
async def test_end_sentinel_evicts_oldest_events():
"""When queue is full, publish_end evicts the oldest events to make room."""
bridge = MemoryStreamBridge(queue_maxsize=1)
run_id = "run-evict"
# Fill queue with one event
await bridge.publish(run_id, "will-be-evicted", {})
assert bridge._queues[run_id].full()
# publish_end must succeed
await bridge.publish_end(run_id)
# The only event we should get is END_SENTINEL (the regular event was evicted)
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
break
assert len(events) == 1
assert events[0] is END_SENTINEL
@pytest.mark.anyio
async def test_end_sentinel_no_eviction_when_space_available():
"""When queue has space, publish_end should not evict anything."""
bridge = MemoryStreamBridge(queue_maxsize=10)
run_id = "run-no-evict"
await bridge.publish(run_id, "event-1", {"n": 1})
await bridge.publish(run_id, "event-2", {"n": 2})
await bridge.publish_end(run_id)
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
break
# All events plus END should be present
assert len(events) == 3
assert events[0].event == "event-1"
assert events[1].event == "event-2"
assert events[2] is END_SENTINEL
@pytest.mark.anyio
async def test_concurrent_tasks_end_sentinel():
"""Multiple concurrent producer/consumer pairs should all terminate properly.
Simulates the production scenario where multiple runs share a single
bridge instance — each must receive its own END sentinel.
"""
bridge = MemoryStreamBridge(queue_maxsize=4)
num_runs = 4
async def producer(run_id: str):
for i in range(10): # More events than queue capacity
await bridge.publish(run_id, f"event-{i}", {"i": i})
await bridge.publish_end(run_id)
async def consumer(run_id: str) -> list:
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
return events
return events # pragma: no cover
# Run producers and consumers concurrently
run_ids = [f"concurrent-{i}" for i in range(num_runs)]
producers = [producer(rid) for rid in run_ids]
consumers = [consumer(rid) for rid in run_ids]
# Start consumers first, then producers
consumer_tasks = [asyncio.create_task(c) for c in consumers]
await asyncio.gather(*producers)
results = await asyncio.wait_for(
asyncio.gather(*consumer_tasks),
timeout=10.0,
)
for i, events in enumerate(results):
assert events[-1] is END_SENTINEL, f"Run {run_ids[i]} did not receive END sentinel"
# ---------------------------------------------------------------------------
# Drop counter tests
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_dropped_count_tracking():
"""Dropped events should be tracked per run_id."""
bridge = MemoryStreamBridge(queue_maxsize=1)
run_id = "run-drop-count"
# Fill the queue
await bridge.publish(run_id, "first", {})
# This publish will time out and be dropped (we patch timeout to be instant)
# Instead, we verify the counter after publish_end eviction
await bridge.publish_end(run_id)
# dropped_count tracks publish() drops, not publish_end evictions
assert bridge.dropped_count(run_id) == 0
# cleanup should also clear the counter
await bridge.cleanup(run_id)
assert bridge.dropped_count(run_id) == 0
@pytest.mark.anyio
async def test_dropped_total():
"""dropped_total should sum across all runs."""
bridge = MemoryStreamBridge(queue_maxsize=256)
# No drops yet
assert bridge.dropped_total == 0
# Manually set some counts to verify the property
bridge._dropped_counts["run-a"] = 3
bridge._dropped_counts["run-b"] = 7
assert bridge.dropped_total == 10
@pytest.mark.anyio
async def test_cleanup_clears_dropped_counts():
"""cleanup() should clear the dropped counter for the run."""
bridge = MemoryStreamBridge(queue_maxsize=256)
run_id = "run-cleanup-drops"
bridge._get_or_create_queue(run_id)
bridge._dropped_counts[run_id] = 5
await bridge.cleanup(run_id)
assert run_id not in bridge._dropped_counts
@pytest.mark.anyio
async def test_close_clears_dropped_counts():
"""close() should clear all dropped counters."""
bridge = MemoryStreamBridge(queue_maxsize=256)
bridge._dropped_counts["run-x"] = 10
bridge._dropped_counts["run-y"] = 20
await bridge.close()
assert bridge.dropped_total == 0
assert len(bridge._dropped_counts) == 0
# ---------------------------------------------------------------------------
# Factory tests
# ---------------------------------------------------------------------------
+5 -5
View File
@@ -1,5 +1,5 @@
import asyncio
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
from app.gateway.routers import suggestions
@@ -43,7 +43,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.return_value = MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```')
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```'))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
@@ -61,7 +61,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.return_value = MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}])
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}]))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
@@ -79,7 +79,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.return_value = MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}])
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}]))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
@@ -94,7 +94,7 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.side_effect = RuntimeError("boom")
fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("boom"))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
@@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares import title_middleware as title_middleware_module
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
@@ -73,37 +74,32 @@ class TestTitleMiddlewareCoreLogic:
assert middleware._should_generate_title(state) is False
def test_generate_title_trims_quotes_and_respects_max_chars(self, monkeypatch):
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
_set_test_title_config(max_chars=12)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='"A very long generated title"'))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = {
"messages": [
HumanMessage(content="请帮我写一个脚本"),
HumanMessage(content="请帮我写一个很长很长的脚本标题"),
AIMessage(content="好的,先确认需求"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
assert '"' not in title
assert "'" not in title
assert len(title) == 12
assert title == "短标题"
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
model.ainvoke.assert_awaited_once()
def test_generate_title_normalizes_structured_message_and_response_content(self, monkeypatch):
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.ainvoke = AsyncMock(
return_value=MagicMock(content=[{"type": "text", "text": '"结构总结"'}]),
)
monkeypatch.setattr(
"deerflow.agents.middlewares.title_middleware.create_chat_model",
lambda **kwargs: fake_model,
)
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = {
"messages": [
@@ -115,21 +111,14 @@ class TestTitleMiddlewareCoreLogic:
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
prompt = fake_model.ainvoke.await_args.args[0]
assert "请帮我总结这段代码" in prompt
assert "好的,先看结构" in prompt
# Ensure structured message dict/JSON reprs are not leaking into the prompt.
assert "{'type':" not in prompt
assert "'type':" not in prompt
assert '"type":' not in prompt
assert title == "结构总结"
assert title == "请帮我总结这段代码"
def test_generate_title_fallback_when_model_fails(self, monkeypatch):
def test_generate_title_fallback_for_long_message(self, monkeypatch):
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
model = MagicMock()
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = {
"messages": [
@@ -164,13 +153,10 @@ class TestTitleMiddlewareCoreLogic:
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None
def test_sync_generate_title_with_model(self, monkeypatch):
"""Sync path calls model.invoke and produces a title."""
def test_sync_generate_title_uses_fallback_without_model(self):
"""Sync path avoids LLM calls and derives a local fallback title."""
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.invoke = MagicMock(return_value=MagicMock(content='"同步生成的标题"'))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
state = {
"messages": [
@@ -179,22 +165,19 @@ class TestTitleMiddlewareCoreLogic:
]
}
result = middleware._generate_title_result(state)
assert result == {"title": "同步生成的标题"}
fake_model.invoke.assert_called_once()
assert result == {"title": "请帮我写测试"}
def test_empty_title_falls_back(self, monkeypatch):
"""Empty model response triggers fallback title."""
def test_sync_generate_title_respects_fallback_truncation(self):
"""Sync fallback path still respects max_chars truncation rules."""
_set_test_title_config(max_chars=50)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.invoke = MagicMock(return_value=MagicMock(content=" "))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
state = {
"messages": [
HumanMessage(content="空标题测试"),
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题,而且这里继续补充更多上下文,确保超过本地fallback截断阈值"),
AIMessage(content="回复"),
]
}
result = middleware._generate_title_result(state)
assert result["title"] == "空标题测试"
assert result["title"].endswith("...")
assert result["title"].startswith("这是一个非常长的问题描述")
@@ -0,0 +1,161 @@
"""Unit tests for tool output truncation functions.
These functions truncate long tool outputs to prevent context window overflow.
- _truncate_bash_output: middle-truncation (head + tail), for bash tool
- _truncate_read_file_output: head-truncation, for read_file tool
"""
from deerflow.sandbox.tools import _truncate_bash_output, _truncate_read_file_output
# ---------------------------------------------------------------------------
# _truncate_bash_output
# ---------------------------------------------------------------------------
class TestTruncateBashOutput:
def test_short_output_returned_unchanged(self):
output = "hello world"
assert _truncate_bash_output(output, 20000) == output
def test_output_equal_to_limit_returned_unchanged(self):
output = "A" * 20000
assert _truncate_bash_output(output, 20000) == output
def test_long_output_is_truncated(self):
output = "A" * 30000
result = _truncate_bash_output(output, 20000)
assert len(result) < len(output)
def test_result_never_exceeds_max_chars(self):
output = "A" * 30000
max_chars = 20000
result = _truncate_bash_output(output, max_chars)
assert len(result) <= max_chars
def test_head_is_preserved(self):
head = "HEAD_CONTENT"
output = head + "M" * 30000
result = _truncate_bash_output(output, 20000)
assert result.startswith(head)
def test_tail_is_preserved(self):
tail = "TAIL_CONTENT"
output = "M" * 30000 + tail
result = _truncate_bash_output(output, 20000)
assert result.endswith(tail)
def test_middle_truncation_marker_present(self):
output = "A" * 30000
result = _truncate_bash_output(output, 20000)
assert "[middle truncated:" in result
assert "chars skipped" in result
def test_skipped_chars_count_is_correct(self):
output = "A" * 25000
result = _truncate_bash_output(output, 20000)
# Extract the reported skipped count and verify it equals len(output) - kept.
# (kept = max_chars - marker_max_len, where marker_max_len is computed from
# the worst-case marker string — so the exact value is implementation-defined,
# but it must equal len(output) minus the chars actually preserved.)
import re
m = re.search(r"(\d+) chars skipped", result)
assert m is not None
reported_skipped = int(m.group(1))
# Verify the number is self-consistent: head + skipped + tail == total
assert reported_skipped > 0
# The marker reports exactly the chars between head and tail
head_and_tail = len(output) - reported_skipped
assert result.startswith(output[: head_and_tail // 2])
def test_max_chars_zero_disables_truncation(self):
output = "A" * 100000
assert _truncate_bash_output(output, 0) == output
def test_50_50_split(self):
# head and tail should each be roughly max_chars // 2
output = "H" * 20000 + "M" * 10000 + "T" * 20000
result = _truncate_bash_output(output, 20000)
assert result[:100] == "H" * 100
assert result[-100:] == "T" * 100
def test_small_max_chars_does_not_crash(self):
output = "A" * 1000
result = _truncate_bash_output(output, 10)
assert len(result) <= 10
def test_result_never_exceeds_max_chars_various_sizes(self):
output = "X" * 50000
for max_chars in [100, 1000, 5000, 20000, 49999]:
result = _truncate_bash_output(output, max_chars)
assert len(result) <= max_chars, f"failed for max_chars={max_chars}"
# ---------------------------------------------------------------------------
# _truncate_read_file_output
# ---------------------------------------------------------------------------
class TestTruncateReadFileOutput:
def test_short_output_returned_unchanged(self):
output = "def foo():\n pass\n"
assert _truncate_read_file_output(output, 50000) == output
def test_output_equal_to_limit_returned_unchanged(self):
output = "X" * 50000
assert _truncate_read_file_output(output, 50000) == output
def test_long_output_is_truncated(self):
output = "X" * 60000
result = _truncate_read_file_output(output, 50000)
assert len(result) < len(output)
def test_result_never_exceeds_max_chars(self):
output = "X" * 60000
max_chars = 50000
result = _truncate_read_file_output(output, max_chars)
assert len(result) <= max_chars
def test_head_is_preserved(self):
head = "import os\nimport sys\n"
output = head + "X" * 60000
result = _truncate_read_file_output(output, 50000)
assert result.startswith(head)
def test_truncation_marker_present(self):
output = "X" * 60000
result = _truncate_read_file_output(output, 50000)
assert "[truncated:" in result
assert "showing first" in result
def test_total_chars_reported_correctly(self):
output = "X" * 60000
result = _truncate_read_file_output(output, 50000)
assert "of 60000 chars" in result
def test_start_line_hint_present(self):
output = "X" * 60000
result = _truncate_read_file_output(output, 50000)
assert "start_line" in result
assert "end_line" in result
def test_max_chars_zero_disables_truncation(self):
output = "X" * 100000
assert _truncate_read_file_output(output, 0) == output
def test_tail_is_not_preserved(self):
# head-truncation: tail should be cut off
output = "H" * 50000 + "TAIL_SHOULD_NOT_APPEAR"
result = _truncate_read_file_output(output, 50000)
assert "TAIL_SHOULD_NOT_APPEAR" not in result
def test_small_max_chars_does_not_crash(self):
output = "X" * 1000
result = _truncate_read_file_output(output, 10)
assert len(result) <= 10
def test_result_never_exceeds_max_chars_various_sizes(self):
output = "X" * 50000
for max_chars in [100, 1000, 5000, 20000, 49999]:
result = _truncate_read_file_output(output, max_chars)
assert len(result) <= max_chars, f"failed for max_chars={max_chars}"
+85 -10
View File
@@ -2,6 +2,8 @@
from __future__ import annotations
import pytest
from deerflow.config import tracing_config as tracing_module
@@ -9,6 +11,29 @@ def _reset_tracing_cache() -> None:
tracing_module._tracing_config = None
@pytest.fixture(autouse=True)
def clear_tracing_env(monkeypatch):
for name in (
"LANGSMITH_TRACING",
"LANGCHAIN_TRACING_V2",
"LANGCHAIN_TRACING",
"LANGSMITH_API_KEY",
"LANGCHAIN_API_KEY",
"LANGSMITH_PROJECT",
"LANGCHAIN_PROJECT",
"LANGSMITH_ENDPOINT",
"LANGCHAIN_ENDPOINT",
"LANGFUSE_TRACING",
"LANGFUSE_PUBLIC_KEY",
"LANGFUSE_SECRET_KEY",
"LANGFUSE_BASE_URL",
):
monkeypatch.delenv(name, raising=False)
_reset_tracing_cache()
yield
_reset_tracing_cache()
def test_prefers_langsmith_env_names(monkeypatch):
monkeypatch.setenv("LANGSMITH_TRACING", "true")
monkeypatch.setenv("LANGSMITH_API_KEY", "lsv2_key")
@@ -18,11 +43,12 @@ def test_prefers_langsmith_env_names(monkeypatch):
_reset_tracing_cache()
cfg = tracing_module.get_tracing_config()
assert cfg.enabled is True
assert cfg.api_key == "lsv2_key"
assert cfg.project == "smith-project"
assert cfg.endpoint == "https://smith.example.com"
assert cfg.langsmith.enabled is True
assert cfg.langsmith.api_key == "lsv2_key"
assert cfg.langsmith.project == "smith-project"
assert cfg.langsmith.endpoint == "https://smith.example.com"
assert tracing_module.is_tracing_enabled() is True
assert tracing_module.get_enabled_tracing_providers() == ["langsmith"]
def test_falls_back_to_langchain_env_names(monkeypatch):
@@ -39,11 +65,12 @@ def test_falls_back_to_langchain_env_names(monkeypatch):
_reset_tracing_cache()
cfg = tracing_module.get_tracing_config()
assert cfg.enabled is True
assert cfg.api_key == "legacy-key"
assert cfg.project == "legacy-project"
assert cfg.endpoint == "https://legacy.example.com"
assert cfg.langsmith.enabled is True
assert cfg.langsmith.api_key == "legacy-key"
assert cfg.langsmith.project == "legacy-project"
assert cfg.langsmith.endpoint == "https://legacy.example.com"
assert tracing_module.is_tracing_enabled() is True
assert tracing_module.get_enabled_tracing_providers() == ["langsmith"]
def test_langsmith_tracing_false_overrides_langchain_tracing_v2_true(monkeypatch):
@@ -55,8 +82,9 @@ def test_langsmith_tracing_false_overrides_langchain_tracing_v2_true(monkeypatch
_reset_tracing_cache()
cfg = tracing_module.get_tracing_config()
assert cfg.enabled is False
assert cfg.langsmith.enabled is False
assert tracing_module.is_tracing_enabled() is False
assert tracing_module.get_enabled_tracing_providers() == []
def test_defaults_when_project_not_set(monkeypatch):
@@ -68,4 +96,51 @@ def test_defaults_when_project_not_set(monkeypatch):
_reset_tracing_cache()
cfg = tracing_module.get_tracing_config()
assert cfg.project == "deer-flow"
assert cfg.langsmith.project == "deer-flow"
def test_langfuse_config_is_loaded(monkeypatch):
monkeypatch.setenv("LANGFUSE_TRACING", "true")
monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test")
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test")
monkeypatch.setenv("LANGFUSE_BASE_URL", "https://langfuse.example.com")
_reset_tracing_cache()
cfg = tracing_module.get_tracing_config()
assert cfg.langfuse.enabled is True
assert cfg.langfuse.public_key == "pk-lf-test"
assert cfg.langfuse.secret_key == "sk-lf-test"
assert cfg.langfuse.host == "https://langfuse.example.com"
assert tracing_module.get_enabled_tracing_providers() == ["langfuse"]
def test_dual_provider_config_is_loaded(monkeypatch):
monkeypatch.setenv("LANGSMITH_TRACING", "true")
monkeypatch.setenv("LANGSMITH_API_KEY", "lsv2_key")
monkeypatch.setenv("LANGFUSE_TRACING", "true")
monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test")
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test")
_reset_tracing_cache()
cfg = tracing_module.get_tracing_config()
assert cfg.langsmith.is_configured is True
assert cfg.langfuse.is_configured is True
assert tracing_module.is_tracing_enabled() is True
assert tracing_module.get_enabled_tracing_providers() == ["langsmith", "langfuse"]
def test_langfuse_enabled_requires_public_and_secret_keys(monkeypatch):
monkeypatch.setenv("LANGFUSE_TRACING", "true")
monkeypatch.delenv("LANGFUSE_PUBLIC_KEY", raising=False)
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test")
_reset_tracing_cache()
assert tracing_module.get_tracing_config().is_configured is False
assert tracing_module.get_enabled_tracing_providers() == []
assert tracing_module.get_tracing_config().explicitly_enabled_providers == ["langfuse"]
with pytest.raises(ValueError, match="LANGFUSE_PUBLIC_KEY"):
tracing_module.validate_enabled_tracing_providers()
+173
View File
@@ -0,0 +1,173 @@
"""Tests for deerflow.tracing.factory."""
from __future__ import annotations
import sys
import types
import pytest
from deerflow.tracing import factory as tracing_factory
@pytest.fixture(autouse=True)
def clear_tracing_env(monkeypatch):
from deerflow.config import tracing_config as tracing_module
for name in (
"LANGSMITH_TRACING",
"LANGCHAIN_TRACING_V2",
"LANGCHAIN_TRACING",
"LANGSMITH_API_KEY",
"LANGCHAIN_API_KEY",
"LANGSMITH_PROJECT",
"LANGCHAIN_PROJECT",
"LANGSMITH_ENDPOINT",
"LANGCHAIN_ENDPOINT",
"LANGFUSE_TRACING",
"LANGFUSE_PUBLIC_KEY",
"LANGFUSE_SECRET_KEY",
"LANGFUSE_BASE_URL",
):
monkeypatch.delenv(name, raising=False)
tracing_module._tracing_config = None
yield
tracing_module._tracing_config = None
def test_build_tracing_callbacks_returns_empty_list_when_disabled(monkeypatch):
monkeypatch.setattr(tracing_factory, "validate_enabled_tracing_providers", lambda: None)
monkeypatch.setattr(tracing_factory, "get_enabled_tracing_providers", lambda: [])
callbacks = tracing_factory.build_tracing_callbacks()
assert callbacks == []
def test_build_tracing_callbacks_creates_langsmith_and_langfuse(monkeypatch):
class FakeLangSmithTracer:
def __init__(self, *, project_name: str):
self.project_name = project_name
class FakeLangfuseHandler:
def __init__(self, *, public_key: str):
self.public_key = public_key
monkeypatch.setattr(tracing_factory, "get_enabled_tracing_providers", lambda: ["langsmith", "langfuse"])
monkeypatch.setattr(tracing_factory, "validate_enabled_tracing_providers", lambda: None)
monkeypatch.setattr(
tracing_factory,
"get_tracing_config",
lambda: type(
"Cfg",
(),
{
"langsmith": type("LangSmithCfg", (), {"project": "smith-project"})(),
"langfuse": type(
"LangfuseCfg",
(),
{
"secret_key": "sk-lf-test",
"public_key": "pk-lf-test",
"host": "https://langfuse.example.com",
},
)(),
},
)(),
)
monkeypatch.setattr(tracing_factory, "_create_langsmith_tracer", lambda cfg: FakeLangSmithTracer(project_name=cfg.project))
monkeypatch.setattr(
tracing_factory,
"_create_langfuse_handler",
lambda cfg: FakeLangfuseHandler(public_key=cfg.public_key),
)
callbacks = tracing_factory.build_tracing_callbacks()
assert len(callbacks) == 2
assert callbacks[0].project_name == "smith-project"
assert callbacks[1].public_key == "pk-lf-test"
def test_build_tracing_callbacks_raises_when_enabled_provider_fails(monkeypatch):
monkeypatch.setattr(tracing_factory, "get_enabled_tracing_providers", lambda: ["langfuse"])
monkeypatch.setattr(tracing_factory, "validate_enabled_tracing_providers", lambda: None)
monkeypatch.setattr(
tracing_factory,
"get_tracing_config",
lambda: type(
"Cfg",
(),
{
"langfuse": type(
"LangfuseCfg",
(),
{"secret_key": "sk-lf-test", "public_key": "pk-lf-test", "host": "https://langfuse.example.com"},
)(),
},
)(),
)
monkeypatch.setattr(tracing_factory, "_create_langfuse_handler", lambda cfg: (_ for _ in ()).throw(RuntimeError("boom")))
with pytest.raises(RuntimeError, match="Langfuse tracing initialization failed"):
tracing_factory.build_tracing_callbacks()
def test_build_tracing_callbacks_raises_for_explicitly_enabled_misconfigured_provider(monkeypatch):
from deerflow.config import tracing_config as tracing_module
monkeypatch.setenv("LANGFUSE_TRACING", "true")
monkeypatch.delenv("LANGFUSE_PUBLIC_KEY", raising=False)
monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test")
tracing_module._tracing_config = None
with pytest.raises(ValueError, match="LANGFUSE_PUBLIC_KEY"):
tracing_factory.build_tracing_callbacks()
def test_create_langfuse_handler_initializes_client_before_handler(monkeypatch):
calls: list[tuple[str, dict]] = []
class FakeLangfuse:
def __init__(self, **kwargs):
calls.append(("client", kwargs))
class FakeCallbackHandler:
def __init__(self, **kwargs):
calls.append(("handler", kwargs))
fake_langfuse_module = types.ModuleType("langfuse")
fake_langfuse_module.Langfuse = FakeLangfuse
fake_langfuse_langchain_module = types.ModuleType("langfuse.langchain")
fake_langfuse_langchain_module.CallbackHandler = FakeCallbackHandler
monkeypatch.setitem(sys.modules, "langfuse", fake_langfuse_module)
monkeypatch.setitem(sys.modules, "langfuse.langchain", fake_langfuse_langchain_module)
cfg = type(
"LangfuseCfg",
(),
{
"secret_key": "sk-lf-test",
"public_key": "pk-lf-test",
"host": "https://langfuse.example.com",
},
)()
tracing_factory._create_langfuse_handler(cfg)
assert calls == [
(
"client",
{
"secret_key": "sk-lf-test",
"public_key": "pk-lf-test",
"host": "https://langfuse.example.com",
},
),
(
"handler",
{
"public_key": "pk-lf-test",
},
),
]
@@ -289,6 +289,8 @@ class TestBeforeAgent:
"size": 5,
"path": "/mnt/user-data/uploads/notes.txt",
"extension": ".txt",
"outline": [],
"outline_preview": [],
}
]
@@ -339,3 +341,130 @@ class TestBeforeAgent:
result = mw.before_agent(self._state(msg), _runtime())
assert result["messages"][-1].id == "original-id-42"
def test_outline_injected_when_md_file_exists(self, tmp_path):
"""When a converted .md file exists alongside the upload, its outline is injected."""
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
(uploads_dir / "report.pdf").write_bytes(b"%PDF fake")
# Simulate the .md produced by the conversion pipeline
(uploads_dir / "report.md").write_text(
"# PART I\n\n## ITEM 1. BUSINESS\n\nBody text.\n\n## ITEM 2. RISK\n",
encoding="utf-8",
)
msg = _human("summarise", files=[{"filename": "report.pdf", "size": 9, "path": "/mnt/user-data/uploads/report.pdf"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
assert "Document outline" in content
assert "PART I" in content
assert "ITEM 1. BUSINESS" in content
assert "ITEM 2. RISK" in content
assert "read_file" in content
def test_no_outline_when_no_md_file(self, tmp_path):
"""Files without a sibling .md have no outline section."""
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
(uploads_dir / "data.xlsx").write_bytes(b"fake-xlsx")
msg = _human("analyse", files=[{"filename": "data.xlsx", "size": 9, "path": "/mnt/user-data/uploads/data.xlsx"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
assert "Document outline" not in content
def test_outline_truncation_hint_shown(self, tmp_path):
"""When outline is truncated, a hint line is appended after the last visible entry."""
from deerflow.utils.file_conversion import MAX_OUTLINE_ENTRIES
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
(uploads_dir / "big.pdf").write_bytes(b"%PDF fake")
# Write MAX_OUTLINE_ENTRIES + 5 headings so truncation is triggered
headings = "\n".join(f"# Heading {i}" for i in range(MAX_OUTLINE_ENTRIES + 5))
(uploads_dir / "big.md").write_text(headings, encoding="utf-8")
msg = _human("read", files=[{"filename": "big.pdf", "size": 9, "path": "/mnt/user-data/uploads/big.pdf"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
assert f"showing first {MAX_OUTLINE_ENTRIES} headings" in content
assert "use `read_file` to explore further" in content
def test_no_truncation_hint_for_short_outline(self, tmp_path):
"""Short outlines (under the cap) must not show a truncation hint."""
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
(uploads_dir / "short.pdf").write_bytes(b"%PDF fake")
(uploads_dir / "short.md").write_text("# Intro\n\n# Conclusion\n", encoding="utf-8")
msg = _human("read", files=[{"filename": "short.pdf", "size": 9, "path": "/mnt/user-data/uploads/short.pdf"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
assert "showing first" not in content
def test_historical_file_outline_injected(self, tmp_path):
"""Outline is also shown for historical (previously uploaded) files."""
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
# Historical file with .md
(uploads_dir / "old_report.pdf").write_bytes(b"%PDF old")
(uploads_dir / "old_report.md").write_text(
"# Chapter 1\n\n# Chapter 2\n",
encoding="utf-8",
)
# New file without .md
(uploads_dir / "new.txt").write_bytes(b"new")
msg = _human("go", files=[{"filename": "new.txt", "size": 3, "path": "/mnt/user-data/uploads/new.txt"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
assert "Chapter 1" in content
assert "Chapter 2" in content
def test_fallback_preview_shown_when_outline_empty(self, tmp_path):
"""When .md exists but has no headings, first lines are shown as a preview."""
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
(uploads_dir / "report.pdf").write_bytes(b"%PDF fake")
# .md with no # headings — plain prose only
(uploads_dir / "report.md").write_text(
"Annual Financial Report 2024\n\nThis document summarises key findings.\n\nRevenue grew by 12%.\n",
encoding="utf-8",
)
msg = _human("analyse", files=[{"filename": "report.pdf", "size": 9, "path": "/mnt/user-data/uploads/report.pdf"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
# Outline section must NOT appear
assert "Document outline" not in content
# Preview lines must appear
assert "Annual Financial Report 2024" in content
assert "No structural headings detected" in content
# grep hint must appear
assert "grep" in content
def test_fallback_grep_hint_shown_when_no_md_file(self, tmp_path):
"""Files with no sibling .md still get the grep hint (outline is empty)."""
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
(uploads_dir / "data.csv").write_bytes(b"a,b,c\n1,2,3\n")
msg = _human("analyse", files=[{"filename": "data.csv", "size": 12, "path": "/mnt/user-data/uploads/data.csv"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
assert "Document outline" not in content
assert "grep" in content