fix(serialization): stop stripping __interrupt__ from channel values (#3595) (#3605)

This commit is contained in:
Zhipeng Zheng
2026-06-17 15:29:22 +08:00
committed by GitHub
parent a72af8ea37
commit c81ab268fb
3 changed files with 94 additions and 5 deletions
@@ -35,6 +35,20 @@ def serialize_lc_object(obj: Any) -> Any:
return obj.dict() return obj.dict()
except Exception: except Exception:
pass pass
# Interrupt is a __slots__ class — no model_dump/dict/__dict__, so it
# would reach str() and produce a malformed payload.
try:
from langgraph.types import Interrupt
except ImportError:
pass
else:
if isinstance(obj, Interrupt):
return serialize_lc_object(
{
"value": obj.value,
"id": getattr(obj, "id", None),
}
)
# Last resort # Last resort
try: try:
return str(obj) return str(obj)
@@ -45,12 +59,13 @@ def serialize_lc_object(obj: Any) -> Any:
def serialize_channel_values(channel_values: dict[str, Any]) -> dict[str, Any]: def serialize_channel_values(channel_values: dict[str, Any]) -> dict[str, Any]:
"""Serialize channel values, stripping internal LangGraph keys. """Serialize channel values, stripping internal LangGraph keys.
Internal keys like ``__pregel_*`` and ``__interrupt__`` are removed Only ``__pregel_*`` keys are removed — ``__interrupt__`` is deliberately
to match what the LangGraph Platform API returns. preserved so the LangGraph SDK can detect interrupt events from values
chunks (see issue #3595).
""" """
result: dict[str, Any] = {} result: dict[str, Any] = {}
for key, value in channel_values.items(): for key, value in channel_values.items():
if key.startswith("__pregel_") or key == "__interrupt__": if key.startswith("__pregel_"):
continue continue
result[key] = serialize_lc_object(value) result[key] = serialize_lc_object(value)
return result return result
@@ -0,0 +1,71 @@
"""Regression tests for issue #3595: __interrupt__ must survive serialize_channel_values."""
from __future__ import annotations
from typing import Any
import pytest
from langgraph.graph import StateGraph
from langgraph.types import Interrupt, interrupt
def _interrupting_node(state: dict) -> dict[str, Any]:
result = interrupt("Please provide API credentials")
return {"result": result}
def _build_test_graph():
builder = StateGraph(dict)
builder.add_node("ask_credential", _interrupting_node)
builder.set_entry_point("ask_credential")
builder.set_finish_point("ask_credential")
return builder.compile()
class _StreamCollector:
def __init__(self):
self.events: list[tuple[str, Any]] = []
async def publish(self, _run_id: str, event: str, data: Any):
self.events.append((event, data))
@pytest.mark.asyncio
async def test_values_mode_includes_interrupt():
from deerflow.runtime.serialization import serialize
graph = _build_test_graph()
collector = _StreamCollector()
async for chunk in graph.astream({"messages": []}, stream_mode="values"):
data = serialize(chunk, mode="values")
await collector.publish("test", "values", data)
interrupt_events = [e for e in collector.events if isinstance(e[1], dict) and "__interrupt__" in e[1]]
assert len(interrupt_events) > 0, "__interrupt__ was stripped from values events"
# Verify the payload is structured (not a str fallback from serialize_lc_object)
interrupt_value = interrupt_events[0][1]["__interrupt__"]
assert isinstance(interrupt_value, list)
assert len(interrupt_value) > 0
assert isinstance(interrupt_value[0], dict)
assert interrupt_value[0]["value"] == "Please provide API credentials"
@pytest.mark.asyncio
async def test_serialize_channel_values_keeps_interrupt():
from deerflow.runtime.serialization import serialize_channel_values
interrupt_obj = Interrupt(value={"question": "Enter API key"}, id="test-interrupt-id")
result = serialize_channel_values(
{
"__interrupt__": (interrupt_obj,),
"__pregel_tasks": "internal",
"messages": [],
}
)
assert "__interrupt__" in result
assert "__pregel_tasks" not in result
assert "messages" in result
# Verify payload shape: Interrupt must serialize to a dict, not str
assert isinstance(result["__interrupt__"], list)
assert len(result["__interrupt__"]) > 0
assert isinstance(result["__interrupt__"][0], dict)
assert result["__interrupt__"][0]["value"] == {"question": "Enter API key"}
+5 -2
View File
@@ -96,7 +96,7 @@ def test_serialize_channel_values_strips_pregel_keys():
"messages": ["hello"], "messages": ["hello"],
"__pregel_tasks": "internal", "__pregel_tasks": "internal",
"__pregel_resuming": True, "__pregel_resuming": True,
"__interrupt__": "stop", "__interrupt__": [{"value": "ask_human", "resumable": True}],
"title": "Test", "title": "Test",
} }
result = serialize_channel_values(raw) result = serialize_channel_values(raw)
@@ -104,7 +104,10 @@ def test_serialize_channel_values_strips_pregel_keys():
assert "title" in result assert "title" in result
assert "__pregel_tasks" not in result assert "__pregel_tasks" not in result
assert "__pregel_resuming" not in result assert "__pregel_resuming" not in result
assert "__interrupt__" not in result assert "__interrupt__" in result
assert isinstance(result["__interrupt__"], list)
assert len(result["__interrupt__"]) == 1
assert result["__interrupt__"][0]["value"] == "ask_human"
def test_serialize_channel_values_serializes_objects(): def test_serialize_channel_values_serializes_objects():