mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-18 13:46:02 +00:00
This commit is contained in:
@@ -35,6 +35,20 @@ def serialize_lc_object(obj: Any) -> Any:
|
|||||||
return obj.dict()
|
return obj.dict()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
# Interrupt is a __slots__ class — no model_dump/dict/__dict__, so it
|
||||||
|
# would reach str() and produce a malformed payload.
|
||||||
|
try:
|
||||||
|
from langgraph.types import Interrupt
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if isinstance(obj, Interrupt):
|
||||||
|
return serialize_lc_object(
|
||||||
|
{
|
||||||
|
"value": obj.value,
|
||||||
|
"id": getattr(obj, "id", None),
|
||||||
|
}
|
||||||
|
)
|
||||||
# Last resort
|
# Last resort
|
||||||
try:
|
try:
|
||||||
return str(obj)
|
return str(obj)
|
||||||
@@ -45,12 +59,13 @@ def serialize_lc_object(obj: Any) -> Any:
|
|||||||
def serialize_channel_values(channel_values: dict[str, Any]) -> dict[str, Any]:
|
def serialize_channel_values(channel_values: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Serialize channel values, stripping internal LangGraph keys.
|
"""Serialize channel values, stripping internal LangGraph keys.
|
||||||
|
|
||||||
Internal keys like ``__pregel_*`` and ``__interrupt__`` are removed
|
Only ``__pregel_*`` keys are removed — ``__interrupt__`` is deliberately
|
||||||
to match what the LangGraph Platform API returns.
|
preserved so the LangGraph SDK can detect interrupt events from values
|
||||||
|
chunks (see issue #3595).
|
||||||
"""
|
"""
|
||||||
result: dict[str, Any] = {}
|
result: dict[str, Any] = {}
|
||||||
for key, value in channel_values.items():
|
for key, value in channel_values.items():
|
||||||
if key.startswith("__pregel_") or key == "__interrupt__":
|
if key.startswith("__pregel_"):
|
||||||
continue
|
continue
|
||||||
result[key] = serialize_lc_object(value)
|
result[key] = serialize_lc_object(value)
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
"""Regression tests for issue #3595: __interrupt__ must survive serialize_channel_values."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langgraph.graph import StateGraph
|
||||||
|
from langgraph.types import Interrupt, interrupt
|
||||||
|
|
||||||
|
|
||||||
|
def _interrupting_node(state: dict) -> dict[str, Any]:
|
||||||
|
result = interrupt("Please provide API credentials")
|
||||||
|
return {"result": result}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_test_graph():
|
||||||
|
builder = StateGraph(dict)
|
||||||
|
builder.add_node("ask_credential", _interrupting_node)
|
||||||
|
builder.set_entry_point("ask_credential")
|
||||||
|
builder.set_finish_point("ask_credential")
|
||||||
|
return builder.compile()
|
||||||
|
|
||||||
|
|
||||||
|
class _StreamCollector:
|
||||||
|
def __init__(self):
|
||||||
|
self.events: list[tuple[str, Any]] = []
|
||||||
|
|
||||||
|
async def publish(self, _run_id: str, event: str, data: Any):
|
||||||
|
self.events.append((event, data))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_values_mode_includes_interrupt():
|
||||||
|
from deerflow.runtime.serialization import serialize
|
||||||
|
|
||||||
|
graph = _build_test_graph()
|
||||||
|
collector = _StreamCollector()
|
||||||
|
async for chunk in graph.astream({"messages": []}, stream_mode="values"):
|
||||||
|
data = serialize(chunk, mode="values")
|
||||||
|
await collector.publish("test", "values", data)
|
||||||
|
interrupt_events = [e for e in collector.events if isinstance(e[1], dict) and "__interrupt__" in e[1]]
|
||||||
|
assert len(interrupt_events) > 0, "__interrupt__ was stripped from values events"
|
||||||
|
# Verify the payload is structured (not a str fallback from serialize_lc_object)
|
||||||
|
interrupt_value = interrupt_events[0][1]["__interrupt__"]
|
||||||
|
assert isinstance(interrupt_value, list)
|
||||||
|
assert len(interrupt_value) > 0
|
||||||
|
assert isinstance(interrupt_value[0], dict)
|
||||||
|
assert interrupt_value[0]["value"] == "Please provide API credentials"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_serialize_channel_values_keeps_interrupt():
|
||||||
|
from deerflow.runtime.serialization import serialize_channel_values
|
||||||
|
|
||||||
|
interrupt_obj = Interrupt(value={"question": "Enter API key"}, id="test-interrupt-id")
|
||||||
|
result = serialize_channel_values(
|
||||||
|
{
|
||||||
|
"__interrupt__": (interrupt_obj,),
|
||||||
|
"__pregel_tasks": "internal",
|
||||||
|
"messages": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert "__interrupt__" in result
|
||||||
|
assert "__pregel_tasks" not in result
|
||||||
|
assert "messages" in result
|
||||||
|
# Verify payload shape: Interrupt must serialize to a dict, not str
|
||||||
|
assert isinstance(result["__interrupt__"], list)
|
||||||
|
assert len(result["__interrupt__"]) > 0
|
||||||
|
assert isinstance(result["__interrupt__"][0], dict)
|
||||||
|
assert result["__interrupt__"][0]["value"] == {"question": "Enter API key"}
|
||||||
@@ -96,7 +96,7 @@ def test_serialize_channel_values_strips_pregel_keys():
|
|||||||
"messages": ["hello"],
|
"messages": ["hello"],
|
||||||
"__pregel_tasks": "internal",
|
"__pregel_tasks": "internal",
|
||||||
"__pregel_resuming": True,
|
"__pregel_resuming": True,
|
||||||
"__interrupt__": "stop",
|
"__interrupt__": [{"value": "ask_human", "resumable": True}],
|
||||||
"title": "Test",
|
"title": "Test",
|
||||||
}
|
}
|
||||||
result = serialize_channel_values(raw)
|
result = serialize_channel_values(raw)
|
||||||
@@ -104,7 +104,10 @@ def test_serialize_channel_values_strips_pregel_keys():
|
|||||||
assert "title" in result
|
assert "title" in result
|
||||||
assert "__pregel_tasks" not in result
|
assert "__pregel_tasks" not in result
|
||||||
assert "__pregel_resuming" not in result
|
assert "__pregel_resuming" not in result
|
||||||
assert "__interrupt__" not in result
|
assert "__interrupt__" in result
|
||||||
|
assert isinstance(result["__interrupt__"], list)
|
||||||
|
assert len(result["__interrupt__"]) == 1
|
||||||
|
assert result["__interrupt__"][0]["value"] == "ask_human"
|
||||||
|
|
||||||
|
|
||||||
def test_serialize_channel_values_serializes_objects():
|
def test_serialize_channel_values_serializes_objects():
|
||||||
|
|||||||
Reference in New Issue
Block a user