refactor: thread release config through lead path (#2612)

Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
This commit is contained in:
greatmengqi
2026-04-28 14:53:18 +08:00
committed by GitHub
parent 69649d8aae
commit e82940c03d
20 changed files with 325 additions and 179 deletions
@@ -84,14 +84,15 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
captured["app_config"] = app_config
return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
@@ -110,6 +111,7 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
assert captured["name"] == "safe-model"
assert captured["thinking_enabled"] is False
assert captured["app_config"] is app_config
assert result["model"] is not None
@@ -126,14 +128,15 @@ def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
get_available_tools = MagicMock(return_value=[])
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", get_available_tools)
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
captured["app_config"] = app_config
return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
@@ -156,8 +159,9 @@ def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
"name": "context-model",
"thinking_enabled": False,
"reasoning_effort": "high",
"app_config": app_config,
}
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True)
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True, app_config=app_config)
assert result["model"] is not None
@@ -198,10 +202,15 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda **kwargs: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares({"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}},
model_name="vision-model",
custom_middlewares=[MagicMock()],
app_config=app_config,
)
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
# verify the custom middleware is injected correctly
@@ -222,18 +231,20 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
captured["app_config"] = app_config
return fake_model
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
middleware = lead_agent_module._create_summarization_middleware()
middleware = lead_agent_module._create_summarization_middleware(app_config=_make_app_config([_make_model("model-masswork", supports_thinking=False)]))
assert captured["name"] == "model-masswork"
assert captured["thinking_enabled"] is False
assert captured["app_config"] is not None
assert middleware["model"] is fake_model
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
+2 -2
View File
@@ -48,7 +48,7 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
@@ -66,7 +66,7 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
+19 -1
View File
@@ -100,6 +100,24 @@ def test_get_skills_prompt_section_cache_respects_skill_evolution_toggle(monkeyp
assert "Skill Self-Evolution" not in disabled_result
def test_get_skills_prompt_section_uses_explicit_config_for_enabled_skills(monkeypatch):
explicit_config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/alt-skills"),
skill_evolution=SimpleNamespace(enabled=False),
)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [_make_skill("global-skill")])
monkeypatch.setattr(
"deerflow.agents.lead_agent.prompt.load_skills",
lambda enabled_only=True, app_config=None: [_make_skill("explicit-skill")] if app_config is explicit_config else [],
)
result = get_skills_prompt_section(app_config=explicit_config)
assert "explicit-skill" in result
assert "global-skill" not in result
def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
from unittest.mock import MagicMock
@@ -107,7 +125,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
# Mock dependencies
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock())
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model")
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
+18 -1
View File
@@ -2,7 +2,7 @@ from unittest.mock import AsyncMock, call
import pytest
from deerflow.runtime.runs.worker import _rollback_to_pre_run_checkpoint
from deerflow.runtime.runs.worker import _agent_factory_supports_app_config, _rollback_to_pre_run_checkpoint
class FakeCheckpointer:
@@ -212,3 +212,20 @@ async def test_rollback_propagates_aput_writes_failure():
# aput succeeded, aput_writes was called but failed
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_awaited_once()
def test_agent_factory_supports_app_config_detects_supported_signature():
def factory(*, config, app_config=None):
return (config, app_config)
assert _agent_factory_supports_app_config(factory) is True
def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_fails(monkeypatch):
class BrokenCallable:
def __call__(self, **kwargs):
return kwargs
monkeypatch.setattr("deerflow.runtime.runs.worker.inspect.signature", lambda _obj: (_ for _ in ()).throw(ValueError("boom")))
assert _agent_factory_supports_app_config(BrokenCallable()) is False
+15 -14
View File
@@ -35,6 +35,13 @@ def _make_skill(name: str, *, enabled: bool) -> Skill:
)
def _make_test_app(config) -> FastAPI:
app = FastAPI()
app.state.config = config
app.include_router(skills_router.router)
return app
def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
skills_root = tmp_path / "skills"
custom_dir = skills_root / "custom" / "demo-skill"
@@ -54,8 +61,7 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.include_router(skills_router.router)
app = _make_test_app(config)
with TestClient(app) as client:
response = client.get("/api/skills/custom")
@@ -96,7 +102,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
get_skill_history_file("demo-skill").write_text(
get_skill_history_file("demo-skill", app_config=config).write_text(
'{"action":"human_edit","prev_content":' + json.dumps(original_content) + ',"new_content":' + json.dumps(edited_content) + "}\n",
encoding="utf-8",
)
@@ -113,8 +119,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", _scan)
app = FastAPI()
app.include_router(skills_router.router)
app = _make_test_app(config)
with TestClient(app) as client:
rollback_response = client.post("/api/skills/custom/demo-skill/rollback", json={"history_index": -1})
@@ -146,8 +151,7 @@ def test_custom_skill_delete_preserves_history_and_allows_restore(monkeypatch, t
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.include_router(skills_router.router)
app = _make_test_app(config)
with TestClient(app) as client:
delete_response = client.delete("/api/skills/custom/demo-skill")
@@ -187,8 +191,7 @@ def test_custom_skill_delete_continues_when_history_write_is_readonly(monkeypatc
monkeypatch.setattr("app.gateway.routers.skills.append_history", _readonly_history)
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.include_router(skills_router.router)
app = _make_test_app(config)
with TestClient(app) as client:
delete_response = client.delete("/api/skills/custom/demo-skill")
@@ -221,8 +224,7 @@ def test_custom_skill_delete_fails_when_skill_dir_removal_fails(monkeypatch, tmp
monkeypatch.setattr("app.gateway.routers.skills.shutil.rmtree", _fail_rmtree)
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.include_router(skills_router.router)
app = _make_test_app(config)
with TestClient(app) as client:
delete_response = client.delete("/api/skills/custom/demo-skill")
@@ -238,7 +240,7 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
enabled_state = {"value": True}
refresh_calls = []
def _load_skills(*, enabled_only: bool):
def _load_skills(*, enabled_only: bool, app_config=None):
skill = _make_skill("demo-skill", enabled=enabled_state["value"])
if enabled_only and not skill.enabled:
return []
@@ -254,8 +256,7 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
monkeypatch.setattr(skills_router.ExtensionsConfig, "resolve_config_path", staticmethod(lambda: config_path))
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.include_router(skills_router.router)
app = _make_test_app(SimpleNamespace())
with TestClient(app) as client:
response = client.put("/api/skills/demo-skill", json={"enabled": False})
+5 -4
View File
@@ -1,4 +1,5 @@
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from app.gateway.routers import suggestions
@@ -48,7 +49,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
# Bypass the require_permission decorator (which needs request +
# thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
assert result.suggestions == ["Q1", "Q2", "Q3"]
fake_model.ainvoke.assert_awaited_once()
@@ -70,7 +71,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
# Bypass the require_permission decorator (which needs request +
# thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
assert result.suggestions == ["Q1", "Q2"]
fake_model.ainvoke.assert_awaited_once()
@@ -92,7 +93,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
# Bypass the require_permission decorator (which needs request +
# thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
assert result.suggestions == ["Q1", "Q2"]
fake_model.ainvoke.assert_awaited_once()
@@ -111,6 +112,6 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
# Bypass the require_permission decorator (which needs request +
# thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
assert result.suggestions == []
+21 -20
View File
@@ -2,6 +2,7 @@ import asyncio
import stat
from io import BytesIO
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
from _router_auth_helpers import call_unwrapped
@@ -26,7 +27,7 @@ def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_pat
patch.object(uploads, "get_sandbox_provider", return_value=provider),
):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
assert len(result.files) == 1
@@ -49,7 +50,7 @@ def test_upload_files_skips_acquire_when_thread_data_is_mounted(tmp_path):
patch.object(uploads, "get_sandbox_provider", return_value=provider),
):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-mounted", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-mounted", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
assert (thread_uploads_dir / "notes.txt").read_bytes() == b"hello uploads"
@@ -75,7 +76,7 @@ def test_upload_files_does_not_auto_convert_documents_by_default(tmp_path):
patch.object(uploads, "convert_file_to_markdown", AsyncMock()) as convert_mock,
):
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
assert len(result.files) == 1
@@ -108,7 +109,7 @@ def test_upload_files_syncs_non_local_sandbox_and_marks_markdown_file(tmp_path):
patch.object(uploads, "convert_file_to_markdown", AsyncMock(side_effect=fake_convert)),
):
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
assert len(result.files) == 1
@@ -147,7 +148,7 @@ def test_upload_files_makes_non_local_files_sandbox_writable(tmp_path):
patch.object(uploads, "_make_file_sandbox_writable") as make_writable,
):
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
make_writable.assert_any_call(thread_uploads_dir / "report.pdf")
@@ -171,7 +172,7 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
patch.object(uploads, "_make_file_sandbox_writable") as make_writable,
):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
make_writable.assert_not_called()
@@ -222,13 +223,13 @@ def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path):
# These filenames must be rejected outright
for bad_name in ["..", "."]:
file = UploadFile(filename=bad_name, file=BytesIO(b"data"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
assert result.files == [], f"Expected no files for unsafe filename {bad_name!r}"
# Path-traversal prefixes are stripped to the basename and accepted safely
file = UploadFile(filename="../etc/passwd", file=BytesIO(b"data"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
assert len(result.files) == 1
assert result.files[0]["filename"] == "passwd"
@@ -252,16 +253,20 @@ def test_delete_uploaded_file_removes_generated_markdown_companion(tmp_path):
def test_auto_convert_documents_enabled_defaults_to_false_on_config_errors():
with patch.object(uploads, "get_app_config", side_effect=RuntimeError("boom")):
assert uploads._auto_convert_documents_enabled() is False
class BrokenConfig:
def __getattribute__(self, name):
if name == "uploads":
raise RuntimeError("boom")
return super().__getattribute__(name)
assert uploads._auto_convert_documents_enabled(BrokenConfig()) is False
def test_auto_convert_documents_enabled_reads_dict_backed_uploads_config():
cfg = MagicMock()
cfg.uploads = {"auto_convert_documents": True}
with patch.object(uploads, "get_app_config", return_value=cfg):
assert uploads._auto_convert_documents_enabled() is True
assert uploads._auto_convert_documents_enabled(cfg) is True
def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values():
@@ -277,11 +282,7 @@ def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values
string_false_cfg = MagicMock()
string_false_cfg.uploads = MagicMock(auto_convert_documents="false")
with patch.object(uploads, "get_app_config", return_value=false_cfg):
assert uploads._auto_convert_documents_enabled() is False
with patch.object(uploads, "get_app_config", return_value=true_cfg):
assert uploads._auto_convert_documents_enabled() is True
with patch.object(uploads, "get_app_config", return_value=string_true_cfg):
assert uploads._auto_convert_documents_enabled() is True
with patch.object(uploads, "get_app_config", return_value=string_false_cfg):
assert uploads._auto_convert_documents_enabled() is False
assert uploads._auto_convert_documents_enabled(false_cfg) is False
assert uploads._auto_convert_documents_enabled(true_cfg) is True
assert uploads._auto_convert_documents_enabled(string_true_cfg) is True
assert uploads._auto_convert_documents_enabled(string_false_cfg) is False