mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
8ba01dfd83
* refactor: thread app config through lead prompt * fix: honor explicit app config across runtime paths * style: format subagent executor tests * fix: thread resolved app config and guard subagents-only fallback Address two PR review findings: 1. _create_summarization_middleware passed the original (possibly None) app_config into create_chat_model, forcing the model factory back to ambient get_app_config() and risking config drift between the middleware's resolved view and the model's view. Pass the resolved AppConfig instance through end-to-end. 2. get_available_subagent_names accepted Any-typed config and forwarded it to is_host_bash_allowed, which reads ``.sandbox``. A SubagentsAppConfig (also accepted upstream as a sum-type input) has no ``.sandbox`` attribute and would be silently treated as "no sandbox configured", incorrectly disabling the bash subagent. Guard on hasattr and fall back to ambient lookup otherwise. Adds regression tests for both paths. * chore: simplify hasattr guard and tighten regression tests - Collapse if/else into ternary in get_available_subagent_names; hasattr(None, ...) is False so the explicit None check was redundant. - Drop comments that narrate the change rather than explain non-obvious WHY (test names already convey intent). - Replace stringly-typed sentinel "no-arg" in regression test with direct args tuple comparison. --------- Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
180 lines
6.9 KiB
Python
180 lines
6.9 KiB
Python
"""Middleware for automatic thread title generation."""
|
|
|
|
import logging
|
|
import re
|
|
from typing import TYPE_CHECKING, Any, NotRequired, override
|
|
|
|
from langchain.agents import AgentState
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langgraph.config import get_config
|
|
from langgraph.runtime import Runtime
|
|
|
|
from deerflow.config.title_config import get_title_config
|
|
from deerflow.models import create_chat_model
|
|
|
|
if TYPE_CHECKING:
|
|
from deerflow.config.app_config import AppConfig
|
|
from deerflow.config.title_config import TitleConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TitleMiddlewareState(AgentState):
|
|
"""Compatible with the `ThreadState` schema."""
|
|
|
|
title: NotRequired[str | None]
|
|
|
|
|
|
class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|
"""Automatically generate a title for the thread after the first user message."""
|
|
|
|
state_schema = TitleMiddlewareState
|
|
|
|
def __init__(self, *, app_config: "AppConfig | None" = None, title_config: "TitleConfig | None" = None):
|
|
super().__init__()
|
|
self._app_config = app_config
|
|
self._title_config = title_config
|
|
|
|
def _get_title_config(self):
|
|
if self._title_config is not None:
|
|
return self._title_config
|
|
if self._app_config is not None:
|
|
return self._app_config.title
|
|
return get_title_config()
|
|
|
|
def _normalize_content(self, content: object) -> str:
|
|
if isinstance(content, str):
|
|
return content
|
|
|
|
if isinstance(content, list):
|
|
parts = [self._normalize_content(item) for item in content]
|
|
return "\n".join(part for part in parts if part)
|
|
|
|
if isinstance(content, dict):
|
|
text_value = content.get("text")
|
|
if isinstance(text_value, str):
|
|
return text_value
|
|
|
|
nested_content = content.get("content")
|
|
if nested_content is not None:
|
|
return self._normalize_content(nested_content)
|
|
|
|
return ""
|
|
|
|
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
|
"""Check if we should generate a title for this thread."""
|
|
config = self._get_title_config()
|
|
if not config.enabled:
|
|
return False
|
|
|
|
# Check if thread already has a title in state
|
|
if state.get("title"):
|
|
return False
|
|
|
|
# Check if this is the first turn (has at least one user message and one assistant response)
|
|
messages = state.get("messages", [])
|
|
if len(messages) < 2:
|
|
return False
|
|
|
|
# Count user and assistant messages
|
|
user_messages = [m for m in messages if m.type == "human"]
|
|
assistant_messages = [m for m in messages if m.type == "ai"]
|
|
|
|
# Generate title after first complete exchange
|
|
return len(user_messages) == 1 and len(assistant_messages) >= 1
|
|
|
|
def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]:
|
|
"""Extract user/assistant messages and build the title prompt.
|
|
|
|
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
|
|
"""
|
|
config = self._get_title_config()
|
|
messages = state.get("messages", [])
|
|
|
|
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
|
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
|
|
|
|
user_msg = self._normalize_content(user_msg_content)
|
|
assistant_msg = self._strip_think_tags(self._normalize_content(assistant_msg_content))
|
|
|
|
prompt = config.prompt_template.format(
|
|
max_words=config.max_words,
|
|
user_msg=user_msg[:500],
|
|
assistant_msg=assistant_msg[:500],
|
|
)
|
|
return prompt, user_msg
|
|
|
|
def _strip_think_tags(self, text: str) -> str:
|
|
"""Remove <think>...</think> blocks emitted by reasoning models (e.g. minimax, DeepSeek-R1)."""
|
|
return re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip()
|
|
|
|
def _parse_title(self, content: object) -> str:
|
|
"""Normalize model output into a clean title string."""
|
|
config = self._get_title_config()
|
|
title_content = self._normalize_content(content)
|
|
title_content = self._strip_think_tags(title_content)
|
|
title = title_content.strip().strip('"').strip("'")
|
|
return title[: config.max_chars] if len(title) > config.max_chars else title
|
|
|
|
def _fallback_title(self, user_msg: str) -> str:
|
|
config = self._get_title_config()
|
|
fallback_chars = min(config.max_chars, 50)
|
|
if len(user_msg) > fallback_chars:
|
|
return user_msg[:fallback_chars].rstrip() + "..."
|
|
return user_msg if user_msg else "New Conversation"
|
|
|
|
def _get_runnable_config(self) -> dict[str, Any]:
|
|
"""Inherit the parent RunnableConfig and add middleware tag.
|
|
|
|
This ensures RunJournal identifies LLM calls from this middleware
|
|
as ``middleware:title`` instead of ``lead_agent``.
|
|
"""
|
|
try:
|
|
parent = get_config()
|
|
except Exception:
|
|
parent = {}
|
|
config = {**parent}
|
|
config["run_name"] = "title_agent"
|
|
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
|
|
return config
|
|
|
|
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
|
"""Generate a local fallback title without blocking on an LLM call."""
|
|
if not self._should_generate_title(state):
|
|
return None
|
|
|
|
_, user_msg = self._build_title_prompt(state)
|
|
return {"title": self._fallback_title(user_msg)}
|
|
|
|
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
|
"""Generate a title asynchronously and fall back locally on failure."""
|
|
if not self._should_generate_title(state):
|
|
return None
|
|
|
|
config = self._get_title_config()
|
|
prompt, user_msg = self._build_title_prompt(state)
|
|
|
|
try:
|
|
model_kwargs = {"thinking_enabled": False}
|
|
if self._app_config is not None:
|
|
model_kwargs["app_config"] = self._app_config
|
|
if config.model_name:
|
|
model = create_chat_model(name=config.model_name, **model_kwargs)
|
|
else:
|
|
model = create_chat_model(**model_kwargs)
|
|
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
|
title = self._parse_title(response.content)
|
|
if title:
|
|
return {"title": title}
|
|
except Exception:
|
|
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
|
|
return {"title": self._fallback_title(user_msg)}
|
|
|
|
@override
|
|
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
|
return self._generate_title_result(state)
|
|
|
|
@override
|
|
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
|
return await self._agenerate_title_result(state)
|