mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9a859085fa | |||
| dad3997459 | |||
| b67c2a4e56 |
@@ -59,7 +59,7 @@ smoke-test/
|
||||
2. **Check pnpm** - Package manager
|
||||
3. **Check uv** - Python package manager
|
||||
4. **Check nginx** - Reverse proxy
|
||||
5. **Check required ports** - Confirm that ports 2026, 3000, and 8001 are not occupied
|
||||
5. **Check required ports** - Confirm that ports 2026, 3000, 8001, and 2024 are not occupied
|
||||
|
||||
**Docker mode environment check** (if Docker is selected):
|
||||
1. **Check whether Docker is installed** - Run `docker --version`
|
||||
@@ -93,17 +93,17 @@ smoke-test/
|
||||
### Phase 5: Service Health Check
|
||||
|
||||
**Local mode health check**:
|
||||
1. **Check process status** - Confirm that Gateway, Frontend, and Nginx processes are all running
|
||||
1. **Check process status** - Confirm that LangGraph, Gateway, Frontend, and Nginx processes are all running
|
||||
2. **Check frontend service** - Visit `http://localhost:2026` and verify that the page loads
|
||||
3. **Check API Gateway** - Verify the `http://localhost:2026/health` endpoint
|
||||
4. **Check LangGraph-compatible API** - Verify the `/api/langgraph/*` route exposed by Gateway
|
||||
4. **Check LangGraph service** - Verify the availability of relevant endpoints
|
||||
5. **Frontend route smoke check** - Run `bash .agent/skills/smoke-test/scripts/frontend_check.sh` to verify key routes under `/workspace`
|
||||
|
||||
**Docker mode health check** (when using Docker):
|
||||
1. **Check container status** - Run `docker ps` and confirm that all containers are running
|
||||
2. **Check frontend service** - Visit `http://localhost:2026` and verify that the page loads
|
||||
3. **Check API Gateway** - Verify the `http://localhost:2026/health` endpoint
|
||||
4. **Check LangGraph-compatible API** - Verify the `/api/langgraph/*` route exposed by Gateway
|
||||
4. **Check LangGraph service** - Verify the availability of relevant endpoints
|
||||
5. **Frontend route smoke check** - Run `bash .agent/skills/smoke-test/scripts/frontend_check.sh` to verify key routes under `/workspace`
|
||||
|
||||
### Optional Functional Verification
|
||||
@@ -135,7 +135,7 @@ smoke-test/
|
||||
|
||||
The following warnings can appear during smoke testing and do not block a successful result:
|
||||
- Feishu/Lark SSL errors in Gateway logs (certificate verification failure) can be ignored if that channel is not enabled
|
||||
- Warnings in Gateway logs about missing methods in the custom checkpointer, such as `adelete_for_runs` or `aprune`, do not affect the core functionality
|
||||
- Warnings in LangGraph logs about missing methods in the custom checkpointer, such as `adelete_for_runs` or `aprune`, do not affect the core functionality
|
||||
|
||||
## Key Tools
|
||||
|
||||
|
||||
@@ -138,6 +138,7 @@ This document describes the detailed operating steps for each phase of the DeerF
|
||||
lsof -i :2026 # Main port
|
||||
lsof -i :3000 # Frontend
|
||||
lsof -i :8001 # Gateway
|
||||
lsof -i :2024 # LangGraph
|
||||
```
|
||||
|
||||
**Success Criteria**: All ports are free, or they are occupied only by DeerFlow-related processes.
|
||||
@@ -257,7 +258,7 @@ This document describes the detailed operating steps for each phase of the DeerF
|
||||
**Steps**:
|
||||
1. Run `make dev-daemon` (background mode)
|
||||
|
||||
**Description**: This command starts all services (Gateway embedded runtime, Frontend, Nginx).
|
||||
**Description**: This command starts all services (LangGraph, Gateway, Frontend, Nginx).
|
||||
|
||||
**Notes**:
|
||||
- `make dev` runs in the foreground and stops with Ctrl+C
|
||||
@@ -271,6 +272,7 @@ This document describes the detailed operating steps for each phase of the DeerF
|
||||
**Steps**:
|
||||
1. Wait 90-120 seconds for all services to start completely
|
||||
2. You can monitor startup progress by checking these log files:
|
||||
- `logs/langgraph.log`
|
||||
- `logs/gateway.log`
|
||||
- `logs/frontend.log`
|
||||
- `logs/nginx.log`
|
||||
@@ -314,10 +316,11 @@ This document describes the detailed operating steps for each phase of the DeerF
|
||||
**Steps**:
|
||||
1. Run the following command to check processes:
|
||||
```bash
|
||||
ps aux | grep -E "(uvicorn|next|nginx)" | grep -v grep
|
||||
ps aux | grep -E "(langgraph|uvicorn|next|nginx)" | grep -v grep
|
||||
```
|
||||
|
||||
**Success Criteria**: Confirm that the following processes are running:
|
||||
- LangGraph (`langgraph dev`)
|
||||
- Gateway (`uvicorn app.gateway.app:app`)
|
||||
- Frontend (`next dev` or `next start`)
|
||||
- Nginx (`nginx`)
|
||||
@@ -353,11 +356,10 @@ curl http://localhost:2026/health
|
||||
|
||||
---
|
||||
|
||||
#### 5.1.4 Check LangGraph-compatible API
|
||||
#### 5.1.4 Check LangGraph Service
|
||||
|
||||
**Steps**:
|
||||
1. Visit `http://localhost:2026/api/langgraph/assistants/lead_agent` to verify Gateway's LangGraph-compatible API route is reachable.
|
||||
2. A `401` response is acceptable when authentication is enabled and no session cookie is provided.
|
||||
1. Visit relevant LangGraph endpoints to verify availability
|
||||
|
||||
---
|
||||
|
||||
@@ -371,6 +373,7 @@ curl http://localhost:2026/health
|
||||
- `deer-flow-nginx`
|
||||
- `deer-flow-frontend`
|
||||
- `deer-flow-gateway`
|
||||
- `deer-flow-langgraph` (if not in gateway mode)
|
||||
|
||||
---
|
||||
|
||||
@@ -403,11 +406,10 @@ curl http://localhost:2026/health
|
||||
|
||||
---
|
||||
|
||||
#### 5.2.4 Check LangGraph-compatible API
|
||||
#### 5.2.4 Check LangGraph Service
|
||||
|
||||
**Steps**:
|
||||
1. Visit `http://localhost:2026/api/langgraph/assistants/lead_agent` to verify Gateway's LangGraph-compatible API route is reachable.
|
||||
2. A `401` response is acceptable when authentication is enabled and no session cookie is provided.
|
||||
1. Visit relevant LangGraph endpoints to verify availability
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -254,6 +254,7 @@ Processes exit quickly after running `make dev-daemon`.
|
||||
**Solutions**:
|
||||
1. Check log files:
|
||||
```bash
|
||||
tail -f logs/langgraph.log
|
||||
tail -f logs/gateway.log
|
||||
tail -f logs/frontend.log
|
||||
tail -f logs/nginx.log
|
||||
@@ -366,7 +367,24 @@ Errors appear in `gateway.log`.
|
||||
uv sync
|
||||
```
|
||||
|
||||
4. Confirm that the Gateway process is running normally.
|
||||
4. Confirm that the LangGraph service is running normally (if not in gateway mode)
|
||||
|
||||
---
|
||||
|
||||
### Issue: LangGraph Fails to Start
|
||||
|
||||
**Symptoms**:
|
||||
Errors appear in `langgraph.log`.
|
||||
|
||||
**Solutions**:
|
||||
1. Check LangGraph logs:
|
||||
```bash
|
||||
tail -f logs/langgraph.log
|
||||
```
|
||||
|
||||
2. Check config.yaml
|
||||
3. Check whether Python dependencies are complete
|
||||
4. Confirm that port 2024 is not occupied
|
||||
|
||||
---
|
||||
|
||||
@@ -501,7 +519,7 @@ Accessing `/health` returns an error or times out.
|
||||
|
||||
2. Confirm that config.yaml exists and has valid formatting
|
||||
3. Check whether Python dependencies are complete
|
||||
4. Confirm that the Gateway process is running normally.
|
||||
4. Confirm that the LangGraph service is running normally
|
||||
|
||||
**Solutions** (Docker mode):
|
||||
1. Check gateway container logs:
|
||||
@@ -511,7 +529,7 @@ Accessing `/health` returns an error or times out.
|
||||
|
||||
2. Confirm that config.yaml is mounted correctly
|
||||
3. Check whether Python dependencies are complete
|
||||
4. Confirm that the Gateway process is running normally.
|
||||
4. Confirm that the LangGraph service is running normally
|
||||
|
||||
---
|
||||
|
||||
@@ -521,7 +539,7 @@ Accessing `/health` returns an error or times out.
|
||||
|
||||
#### View All Service Processes
|
||||
```bash
|
||||
ps aux | grep -E "(uvicorn|next|nginx)" | grep -v grep
|
||||
ps aux | grep -E "(langgraph|uvicorn|next|nginx)" | grep -v grep
|
||||
```
|
||||
|
||||
#### View Service Logs
|
||||
@@ -530,6 +548,7 @@ ps aux | grep -E "(uvicorn|next|nginx)" | grep -v grep
|
||||
tail -f logs/*.log
|
||||
|
||||
# View specific service logs
|
||||
tail -f logs/langgraph.log
|
||||
tail -f logs/gateway.log
|
||||
tail -f logs/frontend.log
|
||||
tail -f logs/nginx.log
|
||||
|
||||
@@ -65,7 +65,7 @@ if ! command -v lsof >/dev/null 2>&1; then
|
||||
echo " Install lsof and rerun this check"
|
||||
all_passed=false
|
||||
else
|
||||
for port in 2026 3000 8001; do
|
||||
for port in 2026 3000 8001 2024; do
|
||||
if lsof -i :$port >/dev/null 2>&1; then
|
||||
echo "⚠ Port $port is already in use:"
|
||||
lsof -i :$port | head -2
|
||||
|
||||
@@ -54,6 +54,7 @@ echo "=========================================="
|
||||
echo ""
|
||||
echo "🌐 Access URL: http://localhost:2026"
|
||||
echo "📋 View logs:"
|
||||
echo " - logs/langgraph.log"
|
||||
echo " - logs/gateway.log"
|
||||
echo " - logs/frontend.log"
|
||||
echo " - logs/nginx.log"
|
||||
|
||||
@@ -76,11 +76,12 @@ if [ "$mode" = "docker" ]; then
|
||||
all_passed=false
|
||||
fi
|
||||
else
|
||||
summary_hint="logs/{gateway,frontend,nginx}.log"
|
||||
summary_hint="logs/{langgraph,gateway,frontend,nginx}.log"
|
||||
print_step "1. Checking local service ports..."
|
||||
check_listen_port "Nginx" 2026
|
||||
check_listen_port "Frontend" 3000
|
||||
check_listen_port "Gateway" 8001
|
||||
check_listen_port "LangGraph" 2024
|
||||
fi
|
||||
echo ""
|
||||
|
||||
@@ -103,8 +104,8 @@ else
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "5. Checking LangGraph-compatible Gateway API..."
|
||||
check_http_status "LangGraph-compatible Gateway API" "http://localhost:2026/api/langgraph/assistants/lead_agent" "200|401"
|
||||
echo "5. Checking LangGraph service..."
|
||||
check_http_status "LangGraph service" "http://localhost:2024/" "200|301|302|307|308|404"
|
||||
echo ""
|
||||
|
||||
echo "=========================================="
|
||||
|
||||
@@ -78,7 +78,7 @@
|
||||
- [x] Container status - {{status_containers}}
|
||||
- [x] Frontend service - {{status_frontend}}
|
||||
- [x] API Gateway - {{status_api_gateway}}
|
||||
- [x] LangGraph-compatible Gateway API - {{status_langgraph}}
|
||||
- [x] LangGraph service - {{status_langgraph}}
|
||||
|
||||
**Phase Status**: {{stage5_status}}
|
||||
|
||||
@@ -147,6 +147,7 @@ Commit Message: {{git_commit_message}}
|
||||
| deer-flow-nginx | {{nginx_status}} | {{nginx_uptime}} |
|
||||
| deer-flow-frontend | {{frontend_status}} | {{frontend_uptime}} |
|
||||
| deer-flow-gateway | {{gateway_status}} | {{gateway_uptime}} |
|
||||
| deer-flow-langgraph | {{langgraph_status}} | {{langgraph_uptime}} |
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@
|
||||
- [x] Process status - {{status_processes}}
|
||||
- [x] Frontend service - {{status_frontend}}
|
||||
- [x] API Gateway - {{status_api_gateway}}
|
||||
- [x] LangGraph-compatible Gateway API - {{status_langgraph}}
|
||||
- [x] LangGraph service - {{status_langgraph}}
|
||||
|
||||
**Phase Status**: {{stage5_status}}
|
||||
|
||||
@@ -152,7 +152,7 @@ Commit Message: {{git_commit_message}}
|
||||
| Nginx | {{nginx_status}} | {{nginx_endpoint}} |
|
||||
| Frontend | {{frontend_status}} | {{frontend_endpoint}} |
|
||||
| Gateway | {{gateway_status}} | {{gateway_endpoint}} |
|
||||
| Gateway LangGraph API | {{langgraph_status}} | {{langgraph_endpoint}} |
|
||||
| LangGraph | {{langgraph_status}} | {{langgraph_endpoint}} |
|
||||
|
||||
---
|
||||
|
||||
@@ -166,7 +166,7 @@ Commit Message: {{git_commit_message}}
|
||||
|
||||
### If the Test Fails
|
||||
1. [ ] Review references/troubleshooting.md for common solutions
|
||||
2. [ ] Check local logs: `logs/{gateway,frontend,nginx}.log`
|
||||
2. [ ] Check local logs: `logs/{langgraph,gateway,frontend,nginx}.log`
|
||||
3. [ ] Verify configuration file format and content
|
||||
4. [ ] If needed, fully reset the environment: `make stop && make clean && make install && make dev-daemon`
|
||||
|
||||
|
||||
+5
-10
@@ -122,14 +122,10 @@ Blocking-IO runtime gate (`tests/blocking_io/`):
|
||||
`tests/support/detectors/blocking_io_runtime.py`). Any sync blocking IO
|
||||
call whose stack passes through DeerFlow business code while running on
|
||||
the asyncio event loop raises `BlockingError` and fails the test.
|
||||
- Regression anchors live there: `test_skills_load.py` (locks the
|
||||
- Two regression anchors live there: `test_skills_load.py` (locks the
|
||||
`asyncio.to_thread` offload around `LocalSkillStorage.load_skills`, fix
|
||||
for #1917); `test_sqlite_lifespan.py` (locks the offload around
|
||||
SQLite path resolution plus `ensure_sqlite_parent_dir`, fix for #1912);
|
||||
`test_jsonl_run_event_store.py` (locks `JsonlRunEventStore`'s async
|
||||
API offloading its file IO via `asyncio.to_thread`, fix #3084); and
|
||||
`test_uploads_middleware.py` (locks `UploadsMiddleware.abefore_agent`
|
||||
offloading the uploads-directory scan off the event loop).
|
||||
for #1917) and `test_sqlite_lifespan.py` (locks the offload around
|
||||
SQLite path resolution plus `ensure_sqlite_parent_dir`, fix for #1912).
|
||||
- `test_gate_smoke.py` is a meta-test asserting the gate actually catches
|
||||
unoffloaded blocking IO and that the `@pytest.mark.allow_blocking_io`
|
||||
opt-out works.
|
||||
@@ -281,7 +277,6 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S
|
||||
- When a persistent `RunStore` is configured, `get()` and `list_by_thread()` hydrate historical runs from the store. In-memory records win for the same `run_id` so task, abort, and stream-control state stays attached to active local runs.
|
||||
- `cancel()` and `create_or_reject(..., multitask_strategy="interrupt"|"rollback")` persist interrupted status through `RunStore.update_status()`, matching normal `set_status()` transitions.
|
||||
- Store-only hydrated runs are readable history. If the current worker has no in-memory task/control state for that run, cancellation APIs can return 409 because this worker cannot stop the task.
|
||||
- `POST /wait` (both thread-scoped and `/api/runs/wait`) drains the stream bridge via `wait_for_run_completion()` instead of bare `await record.task`, so it honours the run's `on_disconnect` setting and cancels the background run on real client disconnect rather than returning a stale checkpoint (issue #3265).
|
||||
|
||||
Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs.
|
||||
|
||||
@@ -347,7 +342,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
||||
- **Cache invalidation**: Detects config file changes via mtime comparison
|
||||
- **Transports**: stdio (command-based), SSE, HTTP
|
||||
- **OAuth (HTTP/SSE)**: Supports token endpoint flows (`client_credentials`, `refresh_token`) with automatic token refresh + Authorization header injection
|
||||
- **Runtime updates**: Gateway API saves to extensions_config.json; the Gateway-embedded runtime detects changes via mtime
|
||||
- **Runtime updates**: Gateway API saves to extensions_config.json; LangGraph detects via mtime
|
||||
|
||||
### Skills System (`packages/harness/deerflow/skills/`)
|
||||
|
||||
@@ -374,7 +369,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
||||
|
||||
### IM Channels System (`app/channels/`)
|
||||
|
||||
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
|
||||
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via the LangGraph Server.
|
||||
|
||||
|
||||
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
|
||||
|
||||
@@ -173,8 +173,6 @@ def _extract_response_text(result: dict | list) -> str:
|
||||
|
||||
# Stop at the last human message — anything before it is a previous turn
|
||||
if msg_type == "human":
|
||||
if _is_hidden_human_control_message(msg):
|
||||
continue
|
||||
break
|
||||
|
||||
# Check for tool messages from ask_clarification (interrupt case)
|
||||
@@ -315,8 +313,6 @@ def _extract_artifacts(result: dict | list) -> list[str]:
|
||||
continue
|
||||
# Stop at the last human message — anything before it is a previous turn
|
||||
if msg.get("type") == "human":
|
||||
if _is_hidden_human_control_message(msg):
|
||||
continue
|
||||
break
|
||||
# Look for AI messages with present_files tool calls
|
||||
if msg.get("type") == "ai":
|
||||
@@ -329,18 +325,6 @@ def _extract_artifacts(result: dict | list) -> list[str]:
|
||||
return artifacts
|
||||
|
||||
|
||||
def _is_hidden_human_control_message(msg: Mapping[str, Any]) -> bool:
|
||||
"""Return whether a human message is an internal control message hidden from UI."""
|
||||
if msg.get("type") != "human":
|
||||
return False
|
||||
|
||||
additional_kwargs = msg.get("additional_kwargs")
|
||||
if not isinstance(additional_kwargs, Mapping):
|
||||
return False
|
||||
|
||||
return additional_kwargs.get("hide_from_ui") is True
|
||||
|
||||
|
||||
def _format_artifact_text(artifacts: list[str]) -> str:
|
||||
"""Format artifact paths into a human-readable text block listing filenames."""
|
||||
import posixpath
|
||||
|
||||
@@ -276,9 +276,10 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
|
||||
|
||||
logger.info(f"MCP configuration updated and saved to: {config_path}")
|
||||
|
||||
# Reload the Gateway configuration and update the global cache. The
|
||||
# agent runtime lives in Gateway, so this keeps API reads and tool
|
||||
# execution aligned after extensions_config.json changes.
|
||||
# NOTE: No need to reload/reset cache here - LangGraph Server (separate process)
|
||||
# will detect config file changes via mtime and reinitialize MCP tools automatically
|
||||
|
||||
# Reload the configuration and update the global cache
|
||||
reloaded_config = reload_extensions_config()
|
||||
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in reloaded_config.mcp_servers.items()}
|
||||
return McpConfigResponse(mcp_servers=servers)
|
||||
|
||||
@@ -7,6 +7,7 @@ is reused so that conversation history is preserved across calls.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
@@ -16,7 +17,7 @@ from fastapi.responses import StreamingResponse
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -65,25 +66,24 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
||||
Otherwise a new temporary thread is created.
|
||||
"""
|
||||
thread_id = _resolve_thread_id(body)
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
completed = True
|
||||
if record.task is not None:
|
||||
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
|
||||
|
||||
if completed:
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
await record.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
return {"status": record.status.value, "error": record.error}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -175,25 +175,24 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
|
||||
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
||||
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
||||
"""Create a run and block until it completes, returning the final state."""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
completed = True
|
||||
if record.task is not None:
|
||||
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
|
||||
|
||||
if completed:
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
await record.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
return {"status": record.status.value, "error": record.error}
|
||||
|
||||
@@ -278,12 +277,7 @@ async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingRe
|
||||
)
|
||||
|
||||
|
||||
# Register GET and POST as separate routes so each method gets a unique OpenAPI
|
||||
# operationId. ``api_route(methods=["GET", "POST"])`` shares one route registration
|
||||
# across both methods, which makes FastAPI emit the same ``operationId`` twice and
|
||||
# warn about a duplicate operation id during OpenAPI generation.
|
||||
@router.get("/{thread_id}/runs/{run_id}/stream", response_model=None)
|
||||
@router.post("/{thread_id}/runs/{run_id}/stream", response_model=None)
|
||||
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def stream_existing_run(
|
||||
thread_id: str,
|
||||
|
||||
@@ -402,51 +402,3 @@ async def sse_consumer(
|
||||
if record.status in (RunStatus.pending, RunStatus.running):
|
||||
if record.on_disconnect == DisconnectMode.cancel:
|
||||
await run_mgr.cancel(record.run_id)
|
||||
|
||||
|
||||
async def wait_for_run_completion(
|
||||
bridge: StreamBridge,
|
||||
record: RunRecord,
|
||||
request: Request,
|
||||
run_mgr: RunManager,
|
||||
) -> bool:
|
||||
"""Block until the run publishes ``END_SENTINEL``, honouring on_disconnect.
|
||||
|
||||
The non-streaming ``/wait`` endpoints used to ``await record.task``
|
||||
directly with no disconnect handling. When the client (or an
|
||||
intermediate HTTP proxy) timed out during a long tool call such as
|
||||
``pip install``, the handler would swallow ``CancelledError`` and
|
||||
serialize whatever checkpoint happened to exist — masking a half-finished
|
||||
run as a normal completion (issue #3265).
|
||||
|
||||
This helper consumes the same bridge that ``sse_consumer`` does so the
|
||||
wait path shares its disconnect semantics: each wake-up polls
|
||||
``request.is_disconnected()``; on a real disconnect it cancels the
|
||||
background run when ``record.on_disconnect`` is ``cancel``. The bridge's
|
||||
heartbeat sentinels guarantee at least one wake-up per
|
||||
``heartbeat_interval`` even when the agent emits no events for a while.
|
||||
|
||||
Returns:
|
||||
``True`` when ``END_SENTINEL`` was observed (run reached a terminal
|
||||
state), ``False`` when the loop exited because the client
|
||||
disconnected. Callers must skip checkpoint serialization on
|
||||
``False`` so a partial checkpoint is not returned as a normal
|
||||
response.
|
||||
"""
|
||||
completed = False
|
||||
try:
|
||||
async for entry in bridge.subscribe(record.run_id):
|
||||
# END_SENTINEL means the run reached a terminal state; honour it
|
||||
# even if the client just disconnected so the caller still serializes
|
||||
# the real final checkpoint.
|
||||
if entry is END_SENTINEL:
|
||||
completed = True
|
||||
return True
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
# Heartbeats and regular events: keep waiting for END_SENTINEL.
|
||||
return completed
|
||||
finally:
|
||||
if not completed and record.status in (RunStatus.pending, RunStatus.running):
|
||||
if record.on_disconnect == DisconnectMode.cancel:
|
||||
await run_mgr.cancel(record.run_id)
|
||||
|
||||
@@ -4,12 +4,10 @@
|
||||
|
||||
| 模式 | 启动命令 | Auth 层 | 端口 |
|
||||
|------|---------|---------|------|
|
||||
| 标准模式 | `make dev` | Gateway AuthMiddleware(全量) | 2026 (nginx) |
|
||||
| 标准模式 | `make dev` | Gateway AuthMiddleware + LangGraph auth | 2026 (nginx) |
|
||||
| Gateway 模式 | `make dev-pro` | Gateway AuthMiddleware(全量) | 2026 (nginx) |
|
||||
| 直连 Gateway | `cd backend && make gateway` | Gateway AuthMiddleware | 8001 |
|
||||
| 直连 LangGraph 兼容性 | 手动运行 LangGraph 工具链时使用 | LangGraph auth | 2024 |
|
||||
|
||||
`make dev`、Docker dev 和生产部署默认都运行 Gateway embedded runtime。
|
||||
`app.gateway.langgraph_auth` 仅用于保留的直连 LangGraph 工具链 / Studio 兼容性测试,不是标准服务启动路径。
|
||||
| 直连 LangGraph | `cd backend && make dev` | LangGraph auth | 2024 |
|
||||
|
||||
每种模式下都需执行以下测试。
|
||||
|
||||
@@ -23,8 +21,10 @@
|
||||
# 清除已有数据
|
||||
rm -f backend/.deer-flow/data/deerflow.db
|
||||
|
||||
# 启动标准模式(Gateway embedded runtime)
|
||||
make dev
|
||||
# 选择模式启动
|
||||
make dev # 标准模式
|
||||
# 或
|
||||
make dev-pro # Gateway 模式
|
||||
```
|
||||
|
||||
**验证点:**
|
||||
@@ -57,7 +57,7 @@ make dev
|
||||
|
||||
## 二、接口流程测试
|
||||
|
||||
> 以下用 `BASE=http://localhost:2026` 为例。标准模式经 nginx 暴露此地址。
|
||||
> 以下用 `BASE=http://localhost:2026` 为例。标准模式和 Gateway 模式都用此地址。
|
||||
> 直连测试替换为对应端口。
|
||||
>
|
||||
> **CSRF token 提取**:多处用到从 cookie jar 提取 CSRF token,统一使用:
|
||||
@@ -211,18 +211,20 @@ curl -s -X POST $BASE/api/threads/search \
|
||||
|
||||
**预期:** 返回 0 或仅包含 user2 自己的 thread
|
||||
|
||||
### 2.3 LangGraph-compatible Gateway 路由隔离
|
||||
### 2.3 标准模式 LangGraph Server 隔离
|
||||
|
||||
#### TC-API-10: LangGraph-compatible 端点需要 cookie
|
||||
> 仅在标准模式下测试。Gateway 模式不跑 LangGraph Server。
|
||||
|
||||
#### TC-API-10: LangGraph 端点需要 cookie
|
||||
|
||||
```bash
|
||||
# 不带 cookie 访问 LangGraph-compatible 接口
|
||||
# 不带 cookie 访问 LangGraph 接口
|
||||
curl -s -w "%{http_code}" $BASE/api/langgraph/threads
|
||||
```
|
||||
|
||||
**预期:** 401
|
||||
|
||||
#### TC-API-11: LangGraph-compatible 路由带 cookie 可访问
|
||||
#### TC-API-11: LangGraph 带 cookie 可访问
|
||||
|
||||
```bash
|
||||
curl -s $BASE/api/langgraph/threads -b user1.txt | jq length
|
||||
@@ -230,10 +232,10 @@ curl -s $BASE/api/langgraph/threads -b user1.txt | jq length
|
||||
|
||||
**预期:** 200,返回 user1 的 thread 列表
|
||||
|
||||
#### TC-API-12: LangGraph-compatible 路由隔离 — 用户只看到自己的
|
||||
#### TC-API-12: LangGraph 隔离 — 用户只看到自己的
|
||||
|
||||
```bash
|
||||
# user2 查 threads
|
||||
# user2 查 LangGraph threads
|
||||
curl -s $BASE/api/langgraph/threads -b user2.txt | jq length
|
||||
```
|
||||
|
||||
@@ -1232,11 +1234,21 @@ P2=$(awk -F': ' '/^password:/ {print $2}' /tmp/deerflow-reset-p2.txt)
|
||||
## 七、模式差异测试
|
||||
|
||||
> 以下用 `GW=http://localhost:8001` 表示直连 Gateway,`BASE=http://localhost:2026` 表示经 nginx。
|
||||
> 标准启动命令:`make dev`(或 `./scripts/serve.sh --dev`)。
|
||||
> Gateway 模式启动命令:`make dev-pro`(或 `./scripts/serve.sh --dev --gateway`)。
|
||||
|
||||
### 7.1 标准启动模式
|
||||
### 7.1 标准模式独有
|
||||
|
||||
#### TC-MODE-01: Gateway AuthMiddleware 的 token_version 检查
|
||||
> 启动命令:`make dev`(或 `./scripts/serve.sh --dev`)
|
||||
|
||||
#### TC-MODE-01: LangGraph Server 独立运行,需 cookie
|
||||
|
||||
```bash
|
||||
# 无 cookie 访问 LangGraph
|
||||
curl -s -w "%{http_code}" -o /dev/null $BASE/api/langgraph/threads/search
|
||||
# 预期: 403(LangGraph auth handler 拒绝)
|
||||
```
|
||||
|
||||
#### TC-MODE-02: LangGraph auth 的 token_version 检查
|
||||
|
||||
```bash
|
||||
# 登录拿 cookie
|
||||
@@ -1249,9 +1261,9 @@ curl -s -X POST $BASE/api/v1/auth/change-password \
|
||||
-b cookies.txt -H "Content-Type: application/json" -H "X-CSRF-Token: $CSRF" \
|
||||
-d '{"current_password":"正确密码","new_password":"NewPass1!"}' -c new_cookies.txt
|
||||
|
||||
# 用旧 cookie 访问 LangGraph-compatible 路由
|
||||
# 用旧 cookie 访问 LangGraph
|
||||
curl -s -w "%{http_code}" $BASE/api/langgraph/threads/search -b cookies.txt
|
||||
# 预期: 401(token_version 不匹配)
|
||||
# 预期: 403(token_version 不匹配)
|
||||
|
||||
# 用新 cookie 访问
|
||||
CSRF2=$(grep csrf_token new_cookies.txt | awk '{print $NF}')
|
||||
@@ -1260,7 +1272,7 @@ curl -s -w "%{http_code}" -X POST $BASE/api/langgraph/threads/search \
|
||||
# 预期: 200
|
||||
```
|
||||
|
||||
#### TC-MODE-02: Gateway owner filter 隔离
|
||||
#### TC-MODE-03: LangGraph auth 的 owner filter 隔离
|
||||
|
||||
```bash
|
||||
# user1 创建 thread
|
||||
@@ -1285,9 +1297,18 @@ print('OK: user2 sees', len(threads), 'threads, none belong to user1')
|
||||
"
|
||||
```
|
||||
|
||||
#### TC-MODE-03: 所有请求经 AuthMiddleware
|
||||
### 7.2 Gateway 模式独有
|
||||
|
||||
> 启动命令:`make dev-pro`(或 `./scripts/serve.sh --dev --gateway`)
|
||||
> 无 LangGraph Server 进程,agent runtime 嵌入 Gateway。
|
||||
|
||||
#### TC-MODE-04: 所有请求经 AuthMiddleware
|
||||
|
||||
```bash
|
||||
# 确认 LangGraph Server 未运行
|
||||
curl -s -w "%{http_code}" -o /dev/null http://localhost:2024/ok
|
||||
# 预期: 000(连接被拒)
|
||||
|
||||
# Gateway API 受保护
|
||||
curl -s -w "%{http_code}" -o /dev/null $BASE/api/models
|
||||
# 预期: 401
|
||||
@@ -1298,7 +1319,7 @@ curl -s -w "%{http_code}" -o /dev/null -X POST $BASE/api/langgraph/threads/searc
|
||||
# 预期: 401
|
||||
```
|
||||
|
||||
#### TC-MODE-04: 标准模式下完整 auth 流程
|
||||
#### TC-MODE-05: Gateway 模式下完整 auth 流程
|
||||
|
||||
```bash
|
||||
# 登录
|
||||
@@ -1313,7 +1334,7 @@ curl -s -X POST $BASE/api/langgraph/threads \
|
||||
-d '{"metadata":{}}' | python3 -c "import sys,json; print(json.load(sys.stdin)['thread_id'])"
|
||||
# 预期: 返回 thread_id
|
||||
|
||||
# CSRF 保护(CSRFMiddleware 覆盖所有 Gateway 路由)
|
||||
# CSRF 保护(Gateway 模式下 CSRFMiddleware 直接覆盖所有路由)
|
||||
curl -s -w "%{http_code}" -o /dev/null -X POST $BASE/api/langgraph/threads \
|
||||
-b cookies.txt -H "Content-Type: application/json" -d '{"metadata":{}}'
|
||||
# 预期: 403(CSRF token missing)
|
||||
@@ -1412,7 +1433,7 @@ done
|
||||
|
||||
### 7.4 Docker 部署
|
||||
|
||||
> 启动命令:`./scripts/deploy.sh`
|
||||
> 启动命令:`./scripts/deploy.sh`(标准)或 `./scripts/deploy.sh --gateway`(Gateway 模式)
|
||||
> Docker Compose 文件:`docker/docker-compose.yaml`
|
||||
>
|
||||
> 前置条件:
|
||||
@@ -1521,16 +1542,16 @@ docker logs deer-flow-gateway 2>&1 | grep -iE "Password: .{15,}" && echo "FAIL:
|
||||
- 容器日志输出**路径**(不是密码本身),符合 CodeQL `py/clear-text-logging-sensitive-data` 规则
|
||||
- `grep "Password:"` 在日志中**应当无匹配**(旧行为已废弃,simplify pass 移除了日志泄露路径)
|
||||
|
||||
#### TC-DOCKER-06: Docker 部署
|
||||
#### TC-DOCKER-06: Gateway 模式 Docker 部署
|
||||
|
||||
```bash
|
||||
# 标准 Docker 模式:runtime 嵌入 gateway 容器
|
||||
./scripts/deploy.sh
|
||||
# Gateway 模式:无 langgraph 容器
|
||||
./scripts/deploy.sh --gateway
|
||||
sleep 15
|
||||
|
||||
# 确认 gateway 容器存在
|
||||
docker ps --filter name=deer-flow-gateway --format '{{.Names}}'
|
||||
# 预期: deer-flow-gateway
|
||||
# 确认 langgraph 容器不存在
|
||||
docker ps --filter name=deer-flow-langgraph --format '{{.Names}}' | wc -l
|
||||
# 预期: 0
|
||||
|
||||
# auth 流程正常:未登录受保护接口返回 401
|
||||
curl -s -w "%{http_code}" -o /dev/null $BASE/api/models
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
# Blocking IO detection usage and maintenance
|
||||
|
||||
This document describes how to use and maintain DeerFlow backend blocking-IO
|
||||
detection for async event-loop safety.
|
||||
|
||||
The goal is narrow: find and prevent synchronous IO from blocking backend
|
||||
async event-loop paths. Static and runtime detection are complementary, but
|
||||
they have different jobs.
|
||||
|
||||
## Static detector
|
||||
|
||||
The static detector is the discovery tool. It scans backend source code and
|
||||
reports candidate blocking-IO call sites that may need human review.
|
||||
|
||||
Run it from the repository root:
|
||||
|
||||
```bash
|
||||
make detect-blocking-io
|
||||
```
|
||||
|
||||
Or from `backend/`:
|
||||
|
||||
```bash
|
||||
make detect-blocking-io
|
||||
```
|
||||
|
||||
The report is written to:
|
||||
|
||||
```text
|
||||
.deer-flow/blocking-io-findings.json
|
||||
```
|
||||
|
||||
Use this output for review and triage. A static finding is a candidate, not
|
||||
proof that production blocks the event loop at runtime. The current static
|
||||
rules are intentionally broad; prefer triaging existing output before adding
|
||||
new static rules.
|
||||
|
||||
Add a static rule only when review finds a recurring high-risk blocking
|
||||
pattern that is invisible to the current detector.
|
||||
|
||||
## Runtime detector
|
||||
|
||||
The runtime detector is the CI regression guard. It uses Blockbuster to fail a
|
||||
focused test when code under `app.*` or `deerflow.*` performs blocking IO on
|
||||
the asyncio event-loop thread.
|
||||
|
||||
Run it from `backend/`:
|
||||
|
||||
```bash
|
||||
make test-blocking-io
|
||||
```
|
||||
|
||||
The runtime gate starts from confirmed production bugs and protects those
|
||||
paths from regressing. It does not prove that the entire backend is free of
|
||||
blocking IO; it only covers the production paths exercised by
|
||||
`backend/tests/blocking_io/`.
|
||||
|
||||
## Maintenance workflow
|
||||
|
||||
Use the static detector to find candidates, then use review to decide which
|
||||
async production paths are worth protecting in CI.
|
||||
|
||||
The normal workflow is:
|
||||
|
||||
1. Run the static detector to find backend blocking-IO candidates.
|
||||
2. Use human review to pick high-risk production async paths.
|
||||
3. Add or update a focused runtime anchor in `backend/tests/blocking_io/`.
|
||||
4. Let CI prevent that path from regressing.
|
||||
|
||||
Runtime detection has two maintenance paths.
|
||||
|
||||
### Add a runtime rule
|
||||
|
||||
Add a runtime rule when Blockbuster's default rules do not cover a generic
|
||||
blocking primitive used by production code.
|
||||
|
||||
Rules belong in:
|
||||
|
||||
```text
|
||||
backend/tests/support/detectors/blocking_io_runtime.py
|
||||
```
|
||||
|
||||
Add them to `_PROJECT_BLOCKING_RULES`, not directly inside individual tests.
|
||||
Keeping rules centralized makes it clear which extra primitives DeerFlow
|
||||
expects Blockbuster to catch.
|
||||
|
||||
Example shape:
|
||||
|
||||
```python
|
||||
import subprocess
|
||||
|
||||
from blockbuster import BlockBusterFunction
|
||||
|
||||
_PROJECT_BLOCKING_RULES = (
|
||||
(
|
||||
"subprocess.Popen.__init__",
|
||||
BlockBusterFunction(
|
||||
subprocess.Popen,
|
||||
"__init__",
|
||||
scanned_modules=["app", "deerflow"],
|
||||
),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
Do not add a runtime rule just because a business path is not tested. A rule
|
||||
only expands what Blockbuster can intercept after code runs.
|
||||
|
||||
### Add a runtime anchor
|
||||
|
||||
Add a runtime anchor when a high-risk async production path should be protected
|
||||
by CI but no existing `backend/tests/blocking_io/` test executes it.
|
||||
|
||||
Anchors belong in:
|
||||
|
||||
```text
|
||||
backend/tests/blocking_io/
|
||||
```
|
||||
|
||||
A good anchor should:
|
||||
|
||||
- Call the real production async entry point.
|
||||
- Avoid bypassing the blocking surface with test-only `asyncio.to_thread`
|
||||
wrappers.
|
||||
- Use real local filesystem inputs when the bug shape is filesystem IO.
|
||||
- Mock only the external dependency boundary, such as a network service or
|
||||
third-party saver class.
|
||||
- Fail if a future change moves the blocking operation back onto the event
|
||||
loop.
|
||||
|
||||
Avoid testing only the low-level helper unless that helper is the production
|
||||
async entry point. The runtime gate is most useful when it protects the caller
|
||||
that production actually executes.
|
||||
|
||||
## Current runtime coverage
|
||||
|
||||
The runtime anchors protect confirmed blocking-IO bug shapes:
|
||||
|
||||
- SQLite checkpointer setup, including path resolution and parent-directory
|
||||
creation.
|
||||
- Subagent skill metadata loading through `SubagentExecutor._load_skills()`.
|
||||
- `JsonlRunEventStore` async API (`put` / `list_*` / `delete_*`): the JSONL
|
||||
run-event backend offloads its synchronous file IO via `asyncio.to_thread`
|
||||
(fix #3084); this anchor drives the real async API under the gate so any
|
||||
blocking IO reintroduced on the loop fails, not only removal of one
|
||||
`to_thread` call.
|
||||
- `UploadsMiddleware.before_agent` uploads-directory scan: a sync-only middleware
|
||||
hook runs on the event loop under async graph execution, so the scan is
|
||||
offloaded via `abefore_agent` + `run_in_executor`.
|
||||
- Gate health checks: Blockbuster catches unoffloaded calls, opt-out works, and
|
||||
patches are restored after exceptions.
|
||||
|
||||
As static detection and review identify more high-risk async paths, add new
|
||||
runtime anchors incrementally.
|
||||
@@ -36,7 +36,6 @@ models:
|
||||
- OpenAI (`langchain_openai:ChatOpenAI`)
|
||||
- Anthropic (`langchain_anthropic:ChatAnthropic`)
|
||||
- DeepSeek (`langchain_deepseek:ChatDeepSeek`)
|
||||
- Xiaomi MiMo (`deerflow.models.patched_mimo:PatchedChatMiMo`)
|
||||
- Claude Code OAuth (`deerflow.models.claude_provider:ClaudeChatModel`)
|
||||
- Codex CLI (`deerflow.models.openai_codex_provider:CodexChatModel`)
|
||||
- Any LangChain-compatible provider
|
||||
@@ -167,37 +166,6 @@ models:
|
||||
|
||||
For Gemini accessed **without** thinking (e.g. via OpenRouter where thinking is not activated), the plain `langchain_openai:ChatOpenAI` with `supports_thinking: false` is sufficient and no patch is needed.
|
||||
|
||||
**MiMo with thinking via OpenAI-compatible API**:
|
||||
|
||||
MiMo returns `reasoning_content` on assistant messages in thinking mode. In multi-turn agent conversations with tool calls, subsequent requests must preserve that historical `reasoning_content` on assistant messages or the MiMo API can return HTTP 400. Standard `langchain_openai:ChatOpenAI` drops this provider-specific field, so use `deerflow.models.patched_mimo:PatchedChatMiMo`:
|
||||
|
||||
For pay-as-you-go API keys (`sk-...`), use `https://api.xiaomimimo.com/v1`. For Token Plan keys (`tp-...`), use the regional Token Plan Base URL shown in the MiMo console, such as `https://token-plan-cn.xiaomimimo.com/v1`. MiMo documents these key types as separate and non-interchangeable.
|
||||
|
||||
`PatchedChatMiMo` is model-id agnostic. Use it for every MiMo thinking model entry you configure, including model entries referenced by `subagents.*.model` overrides (for example `mimo-v2.5-pro`, `mimo-v2.5`, `mimo-v2-pro`, `mimo-v2-omni`, or `mimo-v2-flash`).
|
||||
|
||||
```yaml
|
||||
models:
|
||||
- name: mimo-v2.5-pro
|
||||
display_name: MiMo V2.5 Pro
|
||||
use: deerflow.models.patched_mimo:PatchedChatMiMo
|
||||
model: mimo-v2.5-pro
|
||||
api_key: $MIMO_API_KEY
|
||||
base_url: https://api.xiaomimimo.com/v1
|
||||
max_tokens: 8192
|
||||
supports_thinking: true
|
||||
supports_vision: false
|
||||
when_thinking_enabled:
|
||||
extra_body:
|
||||
thinking:
|
||||
type: enabled
|
||||
when_thinking_disabled:
|
||||
extra_body:
|
||||
thinking:
|
||||
type: disabled
|
||||
```
|
||||
|
||||
`PatchedChatMiMo` preserves MiMo's `choices[].message.reasoning_content`, streaming `delta.reasoning_content`, and request-history assistant `reasoning_content` fields. It does not reuse the DeepSeek provider.
|
||||
|
||||
### Tool Groups
|
||||
|
||||
Organize tools into logical groups:
|
||||
@@ -351,7 +319,6 @@ models:
|
||||
- `OPENAI_API_KEY` - OpenAI API key
|
||||
- `ANTHROPIC_API_KEY` - Anthropic API key
|
||||
- `DEEPSEEK_API_KEY` - DeepSeek API key
|
||||
- `MIMO_API_KEY` - Xiaomi MiMo API key
|
||||
- `NOVITA_API_KEY` - Novita API key (OpenAI-compatible endpoint)
|
||||
- `TAVILY_API_KEY` - Tavily search API key
|
||||
- `DEER_FLOW_PROJECT_ROOT` - Project root for relative runtime paths
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
|
||||
- [x] Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
|
||||
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
|
||||
- For production: tune Gateway worker/runtime settings for long-running agent workloads
|
||||
- For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker)
|
||||
|
||||
## Resolved Issues
|
||||
|
||||
|
||||
@@ -227,110 +227,6 @@ def _extract_text(content: Any) -> str:
|
||||
return str(content)
|
||||
|
||||
|
||||
_REQUIRED_MEMORY_UPDATE_TOP_LEVEL_KEYS = frozenset({"user", "history", "newFacts", "factsToRemove"})
|
||||
|
||||
|
||||
def _normalize_memory_update_fact(fact: Any) -> dict[str, Any] | None:
|
||||
"""Normalize a single fact entry from a model-produced memory update."""
|
||||
if not isinstance(fact, dict):
|
||||
return None
|
||||
|
||||
raw_content = fact.get("content")
|
||||
if not isinstance(raw_content, str):
|
||||
return None
|
||||
content = raw_content.strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
raw_category = fact.get("category")
|
||||
category = raw_category.strip() if isinstance(raw_category, str) and raw_category.strip() else "context"
|
||||
|
||||
raw_confidence = fact.get("confidence", 0.5)
|
||||
if isinstance(raw_confidence, bool):
|
||||
return None
|
||||
if isinstance(raw_confidence, str):
|
||||
raw_confidence = raw_confidence.strip()
|
||||
if not raw_confidence:
|
||||
return None
|
||||
try:
|
||||
raw_confidence = float(raw_confidence)
|
||||
except ValueError:
|
||||
return None
|
||||
elif isinstance(raw_confidence, (int, float)):
|
||||
raw_confidence = float(raw_confidence)
|
||||
else:
|
||||
return None
|
||||
|
||||
if not math.isfinite(raw_confidence):
|
||||
return None
|
||||
|
||||
normalized_fact = {
|
||||
"content": content,
|
||||
"category": category,
|
||||
"confidence": raw_confidence,
|
||||
}
|
||||
source_error = fact.get("sourceError")
|
||||
if isinstance(source_error, str):
|
||||
normalized_source_error = source_error.strip()
|
||||
if normalized_source_error:
|
||||
normalized_fact["sourceError"] = normalized_source_error
|
||||
|
||||
return normalized_fact
|
||||
|
||||
|
||||
def _normalize_memory_update_data(update_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce parsed memory update data into the shape consumed by _apply_updates."""
|
||||
user = update_data.get("user")
|
||||
history = update_data.get("history")
|
||||
new_facts = update_data.get("newFacts")
|
||||
facts_to_remove = update_data.get("factsToRemove")
|
||||
normalized_facts_to_remove = [fact_id for fact_id in facts_to_remove if isinstance(fact_id, str)] if isinstance(facts_to_remove, list) else []
|
||||
normalized_new_facts = []
|
||||
dropped_new_fact = not isinstance(new_facts, list)
|
||||
if isinstance(new_facts, list):
|
||||
for fact in new_facts:
|
||||
normalized_fact = _normalize_memory_update_fact(fact)
|
||||
if normalized_fact is not None:
|
||||
normalized_new_facts.append(normalized_fact)
|
||||
else:
|
||||
dropped_new_fact = True
|
||||
|
||||
if normalized_facts_to_remove and dropped_new_fact:
|
||||
raise json.JSONDecodeError(
|
||||
"Unsafe partial memory update: factsToRemove with malformed newFacts",
|
||||
json.dumps(update_data, ensure_ascii=False),
|
||||
0,
|
||||
)
|
||||
|
||||
return {
|
||||
"user": user if isinstance(user, dict) else {},
|
||||
"history": history if isinstance(history, dict) else {},
|
||||
"newFacts": normalized_new_facts,
|
||||
"factsToRemove": normalized_facts_to_remove,
|
||||
}
|
||||
|
||||
|
||||
def _parse_memory_update_response(response_content: Any) -> dict[str, Any]:
|
||||
"""Parse the first valid memory-update JSON object from an LLM response.
|
||||
|
||||
Some providers may wrap JSON in thinking traces, prose, or markdown fences
|
||||
even when prompted to return JSON only. This parser accepts safely
|
||||
extractable JSON objects but does not repair truncated or malformed JSON.
|
||||
"""
|
||||
response_text = _extract_text(response_content).strip()
|
||||
decoder = json.JSONDecoder()
|
||||
|
||||
for match in re.finditer(r"\{", response_text):
|
||||
try:
|
||||
parsed, _end = decoder.raw_decode(response_text[match.start() :])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(parsed, dict) and _REQUIRED_MEMORY_UPDATE_TOP_LEVEL_KEYS.issubset(parsed):
|
||||
return _normalize_memory_update_data(parsed)
|
||||
|
||||
raise json.JSONDecodeError("No valid memory update JSON object found", response_text, 0)
|
||||
|
||||
|
||||
# Matches sentences that describe a file-upload *event* rather than general
|
||||
# file-related work. Deliberately narrow to avoid removing legitimate facts
|
||||
# such as "User works with CSV files" or "prefers PDF export".
|
||||
@@ -457,7 +353,13 @@ class MemoryUpdater:
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Parse the model response, apply updates, and persist memory."""
|
||||
update_data = _parse_memory_update_response(response_content)
|
||||
response_text = _extract_text(response_content).strip()
|
||||
|
||||
if response_text.startswith("```"):
|
||||
lines = response_text.split("\n")
|
||||
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||
|
||||
update_data = json.loads(response_text)
|
||||
# Deep-copy before in-place mutation so a subsequent save() failure
|
||||
# cannot corrupt the still-cached original object reference.
|
||||
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
|
||||
|
||||
+2
-23
@@ -26,11 +26,6 @@ from langchain_core.messages import ToolMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Workaround for issue #2894: malformed write_file calls can carry huge Markdown
|
||||
# payloads in invalid tool-call args. Keep recovery error details short so the
|
||||
# synthetic ToolMessage does not echo large or malformed content back to the model.
|
||||
_MAX_RECOVERY_ERROR_DETAIL_LEN = 500
|
||||
|
||||
|
||||
class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
|
||||
@@ -103,25 +98,9 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
@staticmethod
|
||||
def _synthetic_tool_message_content(tool_call: dict) -> str:
|
||||
if tool_call.get("invalid"):
|
||||
name = tool_call.get("name")
|
||||
error = tool_call.get("error")
|
||||
error_text = error[:_MAX_RECOVERY_ERROR_DETAIL_LEN] if isinstance(error, str) and error else ""
|
||||
# Workaround for issue #2894: malformed write_file calls can carry huge Markdown
|
||||
# payloads in invalid tool-call args. Keep recovery guidance actionable without
|
||||
# echoing large or malformed content back to the model.
|
||||
if name == "write_file":
|
||||
details = f" Parser error: {error_text}" if error_text else ""
|
||||
return (
|
||||
"[write_file failed before execution: the tool-call arguments were not valid JSON, "
|
||||
"so no file was written. This often happens when the model tries to write a very "
|
||||
"large Markdown file in a single tool call, especially when `content` contains "
|
||||
"unescaped quotes, inline JSON, backslashes, or code fences. Do not retry the same "
|
||||
"large `write_file` payload for this artifact; provide the report/content directly "
|
||||
"as normal assistant text in your next response. If a file write is still needed "
|
||||
f"later, split the file into smaller sections instead of one large payload.{details}]"
|
||||
)
|
||||
if error_text:
|
||||
return f"[Tool call could not be executed because its arguments were invalid: {error_text}]"
|
||||
if isinstance(error, str) and error:
|
||||
return f"[Tool call could not be executed because its arguments were invalid: {error}]"
|
||||
return "[Tool call could not be executed because its arguments were invalid.]"
|
||||
return "[Tool call was interrupted and did not return a result.]"
|
||||
|
||||
|
||||
+1
-3
@@ -77,11 +77,9 @@ def _build_runtime_middlewares(
|
||||
"""Build shared base middlewares for agent execution."""
|
||||
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from deerflow.agents.middlewares.tool_output_budget_middleware import ToolOutputBudgetMiddleware
|
||||
from deerflow.sandbox.middleware import SandboxMiddleware
|
||||
|
||||
middlewares: list[AgentMiddleware] = [
|
||||
ToolOutputBudgetMiddleware.from_app_config(app_config),
|
||||
ThreadDataMiddleware(lazy_init=lazy_init),
|
||||
SandboxMiddleware(lazy_init=lazy_init),
|
||||
]
|
||||
@@ -89,7 +87,7 @@ def _build_runtime_middlewares(
|
||||
if include_uploads:
|
||||
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
|
||||
middlewares.insert(2, UploadsMiddleware())
|
||||
middlewares.insert(1, UploadsMiddleware())
|
||||
|
||||
if include_dangling_tool_call_patch:
|
||||
from deerflow.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
|
||||
|
||||
-489
@@ -1,489 +0,0 @@
|
||||
"""Middleware that enforces a per-result budget on tool outputs.
|
||||
|
||||
Oversized tool results are persisted to disk and replaced with a compact
|
||||
preview containing a file reference. When disk persistence is
|
||||
unavailable the middleware falls back to head+tail truncation so the
|
||||
model context is never blown by a single large tool return.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import replace as dc_replace
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.config.tool_output_config import ToolOutputConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_config() -> ToolOutputConfig:
|
||||
return ToolOutputConfig()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _message_text(content: Any) -> str | None:
|
||||
"""Extract a plain-text representation from a ToolMessage content field.
|
||||
|
||||
Returns ``None`` for non-string / multimodal content so the caller
|
||||
can skip budget enforcement (images, structured blocks, etc.).
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if content is None:
|
||||
return None
|
||||
if isinstance(content, list):
|
||||
pieces: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
pieces.append(part)
|
||||
elif isinstance(part, dict) and isinstance(part.get("text"), str):
|
||||
pieces.append(part["text"])
|
||||
else:
|
||||
return None
|
||||
return "\n".join(pieces) if pieces else None
|
||||
return None
|
||||
|
||||
|
||||
def _snap_to_line_boundary(text: str, pos: int) -> int:
|
||||
"""Return *pos* or the nearest preceding newline+1, whichever is closer.
|
||||
|
||||
Used so that previews and truncations end on a complete line when
|
||||
possible. If no newline exists in the second half of ``text[:pos]``
|
||||
the original *pos* is returned unchanged.
|
||||
"""
|
||||
if pos <= 0 or pos >= len(text):
|
||||
return pos
|
||||
half = pos // 2
|
||||
nl = text.rfind("\n", half, pos)
|
||||
if nl >= 0:
|
||||
return nl + 1
|
||||
return pos
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Disk persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EXT_MAP: dict[str, str] = {
|
||||
"bash": "log",
|
||||
"bash_tool": "log",
|
||||
"web_fetch": "log",
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_tool_name(name: str) -> str:
|
||||
"""Strip path separators and traversal components from a tool name."""
|
||||
base = os.path.basename(name)
|
||||
safe = base.replace("..", "").replace("/", "_").replace("\\", "_")
|
||||
return safe or "unknown"
|
||||
|
||||
|
||||
def _externalize(
|
||||
content: str,
|
||||
*,
|
||||
tool_name: str,
|
||||
tool_call_id: str,
|
||||
outputs_path: str,
|
||||
storage_subdir: str,
|
||||
) -> str | None:
|
||||
"""Write *content* to disk and return the virtual path, or ``None`` on failure."""
|
||||
if os.path.isabs(storage_subdir) or ".." in storage_subdir:
|
||||
return None
|
||||
storage_dir = os.path.join(outputs_path, storage_subdir)
|
||||
try:
|
||||
os.makedirs(storage_dir, exist_ok=True)
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
safe_name = _sanitize_tool_name(tool_name)
|
||||
ext = _EXT_MAP.get(tool_name, "txt")
|
||||
short_id = uuid.uuid4().hex[:12]
|
||||
filename = f"{safe_name}-{short_id}.{ext}"
|
||||
filepath = os.path.join(storage_dir, filename)
|
||||
|
||||
if not os.path.abspath(filepath).startswith(os.path.abspath(storage_dir)):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
virtual_base = "/mnt/user-data/outputs"
|
||||
return f"{virtual_base}/{storage_subdir}/{filename}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preview / fallback builders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_preview(
|
||||
content: str,
|
||||
*,
|
||||
tool_name: str,
|
||||
virtual_path: str,
|
||||
head_chars: int,
|
||||
tail_chars: int,
|
||||
) -> str:
|
||||
"""Build a preview with a file reference for externalized output."""
|
||||
total = len(content)
|
||||
head_end = _snap_to_line_boundary(content, min(head_chars, total))
|
||||
tail_start = max(head_end, total - tail_chars)
|
||||
tail_start_snapped = _snap_to_line_boundary(content, tail_start)
|
||||
if tail_start_snapped > head_end:
|
||||
tail_start = tail_start_snapped
|
||||
|
||||
head = content[:head_end]
|
||||
tail = content[tail_start:] if tail_start < total else ""
|
||||
|
||||
omitted = total - len(head) - len(tail)
|
||||
ref = f"\n\n[Full {tool_name} output saved to {virtual_path} ({total} chars, ~{total // 4} tokens). Use read_file with start_line and end_line to access specific sections. {omitted} chars omitted from this preview.]\n\n"
|
||||
|
||||
parts = [head, ref]
|
||||
if tail:
|
||||
parts.append(tail)
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _build_fallback(
|
||||
content: str,
|
||||
*,
|
||||
tool_name: str,
|
||||
max_chars: int,
|
||||
head_chars: int,
|
||||
tail_chars: int,
|
||||
) -> str:
|
||||
"""Build a head+tail truncation when disk persistence is unavailable.
|
||||
|
||||
The returned string is guaranteed to be no longer than *max_chars*.
|
||||
"""
|
||||
total = len(content)
|
||||
if max_chars <= 0 or total <= max_chars:
|
||||
return content
|
||||
|
||||
marker_template = "\n\n[... {n} chars omitted from {tn} output. Persistent storage unavailable. Consider narrowing the query or using more specific parameters.]\n\n"
|
||||
marker_overhead = len(marker_template.format(n=total, tn=tool_name))
|
||||
|
||||
if marker_overhead >= max_chars:
|
||||
return content[:max_chars]
|
||||
|
||||
budget = max_chars - marker_overhead
|
||||
effective_head = min(head_chars, budget)
|
||||
effective_tail = min(tail_chars, max(0, budget - effective_head))
|
||||
|
||||
head_end = _snap_to_line_boundary(content, min(effective_head, total))
|
||||
tail_start = max(head_end, total - effective_tail)
|
||||
tail_start_snapped = _snap_to_line_boundary(content, tail_start)
|
||||
if tail_start_snapped > head_end:
|
||||
tail_start = tail_start_snapped
|
||||
|
||||
head = content[:head_end]
|
||||
tail = content[tail_start:] if tail_start < total else ""
|
||||
omitted = total - len(head) - len(tail)
|
||||
|
||||
marker = marker_template.format(n=omitted, tn=tool_name)
|
||||
|
||||
parts = [head, marker]
|
||||
if tail:
|
||||
parts.append(tail)
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core budget logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_outputs_path(request: ToolCallRequest) -> str | None:
|
||||
"""Best-effort extraction of the thread outputs path."""
|
||||
runtime = getattr(request, "runtime", None)
|
||||
if runtime is None:
|
||||
return None
|
||||
state = getattr(runtime, "state", None)
|
||||
if state is None:
|
||||
return None
|
||||
thread_data = state.get("thread_data")
|
||||
if not isinstance(thread_data, dict):
|
||||
return None
|
||||
outputs_path = thread_data.get("outputs_path")
|
||||
return outputs_path if isinstance(outputs_path, str) else None
|
||||
|
||||
|
||||
def _budget_content(
|
||||
content: str,
|
||||
*,
|
||||
tool_name: str,
|
||||
tool_call_id: str,
|
||||
outputs_path: str | None,
|
||||
config: ToolOutputConfig,
|
||||
) -> str | None:
|
||||
"""Apply budget to *content*. Returns ``None`` if no change needed."""
|
||||
threshold = config.tool_overrides.get(tool_name, config.externalize_min_chars)
|
||||
if threshold <= 0 and config.fallback_max_chars <= 0:
|
||||
return None
|
||||
if len(content) <= threshold and len(content) <= config.fallback_max_chars:
|
||||
return None
|
||||
|
||||
if threshold > 0 and len(content) > threshold and outputs_path:
|
||||
virtual_path = _externalize(
|
||||
content,
|
||||
tool_name=tool_name,
|
||||
tool_call_id=tool_call_id,
|
||||
outputs_path=outputs_path,
|
||||
storage_subdir=config.storage_subdir,
|
||||
)
|
||||
if virtual_path is not None:
|
||||
logger.info(
|
||||
"Externalized %s output (%d chars) to %s",
|
||||
tool_name,
|
||||
len(content),
|
||||
virtual_path,
|
||||
)
|
||||
return _build_preview(
|
||||
content,
|
||||
tool_name=tool_name,
|
||||
virtual_path=virtual_path,
|
||||
head_chars=config.preview_head_chars,
|
||||
tail_chars=config.preview_tail_chars,
|
||||
)
|
||||
|
||||
if config.fallback_max_chars > 0 and len(content) > config.fallback_max_chars:
|
||||
logger.warning(
|
||||
"Fallback-truncating %s output: %d chars → %d max",
|
||||
tool_name,
|
||||
len(content),
|
||||
config.fallback_max_chars,
|
||||
)
|
||||
return _build_fallback(
|
||||
content,
|
||||
tool_name=tool_name,
|
||||
max_chars=config.fallback_max_chars,
|
||||
head_chars=config.fallback_head_chars,
|
||||
tail_chars=config.fallback_tail_chars,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Result patchers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _patch_tool_message(msg: ToolMessage, config: ToolOutputConfig, outputs_path: str | None) -> ToolMessage:
|
||||
"""Apply budget to a single ToolMessage. Returns the original if unchanged."""
|
||||
tool_name = msg.name or "unknown"
|
||||
if tool_name in config.exempt_tools:
|
||||
return msg
|
||||
|
||||
text = _message_text(msg.content)
|
||||
if text is None:
|
||||
return msg
|
||||
|
||||
replacement = _budget_content(
|
||||
text,
|
||||
tool_name=tool_name,
|
||||
tool_call_id=msg.tool_call_id or "",
|
||||
outputs_path=outputs_path,
|
||||
config=config,
|
||||
)
|
||||
if replacement is None:
|
||||
return msg
|
||||
|
||||
update: dict[str, Any] = {"content": replacement}
|
||||
if getattr(msg, "response_metadata", None):
|
||||
update["response_metadata"] = dict(msg.response_metadata)
|
||||
if getattr(msg, "additional_kwargs", None):
|
||||
update["additional_kwargs"] = dict(msg.additional_kwargs)
|
||||
return msg.model_copy(update=update)
|
||||
|
||||
|
||||
def _effective_trigger(tool_name: str, config: ToolOutputConfig) -> int:
|
||||
"""Smallest content length that could trigger budgeting for *tool_name*.
|
||||
|
||||
Mirrors the trigger conditions in :func:`_budget_content` (per-tool
|
||||
externalize threshold OR global fallback), so the pre-scan never produces
|
||||
a false negative. Returns ``-1`` when nothing could ever trigger.
|
||||
"""
|
||||
candidates: list[int] = []
|
||||
externalize = config.tool_overrides.get(tool_name, config.externalize_min_chars)
|
||||
if externalize > 0:
|
||||
candidates.append(externalize)
|
||||
if config.fallback_max_chars > 0:
|
||||
candidates.append(config.fallback_max_chars)
|
||||
return min(candidates) if candidates else -1
|
||||
|
||||
|
||||
def _tool_message_over_budget(msg: ToolMessage, config: ToolOutputConfig) -> bool:
|
||||
"""Cheap, per-tool-aware check: is this ToolMessage non-exempt and over its trigger?"""
|
||||
if (msg.name or "") in config.exempt_tools:
|
||||
return False
|
||||
trigger = _effective_trigger(msg.name or "", config)
|
||||
if trigger < 0:
|
||||
return False
|
||||
text = _message_text(msg.content)
|
||||
return text is not None and len(text) > trigger
|
||||
|
||||
|
||||
def _needs_budget(result: ToolMessage | Command, config: ToolOutputConfig) -> bool:
|
||||
"""Fast check whether *result* could need budgeting (avoids thread offload for small outputs)."""
|
||||
if isinstance(result, ToolMessage):
|
||||
return _tool_message_over_budget(result, config)
|
||||
update = getattr(result, "update", None)
|
||||
if isinstance(update, dict):
|
||||
for msg in update.get("messages", []):
|
||||
if isinstance(msg, ToolMessage) and _tool_message_over_budget(msg, config):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _patch_result(result: ToolMessage | Command, config: ToolOutputConfig, outputs_path: str | None) -> ToolMessage | Command:
|
||||
"""Apply budget to a tool call result (ToolMessage or Command)."""
|
||||
if isinstance(result, ToolMessage):
|
||||
return _patch_tool_message(result, config, outputs_path)
|
||||
|
||||
update = getattr(result, "update", None)
|
||||
if not isinstance(update, dict):
|
||||
return result
|
||||
|
||||
messages = update.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return result
|
||||
|
||||
new_messages: list[Any] = []
|
||||
changed = False
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
patched = _patch_tool_message(msg, config, outputs_path)
|
||||
if patched is not msg:
|
||||
changed = True
|
||||
new_messages.append(patched)
|
||||
else:
|
||||
new_messages.append(msg)
|
||||
|
||||
if not changed:
|
||||
return result
|
||||
|
||||
return dc_replace(result, update={**update, "messages": new_messages})
|
||||
|
||||
|
||||
def _patch_model_messages(messages: list[Any], config: ToolOutputConfig) -> list[Any] | None:
|
||||
"""Apply budget to historical ToolMessages in a model request. Returns ``None`` if unchanged.
|
||||
|
||||
A cheap pre-scan bails out before allocating a new list when no historical
|
||||
ToolMessage exceeds the budget — the common case once every result has
|
||||
already been budgeted at tool-call time, so a long history is not rebuilt
|
||||
on every model call.
|
||||
"""
|
||||
if not any(isinstance(msg, ToolMessage) and _tool_message_over_budget(msg, config) for msg in messages):
|
||||
return None
|
||||
|
||||
updated: list[Any] = []
|
||||
changed = False
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
patched = _patch_tool_message(msg, config, outputs_path=None)
|
||||
if patched is not msg:
|
||||
changed = True
|
||||
updated.append(patched)
|
||||
else:
|
||||
updated.append(msg)
|
||||
return updated if changed else None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ToolOutputBudgetMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Enforce per-result budget on tool outputs via externalization or truncation."""
|
||||
|
||||
def __init__(self, config: ToolOutputConfig | None = None) -> None:
|
||||
super().__init__()
|
||||
self._config = config if config is not None else _default_config()
|
||||
|
||||
@classmethod
|
||||
def from_app_config(cls, app_config: Any) -> ToolOutputBudgetMiddleware:
|
||||
tool_output = getattr(app_config, "tool_output", None)
|
||||
if isinstance(tool_output, ToolOutputConfig):
|
||||
return cls(config=tool_output)
|
||||
return cls()
|
||||
|
||||
# -- tool call hooks ---------------------------------------------------
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
result = handler(request)
|
||||
if not self._config.enabled:
|
||||
return result
|
||||
if not _needs_budget(result, self._config):
|
||||
return result
|
||||
outputs_path = _resolve_outputs_path(request)
|
||||
return _patch_result(result, self._config, outputs_path)
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
result = await handler(request)
|
||||
if not self._config.enabled:
|
||||
return result
|
||||
if not _needs_budget(result, self._config):
|
||||
return result
|
||||
outputs_path = _resolve_outputs_path(request)
|
||||
return await asyncio.to_thread(_patch_result, result, self._config, outputs_path)
|
||||
|
||||
# -- model call hooks (historical message truncation) ------------------
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
if self._config.enabled:
|
||||
messages = getattr(request, "messages", None)
|
||||
if isinstance(messages, list):
|
||||
patched = _patch_model_messages(messages, self._config)
|
||||
if patched is not None:
|
||||
request = request.override(messages=patched)
|
||||
return handler(request)
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
if self._config.enabled:
|
||||
messages = getattr(request, "messages", None)
|
||||
if isinstance(messages, list):
|
||||
patched = _patch_model_messages(messages, self._config)
|
||||
if patched is not None:
|
||||
request = request.override(messages=patched)
|
||||
return await handler(request)
|
||||
@@ -7,7 +7,6 @@ from typing import NotRequired, override
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
@@ -294,16 +293,3 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
"uploaded_files": new_files,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
@override
|
||||
async def abefore_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Async hook that offloads the synchronous uploads scan off the event loop.
|
||||
|
||||
``before_agent`` performs blocking filesystem IO (directory enumeration,
|
||||
``stat``, reading sibling ``.md`` outlines). When the graph runs async,
|
||||
langgraph would otherwise execute the sync hook directly on the event
|
||||
loop, so it is dispatched to a worker thread via ``run_in_executor``.
|
||||
``run_in_executor`` copies the current context, so the ``user_id``
|
||||
contextvar read by ``get_effective_user_id()`` is preserved.
|
||||
"""
|
||||
return await run_in_executor(None, self.before_agent, state, runtime)
|
||||
|
||||
@@ -119,6 +119,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
port: 8080 # Base port for local containers
|
||||
container_prefix: deer-flow-sandbox
|
||||
idle_timeout: 600 # Idle timeout in seconds (0 to disable)
|
||||
auto_restart: true # Restart crashed containers automatically
|
||||
replicas: 3 # Max concurrent sandbox containers (LRU eviction when exceeded)
|
||||
mounts: # Volume mounts for local containers
|
||||
- host_path: /path/on/host
|
||||
@@ -203,12 +204,14 @@ class AioSandboxProvider(SandboxProvider):
|
||||
|
||||
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
|
||||
replicas = getattr(sandbox_config, "replicas", None)
|
||||
auto_restart = getattr(sandbox_config, "auto_restart", True)
|
||||
|
||||
return {
|
||||
"image": sandbox_config.image or DEFAULT_IMAGE,
|
||||
"port": sandbox_config.port or DEFAULT_PORT,
|
||||
"container_prefix": sandbox_config.container_prefix or DEFAULT_CONTAINER_PREFIX,
|
||||
"idle_timeout": idle_timeout if idle_timeout is not None else DEFAULT_IDLE_TIMEOUT,
|
||||
"auto_restart": auto_restart,
|
||||
"replicas": replicas if replicas is not None else DEFAULT_REPLICAS,
|
||||
"mounts": sandbox_config.mounts or [],
|
||||
"environment": self._resolve_env_vars(sandbox_config.environment or {}),
|
||||
@@ -771,18 +774,58 @@ class AioSandboxProvider(SandboxProvider):
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
"""Get a sandbox by ID. Updates last activity timestamp.
|
||||
|
||||
When ``auto_restart`` is enabled (the default), the container's liveness
|
||||
is verified on each lookup. If the underlying container has crashed, the
|
||||
sandbox is evicted from all caches so that the next ``acquire()`` call will
|
||||
transparently create a fresh container.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox.
|
||||
|
||||
Returns:
|
||||
The sandbox instance if found, None otherwise.
|
||||
The sandbox instance if found and alive, None otherwise.
|
||||
"""
|
||||
with self._lock:
|
||||
sandbox = self._sandboxes.get(sandbox_id)
|
||||
if sandbox is not None:
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
if sandbox is None:
|
||||
return None
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
auto_restart = self._config.get("auto_restart", True)
|
||||
info = self._sandbox_infos.get(sandbox_id) if auto_restart else None
|
||||
|
||||
if not info:
|
||||
return sandbox
|
||||
|
||||
if self._backend.is_alive(info):
|
||||
return sandbox
|
||||
|
||||
info_to_destroy = None
|
||||
with self._lock:
|
||||
current_sandbox = self._sandboxes.get(sandbox_id)
|
||||
current_info = self._sandbox_infos.get(sandbox_id)
|
||||
if current_sandbox is None:
|
||||
return None
|
||||
if current_info is not info:
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
return current_sandbox
|
||||
|
||||
logger.warning(f"Sandbox {sandbox_id} container is not alive, evicting from cache for auto-restart")
|
||||
self._sandboxes.pop(sandbox_id, None)
|
||||
self._sandbox_infos.pop(sandbox_id, None)
|
||||
self._last_activity.pop(sandbox_id, None)
|
||||
self._warm_pool.pop(sandbox_id, None)
|
||||
thread_ids = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||
for tid in thread_ids:
|
||||
del self._thread_sandboxes[tid]
|
||||
info_to_destroy = info
|
||||
|
||||
if info_to_destroy:
|
||||
try:
|
||||
self._backend.destroy(info_to_destroy)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup dead sandbox {sandbox_id}: {e}")
|
||||
return None
|
||||
|
||||
def release(self, sandbox_id: str) -> None:
|
||||
"""Release a sandbox from active use into the warm pool.
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ from deerflow.config.summarization_config import SummarizationConfig, load_summa
|
||||
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
|
||||
from deerflow.config.token_usage_config import TokenUsageConfig
|
||||
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
|
||||
from deerflow.config.tool_output_config import ToolOutputConfig
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
|
||||
|
||||
load_dotenv()
|
||||
@@ -94,7 +93,6 @@ class AppConfig(BaseModel):
|
||||
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
|
||||
skill_evolution: SkillEvolutionConfig = Field(default_factory=SkillEvolutionConfig, description="Agent-managed skill evolution configuration")
|
||||
extensions: ExtensionsConfig = Field(default_factory=ExtensionsConfig, description="Extensions configuration (MCP servers and skills state)")
|
||||
tool_output: ToolOutputConfig = Field(default_factory=ToolOutputConfig, description="Tool output budget protection configuration")
|
||||
tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig, description="Tool search / deferred loading configuration")
|
||||
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
|
||||
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
|
||||
|
||||
@@ -23,6 +23,9 @@ class SandboxConfig(BaseModel):
|
||||
replicas: Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.
|
||||
container_prefix: Prefix for container names (default: deer-flow-sandbox)
|
||||
idle_timeout: Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.
|
||||
auto_restart: Automatically restart sandbox containers that have crashed (default: true). When a tool call
|
||||
detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated
|
||||
on the next acquire. Set to false to disable.
|
||||
mounts: List of volume mounts to share directories with the container
|
||||
environment: Environment variables to inject into the container (values starting with $ are resolved from host env)
|
||||
"""
|
||||
@@ -55,6 +58,10 @@ class SandboxConfig(BaseModel):
|
||||
default=None,
|
||||
description="Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.",
|
||||
)
|
||||
auto_restart: bool = Field(
|
||||
default=True,
|
||||
description="Automatically restart sandbox containers that have crashed. When a tool call detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated on the next acquire.",
|
||||
)
|
||||
mounts: list[VolumeMountConfig] = Field(
|
||||
default_factory=list,
|
||||
description="List of volume mounts to share directories between host and container",
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
"""Configuration for tool output budget protection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ToolOutputConfig(BaseModel):
|
||||
"""Config section for tool-result output budget enforcement.
|
||||
|
||||
When a tool returns more than ``externalize_min_chars`` characters,
|
||||
the full output is persisted to disk and replaced with a compact
|
||||
preview + file reference. If disk persistence is unavailable the
|
||||
output falls back to head+tail truncation.
|
||||
"""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable the tool output budget middleware.",
|
||||
)
|
||||
externalize_min_chars: int = Field(
|
||||
default=12_000,
|
||||
ge=0,
|
||||
description="Character threshold to trigger disk externalization. Outputs below this pass through unchanged. Set to 0 to disable externalization (fallback truncation still applies when output exceeds fallback_max_chars).",
|
||||
)
|
||||
preview_head_chars: int = Field(
|
||||
default=2_000,
|
||||
ge=0,
|
||||
description="Characters to keep from the head of the output in the preview.",
|
||||
)
|
||||
preview_tail_chars: int = Field(
|
||||
default=1_000,
|
||||
ge=0,
|
||||
description="Characters to keep from the tail of the output in the preview.",
|
||||
)
|
||||
fallback_max_chars: int = Field(
|
||||
default=30_000,
|
||||
ge=0,
|
||||
description="Maximum characters when disk persistence is unavailable. 0 disables fallback truncation.",
|
||||
)
|
||||
fallback_head_chars: int = Field(
|
||||
default=8_000,
|
||||
ge=0,
|
||||
description="Head characters for fallback truncation.",
|
||||
)
|
||||
fallback_tail_chars: int = Field(
|
||||
default=3_000,
|
||||
ge=0,
|
||||
description="Tail characters for fallback truncation.",
|
||||
)
|
||||
storage_subdir: str = Field(
|
||||
default=".tool-results",
|
||||
description="Subdirectory under the thread outputs path for persisted tool results.",
|
||||
)
|
||||
exempt_tools: list[str] = Field(
|
||||
default_factory=lambda: ["read_file", "read_file_tool"],
|
||||
description="Tool names exempt from budget enforcement (prevents persist→read→persist loops).",
|
||||
)
|
||||
tool_overrides: dict[str, int] = Field(
|
||||
default_factory=dict,
|
||||
description="Per-tool externalize_min_chars overrides. Keys are tool names, values are char thresholds. Use 0 to disable externalization for a specific tool.",
|
||||
)
|
||||
@@ -87,7 +87,8 @@ def get_cached_mcp_tools() -> list[BaseTool]:
|
||||
|
||||
Also checks if the config file has been modified since last initialization,
|
||||
and re-initializes if needed. This ensures that changes made through the
|
||||
Gateway API are reflected in the Gateway-embedded LangGraph runtime.
|
||||
Gateway API (which runs in a separate process) are reflected in the
|
||||
LangGraph Server.
|
||||
|
||||
Returns:
|
||||
List of cached MCP tools.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Load MCP tools using langchain-mcp-adapters with stdio session pooling."""
|
||||
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -173,10 +173,8 @@ def _make_session_pool_tool(
|
||||
async def get_mcp_tools() -> list[BaseTool]:
|
||||
"""Get all tools from enabled MCP servers.
|
||||
|
||||
Tools using stdio transport are wrapped with persistent-session logic so
|
||||
consecutive calls within the same thread reuse the same MCP session.
|
||||
HTTP/SSE tools are returned unwrapped to avoid cross-task TaskGroup
|
||||
cleanup errors.
|
||||
Tools are wrapped with persistent-session logic so that consecutive
|
||||
calls within the same thread reuse the same MCP session.
|
||||
|
||||
Returns:
|
||||
List of LangChain tools from all enabled MCP servers.
|
||||
@@ -253,9 +251,6 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
|
||||
|
||||
# Wrap each tool with persistent-session logic.
|
||||
# Only pool stdio sessions. HTTP/SSE transports use anyio TaskGroups
|
||||
# internally which cannot be closed from a different async task, so
|
||||
# pooling them causes RuntimeError on cleanup (see #3203).
|
||||
wrapped_tools: list[BaseTool] = []
|
||||
for tool in tools:
|
||||
tool_server: str | None = None
|
||||
@@ -265,11 +260,7 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
break
|
||||
|
||||
if tool_server is not None:
|
||||
transport = servers_config[tool_server].get("transport", "stdio")
|
||||
if transport == "stdio":
|
||||
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
|
||||
else:
|
||||
wrapped_tools.append(tool)
|
||||
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
|
||||
else:
|
||||
wrapped_tools.append(tool)
|
||||
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
"""Helpers for replaying provider-specific assistant message fields.
|
||||
|
||||
Several provider adapters need to preserve fields that LangChain stores on the
|
||||
original ``AIMessage`` but drops when serializing request payloads. This module
|
||||
keeps the assistant-message matching logic shared while letting each provider
|
||||
decide which fields to restore.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
|
||||
AssistantPayloadRestorer = Callable[[dict[str, Any], AIMessage], None]
|
||||
|
||||
|
||||
def restore_assistant_payloads(
|
||||
payload_messages: Sequence[dict[str, Any]],
|
||||
original_messages: Sequence[BaseMessage],
|
||||
restore: AssistantPayloadRestorer,
|
||||
) -> None:
|
||||
"""Restore provider-specific fields onto serialized assistant payloads."""
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
restore(payload_msg, orig_msg)
|
||||
return
|
||||
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [m for m in payload_messages if m.get("role") == "assistant"]
|
||||
used_ai_indexes: set[int] = set()
|
||||
|
||||
for ordinal, payload_msg in enumerate(assistant_payloads):
|
||||
ai_msg = _match_ai_message(payload_msg, ai_messages, used_ai_indexes, ordinal)
|
||||
if ai_msg is not None:
|
||||
restore(payload_msg, ai_msg)
|
||||
|
||||
|
||||
def restore_additional_kwargs_field(payload_msg: dict[str, Any], orig_msg: AIMessage, field_name: str) -> None:
|
||||
"""Copy a provider-specific ``additional_kwargs`` field onto a payload message."""
|
||||
value = orig_msg.additional_kwargs.get(field_name)
|
||||
if value is not None:
|
||||
payload_msg[field_name] = value
|
||||
|
||||
|
||||
def restore_reasoning_content(payload_msg: dict[str, Any], orig_msg: AIMessage) -> None:
|
||||
"""Copy provider reasoning content onto a serialized assistant payload."""
|
||||
restore_additional_kwargs_field(payload_msg, orig_msg, "reasoning_content")
|
||||
|
||||
|
||||
def _match_ai_message(
|
||||
payload_msg: dict[str, Any],
|
||||
ai_messages: Sequence[AIMessage],
|
||||
used_ai_indexes: set[int],
|
||||
fallback_ordinal: int,
|
||||
) -> AIMessage | None:
|
||||
payload_key = _assistant_signature(payload_msg)
|
||||
if payload_key is not None:
|
||||
matches = [index for index, ai_msg in enumerate(ai_messages) if index not in used_ai_indexes and _ai_signature(ai_msg) == payload_key]
|
||||
if len(matches) == 1:
|
||||
used_ai_indexes.add(matches[0])
|
||||
return ai_messages[matches[0]]
|
||||
|
||||
fallback_index = _next_unused_index_at_or_after(len(ai_messages), used_ai_indexes, fallback_ordinal)
|
||||
if fallback_index is not None:
|
||||
used_ai_indexes.add(fallback_index)
|
||||
return ai_messages[fallback_index]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _next_unused_index_at_or_after(count: int, used_ai_indexes: set[int], start: int) -> int | None:
|
||||
"""Return the next unused AI index at or after ``start``.
|
||||
|
||||
Scanning forward from the payload's ordinal preserves the positional bias of
|
||||
the previous behaviour while still recovering when serialization drops or
|
||||
reorders messages so the exact ordinal index is already taken. It does not
|
||||
wrap to earlier indexes because those messages may be represented by payload
|
||||
entries that were already dropped.
|
||||
"""
|
||||
if count == 0 or start >= count:
|
||||
return None
|
||||
for index in range(start, count):
|
||||
if index not in used_ai_indexes:
|
||||
return index
|
||||
return None
|
||||
|
||||
|
||||
def _assistant_signature(payload_msg: dict[str, Any]) -> tuple[str, str] | None:
|
||||
return _signature(
|
||||
payload_msg.get("content"),
|
||||
_tool_call_ids(payload_msg.get("tool_calls") or []),
|
||||
)
|
||||
|
||||
|
||||
def _ai_signature(message: AIMessage) -> tuple[str, str] | None:
|
||||
tool_calls = message.tool_calls or message.additional_kwargs.get("tool_calls") or []
|
||||
return _signature(message.content, _tool_call_ids(tool_calls))
|
||||
|
||||
|
||||
def _signature(content: Any, tool_call_ids: tuple[str, ...]) -> tuple[str, str] | None:
|
||||
if content in (None, "") and not tool_call_ids:
|
||||
return None
|
||||
return (_stable_repr(content), "|".join(tool_call_ids))
|
||||
|
||||
|
||||
def _stable_repr(value: Any) -> str:
|
||||
try:
|
||||
return json.dumps(value, sort_keys=True, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return repr(value)
|
||||
|
||||
|
||||
def _tool_call_ids(tool_calls: Sequence[Any]) -> tuple[str, ...]:
|
||||
ids: list[str] = []
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
call_id = tool_call.get("id")
|
||||
if isinstance(call_id, str) and call_id:
|
||||
ids.append(call_id)
|
||||
return tuple(ids)
|
||||
@@ -10,10 +10,9 @@ on all assistant messages when thinking mode is enabled.
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content
|
||||
|
||||
|
||||
class PatchedChatDeepSeek(ChatDeepSeek):
|
||||
"""ChatDeepSeek with proper reasoning_content preservation.
|
||||
@@ -50,10 +49,25 @@ class PatchedChatDeepSeek(ChatDeepSeek):
|
||||
# Call parent to get the base payload
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
|
||||
restore_assistant_payloads(
|
||||
payload.get("messages", []),
|
||||
original_messages,
|
||||
restore_reasoning_content,
|
||||
)
|
||||
# Match payload messages with original messages to restore reasoning_content
|
||||
payload_messages = payload.get("messages", [])
|
||||
|
||||
# The payload messages and original messages should be in the same order
|
||||
# Iterate through both and match by position
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
reasoning_content = orig_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning_content is not None:
|
||||
payload_msg["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
# Fallback: match by counting assistant messages
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
|
||||
|
||||
for (idx, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
|
||||
reasoning_content = ai_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning_content is not None:
|
||||
payload_messages[idx]["reasoning_content"] = reasoning_content
|
||||
|
||||
return payload
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
"""Patched ChatOpenAI adapter for Xiaomi MiMo reasoning_content replay.
|
||||
|
||||
MiMo's OpenAI-compatible API returns ``reasoning_content`` in thinking mode and
|
||||
requires that value to be replayed on historical assistant messages in
|
||||
multi-turn agent conversations. Standard ``langchain_openai.ChatOpenAI`` drops
|
||||
that provider-specific field, which can cause HTTP 400 errors once tool calls
|
||||
enter the conversation history.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content
|
||||
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
def _extract_reasoning_content(value: Any) -> str | object:
|
||||
"""Return reasoning_content from a dict/Pydantic object, preserving empty strings."""
|
||||
if isinstance(value, Mapping):
|
||||
if "reasoning_content" in value and value["reasoning_content"] is not None:
|
||||
return value["reasoning_content"]
|
||||
return _MISSING
|
||||
|
||||
reasoning = getattr(value, "reasoning_content", _MISSING)
|
||||
if reasoning is not _MISSING and reasoning is not None:
|
||||
return reasoning
|
||||
|
||||
model_extra = getattr(value, "model_extra", None)
|
||||
if isinstance(model_extra, Mapping) and "reasoning_content" in model_extra and model_extra["reasoning_content"] is not None:
|
||||
return model_extra["reasoning_content"]
|
||||
|
||||
return _MISSING
|
||||
|
||||
|
||||
def _with_reasoning_content(message: AIMessage | AIMessageChunk, reasoning: str) -> AIMessage | AIMessageChunk:
|
||||
additional_kwargs = dict(message.additional_kwargs)
|
||||
if additional_kwargs.get("reasoning_content") != reasoning:
|
||||
additional_kwargs["reasoning_content"] = reasoning
|
||||
return message.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||
|
||||
|
||||
def _get_typed_choice_message(response: Any, index: int) -> Any:
|
||||
choices = getattr(response, "choices", None)
|
||||
if choices is None:
|
||||
return None
|
||||
try:
|
||||
return choices[index].message
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
class PatchedChatMiMo(ChatOpenAI):
|
||||
"""ChatOpenAI with ``reasoning_content`` preservation for MiMo thinking mode."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"api_key": "MIMO_API_KEY", "openai_api_key": "MIMO_API_KEY"}
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
original_messages = self._convert_input(input_).to_messages()
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
restore_assistant_payloads(
|
||||
payload.get("messages", []),
|
||||
original_messages,
|
||||
restore_reasoning_content,
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: type,
|
||||
base_generation_info: dict | None,
|
||||
) -> ChatGenerationChunk | None:
|
||||
generation_chunk = super()._convert_chunk_to_generation_chunk(
|
||||
chunk,
|
||||
default_chunk_class,
|
||||
base_generation_info,
|
||||
)
|
||||
if generation_chunk is None:
|
||||
return None
|
||||
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta") or {}
|
||||
reasoning = _extract_reasoning_content(delta)
|
||||
if reasoning is not _MISSING and isinstance(generation_chunk.message, AIMessageChunk):
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=_with_reasoning_content(generation_chunk.message, reasoning),
|
||||
generation_info=generation_chunk.generation_info,
|
||||
)
|
||||
|
||||
return generation_chunk
|
||||
|
||||
def _create_chat_result(
|
||||
self,
|
||||
response: dict | Any,
|
||||
generation_info: dict | None = None,
|
||||
) -> ChatResult:
|
||||
result = super()._create_chat_result(response, generation_info)
|
||||
response_dict = response if isinstance(response, dict) else response.model_dump()
|
||||
choices = response_dict.get("choices", [])
|
||||
|
||||
patched_generations: list[ChatGeneration] | None = None
|
||||
for index, generation in enumerate(result.generations):
|
||||
choice = choices[index] if index < len(choices) else {}
|
||||
choice_message = choice.get("message", {}) if isinstance(choice, Mapping) else {}
|
||||
reasoning = _extract_reasoning_content(choice_message)
|
||||
if reasoning is _MISSING and not isinstance(response, dict):
|
||||
reasoning = _extract_reasoning_content(_get_typed_choice_message(response, index))
|
||||
|
||||
message = generation.message
|
||||
if reasoning is not _MISSING and isinstance(message, AIMessage):
|
||||
if patched_generations is None:
|
||||
patched_generations = list(result.generations)
|
||||
patched_generations[index] = ChatGeneration(
|
||||
message=_with_reasoning_content(message, reasoning),
|
||||
generation_info=generation.generation_info,
|
||||
)
|
||||
|
||||
return ChatResult(generations=patched_generations or result.generations, llm_output=result.llm_output)
|
||||
@@ -27,8 +27,6 @@ from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from deerflow.models.assistant_payload_replay import restore_assistant_payloads
|
||||
|
||||
|
||||
class PatchedChatOpenAI(ChatOpenAI):
|
||||
"""ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway.
|
||||
@@ -77,7 +75,18 @@ class PatchedChatOpenAI(ChatOpenAI):
|
||||
# Obtain the base payload from the parent implementation.
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
|
||||
restore_assistant_payloads(payload.get("messages", []), original_messages, _restore_tool_call_signatures)
|
||||
payload_messages = payload.get("messages", [])
|
||||
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
_restore_tool_call_signatures(payload_msg, orig_msg)
|
||||
else:
|
||||
# Fallback: match assistant-role entries positionally against AIMessages.
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
|
||||
for (_, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
|
||||
_restore_tool_call_signatures(payload_msg, ai_msg)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
@@ -144,13 +144,10 @@ class DbRunEventStore(RunEventStore):
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
thread_ids = {e["thread_id"] for e in events}
|
||||
if len(thread_ids) > 1:
|
||||
raise ValueError(f"put_batch requires all events to belong to the same thread; got {thread_ids!r}")
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# All events belong to the same thread (validated above).
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
thread_id = events[0]["thread_id"]
|
||||
max_seq = await self._max_seq_for_thread(session, thread_id)
|
||||
seq = max_seq or 0
|
||||
|
||||
@@ -6,15 +6,6 @@ Each run's events are stored in a single file:
|
||||
All categories (message, trace, lifecycle) are in the same file.
|
||||
This backend is suitable for lightweight single-node deployments.
|
||||
|
||||
**Single-process guarantee**: the in-memory seq counter is process-local.
|
||||
Multi-process deployments sharing the same directory will produce duplicate
|
||||
or non-monotonic seq values. Use ``DbRunEventStore`` for multi-process or
|
||||
high-concurrency deployments.
|
||||
|
||||
File I/O is offloaded to a thread pool via ``asyncio.to_thread`` so the
|
||||
event loop is never blocked. Per-thread ``asyncio.Lock`` objects serialise
|
||||
writes within a single process to prevent interleaved JSONL lines.
|
||||
|
||||
Known trade-off: ``list_messages()`` must scan all run files for a
|
||||
thread since messages from multiple runs need unified seq ordering.
|
||||
``list_events()`` reads only one file -- the fast path.
|
||||
@@ -22,7 +13,6 @@ thread since messages from multiple runs need unified seq ordering.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@@ -40,11 +30,6 @@ class JsonlRunEventStore(RunEventStore):
|
||||
def __init__(self, base_dir: str | Path | None = None):
|
||||
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
|
||||
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
|
||||
# Per-thread asyncio.Lock — serialises concurrent writes within one process.
|
||||
self._write_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
def _get_write_lock(self, thread_id: str) -> asyncio.Lock:
|
||||
return self._write_locks.setdefault(thread_id, asyncio.Lock())
|
||||
|
||||
@staticmethod
|
||||
def _validate_id(value: str, label: str) -> str:
|
||||
@@ -65,8 +50,10 @@ class JsonlRunEventStore(RunEventStore):
|
||||
self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1
|
||||
return self._seq_counters[thread_id]
|
||||
|
||||
def _compute_max_seq(self, thread_id: str) -> int:
|
||||
"""Scan all run files for a thread and return the current max seq (blocking I/O)."""
|
||||
def _ensure_seq_loaded(self, thread_id: str) -> None:
|
||||
"""Load max seq from existing files if not yet cached."""
|
||||
if thread_id in self._seq_counters:
|
||||
return
|
||||
max_seq = 0
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
@@ -77,13 +64,7 @@ class JsonlRunEventStore(RunEventStore):
|
||||
max_seq = max(max_seq, record.get("seq", 0))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", f)
|
||||
return max_seq
|
||||
|
||||
async def _ensure_seq_loaded(self, thread_id: str) -> None:
|
||||
"""Load max seq from existing files into the in-memory counter (non-blocking)."""
|
||||
if thread_id in self._seq_counters:
|
||||
return
|
||||
max_seq = await asyncio.to_thread(self._compute_max_seq, thread_id)
|
||||
continue
|
||||
self._seq_counters[thread_id] = max_seq
|
||||
|
||||
def _write_record(self, record: dict) -> None:
|
||||
@@ -93,7 +74,7 @@ class JsonlRunEventStore(RunEventStore):
|
||||
f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n")
|
||||
|
||||
def _read_thread_events(self, thread_id: str) -> list[dict]:
|
||||
"""Read all events for a thread, sorted by seq (blocking I/O)."""
|
||||
"""Read all events for a thread, sorted by seq."""
|
||||
events = []
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if not thread_dir.exists():
|
||||
@@ -106,11 +87,12 @@ class JsonlRunEventStore(RunEventStore):
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", f)
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]:
|
||||
"""Read events for a specific run file (blocking I/O)."""
|
||||
"""Read events for a specific run file."""
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if not path.exists():
|
||||
return []
|
||||
@@ -122,36 +104,25 @@ class JsonlRunEventStore(RunEventStore):
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", path)
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
def _delete_thread_files(self, thread_id: str) -> None:
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
f.unlink()
|
||||
|
||||
def _delete_run_file(self, thread_id: str, run_id: str) -> None:
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
|
||||
async with self._get_write_lock(thread_id):
|
||||
await self._ensure_seq_loaded(thread_id)
|
||||
seq = self._next_seq(thread_id)
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"seq": seq,
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
await asyncio.to_thread(self._write_record, record)
|
||||
return record
|
||||
self._ensure_seq_loaded(thread_id)
|
||||
seq = self._next_seq(thread_id)
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"seq": seq,
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
self._write_record(record)
|
||||
return record
|
||||
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
@@ -163,7 +134,7 @@ class JsonlRunEventStore(RunEventStore):
|
||||
return results
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = await asyncio.to_thread(self._read_thread_events, thread_id)
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
messages = [e for e in all_events if e.get("category") == "message"]
|
||||
|
||||
if before_seq is not None:
|
||||
@@ -176,13 +147,13 @@ class JsonlRunEventStore(RunEventStore):
|
||||
return messages[-limit:]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
events = await asyncio.to_thread(self._read_run_events, thread_id, run_id)
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
if event_types is not None:
|
||||
events = [e for e in events if e.get("event_type") in event_types]
|
||||
return events[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
events = await asyncio.to_thread(self._read_run_events, thread_id, run_id)
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
filtered = [e for e in events if e.get("category") == "message"]
|
||||
if before_seq is not None:
|
||||
filtered = [e for e in filtered if e.get("seq", 0) < before_seq]
|
||||
@@ -194,25 +165,23 @@ class JsonlRunEventStore(RunEventStore):
|
||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = await asyncio.to_thread(self._read_thread_events, thread_id)
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
return sum(1 for e in all_events if e.get("category") == "message")
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
async with self._get_write_lock(thread_id):
|
||||
all_events = await asyncio.to_thread(self._read_thread_events, thread_id)
|
||||
count = len(all_events)
|
||||
await asyncio.to_thread(self._delete_thread_files, thread_id)
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
# Pop the lock inside the held scope to minimise the window where a new caller
|
||||
# could obtain a fresh lock while a waiting coroutine still holds the old one.
|
||||
# Note: coroutines that already acquired a reference to this lock before the
|
||||
# delete will still proceed after we release — this is an accepted narrow race.
|
||||
self._write_locks.pop(thread_id, None)
|
||||
return count
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
count = len(all_events)
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
f.unlink()
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
async with self._get_write_lock(thread_id):
|
||||
events = await asyncio.to_thread(self._read_run_events, thread_id, run_id)
|
||||
count = len(events)
|
||||
await asyncio.to_thread(self._delete_run_file, thread_id, run_id)
|
||||
return count
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
count = len(events)
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return count
|
||||
|
||||
@@ -1,39 +1,16 @@
|
||||
"""Run lifecycle management for LangGraph Platform API compatibility."""
|
||||
|
||||
from .domain import (
|
||||
AssistantId,
|
||||
CancelAction,
|
||||
DisconnectMode,
|
||||
EventSeq,
|
||||
InvalidRunTransition,
|
||||
MultitaskStrategy,
|
||||
Run,
|
||||
RunId,
|
||||
RunScope,
|
||||
RunStatus,
|
||||
ThreadId,
|
||||
UserId,
|
||||
)
|
||||
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
|
||||
from .schemas import DisconnectMode, RunStatus
|
||||
from .worker import RunContext, run_agent
|
||||
|
||||
__all__ = [
|
||||
"AssistantId",
|
||||
"CancelAction",
|
||||
"ConflictError",
|
||||
"DisconnectMode",
|
||||
"EventSeq",
|
||||
"InvalidRunTransition",
|
||||
"MultitaskStrategy",
|
||||
"Run",
|
||||
"RunContext",
|
||||
"RunId",
|
||||
"RunManager",
|
||||
"RunRecord",
|
||||
"RunScope",
|
||||
"RunStatus",
|
||||
"ThreadId",
|
||||
"UnsupportedStrategyError",
|
||||
"UserId",
|
||||
"run_agent",
|
||||
]
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
"""Application-layer DTOs and services for run runtime use cases."""
|
||||
|
||||
from .commands import CancelRunCommand, CreateRunCommand, JoinRunStreamCommand
|
||||
from .dto import RunMessageView, RunSnapshot, RunStreamHandle, StoredRunEvent
|
||||
from .queries import GetRunQuery, ListRunMessagesQuery, ListRunsQuery
|
||||
from .services import RunsApplicationService
|
||||
|
||||
__all__ = [
|
||||
"CancelRunCommand",
|
||||
"CreateRunCommand",
|
||||
"GetRunQuery",
|
||||
"JoinRunStreamCommand",
|
||||
"ListRunMessagesQuery",
|
||||
"ListRunsQuery",
|
||||
"RunMessageView",
|
||||
"RunSnapshot",
|
||||
"RunStreamHandle",
|
||||
"RunsApplicationService",
|
||||
"StoredRunEvent",
|
||||
]
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Application command DTOs for run use cases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
from ..domain import AssistantId, CancelAction, DisconnectMode, MultitaskStrategy, RunId, RunScope, ThreadId
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CreateRunCommand:
|
||||
thread_id: ThreadId
|
||||
assistant_id: AssistantId | None = None
|
||||
input: dict[str, Any] | None = None
|
||||
command: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
config: dict[str, Any] = field(default_factory=dict)
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
scope: RunScope = RunScope.stateful
|
||||
on_disconnect: DisconnectMode = DisconnectMode.cancel
|
||||
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
|
||||
stream_mode: list[str] | str | None = None
|
||||
stream_subgraphs: bool = False
|
||||
interrupt_before: list[str] | Literal["*"] | None = None
|
||||
interrupt_after: list[str] | Literal["*"] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CancelRunCommand:
|
||||
run_id: RunId
|
||||
action: CancelAction = CancelAction.interrupt
|
||||
wait: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JoinRunStreamCommand:
|
||||
run_id: RunId
|
||||
last_event_id: str | None = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CancelRunCommand",
|
||||
"CreateRunCommand",
|
||||
"JoinRunStreamCommand",
|
||||
]
|
||||
@@ -1,76 +0,0 @@
|
||||
"""Application output DTOs for run use cases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..domain import AssistantId, EventSeq, Run, RunId, RunStatus, ThreadId
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunSnapshot:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
assistant_id: AssistantId | None = None
|
||||
status: RunStatus = RunStatus.pending
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
error: str | None = None
|
||||
model_name: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_run(cls, run: Run) -> RunSnapshot:
|
||||
return cls(
|
||||
run_id=run.run_id,
|
||||
thread_id=run.thread_id,
|
||||
assistant_id=run.assistant_id,
|
||||
status=run.status,
|
||||
metadata=dict(run.metadata),
|
||||
kwargs=dict(run.kwargs),
|
||||
created_at=run.created_at,
|
||||
updated_at=run.updated_at,
|
||||
error=run.error,
|
||||
model_name=run.model_name,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunMessageView:
|
||||
thread_id: ThreadId
|
||||
run_id: RunId
|
||||
seq: EventSeq
|
||||
event_type: str
|
||||
content: str | dict[str, Any] = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StoredRunEvent:
|
||||
thread_id: ThreadId
|
||||
run_id: RunId
|
||||
seq: EventSeq
|
||||
event_type: str
|
||||
category: str
|
||||
content: str | dict[str, Any] = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunStreamHandle:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
events: AsyncIterator[Any]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunMessageView",
|
||||
"RunSnapshot",
|
||||
"RunStreamHandle",
|
||||
"StoredRunEvent",
|
||||
]
|
||||
@@ -1,37 +0,0 @@
|
||||
"""Application query DTOs for run use cases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..domain import RunId, ThreadId, UserId
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetRunQuery:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId | None = None
|
||||
user_id: UserId | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListRunsQuery:
|
||||
thread_id: ThreadId
|
||||
user_id: UserId | None = None
|
||||
limit: int = 100
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListRunMessagesQuery:
|
||||
thread_id: ThreadId
|
||||
run_id: RunId
|
||||
limit: int = 50
|
||||
before_seq: int | None = None
|
||||
after_seq: int | None = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GetRunQuery",
|
||||
"ListRunMessagesQuery",
|
||||
"ListRunsQuery",
|
||||
]
|
||||
@@ -1,74 +0,0 @@
|
||||
"""Application service skeleton for run use cases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..execution import RunExecutionScheduler, RunSupervisor
|
||||
from ..repositories import RunEventLog, RunRepository
|
||||
from ..streams import RunStreamBroker
|
||||
from .commands import CancelRunCommand, CreateRunCommand, JoinRunStreamCommand
|
||||
from .dto import RunMessageView, RunSnapshot, RunStreamHandle
|
||||
from .queries import GetRunQuery, ListRunMessagesQuery, ListRunsQuery
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunsApplicationService:
|
||||
"""Use-case orchestration boundary for run runtime operations.
|
||||
|
||||
PR1 only introduces the boundary and dependency shape. Existing Gateway
|
||||
handlers continue to call the legacy service functions until later PRs move
|
||||
behavior into this class.
|
||||
"""
|
||||
|
||||
run_repository: RunRepository
|
||||
run_event_log: RunEventLog
|
||||
stream_broker: RunStreamBroker
|
||||
scheduler: RunExecutionScheduler
|
||||
supervisor: RunSupervisor
|
||||
|
||||
async def create_background(self, command: CreateRunCommand) -> RunSnapshot:
|
||||
# PR1 defines the application boundary; later PRs move Gateway runtime
|
||||
# behavior behind this method.
|
||||
raise NotImplementedError("RunsApplicationService is not wired in PR1")
|
||||
|
||||
async def create_and_stream(self, command: CreateRunCommand) -> RunStreamHandle:
|
||||
raise NotImplementedError("RunsApplicationService is not wired in PR1")
|
||||
|
||||
async def create_and_wait(self, command: CreateRunCommand) -> RunSnapshot:
|
||||
raise NotImplementedError("RunsApplicationService is not wired in PR1")
|
||||
|
||||
async def join_stream(self, command: JoinRunStreamCommand) -> RunStreamHandle:
|
||||
raise NotImplementedError("RunsApplicationService is not wired in PR1")
|
||||
|
||||
async def cancel(self, command: CancelRunCommand) -> bool:
|
||||
return await self.supervisor.cancel(command.run_id, action=command.action)
|
||||
|
||||
async def get_run(self, query: GetRunQuery) -> RunSnapshot | None:
|
||||
run = await self.run_repository.get(query.run_id, user_id=query.user_id)
|
||||
if run is None:
|
||||
return None
|
||||
if query.thread_id is not None and run.thread_id != query.thread_id:
|
||||
return None
|
||||
return RunSnapshot.from_run(run)
|
||||
|
||||
async def list_runs(self, query: ListRunsQuery) -> list[RunSnapshot]:
|
||||
return await self.run_repository.list_by_thread(
|
||||
query.thread_id,
|
||||
user_id=query.user_id,
|
||||
limit=query.limit,
|
||||
)
|
||||
|
||||
async def list_run_messages(self, query: ListRunMessagesQuery) -> list[RunMessageView]:
|
||||
return await self.run_event_log.list_messages_by_run(
|
||||
query.thread_id,
|
||||
query.run_id,
|
||||
limit=query.limit,
|
||||
before_seq=query.before_seq,
|
||||
after_seq=query.after_seq,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunsApplicationService",
|
||||
]
|
||||
@@ -1,33 +0,0 @@
|
||||
"""Run runtime domain model."""
|
||||
|
||||
from .errors import InvalidRunTransition, RunDomainError
|
||||
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
|
||||
from .identifiers import AssistantId, RunId, ThreadId, UserId
|
||||
from .model import Run
|
||||
from .policies import CancelPolicy, MultitaskDecision, MultitaskPolicy
|
||||
from .value_objects import CancelAction, DisconnectMode, EventSeq, MultitaskStrategy, RunScope, RunStatus
|
||||
|
||||
__all__ = [
|
||||
"AssistantId",
|
||||
"CancelAction",
|
||||
"CancelPolicy",
|
||||
"DisconnectMode",
|
||||
"EventSeq",
|
||||
"InvalidRunTransition",
|
||||
"MultitaskDecision",
|
||||
"MultitaskPolicy",
|
||||
"MultitaskStrategy",
|
||||
"Run",
|
||||
"RunCancelled",
|
||||
"RunCompleted",
|
||||
"RunCreated",
|
||||
"RunDomainError",
|
||||
"RunEvent",
|
||||
"RunFailed",
|
||||
"RunId",
|
||||
"RunScope",
|
||||
"RunStarted",
|
||||
"RunStatus",
|
||||
"ThreadId",
|
||||
"UserId",
|
||||
]
|
||||
@@ -1,24 +0,0 @@
|
||||
"""Domain-level errors for run lifecycle operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .value_objects import RunStatus
|
||||
|
||||
|
||||
class RunDomainError(Exception):
|
||||
"""Base class for run runtime domain errors."""
|
||||
|
||||
|
||||
class InvalidRunTransition(RunDomainError):
|
||||
"""Raised when a run status transition violates lifecycle rules."""
|
||||
|
||||
def __init__(self, current: RunStatus, target: RunStatus) -> None:
|
||||
super().__init__(f"Cannot transition run from {current.value!r} to {target.value!r}")
|
||||
self.current = current
|
||||
self.target = target
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InvalidRunTransition",
|
||||
"RunDomainError",
|
||||
]
|
||||
@@ -1,64 +0,0 @@
|
||||
"""Domain events emitted by the run aggregate."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from deerflow.utils.time import now_iso
|
||||
|
||||
from .identifiers import AssistantId, RunId, ThreadId
|
||||
from .value_objects import CancelAction, RunStatus
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunCreated:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
occurred_at: str = field(default_factory=now_iso)
|
||||
assistant_id: AssistantId | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunStarted:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
occurred_at: str = field(default_factory=now_iso)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunCompleted:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
occurred_at: str = field(default_factory=now_iso)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunFailed:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
status: RunStatus
|
||||
occurred_at: str = field(default_factory=now_iso)
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunCancelled:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
occurred_at: str = field(default_factory=now_iso)
|
||||
action: CancelAction = CancelAction.interrupt
|
||||
|
||||
|
||||
RunEvent = RunCreated | RunStarted | RunCompleted | RunFailed | RunCancelled
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunCancelled",
|
||||
"RunCompleted",
|
||||
"RunCreated",
|
||||
"RunEvent",
|
||||
"RunFailed",
|
||||
"RunStarted",
|
||||
]
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Lightweight identifiers for the run runtime domain."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NewType
|
||||
|
||||
RunId = NewType("RunId", str)
|
||||
ThreadId = NewType("ThreadId", str)
|
||||
AssistantId = NewType("AssistantId", str)
|
||||
UserId = NewType("UserId", str)
|
||||
|
||||
|
||||
def require_non_empty(value: str, *, field_name: str) -> str:
|
||||
"""Return a stripped identifier value, rejecting empty identifiers."""
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
raise ValueError(f"{field_name} must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AssistantId",
|
||||
"RunId",
|
||||
"ThreadId",
|
||||
"UserId",
|
||||
"require_non_empty",
|
||||
]
|
||||
@@ -1,193 +0,0 @@
|
||||
"""Run aggregate root and lifecycle invariants."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from deerflow.utils.time import now_iso
|
||||
|
||||
from .errors import InvalidRunTransition
|
||||
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
|
||||
from .identifiers import AssistantId, RunId, ThreadId, require_non_empty
|
||||
from .value_objects import CancelAction, MultitaskStrategy, RunScope, RunStatus
|
||||
|
||||
# Keep lifecycle transitions explicit so later application code cannot invent
|
||||
# ad hoc status moves outside the aggregate.
|
||||
_ALLOWED_TRANSITIONS: dict[RunStatus, frozenset[RunStatus]] = {
|
||||
RunStatus.pending: frozenset(
|
||||
{
|
||||
RunStatus.running,
|
||||
RunStatus.error,
|
||||
RunStatus.timeout,
|
||||
RunStatus.interrupted,
|
||||
}
|
||||
),
|
||||
RunStatus.running: frozenset(
|
||||
{
|
||||
RunStatus.success,
|
||||
RunStatus.error,
|
||||
RunStatus.timeout,
|
||||
RunStatus.interrupted,
|
||||
}
|
||||
),
|
||||
RunStatus.success: frozenset(),
|
||||
RunStatus.error: frozenset(),
|
||||
RunStatus.timeout: frozenset(),
|
||||
RunStatus.interrupted: frozenset(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Run:
|
||||
"""Run aggregate root.
|
||||
|
||||
The aggregate owns lifecycle invariants only. Infrastructure concerns such
|
||||
as SQL sessions, SSE frames, Redis clients, and FastAPI requests stay out of
|
||||
this model.
|
||||
"""
|
||||
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
status: RunStatus
|
||||
assistant_id: AssistantId | None = None
|
||||
scope: RunScope = RunScope.stateful
|
||||
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: str = field(default_factory=now_iso)
|
||||
updated_at: str = field(default_factory=now_iso)
|
||||
error: str | None = None
|
||||
model_name: str | None = None
|
||||
_pending_events: list[RunEvent] = field(default_factory=list, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.run_id = RunId(require_non_empty(str(self.run_id), field_name="run_id"))
|
||||
self.thread_id = ThreadId(require_non_empty(str(self.thread_id), field_name="thread_id"))
|
||||
if self.assistant_id is not None:
|
||||
self.assistant_id = AssistantId(require_non_empty(str(self.assistant_id), field_name="assistant_id"))
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
*,
|
||||
run_id: RunId,
|
||||
thread_id: ThreadId,
|
||||
assistant_id: AssistantId | None = None,
|
||||
scope: RunScope = RunScope.stateful,
|
||||
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
model_name: str | None = None,
|
||||
created_at: str | None = None,
|
||||
) -> Run:
|
||||
timestamp = created_at or now_iso()
|
||||
run = cls(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
status=RunStatus.pending,
|
||||
scope=scope,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
created_at=timestamp,
|
||||
updated_at=timestamp,
|
||||
model_name=model_name,
|
||||
)
|
||||
run._record_event(
|
||||
RunCreated(
|
||||
run_id=run.run_id,
|
||||
thread_id=run.thread_id,
|
||||
occurred_at=timestamp,
|
||||
assistant_id=run.assistant_id,
|
||||
metadata=dict(run.metadata),
|
||||
)
|
||||
)
|
||||
return run
|
||||
|
||||
@property
|
||||
def is_terminal(self) -> bool:
|
||||
return not _ALLOWED_TRANSITIONS[self.status]
|
||||
|
||||
def pull_events(self) -> tuple[RunEvent, ...]:
|
||||
# Domain events are drained by the application layer after the aggregate
|
||||
# has accepted a state change.
|
||||
events = tuple(self._pending_events)
|
||||
self._pending_events.clear()
|
||||
return events
|
||||
|
||||
def mark_started(self, *, at: str | None = None) -> None:
|
||||
self._transition_to(RunStatus.running, at=at)
|
||||
|
||||
def mark_completed(self, *, at: str | None = None) -> None:
|
||||
self._transition_to(RunStatus.success, at=at)
|
||||
|
||||
def mark_failed(self, error: str | None = None, *, at: str | None = None) -> None:
|
||||
self._transition_to(RunStatus.error, error=error, at=at)
|
||||
|
||||
def mark_timed_out(self, error: str | None = None, *, at: str | None = None) -> None:
|
||||
self._transition_to(RunStatus.timeout, error=error, at=at)
|
||||
|
||||
def mark_cancelled(self, *, action: CancelAction = CancelAction.interrupt, at: str | None = None) -> None:
|
||||
self._transition_to(RunStatus.interrupted, action=action, at=at)
|
||||
|
||||
def _transition_to(
|
||||
self,
|
||||
target: RunStatus,
|
||||
*,
|
||||
error: str | None = None,
|
||||
action: CancelAction = CancelAction.interrupt,
|
||||
at: str | None = None,
|
||||
) -> None:
|
||||
if target == self.status:
|
||||
return
|
||||
if target not in _ALLOWED_TRANSITIONS[self.status]:
|
||||
raise InvalidRunTransition(self.status, target)
|
||||
|
||||
timestamp = at or now_iso()
|
||||
self.status = target
|
||||
self.updated_at = timestamp
|
||||
if error is not None:
|
||||
self.error = error
|
||||
self._record_event(self._event_for_transition(target, timestamp, error=error, action=action))
|
||||
|
||||
def _event_for_transition(
|
||||
self,
|
||||
target: RunStatus,
|
||||
occurred_at: str,
|
||||
*,
|
||||
error: str | None,
|
||||
action: CancelAction,
|
||||
) -> RunEvent:
|
||||
# Keep event construction next to the transition rules so a new status
|
||||
# cannot be added without an explicit durable event shape.
|
||||
if target == RunStatus.running:
|
||||
return RunStarted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
|
||||
if target == RunStatus.success:
|
||||
return RunCompleted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
|
||||
if target in (RunStatus.error, RunStatus.timeout):
|
||||
return RunFailed(
|
||||
run_id=self.run_id,
|
||||
thread_id=self.thread_id,
|
||||
status=target,
|
||||
occurred_at=occurred_at,
|
||||
error=error,
|
||||
)
|
||||
if target == RunStatus.interrupted:
|
||||
return RunCancelled(
|
||||
run_id=self.run_id,
|
||||
thread_id=self.thread_id,
|
||||
occurred_at=occurred_at,
|
||||
action=action,
|
||||
)
|
||||
raise InvalidRunTransition(self.status, target)
|
||||
|
||||
def _record_event(self, event: RunEvent) -> None:
|
||||
self._pending_events.append(event)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Run",
|
||||
"RunStatus",
|
||||
]
|
||||
@@ -1,50 +0,0 @@
|
||||
"""Domain policies for run concurrency and cancellation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
from .model import Run
|
||||
from .value_objects import CancelAction, MultitaskStrategy, RunStatus
|
||||
|
||||
|
||||
class MultitaskDecision(StrEnum):
|
||||
"""Application-level decision produced by a multitask policy."""
|
||||
|
||||
allow = "allow"
|
||||
reject = "reject"
|
||||
cancel_existing = "cancel_existing"
|
||||
enqueue = "enqueue"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultitaskPolicy:
|
||||
strategy: MultitaskStrategy = MultitaskStrategy.reject
|
||||
|
||||
def decide(self, active_runs: Sequence[Run]) -> MultitaskDecision:
|
||||
inflight = [run for run in active_runs if run.status in (RunStatus.pending, RunStatus.running)]
|
||||
if not inflight:
|
||||
return MultitaskDecision.allow
|
||||
if self.strategy == MultitaskStrategy.reject:
|
||||
return MultitaskDecision.reject
|
||||
if self.strategy in (MultitaskStrategy.interrupt, MultitaskStrategy.rollback):
|
||||
return MultitaskDecision.cancel_existing
|
||||
return MultitaskDecision.enqueue
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CancelPolicy:
|
||||
action: CancelAction = CancelAction.interrupt
|
||||
|
||||
@property
|
||||
def rolls_back_checkpoint(self) -> bool:
|
||||
return self.action == CancelAction.rollback
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CancelPolicy",
|
||||
"MultitaskDecision",
|
||||
"MultitaskPolicy",
|
||||
]
|
||||
@@ -1,88 +0,0 @@
|
||||
"""Domain value objects for run lifecycle semantics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RunStatus(StrEnum):
|
||||
"""Lifecycle status of a single run."""
|
||||
|
||||
pending = "pending"
|
||||
running = "running"
|
||||
success = "success"
|
||||
error = "error"
|
||||
timeout = "timeout"
|
||||
interrupted = "interrupted"
|
||||
|
||||
|
||||
class DisconnectMode(StrEnum):
|
||||
"""Behaviour when the SSE consumer disconnects."""
|
||||
|
||||
cancel = "cancel"
|
||||
continue_ = "continue"
|
||||
|
||||
|
||||
class RunScope(StrEnum):
|
||||
"""Conversation scope for a run."""
|
||||
|
||||
stateful = "stateful"
|
||||
stateless = "stateless"
|
||||
temporary_thread = "temporary_thread"
|
||||
|
||||
|
||||
class MultitaskStrategy(StrEnum):
|
||||
"""Concurrency strategy for a new run on a thread."""
|
||||
|
||||
reject = "reject"
|
||||
interrupt = "interrupt"
|
||||
rollback = "rollback"
|
||||
enqueue = "enqueue"
|
||||
|
||||
|
||||
class CancelAction(StrEnum):
|
||||
"""Cancellation action requested by an API or supervisor."""
|
||||
|
||||
interrupt = "interrupt"
|
||||
rollback = "rollback"
|
||||
|
||||
|
||||
TERMINAL_RUN_STATUSES: frozenset[RunStatus] = frozenset(
|
||||
{
|
||||
RunStatus.success,
|
||||
RunStatus.error,
|
||||
RunStatus.timeout,
|
||||
RunStatus.interrupted,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_terminal_status(status: RunStatus) -> bool:
|
||||
return status in TERMINAL_RUN_STATUSES
|
||||
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class EventSeq:
|
||||
"""Thread-local event sequence number."""
|
||||
|
||||
value: int
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.value < 0:
|
||||
raise ValueError("EventSeq must be non-negative")
|
||||
|
||||
def next(self) -> EventSeq:
|
||||
return EventSeq(self.value + 1)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CancelAction",
|
||||
"DisconnectMode",
|
||||
"EventSeq",
|
||||
"MultitaskStrategy",
|
||||
"RunScope",
|
||||
"RunStatus",
|
||||
"TERMINAL_RUN_STATUSES",
|
||||
"is_terminal_status",
|
||||
]
|
||||
@@ -1,12 +0,0 @@
|
||||
"""Execution contracts for run lifecycle orchestration."""
|
||||
|
||||
from .executor import RunExecutor
|
||||
from .scheduler import RunExecutionHandle, RunExecutionScheduler
|
||||
from .supervisor import RunSupervisor
|
||||
|
||||
__all__ = [
|
||||
"RunExecutionHandle",
|
||||
"RunExecutionScheduler",
|
||||
"RunExecutor",
|
||||
"RunSupervisor",
|
||||
]
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Run executor contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from ..domain import Run
|
||||
|
||||
|
||||
class RunExecutor(Protocol):
|
||||
"""Executes one run against the underlying agent or graph runtime."""
|
||||
|
||||
async def execute(self, run: Run) -> None:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunExecutor",
|
||||
]
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Run execution scheduler contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
from ..domain import RunId
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunExecutionHandle:
|
||||
run_id: RunId
|
||||
|
||||
|
||||
class RunExecutionScheduler(Protocol):
|
||||
"""Starts background execution for an accepted run."""
|
||||
|
||||
async def start(self, run_id: RunId) -> RunExecutionHandle:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunExecutionHandle",
|
||||
"RunExecutionScheduler",
|
||||
]
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Run execution supervision contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from ..domain import CancelAction, RunId
|
||||
|
||||
|
||||
class RunSupervisor(Protocol):
|
||||
"""Controls lifecycle operations for already scheduled runs."""
|
||||
|
||||
async def cancel(self, run_id: RunId, *, action: CancelAction = CancelAction.interrupt) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunSupervisor",
|
||||
]
|
||||
@@ -1,9 +0,0 @@
|
||||
"""Repository contracts for the run runtime application layer."""
|
||||
|
||||
from .run_event_log import RunEventLog
|
||||
from .run_repository import RunRepository
|
||||
|
||||
__all__ = [
|
||||
"RunEventLog",
|
||||
"RunRepository",
|
||||
]
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Durable run event log contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from ..domain import RunEvent, RunId, ThreadId
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..application.dto import RunMessageView, StoredRunEvent
|
||||
|
||||
|
||||
class RunEventLog(Protocol):
|
||||
"""Persistence boundary for run messages and execution trace events."""
|
||||
|
||||
async def append(self, events: list[RunEvent]) -> list[StoredRunEvent]:
|
||||
pass
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: ThreadId,
|
||||
run_id: RunId,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[RunMessageView]:
|
||||
pass
|
||||
|
||||
async def list_events_by_run(
|
||||
self,
|
||||
thread_id: ThreadId,
|
||||
run_id: RunId,
|
||||
*,
|
||||
limit: int = 500,
|
||||
) -> list[StoredRunEvent]:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunEventLog",
|
||||
]
|
||||
@@ -1,37 +0,0 @@
|
||||
"""Run state repository contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from ..domain import Run, RunId, ThreadId, UserId
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..application.dto import RunSnapshot
|
||||
|
||||
|
||||
class RunRepository(Protocol):
|
||||
"""Persistence boundary for run state snapshots."""
|
||||
|
||||
async def save(self, run: Run) -> None:
|
||||
pass
|
||||
|
||||
async def get(self, run_id: RunId, *, user_id: UserId | None = None) -> Run | None:
|
||||
pass
|
||||
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id: ThreadId,
|
||||
*,
|
||||
user_id: UserId | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[RunSnapshot]:
|
||||
pass
|
||||
|
||||
async def delete(self, run_id: RunId) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunRepository",
|
||||
]
|
||||
@@ -1,10 +1,21 @@
|
||||
"""Compatibility exports for run status and disconnect mode enums."""
|
||||
"""Run status and disconnect mode enums."""
|
||||
|
||||
# Existing callers import these enums from ``runs.schemas``. Re-export the
|
||||
# domain definitions until all imports move to ``runs.domain``.
|
||||
from .domain import DisconnectMode, RunStatus
|
||||
from enum import StrEnum
|
||||
|
||||
__all__ = [
|
||||
"DisconnectMode",
|
||||
"RunStatus",
|
||||
]
|
||||
|
||||
class RunStatus(StrEnum):
|
||||
"""Lifecycle status of a single run."""
|
||||
|
||||
pending = "pending"
|
||||
running = "running"
|
||||
success = "success"
|
||||
error = "error"
|
||||
timeout = "timeout"
|
||||
interrupted = "interrupted"
|
||||
|
||||
|
||||
class DisconnectMode(StrEnum):
|
||||
"""Behaviour when the SSE consumer disconnects."""
|
||||
|
||||
cancel = "cancel"
|
||||
continue_ = "continue"
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
"""Realtime stream contracts for run application use cases."""
|
||||
|
||||
from .run_stream_broker import RunStreamBroker, RunStreamEvent
|
||||
|
||||
__all__ = [
|
||||
"RunStreamBroker",
|
||||
"RunStreamEvent",
|
||||
]
|
||||
@@ -1,44 +0,0 @@
|
||||
"""Realtime run stream broker contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from ..domain import RunId
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunStreamEvent:
|
||||
id: str
|
||||
event: str
|
||||
data: Any
|
||||
|
||||
|
||||
class RunStreamBroker(Protocol):
|
||||
"""Realtime publish/subscribe boundary for run streams."""
|
||||
|
||||
async def publish(self, run_id: RunId, event: str, data: Any) -> None:
|
||||
pass
|
||||
|
||||
async def publish_terminal(self, run_id: RunId, *, event: str = "end", data: Any = None) -> None:
|
||||
pass
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
run_id: RunId,
|
||||
*,
|
||||
last_event_id: str | None = None,
|
||||
heartbeat_interval: float = 15.0,
|
||||
) -> AsyncIterator[RunStreamEvent]:
|
||||
pass
|
||||
|
||||
async def cleanup(self, run_id: RunId, *, delay: float = 0) -> None:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunStreamBroker",
|
||||
"RunStreamEvent",
|
||||
]
|
||||
@@ -13,7 +13,6 @@ import stat
|
||||
import zipfile
|
||||
from pathlib import Path, PurePosixPath, PureWindowsPath
|
||||
|
||||
from deerflow.skills.permissions import make_skill_tree_sandbox_readable
|
||||
from deerflow.skills.security_scanner import scan_skill_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -140,7 +139,6 @@ def _move_staged_skill_into_reserved_target(staging_target: Path, target: Path)
|
||||
reserved = True
|
||||
for child in staging_target.iterdir():
|
||||
shutil.move(str(child), target / child.name)
|
||||
make_skill_tree_sandbox_readable(target)
|
||||
installed = True
|
||||
except FileExistsError as e:
|
||||
raise SkillAlreadyExistsError(f"Skill '{target.name}' already exists") from e
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
"""Filesystem permission helpers for installed skill trees."""
|
||||
|
||||
import stat
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def make_skill_path_sandbox_readable(path: Path) -> None:
|
||||
if path.is_symlink():
|
||||
return
|
||||
mode = stat.S_IMODE(path.stat().st_mode)
|
||||
without_sandbox_write = mode & ~(stat.S_IWGRP | stat.S_IWOTH)
|
||||
if path.is_dir():
|
||||
path.chmod(without_sandbox_write | 0o555)
|
||||
elif path.is_file():
|
||||
path.chmod(without_sandbox_write | 0o444)
|
||||
|
||||
|
||||
def make_skill_tree_sandbox_readable(target: Path) -> None:
|
||||
make_skill_path_sandbox_readable(target)
|
||||
for path in target.rglob("*"):
|
||||
make_skill_path_sandbox_readable(path)
|
||||
|
||||
|
||||
def make_skill_written_path_sandbox_readable(skill_root: Path, target: Path) -> None:
|
||||
resolved_root = skill_root.resolve()
|
||||
resolved_target = target.resolve()
|
||||
resolved_target.relative_to(resolved_root)
|
||||
|
||||
make_skill_path_sandbox_readable(resolved_root)
|
||||
current = resolved_root
|
||||
for part in resolved_target.parent.relative_to(resolved_root).parts:
|
||||
current = current / part
|
||||
make_skill_path_sandbox_readable(current)
|
||||
make_skill_path_sandbox_readable(resolved_target)
|
||||
@@ -13,7 +13,6 @@ from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.runtime_paths import resolve_path
|
||||
from deerflow.skills.permissions import make_skill_written_path_sandbox_readable
|
||||
from deerflow.skills.storage.skill_storage import SKILL_MD_FILE, SkillStorage
|
||||
from deerflow.skills.types import SkillCategory
|
||||
|
||||
@@ -91,7 +90,6 @@ class LocalSkillStorage(SkillStorage):
|
||||
tmp_file.write(content)
|
||||
tmp_path = Path(tmp_file.name)
|
||||
tmp_path.replace(target)
|
||||
make_skill_written_path_sandbox_readable(self.get_custom_skill_dir(name), target)
|
||||
|
||||
async def ainstall_skill_from_archive(self, archive_path: str | Path) -> dict:
|
||||
import zipfile
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
"""Regression anchor: JsonlRunEventStore async API must not block the loop.
|
||||
|
||||
``JsonlRunEventStore`` is the ``run_events.backend == "jsonl"`` implementation.
|
||||
Its ``async def`` methods perform synchronous filesystem IO (``Path.glob``,
|
||||
``read_text``, ``open``, ``unlink``) that must be offloaded with
|
||||
``asyncio.to_thread`` (fixed in #3084). ``put`` runs on every emitted run event,
|
||||
so any blocking IO here stalls the event loop on the hot path.
|
||||
|
||||
#3084 added a mock-based offload assertion in
|
||||
``tests/test_jsonl_event_store_async_io.py`` that covers ``put`` only. This
|
||||
anchor complements it by driving the **full** async surface (``put``,
|
||||
``put_batch``, ``list_messages``, ``list_events``, ``list_messages_by_run``,
|
||||
``count_messages``, ``delete_by_run``, ``delete_by_thread``) under the strict
|
||||
Blockbuster runtime gate, so any blocking IO reintroduced on the event loop in
|
||||
any of these methods — not just removal of a specific ``to_thread`` call —
|
||||
fails CI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_jsonl_run_event_store_async_api_does_not_block_event_loop(tmp_path: Path) -> None:
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
store = JsonlRunEventStore(base_dir=str(tmp_path))
|
||||
|
||||
# Seed an existing run file so put()'s seq-load globs + reads, and the
|
||||
# read/delete paths have files to scan. Test-side IO is invisible to the
|
||||
# gate (this module is not in scanned_modules).
|
||||
thread_dir = tmp_path / "threads" / "t1" / "runs"
|
||||
thread_dir.mkdir(parents=True, exist_ok=True)
|
||||
(thread_dir / "r0.jsonl").write_text('{"seq": 1, "category": "message", "run_id": "r0"}\n', encoding="utf-8")
|
||||
|
||||
# writes: put + put_batch
|
||||
record = await store.put(thread_id="t1", run_id="r1", event_type="message", category="message", content="hi")
|
||||
assert record["seq"] >= 2
|
||||
batch = await store.put_batch(
|
||||
[
|
||||
{"thread_id": "t1", "run_id": "r2", "event_type": "message", "category": "message", "content": "a"},
|
||||
{"thread_id": "t1", "run_id": "r2", "event_type": "trace", "category": "trace", "content": "b"},
|
||||
]
|
||||
)
|
||||
assert len(batch) == 2
|
||||
|
||||
# reads: list_messages / list_events / list_messages_by_run / count_messages.
|
||||
# list_events is exercised both without and with the event_types filter so
|
||||
# the filter branch runs after _read_run_events' filesystem IO.
|
||||
assert isinstance(await store.list_messages("t1"), list)
|
||||
assert isinstance(await store.list_events("t1", "r1"), list)
|
||||
assert isinstance(await store.list_events("t1", "r1", event_types=["message"]), list)
|
||||
assert isinstance(await store.list_messages_by_run("t1", "r2"), list)
|
||||
assert await store.count_messages("t1") >= 1
|
||||
|
||||
# deletes: delete_by_run (single file) then delete_by_thread (remaining)
|
||||
assert await store.delete_by_run("t1", "r2") >= 1
|
||||
assert await store.delete_by_thread("t1") >= 1
|
||||
@@ -1,56 +0,0 @@
|
||||
"""Regression anchor: UploadsMiddleware must not block the event loop.
|
||||
|
||||
``before_agent`` scans the thread uploads directory (``exists`` / ``iterdir`` /
|
||||
``stat`` plus reading sibling ``.md`` outlines). LangChain wires a sync-only
|
||||
``before_agent`` as ``RunnableCallable(before_agent, None)``; langgraph's
|
||||
``ainvoke`` runs it directly on the event loop when ``afunc is None``. So the
|
||||
filesystem scan must be offloaded (the middleware provides ``abefore_agent``).
|
||||
|
||||
This anchor drives the real ``create_agent`` graph via ``ainvoke`` under the
|
||||
strict Blockbuster gate. If the scan regresses back onto the event loop,
|
||||
Blockbuster raises ``BlockingError`` and this test fails.
|
||||
|
||||
The graph/middleware construction is offloaded with ``asyncio.to_thread`` only
|
||||
because ``Paths.__init__`` resolves paths synchronously; the surface under test
|
||||
(``before_agent``'s directory scan) is exercised on the event loop, not
|
||||
bypassed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class _FakeModel(FakeMessagesListChatModel):
|
||||
"""FakeMessagesListChatModel with a no-op ``bind_tools`` for create_agent."""
|
||||
|
||||
def bind_tools(self, tools, **kwargs): # type: ignore[override]
|
||||
return self
|
||||
|
||||
|
||||
async def test_before_agent_uploads_scan_does_not_block_event_loop(tmp_path: Path) -> None:
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
mw = await asyncio.to_thread(UploadsMiddleware, str(tmp_path))
|
||||
uploads_dir = await asyncio.to_thread(mw._paths.sandbox_uploads_dir, "t1", user_id=get_effective_user_id())
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True) # test-side seeding (not in scanned_modules)
|
||||
(uploads_dir / "existing.txt").write_text("hello", encoding="utf-8")
|
||||
|
||||
agent = await asyncio.to_thread(lambda: create_agent(model=_FakeModel(responses=[AIMessage(content="ok")]), tools=[], middleware=[mw]))
|
||||
|
||||
result = await agent.ainvoke(
|
||||
{"messages": [HumanMessage(content="hi")]},
|
||||
{"configurable": {"thread_id": "t1"}},
|
||||
)
|
||||
|
||||
assert result["messages"]
|
||||
@@ -0,0 +1,210 @@
|
||||
"""Tests for AioSandboxProvider auto-restart of crashed containers."""
|
||||
|
||||
import importlib
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _import_provider():
|
||||
return importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
|
||||
|
||||
def _make_provider(*, auto_restart=True, alive=True):
|
||||
"""Build a minimal AioSandboxProvider with a mock backend.
|
||||
|
||||
Args:
|
||||
auto_restart: Value for the auto_restart config key.
|
||||
alive: Whether the mock backend reports containers as alive.
|
||||
"""
|
||||
mod = _import_provider()
|
||||
with patch.object(mod.AioSandboxProvider, "_start_idle_checker"):
|
||||
provider = mod.AioSandboxProvider.__new__(mod.AioSandboxProvider)
|
||||
provider._config = {"auto_restart": auto_restart}
|
||||
provider._lock = threading.Lock()
|
||||
provider._sandboxes = {}
|
||||
provider._sandbox_infos = {}
|
||||
provider._thread_sandboxes = {}
|
||||
provider._thread_locks = {}
|
||||
provider._last_activity = {}
|
||||
provider._warm_pool = {}
|
||||
provider._shutdown_called = False
|
||||
provider._idle_checker_stop = threading.Event()
|
||||
|
||||
backend = MagicMock()
|
||||
backend.is_alive.return_value = alive
|
||||
provider._backend = backend
|
||||
|
||||
return provider, backend
|
||||
|
||||
|
||||
def _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1"):
|
||||
"""Insert a sandbox into the provider's caches as if it were acquired."""
|
||||
sandbox = MagicMock()
|
||||
info = MagicMock()
|
||||
|
||||
provider._sandboxes[sandbox_id] = sandbox
|
||||
provider._sandbox_infos[sandbox_id] = info
|
||||
provider._last_activity[sandbox_id] = 0.0
|
||||
if thread_id:
|
||||
provider._thread_sandboxes[thread_id] = sandbox_id
|
||||
|
||||
return sandbox, info
|
||||
|
||||
|
||||
# ── get() returns sandbox when container is alive ──────────────────────────
|
||||
|
||||
|
||||
def test_get_returns_sandbox_when_container_alive():
|
||||
"""When auto_restart is on and the container is alive, get() returns the sandbox."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=True)
|
||||
sandbox, _ = _seed_sandbox(provider)
|
||||
|
||||
result = provider.get("dead-beef")
|
||||
|
||||
assert result is sandbox
|
||||
backend.is_alive.assert_called_once()
|
||||
|
||||
|
||||
def test_get_returns_sandbox_when_auto_restart_disabled():
|
||||
"""When auto_restart is off, get() skips the health check entirely."""
|
||||
provider, backend = _make_provider(auto_restart=False)
|
||||
sandbox, _ = _seed_sandbox(provider)
|
||||
|
||||
result = provider.get("dead-beef")
|
||||
|
||||
assert result is sandbox
|
||||
backend.is_alive.assert_not_called()
|
||||
|
||||
|
||||
# ── get() evicts dead sandbox when auto_restart is on ──────────────────────
|
||||
|
||||
|
||||
def test_get_evicts_dead_sandbox_when_auto_restart_enabled():
|
||||
"""When the container is dead and auto_restart is on, get() returns None and cleans caches."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_, info = _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1")
|
||||
|
||||
result = provider.get("dead-beef")
|
||||
|
||||
assert result is None
|
||||
assert "dead-beef" not in provider._sandboxes
|
||||
assert "dead-beef" not in provider._sandbox_infos
|
||||
assert "dead-beef" not in provider._last_activity
|
||||
assert "thread-1" not in provider._thread_sandboxes
|
||||
backend.destroy.assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_get_returns_dead_sandbox_when_auto_restart_disabled():
|
||||
"""When auto_restart is off, get() returns the cached sandbox even if the container is dead."""
|
||||
provider, backend = _make_provider(auto_restart=False, alive=False)
|
||||
sandbox, _ = _seed_sandbox(provider)
|
||||
|
||||
result = provider.get("dead-beef")
|
||||
|
||||
assert result is sandbox
|
||||
# Caches are untouched
|
||||
assert "dead-beef" in provider._sandboxes
|
||||
|
||||
|
||||
def test_get_eviction_cleans_multiple_thread_mappings():
|
||||
"""A sandbox mapped to multiple thread IDs has all mappings cleaned on eviction."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="t-a")
|
||||
# Manually add a second thread mapping to the same sandbox
|
||||
provider._thread_sandboxes["t-b"] = "sid-1"
|
||||
|
||||
result = provider.get("sid-1")
|
||||
|
||||
assert result is None
|
||||
assert "t-a" not in provider._thread_sandboxes
|
||||
assert "t-b" not in provider._thread_sandboxes
|
||||
|
||||
|
||||
# ── get() does not check health for unknown sandbox IDs ────────────────────
|
||||
|
||||
|
||||
def test_get_returns_none_for_unknown_id():
|
||||
"""If the sandbox_id is not in cache, get() returns None without checking health."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=True)
|
||||
|
||||
result = provider.get("nonexistent")
|
||||
|
||||
assert result is None
|
||||
backend.is_alive.assert_not_called()
|
||||
|
||||
|
||||
# ── get() handles missing sandbox_info gracefully ──────────────────────────
|
||||
|
||||
|
||||
def test_get_handles_missing_info_gracefully():
|
||||
"""If sandbox is cached but info is missing, get() skips the health check."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
sandbox = MagicMock()
|
||||
provider._sandboxes["sid-x"] = sandbox
|
||||
provider._sandbox_infos.pop("sid-x", None) # Ensure no info
|
||||
provider._last_activity["sid-x"] = 0.0
|
||||
|
||||
result = provider.get("sid-x")
|
||||
|
||||
# No info → cannot call is_alive → sandbox returned as-is
|
||||
assert result is sandbox
|
||||
backend.is_alive.assert_not_called()
|
||||
|
||||
|
||||
def test_get_liveness_check_runs_outside_provider_lock():
|
||||
"""get() should not hold the provider lock while checking backend liveness."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_seed_sandbox(provider, sandbox_id="sid-locked", thread_id="thread-1")
|
||||
|
||||
def _assert_lock_not_held(_):
|
||||
assert not provider._lock.locked()
|
||||
return False
|
||||
|
||||
backend.is_alive.side_effect = _assert_lock_not_held
|
||||
|
||||
assert provider.get("sid-locked") is None
|
||||
|
||||
|
||||
def test_get_still_evicts_when_backend_destroy_fails():
|
||||
"""Cleanup errors should not keep stale sandbox state in memory."""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_seed_sandbox(provider, sandbox_id="sid-fail", thread_id="thread-1")
|
||||
backend.destroy.side_effect = RuntimeError("boom")
|
||||
|
||||
assert provider.get("sid-fail") is None
|
||||
assert "sid-fail" not in provider._sandboxes
|
||||
assert "sid-fail" not in provider._sandbox_infos
|
||||
assert "thread-1" not in provider._thread_sandboxes
|
||||
backend.destroy.assert_called_once()
|
||||
|
||||
|
||||
# ── Integration: eviction clears caches for recreation ─────────────────────
|
||||
|
||||
|
||||
def test_eviction_clears_all_caches_for_recreation():
|
||||
"""After eviction, all caches are clean so _acquire_internal can recreate.
|
||||
|
||||
This verifies the preconditions for transparent restart: when get() evicts
|
||||
a dead sandbox, the next _acquire_internal call will find no cached entry,
|
||||
no warm-pool entry, and fall through to _create_sandbox.
|
||||
"""
|
||||
provider, backend = _make_provider(auto_restart=True, alive=False)
|
||||
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="thread-1")
|
||||
|
||||
# Before eviction: caches populated
|
||||
assert "sid-1" in provider._sandboxes
|
||||
assert "sid-1" in provider._sandbox_infos
|
||||
assert "thread-1" in provider._thread_sandboxes
|
||||
|
||||
# get() detects the dead container and evicts
|
||||
assert provider.get("sid-1") is None
|
||||
|
||||
# After eviction: all caches clean
|
||||
assert "sid-1" not in provider._sandboxes
|
||||
assert "sid-1" not in provider._sandbox_infos
|
||||
assert "thread-1" not in provider._thread_sandboxes
|
||||
assert "sid-1" not in provider._warm_pool
|
||||
|
||||
# _acquire_internal for the same thread would find nothing cached
|
||||
# and generate the deterministic ID, then discover fails (container
|
||||
# is gone), falling through to _create_sandbox — a fresh start.
|
||||
@@ -1,166 +0,0 @@
|
||||
"""Tests for shared assistant payload replay helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.models.assistant_payload_replay import (
|
||||
restore_additional_kwargs_field,
|
||||
restore_assistant_payloads,
|
||||
restore_reasoning_content,
|
||||
)
|
||||
|
||||
|
||||
def _restore_reasoning(payload_msg: dict, orig_msg: AIMessage) -> None:
|
||||
restore_additional_kwargs_field(payload_msg, orig_msg, "reasoning_content")
|
||||
|
||||
|
||||
def test_restore_additional_kwargs_field_copies_present_values_only():
|
||||
payload_message = {"role": "assistant"}
|
||||
orig_message = AIMessage(
|
||||
content="answer",
|
||||
additional_kwargs={
|
||||
"reasoning_content": "",
|
||||
"ignored_none": None,
|
||||
},
|
||||
)
|
||||
|
||||
restore_additional_kwargs_field(payload_message, orig_message, "reasoning_content")
|
||||
restore_additional_kwargs_field(payload_message, orig_message, "ignored_none")
|
||||
restore_additional_kwargs_field(payload_message, orig_message, "missing")
|
||||
|
||||
assert payload_message == {"role": "assistant", "reasoning_content": ""}
|
||||
|
||||
|
||||
def test_restore_reasoning_content_copies_reasoning_content():
|
||||
payload_message = {"role": "assistant"}
|
||||
orig_message = AIMessage(content="answer", additional_kwargs={"reasoning_content": "thought"})
|
||||
|
||||
restore_reasoning_content(payload_message, orig_message)
|
||||
|
||||
assert payload_message["reasoning_content"] == "thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_matches_by_position_when_lengths_match():
|
||||
original_messages = [
|
||||
HumanMessage(content="question"),
|
||||
AIMessage(content="answer", additional_kwargs={"reasoning_content": "thought"}),
|
||||
]
|
||||
payload_messages = [
|
||||
{"role": "user", "content": "question"},
|
||||
{"role": "assistant", "content": "answer"},
|
||||
]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[1]["reasoning_content"] == "thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_fallback_matches_unique_content_signature():
|
||||
original_messages = [
|
||||
AIMessage(content="first", additional_kwargs={"reasoning_content": "first-thought"}),
|
||||
AIMessage(content="second", additional_kwargs={"reasoning_content": "second-thought"}),
|
||||
]
|
||||
payload_messages = [{"role": "assistant", "content": "second"}]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[0]["reasoning_content"] == "second-thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_fallback_matches_unique_tool_call_signature():
|
||||
original_messages = [
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"reasoning_content": "first-thought"},
|
||||
tool_calls=[{"id": "call_first", "name": "tool", "args": {}}],
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"reasoning_content": "second-thought"},
|
||||
tool_calls=[{"id": "call_second", "name": "tool", "args": {}}],
|
||||
),
|
||||
]
|
||||
payload_messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"id": "call_second", "type": "function", "function": {"name": "tool", "arguments": "{}"}}],
|
||||
}
|
||||
]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[0]["reasoning_content"] == "second-thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_fallback_matches_structured_content_signature():
|
||||
original_messages = [
|
||||
AIMessage(
|
||||
content=[{"type": "text", "text": "first"}],
|
||||
additional_kwargs={"reasoning_content": "first-thought"},
|
||||
),
|
||||
AIMessage(
|
||||
content=[{"type": "text", "text": "second"}],
|
||||
additional_kwargs={"reasoning_content": "second-thought"},
|
||||
),
|
||||
]
|
||||
payload_messages = [{"role": "assistant", "content": [{"text": "second", "type": "text"}]}]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[0]["reasoning_content"] == "second-thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_fallback_uses_order_when_signature_is_ambiguous():
|
||||
original_messages = [
|
||||
AIMessage(content="", additional_kwargs={"reasoning_content": "first-thought"}),
|
||||
AIMessage(content="", additional_kwargs={"reasoning_content": "second-thought"}),
|
||||
]
|
||||
payload_messages = [{"role": "assistant", "content": ""}]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[0]["reasoning_content"] == "first-thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_fallback_uses_next_unused_when_ordinal_taken():
|
||||
# Serialization dropped a leading empty assistant message, so payload ordinals
|
||||
# no longer line up with the original AIMessage indices. The first payload
|
||||
# uniquely matches a non-ordinal index by signature, which leaves the later
|
||||
# ambiguous payload's exact ordinal index already used. It must still fall
|
||||
# back to the remaining unused AIMessage (scanning forward from the ordinal)
|
||||
# instead of silently dropping the field.
|
||||
original_messages = [
|
||||
AIMessage(content="", additional_kwargs={"reasoning_content": "dropped-thought"}),
|
||||
AIMessage(content="unique", additional_kwargs={"reasoning_content": "unique-thought"}),
|
||||
AIMessage(content="", additional_kwargs={"reasoning_content": "trailing-thought"}),
|
||||
]
|
||||
payload_messages = [
|
||||
{"role": "assistant", "content": "unique"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[0]["reasoning_content"] == "unique-thought"
|
||||
# Forward scan from the taken ordinal picks the trailing message, not the
|
||||
# dropped leading one (which a naive min-unused scan would wrongly select).
|
||||
assert payload_messages[1]["reasoning_content"] == "trailing-thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_does_not_wrap_to_earlier_unused_message():
|
||||
original_messages = [
|
||||
HumanMessage(content="leading user"),
|
||||
AIMessage(content="", additional_kwargs={"reasoning_content": "dropped-leading-thought"}),
|
||||
AIMessage(content="unique", additional_kwargs={"reasoning_content": "unique-thought"}),
|
||||
]
|
||||
payload_messages = [
|
||||
{"role": "assistant", "content": "unique"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[0]["reasoning_content"] == "unique-thought"
|
||||
assert "reasoning_content" not in payload_messages[1]
|
||||
@@ -372,25 +372,6 @@ class TestExtractResponseText:
|
||||
# Should return "" (no text in current turn), NOT "Hi there!" from previous turn
|
||||
assert _extract_response_text(result) == ""
|
||||
|
||||
def test_ignores_hidden_human_control_messages(self):
|
||||
"""Hidden control messages should not terminate current-turn response extraction."""
|
||||
from app.channels.manager import _extract_response_text
|
||||
|
||||
result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "plan this"},
|
||||
{"type": "ai", "content": "Here is the plan."},
|
||||
{
|
||||
"type": "human",
|
||||
"name": "todo_reminder",
|
||||
"content": "keep todos updated",
|
||||
"additional_kwargs": {"hide_from_ui": True},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
assert _extract_response_text(result) == "Here is the plan."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelManager tests
|
||||
@@ -1697,31 +1678,6 @@ class TestExtractArtifacts:
|
||||
}
|
||||
assert _extract_artifacts(result) == ["/mnt/user-data/outputs/a.txt", "/mnt/user-data/outputs/b.csv"]
|
||||
|
||||
def test_ignores_hidden_human_control_messages(self):
|
||||
"""Hidden control messages should not hide current-turn present_files artifacts."""
|
||||
from app.channels.manager import _extract_artifacts
|
||||
|
||||
result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "export"},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "Done.",
|
||||
"tool_calls": [
|
||||
{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/plan.md"]}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "human",
|
||||
"name": "todo_completion_reminder",
|
||||
"content": "mark tasks complete",
|
||||
"additional_kwargs": {"hide_from_ui": True},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
assert _extract_artifacts(result) == ["/mnt/user-data/outputs/plan.md"]
|
||||
|
||||
|
||||
class TestFormatArtifactText:
|
||||
def test_single_artifact(self):
|
||||
@@ -1834,50 +1790,6 @@ class TestHandleChatWithArtifacts:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_hidden_human_control_message_does_not_trigger_no_response_fallback(self):
|
||||
"""Plan-mode hidden control messages should not mask the final AI response."""
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store)
|
||||
|
||||
run_result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "make a plan"},
|
||||
{"type": "ai", "content": "Here is a concrete plan."},
|
||||
{
|
||||
"type": "human",
|
||||
"name": "todo_reminder",
|
||||
"content": "sync todos",
|
||||
"additional_kwargs": {"hide_from_ui": True},
|
||||
},
|
||||
]
|
||||
}
|
||||
mock_client = _make_mock_langgraph_client(run_result=run_result)
|
||||
manager._client = mock_client
|
||||
|
||||
outbound_received = []
|
||||
bus.subscribe_outbound(lambda msg: outbound_received.append(msg))
|
||||
await manager.start()
|
||||
|
||||
await bus.publish_inbound(
|
||||
InboundMessage(
|
||||
channel_name="test",
|
||||
chat_id="c1",
|
||||
user_id="u1",
|
||||
text="make a plan",
|
||||
)
|
||||
)
|
||||
await _wait_for(lambda: len(outbound_received) >= 1)
|
||||
await manager.stop()
|
||||
|
||||
assert len(outbound_received) == 1
|
||||
assert outbound_received[0].text == "Here is a concrete plan."
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_only_last_turn_artifacts_returned(self):
|
||||
"""Only artifacts from the current turn's present_files calls should be included."""
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
@@ -333,27 +333,8 @@ class TestBuildPatchedMessagesPatching:
|
||||
assert patched[1].tool_call_id == "write_file:36"
|
||||
assert patched[1].name == "write_file"
|
||||
assert patched[1].status == "error"
|
||||
assert "write_file failed before execution" in patched[1].content
|
||||
assert "no file was written" in patched[1].content
|
||||
assert "very large Markdown file in a single tool call" in patched[1].content
|
||||
assert "Do not retry the same large `write_file` payload" in patched[1].content
|
||||
assert "split the file into smaller sections" in patched[1].content
|
||||
assert "normal assistant text" in patched[1].content
|
||||
assert "Failed to parse tool arguments" in patched[1].content
|
||||
assert 'bad {"json"}' not in patched[1].content
|
||||
|
||||
def test_non_write_file_invalid_tool_call_uses_generic_recovery_message(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [_ai_with_invalid_tool_calls([_invalid_tc(name="search", tc_id="search:1")])]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert patched[1].tool_call_id == "search:1"
|
||||
assert patched[1].name == "search"
|
||||
assert "arguments were invalid" in patched[1].content
|
||||
assert "Failed to parse tool arguments" in patched[1].content
|
||||
assert "write_file failed before execution" not in patched[1].content
|
||||
|
||||
def test_valid_and_invalid_tool_calls_are_both_patched(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
|
||||
@@ -83,24 +83,3 @@ def test_frontend_rewrites_langgraph_prefix_to_gateway():
|
||||
assert "DEER_FLOW_INTERNAL_LANGGRAPH_BASE_URL" not in next_config
|
||||
assert "http://127.0.0.1:2024" not in next_config
|
||||
assert "langgraph-compat" not in api_client
|
||||
|
||||
|
||||
def test_smoke_test_docs_do_not_expect_standalone_langgraph_server():
|
||||
smoke_files = {
|
||||
".agent/skills/smoke-test/SKILL.md": _read(".agent/skills/smoke-test/SKILL.md"),
|
||||
".agent/skills/smoke-test/references/SOP.md": _read(".agent/skills/smoke-test/references/SOP.md"),
|
||||
".agent/skills/smoke-test/references/troubleshooting.md": _read(".agent/skills/smoke-test/references/troubleshooting.md"),
|
||||
".agent/skills/smoke-test/scripts/check_local_env.sh": _read(".agent/skills/smoke-test/scripts/check_local_env.sh"),
|
||||
".agent/skills/smoke-test/scripts/deploy_local.sh": _read(".agent/skills/smoke-test/scripts/deploy_local.sh"),
|
||||
".agent/skills/smoke-test/scripts/health_check.sh": _read(".agent/skills/smoke-test/scripts/health_check.sh"),
|
||||
".agent/skills/smoke-test/templates/report.local.template.md": _read(".agent/skills/smoke-test/templates/report.local.template.md"),
|
||||
".agent/skills/smoke-test/templates/report.docker.template.md": _read(".agent/skills/smoke-test/templates/report.docker.template.md"),
|
||||
}
|
||||
|
||||
for path, content in smoke_files.items():
|
||||
assert "localhost:2024" not in content, path
|
||||
assert "127.0.0.1:2024" not in content, path
|
||||
assert "deer-flow-langgraph" not in content, path
|
||||
assert "langgraph.log" not in content, path
|
||||
assert "LangGraph service" not in content, path
|
||||
assert "langgraph dev" not in content, path
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
"""Concurrency-safety tests for JsonlRunEventStore async I/O hardening (#2816).
|
||||
|
||||
Verifies:
|
||||
- write-lock serialises concurrent puts within the same thread_id
|
||||
- put_batch keeps monotonic seq even under concurrent callers
|
||||
- seq recovery from disk on fresh store init
|
||||
- DB put_batch rejects mixed-thread batches
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_store(base_dir: Path) -> JsonlRunEventStore:
|
||||
return JsonlRunEventStore(base_dir=base_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Write-lock: per-thread lock exists and is reused
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_write_lock_returns_asyncio_lock():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
lock = store._get_write_lock("t1")
|
||||
assert isinstance(lock, asyncio.Lock)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_write_lock_same_thread_reuses_lock():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
lock_a = store._get_write_lock("t1")
|
||||
lock_b = store._get_write_lock("t1")
|
||||
assert lock_a is lock_b
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_write_lock_different_threads_get_different_locks():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
lock_a = store._get_write_lock("t1")
|
||||
lock_b = store._get_write_lock("t2")
|
||||
assert lock_a is not lock_b
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seq monotonicity under concurrent puts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_concurrent_puts_produce_unique_monotonic_seqs():
|
||||
"""10 concurrent puts on the same thread must yield distinct, monotonic seq values."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
results = await asyncio.gather(*[store.put(thread_id="t1", run_id=f"r{i}", event_type="trace", category="trace", content=f"msg{i}") for i in range(10)])
|
||||
seqs = sorted(r["seq"] for r in results)
|
||||
assert seqs == list(range(1, 11)), f"Expected 1-10, got {seqs}"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_concurrent_puts_different_threads_independent_seqs():
|
||||
"""Concurrent puts on different threads keep independent seq counters."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
t1_results, t2_results = await asyncio.gather(
|
||||
asyncio.gather(*[store.put(thread_id="t1", run_id="r1", event_type="trace", category="trace") for _ in range(5)]),
|
||||
asyncio.gather(*[store.put(thread_id="t2", run_id="r2", event_type="trace", category="trace") for _ in range(5)]),
|
||||
)
|
||||
t1_seqs = sorted(r["seq"] for r in t1_results)
|
||||
t2_seqs = sorted(r["seq"] for r in t2_results)
|
||||
assert t1_seqs == [1, 2, 3, 4, 5]
|
||||
assert t2_seqs == [1, 2, 3, 4, 5]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# put_batch: delegates to put() and preserves order
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_batch_seqs_are_monotonic():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
events = [{"thread_id": "t1", "run_id": "r1", "event_type": "trace", "category": "trace", "content": str(i)} for i in range(5)]
|
||||
results = await store.put_batch(events)
|
||||
seqs = [r["seq"] for r in results]
|
||||
assert seqs == sorted(seqs)
|
||||
assert len(set(seqs)) == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _ensure_seq_loaded: recovers max_seq from disk after fresh store init
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ensure_seq_loaded_recovers_from_disk():
|
||||
"""A fresh JsonlRunEventStore should pick up the max seq written by a previous instance."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
base = Path(tmp)
|
||||
store1 = _make_store(base)
|
||||
for i in range(3):
|
||||
await store1.put(thread_id="t1", run_id="r1", event_type="trace", category="trace", content=str(i))
|
||||
|
||||
store2 = _make_store(base)
|
||||
record = await store2.put(thread_id="t1", run_id="r1", event_type="trace", category="trace", content="new")
|
||||
assert record["seq"] == 4, f"Expected seq=4 after recovery, got {record['seq']}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# asyncio.to_thread regression guard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_offloads_write_via_to_thread():
|
||||
"""Regression guard: put() must call asyncio.to_thread for _write_record."""
|
||||
original = asyncio.to_thread
|
||||
calls: list[str] = []
|
||||
|
||||
async def spy(*args, **kwargs):
|
||||
calls.append(args[0].__name__ if callable(args[0]) else repr(args[0]))
|
||||
return await original(*args, **kwargs)
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
with patch("asyncio.to_thread", new=spy):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="trace", category="trace", content="x")
|
||||
|
||||
assert "_write_record" in calls, f"Expected asyncio.to_thread(_write_record, ...) — got: {calls}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read methods are non-blocking (asyncio.to_thread path exercised)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_reads_written_records():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hello")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="world")
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["content"] == "hello"
|
||||
assert messages[1]["content"] == "world"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_count_messages_accurate_after_concurrent_writes():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
await asyncio.gather(*[store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") for _ in range(7)])
|
||||
count = await store.count_messages("t1")
|
||||
assert count == 7
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# delete_by_thread and delete_by_run use the write lock
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_thread_clears_seq_counter_and_lock():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="trace", category="trace")
|
||||
await store.delete_by_thread("t1")
|
||||
assert "t1" not in store._seq_counters
|
||||
assert "t1" not in store._write_locks
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run_removes_run_events():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
store = _make_store(Path(tmp))
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="trace", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="trace", category="trace")
|
||||
await store.delete_by_run("t1", "r1")
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert events == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB put_batch: rejects mixed-thread batches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_db_put_batch_rejects_mixed_thread_ids():
|
||||
"""DbRunEventStore.put_batch must raise ValueError for cross-thread batches."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
mock_sf = MagicMock()
|
||||
store = DbRunEventStore(session_factory=mock_sf)
|
||||
|
||||
events = [
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "trace", "category": "trace"},
|
||||
{"thread_id": "t2", "run_id": "r2", "event_type": "trace", "category": "trace"},
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="same thread"):
|
||||
await store.put_batch(events)
|
||||
@@ -476,24 +476,6 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
|
||||
|
||||
|
||||
def test_create_summarization_middleware_uses_frontend_supported_update_key(monkeypatch):
|
||||
"""LangGraph update keys use the middleware class name plus hook name."""
|
||||
|
||||
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
|
||||
app_config.summarization = SummarizationConfig(enabled=True)
|
||||
app_config.memory = MemoryConfig(enabled=False)
|
||||
|
||||
fake_model = MagicMock()
|
||||
fake_model.with_config.return_value = fake_model
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
middleware = lead_agent_module._create_summarization_middleware(app_config=app_config)
|
||||
|
||||
assert middleware is not None
|
||||
update_key = f"{type(middleware).__name__}.before_model"
|
||||
assert update_key == "DeerFlowSummarizationMiddleware.before_model"
|
||||
|
||||
|
||||
def test_create_summarization_middleware_threads_resolved_app_config_to_model(monkeypatch):
|
||||
fallback_app_config = _make_app_config([_make_model("fallback-model", supports_thinking=False)])
|
||||
fallback_app_config.summarization = SummarizationConfig(enabled=True, model_name="fallback-model")
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import stat
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -44,20 +43,6 @@ def test_write_is_atomic_overwrite(tmp_path, storage):
|
||||
assert (tmp_path / "custom" / "demo-skill" / "SKILL.md").read_text() == "second"
|
||||
|
||||
|
||||
def test_write_makes_written_path_sandbox_readable(tmp_path, storage):
|
||||
skill_dir = tmp_path / "custom" / "demo-skill"
|
||||
skill_dir.mkdir(parents=True)
|
||||
skill_dir.chmod(0o700)
|
||||
|
||||
storage.write_custom_skill("demo-skill", "references/ref.md", "# ref")
|
||||
|
||||
ref_dir = skill_dir / "references"
|
||||
ref_file = ref_dir / "ref.md"
|
||||
assert stat.S_IMODE(skill_dir.stat().st_mode) & 0o055 == 0o055
|
||||
assert stat.S_IMODE(ref_dir.stat().st_mode) & 0o055 == 0o055
|
||||
assert stat.S_IMODE(ref_file.stat().st_mode) & 0o044 == 0o044
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Empty / blank path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -407,80 +407,3 @@ def test_session_pool_tool_sync_wrapper_path_is_safe():
|
||||
wrapped.func(url="https://example.com")
|
||||
|
||||
mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_mcp_tools: HTTP transport should NOT be pooled
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_transport_tools_not_pooled():
|
||||
"""HTTP/SSE transport tools should NOT be wrapped with the session pool."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import get_mcp_tools
|
||||
|
||||
class Args(BaseModel):
|
||||
query: str = Field(..., description="query")
|
||||
|
||||
http_tool = StructuredTool(
|
||||
name="myserver_search",
|
||||
description="Search tool",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
stdio_tool = StructuredTool(
|
||||
name="playwright_navigate",
|
||||
description="Navigate browser",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
extensions_config = MagicMock()
|
||||
extensions_config.get_enabled_mcp_servers.return_value = {
|
||||
"myserver": MagicMock(type="http", url="http://localhost:8000/mcp", headers=None, command=None, args=[], env=None),
|
||||
"playwright": MagicMock(type="stdio", command="npx", args=["-y", "@anthropic/mcp-server-playwright"], env=None, url=None, headers=None),
|
||||
}
|
||||
extensions_config.model_extra = {}
|
||||
|
||||
servers_config = {
|
||||
"myserver": {"transport": "http", "url": "http://localhost:8000/mcp"},
|
||||
"playwright": {"transport": "stdio", "command": "npx", "args": ["-y", "@anthropic/mcp-server-playwright"]},
|
||||
}
|
||||
|
||||
with (
|
||||
patch("deerflow.mcp.tools.ExtensionsConfig.from_file", return_value=extensions_config),
|
||||
patch("deerflow.mcp.tools.build_servers_config", return_value=servers_config),
|
||||
patch("deerflow.mcp.tools.get_initial_oauth_headers", return_value={}),
|
||||
patch("deerflow.mcp.tools.build_oauth_tool_interceptor", return_value=None),
|
||||
patch("langchain_mcp_adapters.client.MultiServerMCPClient") as MockClient,
|
||||
patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm),
|
||||
):
|
||||
mock_client_instance = MockClient.return_value
|
||||
mock_client_instance.get_tools = AsyncMock(return_value=[http_tool, stdio_tool])
|
||||
|
||||
tools = await get_mcp_tools()
|
||||
|
||||
pool = get_session_pool()
|
||||
# Tool discovery is lazy: no pooled sessions are created until a wrapped tool is invoked.
|
||||
assert list(pool._entries.keys()) == []
|
||||
|
||||
# Verify the HTTP tool was NOT wrapped with the pool (it's the original tool).
|
||||
http_tools = [t for t in tools if t.name == "myserver_search"]
|
||||
assert len(http_tools) == 1
|
||||
assert http_tools[0].coroutine is http_tool.coroutine
|
||||
|
||||
# Verify the stdio tool WAS wrapped with the pool.
|
||||
stdio_tools = [t for t in tools if t.name == "playwright_navigate"]
|
||||
assert len(stdio_tools) == 1
|
||||
assert stdio_tools[0].coroutine is not stdio_tool.coroutine
|
||||
|
||||
@@ -563,28 +563,6 @@ class TestUpdateMemoryStructuredResponse:
|
||||
model.invoke = MagicMock(return_value=response)
|
||||
return model
|
||||
|
||||
def _run_update_with_response(self, content):
|
||||
updater = MemoryUpdater()
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=self._make_mock_model(content)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True, fact_confidence_threshold=0.7, max_facts=100)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Remember that I prefer concise updates."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Got it."
|
||||
ai_msg.tool_calls = []
|
||||
result = updater.update_memory([msg, ai_msg], thread_id="thread-memory")
|
||||
|
||||
return result, mock_storage
|
||||
|
||||
def test_string_response_parses(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
@@ -631,82 +609,6 @@ class TestUpdateMemoryStructuredResponse:
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_wrapped_json_responses_parse(self):
|
||||
"""Memory update should tolerate provider wrappers around valid JSON."""
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [{"content": "User prefers concise updates", "category": "preference", "confidence": 0.9}], "factsToRemove": []}'
|
||||
response_variants = [
|
||||
f"<think>Analyze the conversation first.</think>\n{valid_json}",
|
||||
f"<think>Analyze the conversation first.\n{valid_json}",
|
||||
f"Here is the memory update:\n{valid_json}",
|
||||
f"{valid_json}\nDone.",
|
||||
f"```json\n{valid_json}\n```",
|
||||
]
|
||||
|
||||
for content in response_variants:
|
||||
result, mock_storage = self._run_update_with_response(content)
|
||||
|
||||
assert result is True
|
||||
saved_memory = mock_storage.save.call_args.args[0]
|
||||
assert saved_memory["facts"][0]["content"] == "User prefers concise updates"
|
||||
|
||||
def test_ignores_unrelated_json_before_memory_update(self):
|
||||
"""Parser should not select unrelated JSON objects before the memory update."""
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [{"content": "Remember the actual update", "category": "context", "confidence": 0.9}], "factsToRemove": []}'
|
||||
response = f'Example object: {{"user": "alice"}}\nActual memory update:\n{valid_json}'
|
||||
|
||||
result, mock_storage = self._run_update_with_response(response)
|
||||
|
||||
assert result is True
|
||||
saved_memory = mock_storage.save.call_args.args[0]
|
||||
assert saved_memory["facts"][0]["content"] == "Remember the actual update"
|
||||
|
||||
def test_invalid_json_response_is_skipped_without_saving(self):
|
||||
"""Truncated JSON should remain a safe skipped update, not guessed repair."""
|
||||
result, mock_storage = self._run_update_with_response('{"user": {}, "history": {}, "newFacts": [')
|
||||
|
||||
assert result is False
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
def test_schema_guard_ignores_invalid_update_fields(self):
|
||||
"""Parsed JSON with bad field types should not break the memory update."""
|
||||
response = '{"user": "bad", "history": [], "newFacts": ["bad", {"content": "User works on DeerFlow", "category": "context", "confidence": 0.91}], "factsToRemove": "bad"}'
|
||||
|
||||
result, mock_storage = self._run_update_with_response(response)
|
||||
|
||||
assert result is True
|
||||
saved_memory = mock_storage.save.call_args.args[0]
|
||||
assert [fact["content"] for fact in saved_memory["facts"]] == ["User works on DeerFlow"]
|
||||
|
||||
def test_fact_schema_guard_coerces_and_filters_nested_fields(self):
|
||||
"""Malformed fact entries should be normalized per fact, not fail the whole update."""
|
||||
response = (
|
||||
'{"user": {}, "history": {}, "newFacts": ['
|
||||
'{"content": " User likes async updates ", "category": 9, "confidence": "0.91", "sourceError": " parse issue "}, '
|
||||
'{"content": "skip invalid confidence", "category": "context", "confidence": "high"}, '
|
||||
'{"content": 12, "category": "context", "confidence": 0.9}, '
|
||||
'{"content": " ", "category": "context", "confidence": 0.9}'
|
||||
'], "factsToRemove": []}'
|
||||
)
|
||||
|
||||
result, mock_storage = self._run_update_with_response(response)
|
||||
|
||||
assert result is True
|
||||
saved_memory = mock_storage.save.call_args.args[0]
|
||||
assert len(saved_memory["facts"]) == 1
|
||||
assert saved_memory["facts"][0]["content"] == "User likes async updates"
|
||||
assert saved_memory["facts"][0]["category"] == "context"
|
||||
assert saved_memory["facts"][0]["confidence"] == 0.91
|
||||
assert saved_memory["facts"][0]["sourceError"] == "parse issue"
|
||||
|
||||
def test_malformed_replacement_update_fails_closed(self):
|
||||
"""Malformed replacement facts should not turn remove+add into delete-only."""
|
||||
response = '{"user": {}, "history": {}, "newFacts": [{"content": "replacement fact", "category": "context", "confidence": "bad"}], "factsToRemove": ["fact_old"]}'
|
||||
|
||||
result, mock_storage = self._run_update_with_response(response)
|
||||
|
||||
assert result is False
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
def test_async_update_memory_delegates_to_sync(self):
|
||||
"""aupdate_memory should delegate to sync _do_update_memory_sync via to_thread."""
|
||||
updater = MemoryUpdater()
|
||||
|
||||
@@ -995,41 +995,6 @@ def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch):
|
||||
assert captured.get("output_version") == "responses/v1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider class path resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["mimo-v2.5-pro", "mimo-v2.5", "mimo-v2-flash"])
|
||||
def test_create_chat_model_resolves_patched_mimo_provider(model_id):
|
||||
from deerflow.models.patched_mimo import PatchedChatMiMo
|
||||
|
||||
model = ModelConfig(
|
||||
name=f"{model_id}-thinking",
|
||||
display_name=f"{model_id} Thinking",
|
||||
description=None,
|
||||
use="deerflow.models.patched_mimo:PatchedChatMiMo",
|
||||
model=model_id,
|
||||
api_key="test-key",
|
||||
base_url="https://api.xiaomimimo.com/v1",
|
||||
supports_thinking=True,
|
||||
when_thinking_enabled={"extra_body": {"thinking": {"type": "enabled"}}},
|
||||
supports_vision=False,
|
||||
)
|
||||
cfg = _make_app_config([model])
|
||||
|
||||
chat_model = factory_module.create_chat_model(
|
||||
name=f"{model_id}-thinking",
|
||||
thinking_enabled=True,
|
||||
app_config=cfg,
|
||||
attach_tracing=False,
|
||||
)
|
||||
|
||||
assert isinstance(chat_model, PatchedChatMiMo)
|
||||
assert chat_model.model_name == model_id
|
||||
assert chat_model.extra_body["thinking"]["type"] == "enabled"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Duplicate keyword argument collision (issue #1977)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
"""Regression tests for the generated OpenAPI spec.
|
||||
|
||||
The Gateway exposes its FastAPI ``app.openapi()`` schema at ``/openapi.json``
|
||||
and downstream tooling (SDK codegen, schema validators, client generators)
|
||||
relies on ``operationId`` values being globally unique. FastAPI emits a
|
||||
``UserWarning`` during spec generation when two routes share the same
|
||||
``operationId`` — concretely this happens when ``@router.api_route`` registers
|
||||
one route for multiple HTTP methods, because the auto-generated unique id is
|
||||
computed from a single method picked out of ``route.methods`` while OpenAPI
|
||||
generation iterates over every method on that route.
|
||||
|
||||
These tests pin that invariant so the warning cannot silently come back.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def openapi_spec() -> dict:
|
||||
"""Build the OpenAPI spec for the Gateway app once per module."""
|
||||
from app.gateway.app import app
|
||||
|
||||
# ``app.openapi()`` caches the result on the FastAPI instance, so reset to
|
||||
# force a fresh generation pass that triggers any duplicate-id warnings.
|
||||
app.openapi_schema = None
|
||||
return app.openapi()
|
||||
|
||||
|
||||
def test_openapi_spec_has_no_duplicate_operation_warnings() -> None:
|
||||
"""Generating the OpenAPI schema must not emit any ``Duplicate Operation ID`` UserWarning."""
|
||||
from app.gateway.app import app
|
||||
|
||||
app.openapi_schema = None
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
app.openapi()
|
||||
|
||||
dup_messages = [str(item.message) for item in caught if "Duplicate Operation ID" in str(item.message)]
|
||||
assert dup_messages == [], f"OpenAPI generation emitted duplicate operation id warnings: {dup_messages}"
|
||||
|
||||
|
||||
def test_openapi_operation_ids_are_unique(openapi_spec: dict) -> None:
|
||||
"""Every (path, method) operation in the spec must carry a unique ``operationId``."""
|
||||
op_id_to_locations: dict[str, list[tuple[str, str]]] = {}
|
||||
|
||||
for path, path_item in openapi_spec.get("paths", {}).items():
|
||||
for method, operation in path_item.items():
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
op_id = operation.get("operationId")
|
||||
if op_id is None:
|
||||
continue
|
||||
op_id_to_locations.setdefault(op_id, []).append((path, method))
|
||||
|
||||
duplicates = {op_id: locations for op_id, locations in op_id_to_locations.items() if len(locations) > 1}
|
||||
assert not duplicates, f"Duplicate operationIds in OpenAPI spec: {duplicates}"
|
||||
|
||||
|
||||
def test_stream_existing_run_exposes_distinct_get_and_post(openapi_spec: dict) -> None:
|
||||
"""The ``/runs/{run_id}/stream`` endpoint must expose GET and POST as distinct operations.
|
||||
|
||||
LangGraph SDK ``joinStream`` uses GET while ``useStream``'s stop button uses POST, so
|
||||
both methods must remain registered with their own ``operationId``.
|
||||
"""
|
||||
path = "/api/threads/{thread_id}/runs/{run_id}/stream"
|
||||
path_item = openapi_spec["paths"].get(path)
|
||||
assert path_item is not None, f"Expected {path} to be present in the OpenAPI spec"
|
||||
|
||||
assert "get" in path_item, f"Expected GET handler on {path}"
|
||||
assert "post" in path_item, f"Expected POST handler on {path}"
|
||||
|
||||
get_op_id = path_item["get"].get("operationId")
|
||||
post_op_id = path_item["post"].get("operationId")
|
||||
assert get_op_id and post_op_id, "Both GET and POST must have operationIds"
|
||||
assert get_op_id != post_op_id, f"GET and POST share operationId {get_op_id!r}, which breaks OpenAPI codegen"
|
||||
@@ -1,169 +0,0 @@
|
||||
"""Tests for deerflow.models.patched_mimo.PatchedChatMiMo."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||
|
||||
|
||||
def _make_model(**kwargs):
|
||||
from deerflow.models.patched_mimo import PatchedChatMiMo
|
||||
|
||||
return PatchedChatMiMo(
|
||||
model="mimo-v2.5-pro",
|
||||
api_key="test-key",
|
||||
base_url="https://api.xiaomimimo.com/v1",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_is_lc_serializable_returns_true():
|
||||
from deerflow.models.patched_mimo import PatchedChatMiMo
|
||||
|
||||
assert PatchedChatMiMo.is_lc_serializable() is True
|
||||
|
||||
|
||||
def test_lc_secrets_contains_mimo_api_key_mapping():
|
||||
model = _make_model()
|
||||
|
||||
assert model.lc_secrets["api_key"] == "MIMO_API_KEY"
|
||||
assert model.lc_secrets["openai_api_key"] == "MIMO_API_KEY"
|
||||
|
||||
|
||||
def test_reasoning_content_injected_into_assistant_tool_call_message():
|
||||
model = _make_model()
|
||||
|
||||
human = HumanMessage(content="Check Beijing weather.")
|
||||
ai = AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"reasoning_content": "I need to call the weather tool."},
|
||||
)
|
||||
payload_message = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_weather",
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather", "arguments": '{"location":"Beijing"}'},
|
||||
}
|
||||
],
|
||||
}
|
||||
base_payload = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Check Beijing weather."},
|
||||
payload_message,
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(type(model).__bases__[0], "_get_request_payload", return_value=base_payload):
|
||||
with patch.object(model, "_convert_input") as mock_convert:
|
||||
mock_convert.return_value = MagicMock(to_messages=lambda: [human, ai])
|
||||
payload = model._get_request_payload([human, ai])
|
||||
|
||||
assert payload["messages"][1]["reasoning_content"] == "I need to call the weather tool."
|
||||
|
||||
|
||||
def test_reasoning_content_is_noop_when_missing():
|
||||
model = _make_model()
|
||||
|
||||
human = HumanMessage(content="hello")
|
||||
ai = AIMessage(content="hi", additional_kwargs={})
|
||||
base_payload = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(type(model).__bases__[0], "_get_request_payload", return_value=base_payload):
|
||||
with patch.object(model, "_convert_input") as mock_convert:
|
||||
mock_convert.return_value = MagicMock(to_messages=lambda: [human, ai])
|
||||
payload = model._get_request_payload([human, ai])
|
||||
|
||||
assert "reasoning_content" not in payload["messages"][1]
|
||||
|
||||
|
||||
def test_create_chat_result_maps_message_reasoning_content():
|
||||
model = _make_model()
|
||||
response = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "The weather is sunny.",
|
||||
"reasoning_content": "The tool returned sunny weather, so answer directly.",
|
||||
"tool_calls": None,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"model": "mimo-v2.5-pro",
|
||||
}
|
||||
|
||||
result = model._create_chat_result(response)
|
||||
message = result.generations[0].message
|
||||
|
||||
assert message.content == "The weather is sunny."
|
||||
assert message.additional_kwargs["reasoning_content"] == "The tool returned sunny weather, so answer directly."
|
||||
|
||||
|
||||
def test_create_chat_result_reads_reasoning_content_from_message_attribute():
|
||||
model = _make_model()
|
||||
|
||||
class FakeMessage:
|
||||
reasoning_content = "Reasoning stored on the SDK message object."
|
||||
|
||||
class FakeChoice:
|
||||
message = FakeMessage()
|
||||
|
||||
class FakeResponse:
|
||||
choices = [FakeChoice()]
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Answer.",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"model": "mimo-v2.5-pro",
|
||||
}
|
||||
|
||||
result = model._create_chat_result(FakeResponse())
|
||||
|
||||
assert result.generations[0].message.additional_kwargs["reasoning_content"] == "Reasoning stored on the SDK message object."
|
||||
|
||||
|
||||
def test_convert_chunk_to_generation_chunk_preserves_reasoning_deltas():
|
||||
model = _make_model()
|
||||
|
||||
first = model._convert_chunk_to_generation_chunk(
|
||||
{"choices": [{"delta": {"role": "assistant", "reasoning_content": "I need "}}]},
|
||||
AIMessageChunk,
|
||||
{},
|
||||
)
|
||||
second = model._convert_chunk_to_generation_chunk(
|
||||
{"choices": [{"delta": {"reasoning_content": "a tool."}}]},
|
||||
AIMessageChunk,
|
||||
{},
|
||||
)
|
||||
answer = model._convert_chunk_to_generation_chunk(
|
||||
{"choices": [{"delta": {"content": "Done."}, "finish_reason": "stop"}], "model": "mimo-v2.5-pro"},
|
||||
AIMessageChunk,
|
||||
{},
|
||||
)
|
||||
|
||||
assert first is not None
|
||||
assert second is not None
|
||||
assert answer is not None
|
||||
|
||||
combined = first.message + second.message + answer.message
|
||||
|
||||
assert combined.additional_kwargs["reasoning_content"] == "I need a tool."
|
||||
assert combined.content == "Done."
|
||||
@@ -1,109 +0,0 @@
|
||||
"""Tests for the DDD run domain skeleton."""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.runs import DisconnectMode, RunStatus
|
||||
from deerflow.runtime.runs.domain import (
|
||||
AssistantId,
|
||||
CancelAction,
|
||||
EventSeq,
|
||||
InvalidRunTransition,
|
||||
MultitaskStrategy,
|
||||
Run,
|
||||
RunCancelled,
|
||||
RunCompleted,
|
||||
RunCreated,
|
||||
RunFailed,
|
||||
RunId,
|
||||
RunScope,
|
||||
RunStarted,
|
||||
ThreadId,
|
||||
)
|
||||
from deerflow.runtime.runs.schemas import DisconnectMode as CompatDisconnectMode
|
||||
from deerflow.runtime.runs.schemas import RunStatus as CompatRunStatus
|
||||
|
||||
|
||||
def test_compat_schema_exports_use_domain_enums() -> None:
|
||||
assert CompatRunStatus is RunStatus
|
||||
assert CompatDisconnectMode is DisconnectMode
|
||||
|
||||
|
||||
def test_create_run_records_pending_state_and_created_event() -> None:
|
||||
run = Run.create(
|
||||
run_id=RunId("run-1"),
|
||||
thread_id=ThreadId("thread-1"),
|
||||
assistant_id=AssistantId("lead_agent"),
|
||||
scope=RunScope.stateful,
|
||||
multitask_strategy=MultitaskStrategy.reject,
|
||||
metadata={"source": "test"},
|
||||
kwargs={"input": {"messages": []}},
|
||||
created_at="2026-01-01T00:00:00+00:00",
|
||||
)
|
||||
|
||||
assert run.status == RunStatus.pending
|
||||
assert run.run_id == "run-1"
|
||||
assert run.thread_id == "thread-1"
|
||||
assert run.assistant_id == "lead_agent"
|
||||
assert run.created_at == "2026-01-01T00:00:00+00:00"
|
||||
assert run.updated_at == "2026-01-01T00:00:00+00:00"
|
||||
|
||||
events = run.pull_events()
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], RunCreated)
|
||||
assert events[0].metadata == {"source": "test"}
|
||||
assert run.pull_events() == ()
|
||||
|
||||
|
||||
def test_run_allows_pending_running_success_transition() -> None:
|
||||
run = Run.create(run_id=RunId("run-1"), thread_id=ThreadId("thread-1"))
|
||||
run.pull_events()
|
||||
|
||||
run.mark_started(at="2026-01-01T00:00:01+00:00")
|
||||
run.mark_completed(at="2026-01-01T00:00:02+00:00")
|
||||
|
||||
assert run.status == RunStatus.success
|
||||
assert run.updated_at == "2026-01-01T00:00:02+00:00"
|
||||
events = run.pull_events()
|
||||
assert [type(event) for event in events] == [RunStarted, RunCompleted]
|
||||
|
||||
|
||||
def test_run_records_failed_and_cancelled_domain_events() -> None:
|
||||
failed = Run.create(run_id=RunId("run-failed"), thread_id=ThreadId("thread-1"))
|
||||
failed.pull_events()
|
||||
failed.mark_started()
|
||||
failed.mark_failed("boom", at="2026-01-01T00:00:03+00:00")
|
||||
failed_events = failed.pull_events()
|
||||
|
||||
assert failed.status == RunStatus.error
|
||||
assert isinstance(failed_events[-1], RunFailed)
|
||||
assert failed_events[-1].status == RunStatus.error
|
||||
assert failed_events[-1].error == "boom"
|
||||
|
||||
cancelled = Run.create(run_id=RunId("run-cancelled"), thread_id=ThreadId("thread-1"))
|
||||
cancelled.pull_events()
|
||||
cancelled.mark_cancelled(action=CancelAction.rollback)
|
||||
cancelled_events = cancelled.pull_events()
|
||||
|
||||
assert cancelled.status == RunStatus.interrupted
|
||||
assert isinstance(cancelled_events[-1], RunCancelled)
|
||||
assert cancelled_events[-1].action == CancelAction.rollback
|
||||
|
||||
|
||||
def test_terminal_run_cannot_transition_again() -> None:
|
||||
run = Run.create(run_id=RunId("run-1"), thread_id=ThreadId("thread-1"))
|
||||
run.mark_started()
|
||||
run.mark_completed()
|
||||
|
||||
with pytest.raises(InvalidRunTransition) as exc:
|
||||
run.mark_failed("too late")
|
||||
|
||||
assert exc.value.current == RunStatus.success
|
||||
assert exc.value.target == RunStatus.error
|
||||
|
||||
|
||||
def test_domain_value_objects_validate_minimal_invariants() -> None:
|
||||
assert EventSeq(1).next() == EventSeq(2)
|
||||
with pytest.raises(ValueError, match="EventSeq"):
|
||||
EventSeq(-1)
|
||||
with pytest.raises(ValueError, match="run_id"):
|
||||
Run.create(run_id=RunId(" "), thread_id=ThreadId("thread-1"))
|
||||
@@ -96,30 +96,25 @@ class _ScriptedAgent:
|
||||
del subgraphs
|
||||
self.controller.started.set()
|
||||
|
||||
try:
|
||||
thread_id = _thread_id_from_config(config)
|
||||
human_text = _last_human_text(graph_input)
|
||||
human = HumanMessage(content=human_text)
|
||||
ai = await self.model.ainvoke([human], config=config)
|
||||
state = {"messages": [human.model_dump(), ai.model_dump()], "title": self.title}
|
||||
thread_id = _thread_id_from_config(config)
|
||||
human_text = _last_human_text(graph_input)
|
||||
human = HumanMessage(content=human_text)
|
||||
ai = await self.model.ainvoke([human], config=config)
|
||||
state = {"messages": [human.model_dump(), ai.model_dump()], "title": self.title}
|
||||
|
||||
if self.checkpointer is not None:
|
||||
await _write_checkpoint(self.checkpointer, thread_id=thread_id, state=state)
|
||||
self.controller.checkpoint_written.set()
|
||||
if self.checkpointer is not None:
|
||||
await _write_checkpoint(self.checkpointer, thread_id=thread_id, state=state)
|
||||
self.controller.checkpoint_written.set()
|
||||
|
||||
yield _stream_item_for_mode(stream_mode, state)
|
||||
yield _stream_item_for_mode(stream_mode, state)
|
||||
|
||||
if self.block_after_first_chunk:
|
||||
if self.block_after_first_chunk:
|
||||
try:
|
||||
while not self.controller.release.is_set():
|
||||
await asyncio.sleep(0.05)
|
||||
except asyncio.CancelledError:
|
||||
# Catch cancellation arriving anywhere in the body — including the
|
||||
# `await ainvoke()` / `_write_checkpoint()` / `yield` points between
|
||||
# ``started.set()`` and the original inner ``try`` — so tests that
|
||||
# wait for ``cancelled`` after issuing ``POST /cancel`` no longer
|
||||
# race with cancellation arriving early.
|
||||
self.controller.cancelled.set()
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
self.controller.cancelled.set()
|
||||
raise
|
||||
|
||||
|
||||
def _make_agent_factory(controller: _RunController, **agent_kwargs):
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
import stat
|
||||
|
||||
from deerflow.skills.permissions import make_skill_tree_sandbox_readable, make_skill_written_path_sandbox_readable
|
||||
|
||||
|
||||
def _mode(path):
|
||||
return stat.S_IMODE(path.stat().st_mode)
|
||||
|
||||
|
||||
def test_skill_tree_readability_includes_hidden_paths_and_removes_sandbox_write(tmp_path):
|
||||
root = tmp_path / "demo-skill"
|
||||
hidden_dir = root / ".hidden"
|
||||
scripts_dir = root / "scripts"
|
||||
hidden_dir.mkdir(parents=True)
|
||||
scripts_dir.mkdir()
|
||||
env_file = root / ".env"
|
||||
hidden_file = hidden_dir / ".secret"
|
||||
script_file = scripts_dir / "run.sh"
|
||||
env_file.write_text("secret", encoding="utf-8")
|
||||
hidden_file.write_text("secret", encoding="utf-8")
|
||||
script_file.write_text("#!/bin/sh\n", encoding="utf-8")
|
||||
|
||||
root.chmod(0o777)
|
||||
hidden_dir.chmod(0o777)
|
||||
scripts_dir.chmod(0o777)
|
||||
env_file.chmod(0o666)
|
||||
hidden_file.chmod(0o600)
|
||||
script_file.chmod(0o777)
|
||||
|
||||
make_skill_tree_sandbox_readable(root)
|
||||
|
||||
assert _mode(root) == 0o755
|
||||
assert _mode(hidden_dir) == 0o755
|
||||
assert _mode(scripts_dir) == 0o755
|
||||
assert _mode(env_file) == 0o644
|
||||
assert _mode(hidden_file) == 0o644
|
||||
assert _mode(script_file) == 0o755
|
||||
|
||||
|
||||
def test_written_path_readability_is_limited_to_written_path(tmp_path):
|
||||
root = tmp_path / "demo-skill"
|
||||
ref_dir = root / "references"
|
||||
sibling_dir = root / "templates"
|
||||
ref_dir.mkdir(parents=True)
|
||||
sibling_dir.mkdir()
|
||||
target = ref_dir / "guide.md"
|
||||
sibling = sibling_dir / "note.md"
|
||||
target.write_text("guide", encoding="utf-8")
|
||||
sibling.write_text("note", encoding="utf-8")
|
||||
|
||||
root.chmod(0o700)
|
||||
ref_dir.chmod(0o700)
|
||||
target.chmod(0o600)
|
||||
sibling_dir.chmod(0o700)
|
||||
sibling.chmod(0o600)
|
||||
|
||||
make_skill_written_path_sandbox_readable(root, target)
|
||||
|
||||
assert _mode(root) == 0o755
|
||||
assert _mode(ref_dir) == 0o755
|
||||
assert _mode(target) == 0o644
|
||||
assert _mode(sibling_dir) == 0o700
|
||||
assert _mode(sibling) == 0o600
|
||||
@@ -1,18 +1,14 @@
|
||||
import errno
|
||||
import json
|
||||
import stat
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from app.gateway.routers import skills as skills_router
|
||||
from app.gateway.routers import uploads as uploads_router
|
||||
from deerflow.skills.storage import get_or_new_skill_storage
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
@@ -57,15 +53,6 @@ def _make_skill_archive(tmp_path: Path, name: str, content: str | None = None) -
|
||||
return archive
|
||||
|
||||
|
||||
def _make_skill_archive_bytes(name: str, content: str | None = None) -> bytes:
|
||||
buffer = BytesIO()
|
||||
skill_content = content or _skill_content(name)
|
||||
with zipfile.ZipFile(buffer, "w") as zf:
|
||||
zf.writestr(f"{name}/SKILL.md", skill_content)
|
||||
zf.writestr(f"{name}/references/guide.md", "# Guide\n")
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def test_install_skill_archive_runs_security_scan(monkeypatch, tmp_path):
|
||||
skills_root = tmp_path / "skills"
|
||||
(skills_root / "custom").mkdir(parents=True)
|
||||
@@ -114,65 +101,6 @@ def test_install_skill_archive_runs_security_scan(monkeypatch, tmp_path):
|
||||
assert refresh_calls == ["refresh"]
|
||||
|
||||
|
||||
def test_uploaded_skill_archive_installs_sandbox_readable_tree(monkeypatch, tmp_path):
|
||||
home = tmp_path / "home"
|
||||
skills_root = tmp_path / "skills"
|
||||
skills_root.mkdir()
|
||||
refresh_calls = []
|
||||
|
||||
async def _scan(*args, **kwargs):
|
||||
from deerflow.skills.security_scanner import ScanResult
|
||||
|
||||
return ScanResult(decision="allow", reason="ok")
|
||||
|
||||
async def _refresh():
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
config = SimpleNamespace(
|
||||
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills", use="deerflow.skills.storage.local_skill_storage:LocalSkillStorage"),
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
uploads=SimpleNamespace(auto_convert_documents=False),
|
||||
)
|
||||
provider = SimpleNamespace(uses_thread_data_mounts=True)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_HOME", str(home))
|
||||
monkeypatch.setattr("deerflow.config.paths._paths", None)
|
||||
monkeypatch.setattr(uploads_router, "get_sandbox_provider", lambda: provider)
|
||||
monkeypatch.setattr("deerflow.skills.installer.scan_skill_content", _scan)
|
||||
monkeypatch.setattr(skills_router, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = make_authed_test_app()
|
||||
app.state.config = config
|
||||
app.dependency_overrides[get_config] = lambda: config
|
||||
app.include_router(uploads_router.router)
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
thread_id = "thread-uploaded-skill"
|
||||
archive_bytes = _make_skill_archive_bytes("uploaded-skill")
|
||||
|
||||
with TestClient(app) as client:
|
||||
upload_response = client.post(
|
||||
f"/api/threads/{thread_id}/uploads",
|
||||
files=[("files", ("uploaded-skill.skill", archive_bytes, "application/octet-stream"))],
|
||||
)
|
||||
assert upload_response.status_code == 200
|
||||
uploaded_file = upload_response.json()["files"][0]
|
||||
uploaded_path = Path(uploaded_file["path"])
|
||||
assert uploaded_path.is_file()
|
||||
|
||||
install_response = client.post("/api/skills/install", json={"thread_id": thread_id, "path": uploaded_file["virtual_path"]})
|
||||
|
||||
assert install_response.status_code == 200
|
||||
assert install_response.json()["skill_name"] == "uploaded-skill"
|
||||
installed_dir = skills_root / "custom" / "uploaded-skill"
|
||||
nested_dir = installed_dir / "references"
|
||||
assert stat.S_IMODE(installed_dir.stat().st_mode) & 0o055 == 0o055
|
||||
assert stat.S_IMODE(nested_dir.stat().st_mode) & 0o055 == 0o055
|
||||
assert stat.S_IMODE((installed_dir / "SKILL.md").stat().st_mode) & 0o044 == 0o044
|
||||
assert stat.S_IMODE((nested_dir / "guide.md").stat().st_mode) & 0o044 == 0o044
|
||||
assert refresh_calls == ["refresh"]
|
||||
|
||||
|
||||
def test_install_skill_archive_security_scan_block_returns_400(monkeypatch, tmp_path):
|
||||
skills_root = tmp_path / "skills"
|
||||
(skills_root / "custom").mkdir(parents=True)
|
||||
@@ -247,7 +175,6 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
|
||||
)
|
||||
assert update_response.status_code == 200
|
||||
assert update_response.json()["description"] == "Edited skill"
|
||||
assert stat.S_IMODE((custom_dir / "SKILL.md").stat().st_mode) & 0o044 == 0o044
|
||||
|
||||
history_response = client.get("/api/skills/custom/demo-skill/history")
|
||||
assert history_response.status_code == 200
|
||||
@@ -256,7 +183,6 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
|
||||
rollback_response = client.post("/api/skills/custom/demo-skill/rollback", json={"history_index": -1})
|
||||
assert rollback_response.status_code == 200
|
||||
assert rollback_response.json()["description"] == "Demo skill"
|
||||
assert stat.S_IMODE((custom_dir / "SKILL.md").stat().st_mode) & 0o044 == 0o044
|
||||
assert refresh_calls == ["refresh", "refresh"]
|
||||
|
||||
|
||||
|
||||
@@ -198,26 +198,6 @@ class TestInstallSkillFromArchive:
|
||||
assert result["skill_name"] == "test-skill"
|
||||
assert (skills_root / "custom" / "test-skill" / "SKILL.md").exists()
|
||||
|
||||
def test_installed_skill_tree_is_readable_by_sandbox_mount(self, tmp_path):
|
||||
zip_path = tmp_path / "test-skill.skill"
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
zf.writestr("test-skill/SKILL.md", "---\nname: test-skill\ndescription: A test skill\n---\n\n# test-skill\n")
|
||||
zf.writestr("test-skill/references/guide.md", "# Guide\n")
|
||||
skills_root = tmp_path / "skills"
|
||||
skills_root.mkdir()
|
||||
|
||||
get_or_new_skill_storage(skills_path=skills_root).install_skill_from_archive(zip_path)
|
||||
|
||||
installed_dir = skills_root / "custom" / "test-skill"
|
||||
nested_dir = installed_dir / "references"
|
||||
skill_file = installed_dir / "SKILL.md"
|
||||
guide_file = nested_dir / "guide.md"
|
||||
|
||||
assert stat.S_IMODE(installed_dir.stat().st_mode) & 0o055 == 0o055
|
||||
assert stat.S_IMODE(nested_dir.stat().st_mode) & 0o055 == 0o055
|
||||
assert stat.S_IMODE(skill_file.stat().st_mode) & 0o044 == 0o044
|
||||
assert stat.S_IMODE(guide_file.stat().st_mode) & 0o044 == 0o044
|
||||
|
||||
def test_scans_skill_markdown_before_install(self, tmp_path, monkeypatch):
|
||||
zip_path = self._make_skill_zip(tmp_path)
|
||||
skills_root = tmp_path / "skills"
|
||||
|
||||
@@ -5,10 +5,7 @@ from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware
|
||||
@@ -25,23 +22,6 @@ def _messages() -> list:
|
||||
]
|
||||
|
||||
|
||||
class _StaticChatModel(BaseChatModel):
|
||||
text: str = "ok"
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "static-test-chat-model"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=self.text))])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
|
||||
def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content="<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
|
||||
@@ -134,32 +114,6 @@ def test_before_summarization_hook_receives_messages_before_compression() -> Non
|
||||
assert result["messages"][1].content.startswith("Here is a summary")
|
||||
|
||||
|
||||
def test_summarization_middleware_emits_frontend_update_key_in_agent_stream() -> None:
|
||||
middleware = DeerFlowSummarizationMiddleware(
|
||||
model=_StaticChatModel(text="compressed summary"),
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
token_counter=len,
|
||||
)
|
||||
agent = create_agent(
|
||||
model=_StaticChatModel(text="done"),
|
||||
tools=[],
|
||||
middleware=[middleware],
|
||||
)
|
||||
|
||||
chunks = list(agent.stream({"messages": _messages()}, stream_mode="updates"))
|
||||
update = next(
|
||||
(chunk["DeerFlowSummarizationMiddleware.before_model"] for chunk in chunks if "DeerFlowSummarizationMiddleware.before_model" in chunk),
|
||||
None,
|
||||
)
|
||||
|
||||
assert update is not None
|
||||
emitted = update["messages"]
|
||||
assert isinstance(emitted[0], RemoveMessage)
|
||||
assert emitted[1].name == "summary"
|
||||
assert emitted[1].content == ("Here is a summary of the conversation to date:\n\ncompressed summary")
|
||||
|
||||
|
||||
def test_dynamic_context_reminder_is_preserved_across_summarization() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append])
|
||||
|
||||
@@ -134,14 +134,12 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
|
||||
|
||||
assert captured["app_config"] is app_config
|
||||
# 7 baseline (ToolOutputBudget, ThreadData, Sandbox, DanglingToolCall,
|
||||
# LLMErrorHandling, SandboxAudit, ToolErrorHandling)
|
||||
# + 1 SafetyFinishReasonMiddleware (enabled by default).
|
||||
# 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling,
|
||||
# SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware
|
||||
# (enabled by default — see SafetyFinishReasonConfig).
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
from deerflow.agents.middlewares.tool_output_budget_middleware import ToolOutputBudgetMiddleware
|
||||
|
||||
assert len(middlewares) == 8
|
||||
assert isinstance(middlewares[0], ToolOutputBudgetMiddleware)
|
||||
assert len(middlewares) == 7
|
||||
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
|
||||
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
|
||||
|
||||
|
||||
@@ -1,890 +0,0 @@
|
||||
"""Comprehensive tests for ToolOutputBudgetMiddleware.
|
||||
|
||||
Covers: pass-through, disk externalization, fallback truncation, UTF-8
|
||||
boundaries, Command results, model-request history patching, config
|
||||
variations, exempt tools, per-tool overrides, edge cases, and both
|
||||
sync/async code paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.agents.middlewares.tool_output_budget_middleware import (
|
||||
ToolOutputBudgetMiddleware,
|
||||
_build_fallback,
|
||||
_build_preview,
|
||||
_effective_trigger,
|
||||
_externalize,
|
||||
_message_text,
|
||||
_needs_budget,
|
||||
_patch_model_messages,
|
||||
_sanitize_tool_name,
|
||||
_snap_to_line_boundary,
|
||||
_tool_message_over_budget,
|
||||
)
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.tool_output_config import ToolOutputConfig
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(tool_name: str = "remote_executor", tool_call_id: str = "tc-1", outputs_path: str | None = None) -> SimpleNamespace:
|
||||
thread_data = {"outputs_path": outputs_path} if outputs_path else None
|
||||
state = {"thread_data": thread_data} if thread_data else {}
|
||||
runtime = SimpleNamespace(state=state)
|
||||
return SimpleNamespace(
|
||||
tool_call={"name": tool_name, "id": tool_call_id},
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
|
||||
def _tm(content: str = "ok", name: str = "tool", tool_call_id: str = "tc-1") -> ToolMessage:
|
||||
return ToolMessage(content=content, name=name, tool_call_id=tool_call_id)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Unit tests for helper functions
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestMessageText:
|
||||
def test_string_content(self):
|
||||
assert _message_text("hello") == "hello"
|
||||
|
||||
def test_none_content(self):
|
||||
assert _message_text(None) is None
|
||||
|
||||
def test_list_of_strings(self):
|
||||
assert _message_text(["a", "b"]) == "a\nb"
|
||||
|
||||
def test_list_of_text_dicts(self):
|
||||
assert _message_text([{"text": "x"}, {"text": "y"}]) == "x\ny"
|
||||
|
||||
def test_list_with_image_returns_none(self):
|
||||
assert _message_text([{"type": "image", "data": "..."}]) is None
|
||||
|
||||
def test_empty_list(self):
|
||||
assert _message_text([]) is None
|
||||
|
||||
def test_non_string_non_list(self):
|
||||
assert _message_text(42) is None
|
||||
|
||||
|
||||
class TestSnapToLineBoundary:
|
||||
def test_snaps_to_newline(self):
|
||||
text = "line1\nline2\nline3"
|
||||
pos = 14 # inside "line3"
|
||||
result = _snap_to_line_boundary(text, pos)
|
||||
assert text[result - 1] == "\n"
|
||||
|
||||
def test_no_snap_when_no_newline_in_range(self):
|
||||
text = "abcdefghij"
|
||||
assert _snap_to_line_boundary(text, 8) == 8
|
||||
|
||||
def test_zero_pos(self):
|
||||
assert _snap_to_line_boundary("abc", 0) == 0
|
||||
|
||||
def test_pos_beyond_length(self):
|
||||
assert _snap_to_line_boundary("abc", 10) == 10
|
||||
|
||||
|
||||
class TestExternalize:
|
||||
def test_writes_file_and_returns_virtual_path(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _externalize(
|
||||
"full content here",
|
||||
tool_name="bash",
|
||||
tool_call_id="tc-1",
|
||||
outputs_path=tmpdir,
|
||||
storage_subdir=".tool-results",
|
||||
)
|
||||
assert path is not None
|
||||
assert path.startswith("/mnt/user-data/outputs/.tool-results/bash-")
|
||||
assert path.endswith(".log")
|
||||
|
||||
# Verify actual file on disk
|
||||
storage_dir = os.path.join(tmpdir, ".tool-results")
|
||||
files = os.listdir(storage_dir)
|
||||
assert len(files) == 1
|
||||
with open(os.path.join(storage_dir, files[0]), encoding="utf-8") as f:
|
||||
assert f.read() == "full content here"
|
||||
|
||||
def test_returns_none_on_invalid_path(self):
|
||||
path = _externalize(
|
||||
"data",
|
||||
tool_name="test",
|
||||
tool_call_id="tc-1",
|
||||
outputs_path="/nonexistent/path/that/should/not/exist",
|
||||
storage_subdir=".tool-results",
|
||||
)
|
||||
assert path is None
|
||||
|
||||
def test_txt_extension_for_unknown_tool(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _externalize(
|
||||
"data",
|
||||
tool_name="unknown_tool",
|
||||
tool_call_id="tc-1",
|
||||
outputs_path=tmpdir,
|
||||
storage_subdir=".tool-results",
|
||||
)
|
||||
assert path is not None
|
||||
assert path.endswith(".txt")
|
||||
|
||||
|
||||
class TestSanitizeToolName:
|
||||
def test_strips_path_separators(self):
|
||||
assert _sanitize_tool_name("../../etc/passwd") == "passwd"
|
||||
|
||||
def test_strips_backslashes(self):
|
||||
result = _sanitize_tool_name("..\\..\\windows\\system32")
|
||||
assert ".." not in result
|
||||
assert "/" not in result
|
||||
|
||||
def test_normal_name_unchanged(self):
|
||||
assert _sanitize_tool_name("bash") == "bash"
|
||||
|
||||
def test_empty_becomes_unknown(self):
|
||||
assert _sanitize_tool_name("") == "unknown"
|
||||
|
||||
def test_dots_only_becomes_unknown(self):
|
||||
assert _sanitize_tool_name("..") == "unknown"
|
||||
|
||||
|
||||
class TestExternalizePathTraversal:
|
||||
def test_traversal_tool_name_is_sanitized(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _externalize(
|
||||
"data",
|
||||
tool_name="../../etc/passwd",
|
||||
tool_call_id="tc-1",
|
||||
outputs_path=tmpdir,
|
||||
storage_subdir=".tool-results",
|
||||
)
|
||||
assert path is not None
|
||||
assert "passwd-" in path
|
||||
assert "../" not in path
|
||||
|
||||
def test_absolute_storage_subdir_rejected(self):
|
||||
path = _externalize(
|
||||
"data",
|
||||
tool_name="tool",
|
||||
tool_call_id="tc-1",
|
||||
outputs_path="/tmp",
|
||||
storage_subdir="/etc/evil",
|
||||
)
|
||||
assert path is None
|
||||
|
||||
def test_traversal_storage_subdir_rejected(self):
|
||||
path = _externalize(
|
||||
"data",
|
||||
tool_name="tool",
|
||||
tool_call_id="tc-1",
|
||||
outputs_path="/tmp",
|
||||
storage_subdir="../../../etc",
|
||||
)
|
||||
assert path is None
|
||||
|
||||
|
||||
class TestNeedsBudget:
|
||||
def test_small_output_does_not_need_budget(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=1000)
|
||||
msg = _tm("small", name="tool")
|
||||
assert _needs_budget(msg, config) is False
|
||||
|
||||
def test_large_output_needs_budget(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=50)
|
||||
msg = _tm("x" * 100, name="tool")
|
||||
assert _needs_budget(msg, config) is True
|
||||
|
||||
def test_exempt_tool_does_not_need_budget(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=10)
|
||||
msg = _tm("x" * 100, name="read_file")
|
||||
assert _needs_budget(msg, config) is False
|
||||
|
||||
def test_multimodal_does_not_need_budget(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=10)
|
||||
msg = ToolMessage(content=[{"type": "image", "data": "x" * 100}], name="tool", tool_call_id="tc-1")
|
||||
assert _needs_budget(msg, config) is False
|
||||
|
||||
|
||||
class TestBuildPreview:
|
||||
def test_contains_head_and_tail_and_reference(self):
|
||||
content = "HEAD_" + "x" * 5000 + "_TAIL"
|
||||
preview = _build_preview(
|
||||
content,
|
||||
tool_name="bash",
|
||||
virtual_path="/mnt/test/bash-abc.log",
|
||||
head_chars=100,
|
||||
tail_chars=50,
|
||||
)
|
||||
assert preview.startswith("HEAD_")
|
||||
assert "_TAIL" in preview
|
||||
assert "/mnt/test/bash-abc.log" in preview
|
||||
assert "read_file" in preview
|
||||
assert "start_line and end_line" in preview
|
||||
|
||||
def test_reports_total_chars(self):
|
||||
content = "a" * 10000
|
||||
preview = _build_preview(
|
||||
content,
|
||||
tool_name="web_search",
|
||||
virtual_path="/mnt/test/file.txt",
|
||||
head_chars=200,
|
||||
tail_chars=100,
|
||||
)
|
||||
assert "10000 chars" in preview
|
||||
|
||||
|
||||
class TestBuildFallback:
|
||||
def test_short_content_unchanged(self):
|
||||
assert _build_fallback("short", tool_name="t", max_chars=100, head_chars=50, tail_chars=50) == "short"
|
||||
|
||||
def test_zero_max_disables(self):
|
||||
content = "a" * 1000
|
||||
assert _build_fallback(content, tool_name="t", max_chars=0, head_chars=50, tail_chars=50) == content
|
||||
|
||||
def test_truncates_long_content(self):
|
||||
content = "H" * 5000 + "M" * 20000 + "T" * 5000
|
||||
result = _build_fallback(content, tool_name="bash", max_chars=12000, head_chars=6000, tail_chars=3000)
|
||||
assert len(result) < len(content)
|
||||
assert "omitted from bash output" in result
|
||||
assert "Persistent storage unavailable" in result
|
||||
|
||||
def test_preserves_head_and_tail(self):
|
||||
content = "HEADSTART" + "x" * 50000 + "TAILEND"
|
||||
result = _build_fallback(content, tool_name="t", max_chars=20000, head_chars=10000, tail_chars=5000)
|
||||
assert result.startswith("HEADSTART")
|
||||
assert "TAILEND" in result
|
||||
|
||||
def test_result_never_exceeds_max_chars(self):
|
||||
"""The marker itself has non-zero length; total must still respect max_chars."""
|
||||
for max_chars in [200, 500, 1000, 5000, 20000]:
|
||||
content = "x" * 50000
|
||||
result = _build_fallback(content, tool_name="long_tool_name", max_chars=max_chars, head_chars=max_chars // 2, tail_chars=max_chars // 4)
|
||||
assert len(result) <= max_chars, f"max_chars={max_chars}: got {len(result)}"
|
||||
|
||||
def test_very_small_max_chars_does_not_crash(self):
|
||||
content = "x" * 1000
|
||||
result = _build_fallback(content, tool_name="t", max_chars=50, head_chars=20, tail_chars=10)
|
||||
assert len(result) <= 50
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Middleware integration tests — wrap_tool_call
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestWrapToolCallPassThrough:
|
||||
def test_small_output_passes_through(self):
|
||||
mw = ToolOutputBudgetMiddleware(config=ToolOutputConfig(externalize_min_chars=1000))
|
||||
msg = _tm("small output", name="bash")
|
||||
result = mw.wrap_tool_call(_make_request(), lambda _: msg)
|
||||
assert result is msg
|
||||
|
||||
def test_disabled_middleware_passes_through(self):
|
||||
mw = ToolOutputBudgetMiddleware(config=ToolOutputConfig(enabled=False, externalize_min_chars=10, fallback_max_chars=20))
|
||||
msg = _tm("x" * 50000, name="bash")
|
||||
result = mw.wrap_tool_call(_make_request(), lambda _: msg)
|
||||
assert result is msg
|
||||
|
||||
|
||||
class TestWrapToolCallExternalize:
|
||||
def test_oversized_output_externalized_to_disk(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = ToolOutputConfig(externalize_min_chars=100, preview_head_chars=50, preview_tail_chars=30)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = "x" * 500
|
||||
msg = _tm(content, name="remote_executor")
|
||||
req = _make_request(outputs_path=tmpdir)
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result is not msg
|
||||
assert "Full remote_executor output saved to" in result.content
|
||||
assert "read_file" in result.content
|
||||
assert result.tool_call_id == "tc-1"
|
||||
|
||||
# Verify file was written
|
||||
storage_dir = os.path.join(tmpdir, ".tool-results")
|
||||
assert os.path.isdir(storage_dir)
|
||||
files = os.listdir(storage_dir)
|
||||
assert len(files) == 1
|
||||
with open(os.path.join(storage_dir, files[0]), encoding="utf-8") as f:
|
||||
assert f.read() == content
|
||||
|
||||
def test_preview_contains_head_and_tail(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = "HEADPART_" + "m" * 200 + "_TAILPART"
|
||||
msg = _tm(content, name="web_search")
|
||||
req = _make_request(outputs_path=tmpdir)
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert result.content.startswith("HEADPART_")
|
||||
assert "_TAILPART" in result.content
|
||||
|
||||
|
||||
class TestWrapToolCallFallback:
|
||||
def test_fallback_when_no_outputs_path(self):
|
||||
config = ToolOutputConfig(
|
||||
externalize_min_chars=50,
|
||||
fallback_max_chars=200,
|
||||
fallback_head_chars=80,
|
||||
fallback_tail_chars=40,
|
||||
)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = "x" * 500
|
||||
msg = _tm(content, name="mcp_tool")
|
||||
req = _make_request(outputs_path=None)
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result is not msg
|
||||
assert "omitted from mcp_tool output" in result.content
|
||||
assert "Persistent storage unavailable" in result.content
|
||||
assert len(result.content) < len(content)
|
||||
|
||||
def test_fallback_when_disk_write_fails(self):
|
||||
config = ToolOutputConfig(
|
||||
externalize_min_chars=50,
|
||||
fallback_max_chars=200,
|
||||
fallback_head_chars=80,
|
||||
fallback_tail_chars=40,
|
||||
)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = "x" * 500
|
||||
msg = _tm(content, name="tool")
|
||||
req = _make_request(outputs_path="/nonexistent/impossible/path")
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert "omitted from tool output" in result.content
|
||||
|
||||
|
||||
class TestWrapToolCallExemption:
|
||||
def test_read_file_exempt(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=50)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = "x" * 100
|
||||
msg = _tm(content, name="read_file")
|
||||
|
||||
result = mw.wrap_tool_call(_make_request(tool_name="read_file"), lambda _: msg)
|
||||
|
||||
assert result is msg
|
||||
|
||||
def test_read_file_tool_exempt(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=50)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = "x" * 100
|
||||
msg = _tm(content, name="read_file_tool")
|
||||
|
||||
result = mw.wrap_tool_call(_make_request(tool_name="read_file_tool"), lambda _: msg)
|
||||
|
||||
assert result is msg
|
||||
|
||||
def test_custom_exempt_tool(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=50, exempt_tools=["my_tool"])
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = "x" * 100
|
||||
msg = _tm(content, name="my_tool")
|
||||
|
||||
result = mw.wrap_tool_call(_make_request(tool_name="my_tool"), lambda _: msg)
|
||||
|
||||
assert result is msg
|
||||
|
||||
|
||||
class TestWrapToolCallPerToolOverride:
|
||||
def test_per_tool_threshold(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = ToolOutputConfig(
|
||||
externalize_min_chars=50000, # global: high
|
||||
tool_overrides={"sensitive_tool": 100}, # override: low
|
||||
)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = "x" * 500
|
||||
msg = _tm(content, name="sensitive_tool")
|
||||
req = _make_request(tool_name="sensitive_tool", outputs_path=tmpdir)
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert result is not msg
|
||||
assert "Full sensitive_tool output saved to" in result.content
|
||||
|
||||
def test_per_tool_zero_disables_externalization(self):
|
||||
config = ToolOutputConfig(
|
||||
externalize_min_chars=50,
|
||||
tool_overrides={"bash": 0},
|
||||
fallback_max_chars=200,
|
||||
fallback_head_chars=80,
|
||||
fallback_tail_chars=40,
|
||||
)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = "x" * 500
|
||||
msg = _tm(content, name="bash")
|
||||
# Even with outputs_path, externalization disabled for bash
|
||||
req = _make_request(tool_name="bash", outputs_path="/tmp/test")
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
# Should use fallback instead of externalization
|
||||
assert "Persistent storage unavailable" in result.content or "omitted" in result.content
|
||||
|
||||
|
||||
class TestWrapToolCallCommand:
|
||||
def test_command_messages_are_patched(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
tool_msg = _tm("x" * 200, name="present_files")
|
||||
command = Command(update={"messages": [tool_msg], "artifacts": ["/mnt/report.html"]})
|
||||
req = _make_request(tool_name="present_files", outputs_path=tmpdir)
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: command)
|
||||
|
||||
assert isinstance(result, Command)
|
||||
assert result is not command
|
||||
assert result.update["artifacts"] == ["/mnt/report.html"]
|
||||
new_msg = result.update["messages"][0]
|
||||
assert isinstance(new_msg, ToolMessage)
|
||||
assert "Full present_files output saved to" in new_msg.content
|
||||
|
||||
def test_command_without_messages_unchanged(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=10)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
command = Command(update={"key": "value"})
|
||||
result = mw.wrap_tool_call(_make_request(), lambda _: command)
|
||||
assert result is command
|
||||
|
||||
|
||||
class TestWrapToolCallEdgeCases:
|
||||
def test_none_content_passes_through(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
msg = ToolMessage(content=None, name="tool", tool_call_id="tc-1")
|
||||
|
||||
result = mw.wrap_tool_call(_make_request(), lambda _: msg)
|
||||
|
||||
assert result is msg
|
||||
|
||||
def test_empty_string_passes_through(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
msg = _tm("", name="tool")
|
||||
|
||||
result = mw.wrap_tool_call(_make_request(), lambda _: msg)
|
||||
|
||||
assert result is msg
|
||||
|
||||
def test_multimodal_content_skipped(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = [{"type": "image", "data": "x" * 100}]
|
||||
msg = ToolMessage(content=content, name="tool", tool_call_id="tc-1")
|
||||
|
||||
result = mw.wrap_tool_call(_make_request(), lambda _: msg)
|
||||
|
||||
assert result is msg
|
||||
|
||||
def test_exactly_at_threshold_passes_through(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=100, fallback_max_chars=100)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
msg = _tm("x" * 100, name="tool")
|
||||
|
||||
result = mw.wrap_tool_call(_make_request(), lambda _: msg)
|
||||
|
||||
assert result is msg
|
||||
|
||||
def test_one_char_over_threshold_triggers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = ToolOutputConfig(externalize_min_chars=100)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
msg = _tm("x" * 101, name="tool")
|
||||
req = _make_request(outputs_path=tmpdir)
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert result is not msg
|
||||
|
||||
def test_chinese_content_preserved(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = "你好世界" * 50
|
||||
msg = _tm(content, name="tool")
|
||||
req = _make_request(outputs_path=tmpdir)
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
# File should contain the full Chinese content
|
||||
storage_dir = os.path.join(tmpdir, ".tool-results")
|
||||
files = os.listdir(storage_dir)
|
||||
with open(os.path.join(storage_dir, files[0]), encoding="utf-8") as f:
|
||||
assert f.read() == content
|
||||
|
||||
def test_no_runtime_state_uses_fallback(self):
|
||||
config = ToolOutputConfig(
|
||||
externalize_min_chars=50,
|
||||
fallback_max_chars=500,
|
||||
fallback_head_chars=100,
|
||||
fallback_tail_chars=50,
|
||||
)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = "x" * 1000
|
||||
msg = _tm(content, name="tool")
|
||||
req = SimpleNamespace(
|
||||
tool_call={"name": "tool", "id": "tc-1"},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert "omitted" in result.content
|
||||
assert len(result.content) <= 500
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# MCP content_and_artifact format tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestMCPContentAndArtifact:
|
||||
"""MCP tools return content as list of content blocks, not plain strings."""
|
||||
|
||||
def test_text_content_blocks_are_budgeted(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = [{"type": "text", "text": "x" * 200}]
|
||||
msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-mcp")
|
||||
req = _make_request(tool_name="mcp_tool", outputs_path=tmpdir)
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert result is not msg
|
||||
assert isinstance(result.content, str)
|
||||
assert "Full mcp_tool output saved to" in result.content
|
||||
assert result.tool_call_id == "tc-mcp"
|
||||
|
||||
def test_multiple_text_blocks_joined_and_budgeted(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=50, fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = [{"type": "text", "text": "a" * 300}, {"type": "text", "text": "b" * 300}]
|
||||
msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-mcp2")
|
||||
req = _make_request(tool_name="mcp_tool")
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert result is not msg
|
||||
assert "omitted" in result.content
|
||||
|
||||
def test_image_content_blocks_are_skipped(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = [{"type": "image", "data": "base64data" * 100}]
|
||||
msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-img")
|
||||
req = _make_request(tool_name="mcp_tool")
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert result is msg
|
||||
|
||||
def test_mixed_text_and_image_blocks_are_skipped(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=10)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = [{"type": "text", "text": "x" * 100}, {"type": "image", "data": "base64"}]
|
||||
msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-mix")
|
||||
req = _make_request(tool_name="mcp_tool")
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert result is msg
|
||||
|
||||
def test_small_text_blocks_pass_through(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=1000)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = [{"type": "text", "text": "small result"}]
|
||||
msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-sm")
|
||||
req = _make_request(tool_name="mcp_tool")
|
||||
|
||||
result = mw.wrap_tool_call(req, lambda _: msg)
|
||||
|
||||
assert result is msg
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Async path tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestAsyncPaths:
|
||||
@pytest.mark.anyio
|
||||
async def test_async_tool_call_externalized(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
content = "x" * 200
|
||||
msg = _tm(content, name="async_tool")
|
||||
req = _make_request(tool_name="async_tool", outputs_path=tmpdir)
|
||||
|
||||
async def handler(_):
|
||||
return msg
|
||||
|
||||
result = await mw.awrap_tool_call(req, handler)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result is not msg
|
||||
assert "Full async_tool output saved to" in result.content
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_model_call_patches_history(self):
|
||||
config = ToolOutputConfig(fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
oversized = _tm("h" * 1000, name="tool", tool_call_id="tc-h")
|
||||
request = ModelRequest(model=None, messages=[oversized], tools=[], state={})
|
||||
captured: dict[str, ModelRequest] = {}
|
||||
|
||||
async def handler(req):
|
||||
captured["request"] = req
|
||||
return []
|
||||
|
||||
await mw.awrap_model_call(request, handler)
|
||||
|
||||
forwarded = captured["request"]
|
||||
assert forwarded is not request
|
||||
msg = forwarded.messages[0]
|
||||
assert isinstance(msg, ToolMessage)
|
||||
assert "omitted" in msg.content
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# wrap_model_call — historical message patching
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestWrapModelCall:
|
||||
def test_oversized_historical_messages_truncated(self):
|
||||
config = ToolOutputConfig(fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
oversized = _tm("q" * 1000, name="tool", tool_call_id="tc-q")
|
||||
request = ModelRequest(model=None, messages=[oversized], tools=[], state={})
|
||||
captured: dict[str, ModelRequest] = {}
|
||||
|
||||
def handler(req):
|
||||
captured["request"] = req
|
||||
return []
|
||||
|
||||
mw.wrap_model_call(request, handler)
|
||||
|
||||
forwarded = captured["request"]
|
||||
assert forwarded is not request
|
||||
msg = forwarded.messages[0]
|
||||
assert isinstance(msg, ToolMessage)
|
||||
assert "omitted" in msg.content
|
||||
assert len(msg.content) < len(oversized.content) + 150
|
||||
|
||||
def test_small_historical_messages_unchanged(self):
|
||||
config = ToolOutputConfig(fallback_max_chars=1000)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
small = _tm("small", name="tool")
|
||||
request = ModelRequest(model=None, messages=[small], tools=[], state={})
|
||||
captured: dict[str, ModelRequest] = {}
|
||||
|
||||
def handler(req):
|
||||
captured["request"] = req
|
||||
return []
|
||||
|
||||
mw.wrap_model_call(request, handler)
|
||||
|
||||
assert captured["request"] is request
|
||||
|
||||
def test_exempt_tools_in_history_unchanged(self):
|
||||
config = ToolOutputConfig(fallback_max_chars=50)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
read_msg = _tm("x" * 200, name="read_file", tool_call_id="tc-r")
|
||||
request = ModelRequest(model=None, messages=[read_msg], tools=[], state={})
|
||||
captured: dict[str, ModelRequest] = {}
|
||||
|
||||
def handler(req):
|
||||
captured["request"] = req
|
||||
return []
|
||||
|
||||
mw.wrap_model_call(request, handler)
|
||||
|
||||
assert captured["request"] is request
|
||||
|
||||
def test_non_tool_messages_preserved(self):
|
||||
config = ToolOutputConfig(fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50)
|
||||
mw = ToolOutputBudgetMiddleware(config=config)
|
||||
human = HumanMessage(content="x" * 200)
|
||||
ai = AIMessage(content="y" * 200)
|
||||
oversized_tool = _tm("z" * 1000, name="tool")
|
||||
request = ModelRequest(model=None, messages=[human, ai, oversized_tool], tools=[], state={})
|
||||
captured: dict[str, ModelRequest] = {}
|
||||
|
||||
def handler(req):
|
||||
captured["request"] = req
|
||||
return []
|
||||
|
||||
mw.wrap_model_call(request, handler)
|
||||
|
||||
msgs = captured["request"].messages
|
||||
assert msgs[0] is human
|
||||
assert msgs[1] is ai
|
||||
assert isinstance(msgs[2], ToolMessage)
|
||||
assert "omitted" in msgs[2].content
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Config integration
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestFromAppConfig:
|
||||
def test_from_app_config_with_tool_output(self):
|
||||
config = AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
tool_output={"externalize_min_chars": 5000, "preview_head_chars": 500},
|
||||
)
|
||||
mw = ToolOutputBudgetMiddleware.from_app_config(config)
|
||||
assert mw._config.externalize_min_chars == 5000
|
||||
assert mw._config.preview_head_chars == 500
|
||||
|
||||
def test_from_app_config_defaults(self):
|
||||
config = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
mw = ToolOutputBudgetMiddleware.from_app_config(config)
|
||||
assert mw._config.externalize_min_chars == 12000
|
||||
|
||||
|
||||
class TestPatchModelMessages:
|
||||
def test_returns_none_when_no_changes(self):
|
||||
config = ToolOutputConfig(fallback_max_chars=1000)
|
||||
messages = [_tm("short", name="tool")]
|
||||
assert _patch_model_messages(messages, config) is None
|
||||
|
||||
def test_patches_oversized_messages(self):
|
||||
config = ToolOutputConfig(fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50)
|
||||
messages = [_tm("x" * 1000, name="tool")]
|
||||
result = _patch_model_messages(messages, config)
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert "omitted" in result[0].content
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Pre-scan helpers (_effective_trigger / _tool_message_over_budget / _needs_budget)
|
||||
# These guard the fast-path optimization — a false negative here is a real bug
|
||||
# (budgeting silently skipped), so per-tool overrides must be honored.
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestPreScanHelpers:
|
||||
def test_effective_trigger_uses_global_externalize(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=12000, fallback_max_chars=30000)
|
||||
# smallest of the two thresholds wins
|
||||
assert _effective_trigger("any_tool", config) == 12000
|
||||
|
||||
def test_effective_trigger_respects_per_tool_override(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=50000, fallback_max_chars=0, tool_overrides={"sensitive": 100})
|
||||
assert _effective_trigger("sensitive", config) == 100
|
||||
# other tools fall back to the (high) global
|
||||
assert _effective_trigger("other", config) == 50000
|
||||
|
||||
def test_effective_trigger_per_tool_zero_falls_to_fallback(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=50, tool_overrides={"bash": 0}, fallback_max_chars=200)
|
||||
# externalize disabled for bash → only fallback can trigger
|
||||
assert _effective_trigger("bash", config) == 200
|
||||
|
||||
def test_effective_trigger_returns_negative_when_fully_disabled(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=0, fallback_max_chars=0)
|
||||
assert _effective_trigger("any", config) == -1
|
||||
|
||||
def test_pre_scan_does_not_short_circuit_per_tool_override(self):
|
||||
"""Regression: pre-scan must honor per-tool overrides, not just global threshold."""
|
||||
config = ToolOutputConfig(externalize_min_chars=50000, fallback_max_chars=0, tool_overrides={"sensitive": 100})
|
||||
msg = _tm("x" * 500, name="sensitive")
|
||||
# 500 < global 50000 but > per-tool 100 → must still be flagged
|
||||
assert _tool_message_over_budget(msg, config) is True
|
||||
assert _needs_budget(msg, config) is True
|
||||
|
||||
def test_exempt_tool_never_over_budget(self):
|
||||
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20, exempt_tools=["read_file"])
|
||||
msg = _tm("x" * 1000, name="read_file")
|
||||
assert _tool_message_over_budget(msg, config) is False
|
||||
|
||||
def test_model_call_pre_scan_skips_when_nothing_oversized(self):
|
||||
"""_patch_model_messages returns None (no list rebuild) when all messages are small."""
|
||||
config = ToolOutputConfig(externalize_min_chars=12000, fallback_max_chars=30000)
|
||||
messages = [_tm("small", name="tool"), HumanMessage(content="hi"), _tm("also small", name="bash")]
|
||||
assert _patch_model_messages(messages, config) is None
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Middleware ordering in the chain
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestMiddlewareChainIntegration:
|
||||
def test_budget_middleware_is_first_in_chain(self):
|
||||
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
|
||||
|
||||
app_config = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
|
||||
|
||||
assert isinstance(middlewares[0], ToolOutputBudgetMiddleware)
|
||||
|
||||
def test_budget_middleware_in_lead_chain(self):
|
||||
from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
|
||||
|
||||
app_config = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
middlewares = build_lead_runtime_middlewares(app_config=app_config, lazy_init=False)
|
||||
|
||||
assert isinstance(middlewares[0], ToolOutputBudgetMiddleware)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Config version bump
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestConfigVersion:
|
||||
def test_config_version_bumped(self):
|
||||
import yaml
|
||||
|
||||
example_path = os.path.join(os.path.dirname(__file__), "..", "..", "config.example.yaml")
|
||||
if os.path.exists(example_path):
|
||||
with open(example_path, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert data.get("config_version", 0) >= 11
|
||||
|
||||
def test_config_example_has_tool_output_section(self):
|
||||
import yaml
|
||||
|
||||
example_path = os.path.join(os.path.dirname(__file__), "..", "..", "config.example.yaml")
|
||||
if os.path.exists(example_path):
|
||||
with open(example_path, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert "tool_output" in data
|
||||
tool_output = data["tool_output"]
|
||||
assert tool_output["enabled"] is True
|
||||
assert tool_output["externalize_min_chars"] == 12000
|
||||
assert "read_file" in tool_output["exempt_tools"]
|
||||
@@ -1,177 +0,0 @@
|
||||
"""Regression tests for issue #3265.
|
||||
|
||||
The non-streaming ``/wait`` endpoints used to ``await record.task`` with no
|
||||
disconnect handling and silently swallow ``CancelledError``. When a long
|
||||
tool call (e.g. ``pip install`` inside a custom skill) kept the connection
|
||||
idle long enough for an intermediate HTTP layer to time out, the handler
|
||||
would return a stale checkpoint that looked like a normal completion.
|
||||
|
||||
The fix introduces ``wait_for_run_completion`` in ``app.gateway.services``:
|
||||
it subscribes to the stream bridge until ``END_SENTINEL``, polls
|
||||
``request.is_disconnected()`` on every wake-up, and honours the record's
|
||||
``on_disconnect`` mode by cancelling the background run on real client
|
||||
disconnect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from deerflow.runtime import RunManager, RunStatus
|
||||
from deerflow.runtime.runs.schemas import DisconnectMode
|
||||
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
|
||||
|
||||
THREAD_ID = "thread-wait-3265"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeRequest:
|
||||
"""Minimal stand-in for FastAPI ``Request`` with controllable disconnect.
|
||||
|
||||
``is_disconnected`` is awaited each iteration of the helper's loop, so the
|
||||
counter lets a test transition from "still connected" to "disconnected"
|
||||
after N polls without racing the event loop.
|
||||
"""
|
||||
|
||||
disconnect_after: int = 10**9 # effectively "never" by default
|
||||
_polls: int = 0
|
||||
|
||||
async def is_disconnected(self) -> bool:
|
||||
self._polls += 1
|
||||
return self._polls > self.disconnect_after
|
||||
|
||||
|
||||
async def _create_running_record(mgr: RunManager, *, on_disconnect: DisconnectMode) -> Any:
|
||||
record = await mgr.create_or_reject(
|
||||
THREAD_ID,
|
||||
assistant_id=None,
|
||||
on_disconnect=on_disconnect,
|
||||
)
|
||||
await mgr.set_status(record.run_id, RunStatus.running)
|
||||
return record
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper-level unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWaitForRunCompletion:
|
||||
def test_returns_when_run_publishes_end(self) -> None:
|
||||
"""Happy path: helper returns once the bridge publishes END_SENTINEL."""
|
||||
from app.gateway.services import wait_for_run_completion
|
||||
|
||||
async def run() -> None:
|
||||
mgr = RunManager()
|
||||
bridge = MemoryStreamBridge()
|
||||
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
|
||||
request = _FakeRequest()
|
||||
|
||||
async def finish_soon() -> None:
|
||||
await asyncio.sleep(0)
|
||||
await bridge.publish(record.run_id, "values", {"messages": []})
|
||||
await mgr.set_status(record.run_id, RunStatus.success)
|
||||
await bridge.publish_end(record.run_id)
|
||||
|
||||
asyncio.create_task(finish_soon())
|
||||
completed = await asyncio.wait_for(
|
||||
wait_for_run_completion(bridge, record, request, mgr),
|
||||
timeout=2.0,
|
||||
)
|
||||
assert completed is True
|
||||
assert record.status == RunStatus.success
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_cancels_run_on_disconnect_when_cancel_mode(self) -> None:
|
||||
"""on_disconnect=cancel: real disconnect must call run_mgr.cancel()."""
|
||||
from app.gateway.services import wait_for_run_completion
|
||||
|
||||
async def run() -> None:
|
||||
mgr = RunManager()
|
||||
bridge = MemoryStreamBridge()
|
||||
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
|
||||
# Attach a real (idle) task so cancel() actually has something to cancel.
|
||||
sleeper = asyncio.create_task(asyncio.sleep(30))
|
||||
record.task = sleeper
|
||||
request = _FakeRequest(disconnect_after=0) # disconnected on first poll
|
||||
|
||||
async def publish_until_cancel() -> None:
|
||||
# Emit one event so subscribe wakes up immediately; helper polls
|
||||
# is_disconnected after each yield.
|
||||
await asyncio.sleep(0)
|
||||
await bridge.publish(record.run_id, "values", {"step": 1})
|
||||
|
||||
asyncio.create_task(publish_until_cancel())
|
||||
completed = await asyncio.wait_for(
|
||||
wait_for_run_completion(bridge, record, request, mgr),
|
||||
timeout=2.0,
|
||||
)
|
||||
|
||||
assert completed is False
|
||||
assert record.status == RunStatus.interrupted
|
||||
# Drain the cancelled sleeper so it does not linger past the test.
|
||||
try:
|
||||
await asyncio.wait_for(sleeper, timeout=1.0)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
assert sleeper.done()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_does_not_cancel_when_continue_mode(self) -> None:
|
||||
"""on_disconnect=continue: disconnect must NOT cancel the run."""
|
||||
from app.gateway.services import wait_for_run_completion
|
||||
|
||||
async def run() -> None:
|
||||
mgr = RunManager()
|
||||
bridge = MemoryStreamBridge()
|
||||
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.continue_)
|
||||
sleeper = asyncio.create_task(asyncio.sleep(30))
|
||||
record.task = sleeper
|
||||
request = _FakeRequest(disconnect_after=0)
|
||||
|
||||
async def publish_then_end() -> None:
|
||||
await asyncio.sleep(0)
|
||||
await bridge.publish(record.run_id, "values", {"step": 1})
|
||||
|
||||
asyncio.create_task(publish_then_end())
|
||||
completed = await asyncio.wait_for(
|
||||
wait_for_run_completion(bridge, record, request, mgr),
|
||||
timeout=2.0,
|
||||
)
|
||||
|
||||
# Disconnected before END — helper still reports incomplete so the
|
||||
# caller skips checkpoint serialization, but the run keeps going.
|
||||
assert completed is False
|
||||
assert record.status == RunStatus.running
|
||||
sleeper.cancel()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_no_cancel_when_run_already_finished(self) -> None:
|
||||
"""If the run ended (END_SENTINEL) before disconnect is observed, the
|
||||
finally block must not call cancel — the run is already terminal."""
|
||||
from app.gateway.services import wait_for_run_completion
|
||||
|
||||
async def run() -> None:
|
||||
mgr = RunManager()
|
||||
bridge = MemoryStreamBridge()
|
||||
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
|
||||
# Publish END before subscribe — helper should see ended=True first
|
||||
# poll and return without ever observing the "disconnect".
|
||||
await mgr.set_status(record.run_id, RunStatus.success)
|
||||
await bridge.publish_end(record.run_id)
|
||||
request = _FakeRequest(disconnect_after=0)
|
||||
|
||||
completed = await asyncio.wait_for(
|
||||
wait_for_run_completion(bridge, record, request, mgr),
|
||||
timeout=2.0,
|
||||
)
|
||||
|
||||
assert completed is True
|
||||
assert record.status == RunStatus.success
|
||||
|
||||
asyncio.run(run())
|
||||
+6
-61
@@ -15,7 +15,7 @@
|
||||
# ============================================================================
|
||||
# Bump this number when the config schema changes.
|
||||
# Run `make config-upgrade` to merge new fields into your local config.yaml.
|
||||
config_version: 11
|
||||
config_version: 10
|
||||
|
||||
# ============================================================================
|
||||
# Logging
|
||||
@@ -177,38 +177,6 @@ models:
|
||||
# thinking:
|
||||
# type: disabled
|
||||
|
||||
# Example: Xiaomi MiMo model (with thinking support)
|
||||
# MiMo thinking mode returns reasoning_content and requires that field to be
|
||||
# replayed on historical assistant messages in multi-turn agent/tool-call
|
||||
# conversations. Use PatchedChatMiMo instead of plain ChatOpenAI.
|
||||
# Use https://api.xiaomimimo.com/v1 with pay-as-you-go `sk-...` keys.
|
||||
# Use your Token Plan regional URL (for example
|
||||
# https://token-plan-cn.xiaomimimo.com/v1) with Token Plan `tp-...` keys.
|
||||
# PatchedChatMiMo is model-id agnostic; use it for every MiMo thinking model
|
||||
# entry you configure (for example mimo-v2.5-pro, mimo-v2.5, mimo-v2-pro,
|
||||
# mimo-v2-omni, or mimo-v2-flash), including models referenced by subagent
|
||||
# model overrides.
|
||||
# See: https://platform.xiaomimimo.com/docs/en-US/usage-guide/passing-back-reasoning_content
|
||||
# - name: mimo-v2.5-pro
|
||||
# display_name: MiMo V2.5 Pro
|
||||
# use: deerflow.models.patched_mimo:PatchedChatMiMo
|
||||
# model: mimo-v2.5-pro
|
||||
# api_key: $MIMO_API_KEY
|
||||
# base_url: https://api.xiaomimimo.com/v1
|
||||
# request_timeout: 600.0
|
||||
# max_retries: 2
|
||||
# max_tokens: 8192
|
||||
# supports_thinking: true
|
||||
# supports_vision: false
|
||||
# when_thinking_enabled:
|
||||
# extra_body:
|
||||
# thinking:
|
||||
# type: enabled
|
||||
# when_thinking_disabled:
|
||||
# extra_body:
|
||||
# thinking:
|
||||
# type: disabled
|
||||
|
||||
# Example: DeepSeek model (with thinking support)
|
||||
# - name: deepseek-v3
|
||||
# display_name: DeepSeek V3 (Thinking)
|
||||
@@ -544,34 +512,6 @@ tools:
|
||||
tool_search:
|
||||
enabled: false
|
||||
|
||||
# ============================================================================
|
||||
# Tool Output Budget Protection
|
||||
# ============================================================================
|
||||
# Prevents oversized tool results from blowing the model context window.
|
||||
# Outputs exceeding `externalize_min_chars` are persisted to disk and replaced
|
||||
# with a compact preview + file reference. The model can read the full output
|
||||
# via read_file. When disk persistence is unavailable, outputs exceeding
|
||||
# `fallback_max_chars` are head+tail truncated instead.
|
||||
#
|
||||
# `exempt_tools` prevents persist→read→persist infinite loops for read tools.
|
||||
# `tool_overrides` allows per-tool threshold customization.
|
||||
|
||||
tool_output:
|
||||
enabled: true
|
||||
externalize_min_chars: 12000
|
||||
preview_head_chars: 2000
|
||||
preview_tail_chars: 1000
|
||||
fallback_max_chars: 30000
|
||||
fallback_head_chars: 8000
|
||||
fallback_tail_chars: 3000
|
||||
storage_subdir: ".tool-results"
|
||||
exempt_tools:
|
||||
- read_file
|
||||
- read_file_tool
|
||||
# tool_overrides:
|
||||
# web_search: 8000
|
||||
# bash: 20000
|
||||
|
||||
# ============================================================================
|
||||
# Loop Detection Configuration
|
||||
# ============================================================================
|
||||
@@ -702,6 +642,11 @@ sandbox:
|
||||
# # Optional: Prefix for container names (default: deer-flow-sandbox)
|
||||
# # container_prefix: deer-flow-sandbox
|
||||
#
|
||||
# # Optional: Automatically restart crashed sandbox containers (default: true)
|
||||
# # When enabled, a dead container is detected on the next tool call and
|
||||
# # transparently replaced with a fresh one. Set to false to disable.
|
||||
# # auto_restart: true
|
||||
#
|
||||
# # Optional: Additional mount directories from host to container
|
||||
# # NOTE: Skills directory is automatically mounted from skills.path to skills.container_path
|
||||
# # mounts:
|
||||
|
||||
@@ -10,11 +10,10 @@
|
||||
# should be updated accordingly.
|
||||
|
||||
# Backend API URLs (optional)
|
||||
# Leave these commented out to use the default nginx proxy (recommended for `make dev`).
|
||||
# Only set these if you need to connect to the Gateway service directly.
|
||||
# For split-origin browser access, also configure GATEWAY_CORS_ORIGINS.
|
||||
# Leave these commented out to use the default nginx proxy (recommended for `make dev`)
|
||||
# Only set these if you need to connect to backend services directly
|
||||
# NEXT_PUBLIC_BACKEND_BASE_URL="http://localhost:8001"
|
||||
# NEXT_PUBLIC_LANGGRAPH_BASE_URL="http://localhost:8001/api"
|
||||
# NEXT_PUBLIC_LANGGRAPH_BASE_URL="http://localhost:2024"
|
||||
|
||||
# Server-only Gateway wiring used by SSR (auth checks, /api/* rewrites).
|
||||
# Defaults to localhost — only override for non-local deployments.
|
||||
|
||||
+1
-5
@@ -88,11 +88,7 @@ Backend API URLs are optional; an nginx proxy is used by default:
|
||||
|
||||
```
|
||||
NEXT_PUBLIC_BACKEND_BASE_URL=http://localhost:8001
|
||||
NEXT_PUBLIC_LANGGRAPH_BASE_URL=http://localhost:8001/api
|
||||
NEXT_PUBLIC_LANGGRAPH_BASE_URL=http://localhost:2024
|
||||
```
|
||||
|
||||
Leave these unset for the standard `make dev` / Docker flow, where nginx serves
|
||||
the public `/api/langgraph/*` prefix and rewrites it to Gateway's native `/api/*`
|
||||
routes.
|
||||
|
||||
Requires Node.js 22+ and pnpm 10.26.2+.
|
||||
|
||||
Generated
+17
-17
@@ -299,16 +299,16 @@ packages:
|
||||
'@antfu/install-pkg@1.1.0':
|
||||
resolution: {integrity: sha512-MGQsmw10ZyI+EJo45CdSER4zEb+p31LpDAFp2Z3gkSd1yqVZGi0Ebx++YTEMonJy4oChEMLsxZ64j8FH6sSqtQ==}
|
||||
|
||||
'@babel/helper-string-parser@7.29.7':
|
||||
resolution: {integrity: sha512-Pb5ijPrZ89GDH8223L4UP8i6QApWxs04RbPQJTeWDV0/keR2E36MeKnyr6LYmUUvqRRI+Iv87SuF1W6ErINzYw==}
|
||||
'@babel/helper-string-parser@7.27.1':
|
||||
resolution: {integrity: sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
|
||||
'@babel/helper-validator-identifier@7.29.7':
|
||||
resolution: {integrity: sha512-qehxGkRj55h/ff8EMaJ+cYhyaKlHIxqYDn682wQD7RNp9UujOQsHog2uS0r2vzr4pW+sXf90NeeayjcNaX3fFg==}
|
||||
'@babel/helper-validator-identifier@7.28.5':
|
||||
resolution: {integrity: sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
|
||||
'@babel/parser@7.29.7':
|
||||
resolution: {integrity: sha512-hnORnjP/1P/zFEndoeX+n+t1RwWRJiJpM/jO7FW32Kn9r5+sJB2JWOdYo4L6k78j15eCwY3Gm/7364B1EMwtNg==}
|
||||
'@babel/parser@7.29.3':
|
||||
resolution: {integrity: sha512-b3ctpQwp+PROvU/cttc4OYl4MzfJUWy6FZg+PMXfzmt/+39iHVF0sDfqay8TQM3JA2EUOyKcFZt75jWriQijsA==}
|
||||
engines: {node: '>=6.0.0'}
|
||||
hasBin: true
|
||||
|
||||
@@ -316,8 +316,8 @@ packages:
|
||||
resolution: {integrity: sha512-05WQkdpL9COIMz4LjTxGpPNCdlpyimKppYNoJ5Di5EUObifl8t4tuLuUBBZEpoLYOmfvIWrsp9fCl0HoPRVTdA==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
|
||||
'@babel/types@7.29.7':
|
||||
resolution: {integrity: sha512-4zBIxpPzowiZpusoFkyGVwakdRJUyuH5PxQ/PrqghfdFWWasvnCdPfQXHrenDai+gyLARulZjZowCOj6fjT4pA==}
|
||||
'@babel/types@7.29.0':
|
||||
resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
|
||||
'@braintree/sanitize-url@7.1.2':
|
||||
@@ -5777,20 +5777,20 @@ snapshots:
|
||||
package-manager-detector: 1.6.0
|
||||
tinyexec: 1.0.2
|
||||
|
||||
'@babel/helper-string-parser@7.29.7': {}
|
||||
'@babel/helper-string-parser@7.27.1': {}
|
||||
|
||||
'@babel/helper-validator-identifier@7.29.7': {}
|
||||
'@babel/helper-validator-identifier@7.28.5': {}
|
||||
|
||||
'@babel/parser@7.29.7':
|
||||
'@babel/parser@7.29.3':
|
||||
dependencies:
|
||||
'@babel/types': 7.29.7
|
||||
'@babel/types': 7.29.0
|
||||
|
||||
'@babel/runtime@7.28.6': {}
|
||||
|
||||
'@babel/types@7.29.7':
|
||||
'@babel/types@7.29.0':
|
||||
dependencies:
|
||||
'@babel/helper-string-parser': 7.29.7
|
||||
'@babel/helper-validator-identifier': 7.29.7
|
||||
'@babel/helper-string-parser': 7.27.1
|
||||
'@babel/helper-validator-identifier': 7.28.5
|
||||
|
||||
'@braintree/sanitize-url@7.1.2': {}
|
||||
|
||||
@@ -8047,7 +8047,7 @@ snapshots:
|
||||
|
||||
'@vue/compiler-core@3.5.28':
|
||||
dependencies:
|
||||
'@babel/parser': 7.29.7
|
||||
'@babel/parser': 7.29.3
|
||||
'@vue/shared': 3.5.28
|
||||
entities: 7.0.1
|
||||
estree-walker: 2.0.2
|
||||
@@ -8060,7 +8060,7 @@ snapshots:
|
||||
|
||||
'@vue/compiler-sfc@3.5.28':
|
||||
dependencies:
|
||||
'@babel/parser': 7.29.7
|
||||
'@babel/parser': 7.29.3
|
||||
'@vue/compiler-core': 3.5.28
|
||||
'@vue/compiler-dom': 3.5.28
|
||||
'@vue/compiler-ssr': 3.5.28
|
||||
|
||||
@@ -146,27 +146,6 @@ export default function NewAgentPage() {
|
||||
err.reason === "backend_unreachable"
|
||||
) {
|
||||
setNameError(t.agents.nameStepNetworkError);
|
||||
} else if (
|
||||
err instanceof AgentNameCheckError &&
|
||||
err.reason === "request_failed"
|
||||
) {
|
||||
// Surface the backend-provided detail (e.g. validation error) when
|
||||
// one is present, wrapped in a localised prefix so zh-CN users
|
||||
// don't see a bare English string next to the surrounding Chinese
|
||||
// UI. Falls back to the generic localised fallback when the backend
|
||||
// sent no detail — `err.message` is unreliable for this branch
|
||||
// because `checkAgentName` substitutes a generated fallback string
|
||||
// ("Failed to check agent name: ${statusText}") when `detail` is
|
||||
// missing, so testing `err.message` would always be truthy and the
|
||||
// generated fallback would leak through.
|
||||
setNameError(
|
||||
err.detail
|
||||
? t.agents.nameStepCheckErrorWithDetail.replace(
|
||||
"{detail}",
|
||||
err.detail,
|
||||
)
|
||||
: t.agents.nameStepCheckError,
|
||||
);
|
||||
} else {
|
||||
setNameError(t.agents.nameStepCheckError);
|
||||
}
|
||||
@@ -193,7 +172,6 @@ export default function NewAgentPage() {
|
||||
t.agents.nameStepNetworkError,
|
||||
t.agents.nameStepBootstrapMessage,
|
||||
t.agents.nameStepCheckError,
|
||||
t.agents.nameStepCheckErrorWithDetail,
|
||||
t.agents.nameStepInvalidError,
|
||||
threadId,
|
||||
]);
|
||||
|
||||
@@ -7,10 +7,7 @@ import {
|
||||
MessageResponse,
|
||||
type MessageResponseProps,
|
||||
} from "@/components/ai-elements/message";
|
||||
import {
|
||||
preprocessStreamdownMarkdown,
|
||||
streamdownPlugins,
|
||||
} from "@/core/streamdown";
|
||||
import { streamdownPlugins } from "@/core/streamdown";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
import { CitationLink } from "../citations/citation-link";
|
||||
@@ -36,10 +33,6 @@ export function MarkdownContent({
|
||||
remarkPlugins = streamdownPlugins.remarkPlugins,
|
||||
components: componentsFromProps,
|
||||
}: MarkdownContentProps) {
|
||||
const normalizedContent = useMemo(
|
||||
() => preprocessStreamdownMarkdown(content),
|
||||
[content],
|
||||
);
|
||||
const components = useMemo(() => {
|
||||
return {
|
||||
a: (props: AnchorHTMLAttributes<HTMLAnchorElement>) => {
|
||||
@@ -77,7 +70,7 @@ export function MarkdownContent({
|
||||
rehypePlugins={rehypePlugins}
|
||||
components={components}
|
||||
>
|
||||
{normalizedContent}
|
||||
{content}
|
||||
</MessageResponse>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -9,15 +9,6 @@ export class AgentNameCheckError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public readonly reason: "backend_unreachable" | "request_failed",
|
||||
/**
|
||||
* Raw backend `detail` string when the failure came from a backend
|
||||
* response carrying one. `null` when no detail was provided (e.g.
|
||||
* network-layer failure, empty response body, unparseable body) — in
|
||||
* which case `message` is a generated fallback like "Failed to check
|
||||
* agent name: Bad Gateway" and the UI should prefer its own localized
|
||||
* fallback instead of surfacing the generated string.
|
||||
*/
|
||||
public readonly detail: string | null = null,
|
||||
) {
|
||||
super(message);
|
||||
this.name = "AgentNameCheckError";
|
||||
@@ -113,11 +104,9 @@ export async function checkAgentName(
|
||||
"backend_unreachable",
|
||||
);
|
||||
}
|
||||
const backendDetail = typeof err.detail === "string" ? err.detail : null;
|
||||
throw new AgentNameCheckError(
|
||||
backendDetail ?? `Failed to check agent name: ${res.statusText}`,
|
||||
err.detail ?? `Failed to check agent name: ${res.statusText}`,
|
||||
"request_failed",
|
||||
backendDetail,
|
||||
);
|
||||
}
|
||||
return res.json() as Promise<{ available: boolean; name: string }>;
|
||||
|
||||
@@ -38,47 +38,6 @@ function injectCsrfHeader(_url: URL, init: RequestInit): RequestInit {
|
||||
return { ...init, headers };
|
||||
}
|
||||
|
||||
export function isInactiveRunStreamError(error: unknown): boolean {
|
||||
const status =
|
||||
typeof error === "object" && error !== null
|
||||
? Reflect.get(error, "status")
|
||||
: undefined;
|
||||
const message =
|
||||
typeof error === "string"
|
||||
? error
|
||||
: error instanceof Error
|
||||
? error.message
|
||||
: typeof error === "object" && error !== null
|
||||
? String(Reflect.get(error, "message") ?? "")
|
||||
: "";
|
||||
|
||||
// Match the gateway's store-only run response in
|
||||
// backend/app/gateway/routers/thread_runs.py until the API exposes a
|
||||
// structured error code for inactive run streams.
|
||||
return (
|
||||
(status === 409 || message.includes("HTTP 409")) &&
|
||||
message.includes("not active on this worker") &&
|
||||
message.includes("cannot be streamed")
|
||||
);
|
||||
}
|
||||
|
||||
export function clearReconnectRun(
|
||||
threadId: string | null | undefined,
|
||||
runId: string,
|
||||
): void {
|
||||
if (typeof window === "undefined" || !threadId) return;
|
||||
|
||||
const key = `lg:stream:${threadId}`;
|
||||
try {
|
||||
const storage = window.sessionStorage;
|
||||
if (storage.getItem(key) === runId) {
|
||||
storage.removeItem(key);
|
||||
}
|
||||
} catch {
|
||||
// Ignore storage access failures so reconnect cleanup never throws.
|
||||
}
|
||||
}
|
||||
|
||||
function createCompatibleClient(isMock?: boolean): LangGraphClient {
|
||||
if (isStaticWebsiteOnly() && !isMock) {
|
||||
return createStaticClient();
|
||||
@@ -100,21 +59,12 @@ function createCompatibleClient(isMock?: boolean): LangGraphClient {
|
||||
)) as typeof client.runs.stream;
|
||||
|
||||
const originalJoinStream = client.runs.joinStream.bind(client.runs);
|
||||
client.runs.joinStream = async function* (threadId, runId, options) {
|
||||
try {
|
||||
yield* originalJoinStream(
|
||||
threadId,
|
||||
runId,
|
||||
sanitizeRunStreamOptions(options),
|
||||
);
|
||||
} catch (error) {
|
||||
if (isInactiveRunStreamError(error)) {
|
||||
clearReconnectRun(threadId, runId);
|
||||
return;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
} as typeof client.runs.joinStream;
|
||||
client.runs.joinStream = ((threadId, runId, options) =>
|
||||
originalJoinStream(
|
||||
threadId,
|
||||
runId,
|
||||
sanitizeRunStreamOptions(options),
|
||||
)) as typeof client.runs.joinStream;
|
||||
|
||||
return client;
|
||||
}
|
||||
|
||||
@@ -204,7 +204,6 @@ export const enUS: Translations = {
|
||||
nameStepNetworkError:
|
||||
"Network request failed — check your network or backend connection",
|
||||
nameStepCheckError: "Could not verify name availability — please try again",
|
||||
nameStepCheckErrorWithDetail: "Name check failed: {detail}",
|
||||
nameStepApiDisabledError:
|
||||
"Custom agent management is not enabled on this server. Please contact your administrator.",
|
||||
nameStepBootstrapMessage:
|
||||
|
||||
@@ -141,7 +141,6 @@ export interface Translations {
|
||||
nameStepAlreadyExistsError: string;
|
||||
nameStepNetworkError: string;
|
||||
nameStepCheckError: string;
|
||||
nameStepCheckErrorWithDetail: string;
|
||||
nameStepApiDisabledError: string;
|
||||
nameStepBootstrapMessage: string;
|
||||
save: string;
|
||||
|
||||
@@ -192,7 +192,6 @@ export const zhCN: Translations = {
|
||||
nameStepAlreadyExistsError: "已存在同名智能体",
|
||||
nameStepNetworkError: "网络请求失败,请检查网络或后端连接",
|
||||
nameStepCheckError: "无法验证名称可用性,请稍后重试",
|
||||
nameStepCheckErrorWithDetail: "名称校验失败:{detail}",
|
||||
nameStepApiDisabledError:
|
||||
"服务器未开启自定义智能体管理功能,请联系管理员。",
|
||||
nameStepBootstrapMessage:
|
||||
|
||||
@@ -1,3 +1 @@
|
||||
export * from "./mermaid";
|
||||
export * from "./preprocess";
|
||||
export * from "./plugins";
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
const MERMAID_OPENING_FENCE_RE =
|
||||
/^[ \t]{0,3}(`{3,}|~{3,})[ \t]*mermaid(?:[ \t].*)?$/i;
|
||||
|
||||
const WINDOWS_LINE_ENDING_RE = /\r\n?/g;
|
||||
|
||||
const LABELLED_DOTTED_ARROW_RE =
|
||||
/^(\s*)(.+?)\s*--\s*("[^"\n]+"|'[^'\n]+')\s*-\.->\s*(.+?)\s*$/;
|
||||
|
||||
function normalizeMermaidCode(code: string): string {
|
||||
return code
|
||||
.split("\n")
|
||||
.map((line) =>
|
||||
line.replace(
|
||||
LABELLED_DOTTED_ARROW_RE,
|
||||
(
|
||||
_match,
|
||||
indent: string,
|
||||
source: string,
|
||||
label: string,
|
||||
target: string,
|
||||
) => `${indent}${source} -. ${label} .-> ${target}`,
|
||||
),
|
||||
)
|
||||
.join("\n");
|
||||
}
|
||||
|
||||
function isClosingFence(line: string, fence: string): boolean {
|
||||
const trimmedLine = line.trimEnd();
|
||||
const indentationLength = trimmedLine.length - trimmedLine.trimStart().length;
|
||||
const fenceMarker = trimmedLine.slice(indentationLength);
|
||||
const fenceChar = fence.charAt(0);
|
||||
|
||||
if (indentationLength > 3 || !fenceMarker.startsWith(fenceChar)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (
|
||||
fenceMarker.length >= fence.length &&
|
||||
[...fenceMarker].every((char) => char === fenceChar)
|
||||
);
|
||||
}
|
||||
|
||||
export function normalizeMermaidMarkdown(markdown: string): string {
|
||||
const lines = markdown.replace(WINDOWS_LINE_ENDING_RE, "\n").split("\n");
|
||||
const normalizedLines: string[] = [];
|
||||
|
||||
for (let index = 0; index < lines.length; index += 1) {
|
||||
const line = lines[index]!;
|
||||
|
||||
const openingFenceMatch = MERMAID_OPENING_FENCE_RE.exec(line);
|
||||
|
||||
if (!openingFenceMatch) {
|
||||
normalizedLines.push(line);
|
||||
continue;
|
||||
}
|
||||
|
||||
const openingFence = openingFenceMatch[1];
|
||||
|
||||
if (openingFence === undefined) {
|
||||
normalizedLines.push(line);
|
||||
continue;
|
||||
}
|
||||
|
||||
const codeLines: string[] = [];
|
||||
let closingLine: string | undefined;
|
||||
let cursor = index + 1;
|
||||
|
||||
for (; cursor < lines.length; cursor += 1) {
|
||||
const candidateLine = lines[cursor]!;
|
||||
|
||||
if (isClosingFence(candidateLine, openingFence)) {
|
||||
closingLine = candidateLine;
|
||||
break;
|
||||
}
|
||||
|
||||
codeLines.push(candidateLine);
|
||||
}
|
||||
|
||||
if (closingLine === undefined) {
|
||||
normalizedLines.push(line, ...codeLines);
|
||||
index = cursor - 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
normalizedLines.push(line);
|
||||
|
||||
if (codeLines.length > 0) {
|
||||
normalizedLines.push(
|
||||
...normalizeMermaidCode(codeLines.join("\n")).split("\n"),
|
||||
);
|
||||
}
|
||||
|
||||
normalizedLines.push(closingLine);
|
||||
index = cursor;
|
||||
}
|
||||
|
||||
return normalizedLines.join("\n");
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
import { normalizeMermaidMarkdown } from "./mermaid";
|
||||
|
||||
const MERMAID_BLOCK_HINT_RE = /mermaid/i;
|
||||
|
||||
export function preprocessStreamdownMarkdown(markdown: string): string {
|
||||
if (!MERMAID_BLOCK_HINT_RE.test(markdown) || !markdown.includes("-.->")) {
|
||||
return markdown;
|
||||
}
|
||||
|
||||
return normalizeMermaidMarkdown(markdown);
|
||||
}
|
||||
@@ -1,12 +1,7 @@
|
||||
import type { AIMessage, Message, Run } from "@langchain/langgraph-sdk";
|
||||
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
|
||||
import { useStream } from "@langchain/langgraph-sdk/react";
|
||||
import {
|
||||
type QueryClient,
|
||||
useMutation,
|
||||
useQuery,
|
||||
useQueryClient,
|
||||
} from "@tanstack/react-query";
|
||||
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
@@ -16,7 +11,6 @@ import { getAPIClient } from "../api";
|
||||
import { fetch } from "../api/fetcher";
|
||||
import { getBackendBaseURL } from "../config";
|
||||
import { useI18n } from "../i18n/hooks";
|
||||
import { isHiddenFromUIMessage } from "../messages/utils";
|
||||
import type { FileInMessage } from "../messages/utils";
|
||||
import type { LocalSettings } from "../settings";
|
||||
import { useUpdateSubtask } from "../tasks/context";
|
||||
@@ -55,11 +49,6 @@ function isNonEmptyString(value: string | undefined): value is string {
|
||||
return typeof value === "string" && value.length > 0;
|
||||
}
|
||||
|
||||
const SUMMARIZATION_MIDDLEWARE_UPDATE_KEYS = new Set([
|
||||
"SummarizationMiddleware.before_model",
|
||||
"DeerFlowSummarizationMiddleware.before_model",
|
||||
]);
|
||||
|
||||
function messageIdentity(message: Message): string | undefined {
|
||||
if (
|
||||
"tool_call_id" in message &&
|
||||
@@ -76,33 +65,17 @@ function messageIdentity(message: Message): string | undefined {
|
||||
|
||||
function dedupeMessagesByIdentity(messages: Message[]): Message[] {
|
||||
const lastIndexByIdentity = new Map<string, number>();
|
||||
const lastVisibleIndexByIdentity = new Map<string, number>();
|
||||
|
||||
// This is a UI-display dedupe rule, not a general LangChain message-stream
|
||||
// contract. Hidden messages that share an identity with a visible message are
|
||||
// treated as control messages for this merged view; hidden messages carrying
|
||||
// independent tracing/task semantics should use a distinct id or a custom
|
||||
// stream/state channel instead of relying on message dedupe preservation.
|
||||
messages.forEach((message, index) => {
|
||||
const identity = messageIdentity(message);
|
||||
if (identity) {
|
||||
lastIndexByIdentity.set(identity, index);
|
||||
if (!isHiddenFromUIMessage(message)) {
|
||||
lastVisibleIndexByIdentity.set(identity, index);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return messages.filter((message, index) => {
|
||||
const identity = messageIdentity(message);
|
||||
if (!identity) {
|
||||
return true;
|
||||
}
|
||||
const visibleIndex = lastVisibleIndexByIdentity.get(identity);
|
||||
if (visibleIndex !== undefined) {
|
||||
return visibleIndex === index;
|
||||
}
|
||||
return lastIndexByIdentity.get(identity) === index;
|
||||
return !identity || lastIndexByIdentity.get(identity) === index;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -124,15 +97,8 @@ export function mergeMessages(
|
||||
threadMessages: Message[],
|
||||
optimisticMessages: Message[],
|
||||
): Message[] {
|
||||
// Only visible live messages should trim overlapping history. Hidden messages
|
||||
// are UI control messages in this path, not observability records; any hidden
|
||||
// message that must survive as task/tracing data should use custom events or a
|
||||
// separate state channel instead of participating in this overlap heuristic.
|
||||
const threadMessageIds = new Set(
|
||||
threadMessages
|
||||
.filter((message) => !isHiddenFromUIMessage(message))
|
||||
.map(messageIdentity)
|
||||
.filter(isNonEmptyString),
|
||||
threadMessages.map(messageIdentity).filter(isNonEmptyString),
|
||||
);
|
||||
|
||||
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
|
||||
@@ -183,72 +149,6 @@ export function getVisibleOptimisticMessages(
|
||||
return optimisticMessages;
|
||||
}
|
||||
|
||||
export function getSummarizationMiddlewareMessages(
|
||||
data: unknown,
|
||||
): Message[] | undefined {
|
||||
if (typeof data !== "object" || data === null) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
for (const [key, update] of Object.entries(data)) {
|
||||
if (!SUMMARIZATION_MIDDLEWARE_UPDATE_KEYS.has(key)) {
|
||||
continue;
|
||||
}
|
||||
if (typeof update !== "object" || update === null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const messages = Reflect.get(update, "messages");
|
||||
if (Array.isArray(messages)) {
|
||||
return [...messages] as Message[];
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export function upsertThreadInSearchCache(
|
||||
queryClient: QueryClient,
|
||||
thread: AgentThread,
|
||||
) {
|
||||
queryClient.setQueriesData(
|
||||
{
|
||||
queryKey: ["threads", "search"],
|
||||
exact: false,
|
||||
},
|
||||
(oldData: Array<AgentThread> | undefined) => {
|
||||
if (!oldData) {
|
||||
return [thread];
|
||||
}
|
||||
|
||||
const existingIndex = oldData.findIndex(
|
||||
(t) => t.thread_id === thread.thread_id,
|
||||
);
|
||||
if (existingIndex === -1) {
|
||||
return [thread, ...oldData];
|
||||
}
|
||||
|
||||
return oldData.map((t, index) => {
|
||||
if (index !== existingIndex) {
|
||||
return t;
|
||||
}
|
||||
return {
|
||||
...thread,
|
||||
...t,
|
||||
metadata: {
|
||||
...(thread.metadata ?? {}),
|
||||
...(t.metadata ?? {}),
|
||||
},
|
||||
values: {
|
||||
...thread.values,
|
||||
...t.values,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
function getStreamErrorMessage(error: unknown): string {
|
||||
if (typeof error === "string" && error.trim()) {
|
||||
return error;
|
||||
@@ -341,20 +241,6 @@ export function useThreadStream({
|
||||
fetchStateHistory: { limit: 1 },
|
||||
onCreated(meta) {
|
||||
handleStreamStart(meta.thread_id, meta.run_id);
|
||||
const now = new Date().toISOString();
|
||||
upsertThreadInSearchCache(queryClient, {
|
||||
thread_id: meta.thread_id,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
metadata: context.agent_name ? { agent_name: context.agent_name } : {},
|
||||
status: "busy",
|
||||
values: {
|
||||
title: t.pages.newChat,
|
||||
messages: [],
|
||||
artifacts: [],
|
||||
},
|
||||
interrupts: {},
|
||||
});
|
||||
if (context.agent_name && !isMock) {
|
||||
void getAPIClient()
|
||||
.threads.update(meta.thread_id, {
|
||||
@@ -372,25 +258,24 @@ export function useThreadStream({
|
||||
}
|
||||
},
|
||||
onUpdateEvent(data) {
|
||||
const _messages = getSummarizationMiddlewareMessages(data);
|
||||
if (_messages && _messages.length >= 2) {
|
||||
if (data["SummarizationMiddleware.before_model"]) {
|
||||
const _messages = [
|
||||
...(data["SummarizationMiddleware.before_model"].messages ?? []),
|
||||
];
|
||||
|
||||
if (_messages.length < 2) {
|
||||
return;
|
||||
}
|
||||
for (const m of _messages) {
|
||||
if (m.name === "summary" && m.type === "human") {
|
||||
summarizedRef.current?.add(m.id ?? "");
|
||||
}
|
||||
}
|
||||
const firstRetainedVisibleIdentity = _messages
|
||||
.filter((message) => message.type !== "remove")
|
||||
.filter((message) => !isHiddenFromUIMessage(message))
|
||||
.map(messageIdentity)
|
||||
.find(isNonEmptyString);
|
||||
const _lastKeepMessage = _messages[2];
|
||||
const _currentMessages = [...messagesRef.current];
|
||||
const _movedMessages: Message[] = [];
|
||||
for (const m of _currentMessages) {
|
||||
if (
|
||||
firstRetainedVisibleIdentity &&
|
||||
messageIdentity(m) === firstRetainedVisibleIdentity
|
||||
) {
|
||||
if (m.id !== undefined && m.id === _lastKeepMessage?.id) {
|
||||
break;
|
||||
}
|
||||
if (!summarizedRef.current?.has(m.id ?? "")) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user