mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
refactor: thread app_config through lead and subagent task path (#2666)
* refactor: thread app config through lead prompt * fix: honor explicit app config across runtime paths * style: format subagent executor tests * fix: thread resolved app config and guard subagents-only fallback Address two PR review findings: 1. _create_summarization_middleware passed the original (possibly None) app_config into create_chat_model, forcing the model factory back to ambient get_app_config() and risking config drift between the middleware's resolved view and the model's view. Pass the resolved AppConfig instance through end-to-end. 2. get_available_subagent_names accepted Any-typed config and forwarded it to is_host_bash_allowed, which reads ``.sandbox``. A SubagentsAppConfig (also accepted upstream as a sum-type input) has no ``.sandbox`` attribute and would be silently treated as "no sandbox configured", incorrectly disabling the bash subagent. Guard on hasattr and fall back to ambient lookup otherwise. Adds regression tests for both paths. * chore: simplify hasattr guard and tighten regression tests - Collapse if/else into ternary in get_available_subagent_names; hasattr(None, ...) is False so the explicit None check was redundant. - Drop comments that narrate the change rather than explain non-obvious WHY (test names already convey intent). - Replace stringly-typed sentinel "no-arg" in regression test with direct args tuple comparison. --------- Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
This commit is contained in:
@@ -19,8 +19,6 @@ from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddlewar
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.agents_config import load_agent_config, validate_agent_name
|
||||
from deerflow.config.app_config import AppConfig, get_app_config
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.config.summarization_config import get_summarization_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -52,7 +50,8 @@ def _resolve_model_name(requested_model_name: str | None = None, *, app_config:
|
||||
|
||||
def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> DeerFlowSummarizationMiddleware | None:
|
||||
"""Create and configure the summarization middleware from config."""
|
||||
config = get_summarization_config()
|
||||
resolved_app_config = app_config or get_app_config()
|
||||
config = resolved_app_config.summarization
|
||||
|
||||
if not config.enabled:
|
||||
return None
|
||||
@@ -73,9 +72,9 @@ def _create_summarization_middleware(*, app_config: AppConfig | None = None) ->
|
||||
# as middleware rather than lead_agent (SummarizationMiddleware is a
|
||||
# LangChain built-in, so we tag the model at creation time).
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config)
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=resolved_app_config)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False, app_config=app_config)
|
||||
model = create_chat_model(thinking_enabled=False, app_config=resolved_app_config)
|
||||
model = model.with_config(tags=["middleware:summarize"])
|
||||
|
||||
# Prepare kwargs
|
||||
@@ -92,18 +91,13 @@ def _create_summarization_middleware(*, app_config: AppConfig | None = None) ->
|
||||
kwargs["summary_prompt"] = config.summary_prompt
|
||||
|
||||
hooks: list[BeforeSummarizationHook] = []
|
||||
if get_memory_config().enabled:
|
||||
if resolved_app_config.memory.enabled:
|
||||
hooks.append(memory_flush_hook)
|
||||
|
||||
# The logic below relies on two assumptions holding true: this factory is
|
||||
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
|
||||
# config is not expected to change after startup.
|
||||
try:
|
||||
resolved_app_config = app_config or get_app_config()
|
||||
skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills"
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve skills container path; falling back to default")
|
||||
skills_container_path = "/mnt/skills"
|
||||
skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills"
|
||||
|
||||
return DeerFlowSummarizationMiddleware(
|
||||
**kwargs,
|
||||
@@ -279,10 +273,10 @@ def _build_middlewares(
|
||||
middlewares.append(TokenUsageMiddleware())
|
||||
|
||||
# Add TitleMiddleware
|
||||
middlewares.append(TitleMiddleware())
|
||||
middlewares.append(TitleMiddleware(app_config=resolved_app_config))
|
||||
|
||||
# Add MemoryMiddleware (after TitleMiddleware)
|
||||
middlewares.append(MemoryMiddleware(agent_name=agent_name))
|
||||
middlewares.append(MemoryMiddleware(agent_name=agent_name, memory_config=resolved_app_config.memory))
|
||||
|
||||
# Add ViewImageMiddleware only if the current model supports vision.
|
||||
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
|
||||
@@ -316,7 +310,9 @@ def _build_middlewares(
|
||||
|
||||
def make_lead_agent(config: RunnableConfig):
|
||||
"""LangGraph graph factory; keep the signature compatible with LangGraph Server."""
|
||||
return _make_lead_agent(config, app_config=get_app_config())
|
||||
runtime_config = _get_runtime_config(config)
|
||||
runtime_app_config = runtime_config.get("app_config")
|
||||
return _make_lead_agent(config, app_config=runtime_app_config or get_app_config())
|
||||
|
||||
|
||||
def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
|
||||
@@ -158,7 +158,7 @@ Skip simple one-off tasks.
|
||||
"""
|
||||
|
||||
|
||||
def _build_available_subagents_description(available_names: list[str], bash_available: bool) -> str:
|
||||
def _build_available_subagents_description(available_names: list[str], bash_available: bool, *, app_config: AppConfig | None = None) -> str:
|
||||
"""Dynamically build subagent type descriptions from registry.
|
||||
|
||||
Mirrors Codex's pattern where agent_type_description is dynamically generated
|
||||
@@ -180,7 +180,7 @@ def _build_available_subagents_description(available_names: list[str], bash_avai
|
||||
if name in builtin_descriptions:
|
||||
lines.append(f"- **{name}**: {builtin_descriptions[name]}")
|
||||
else:
|
||||
config = get_subagent_config(name)
|
||||
config = get_subagent_config(name, app_config=app_config)
|
||||
if config is not None:
|
||||
desc = config.description.split("\n")[0].strip() # First line only for brevity
|
||||
lines.append(f"- **{name}**: {desc}")
|
||||
@@ -188,7 +188,7 @@ def _build_available_subagents_description(available_names: list[str], bash_avai
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _build_subagent_section(max_concurrent: int) -> str:
|
||||
def _build_subagent_section(max_concurrent: int, *, app_config: AppConfig | None = None) -> str:
|
||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||
|
||||
Args:
|
||||
@@ -198,12 +198,12 @@ def _build_subagent_section(max_concurrent: int) -> str:
|
||||
Formatted subagent section string.
|
||||
"""
|
||||
n = max_concurrent
|
||||
available_names = get_available_subagent_names()
|
||||
available_names = get_available_subagent_names(app_config=app_config) if app_config is not None else get_available_subagent_names()
|
||||
bash_available = "bash" in available_names
|
||||
|
||||
# Dynamically build subagent type descriptions from registry (aligned with Codex's
|
||||
# agent_type_description pattern where all registered roles are listed in the tool spec).
|
||||
available_subagents = _build_available_subagents_description(available_names, bash_available)
|
||||
available_subagents = _build_available_subagents_description(available_names, bash_available, app_config=app_config)
|
||||
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
|
||||
direct_execution_example = (
|
||||
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
|
||||
@@ -530,21 +530,28 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_context(agent_name: str | None = None) -> str:
|
||||
def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig | None = None) -> str:
|
||||
"""Get memory context for injection into system prompt.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, loads per-agent memory. If None, loads global memory.
|
||||
app_config: Explicit application config. When provided, memory options
|
||||
are read from this value instead of the global config singleton.
|
||||
|
||||
Returns:
|
||||
Formatted memory context string wrapped in XML tags, or empty string if disabled.
|
||||
"""
|
||||
try:
|
||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
config = get_memory_config()
|
||||
if app_config is None:
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
config = get_memory_config()
|
||||
else:
|
||||
config = app_config.memory
|
||||
|
||||
if not config.enabled or not config.injection_enabled:
|
||||
return ""
|
||||
|
||||
@@ -558,8 +565,8 @@ def _get_memory_context(agent_name: str | None = None) -> str:
|
||||
{memory_content}
|
||||
</memory>
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.error("Failed to load memory context: %s", e)
|
||||
except Exception:
|
||||
logger.exception("Failed to load memory context")
|
||||
return ""
|
||||
|
||||
|
||||
@@ -599,15 +606,20 @@ def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_c
|
||||
"""Generate the skills prompt section with available skills list."""
|
||||
skills = _get_enabled_skills_for_config(app_config)
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
if app_config is None:
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = app_config or get_app_config()
|
||||
config = get_app_config()
|
||||
container_base_path = config.skills.container_path
|
||||
skill_evolution_enabled = config.skill_evolution.enabled
|
||||
except Exception:
|
||||
container_base_path = "/mnt/skills"
|
||||
skill_evolution_enabled = False
|
||||
else:
|
||||
config = app_config
|
||||
container_base_path = config.skills.container_path
|
||||
skill_evolution_enabled = config.skill_evolution.enabled
|
||||
except Exception:
|
||||
container_base_path = "/mnt/skills"
|
||||
skill_evolution_enabled = False
|
||||
|
||||
if not skills and not skill_evolution_enabled:
|
||||
return ""
|
||||
@@ -640,13 +652,17 @@ def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) ->
|
||||
"""
|
||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
if app_config is None:
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = app_config or get_app_config()
|
||||
if not config.tool_search.enabled:
|
||||
config = get_app_config()
|
||||
except Exception:
|
||||
return ""
|
||||
except Exception:
|
||||
else:
|
||||
config = app_config
|
||||
|
||||
if not config.tool_search.enabled:
|
||||
return ""
|
||||
|
||||
registry = get_deferred_registry()
|
||||
@@ -657,15 +673,19 @@ def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) ->
|
||||
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
|
||||
|
||||
|
||||
def _build_acp_section() -> str:
|
||||
def _build_acp_section(*, app_config: AppConfig | None = None) -> str:
|
||||
"""Build the ACP agent prompt section, only if ACP agents are configured."""
|
||||
try:
|
||||
from deerflow.config.acp_config import get_acp_agents
|
||||
if app_config is None:
|
||||
try:
|
||||
from deerflow.config.acp_config import get_acp_agents
|
||||
|
||||
agents = get_acp_agents()
|
||||
if not agents:
|
||||
agents = get_acp_agents()
|
||||
except Exception:
|
||||
return ""
|
||||
except Exception:
|
||||
else:
|
||||
agents = getattr(app_config, "acp_agents", {}) or {}
|
||||
|
||||
if not agents:
|
||||
return ""
|
||||
|
||||
return (
|
||||
@@ -679,14 +699,18 @@ def _build_acp_section() -> str:
|
||||
|
||||
def _build_custom_mounts_section(*, app_config: AppConfig | None = None) -> str:
|
||||
"""Build a prompt section for explicitly configured sandbox mounts."""
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
if app_config is None:
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = app_config or get_app_config()
|
||||
mounts = config.sandbox.mounts or []
|
||||
except Exception:
|
||||
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
|
||||
return ""
|
||||
config = get_app_config()
|
||||
except Exception:
|
||||
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
|
||||
return ""
|
||||
else:
|
||||
config = app_config
|
||||
|
||||
mounts = config.sandbox.mounts or []
|
||||
|
||||
if not mounts:
|
||||
return ""
|
||||
@@ -709,11 +733,11 @@ def apply_prompt_template(
|
||||
app_config: AppConfig | None = None,
|
||||
) -> str:
|
||||
# Get memory context
|
||||
memory_context = _get_memory_context(agent_name)
|
||||
memory_context = _get_memory_context(agent_name, app_config=app_config)
|
||||
|
||||
# Include subagent section only if enabled (from runtime parameter)
|
||||
n = max_concurrent_subagents
|
||||
subagent_section = _build_subagent_section(n) if subagent_enabled else ""
|
||||
subagent_section = _build_subagent_section(n, app_config=app_config) if subagent_enabled else ""
|
||||
|
||||
# Add subagent reminder to critical_reminders if enabled
|
||||
subagent_reminder = (
|
||||
@@ -740,7 +764,7 @@ def apply_prompt_template(
|
||||
deferred_tools_section = get_deferred_tools_prompt_section(app_config=app_config)
|
||||
|
||||
# Build ACP agent section only if ACP agents are configured
|
||||
acp_section = _build_acp_section()
|
||||
acp_section = _build_acp_section(app_config=app_config)
|
||||
custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
|
||||
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Middleware for memory mechanism."""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
@@ -13,6 +13,9 @@ from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -34,14 +37,17 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
|
||||
state_schema = MemoryMiddlewareState
|
||||
|
||||
def __init__(self, agent_name: str | None = None):
|
||||
def __init__(self, agent_name: str | None = None, *, memory_config: "MemoryConfig | None" = None):
|
||||
"""Initialize the MemoryMiddleware.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
memory_config: Explicit memory config. When omitted, legacy global
|
||||
config fallback is used.
|
||||
"""
|
||||
super().__init__()
|
||||
self._agent_name = agent_name
|
||||
self._memory_config = memory_config
|
||||
|
||||
@override
|
||||
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
@@ -54,7 +60,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
Returns:
|
||||
None (no state changes needed from this middleware).
|
||||
"""
|
||||
config = get_memory_config()
|
||||
config = self._memory_config or get_memory_config()
|
||||
if not config.enabled:
|
||||
return None
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, NotRequired, override
|
||||
from typing import TYPE_CHECKING, Any, NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
@@ -12,6 +12,10 @@ from langgraph.runtime import Runtime
|
||||
from deerflow.config.title_config import get_title_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -26,6 +30,18 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
state_schema = TitleMiddlewareState
|
||||
|
||||
def __init__(self, *, app_config: "AppConfig | None" = None, title_config: "TitleConfig | None" = None):
|
||||
super().__init__()
|
||||
self._app_config = app_config
|
||||
self._title_config = title_config
|
||||
|
||||
def _get_title_config(self):
|
||||
if self._title_config is not None:
|
||||
return self._title_config
|
||||
if self._app_config is not None:
|
||||
return self._app_config.title
|
||||
return get_title_config()
|
||||
|
||||
def _normalize_content(self, content: object) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
@@ -47,7 +63,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
||||
"""Check if we should generate a title for this thread."""
|
||||
config = get_title_config()
|
||||
config = self._get_title_config()
|
||||
if not config.enabled:
|
||||
return False
|
||||
|
||||
@@ -72,7 +88,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
|
||||
"""
|
||||
config = get_title_config()
|
||||
config = self._get_title_config()
|
||||
messages = state.get("messages", [])
|
||||
|
||||
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
||||
@@ -94,14 +110,14 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
def _parse_title(self, content: object) -> str:
|
||||
"""Normalize model output into a clean title string."""
|
||||
config = get_title_config()
|
||||
config = self._get_title_config()
|
||||
title_content = self._normalize_content(content)
|
||||
title_content = self._strip_think_tags(title_content)
|
||||
title = title_content.strip().strip('"').strip("'")
|
||||
return title[: config.max_chars] if len(title) > config.max_chars else title
|
||||
|
||||
def _fallback_title(self, user_msg: str) -> str:
|
||||
config = get_title_config()
|
||||
config = self._get_title_config()
|
||||
fallback_chars = min(config.max_chars, 50)
|
||||
if len(user_msg) > fallback_chars:
|
||||
return user_msg[:fallback_chars].rstrip() + "..."
|
||||
@@ -135,14 +151,17 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
config = get_title_config()
|
||||
config = self._get_title_config()
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
|
||||
try:
|
||||
model_kwargs = {"thinking_enabled": False}
|
||||
if self._app_config is not None:
|
||||
model_kwargs["app_config"] = self._app_config
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
model = create_chat_model(name=config.model_name, **model_kwargs)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
model = create_chat_model(**model_kwargs)
|
||||
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
||||
title = self._parse_title(response.content)
|
||||
if title:
|
||||
|
||||
Reference in New Issue
Block a user