mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-20 23:21:06 +00:00
feat: implement process-local internal authentication for Gateway and enhance CSRF handling
This commit is contained in:
@@ -174,6 +174,20 @@ def test_protected_post_no_cookie_returns_401(client):
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_post_with_internal_auth_header_passes():
|
||||
from app.gateway.internal_auth import create_internal_auth_headers
|
||||
|
||||
app = _make_app()
|
||||
client = TestClient(app)
|
||||
|
||||
res = client.post(
|
||||
"/api/threads/abc/runs/stream",
|
||||
headers=create_internal_auth_headers(),
|
||||
)
|
||||
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -414,6 +414,27 @@ def _make_async_iterator(items):
|
||||
|
||||
|
||||
class TestChannelManager:
|
||||
def test_get_client_includes_csrf_header_and_cookie(self):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store, langgraph_url="http://localhost:8001")
|
||||
|
||||
with patch("langgraph_sdk.get_client") as get_client:
|
||||
get_client.return_value = object()
|
||||
|
||||
manager._get_client()
|
||||
|
||||
get_client.assert_called_once()
|
||||
kwargs = get_client.call_args.kwargs
|
||||
assert kwargs["url"] == "http://localhost:8001"
|
||||
headers = kwargs["headers"]
|
||||
csrf_token = headers["X-CSRF-Token"]
|
||||
assert csrf_token
|
||||
assert headers["Cookie"] == f"csrf_token={csrf_token}"
|
||||
assert headers["X-DeerFlow-Internal-Token"]
|
||||
|
||||
def test_handle_chat_calls_channel_receive_file_for_inbound_files(self, monkeypatch):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
|
||||
@@ -7,6 +7,15 @@ from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
|
||||
def _set_skills_cache_state(*, skills=None, active=False, version=0):
|
||||
prompt_module._get_cached_skills_prompt_section.cache_clear()
|
||||
with prompt_module._enabled_skills_lock:
|
||||
prompt_module._enabled_skills_cache = skills
|
||||
prompt_module._enabled_skills_refresh_active = active
|
||||
prompt_module._enabled_skills_refresh_version = version
|
||||
prompt_module._enabled_skills_refresh_event.clear()
|
||||
|
||||
|
||||
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
|
||||
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
@@ -84,7 +93,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
|
||||
state = {"skills": [make_skill("first-skill")]}
|
||||
monkeypatch.setattr(prompt_module, "load_skills", lambda enabled_only=True: list(state["skills"]))
|
||||
prompt_module.clear_skills_system_prompt_cache()
|
||||
_set_skills_cache_state()
|
||||
|
||||
try:
|
||||
prompt_module.warm_enabled_skills_cache()
|
||||
@@ -95,7 +104,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
|
||||
assert [skill.name for skill in prompt_module._get_enabled_skills()] == ["second-skill"]
|
||||
finally:
|
||||
prompt_module.clear_skills_system_prompt_cache()
|
||||
_set_skills_cache_state()
|
||||
|
||||
|
||||
def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_path):
|
||||
@@ -137,7 +146,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
return [make_skill(f"skill-{current_call}")]
|
||||
|
||||
monkeypatch.setattr(prompt_module, "load_skills", fake_load_skills)
|
||||
prompt_module.clear_skills_system_prompt_cache()
|
||||
_set_skills_cache_state()
|
||||
|
||||
try:
|
||||
prompt_module.clear_skills_system_prompt_cache()
|
||||
@@ -151,7 +160,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
assert [skill.name for skill in prompt_module._get_enabled_skills()] == ["skill-2"]
|
||||
finally:
|
||||
release.set()
|
||||
prompt_module.clear_skills_system_prompt_cache()
|
||||
_set_skills_cache_state()
|
||||
|
||||
|
||||
def test_warm_enabled_skills_cache_logs_on_timeout(monkeypatch, caplog):
|
||||
|
||||
@@ -93,7 +93,10 @@ class TestTitleMiddlewareCoreLogic:
|
||||
assert title == "短标题"
|
||||
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
|
||||
model.ainvoke.assert_awaited_once()
|
||||
assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "title_agent"}
|
||||
assert model.ainvoke.await_args.kwargs["config"] == {
|
||||
"run_name": "title_agent",
|
||||
"tags": ["middleware:title"],
|
||||
}
|
||||
|
||||
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=20)
|
||||
|
||||
@@ -49,7 +49,7 @@ def test_upload_files_skips_acquire_when_thread_data_is_mounted(tmp_path):
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
|
||||
result = asyncio.run(uploads.upload_files("thread-mounted", files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-mounted", request=MagicMock(), files=[file]))
|
||||
|
||||
assert result.success is True
|
||||
assert (thread_uploads_dir / "notes.txt").read_bytes() == b"hello uploads"
|
||||
@@ -75,7 +75,7 @@ def test_upload_files_does_not_auto_convert_documents_by_default(tmp_path):
|
||||
patch.object(uploads, "convert_file_to_markdown", AsyncMock()) as convert_mock,
|
||||
):
|
||||
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file]))
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.files) == 1
|
||||
|
||||
Reference in New Issue
Block a user