mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
fix title generation with dynamic context reminder (#2830)
This commit is contained in:
@@ -53,6 +53,11 @@ def _extract_date(content: str) -> str | None:
|
|||||||
return m.group(1) if m else None
|
return m.group(1) if m else None
|
||||||
|
|
||||||
|
|
||||||
|
def is_dynamic_context_reminder(message: object) -> bool:
|
||||||
|
"""Return whether *message* is a hidden dynamic-context reminder."""
|
||||||
|
return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY))
|
||||||
|
|
||||||
|
|
||||||
def _last_injected_date(messages: list) -> str | None:
|
def _last_injected_date(messages: list) -> str | None:
|
||||||
"""Scan messages in reverse and return the most recently injected date.
|
"""Scan messages in reverse and return the most recently injected date.
|
||||||
|
|
||||||
@@ -61,7 +66,7 @@ def _last_injected_date(messages: list) -> str | None:
|
|||||||
are not mistakenly treated as injected reminders.
|
are not mistakenly treated as injected reminders.
|
||||||
"""
|
"""
|
||||||
for msg in reversed(messages):
|
for msg in reversed(messages):
|
||||||
if isinstance(msg, HumanMessage) and msg.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY):
|
if is_dynamic_context_reminder(msg):
|
||||||
content_str = msg.content if isinstance(msg.content, str) else str(msg.content)
|
content_str = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||||
return _extract_date(content_str)
|
return _extract_date(content_str)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from langchain.agents.middleware import AgentMiddleware
|
|||||||
from langgraph.config import get_config
|
from langgraph.config import get_config
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
|
||||||
from deerflow.config.title_config import get_title_config
|
from deerflow.config.title_config import get_title_config
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
|
|
||||||
@@ -61,6 +62,10 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_user_message_for_title(message: object) -> bool:
|
||||||
|
return getattr(message, "type", None) == "human" and not is_dynamic_context_reminder(message)
|
||||||
|
|
||||||
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
||||||
"""Check if we should generate a title for this thread."""
|
"""Check if we should generate a title for this thread."""
|
||||||
config = self._get_title_config()
|
config = self._get_title_config()
|
||||||
@@ -77,7 +82,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Count user and assistant messages
|
# Count user and assistant messages
|
||||||
user_messages = [m for m in messages if m.type == "human"]
|
user_messages = [m for m in messages if self._is_user_message_for_title(m)]
|
||||||
assistant_messages = [m for m in messages if m.type == "ai"]
|
assistant_messages = [m for m in messages if m.type == "ai"]
|
||||||
|
|
||||||
# Generate title after first complete exchange
|
# Generate title after first complete exchange
|
||||||
@@ -91,7 +96,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
config = self._get_title_config()
|
config = self._get_title_config()
|
||||||
messages = state.get("messages", [])
|
messages = state.get("messages", [])
|
||||||
|
|
||||||
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
user_msg_content = next((m.content for m in messages if self._is_user_message_for_title(m)), "")
|
||||||
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
|
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
|
||||||
|
|
||||||
user_msg = self._normalize_content(user_msg_content)
|
user_msg = self._normalize_content(user_msg_content)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
from deerflow.agents.middlewares import title_middleware as title_middleware_module
|
from deerflow.agents.middlewares import title_middleware as title_middleware_module
|
||||||
|
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY
|
||||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||||
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
|
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
|
||||||
|
|
||||||
@@ -44,6 +45,22 @@ class TestTitleMiddlewareCoreLogic:
|
|||||||
|
|
||||||
assert middleware._should_generate_title(state) is True
|
assert middleware._should_generate_title(state) is True
|
||||||
|
|
||||||
|
def test_should_generate_title_with_dynamic_context_reminder(self):
|
||||||
|
_set_test_title_config(enabled=True)
|
||||||
|
middleware = TitleMiddleware()
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(
|
||||||
|
content="<system-reminder>\n<memory>User prefers Python.</memory>\n</system-reminder>",
|
||||||
|
additional_kwargs={_DYNAMIC_CONTEXT_REMINDER_KEY: True},
|
||||||
|
),
|
||||||
|
HumanMessage(content="帮我总结这段代码"),
|
||||||
|
AIMessage(content="好的,我先看结构"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
assert middleware._should_generate_title(state) is True
|
||||||
|
|
||||||
def test_should_not_generate_title_when_disabled_or_already_set(self):
|
def test_should_not_generate_title_when_disabled_or_already_set(self):
|
||||||
middleware = TitleMiddleware()
|
middleware = TitleMiddleware()
|
||||||
|
|
||||||
@@ -243,6 +260,25 @@ class TestTitleMiddlewareCoreLogic:
|
|||||||
prompt, _ = middleware._build_title_prompt(state)
|
prompt, _ = middleware._build_title_prompt(state)
|
||||||
assert "<think>" not in prompt
|
assert "<think>" not in prompt
|
||||||
|
|
||||||
|
def test_build_title_prompt_uses_real_user_message_with_dynamic_context_reminder(self):
|
||||||
|
_set_test_title_config(enabled=True)
|
||||||
|
middleware = TitleMiddleware()
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(
|
||||||
|
content="<system-reminder>\n<memory>User prefers Python.</memory>\n</system-reminder>",
|
||||||
|
additional_kwargs={_DYNAMIC_CONTEXT_REMINDER_KEY: True},
|
||||||
|
),
|
||||||
|
HumanMessage(content="请帮我写测试"),
|
||||||
|
AIMessage(content="好的"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt, user_msg = middleware._build_title_prompt(state)
|
||||||
|
assert user_msg == "请帮我写测试"
|
||||||
|
assert "<system-reminder>" not in prompt
|
||||||
|
assert "User prefers Python" not in prompt
|
||||||
|
|
||||||
def test_generate_title_async_strips_think_tags_in_response(self, monkeypatch):
|
def test_generate_title_async_strips_think_tags_in_response(self, monkeypatch):
|
||||||
"""Async title generation strips <think> blocks from the model response."""
|
"""Async title generation strips <think> blocks from the model response."""
|
||||||
_set_test_title_config(max_chars=50)
|
_set_test_title_config(max_chars=50)
|
||||||
|
|||||||
Reference in New Issue
Block a user