refactor: thread release config through lead path (#2612)

Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
This commit is contained in:
greatmengqi
2026-04-28 14:53:18 +08:00
committed by GitHub
parent 69649d8aae
commit e82940c03d
20 changed files with 325 additions and 179 deletions
+1
View File
@@ -148,6 +148,7 @@ def get_run_context(request: Request) -> RunContext:
event_store=get_run_event_store(request), event_store=get_run_event_store(request),
run_events_config=getattr(config, "run_events", None), run_events_config=getattr(config, "run_events", None),
thread_store=get_thread_store(request), thread_store=get_thread_store(request),
app_config=config,
) )
+5 -6
View File
@@ -1,7 +1,8 @@
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from deerflow.config import get_app_config from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
router = APIRouter(prefix="/api", tags=["models"]) router = APIRouter(prefix="/api", tags=["models"])
@@ -36,7 +37,7 @@ class ModelsListResponse(BaseModel):
summary="List All Models", summary="List All Models",
description="Retrieve a list of all available AI models configured in the system.", description="Retrieve a list of all available AI models configured in the system.",
) )
async def list_models() -> ModelsListResponse: async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResponse:
"""List all available models from configuration. """List all available models from configuration.
Returns model information suitable for frontend display, Returns model information suitable for frontend display,
@@ -72,7 +73,6 @@ async def list_models() -> ModelsListResponse:
} }
``` ```
""" """
config = get_app_config()
models = [ models = [
ModelResponse( ModelResponse(
name=model.name, name=model.name,
@@ -96,7 +96,7 @@ async def list_models() -> ModelsListResponse:
summary="Get Model Details", summary="Get Model Details",
description="Retrieve detailed information about a specific AI model by its name.", description="Retrieve detailed information about a specific AI model by its name.",
) )
async def get_model(model_name: str) -> ModelResponse: async def get_model(model_name: str, config: AppConfig = Depends(get_config)) -> ModelResponse:
"""Get a specific model by name. """Get a specific model by name.
Args: Args:
@@ -118,7 +118,6 @@ async def get_model(model_name: str) -> ModelResponse:
} }
``` ```
""" """
config = get_app_config()
model = config.get_model_config(model_name) model = config.get_model_config(model_name)
if model is None: if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
+37 -33
View File
@@ -4,11 +4,13 @@ import logging
import shutil import shutil
from pathlib import Path from pathlib import Path
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.deps import get_config
from app.gateway.path_utils import resolve_thread_virtual_path from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.skills import Skill, load_skills from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
@@ -101,9 +103,9 @@ def _skill_to_response(skill: Skill) -> SkillResponse:
summary="List All Skills", summary="List All Skills",
description="Retrieve a list of all available skills from both public and custom directories.", description="Retrieve a list of all available skills from both public and custom directories.",
) )
async def list_skills() -> SkillsListResponse: async def list_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try: try:
skills = load_skills(enabled_only=False) skills = load_skills(enabled_only=False, app_config=config)
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills]) return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e: except Exception as e:
logger.error(f"Failed to load skills: {e}", exc_info=True) logger.error(f"Failed to load skills: {e}", exc_info=True)
@@ -136,9 +138,9 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills") @router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
async def list_custom_skills() -> SkillsListResponse: async def list_custom_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try: try:
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"] skills = [skill for skill in load_skills(enabled_only=False, app_config=config) if skill.category == "custom"]
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills]) return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e: except Exception as e:
logger.error("Failed to list custom skills: %s", e, exc_info=True) logger.error("Failed to list custom skills: %s", e, exc_info=True)
@@ -146,13 +148,13 @@ async def list_custom_skills() -> SkillsListResponse:
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content") @router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse: async def get_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try: try:
skills = load_skills(enabled_only=False) skills = load_skills(enabled_only=False, app_config=config)
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None) skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
if skill is None: if skill is None:
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found") raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name)) return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name, app_config=config))
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -161,14 +163,14 @@ async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill") @router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse: async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try: try:
ensure_custom_skill_is_editable(skill_name) ensure_custom_skill_is_editable(skill_name, app_config=config)
validate_skill_markdown_content(skill_name, request.content) validate_skill_markdown_content(skill_name, request.content)
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md") scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md", app_config=config)
if scan.decision == "block": if scan.decision == "block":
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}") raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md" skill_file = get_custom_skill_dir(skill_name, app_config=config) / "SKILL.md"
prev_content = skill_file.read_text(encoding="utf-8") prev_content = skill_file.read_text(encoding="utf-8")
atomic_write(skill_file, request.content) atomic_write(skill_file, request.content)
append_history( append_history(
@@ -182,9 +184,10 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
"new_content": request.content, "new_content": request.content,
"scanner": {"decision": scan.decision, "reason": scan.reason}, "scanner": {"decision": scan.decision, "reason": scan.reason},
}, },
app_config=config,
) )
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name) return await get_custom_skill(skill_name, config)
except HTTPException: except HTTPException:
raise raise
except FileNotFoundError as e: except FileNotFoundError as e:
@@ -197,11 +200,11 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill") @router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
async def delete_custom_skill(skill_name: str) -> dict[str, bool]: async def delete_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> dict[str, bool]:
try: try:
ensure_custom_skill_is_editable(skill_name) ensure_custom_skill_is_editable(skill_name, app_config=config)
skill_dir = get_custom_skill_dir(skill_name) skill_dir = get_custom_skill_dir(skill_name, app_config=config)
prev_content = read_custom_skill_content(skill_name) prev_content = read_custom_skill_content(skill_name, app_config=config)
try: try:
append_history( append_history(
skill_name, skill_name,
@@ -214,6 +217,7 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
"new_content": None, "new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."}, "scanner": {"decision": "allow", "reason": "Deletion requested."},
}, },
app_config=config,
) )
except OSError as e: except OSError as e:
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}: if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
@@ -232,11 +236,11 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History") @router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse: async def get_custom_skill_history(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillHistoryResponse:
try: try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists(): if not custom_skill_exists(skill_name, app_config=config) and not get_skill_history_file(skill_name, app_config=config).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found") raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillHistoryResponse(history=read_history(skill_name)) return CustomSkillHistoryResponse(history=read_history(skill_name, app_config=config))
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -245,11 +249,11 @@ async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryRespons
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill") @router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse: async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try: try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists(): if not custom_skill_exists(skill_name, app_config=config) and not get_skill_history_file(skill_name, app_config=config).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found") raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
history = read_history(skill_name) history = read_history(skill_name, app_config=config)
if not history: if not history:
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history") raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
record = history[request.history_index] record = history[request.history_index]
@@ -257,8 +261,8 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
if target_content is None: if target_content is None:
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to") raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
validate_skill_markdown_content(skill_name, target_content) validate_skill_markdown_content(skill_name, target_content)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md") scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md", app_config=config)
skill_file = get_custom_skill_file(skill_name) skill_file = get_custom_skill_file(skill_name, app_config=config)
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
history_entry = { history_entry = {
"action": "rollback", "action": "rollback",
@@ -271,12 +275,12 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
"scanner": {"decision": scan.decision, "reason": scan.reason}, "scanner": {"decision": scan.decision, "reason": scan.reason},
} }
if scan.decision == "block": if scan.decision == "block":
append_history(skill_name, history_entry) append_history(skill_name, history_entry, app_config=config)
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}") raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
atomic_write(skill_file, target_content) atomic_write(skill_file, target_content)
append_history(skill_name, history_entry) append_history(skill_name, history_entry, app_config=config)
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name) return await get_custom_skill(skill_name, config)
except HTTPException: except HTTPException:
raise raise
except IndexError: except IndexError:
@@ -296,9 +300,9 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
summary="Get Skill Details", summary="Get Skill Details",
description="Retrieve detailed information about a specific skill by its name.", description="Retrieve detailed information about a specific skill by its name.",
) )
async def get_skill(skill_name: str) -> SkillResponse: async def get_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> SkillResponse:
try: try:
skills = load_skills(enabled_only=False) skills = load_skills(enabled_only=False, app_config=config)
skill = next((s for s in skills if s.name == skill_name), None) skill = next((s for s in skills if s.name == skill_name), None)
if skill is None: if skill is None:
@@ -318,9 +322,9 @@ async def get_skill(skill_name: str) -> SkillResponse:
summary="Update Skill", summary="Update Skill",
description="Update a skill's enabled status by modifying the extensions_config.json file.", description="Update a skill's enabled status by modifying the extensions_config.json file.",
) )
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse: async def update_skill(skill_name: str, request: SkillUpdateRequest, config: AppConfig = Depends(get_config)) -> SkillResponse:
try: try:
skills = load_skills(enabled_only=False) skills = load_skills(enabled_only=False, app_config=config)
skill = next((s for s in skills if s.name == skill_name), None) skill = next((s for s in skills if s.name == skill_name), None)
if skill is None: if skill is None:
@@ -346,7 +350,7 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
reload_extensions_config() reload_extensions_config()
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async()
skills = load_skills(enabled_only=False) skills = load_skills(enabled_only=False, app_config=config)
updated_skill = next((s for s in skills if s.name == skill_name), None) updated_skill = next((s for s in skills if s.name == skill_name), None)
if updated_skill is None: if updated_skill is None:
+10 -3
View File
@@ -1,11 +1,13 @@
import json import json
import logging import logging
from fastapi import APIRouter, Request from fastapi import APIRouter, Depends, Request
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.authz import require_permission from app.gateway.authz import require_permission
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -100,7 +102,12 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
description="Generate short follow-up questions a user might ask next, based on recent conversation context.", description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
) )
@require_permission("threads", "read", owner_check=True) @require_permission("threads", "read", owner_check=True)
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse: async def generate_suggestions(
thread_id: str,
body: SuggestionsRequest,
request: Request,
config: AppConfig = Depends(get_config),
) -> SuggestionsResponse:
if not body.messages: if not body.messages:
return SuggestionsResponse(suggestions=[]) return SuggestionsResponse(suggestions=[])
@@ -122,7 +129,7 @@ async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions" user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try: try:
model = create_chat_model(name=body.model_name, thinking_enabled=False) model = create_chat_model(name=body.model_name, thinking_enabled=False, app_config=config)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"}) response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"})
raw = _extract_response_text(response.content) raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or [] suggestions = _parse_json_string_list(raw) or []
+9 -8
View File
@@ -4,11 +4,12 @@ import logging
import os import os
import stat import stat
from fastapi import APIRouter, File, HTTPException, Request, UploadFile from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from app.gateway.authz import require_permission from app.gateway.authz import require_permission
from deerflow.config.app_config import get_app_config from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
@@ -60,23 +61,22 @@ def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False)) return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
def _get_uploads_config_value(key: str, default: object) -> object: def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access.""" """Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config() uploads_cfg = getattr(app_config, "uploads", None)
uploads_cfg = getattr(cfg, "uploads", None)
if isinstance(uploads_cfg, dict): if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default) return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default) return getattr(uploads_cfg, key, default)
def _auto_convert_documents_enabled() -> bool: def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
"""Return whether automatic host-side document conversion is enabled. """Return whether automatic host-side document conversion is enabled.
The secure default is disabled unless an operator explicitly opts in via The secure default is disabled unless an operator explicitly opts in via
uploads.auto_convert_documents in config.yaml. uploads.auto_convert_documents in config.yaml.
""" """
try: try:
raw = _get_uploads_config_value("auto_convert_documents", False) raw = _get_uploads_config_value(app_config, "auto_convert_documents", False)
if isinstance(raw, str): if isinstance(raw, str):
return raw.strip().lower() in {"1", "true", "yes", "on"} return raw.strip().lower() in {"1", "true", "yes", "on"}
return bool(raw) return bool(raw)
@@ -90,6 +90,7 @@ async def upload_files(
thread_id: str, thread_id: str,
request: Request, request: Request,
files: list[UploadFile] = File(...), files: list[UploadFile] = File(...),
config: AppConfig = Depends(get_config),
) -> UploadResponse: ) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory.""" """Upload multiple files to a thread's uploads directory."""
if not files: if not files:
@@ -108,7 +109,7 @@ async def upload_files(
if sync_to_sandbox: if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id) sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id) sandbox = sandbox_provider.get(sandbox_id)
auto_convert_documents = _auto_convert_documents_enabled() auto_convert_documents = _auto_convert_documents_enabled(config)
for file in files: for file in files:
if not file.filename: if not file.filename:
@@ -18,7 +18,7 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import build_lea
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config, validate_agent_name from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.config.memory_config import get_memory_config from deerflow.config.memory_config import get_memory_config
from deerflow.config.summarization_config import get_summarization_config from deerflow.config.summarization_config import get_summarization_config
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
@@ -35,9 +35,9 @@ def _get_runtime_config(config: RunnableConfig) -> dict:
return cfg return cfg
def _resolve_model_name(requested_model_name: str | None = None) -> str: def _resolve_model_name(requested_model_name: str | None = None, *, app_config: AppConfig | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured.""" """Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config() app_config = app_config or get_app_config()
default_model_name = app_config.models[0].name if app_config.models else None default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None: if default_model_name is None:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.") raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
@@ -50,7 +50,7 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
return default_model_name return default_model_name
def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None: def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> DeerFlowSummarizationMiddleware | None:
"""Create and configure the summarization middleware from config.""" """Create and configure the summarization middleware from config."""
config = get_summarization_config() config = get_summarization_config()
@@ -73,9 +73,9 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
# as middleware rather than lead_agent (SummarizationMiddleware is a # as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time). # LangChain built-in, so we tag the model at creation time).
if config.model_name: if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False) model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config)
else: else:
model = create_chat_model(thinking_enabled=False) model = create_chat_model(thinking_enabled=False, app_config=app_config)
model = model.with_config(tags=["middleware:summarize"]) model = model.with_config(tags=["middleware:summarize"])
# Prepare kwargs # Prepare kwargs
@@ -99,7 +99,8 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime # the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
# config is not expected to change after startup. # config is not expected to change after startup.
try: try:
skills_container_path = get_app_config().skills.container_path or "/mnt/skills" resolved_app_config = app_config or get_app_config()
skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills"
except Exception: except Exception:
logger.exception("Failed to resolve skills container path; falling back to default") logger.exception("Failed to resolve skills container path; falling back to default")
skills_container_path = "/mnt/skills" skills_container_path = "/mnt/skills"
@@ -240,7 +241,14 @@ Being proactive with task management demonstrates thoroughness and ensures all r
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM # ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages # ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
# ClarificationMiddleware should be last to intercept clarification requests after model calls # ClarificationMiddleware should be last to intercept clarification requests after model calls
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None): def _build_middlewares(
config: RunnableConfig,
model_name: str | None,
agent_name: str | None = None,
custom_middlewares: list[AgentMiddleware] | None = None,
*,
app_config: AppConfig | None = None,
):
"""Build middleware chain based on runtime configuration. """Build middleware chain based on runtime configuration.
Args: Args:
@@ -252,9 +260,10 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
List of middleware instances. List of middleware instances.
""" """
middlewares = build_lead_runtime_middlewares(lazy_init=True) middlewares = build_lead_runtime_middlewares(lazy_init=True)
resolved_app_config = app_config or get_app_config()
# Add summarization middleware if enabled # Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware() summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
if summarization_middleware is not None: if summarization_middleware is not None:
middlewares.append(summarization_middleware) middlewares.append(summarization_middleware)
@@ -266,7 +275,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
middlewares.append(todo_list_middleware) middlewares.append(todo_list_middleware)
# Add TokenUsageMiddleware when token_usage tracking is enabled # Add TokenUsageMiddleware when token_usage tracking is enabled
if get_app_config().token_usage.enabled: if resolved_app_config.token_usage.enabled:
middlewares.append(TokenUsageMiddleware()) middlewares.append(TokenUsageMiddleware())
# Add TitleMiddleware # Add TitleMiddleware
@@ -277,13 +286,12 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
# Add ViewImageMiddleware only if the current model supports vision. # Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values. # Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = get_app_config() model_config = resolved_app_config.get_model_config(model_name) if model_name else None
model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision: if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware()) middlewares.append(ViewImageMiddleware())
# Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding # Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
if app_config.tool_search.enabled: if resolved_app_config.tool_search.enabled:
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
middlewares.append(DeferredToolFilterMiddleware()) middlewares.append(DeferredToolFilterMiddleware())
@@ -306,12 +314,13 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
return middlewares return middlewares
def make_lead_agent(config: RunnableConfig): def make_lead_agent(config: RunnableConfig, app_config: AppConfig | None = None):
# Lazy import to avoid circular dependency # Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent from deerflow.tools.builtins import setup_agent
cfg = _get_runtime_config(config) cfg = _get_runtime_config(config)
resolved_app_config = app_config or get_app_config()
thinking_enabled = cfg.get("thinking_enabled", True) thinking_enabled = cfg.get("thinking_enabled", True)
reasoning_effort = cfg.get("reasoning_effort", None) reasoning_effort = cfg.get("reasoning_effort", None)
@@ -327,10 +336,9 @@ def make_lead_agent(config: RunnableConfig):
agent_model_name = agent_config.model if agent_config and agent_config.model else None agent_model_name = agent_config.model if agent_config and agent_config.model else None
# Final model name resolution: request → agent config → global default, with fallback for unknown names # Final model name resolution: request → agent config → global default, with fallback for unknown names
model_name = _resolve_model_name(requested_model_name or agent_model_name) model_name = _resolve_model_name(requested_model_name or agent_model_name, app_config=resolved_app_config)
app_config = get_app_config() model_config = resolved_app_config.get_model_config(model_name)
model_config = app_config.get_model_config(model_name)
if model_config is None: if model_config is None:
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.") raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
@@ -369,20 +377,34 @@ def make_lead_agent(config: RunnableConfig):
if is_bootstrap: if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow # Special bootstrap agent with minimal prompt for initial custom agent creation flow
return create_agent( return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled), model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent], tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent],
middleware=_build_middlewares(config, model_name=model_name), middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])), system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
available_skills=set(["bootstrap"]),
app_config=resolved_app_config,
),
state_schema=ThreadState, state_schema=ThreadState,
) )
# Default lead agent (unchanged behavior) # Default lead agent (unchanged behavior)
return create_agent( return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort), model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config),
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled), tools=get_available_tools(
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name), model_name=model_name,
groups=agent_config.tool_groups if agent_config else None,
subagent_enabled=subagent_enabled,
app_config=resolved_app_config,
),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template( system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=agent_name,
available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None,
app_config=resolved_app_config,
), ),
state_schema=ThreadState, state_schema=ThreadState,
) )
@@ -1,14 +1,20 @@
from __future__ import annotations
import asyncio import asyncio
import logging import logging
import threading import threading
from datetime import datetime from datetime import datetime
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING
from deerflow.config.agents_config import load_agent_soul from deerflow.config.agents_config import load_agent_soul
from deerflow.skills import load_skills from deerflow.skills import load_skills
from deerflow.skills.types import Skill from deerflow.skills.types import Skill
from deerflow.subagents import get_available_subagent_names from deerflow.subagents import get_available_subagent_names
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0 _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
@@ -111,6 +117,19 @@ def _get_enabled_skills():
return [] return []
def _get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
"""Return enabled skills using the caller's config source.
When a concrete ``app_config`` is supplied, bypass the global enabled-skills
cache so the skill list and skill paths are resolved from the same config
object. This keeps request-scoped config injection consistent even while the
release branch still supports global fallback paths.
"""
if app_config is None:
return _get_enabled_skills()
return list(load_skills(enabled_only=True, app_config=app_config))
def _skill_mutability_label(category: str) -> str: def _skill_mutability_label(category: str) -> str:
return "[custom, editable]" if category == "custom" else "[built-in]" return "[custom, editable]" if category == "custom" else "[built-in]"
@@ -576,14 +595,14 @@ You have access to skills that provide optimized workflows for specific tasks. E
</skill_system>""" </skill_system>"""
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str: def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_config: AppConfig | None = None) -> str:
"""Generate the skills prompt section with available skills list.""" """Generate the skills prompt section with available skills list."""
skills = _get_enabled_skills() skills = _get_enabled_skills_for_config(app_config)
try: try:
from deerflow.config import get_app_config from deerflow.config import get_app_config
config = get_app_config() config = app_config or get_app_config()
container_base_path = config.skills.container_path container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled skill_evolution_enabled = config.skill_evolution.enabled
except Exception: except Exception:
@@ -612,7 +631,7 @@ def get_agent_soul(agent_name: str | None) -> str:
return "" return ""
def get_deferred_tools_prompt_section() -> str: def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
"""Generate <available-deferred-tools> block for the system prompt. """Generate <available-deferred-tools> block for the system prompt.
Lists only deferred tool names so the agent knows what exists Lists only deferred tool names so the agent knows what exists
@@ -624,7 +643,8 @@ def get_deferred_tools_prompt_section() -> str:
try: try:
from deerflow.config import get_app_config from deerflow.config import get_app_config
if not get_app_config().tool_search.enabled: config = app_config or get_app_config()
if not config.tool_search.enabled:
return "" return ""
except Exception: except Exception:
return "" return ""
@@ -657,12 +677,13 @@ def _build_acp_section() -> str:
) )
def _build_custom_mounts_section() -> str: def _build_custom_mounts_section(*, app_config: AppConfig | None = None) -> str:
"""Build a prompt section for explicitly configured sandbox mounts.""" """Build a prompt section for explicitly configured sandbox mounts."""
try: try:
from deerflow.config import get_app_config from deerflow.config import get_app_config
mounts = get_app_config().sandbox.mounts or [] config = app_config or get_app_config()
mounts = config.sandbox.mounts or []
except Exception: except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt") logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return "" return ""
@@ -679,7 +700,14 @@ def _build_custom_mounts_section() -> str:
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory" return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str: def apply_prompt_template(
subagent_enabled: bool = False,
max_concurrent_subagents: int = 3,
*,
agent_name: str | None = None,
available_skills: set[str] | None = None,
app_config: AppConfig | None = None,
) -> str:
# Get memory context # Get memory context
memory_context = _get_memory_context(agent_name) memory_context = _get_memory_context(agent_name)
@@ -706,14 +734,14 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
) )
# Get skills section # Get skills section
skills_section = get_skills_prompt_section(available_skills) skills_section = get_skills_prompt_section(available_skills, app_config=app_config)
# Get deferred tools section (tool_search) # Get deferred tools section (tool_search)
deferred_tools_section = get_deferred_tools_prompt_section() deferred_tools_section = get_deferred_tools_prompt_section(app_config=app_config)
# Build ACP agent section only if ACP agents are configured # Build ACP agent section only if ACP agents are configured
acp_section = _build_acp_section() acp_section = _build_acp_section()
custom_mounts_section = _build_custom_mounts_section() custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section) acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
# Format the prompt with dynamic skills and memory # Format the prompt with dynamic skills and memory
@@ -3,6 +3,7 @@ import logging
from langchain.chat_models import BaseChatModel from langchain.chat_models import BaseChatModel
from deerflow.config import get_app_config from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks from deerflow.tracing import build_tracing_callbacks
@@ -46,7 +47,7 @@ def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_con
model_settings_from_config["stream_usage"] = True model_settings_from_config["stream_usage"] = True
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel: def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *, app_config: AppConfig | None = None, **kwargs) -> BaseChatModel:
"""Create a chat model instance from the config. """Create a chat model instance from the config.
Args: Args:
@@ -55,7 +56,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
Returns: Returns:
A chat model instance. A chat model instance.
""" """
config = get_app_config() config = app_config or get_app_config()
if name is None: if name is None:
name = config.models[0].name name = config.models[0].name
model_config = config.get_model_config(name) model_config = config.get_model_config(name)
@@ -20,11 +20,13 @@ import copy
import inspect import inspect
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from deerflow.config.app_config import AppConfig
from deerflow.runtime.serialization import serialize from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge from deerflow.runtime.stream_bridge import StreamBridge
@@ -51,6 +53,27 @@ class RunContext:
event_store: Any | None = field(default=None) event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None) run_events_config: Any | None = field(default=None)
thread_store: Any | None = field(default=None) thread_store: Any | None = field(default=None)
app_config: AppConfig | None = field(default=None)
def _compute_agent_factory_supports_app_config(agent_factory: Any) -> bool:
try:
return "app_config" in inspect.signature(agent_factory).parameters
except (TypeError, ValueError):
return False
@lru_cache(maxsize=128)
def _cached_agent_factory_supports_app_config(agent_factory: Any) -> bool:
return _compute_agent_factory_supports_app_config(agent_factory)
def _agent_factory_supports_app_config(agent_factory: Any) -> bool:
try:
return _cached_agent_factory_supports_app_config(agent_factory)
except TypeError:
# Some callable instances are unhashable; fall back to a direct check.
return _compute_agent_factory_supports_app_config(agent_factory)
async def run_agent( async def run_agent(
@@ -163,7 +186,10 @@ async def run_agent(
config.setdefault("callbacks", []).append(journal) config.setdefault("callbacks", []).append(journal)
runnable_config = RunnableConfig(**config) runnable_config = RunnableConfig(**config)
agent = agent_factory(config=runnable_config) if ctx.app_config is not None and _agent_factory_supports_app_config(agent_factory):
agent = agent_factory(config=runnable_config, app_config=ctx.app_config)
else:
agent = agent_factory(config=runnable_config)
# 4. Attach checkpointer and store # 4. Attach checkpointer and store
if checkpointer is not None: if checkpointer is not None:
@@ -2,6 +2,8 @@ import logging
import os import os
from pathlib import Path from pathlib import Path
from deerflow.config.app_config import AppConfig
from .parser import parse_skill_file from .parser import parse_skill_file
from .types import Skill from .types import Skill
@@ -22,7 +24,7 @@ def get_skills_root_path() -> Path:
return skills_dir return skills_dir
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False) -> list[Skill]: def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False, *, app_config: AppConfig | None = None) -> list[Skill]:
""" """
Load all skills from the skills directory. Load all skills from the skills directory.
@@ -44,7 +46,7 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
try: try:
from deerflow.config import get_app_config from deerflow.config import get_app_config
config = get_app_config() config = app_config or get_app_config()
skills_path = config.skills.get_skills_path() skills_path = config.skills.get_skills_path()
except Exception: except Exception:
# Fallback to default if config fails # Fallback to default if config fails
@@ -10,6 +10,7 @@ from pathlib import Path
from typing import Any from typing import Any
from deerflow.config import get_app_config from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.skills.loader import load_skills from deerflow.skills.loader import load_skills
from deerflow.skills.validation import _validate_skill_frontmatter from deerflow.skills.validation import _validate_skill_frontmatter
@@ -20,16 +21,17 @@ ALLOWED_SUPPORT_SUBDIRS = {"references", "templates", "scripts", "assets"}
_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$") _SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
def get_skills_root_dir() -> Path: def get_skills_root_dir(*, app_config: AppConfig | None = None) -> Path:
return get_app_config().skills.get_skills_path() config = app_config or get_app_config()
return config.skills.get_skills_path()
def get_public_skills_dir() -> Path: def get_public_skills_dir(*, app_config: AppConfig | None = None) -> Path:
return get_skills_root_dir() / "public" return get_skills_root_dir(app_config=app_config) / "public"
def get_custom_skills_dir() -> Path: def get_custom_skills_dir(*, app_config: AppConfig | None = None) -> Path:
path = get_skills_root_dir() / "custom" path = get_skills_root_dir(app_config=app_config) / "custom"
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
return path return path
@@ -43,46 +45,46 @@ def validate_skill_name(name: str) -> str:
return normalized return normalized
def get_custom_skill_dir(name: str) -> Path: def get_custom_skill_dir(name: str, *, app_config: AppConfig | None = None) -> Path:
return get_custom_skills_dir() / validate_skill_name(name) return get_custom_skills_dir(app_config=app_config) / validate_skill_name(name)
def get_custom_skill_file(name: str) -> Path: def get_custom_skill_file(name: str, *, app_config: AppConfig | None = None) -> Path:
return get_custom_skill_dir(name) / SKILL_FILE_NAME return get_custom_skill_dir(name, app_config=app_config) / SKILL_FILE_NAME
def get_custom_skill_history_dir() -> Path: def get_custom_skill_history_dir(*, app_config: AppConfig | None = None) -> Path:
path = get_custom_skills_dir() / HISTORY_DIR_NAME path = get_custom_skills_dir(app_config=app_config) / HISTORY_DIR_NAME
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
return path return path
def get_skill_history_file(name: str) -> Path: def get_skill_history_file(name: str, *, app_config: AppConfig | None = None) -> Path:
return get_custom_skill_history_dir() / f"{validate_skill_name(name)}.jsonl" return get_custom_skill_history_dir(app_config=app_config) / f"{validate_skill_name(name)}.jsonl"
def get_public_skill_dir(name: str) -> Path: def get_public_skill_dir(name: str, *, app_config: AppConfig | None = None) -> Path:
return get_public_skills_dir() / validate_skill_name(name) return get_public_skills_dir(app_config=app_config) / validate_skill_name(name)
def custom_skill_exists(name: str) -> bool: def custom_skill_exists(name: str, *, app_config: AppConfig | None = None) -> bool:
return get_custom_skill_file(name).exists() return get_custom_skill_file(name, app_config=app_config).exists()
def public_skill_exists(name: str) -> bool: def public_skill_exists(name: str, *, app_config: AppConfig | None = None) -> bool:
return (get_public_skill_dir(name) / SKILL_FILE_NAME).exists() return (get_public_skill_dir(name, app_config=app_config) / SKILL_FILE_NAME).exists()
def ensure_custom_skill_is_editable(name: str) -> None: def ensure_custom_skill_is_editable(name: str, *, app_config: AppConfig | None = None) -> None:
if custom_skill_exists(name): if custom_skill_exists(name, app_config=app_config):
return return
if public_skill_exists(name): if public_skill_exists(name, app_config=app_config):
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.") raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
raise FileNotFoundError(f"Custom skill '{name}' not found.") raise FileNotFoundError(f"Custom skill '{name}' not found.")
def ensure_safe_support_path(name: str, relative_path: str) -> Path: def ensure_safe_support_path(name: str, relative_path: str, *, app_config: AppConfig | None = None) -> Path:
skill_dir = get_custom_skill_dir(name).resolve() skill_dir = get_custom_skill_dir(name, app_config=app_config).resolve()
if not relative_path or relative_path.endswith("/"): if not relative_path or relative_path.endswith("/"):
raise ValueError("Supporting file path must include a filename.") raise ValueError("Supporting file path must include a filename.")
relative = Path(relative_path) relative = Path(relative_path)
@@ -124,8 +126,8 @@ def atomic_write(path: Path, content: str) -> None:
tmp_path.replace(path) tmp_path.replace(path)
def append_history(name: str, record: dict[str, Any]) -> None: def append_history(name: str, record: dict[str, Any], *, app_config: AppConfig | None = None) -> None:
history_path = get_skill_history_file(name) history_path = get_skill_history_file(name, app_config=app_config)
history_path.parent.mkdir(parents=True, exist_ok=True) history_path.parent.mkdir(parents=True, exist_ok=True)
payload = { payload = {
"ts": datetime.now(UTC).isoformat(), "ts": datetime.now(UTC).isoformat(),
@@ -136,8 +138,8 @@ def append_history(name: str, record: dict[str, Any]) -> None:
f.write("\n") f.write("\n")
def read_history(name: str) -> list[dict[str, Any]]: def read_history(name: str, *, app_config: AppConfig | None = None) -> list[dict[str, Any]]:
history_path = get_skill_history_file(name) history_path = get_skill_history_file(name, app_config=app_config)
if not history_path.exists(): if not history_path.exists():
return [] return []
records: list[dict[str, Any]] = [] records: list[dict[str, Any]] = []
@@ -148,12 +150,12 @@ def read_history(name: str) -> list[dict[str, Any]]:
return records return records
def list_custom_skills() -> list: def list_custom_skills(*, app_config: AppConfig | None = None) -> list:
return [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"] return [skill for skill in load_skills(enabled_only=False, app_config=app_config) if skill.category == "custom"]
def read_custom_skill_content(name: str) -> str: def read_custom_skill_content(name: str, *, app_config: AppConfig | None = None) -> str:
skill_file = get_custom_skill_file(name) skill_file = get_custom_skill_file(name, app_config=app_config)
if not skill_file.exists(): if not skill_file.exists():
raise FileNotFoundError(f"Custom skill '{name}' not found.") raise FileNotFoundError(f"Custom skill '{name}' not found.")
return skill_file.read_text(encoding="utf-8") return skill_file.read_text(encoding="utf-8")
@@ -8,6 +8,7 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from deerflow.config import get_app_config from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,7 +36,7 @@ def _extract_json_object(raw: str) -> dict | None:
return None return None
async def scan_skill_content(content: str, *, executable: bool = False, location: str = "SKILL.md") -> ScanResult: async def scan_skill_content(content: str, *, executable: bool = False, location: str = "SKILL.md", app_config: AppConfig | None = None) -> ScanResult:
"""Screen skill content before it is written to disk.""" """Screen skill content before it is written to disk."""
rubric = ( rubric = (
"You are a security reviewer for AI agent skills. " "You are a security reviewer for AI agent skills. "
@@ -47,9 +48,9 @@ async def scan_skill_content(content: str, *, executable: bool = False, location
prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----" prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----"
try: try:
config = get_app_config() config = app_config or get_app_config()
model_name = config.skill_evolution.moderation_model_name model_name = config.skill_evolution.moderation_model_name
model = create_chat_model(name=model_name, thinking_enabled=False) if model_name else create_chat_model(thinking_enabled=False) model = create_chat_model(name=model_name, thinking_enabled=False, app_config=config) if model_name else create_chat_model(thinking_enabled=False, app_config=config)
response = await model.ainvoke( response = await model.ainvoke(
[ [
{"role": "system", "content": rubric}, {"role": "system", "content": rubric},
@@ -3,6 +3,7 @@ import logging
from langchain.tools import BaseTool from langchain.tools import BaseTool
from deerflow.config import get_app_config from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_variable from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
@@ -37,6 +38,8 @@ def get_available_tools(
include_mcp: bool = True, include_mcp: bool = True,
model_name: str | None = None, model_name: str | None = None,
subagent_enabled: bool = False, subagent_enabled: bool = False,
*,
app_config: AppConfig | None = None,
) -> list[BaseTool]: ) -> list[BaseTool]:
"""Get all available tools from config. """Get all available tools from config.
@@ -52,7 +55,7 @@ def get_available_tools(
Returns: Returns:
List of available tools. List of available tools.
""" """
config = get_app_config() config = app_config or get_app_config()
tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups] tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups]
# Do not expose host bash by default when LocalSandboxProvider is active. # Do not expose host bash by default when LocalSandboxProvider is active.
@@ -84,14 +84,15 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config) monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: []) monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: []) monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
captured: dict[str, object] = {} captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None): def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name captured["name"] = name
captured["thinking_enabled"] = thinking_enabled captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort captured["reasoning_effort"] = reasoning_effort
captured["app_config"] = app_config
return object() return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model) monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
@@ -110,6 +111,7 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
assert captured["name"] == "safe-model" assert captured["name"] == "safe-model"
assert captured["thinking_enabled"] is False assert captured["thinking_enabled"] is False
assert captured["app_config"] is app_config
assert result["model"] is not None assert result["model"] is not None
@@ -126,14 +128,15 @@ def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
get_available_tools = MagicMock(return_value=[]) get_available_tools = MagicMock(return_value=[])
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config) monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", get_available_tools) monkeypatch.setattr(tools_module, "get_available_tools", get_available_tools)
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: []) monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
captured: dict[str, object] = {} captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None): def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name captured["name"] = name
captured["thinking_enabled"] = thinking_enabled captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort captured["reasoning_effort"] = reasoning_effort
captured["app_config"] = app_config
return object() return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model) monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
@@ -156,8 +159,9 @@ def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
"name": "context-model", "name": "context-model",
"thinking_enabled": False, "thinking_enabled": False,
"reasoning_effort": "high", "reasoning_effort": "high",
"app_config": app_config,
} }
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True) get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True, app_config=app_config)
assert result["model"] is not None assert result["model"] is not None
@@ -198,10 +202,15 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
) )
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config) monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None) monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda **kwargs: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None) monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares({"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()]) middlewares = lead_agent_module._build_middlewares(
{"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}},
model_name="vision-model",
custom_middlewares=[MagicMock()],
app_config=app_config,
)
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares) assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
# verify the custom middleware is injected correctly # verify the custom middleware is injected correctly
@@ -222,18 +231,20 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
fake_model = MagicMock() fake_model = MagicMock()
fake_model.with_config.return_value = fake_model fake_model.with_config.return_value = fake_model
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None): def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name captured["name"] = name
captured["thinking_enabled"] = thinking_enabled captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort captured["reasoning_effort"] = reasoning_effort
captured["app_config"] = app_config
return fake_model return fake_model
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model) monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs) monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
middleware = lead_agent_module._create_summarization_middleware() middleware = lead_agent_module._create_summarization_middleware(app_config=_make_app_config([_make_model("model-masswork", supports_thinking=False)]))
assert captured["name"] == "model-masswork" assert captured["name"] == "model-masswork"
assert captured["thinking_enabled"] is False assert captured["thinking_enabled"] is False
assert captured["app_config"] is not None
assert middleware["model"] is fake_model assert middleware["model"] is fake_model
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"]) fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
+2 -2
View File
@@ -48,7 +48,7 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
) )
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config) monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: []) monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "") monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "") monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "") monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "") monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
@@ -66,7 +66,7 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
) )
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config) monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: []) monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "") monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "") monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "") monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "") monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
+19 -1
View File
@@ -100,6 +100,24 @@ def test_get_skills_prompt_section_cache_respects_skill_evolution_toggle(monkeyp
assert "Skill Self-Evolution" not in disabled_result assert "Skill Self-Evolution" not in disabled_result
def test_get_skills_prompt_section_uses_explicit_config_for_enabled_skills(monkeypatch):
explicit_config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/alt-skills"),
skill_evolution=SimpleNamespace(enabled=False),
)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [_make_skill("global-skill")])
monkeypatch.setattr(
"deerflow.agents.lead_agent.prompt.load_skills",
lambda enabled_only=True, app_config=None: [_make_skill("explicit-skill")] if app_config is explicit_config else [],
)
result = get_skills_prompt_section(app_config=explicit_config)
assert "explicit-skill" in result
assert "global-skill" not in result
def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch): def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
from unittest.mock import MagicMock from unittest.mock import MagicMock
@@ -107,7 +125,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
# Mock dependencies # Mock dependencies
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock()) monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock())
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model") monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model") monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: []) monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
+18 -1
View File
@@ -2,7 +2,7 @@ from unittest.mock import AsyncMock, call
import pytest import pytest
from deerflow.runtime.runs.worker import _rollback_to_pre_run_checkpoint from deerflow.runtime.runs.worker import _agent_factory_supports_app_config, _rollback_to_pre_run_checkpoint
class FakeCheckpointer: class FakeCheckpointer:
@@ -212,3 +212,20 @@ async def test_rollback_propagates_aput_writes_failure():
# aput succeeded, aput_writes was called but failed # aput succeeded, aput_writes was called but failed
checkpointer.aput.assert_awaited_once() checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_awaited_once() checkpointer.aput_writes.assert_awaited_once()
def test_agent_factory_supports_app_config_detects_supported_signature():
def factory(*, config, app_config=None):
return (config, app_config)
assert _agent_factory_supports_app_config(factory) is True
def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_fails(monkeypatch):
class BrokenCallable:
def __call__(self, **kwargs):
return kwargs
monkeypatch.setattr("deerflow.runtime.runs.worker.inspect.signature", lambda _obj: (_ for _ in ()).throw(ValueError("boom")))
assert _agent_factory_supports_app_config(BrokenCallable()) is False
+15 -14
View File
@@ -35,6 +35,13 @@ def _make_skill(name: str, *, enabled: bool) -> Skill:
) )
def _make_test_app(config) -> FastAPI:
app = FastAPI()
app.state.config = config
app.include_router(skills_router.router)
return app
def test_custom_skills_router_lifecycle(monkeypatch, tmp_path): def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
skills_root = tmp_path / "skills" skills_root = tmp_path / "skills"
custom_dir = skills_root / "custom" / "demo-skill" custom_dir = skills_root / "custom" / "demo-skill"
@@ -54,8 +61,7 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh) monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI() app = _make_test_app(config)
app.include_router(skills_router.router)
with TestClient(app) as client: with TestClient(app) as client:
response = client.get("/api/skills/custom") response = client.get("/api/skills/custom")
@@ -96,7 +102,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
) )
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config) monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config) monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
get_skill_history_file("demo-skill").write_text( get_skill_history_file("demo-skill", app_config=config).write_text(
'{"action":"human_edit","prev_content":' + json.dumps(original_content) + ',"new_content":' + json.dumps(edited_content) + "}\n", '{"action":"human_edit","prev_content":' + json.dumps(original_content) + ',"new_content":' + json.dumps(edited_content) + "}\n",
encoding="utf-8", encoding="utf-8",
) )
@@ -113,8 +119,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", _scan) monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", _scan)
app = FastAPI() app = _make_test_app(config)
app.include_router(skills_router.router)
with TestClient(app) as client: with TestClient(app) as client:
rollback_response = client.post("/api/skills/custom/demo-skill/rollback", json={"history_index": -1}) rollback_response = client.post("/api/skills/custom/demo-skill/rollback", json={"history_index": -1})
@@ -146,8 +151,7 @@ def test_custom_skill_delete_preserves_history_and_allows_restore(monkeypatch, t
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh) monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI() app = _make_test_app(config)
app.include_router(skills_router.router)
with TestClient(app) as client: with TestClient(app) as client:
delete_response = client.delete("/api/skills/custom/demo-skill") delete_response = client.delete("/api/skills/custom/demo-skill")
@@ -187,8 +191,7 @@ def test_custom_skill_delete_continues_when_history_write_is_readonly(monkeypatc
monkeypatch.setattr("app.gateway.routers.skills.append_history", _readonly_history) monkeypatch.setattr("app.gateway.routers.skills.append_history", _readonly_history)
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh) monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI() app = _make_test_app(config)
app.include_router(skills_router.router)
with TestClient(app) as client: with TestClient(app) as client:
delete_response = client.delete("/api/skills/custom/demo-skill") delete_response = client.delete("/api/skills/custom/demo-skill")
@@ -221,8 +224,7 @@ def test_custom_skill_delete_fails_when_skill_dir_removal_fails(monkeypatch, tmp
monkeypatch.setattr("app.gateway.routers.skills.shutil.rmtree", _fail_rmtree) monkeypatch.setattr("app.gateway.routers.skills.shutil.rmtree", _fail_rmtree)
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh) monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI() app = _make_test_app(config)
app.include_router(skills_router.router)
with TestClient(app) as client: with TestClient(app) as client:
delete_response = client.delete("/api/skills/custom/demo-skill") delete_response = client.delete("/api/skills/custom/demo-skill")
@@ -238,7 +240,7 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
enabled_state = {"value": True} enabled_state = {"value": True}
refresh_calls = [] refresh_calls = []
def _load_skills(*, enabled_only: bool): def _load_skills(*, enabled_only: bool, app_config=None):
skill = _make_skill("demo-skill", enabled=enabled_state["value"]) skill = _make_skill("demo-skill", enabled=enabled_state["value"])
if enabled_only and not skill.enabled: if enabled_only and not skill.enabled:
return [] return []
@@ -254,8 +256,7 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
monkeypatch.setattr(skills_router.ExtensionsConfig, "resolve_config_path", staticmethod(lambda: config_path)) monkeypatch.setattr(skills_router.ExtensionsConfig, "resolve_config_path", staticmethod(lambda: config_path))
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh) monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI() app = _make_test_app(SimpleNamespace())
app.include_router(skills_router.router)
with TestClient(app) as client: with TestClient(app) as client:
response = client.put("/api/skills/demo-skill", json={"enabled": False}) response = client.put("/api/skills/demo-skill", json={"enabled": False})
+5 -4
View File
@@ -1,4 +1,5 @@
import asyncio import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from app.gateway.routers import suggestions from app.gateway.routers import suggestions
@@ -48,7 +49,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
# Bypass the require_permission decorator (which needs request + # Bypass the require_permission decorator (which needs request +
# thread_store) — these tests cover the parsing logic. # thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
assert result.suggestions == ["Q1", "Q2", "Q3"] assert result.suggestions == ["Q1", "Q2", "Q3"]
fake_model.ainvoke.assert_awaited_once() fake_model.ainvoke.assert_awaited_once()
@@ -70,7 +71,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
# Bypass the require_permission decorator (which needs request + # Bypass the require_permission decorator (which needs request +
# thread_store) — these tests cover the parsing logic. # thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
assert result.suggestions == ["Q1", "Q2"] assert result.suggestions == ["Q1", "Q2"]
fake_model.ainvoke.assert_awaited_once() fake_model.ainvoke.assert_awaited_once()
@@ -92,7 +93,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
# Bypass the require_permission decorator (which needs request + # Bypass the require_permission decorator (which needs request +
# thread_store) — these tests cover the parsing logic. # thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
assert result.suggestions == ["Q1", "Q2"] assert result.suggestions == ["Q1", "Q2"]
fake_model.ainvoke.assert_awaited_once() fake_model.ainvoke.assert_awaited_once()
@@ -111,6 +112,6 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
# Bypass the require_permission decorator (which needs request + # Bypass the require_permission decorator (which needs request +
# thread_store) — these tests cover the parsing logic. # thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None, config=SimpleNamespace()))
assert result.suggestions == [] assert result.suggestions == []
+21 -20
View File
@@ -2,6 +2,7 @@ import asyncio
import stat import stat
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from _router_auth_helpers import call_unwrapped from _router_auth_helpers import call_unwrapped
@@ -26,7 +27,7 @@ def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_pat
patch.object(uploads, "get_sandbox_provider", return_value=provider), patch.object(uploads, "get_sandbox_provider", return_value=provider),
): ):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads")) file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file])) result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True assert result.success is True
assert len(result.files) == 1 assert len(result.files) == 1
@@ -49,7 +50,7 @@ def test_upload_files_skips_acquire_when_thread_data_is_mounted(tmp_path):
patch.object(uploads, "get_sandbox_provider", return_value=provider), patch.object(uploads, "get_sandbox_provider", return_value=provider),
): ):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads")) file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-mounted", request=MagicMock(), files=[file])) result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-mounted", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True assert result.success is True
assert (thread_uploads_dir / "notes.txt").read_bytes() == b"hello uploads" assert (thread_uploads_dir / "notes.txt").read_bytes() == b"hello uploads"
@@ -75,7 +76,7 @@ def test_upload_files_does_not_auto_convert_documents_by_default(tmp_path):
patch.object(uploads, "convert_file_to_markdown", AsyncMock()) as convert_mock, patch.object(uploads, "convert_file_to_markdown", AsyncMock()) as convert_mock,
): ):
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes")) file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file])) result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True assert result.success is True
assert len(result.files) == 1 assert len(result.files) == 1
@@ -108,7 +109,7 @@ def test_upload_files_syncs_non_local_sandbox_and_marks_markdown_file(tmp_path):
patch.object(uploads, "convert_file_to_markdown", AsyncMock(side_effect=fake_convert)), patch.object(uploads, "convert_file_to_markdown", AsyncMock(side_effect=fake_convert)),
): ):
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes")) file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file])) result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True assert result.success is True
assert len(result.files) == 1 assert len(result.files) == 1
@@ -147,7 +148,7 @@ def test_upload_files_makes_non_local_files_sandbox_writable(tmp_path):
patch.object(uploads, "_make_file_sandbox_writable") as make_writable, patch.object(uploads, "_make_file_sandbox_writable") as make_writable,
): ):
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes")) file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file])) result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True assert result.success is True
make_writable.assert_any_call(thread_uploads_dir / "report.pdf") make_writable.assert_any_call(thread_uploads_dir / "report.pdf")
@@ -171,7 +172,7 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
patch.object(uploads, "_make_file_sandbox_writable") as make_writable, patch.object(uploads, "_make_file_sandbox_writable") as make_writable,
): ):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads")) file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file])) result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True assert result.success is True
make_writable.assert_not_called() make_writable.assert_not_called()
@@ -222,13 +223,13 @@ def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path):
# These filenames must be rejected outright # These filenames must be rejected outright
for bad_name in ["..", "."]: for bad_name in ["..", "."]:
file = UploadFile(filename=bad_name, file=BytesIO(b"data")) file = UploadFile(filename=bad_name, file=BytesIO(b"data"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file])) result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True assert result.success is True
assert result.files == [], f"Expected no files for unsafe filename {bad_name!r}" assert result.files == [], f"Expected no files for unsafe filename {bad_name!r}"
# Path-traversal prefixes are stripped to the basename and accepted safely # Path-traversal prefixes are stripped to the basename and accepted safely
file = UploadFile(filename="../etc/passwd", file=BytesIO(b"data")) file = UploadFile(filename="../etc/passwd", file=BytesIO(b"data"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file])) result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True assert result.success is True
assert len(result.files) == 1 assert len(result.files) == 1
assert result.files[0]["filename"] == "passwd" assert result.files[0]["filename"] == "passwd"
@@ -252,16 +253,20 @@ def test_delete_uploaded_file_removes_generated_markdown_companion(tmp_path):
def test_auto_convert_documents_enabled_defaults_to_false_on_config_errors(): def test_auto_convert_documents_enabled_defaults_to_false_on_config_errors():
with patch.object(uploads, "get_app_config", side_effect=RuntimeError("boom")): class BrokenConfig:
assert uploads._auto_convert_documents_enabled() is False def __getattribute__(self, name):
if name == "uploads":
raise RuntimeError("boom")
return super().__getattribute__(name)
assert uploads._auto_convert_documents_enabled(BrokenConfig()) is False
def test_auto_convert_documents_enabled_reads_dict_backed_uploads_config(): def test_auto_convert_documents_enabled_reads_dict_backed_uploads_config():
cfg = MagicMock() cfg = MagicMock()
cfg.uploads = {"auto_convert_documents": True} cfg.uploads = {"auto_convert_documents": True}
with patch.object(uploads, "get_app_config", return_value=cfg): assert uploads._auto_convert_documents_enabled(cfg) is True
assert uploads._auto_convert_documents_enabled() is True
def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values(): def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values():
@@ -277,11 +282,7 @@ def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values
string_false_cfg = MagicMock() string_false_cfg = MagicMock()
string_false_cfg.uploads = MagicMock(auto_convert_documents="false") string_false_cfg.uploads = MagicMock(auto_convert_documents="false")
with patch.object(uploads, "get_app_config", return_value=false_cfg): assert uploads._auto_convert_documents_enabled(false_cfg) is False
assert uploads._auto_convert_documents_enabled() is False assert uploads._auto_convert_documents_enabled(true_cfg) is True
with patch.object(uploads, "get_app_config", return_value=true_cfg): assert uploads._auto_convert_documents_enabled(string_true_cfg) is True
assert uploads._auto_convert_documents_enabled() is True assert uploads._auto_convert_documents_enabled(string_false_cfg) is False
with patch.object(uploads, "get_app_config", return_value=string_true_cfg):
assert uploads._auto_convert_documents_enabled() is True
with patch.object(uploads, "get_app_config", return_value=string_false_cfg):
assert uploads._auto_convert_documents_enabled() is False