mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
fix(channels): preserve Feishu clarification thread continuity (#3285)
* fix(channels): preserve Feishu clarification thread continuity * fix(channels): address Feishu clarification review feedback --------- Co-authored-by: zzp1221 <zzp1221@users.noreply.github.com> Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -12,7 +12,14 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.message_bus import (
|
||||
PENDING_CLARIFICATION_METADATA_KEY,
|
||||
InboundMessage,
|
||||
InboundMessageType,
|
||||
MessageBus,
|
||||
OutboundMessage,
|
||||
ResolvedAttachment,
|
||||
)
|
||||
from app.channels.store import ChannelStore
|
||||
|
||||
|
||||
@@ -392,6 +399,47 @@ class TestExtractResponseText:
|
||||
assert _extract_response_text(result) == "Here is the plan."
|
||||
|
||||
|
||||
class TestClarificationDetection:
|
||||
def test_final_clarification_tool_message_is_pending(self):
|
||||
from app.channels.manager import _has_current_turn_clarification
|
||||
|
||||
result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "deploy"},
|
||||
{"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]},
|
||||
{"type": "tool", "name": "ask_clarification", "content": "Which environment?"},
|
||||
]
|
||||
}
|
||||
assert _has_current_turn_clarification(result) is True
|
||||
|
||||
def test_clarification_followed_by_regular_ai_is_not_pending(self):
|
||||
from app.channels.manager import _has_current_turn_clarification
|
||||
|
||||
result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "deploy"},
|
||||
{"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]},
|
||||
{"type": "tool", "name": "ask_clarification", "content": "Which environment?"},
|
||||
{"type": "ai", "content": "I will continue without pending clarification."},
|
||||
]
|
||||
}
|
||||
assert _has_current_turn_clarification(result) is False
|
||||
|
||||
def test_previous_turn_clarification_does_not_mark_current_turn(self):
|
||||
from app.channels.manager import _has_current_turn_clarification
|
||||
|
||||
result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "deploy"},
|
||||
{"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]},
|
||||
{"type": "tool", "name": "ask_clarification", "content": "Which environment?"},
|
||||
{"type": "human", "content": "prod"},
|
||||
{"type": "ai", "content": "Deploying to prod."},
|
||||
]
|
||||
}
|
||||
assert _has_current_turn_clarification(result) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelManager tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -637,6 +685,74 @@ class TestChannelManager:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_chat_marks_clarification_outbound_metadata(self):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store)
|
||||
outbound_received: list[OutboundMessage] = []
|
||||
|
||||
async def capture_outbound(msg: OutboundMessage) -> None:
|
||||
outbound_received.append(msg)
|
||||
|
||||
bus.subscribe_outbound(capture_outbound)
|
||||
mock_client = _make_mock_langgraph_client(
|
||||
run_result={
|
||||
"messages": [
|
||||
{"type": "human", "content": "deploy"},
|
||||
{"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]},
|
||||
{"type": "tool", "name": "ask_clarification", "content": "Which environment?"},
|
||||
]
|
||||
}
|
||||
)
|
||||
manager._client = mock_client
|
||||
await manager.start()
|
||||
|
||||
inbound = InboundMessage(
|
||||
channel_name="test",
|
||||
chat_id="chat1",
|
||||
user_id="user1",
|
||||
text="deploy",
|
||||
metadata={"message_id": "msg-1"},
|
||||
)
|
||||
await bus.publish_inbound(inbound)
|
||||
await _wait_for(lambda: len(outbound_received) >= 1)
|
||||
await manager.stop()
|
||||
|
||||
assert outbound_received[0].text == "Which environment?"
|
||||
assert outbound_received[0].metadata["message_id"] == "msg-1"
|
||||
assert outbound_received[0].metadata[PENDING_CLARIFICATION_METADATA_KEY] is True
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_chat_does_not_mark_regular_outbound_as_clarification(self):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store)
|
||||
outbound_received: list[OutboundMessage] = []
|
||||
|
||||
async def capture_outbound(msg: OutboundMessage) -> None:
|
||||
outbound_received.append(msg)
|
||||
|
||||
bus.subscribe_outbound(capture_outbound)
|
||||
mock_client = _make_mock_langgraph_client()
|
||||
manager._client = mock_client
|
||||
await manager.start()
|
||||
|
||||
await bus.publish_inbound(InboundMessage(channel_name="test", chat_id="chat1", user_id="user1", text="hi"))
|
||||
await _wait_for(lambda: len(outbound_received) >= 1)
|
||||
await manager.stop()
|
||||
|
||||
assert outbound_received[0].text == "Hello from agent!"
|
||||
assert PENDING_CLARIFICATION_METADATA_KEY not in outbound_received[0].metadata
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_chat_outbound_drops_large_metadata_keys(self):
|
||||
"""Large metadata keys like raw_message should be stripped from outbound messages."""
|
||||
from app.channels.manager import ChannelManager
|
||||
@@ -1018,6 +1134,67 @@ class TestChannelManager:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_feishu_streaming_marks_only_final_clarification_outbound(self, monkeypatch):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
monkeypatch.setattr("app.channels.manager.STREAM_UPDATE_MIN_INTERVAL_SECONDS", 0.0)
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store)
|
||||
outbound_received: list[OutboundMessage] = []
|
||||
|
||||
async def capture_outbound(msg: OutboundMessage) -> None:
|
||||
outbound_received.append(msg)
|
||||
|
||||
bus.subscribe_outbound(capture_outbound)
|
||||
stream_events = [
|
||||
_make_stream_part(
|
||||
"messages-tuple",
|
||||
[
|
||||
{"id": "ai-1", "content": "Thinking", "type": "AIMessageChunk"},
|
||||
{"langgraph_node": "agent"},
|
||||
],
|
||||
),
|
||||
_make_stream_part(
|
||||
"values",
|
||||
{
|
||||
"messages": [
|
||||
{"type": "human", "content": "deploy"},
|
||||
{"type": "ai", "content": "", "tool_calls": [{"name": "ask_clarification", "args": {}}]},
|
||||
{"type": "tool", "name": "ask_clarification", "content": "Which environment?"},
|
||||
],
|
||||
"artifacts": [],
|
||||
},
|
||||
),
|
||||
]
|
||||
mock_client = _make_mock_langgraph_client()
|
||||
mock_client.runs.stream = MagicMock(return_value=_make_async_iterator(stream_events))
|
||||
manager._client = mock_client
|
||||
await manager.start()
|
||||
|
||||
await bus.publish_inbound(
|
||||
InboundMessage(
|
||||
channel_name="feishu",
|
||||
chat_id="chat1",
|
||||
user_id="user1",
|
||||
text="deploy",
|
||||
thread_ts="om-source-1",
|
||||
)
|
||||
)
|
||||
await _wait_for(lambda: len(outbound_received) >= 2)
|
||||
await manager.stop()
|
||||
|
||||
assert [msg.is_final for msg in outbound_received] == [False, False, True]
|
||||
assert outbound_received[0].text == "Thinking"
|
||||
assert outbound_received[1].text == "Which environment?"
|
||||
assert outbound_received[2].text == "Which environment?"
|
||||
assert all(PENDING_CLARIFICATION_METADATA_KEY not in msg.metadata for msg in outbound_received[:-1])
|
||||
assert outbound_received[-1].metadata[PENDING_CLARIFICATION_METADATA_KEY] is True
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_feishu_stream_error_still_sends_final(self, monkeypatch):
|
||||
"""When the stream raises mid-way, a final outbound with is_final=True must still be published."""
|
||||
from app.channels.manager import ChannelManager
|
||||
@@ -2010,7 +2187,8 @@ class TestFeishuChannel:
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
channel = FeishuChannel(bus, config={})
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
channel = FeishuChannel(bus, config={"channel_store": store})
|
||||
channel._api_client = MagicMock()
|
||||
|
||||
reply_started = asyncio.Event()
|
||||
@@ -2046,6 +2224,11 @@ class TestFeishuChannel:
|
||||
text="Hello",
|
||||
is_final=False,
|
||||
thread_ts="om-source-msg",
|
||||
metadata={
|
||||
"user_id": "user-1",
|
||||
"root_id": "om-root-msg",
|
||||
"topic_id": "om-root-msg",
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -2060,6 +2243,9 @@ class TestFeishuChannel:
|
||||
assert channel._reply_card.await_count == 1
|
||||
channel._update_card.assert_awaited_once_with("om-running-card", "Hello")
|
||||
assert "om-source-msg" not in channel._running_card_tasks
|
||||
assert store.get_thread_id("feishu", "chat-1", topic_id="om-source-msg") == "thread-1"
|
||||
assert store.get_thread_id("feishu", "chat-1", topic_id="om-running-card") == "thread-1"
|
||||
assert store.get_thread_id("feishu", "chat-1", topic_id="om-root-msg") == "thread-1"
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user