mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 23:46:50 +00:00
Merge branch 'main' into rayhpeng/persistence-scaffold
# Conflicts: # .env.example # backend/packages/harness/deerflow/agents/middlewares/title_middleware.py
This commit is contained in:
@@ -49,6 +49,7 @@ class Fact(BaseModel):
|
||||
confidence: float = Field(default=0.5, description="Confidence score (0-1)")
|
||||
createdAt: str = Field(default="", description="Creation timestamp")
|
||||
source: str = Field(default="unknown", description="Source thread ID")
|
||||
sourceError: str | None = Field(default=None, description="Optional description of the prior mistake or wrong approach")
|
||||
|
||||
|
||||
class MemoryResponse(BaseModel):
|
||||
@@ -108,6 +109,7 @@ class MemoryStatusResponse(BaseModel):
|
||||
@router.get(
|
||||
"/memory",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Get Memory Data",
|
||||
description="Retrieve the current global memory data including user context, history, and facts.",
|
||||
)
|
||||
@@ -152,6 +154,7 @@ async def get_memory() -> MemoryResponse:
|
||||
@router.post(
|
||||
"/memory/reload",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Reload Memory Data",
|
||||
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
||||
)
|
||||
@@ -171,6 +174,7 @@ async def reload_memory() -> MemoryResponse:
|
||||
@router.delete(
|
||||
"/memory",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Clear All Memory Data",
|
||||
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
||||
)
|
||||
@@ -187,6 +191,7 @@ async def clear_memory() -> MemoryResponse:
|
||||
@router.post(
|
||||
"/memory/facts",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Create Memory Fact",
|
||||
description="Create a single saved memory fact manually.",
|
||||
)
|
||||
@@ -209,6 +214,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
||||
@router.delete(
|
||||
"/memory/facts/{fact_id}",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Delete Memory Fact",
|
||||
description="Delete a single saved memory fact by its fact id.",
|
||||
)
|
||||
@@ -227,6 +233,7 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
||||
@router.patch(
|
||||
"/memory/facts/{fact_id}",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Patch Memory Fact",
|
||||
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
||||
)
|
||||
@@ -252,6 +259,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
||||
@router.get(
|
||||
"/memory/export",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Export Memory Data",
|
||||
description="Export the current global memory data as JSON for backup or transfer.",
|
||||
)
|
||||
@@ -264,6 +272,7 @@ async def export_memory() -> MemoryResponse:
|
||||
@router.post(
|
||||
"/memory/import",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Import Memory Data",
|
||||
description="Import and overwrite the current global memory data from a JSON payload.",
|
||||
)
|
||||
@@ -317,6 +326,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||
@router.get(
|
||||
"/memory/status",
|
||||
response_model=MemoryStatusResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Get Memory Status",
|
||||
description="Retrieve both memory configuration and current data in a single request.",
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.models import create_chat_model
|
||||
@@ -106,22 +107,21 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
|
||||
if not conversation:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
prompt = (
|
||||
system_instruction = (
|
||||
"You are generating follow-up questions to help the user continue the conversation.\n"
|
||||
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
|
||||
"Requirements:\n"
|
||||
"- Questions must be relevant to the conversation.\n"
|
||||
"- Questions must be relevant to the preceding conversation.\n"
|
||||
"- Questions must be written in the same language as the user.\n"
|
||||
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
|
||||
"- Do NOT include numbering, markdown, or any extra text.\n"
|
||||
"- Output MUST be a JSON array of strings only.\n\n"
|
||||
"Conversation:\n"
|
||||
f"{conversation}\n"
|
||||
"- Output MUST be a JSON array of strings only.\n"
|
||||
)
|
||||
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||
|
||||
try:
|
||||
model = create_chat_model(name=request.model_name, thinking_enabled=False)
|
||||
response = model.invoke(prompt)
|
||||
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||
raw = _extract_response_text(response.content)
|
||||
suggestions = _parse_json_string_list(raw) or []
|
||||
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
||||
|
||||
@@ -38,6 +38,7 @@ class RunCreateRequest(BaseModel):
|
||||
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
|
||||
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
|
||||
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
|
||||
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
||||
webhook: str | None = Field(default=None, description="Completion callback URL")
|
||||
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
||||
|
||||
@@ -413,16 +413,19 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
||||
}
|
||||
|
||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle") # type: ignore[union-attr]
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=status,
|
||||
created_at=str(record.get("created_at", "")), # type: ignore[union-attr]
|
||||
updated_at=str(record.get("updated_at", "")), # type: ignore[union-attr]
|
||||
metadata=record.get("metadata", {}), # type: ignore[union-attr]
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
values=serialize_channel_values(channel_values),
|
||||
)
|
||||
|
||||
|
||||
@@ -129,26 +129,38 @@ def build_run_config(
|
||||
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
||||
identically.
|
||||
"""
|
||||
configurable: dict[str, Any] = {"thread_id": thread_id}
|
||||
config: dict[str, Any] = {"recursion_limit": 100}
|
||||
if request_config:
|
||||
configurable.update(request_config.get("configurable", {}))
|
||||
# LangGraph >= 0.6.0 introduced ``context`` as the preferred way to
|
||||
# pass thread-level data and rejects requests that include both
|
||||
# ``configurable`` and ``context``. If the caller already sends
|
||||
# ``context``, honour it and skip our own ``configurable`` dict.
|
||||
if "context" in request_config:
|
||||
if "configurable" in request_config:
|
||||
logger.warning(
|
||||
"build_run_config: client sent both 'context' and 'configurable'; preferring 'context' (LangGraph >= 0.6.0). thread_id=%s, caller_configurable keys=%s",
|
||||
thread_id,
|
||||
list(request_config.get("configurable", {}).keys()),
|
||||
)
|
||||
config["context"] = request_config["context"]
|
||||
else:
|
||||
configurable = {"thread_id": thread_id}
|
||||
configurable.update(request_config.get("configurable", {}))
|
||||
config["configurable"] = configurable
|
||||
for k, v in request_config.items():
|
||||
if k not in ("configurable", "context"):
|
||||
config[k] = v
|
||||
else:
|
||||
config["configurable"] = {"thread_id": thread_id}
|
||||
|
||||
# Inject custom agent name when the caller specified a non-default assistant.
|
||||
# Honour an explicit configurable["agent_name"] in the request if already set.
|
||||
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "agent_name" not in configurable:
|
||||
# Normalize the same way ChannelManager does: strip, lowercase,
|
||||
# replace underscores with hyphens, then validate to prevent path
|
||||
# traversal and invalid agent directory lookups.
|
||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
||||
configurable["agent_name"] = normalized
|
||||
|
||||
config: dict[str, Any] = {"configurable": configurable, "recursion_limit": 100}
|
||||
if request_config:
|
||||
for k, v in request_config.items():
|
||||
if k != "configurable":
|
||||
config[k] = v
|
||||
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "configurable" in config:
|
||||
if "agent_name" not in config["configurable"]:
|
||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
||||
config["configurable"]["agent_name"] = normalized
|
||||
if metadata:
|
||||
config.setdefault("metadata", {}).update(metadata)
|
||||
return config
|
||||
@@ -304,6 +316,27 @@ async def start_run(
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||
|
||||
# Merge DeerFlow-specific context overrides into configurable.
|
||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||
context = getattr(body, "context", None)
|
||||
if context:
|
||||
_CONTEXT_CONFIGURABLE_KEYS = {
|
||||
"model_name",
|
||||
"mode",
|
||||
"thinking_enabled",
|
||||
"reasoning_effort",
|
||||
"is_plan_mode",
|
||||
"subagent_enabled",
|
||||
"max_concurrent_subagents",
|
||||
}
|
||||
configurable = config.setdefault("configurable", {})
|
||||
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
||||
if key in context:
|
||||
configurable.setdefault(key, context[key])
|
||||
|
||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||
|
||||
task = asyncio.create_task(
|
||||
|
||||
Reference in New Issue
Block a user