mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
fix(tools): introduce Runtime type alias to eliminate Pydantic serialization warning (#2774)
* fix(tools): introduce Runtime type alias to eliminate Pydantic serialization warning
Add deerflow/tools/types.py with:
Runtime = ToolRuntime[dict[str, Any], ThreadState]
Replace every runtime: ToolRuntime[ContextT, ThreadState] and
runtime: ToolRuntime[dict[str, Any], ThreadState] annotation in
sandbox/tools.py, present_file_tool.py, task_tool.py, view_image_tool.py,
and skill_manage_tool.py with the new Runtime alias.
The unbound ContextT TypeVar (default None) caused
PydanticSerializationUnexpectedValue warnings on every tool call because
LangChain's BaseTool._parse_input calls model_dump() on the auto-generated
args_schema while DeerFlow passes a dict as runtime context.
Binding the context to dict[str, Any] aligns Pydantic's serialization
expectations with reality and removes the noise from all run modes.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
* fix(tools): extend Runtime alias to setup_agent and update_agent tools
Replace bare ToolRuntime annotations in setup_agent_tool.py and
update_agent_tool.py with the shared Runtime alias introduced in the
previous commit, and add both tools to the Pydantic serialization
warning regression test (13 cases total).
Co-authored-by: Cursor <cursoragent@cursor.com>
* test(tools): loosen Pydantic warning filter to avoid version-specific format
Replace the brittle "field_name='context'" substring check with a looser
"context" match so the assertion stays valid if Pydantic changes its
internal warning format across versions.
Co-authored-by: Cursor <cursoragent@cursor.com>
* test(tools): simplify warning filter and clean up docstring
Remove the "context" substring condition from the Pydantic warning
filter — asserting that no PydanticSerializationUnexpectedValue fires
at all is both simpler and more comprehensive, since the test payload
contains only the tool's own args plus runtime.
Also update the module docstring to remove the version-specific warning
format example that was inconsistent with the looser filter.
Co-authored-by: Cursor <cursoragent@cursor.com>
---------
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -1,20 +1,19 @@
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain.tools import InjectedToolCallId, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.config import get_config
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
||||
|
||||
|
||||
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
|
||||
def _get_thread_id(runtime: Runtime) -> str | None:
|
||||
"""Resolve the current thread id from runtime context or RunnableConfig."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id:
|
||||
@@ -32,7 +31,7 @@ def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
|
||||
|
||||
|
||||
def _normalize_presented_filepath(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
filepath: str,
|
||||
) -> str:
|
||||
"""Normalize a presented file path to the `/mnt/user-data/outputs/*` contract.
|
||||
@@ -83,7 +82,7 @@ def _normalize_presented_filepath(
|
||||
|
||||
@tool("present_files", parse_docstring=True)
|
||||
def present_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
filepaths: list[str],
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
|
||||
@@ -3,12 +3,12 @@ import logging
|
||||
import yaml
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.prebuilt import ToolRuntime
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.config.agents_config import validate_agent_name
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
def setup_agent(
|
||||
soul: str,
|
||||
description: str,
|
||||
runtime: ToolRuntime,
|
||||
runtime: Runtime,
|
||||
skills: list[str] | None = None,
|
||||
) -> Command:
|
||||
"""Setup the custom DeerFlow agent.
|
||||
|
||||
@@ -6,11 +6,9 @@ import uuid
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Annotated, Any, cast
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain.tools import InjectedToolCallId, tool
|
||||
from langgraph.config import get_stream_writer
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
|
||||
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
|
||||
@@ -21,6 +19,7 @@ from deerflow.subagents.executor import (
|
||||
get_background_task_result,
|
||||
request_cancel_background_task,
|
||||
)
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
@@ -50,7 +49,7 @@ def _merge_skill_allowlists(parent: list[str] | None, child: list[str] | None) -
|
||||
|
||||
@tool("task", parse_docstring=True)
|
||||
async def task_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
prompt: str,
|
||||
subagent_type: str,
|
||||
|
||||
@@ -22,13 +22,13 @@ from typing import Any
|
||||
import yaml
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.prebuilt import ToolRuntime
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.config.agents_config import load_agent_config, validate_agent_name
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -69,7 +69,7 @@ def _cleanup_temps(temps: list[Path]) -> None:
|
||||
|
||||
@tool
|
||||
def update_agent(
|
||||
runtime: ToolRuntime,
|
||||
runtime: Runtime,
|
||||
soul: str | None = None,
|
||||
description: str | None = None,
|
||||
skills: list[str] | None = None,
|
||||
|
||||
@@ -3,13 +3,13 @@ import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain.tools import InjectedToolCallId, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState, ThreadState
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
_ALLOWED_IMAGE_VIRTUAL_ROOTS = (
|
||||
f"{VIRTUAL_PATH_PREFIX}/workspace",
|
||||
@@ -48,7 +48,7 @@ def _sanitize_image_error(error: Exception, thread_data: ThreadDataState | None)
|
||||
|
||||
@tool("view_image", parse_docstring=True)
|
||||
def view_image_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
image_path: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
|
||||
@@ -7,16 +7,15 @@ import logging
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
from langgraph.typing import ContextT
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.mcp.tools import _make_sync_tool_wrapper
|
||||
from deerflow.skills.security_scanner import scan_skill_content
|
||||
from deerflow.skills.storage import get_or_new_skill_storage
|
||||
from deerflow.skills.storage.skill_storage import SkillStorage
|
||||
from deerflow.skills.types import SKILL_MD_FILE
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,7 +30,7 @@ def _get_lock(name: str) -> asyncio.Lock:
|
||||
return lock
|
||||
|
||||
|
||||
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
|
||||
def _get_thread_id(runtime: Runtime | None) -> str | None:
|
||||
if runtime is None:
|
||||
return None
|
||||
if runtime.context and runtime.context.get("thread_id"):
|
||||
@@ -65,7 +64,7 @@ async def _to_thread(func, /, *args, **kwargs):
|
||||
|
||||
|
||||
async def _skill_manage_impl(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
action: str,
|
||||
name: str,
|
||||
content: str | None = None,
|
||||
@@ -204,7 +203,7 @@ async def _skill_manage_impl(
|
||||
|
||||
@tool("skill_manage", parse_docstring=True)
|
||||
async def skill_manage_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
runtime: Runtime,
|
||||
action: str,
|
||||
name: str,
|
||||
content: str | None = None,
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
# Concrete runtime type used by all DeerFlow tools.
|
||||
# Using dict[str, Any] for the context parameter instead of the unbound ContextT
|
||||
# TypeVar prevents PydanticSerializationUnexpectedValue warnings when LangChain
|
||||
# calls model_dump() on a tool's auto-generated args_schema.
|
||||
Runtime = ToolRuntime[dict[str, Any], ThreadState]
|
||||
Reference in New Issue
Block a user