mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
fix(channels): preserve clarification conversation history across follow-up turns (#2444)
* fix(channels): preserve clarification conversation history across follow-up turns Pin channel-triggered runs to the root checkpoint namespace and ensure thread_id is always present in configurable run config so follow-up replies resume the same conversation state. Add regression coverage to channel tests: assert checkpoint_ns/thread_id are passed in wait and stream paths add an integration-style clarification flow test that verifies the second user reply continues prior context instead of starting a new session This addresses history loss after ask_clarification interruptions (issue #2425). * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix(channels): copy configurable dict before injecting run-scoped fields When configurable was already a plain dict, _resolve_run_params mutated it in place, leaking checkpoint_ns and thread_id back into the shared session config. Always copy via dict() before mutating to prevent cross-user or cross-channel config pollution. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -589,6 +589,17 @@ class ChannelManager:
|
|||||||
user_layer.get("config"),
|
user_layer.get("config"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
configurable = run_config.get("configurable")
|
||||||
|
if isinstance(configurable, Mapping):
|
||||||
|
configurable = dict(configurable)
|
||||||
|
else:
|
||||||
|
configurable = {}
|
||||||
|
run_config["configurable"] = configurable
|
||||||
|
# Pin channel-triggered runs to the root graph namespace so follow-up
|
||||||
|
# turns continue from the same conversation checkpoint.
|
||||||
|
configurable["checkpoint_ns"] = ""
|
||||||
|
configurable["thread_id"] = thread_id
|
||||||
|
|
||||||
run_context = _merge_dicts(
|
run_context = _merge_dicts(
|
||||||
DEFAULT_RUN_CONTEXT,
|
DEFAULT_RUN_CONTEXT,
|
||||||
self._default_session.get("context"),
|
self._default_session.get("context"),
|
||||||
|
|||||||
@@ -530,6 +530,8 @@ class TestChannelManager:
|
|||||||
assert call_args[0][0] == "test-thread-123" # thread_id
|
assert call_args[0][0] == "test-thread-123" # thread_id
|
||||||
assert call_args[0][1] == "lead_agent" # assistant_id
|
assert call_args[0][1] == "lead_agent" # assistant_id
|
||||||
assert call_args[1]["input"]["messages"][0]["content"] == "hi"
|
assert call_args[1]["input"]["messages"][0]["content"] == "hi"
|
||||||
|
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
|
||||||
|
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
|
||||||
|
|
||||||
assert len(outbound_received) == 1
|
assert len(outbound_received) == 1
|
||||||
assert outbound_received[0].text == "Hello from agent!"
|
assert outbound_received[0].text == "Hello from agent!"
|
||||||
@@ -661,12 +663,135 @@ class TestChannelManager:
|
|||||||
call_args = mock_client.runs.wait.call_args
|
call_args = mock_client.runs.wait.call_args
|
||||||
assert call_args[0][1] == "lead_agent"
|
assert call_args[0][1] == "lead_agent"
|
||||||
assert call_args[1]["config"]["recursion_limit"] == 55
|
assert call_args[1]["config"]["recursion_limit"] == 55
|
||||||
|
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
|
||||||
|
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
|
||||||
assert call_args[1]["context"]["thinking_enabled"] is False
|
assert call_args[1]["context"]["thinking_enabled"] is False
|
||||||
assert call_args[1]["context"]["subagent_enabled"] is True
|
assert call_args[1]["context"]["subagent_enabled"] is True
|
||||||
assert call_args[1]["context"]["agent_name"] == "mobile-agent"
|
assert call_args[1]["context"]["agent_name"] == "mobile-agent"
|
||||||
|
|
||||||
_run(go())
|
_run(go())
|
||||||
|
|
||||||
|
def test_clarification_follow_up_preserves_history(self):
|
||||||
|
"""Conversation should continue after ask_clarification instead of resetting history."""
|
||||||
|
from app.channels.manager import ChannelManager
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
bus = MessageBus()
|
||||||
|
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||||
|
manager = ChannelManager(bus=bus, store=store)
|
||||||
|
|
||||||
|
outbound_received = []
|
||||||
|
|
||||||
|
async def capture_outbound(msg):
|
||||||
|
outbound_received.append(msg)
|
||||||
|
|
||||||
|
bus.subscribe_outbound(capture_outbound)
|
||||||
|
|
||||||
|
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
|
||||||
|
|
||||||
|
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
|
||||||
|
del assistant_id, context # unused in this test, kept for signature parity
|
||||||
|
|
||||||
|
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
|
||||||
|
key = (thread_id, str(checkpoint_ns))
|
||||||
|
history = history_by_checkpoint.setdefault(key, [])
|
||||||
|
|
||||||
|
human_text = input["messages"][0]["content"]
|
||||||
|
history.append(human_text)
|
||||||
|
|
||||||
|
if len(history) == 1:
|
||||||
|
return {
|
||||||
|
"messages": [
|
||||||
|
{"type": "human", "content": history[0]},
|
||||||
|
{
|
||||||
|
"type": "ai",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"name": "ask_clarification",
|
||||||
|
"args": {"question": "Which environment should I use?"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"name": "ask_clarification",
|
||||||
|
"content": "Which environment should I use?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(history) == 2 and history[0] == "Deploy my app" and history[1] == "prod":
|
||||||
|
return {
|
||||||
|
"messages": [
|
||||||
|
{"type": "human", "content": history[0]},
|
||||||
|
{
|
||||||
|
"type": "ai",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"name": "ask_clarification",
|
||||||
|
"args": {"question": "Which environment should I use?"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"name": "ask_clarification",
|
||||||
|
"content": "Which environment should I use?",
|
||||||
|
},
|
||||||
|
{"type": "human", "content": history[1]},
|
||||||
|
{"type": "ai", "content": "Got it. I will deploy to prod."},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": [
|
||||||
|
{"type": "human", "content": history[-1]},
|
||||||
|
{"type": "ai", "content": "History missing; clarification repeated."},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.threads.create = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
|
||||||
|
mock_client.threads.get = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
|
||||||
|
mock_client.runs.wait = AsyncMock(side_effect=_runs_wait)
|
||||||
|
manager._client = mock_client
|
||||||
|
|
||||||
|
await manager.start()
|
||||||
|
|
||||||
|
await bus.publish_inbound(
|
||||||
|
InboundMessage(
|
||||||
|
channel_name="test",
|
||||||
|
chat_id="chat1",
|
||||||
|
user_id="user1",
|
||||||
|
text="Deploy my app",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await _wait_for(lambda: len(outbound_received) >= 1)
|
||||||
|
|
||||||
|
await bus.publish_inbound(
|
||||||
|
InboundMessage(
|
||||||
|
channel_name="test",
|
||||||
|
chat_id="chat1",
|
||||||
|
user_id="user1",
|
||||||
|
text="prod",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await _wait_for(lambda: len(outbound_received) >= 2)
|
||||||
|
await manager.stop()
|
||||||
|
|
||||||
|
assert outbound_received[0].text == "Which environment should I use?"
|
||||||
|
assert outbound_received[1].text == "Got it. I will deploy to prod."
|
||||||
|
|
||||||
|
assert mock_client.runs.wait.call_count == 2
|
||||||
|
first_call = mock_client.runs.wait.call_args_list[0]
|
||||||
|
second_call = mock_client.runs.wait.call_args_list[1]
|
||||||
|
assert first_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
|
||||||
|
assert second_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
|
||||||
def test_handle_chat_uses_user_session_overrides(self):
|
def test_handle_chat_uses_user_session_overrides(self):
|
||||||
from app.channels.manager import ChannelManager
|
from app.channels.manager import ChannelManager
|
||||||
|
|
||||||
@@ -1343,6 +1468,8 @@ class TestChannelManager:
|
|||||||
call_args = mock_client.runs.stream.call_args
|
call_args = mock_client.runs.stream.call_args
|
||||||
|
|
||||||
assert call_args[1]["input"]["messages"][0]["content"] == "hello"
|
assert call_args[1]["input"]["messages"][0]["content"] == "hello"
|
||||||
|
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
|
||||||
|
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
|
||||||
assert call_args[1]["context"]["is_bootstrap"] is True
|
assert call_args[1]["context"]["is_bootstrap"] is True
|
||||||
|
|
||||||
# Final message should be published
|
# Final message should be published
|
||||||
|
|||||||
Reference in New Issue
Block a user