mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
ci: enforce code formatting checks for backend and frontend (#1536)
This commit is contained in:
@@ -12,6 +12,7 @@ test:
|
||||
|
||||
lint:
|
||||
uvx ruff check .
|
||||
uvx ruff format --check .
|
||||
|
||||
format:
|
||||
uvx ruff check . --fix && uvx ruff format .
|
||||
|
||||
@@ -66,14 +66,9 @@ def _normalize_custom_agent_name(raw_value: str) -> str:
|
||||
"""Normalize legacy channel assistant IDs into valid custom agent names."""
|
||||
normalized = raw_value.strip().lower().replace("_", "-")
|
||||
if not normalized:
|
||||
raise InvalidChannelSessionConfigError(
|
||||
"Channel session assistant_id is empty. Use 'lead_agent' or a valid custom agent name."
|
||||
)
|
||||
raise InvalidChannelSessionConfigError("Channel session assistant_id is empty. Use 'lead_agent' or a valid custom agent name.")
|
||||
if not CUSTOM_AGENT_NAME_PATTERN.fullmatch(normalized):
|
||||
raise InvalidChannelSessionConfigError(
|
||||
f"Invalid channel session assistant_id {raw_value!r}. "
|
||||
"Use 'lead_agent' or a custom agent name containing only letters, digits, and hyphens."
|
||||
)
|
||||
raise InvalidChannelSessionConfigError(f"Invalid channel session assistant_id {raw_value!r}. Use 'lead_agent' or a custom agent name containing only letters, digits, and hyphens.")
|
||||
return normalized
|
||||
|
||||
|
||||
|
||||
@@ -48,9 +48,7 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
||||
logger.warning("Skipping sandbox chmod for symlinked upload path: %s", file_path)
|
||||
return
|
||||
|
||||
writable_mode = (
|
||||
stat.S_IMODE(file_stat.st_mode) | stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH
|
||||
)
|
||||
writable_mode = stat.S_IMODE(file_stat.st_mode) | stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH
|
||||
chmod_kwargs = {"follow_symlinks": False} if os.chmod in os.supports_follow_symlinks else {}
|
||||
os.chmod(file_path, writable_mode, **chmod_kwargs)
|
||||
|
||||
|
||||
@@ -71,9 +71,7 @@ class FileMemoryStorage(MemoryStorage):
|
||||
if not agent_name:
|
||||
raise ValueError("Agent name must be a non-empty string.")
|
||||
if not AGENT_NAME_PATTERN.match(agent_name):
|
||||
raise ValueError(
|
||||
f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}"
|
||||
)
|
||||
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
|
||||
|
||||
def _get_memory_file_path(self, agent_name: str | None = None) -> Path:
|
||||
"""Get the path to the memory file."""
|
||||
@@ -180,18 +178,15 @@ def get_memory_storage() -> MemoryStorage:
|
||||
try:
|
||||
module_path, class_name = storage_class_path.rsplit(".", 1)
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
storage_class = getattr(module, class_name)
|
||||
|
||||
# Validate that the configured storage is a MemoryStorage implementation
|
||||
if not isinstance(storage_class, type):
|
||||
raise TypeError(
|
||||
f"Configured memory storage '{storage_class_path}' is not a class: {storage_class!r}"
|
||||
)
|
||||
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a class: {storage_class!r}")
|
||||
if not issubclass(storage_class, MemoryStorage):
|
||||
raise TypeError(
|
||||
f"Configured memory storage '{storage_class_path}' is not a subclass of MemoryStorage"
|
||||
)
|
||||
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a subclass of MemoryStorage")
|
||||
|
||||
_storage_instance = storage_class()
|
||||
except Exception as e:
|
||||
|
||||
@@ -27,10 +27,12 @@ def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = N
|
||||
"""Backward-compatible wrapper around the configured memory storage save path."""
|
||||
return get_memory_storage().save(memory_data, agent_name)
|
||||
|
||||
|
||||
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Get the current memory data via storage provider."""
|
||||
return get_memory_storage().load(agent_name)
|
||||
|
||||
|
||||
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data via storage provider."""
|
||||
return get_memory_storage().reload(agent_name)
|
||||
|
||||
@@ -162,10 +162,7 @@ class ClaudeChatModel(ChatAnthropic):
|
||||
system = payload.get("system")
|
||||
if isinstance(system, list):
|
||||
# Remove any existing billing blocks, then insert a single one at index 0.
|
||||
filtered = [
|
||||
b for b in system
|
||||
if not (isinstance(b, dict) and OAUTH_BILLING_HEADER in b.get("text", ""))
|
||||
]
|
||||
filtered = [b for b in system if not (isinstance(b, dict) and OAUTH_BILLING_HEADER in b.get("text", ""))]
|
||||
payload["system"] = [billing_block] + filtered
|
||||
elif isinstance(system, str):
|
||||
if OAUTH_BILLING_HEADER in system:
|
||||
@@ -183,11 +180,13 @@ class ClaudeChatModel(ChatAnthropic):
|
||||
hostname = socket.gethostname()
|
||||
device_id = hashlib.sha256(f"deerflow-{hostname}".encode()).hexdigest()
|
||||
session_id = str(uuid.uuid4())
|
||||
payload["metadata"]["user_id"] = json.dumps({
|
||||
"device_id": device_id,
|
||||
"account_uuid": "deerflow",
|
||||
"session_id": session_id,
|
||||
})
|
||||
payload["metadata"]["user_id"] = json.dumps(
|
||||
{
|
||||
"device_id": device_id,
|
||||
"account_uuid": "deerflow",
|
||||
"session_id": session_id,
|
||||
}
|
||||
)
|
||||
|
||||
def _apply_prompt_caching(self, payload: dict) -> None:
|
||||
"""Apply ephemeral cache_control to system and recent messages."""
|
||||
|
||||
@@ -84,9 +84,7 @@ class PatchedChatOpenAI(ChatOpenAI):
|
||||
else:
|
||||
# Fallback: match assistant-role entries positionally against AIMessages.
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [
|
||||
(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"
|
||||
]
|
||||
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
|
||||
for (_, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
|
||||
_restore_tool_call_signatures(payload_msg, ai_msg)
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ def _resolve_skills_path(path: str) -> str:
|
||||
if path == skills_container:
|
||||
return skills_host
|
||||
|
||||
relative = path[len(skills_container):].lstrip("/")
|
||||
relative = path[len(skills_container) :].lstrip("/")
|
||||
return _join_path_preserving_style(skills_host, relative)
|
||||
|
||||
|
||||
|
||||
@@ -197,6 +197,7 @@ async def task_tool(
|
||||
writer({"type": "task_timed_out", "task_id": task_id})
|
||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||
except asyncio.CancelledError:
|
||||
|
||||
async def cleanup_when_done() -> None:
|
||||
max_cleanup_polls = max_poll_count
|
||||
cleanup_poll_count = 0
|
||||
@@ -211,9 +212,7 @@ async def task_tool(
|
||||
return
|
||||
|
||||
if cleanup_poll_count > max_cleanup_polls:
|
||||
logger.warning(
|
||||
f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls"
|
||||
)
|
||||
logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls")
|
||||
return
|
||||
|
||||
await asyncio.sleep(5)
|
||||
|
||||
@@ -118,9 +118,7 @@ def _regex_score(pattern: str, entry: DeferredToolEntry) -> int:
|
||||
# loop.run_in_executor, Python copies the current context to the worker thread,
|
||||
# so the ContextVar value is correctly inherited there too.
|
||||
|
||||
_registry_var: contextvars.ContextVar[DeferredToolRegistry | None] = contextvars.ContextVar(
|
||||
"deferred_tool_registry", default=None
|
||||
)
|
||||
_registry_var: contextvars.ContextVar[DeferredToolRegistry | None] = contextvars.ContextVar("deferred_tool_registry", default=None)
|
||||
|
||||
|
||||
def get_deferred_registry() -> DeferredToolRegistry | None:
|
||||
|
||||
@@ -600,10 +600,7 @@ class TestChannelManager:
|
||||
await manager.stop()
|
||||
|
||||
mock_client.runs.wait.assert_not_called()
|
||||
assert outbound_received[0].text == (
|
||||
"Invalid channel session assistant_id 'bad agent!'. "
|
||||
"Use 'lead_agent' or a custom agent name containing only letters, digits, and hyphens."
|
||||
)
|
||||
assert outbound_received[0].text == ("Invalid channel session assistant_id 'bad agent!'. Use 'lead_agent' or a custom agent name containing only letters, digits, and hyphens.")
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
@@ -56,10 +56,7 @@ def test_billing_not_duplicated_on_second_call(model):
|
||||
payload = {"system": [{"type": "text", "text": "prompt"}]}
|
||||
model._apply_oauth_billing(payload)
|
||||
model._apply_oauth_billing(payload)
|
||||
billing_count = sum(
|
||||
1 for b in payload["system"]
|
||||
if isinstance(b, dict) and OAUTH_BILLING_HEADER in b.get("text", "")
|
||||
)
|
||||
billing_count = sum(1 for b in payload["system"] if isinstance(b, dict) and OAUTH_BILLING_HEADER in b.get("text", ""))
|
||||
assert billing_count == 1
|
||||
|
||||
|
||||
|
||||
@@ -65,14 +65,7 @@ class TestClientInit:
|
||||
def test_custom_params(self, mock_app_config):
|
||||
mock_middleware = MagicMock()
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
c = DeerFlowClient(
|
||||
model_name="gpt-4",
|
||||
thinking_enabled=False,
|
||||
subagent_enabled=True,
|
||||
plan_mode=True,
|
||||
agent_name="test-agent",
|
||||
middlewares=[mock_middleware]
|
||||
)
|
||||
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", middlewares=[mock_middleware])
|
||||
assert c._model_name == "gpt-4"
|
||||
assert c._thinking_enabled is False
|
||||
assert c._subagent_enabled is True
|
||||
|
||||
@@ -132,18 +132,13 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
|
||||
middlewares = lead_agent_module._build_middlewares(
|
||||
{"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}},
|
||||
model_name="vision-model",
|
||||
custom_middlewares=[MagicMock()]
|
||||
)
|
||||
middlewares = lead_agent_module._build_middlewares({"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
|
||||
|
||||
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
|
||||
# verify the custom middleware is injected correctly
|
||||
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
|
||||
|
||||
|
||||
|
||||
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
lead_agent_module,
|
||||
|
||||
@@ -33,6 +33,7 @@ class TestMemoryStorageInterface:
|
||||
|
||||
def test_abstract_methods(self):
|
||||
"""Should raise TypeError when trying to instantiate abstract class."""
|
||||
|
||||
class TestStorage(MemoryStorage):
|
||||
pass
|
||||
|
||||
@@ -45,6 +46,7 @@ class TestFileMemoryStorage:
|
||||
|
||||
def test_get_memory_file_path_global(self, tmp_path):
|
||||
"""Should return global memory file path when agent_name is None."""
|
||||
|
||||
def mock_get_paths():
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.memory_file = tmp_path / "memory.json"
|
||||
@@ -58,6 +60,7 @@ class TestFileMemoryStorage:
|
||||
|
||||
def test_get_memory_file_path_agent(self, tmp_path):
|
||||
"""Should return per-agent memory file path when agent_name is provided."""
|
||||
|
||||
def mock_get_paths():
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.agent_memory_file.return_value = tmp_path / "agents" / "test-agent" / "memory.json"
|
||||
@@ -68,9 +71,7 @@ class TestFileMemoryStorage:
|
||||
path = storage._get_memory_file_path("test-agent")
|
||||
assert path == tmp_path / "agents" / "test-agent" / "memory.json"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_name", ["", "../etc/passwd", "agent/name", "agent\\name", "agent name", "agent@123", "agent_name"]
|
||||
)
|
||||
@pytest.mark.parametrize("invalid_name", ["", "../etc/passwd", "agent/name", "agent\\name", "agent name", "agent@123", "agent_name"])
|
||||
def test_validate_agent_name_invalid(self, invalid_name):
|
||||
"""Should raise ValueError for invalid agent names that don't match the pattern."""
|
||||
storage = FileMemoryStorage()
|
||||
@@ -79,6 +80,7 @@ class TestFileMemoryStorage:
|
||||
|
||||
def test_load_creates_empty_memory(self, tmp_path):
|
||||
"""Should create empty memory when file doesn't exist."""
|
||||
|
||||
def mock_get_paths():
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.memory_file = tmp_path / "non_existent_memory.json"
|
||||
@@ -125,10 +127,10 @@ class TestFileMemoryStorage:
|
||||
# First load
|
||||
memory1 = storage.load()
|
||||
assert memory1["facts"][0]["content"] == "initial fact"
|
||||
|
||||
|
||||
# Update file directly
|
||||
memory_file.write_text('{"version": "1.0", "facts": [{"content": "updated fact"}]}')
|
||||
|
||||
|
||||
# Reload should get updated data
|
||||
memory2 = storage.reload()
|
||||
assert memory2["facts"][0]["content"] == "updated fact"
|
||||
@@ -141,6 +143,7 @@ class TestGetMemoryStorage:
|
||||
def reset_storage_instance(self):
|
||||
"""Reset the global storage instance before and after each test."""
|
||||
import deerflow.agents.memory.storage as storage_mod
|
||||
|
||||
storage_mod._storage_instance = None
|
||||
yield
|
||||
storage_mod._storage_instance = None
|
||||
@@ -167,6 +170,7 @@ class TestGetMemoryStorage:
|
||||
def test_get_memory_storage_thread_safety(self):
|
||||
"""Should safely initialize the singleton even with concurrent calls."""
|
||||
results = []
|
||||
|
||||
def get_storage():
|
||||
# get_memory_storage is called concurrently from multiple threads while
|
||||
# get_memory_config is patched once around thread creation. This verifies
|
||||
|
||||
Reference in New Issue
Block a user