fix(backend): make clarification messages idempotent (#2350) (#2351)

This commit is contained in:
Nan Gao
2026-04-19 16:00:58 +02:00
committed by GitHub
parent 7c87dc5bca
commit f514e35a36
2 changed files with 68 additions and 0 deletions
@@ -3,6 +3,7 @@
import json import json
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from hashlib import sha256
from typing import override from typing import override
from langchain.agents import AgentState from langchain.agents import AgentState
@@ -36,6 +37,13 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
state_schema = ClarificationMiddlewareState state_schema = ClarificationMiddlewareState
def _stable_message_id(self, tool_call_id: str, formatted_message: str) -> str:
"""Build a deterministic message ID so retried clarification calls replace, not append."""
if tool_call_id:
return f"clarification:{tool_call_id}"
digest = sha256(formatted_message.encode("utf-8")).hexdigest()[:16]
return f"clarification:{digest}"
def _is_chinese(self, text: str) -> bool: def _is_chinese(self, text: str) -> bool:
"""Check if text contains Chinese characters. """Check if text contains Chinese characters.
@@ -131,6 +139,7 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
# Create a ToolMessage with the formatted question # Create a ToolMessage with the formatted question
# This will be added to the message history # This will be added to the message history
tool_message = ToolMessage( tool_message = ToolMessage(
id=self._stable_message_id(tool_call_id, formatted_message),
content=formatted_message, content=formatted_message,
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
name="ask_clarification", name="ask_clarification",
@@ -1,8 +1,10 @@
"""Tests for ClarificationMiddleware, focusing on options type coercion.""" """Tests for ClarificationMiddleware, focusing on options type coercion."""
import json import json
from types import SimpleNamespace
import pytest import pytest
from langgraph.graph.message import add_messages
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
@@ -118,3 +120,60 @@ class TestFormatClarificationMessage:
assert "2. 2" in result assert "2. 2" in result
assert "3. True" in result assert "3. True" in result
assert "4. None" in result assert "4. None" in result
class TestClarificationCommandIdempotency:
"""Clarification tool-call retries should not duplicate messages in state."""
def test_repeated_tool_call_uses_stable_message_id(self, middleware):
request = SimpleNamespace(
tool_call={
"name": "ask_clarification",
"id": "call-clarify-1",
"args": {
"question": "Which environment should I use?",
"clarification_type": "approach_choice",
"options": ["dev", "prod"],
},
}
)
first = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called"))
second = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called"))
first_message = first.update["messages"][0]
second_message = second.update["messages"][0]
assert first_message.id == "clarification:call-clarify-1"
assert second_message.id == first_message.id
assert second_message.tool_call_id == first_message.tool_call_id
merged = add_messages(add_messages([], [first_message]), [second_message])
assert len(merged) == 1
assert merged[0].id == "clarification:call-clarify-1"
assert merged[0].content == first_message.content
def test_missing_tool_call_id_still_gets_stable_message_id(self, middleware):
request = SimpleNamespace(
tool_call={
"name": "ask_clarification",
"args": {
"question": "Which environment should I use?",
"clarification_type": "missing_info",
},
}
)
first = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called"))
second = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called"))
first_message = first.update["messages"][0]
second_message = second.update["messages"][0]
assert first_message.id.startswith("clarification:")
assert second_message.id == first_message.id
merged = add_messages(add_messages([], [first_message]), [second_message])
assert len(merged) == 1