mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
feat(gateway): implement LangGraph Platform API in Gateway, replace langgraph-cli (#1403)
* feat(gateway): implement LangGraph Platform API in Gateway, replace langgraph-cli Implement all core LangGraph Platform API endpoints in the Gateway, allowing it to fully replace the langgraph-cli dev server for local development. This eliminates a heavyweight dependency and simplifies the development stack. Changes: - Add runs lifecycle endpoints (create, stream, wait, cancel, join) - Add threads CRUD and search endpoints - Add assistants compatibility endpoints (search, get, graph, schemas) - Add StreamBridge (in-memory pub/sub for SSE) and async provider - Add RunManager with atomic create_or_reject (eliminates TOCTOU race) - Add worker with interrupt/rollback cancel actions and runtime context injection - Route /api/langgraph/* to Gateway in nginx config - Skip langgraph-cli startup by default (SKIP_LANGGRAPH_SERVER=0 to restore) - Add unit tests for RunManager, SSE format, and StreamBridge * fix: drain bridge queue on client disconnect to prevent backpressure When on_disconnect=continue, keep consuming events from the bridge without yielding, so the worker is not blocked by a full queue. Only on_disconnect=cancel breaks out immediately. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: remove pytest import Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: Fix default stream_mode to ["values", "messages-tuple"] Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: Remove unused if_exists field from ThreadCreateRequest Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: address review comments on gateway LangGraph API - Mount runs.py router in app.py (missing include_router) - Normalize interrupt_before/after "*" to node list before run_agent() - Use entry.id for SSE event ID instead of counter - Drain bridge queue on disconnect when on_disconnect=continue - Reuse serialization helper in wait_run() for consistent wire format - Reject unsupported multitask_strategy with 400 - Remove SKIP_LANGGRAPH_SERVER fallback, always use Gateway * feat: extract app.state access into deps.py Encapsulate read/write operations for singleton objects (RunManager, StreamBridge, checkpointer) held in app.state into a shared utility, reducing repeated access patterns across router modules. * feat: extract deerflow.runtime.serialization module with tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * refactor: replace duplicated serialization with deerflow.runtime.serialization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat: extract app/gateway/services.py with run lifecycle logic Create a service layer that centralizes SSE formatting, input/config normalization, and run lifecycle management. Router modules will delegate to these functions instead of using private cross-imported helpers. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * refactor: wire routers to use services layer, remove cross-module private imports Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * style: apply ruff formatting to refactored files Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat(runtime): support LangGraph dev server and add compat route - Enable official LangGraph dev server for local development workflow - Decouple runtime components from agents package for better separation - Provide gateway-backed fallback route when dev server is skipped - Simplify lifecycle management using context manager in gateway * feat(runtime): add Store providers with auto-backend selection - Add async_provider.py and provider.py under deerflow/runtime/store/ - Support memory, sqlite, postgres backends matching checkpointer config - Integrate into FastAPI lifespan via AsyncExitStack in deps.py - Replace hardcoded InMemoryStore with config-driven factory * refactor(gateway): migrate thread management from checkpointer to Store and resolve multiple endpoint failures - Add Store-backed CRUD helpers (_store_get, _store_put, _store_upsert) - Replace checkpoint-scanning search with two-phase strategy: phase 1 reads Store (O(threads)), phase 2 backfills from checkpointer for legacy/LangGraph Server threads with lazy migration - Extend Store record schema with values field for title persistence - Sync thread title from checkpoint to Store after run completion - Fix /threads/{id}/runs/{run_id}/stream 405 by accepting both GET and POST methods; POST handles interrupt/rollback actions - Fix /threads/{id}/state 500 by separating read_config and write_config, adding checkpoint_ns to configurable, and shallow-copying checkpoint/metadata before mutation - Sync title to Store on state update for immediate search reflection - Move _upsert_thread_in_store into services.py, remove duplicate logic - Add _sync_thread_title_after_run: await run task, read final checkpoint title, write back to Store record - Spawn title sync as background task from start_run when Store exists * refactor(runtime): deduplicate store and checkpointer provider logic Extract _ensure_sqlite_parent_dir() helper into checkpointer/provider.py and use it in all three places that previously inlined the same mkdir logic. Consolidate duplicate error constants in store/async_provider.py by importing from store/provider.py instead of redefining them. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * refactor(runtime): move SQLite helpers to runtime/store, checkpointer imports from store _resolve_sqlite_conn_str and _ensure_sqlite_parent_dir now live in runtime/store/provider.py. agents/checkpointer/provider and agents/checkpointer/async_provider import from there, reversing the previous dependency direction (store → checkpointer becomes checkpointer → store). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * refactor(runtime): extract SQLite helpers into runtime/store/_sqlite_utils.py Move resolve_sqlite_conn_str and ensure_sqlite_parent_dir out of checkpointer/provider.py into a dedicated _sqlite_utils module. Functions are now public (no underscore prefix), making cross-module imports semantically correct. All four provider files import from the single shared location. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(gateway): use adelete_thread to fully remove thread checkpoints on delete AsyncSqliteSaver has no adelete method — the previous hasattr check always evaluated to False, silently leaving all checkpoint rows in the database. Switch to adelete_thread(thread_id) which deletes every checkpoint and pending-write row for the thread across all namespaces (including sub-graph checkpoints). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(gateway): remove dead bridge_cm/ckpt_cm code and fix StrEnum lint app.py had unreachable code after the async-with lifespan refactor: bridge_cm and ckpt_cm were referenced but never defined (F821), and the channel service startup/shutdown was outside the langgraph_runtime block so it never ran. Move channel service lifecycle inside the async-with block where it belongs. Replace str+Enum inheritance in RunStatus and DisconnectMode with StrEnum as suggested by UP042. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * style: format with ruff --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: JeffJiang <for-eleven@hotmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -0,0 +1,102 @@
|
||||
"""Tests for app.gateway.services — run lifecycle service layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def test_format_sse_basic():
|
||||
from app.gateway.services import format_sse
|
||||
|
||||
frame = format_sse("metadata", {"run_id": "abc"})
|
||||
assert frame.startswith("event: metadata\n")
|
||||
assert "data: " in frame
|
||||
parsed = json.loads(frame.split("data: ")[1].split("\n")[0])
|
||||
assert parsed["run_id"] == "abc"
|
||||
|
||||
|
||||
def test_format_sse_with_event_id():
|
||||
from app.gateway.services import format_sse
|
||||
|
||||
frame = format_sse("metadata", {"run_id": "abc"}, event_id="123-0")
|
||||
assert "id: 123-0" in frame
|
||||
|
||||
|
||||
def test_format_sse_end_event_null():
|
||||
from app.gateway.services import format_sse
|
||||
|
||||
frame = format_sse("end", None)
|
||||
assert "data: null" in frame
|
||||
|
||||
|
||||
def test_format_sse_no_event_id():
|
||||
from app.gateway.services import format_sse
|
||||
|
||||
frame = format_sse("values", {"x": 1})
|
||||
assert "id:" not in frame
|
||||
|
||||
|
||||
def test_normalize_stream_modes_none():
|
||||
from app.gateway.services import normalize_stream_modes
|
||||
|
||||
assert normalize_stream_modes(None) == ["values"]
|
||||
|
||||
|
||||
def test_normalize_stream_modes_string():
|
||||
from app.gateway.services import normalize_stream_modes
|
||||
|
||||
assert normalize_stream_modes("messages-tuple") == ["messages-tuple"]
|
||||
|
||||
|
||||
def test_normalize_stream_modes_list():
|
||||
from app.gateway.services import normalize_stream_modes
|
||||
|
||||
assert normalize_stream_modes(["values", "messages-tuple"]) == ["values", "messages-tuple"]
|
||||
|
||||
|
||||
def test_normalize_stream_modes_empty_list():
|
||||
from app.gateway.services import normalize_stream_modes
|
||||
|
||||
assert normalize_stream_modes([]) == ["values"]
|
||||
|
||||
|
||||
def test_normalize_input_none():
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
assert normalize_input(None) == {}
|
||||
|
||||
|
||||
def test_normalize_input_with_messages():
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
result = normalize_input({"messages": [{"role": "user", "content": "hi"}]})
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].content == "hi"
|
||||
|
||||
|
||||
def test_normalize_input_passthrough():
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
result = normalize_input({"custom_key": "value"})
|
||||
assert result == {"custom_key": "value"}
|
||||
|
||||
|
||||
def test_build_run_config_basic():
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
config = build_run_config("thread-1", None, None)
|
||||
assert config["configurable"]["thread_id"] == "thread-1"
|
||||
assert config["recursion_limit"] == 100
|
||||
|
||||
|
||||
def test_build_run_config_with_overrides():
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
config = build_run_config(
|
||||
"thread-1",
|
||||
{"configurable": {"model_name": "gpt-4"}, "tags": ["test"]},
|
||||
{"user": "alice"},
|
||||
)
|
||||
assert config["configurable"]["model_name"] == "gpt-4"
|
||||
assert config["tags"] == ["test"]
|
||||
assert config["metadata"]["user"] == "alice"
|
||||
@@ -0,0 +1,131 @@
|
||||
"""Tests for RunManager."""
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime import RunManager, RunStatus
|
||||
|
||||
ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager() -> RunManager:
|
||||
return RunManager()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_and_get(manager: RunManager):
|
||||
"""Created run should be retrievable with new fields."""
|
||||
record = await manager.create(
|
||||
"thread-1",
|
||||
"lead_agent",
|
||||
metadata={"key": "val"},
|
||||
kwargs={"input": {}},
|
||||
multitask_strategy="reject",
|
||||
)
|
||||
assert record.status == RunStatus.pending
|
||||
assert record.thread_id == "thread-1"
|
||||
assert record.assistant_id == "lead_agent"
|
||||
assert record.metadata == {"key": "val"}
|
||||
assert record.kwargs == {"input": {}}
|
||||
assert record.multitask_strategy == "reject"
|
||||
assert ISO_RE.match(record.created_at)
|
||||
assert ISO_RE.match(record.updated_at)
|
||||
|
||||
fetched = manager.get(record.run_id)
|
||||
assert fetched is record
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_status_transitions(manager: RunManager):
|
||||
"""Status should transition pending -> running -> success."""
|
||||
record = await manager.create("thread-1")
|
||||
assert record.status == RunStatus.pending
|
||||
|
||||
await manager.set_status(record.run_id, RunStatus.running)
|
||||
assert record.status == RunStatus.running
|
||||
assert ISO_RE.match(record.updated_at)
|
||||
|
||||
await manager.set_status(record.run_id, RunStatus.success)
|
||||
assert record.status == RunStatus.success
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cancel(manager: RunManager):
|
||||
"""Cancel should set abort_event and transition to interrupted."""
|
||||
record = await manager.create("thread-1")
|
||||
await manager.set_status(record.run_id, RunStatus.running)
|
||||
|
||||
cancelled = await manager.cancel(record.run_id)
|
||||
assert cancelled is True
|
||||
assert record.abort_event.is_set()
|
||||
assert record.status == RunStatus.interrupted
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cancel_not_inflight(manager: RunManager):
|
||||
"""Cancelling a completed run should return False."""
|
||||
record = await manager.create("thread-1")
|
||||
await manager.set_status(record.run_id, RunStatus.success)
|
||||
|
||||
cancelled = await manager.cancel(record.run_id)
|
||||
assert cancelled is False
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(manager: RunManager):
|
||||
"""Same thread should return multiple runs, newest first."""
|
||||
r1 = await manager.create("thread-1")
|
||||
r2 = await manager.create("thread-1")
|
||||
await manager.create("thread-2")
|
||||
|
||||
runs = await manager.list_by_thread("thread-1")
|
||||
assert len(runs) == 2
|
||||
assert runs[0].run_id == r2.run_id
|
||||
assert runs[1].run_id == r1.run_id
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_has_inflight(manager: RunManager):
|
||||
"""has_inflight should be True when a run is pending or running."""
|
||||
record = await manager.create("thread-1")
|
||||
assert await manager.has_inflight("thread-1") is True
|
||||
|
||||
await manager.set_status(record.run_id, RunStatus.success)
|
||||
assert await manager.has_inflight("thread-1") is False
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleanup(manager: RunManager):
|
||||
"""After cleanup, the run should be gone."""
|
||||
record = await manager.create("thread-1")
|
||||
run_id = record.run_id
|
||||
|
||||
await manager.cleanup(run_id, delay=0)
|
||||
assert manager.get(run_id) is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_set_status_with_error(manager: RunManager):
|
||||
"""Error message should be stored on the record."""
|
||||
record = await manager.create("thread-1")
|
||||
await manager.set_status(record.run_id, RunStatus.error, error="Something went wrong")
|
||||
assert record.status == RunStatus.error
|
||||
assert record.error == "Something went wrong"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nonexistent(manager: RunManager):
|
||||
"""Getting a nonexistent run should return None."""
|
||||
assert manager.get("does-not-exist") is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_defaults(manager: RunManager):
|
||||
"""Create with no optional args should use defaults."""
|
||||
record = await manager.create("thread-1")
|
||||
assert record.metadata == {}
|
||||
assert record.kwargs == {}
|
||||
assert record.multitask_strategy == "reject"
|
||||
assert record.assistant_id is None
|
||||
@@ -0,0 +1,159 @@
|
||||
"""Tests for deerflow.runtime.serialization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class _FakePydanticV2:
|
||||
"""Object with model_dump (Pydantic v2)."""
|
||||
|
||||
def model_dump(self):
|
||||
return {"key": "v2"}
|
||||
|
||||
|
||||
class _FakePydanticV1:
|
||||
"""Object with dict (Pydantic v1)."""
|
||||
|
||||
def dict(self):
|
||||
return {"key": "v1"}
|
||||
|
||||
|
||||
class _Unprintable:
|
||||
"""Object whose str() raises."""
|
||||
|
||||
def __str__(self):
|
||||
raise RuntimeError("no str")
|
||||
|
||||
def __repr__(self):
|
||||
return "<Unprintable>"
|
||||
|
||||
|
||||
def test_serialize_none():
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
assert serialize_lc_object(None) is None
|
||||
|
||||
|
||||
def test_serialize_primitives():
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
assert serialize_lc_object("hello") == "hello"
|
||||
assert serialize_lc_object(42) == 42
|
||||
assert serialize_lc_object(3.14) == 3.14
|
||||
assert serialize_lc_object(True) is True
|
||||
|
||||
|
||||
def test_serialize_dict():
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
obj = {"a": _FakePydanticV2(), "b": [1, "two"]}
|
||||
result = serialize_lc_object(obj)
|
||||
assert result == {"a": {"key": "v2"}, "b": [1, "two"]}
|
||||
|
||||
|
||||
def test_serialize_list():
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
result = serialize_lc_object([_FakePydanticV1(), 1])
|
||||
assert result == [{"key": "v1"}, 1]
|
||||
|
||||
|
||||
def test_serialize_tuple():
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
result = serialize_lc_object((_FakePydanticV2(),))
|
||||
assert result == [{"key": "v2"}]
|
||||
|
||||
|
||||
def test_serialize_pydantic_v2():
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
assert serialize_lc_object(_FakePydanticV2()) == {"key": "v2"}
|
||||
|
||||
|
||||
def test_serialize_pydantic_v1():
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
assert serialize_lc_object(_FakePydanticV1()) == {"key": "v1"}
|
||||
|
||||
|
||||
def test_serialize_fallback_str():
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
result = serialize_lc_object(object())
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_serialize_fallback_repr():
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
assert serialize_lc_object(_Unprintable()) == "<Unprintable>"
|
||||
|
||||
|
||||
def test_serialize_channel_values_strips_pregel_keys():
|
||||
from deerflow.runtime.serialization import serialize_channel_values
|
||||
|
||||
raw = {
|
||||
"messages": ["hello"],
|
||||
"__pregel_tasks": "internal",
|
||||
"__pregel_resuming": True,
|
||||
"__interrupt__": "stop",
|
||||
"title": "Test",
|
||||
}
|
||||
result = serialize_channel_values(raw)
|
||||
assert "messages" in result
|
||||
assert "title" in result
|
||||
assert "__pregel_tasks" not in result
|
||||
assert "__pregel_resuming" not in result
|
||||
assert "__interrupt__" not in result
|
||||
|
||||
|
||||
def test_serialize_channel_values_serializes_objects():
|
||||
from deerflow.runtime.serialization import serialize_channel_values
|
||||
|
||||
result = serialize_channel_values({"obj": _FakePydanticV2()})
|
||||
assert result == {"obj": {"key": "v2"}}
|
||||
|
||||
|
||||
def test_serialize_messages_tuple():
|
||||
from deerflow.runtime.serialization import serialize_messages_tuple
|
||||
|
||||
chunk = _FakePydanticV2()
|
||||
metadata = {"langgraph_node": "agent"}
|
||||
result = serialize_messages_tuple((chunk, metadata))
|
||||
assert result == [{"key": "v2"}, {"langgraph_node": "agent"}]
|
||||
|
||||
|
||||
def test_serialize_messages_tuple_non_dict_metadata():
|
||||
from deerflow.runtime.serialization import serialize_messages_tuple
|
||||
|
||||
result = serialize_messages_tuple((_FakePydanticV2(), "not-a-dict"))
|
||||
assert result == [{"key": "v2"}, {}]
|
||||
|
||||
|
||||
def test_serialize_messages_tuple_fallback():
|
||||
from deerflow.runtime.serialization import serialize_messages_tuple
|
||||
|
||||
result = serialize_messages_tuple("not-a-tuple")
|
||||
assert result == "not-a-tuple"
|
||||
|
||||
|
||||
def test_serialize_dispatcher_messages_mode():
|
||||
from deerflow.runtime.serialization import serialize
|
||||
|
||||
chunk = _FakePydanticV2()
|
||||
result = serialize((chunk, {"node": "x"}), mode="messages")
|
||||
assert result == [{"key": "v2"}, {"node": "x"}]
|
||||
|
||||
|
||||
def test_serialize_dispatcher_values_mode():
|
||||
from deerflow.runtime.serialization import serialize
|
||||
|
||||
result = serialize({"msg": "hi", "__pregel_tasks": "x"}, mode="values")
|
||||
assert result == {"msg": "hi"}
|
||||
|
||||
|
||||
def test_serialize_dispatcher_default_mode():
|
||||
from deerflow.runtime.serialization import serialize
|
||||
|
||||
result = serialize(_FakePydanticV1())
|
||||
assert result == {"key": "v1"}
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Tests for SSE frame formatting utilities."""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def _format_sse(event: str, data, *, event_id: str | None = None) -> str:
|
||||
from app.gateway.services import format_sse
|
||||
|
||||
return format_sse(event, data, event_id=event_id)
|
||||
|
||||
|
||||
def test_sse_end_event_data_null():
|
||||
"""End event should have data: null."""
|
||||
frame = _format_sse("end", None)
|
||||
assert "data: null" in frame
|
||||
|
||||
|
||||
def test_sse_metadata_event():
|
||||
"""Metadata event should include run_id and attempt."""
|
||||
frame = _format_sse("metadata", {"run_id": "abc", "attempt": 1}, event_id="123-0")
|
||||
assert "event: metadata" in frame
|
||||
assert "id: 123-0" in frame
|
||||
|
||||
|
||||
def test_sse_error_format():
|
||||
"""Error event should use message/name format."""
|
||||
frame = _format_sse("error", {"message": "boom", "name": "ValueError"})
|
||||
parsed = json.loads(frame.split("data: ")[1].split("\n")[0])
|
||||
assert parsed["message"] == "boom"
|
||||
assert parsed["name"] == "ValueError"
|
||||
@@ -0,0 +1,152 @@
|
||||
"""Tests for the in-memory StreamBridge implementation."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests for MemoryStreamBridge
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bridge() -> MemoryStreamBridge:
|
||||
return MemoryStreamBridge(queue_maxsize=256)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_publish_subscribe(bridge: MemoryStreamBridge):
|
||||
"""Three events followed by end should be received in order."""
|
||||
run_id = "run-1"
|
||||
|
||||
await bridge.publish(run_id, "metadata", {"run_id": run_id})
|
||||
await bridge.publish(run_id, "values", {"messages": []})
|
||||
await bridge.publish(run_id, "updates", {"step": 1})
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
received = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0):
|
||||
received.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
assert len(received) == 4
|
||||
assert received[0].event == "metadata"
|
||||
assert received[1].event == "values"
|
||||
assert received[2].event == "updates"
|
||||
assert received[3] is END_SENTINEL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_heartbeat(bridge: MemoryStreamBridge):
|
||||
"""When no events arrive within the heartbeat interval, yield a heartbeat."""
|
||||
run_id = "run-heartbeat"
|
||||
bridge._get_or_create_queue(run_id) # ensure queue exists
|
||||
|
||||
received = []
|
||||
|
||||
async def consumer():
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
received.append(entry)
|
||||
if entry is HEARTBEAT_SENTINEL:
|
||||
break
|
||||
|
||||
await asyncio.wait_for(consumer(), timeout=2.0)
|
||||
assert len(received) == 1
|
||||
assert received[0] is HEARTBEAT_SENTINEL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleanup(bridge: MemoryStreamBridge):
|
||||
"""After cleanup, the run's queue is removed."""
|
||||
run_id = "run-cleanup"
|
||||
await bridge.publish(run_id, "test", {})
|
||||
assert run_id in bridge._queues
|
||||
|
||||
await bridge.cleanup(run_id)
|
||||
assert run_id not in bridge._queues
|
||||
assert run_id not in bridge._counters
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_backpressure():
|
||||
"""With maxsize=1, publish should not block forever."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=1)
|
||||
run_id = "run-bp"
|
||||
|
||||
await bridge.publish(run_id, "first", {})
|
||||
|
||||
# Second publish should either succeed after queue drains or warn+drop
|
||||
# It should not hang indefinitely
|
||||
async def publish_second():
|
||||
await bridge.publish(run_id, "second", {})
|
||||
|
||||
# Give it a generous timeout — the publish timeout is 30s but we don't
|
||||
# want to wait that long in tests. Instead, drain the queue first.
|
||||
async def drain():
|
||||
await asyncio.sleep(0.05)
|
||||
bridge._queues[run_id].get_nowait()
|
||||
|
||||
await asyncio.gather(publish_second(), drain())
|
||||
assert bridge._queues[run_id].qsize() == 1
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_multiple_runs(bridge: MemoryStreamBridge):
|
||||
"""Two different run_ids should not interfere with each other."""
|
||||
await bridge.publish("run-a", "event-a", {"a": 1})
|
||||
await bridge.publish("run-b", "event-b", {"b": 2})
|
||||
await bridge.publish_end("run-a")
|
||||
await bridge.publish_end("run-b")
|
||||
|
||||
events_a = []
|
||||
async for entry in bridge.subscribe("run-a", heartbeat_interval=1.0):
|
||||
events_a.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
events_b = []
|
||||
async for entry in bridge.subscribe("run-b", heartbeat_interval=1.0):
|
||||
events_b.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
assert len(events_a) == 2
|
||||
assert events_a[0].event == "event-a"
|
||||
assert events_a[0].data == {"a": 1}
|
||||
|
||||
assert len(events_b) == 2
|
||||
assert events_b[0].event == "event-b"
|
||||
assert events_b[0].data == {"b": 2}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_event_id_format(bridge: MemoryStreamBridge):
|
||||
"""Event IDs should use timestamp-sequence format."""
|
||||
run_id = "run-id-format"
|
||||
await bridge.publish(run_id, "test", {"key": "value"})
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
received = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0):
|
||||
received.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
event = received[0]
|
||||
assert re.match(r"^\d+-\d+$", event.id), f"Expected timestamp-seq format, got {event.id}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_make_stream_bridge_defaults():
|
||||
"""make_stream_bridge() with no config yields a MemoryStreamBridge."""
|
||||
async with make_stream_bridge() as bridge:
|
||||
assert isinstance(bridge, MemoryStreamBridge)
|
||||
Reference in New Issue
Block a user