Compare commits

...

5 Commits

Author SHA1 Message Date
hetao d9aa92afaa fix: fix unittes & background investigation search logic 2025-05-28 12:46:55 +08:00
wushiai1109 29be360954 Update nodes.py (#242)
SELECTED_SEARCH_ENGINE impossible equal to SearchEngine.ARXIV, should be SearchEngine.ARXIV.value, or use the encapsulated get_web_search_tool
2025-05-27 18:58:14 +08:00
Harsha Vardhan Mannem 3ed70e11d5 Fix/server error handling (#212)
* chore: add venv/ to gitignore

* fix: add server error handling and graceful shutdown

* Fix linting issues in server.py
2025-05-22 13:45:07 +08:00
laundry 55ce399969 test: add background node unit test (#198)
* test: add background node unit test

Change-Id: Ia99f5a1687464387dcb01bbee04deaa371c6e490

* test: add background node unit test

Change-Id: I9aabcf02ff04fda40c56f3ea22abe6b8f93bf9b6

* test: fix test error

Change-Id: I3997dc53a2cfaa35501a1fbda5902ee15528124e

* test: fix unit test error

Change-Id: If4c4cd10673e76a30945674c7cda198aeabf28d0

* test: fix unit test error

Change-Id: I3dd7a6179132e5497a30ada443d88de0c47af3d4
2025-05-20 14:25:35 +08:00
DanielWalnut 8bbcdbe4de feat: config max_search_results for search engine (#192)
* feat: implement UI

* feat: config max_search_results for search engine via api

---------

Co-authored-by: Henry Li <henry1943@163.com>
2025-05-18 13:23:52 +08:00
18 changed files with 258 additions and 100 deletions
+1
View File
@@ -11,6 +11,7 @@ static/browser_history/*.gif
# Virtual environments
.venv
venv/
# Environment variables
.env
+25 -11
View File
@@ -7,7 +7,8 @@ Server script for running the DeerFlow API.
import argparse
import logging
import signal
import sys
import uvicorn
# Configure logging
@@ -18,6 +19,17 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
def handle_shutdown(signum, frame):
"""Handle graceful shutdown on SIGTERM/SIGINT"""
logger.info("Received shutdown signal. Starting graceful shutdown...")
sys.exit(0)
# Register signal handlers
signal.signal(signal.SIGTERM, handle_shutdown)
signal.signal(signal.SIGINT, handle_shutdown)
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser(description="Run the DeerFlow API server")
@@ -50,16 +62,18 @@ if __name__ == "__main__":
# Determine reload setting
reload = False
# Command line arguments override defaults
if args.reload:
reload = True
logger.info("Starting DeerFlow API server")
uvicorn.run(
"src.server:app",
host=args.host,
port=args.port,
reload=reload,
log_level=args.log_level,
)
try:
logger.info(f"Starting DeerFlow API server on {args.host}:{args.port}")
uvicorn.run(
"src.server:app",
host=args.host,
port=args.port,
reload=reload,
log_level=args.log_level,
)
except Exception as e:
logger.error(f"Failed to start server: {str(e)}")
sys.exit(1)
+2 -2
View File
@@ -1,6 +1,6 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .agents import research_agent, coder_agent
from .agents import create_agent
__all__ = ["research_agent", "coder_agent"]
__all__ = ["create_agent"]
-13
View File
@@ -4,12 +4,6 @@
from langgraph.prebuilt import create_react_agent
from src.prompts import apply_prompt_template
from src.tools import (
crawl_tool,
python_repl_tool,
web_search_tool,
)
from src.llms.llm import get_llm_by_type
from src.config.agents import AGENT_LLM_MAP
@@ -23,10 +17,3 @@ def create_agent(agent_name: str, agent_type: str, tools: list, prompt_template:
tools=tools,
prompt=lambda state: apply_prompt_template(prompt_template, state),
)
# Create agents using the factory function
research_agent = create_agent(
"researcher", "researcher", [web_search_tool, crawl_tool], "researcher"
)
coder_agent = create_agent("coder", "coder", [python_repl_tool], "coder")
+1 -2
View File
@@ -1,7 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .tools import SEARCH_MAX_RESULTS, SELECTED_SEARCH_ENGINE, SearchEngine
from .tools import SELECTED_SEARCH_ENGINE, SearchEngine
from .loader import load_yaml_config
from .questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN
@@ -42,7 +42,6 @@ __all__ = [
# Other configurations
"TEAM_MEMBERS",
"TEAM_MEMBER_CONFIGRATIONS",
"SEARCH_MAX_RESULTS",
"SELECTED_SEARCH_ENGINE",
"SearchEngine",
"BUILT_IN_QUESTIONS",
+1
View File
@@ -14,6 +14,7 @@ class Configuration:
max_plan_iterations: int = 1 # Maximum number of plan iterations
max_step_num: int = 3 # Maximum number of steps in a plan
max_search_results: int = 3 # Maximum number of search results
mcp_settings: dict = None # MCP settings, including dynamic loaded tools
@classmethod
-1
View File
@@ -17,4 +17,3 @@ class SearchEngine(enum.Enum):
# Tool configuration
SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
SEARCH_MAX_RESULTS = 3
+19 -17
View File
@@ -12,12 +12,11 @@ from langchain_core.tools import tool
from langgraph.types import Command, interrupt
from langchain_mcp_adapters.client import MultiServerMCPClient
from src.agents.agents import coder_agent, research_agent, create_agent
from src.agents import create_agent
from src.tools.search import LoggedTavilySearch
from src.tools import (
crawl_tool,
web_search_tool,
get_web_search_tool,
python_repl_tool,
)
@@ -29,7 +28,7 @@ from src.prompts.template import apply_prompt_template
from src.utils.json_utils import repair_json_output
from .types import State
from ..config import SEARCH_MAX_RESULTS, SELECTED_SEARCH_ENGINE, SearchEngine
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
logger = logging.getLogger(__name__)
@@ -45,13 +44,16 @@ def handoff_to_planner(
return
def background_investigation_node(state: State) -> Command[Literal["planner"]]:
def background_investigation_node(
state: State, config: RunnableConfig
) -> Command[Literal["planner"]]:
logger.info("background investigation node is running.")
configurable = Configuration.from_runnable_config(config)
query = state["messages"][-1].content
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY:
searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke(
{"query": query}
)
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
searched_content = LoggedTavilySearch(
max_results=configurable.max_search_results
).invoke(query)
background_investigation_results = None
if isinstance(searched_content, list):
background_investigation_results = [
@@ -63,7 +65,9 @@ def background_investigation_node(state: State) -> Command[Literal["planner"]]:
f"Tavily search returned malformed response: {searched_content}"
)
else:
background_investigation_results = web_search_tool.invoke(query)
background_investigation_results = get_web_search_tool(
configurable.max_search_results
).invoke(query)
return Command(
update={
"background_investigation_results": json.dumps(
@@ -403,7 +407,6 @@ async def _setup_and_execute_agent_step(
state: State,
config: RunnableConfig,
agent_type: str,
default_agent,
default_tools: list,
) -> Command[Literal["research_team"]]:
"""Helper function to set up an agent with appropriate tools and execute a step.
@@ -417,7 +420,6 @@ async def _setup_and_execute_agent_step(
state: The current state
config: The runnable config
agent_type: The type of agent ("researcher" or "coder")
default_agent: The default agent to use if no MCP servers are configured
default_tools: The default tools to add to the agent
Returns:
@@ -455,8 +457,9 @@ async def _setup_and_execute_agent_step(
agent = create_agent(agent_type, agent_type, loaded_tools, agent_type)
return await _execute_agent_step(state, agent, agent_type)
else:
# Use default agent if no MCP servers are configured
return await _execute_agent_step(state, default_agent, agent_type)
# Use default tools if no MCP servers are configured
agent = create_agent(agent_type, agent_type, default_tools, agent_type)
return await _execute_agent_step(state, agent, agent_type)
async def researcher_node(
@@ -464,12 +467,12 @@ async def researcher_node(
) -> Command[Literal["research_team"]]:
"""Researcher node that do research"""
logger.info("Researcher node is researching.")
configurable = Configuration.from_runnable_config(config)
return await _setup_and_execute_agent_step(
state,
config,
"researcher",
research_agent,
[web_search_tool, crawl_tool],
[get_web_search_tool(configurable.max_search_results), crawl_tool],
)
@@ -482,6 +485,5 @@ async def coder_node(
state,
config,
"coder",
coder_agent,
[python_repl_tool],
)
+2 -3
View File
@@ -44,13 +44,12 @@ def get_llm_by_type(
return llm
# Initialize LLMs for different purposes - now these will be cached
basic_llm = get_llm_by_type("basic")
# In the future, we will use reasoning_llm and vl_llm for different purposes
# reasoning_llm = get_llm_by_type("reasoning")
# vl_llm = get_llm_by_type("vision")
if __name__ == "__main__":
# Initialize LLMs for different purposes - now these will be cached
basic_llm = get_llm_by_type("basic")
print(basic_llm.invoke("Hello"))
+3
View File
@@ -61,6 +61,7 @@ async def chat_stream(request: ChatRequest):
thread_id,
request.max_plan_iterations,
request.max_step_num,
request.max_search_results,
request.auto_accepted_plan,
request.interrupt_feedback,
request.mcp_settings,
@@ -75,6 +76,7 @@ async def _astream_workflow_generator(
thread_id: str,
max_plan_iterations: int,
max_step_num: int,
max_search_results: int,
auto_accepted_plan: bool,
interrupt_feedback: str,
mcp_settings: dict,
@@ -101,6 +103,7 @@ async def _astream_workflow_generator(
"thread_id": thread_id,
"max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num,
"max_search_results": max_search_results,
"mcp_settings": mcp_settings,
},
stream_mode=["messages", "updates"],
+3
View File
@@ -38,6 +38,9 @@ class ChatRequest(BaseModel):
max_step_num: Optional[int] = Field(
3, description="The maximum number of steps in a plan"
)
max_search_results: Optional[int] = Field(
3, description="The maximum number of search results"
)
auto_accepted_plan: Optional[bool] = Field(
False, description="Whether to automatically accept the plan"
)
+2 -18
View File
@@ -5,28 +5,12 @@ import os
from .crawl import crawl_tool
from .python_repl import python_repl_tool
from .search import (
tavily_search_tool,
duckduckgo_search_tool,
brave_search_tool,
arxiv_search_tool,
)
from .search import get_web_search_tool
from .tts import VolcengineTTS
from src.config import SELECTED_SEARCH_ENGINE, SearchEngine
# Map search engine names to their respective tools
search_tool_mappings = {
SearchEngine.TAVILY.value: tavily_search_tool,
SearchEngine.DUCKDUCKGO.value: duckduckgo_search_tool,
SearchEngine.BRAVE_SEARCH.value: brave_search_tool,
SearchEngine.ARXIV.value: arxiv_search_tool,
}
web_search_tool = search_tool_mappings.get(SELECTED_SEARCH_ENGINE, tavily_search_tool)
__all__ = [
"crawl_tool",
"web_search_tool",
"python_repl_tool",
"get_web_search_tool",
"VolcengineTTS",
]
+37 -33
View File
@@ -9,7 +9,7 @@ from langchain_community.tools import BraveSearch, DuckDuckGoSearchResults
from langchain_community.tools.arxiv import ArxivQueryRun
from langchain_community.utilities import ArxivAPIWrapper, BraveSearchWrapper
from src.config import SEARCH_MAX_RESULTS, SearchEngine
from src.config import SearchEngine, SELECTED_SEARCH_ENGINE
from src.tools.tavily_search.tavily_search_results_with_images import (
TavilySearchResultsWithImages,
)
@@ -18,44 +18,48 @@ from src.tools.decorators import create_logged_tool
logger = logging.getLogger(__name__)
# Create logged versions of the search tools
LoggedTavilySearch = create_logged_tool(TavilySearchResultsWithImages)
if os.getenv("SEARCH_API", "") == SearchEngine.TAVILY.value:
tavily_search_tool = LoggedTavilySearch(
name="web_search",
max_results=SEARCH_MAX_RESULTS,
include_raw_content=True,
include_images=True,
include_image_descriptions=True,
)
else:
tavily_search_tool = None
LoggedDuckDuckGoSearch = create_logged_tool(DuckDuckGoSearchResults)
duckduckgo_search_tool = LoggedDuckDuckGoSearch(
name="web_search", max_results=SEARCH_MAX_RESULTS
)
LoggedBraveSearch = create_logged_tool(BraveSearch)
brave_search_tool = LoggedBraveSearch(
name="web_search",
search_wrapper=BraveSearchWrapper(
api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""),
search_kwargs={"count": SEARCH_MAX_RESULTS},
),
)
LoggedArxivSearch = create_logged_tool(ArxivQueryRun)
arxiv_search_tool = LoggedArxivSearch(
name="web_search",
api_wrapper=ArxivAPIWrapper(
top_k_results=SEARCH_MAX_RESULTS,
load_max_docs=SEARCH_MAX_RESULTS,
load_all_available_meta=True,
),
)
# Get the selected search tool
def get_web_search_tool(max_search_results: int):
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
return LoggedTavilySearch(
name="web_search",
max_results=max_search_results,
include_raw_content=True,
include_images=True,
include_image_descriptions=True,
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value:
return LoggedDuckDuckGoSearch(name="web_search", max_results=max_search_results)
elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value:
return LoggedBraveSearch(
name="web_search",
search_wrapper=BraveSearchWrapper(
api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""),
search_kwargs={"count": max_search_results},
),
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.ARXIV.value:
return LoggedArxivSearch(
name="web_search",
api_wrapper=ArxivAPIWrapper(
top_k_results=max_search_results,
load_max_docs=max_search_results,
load_all_available_meta=True,
),
)
else:
raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}")
if __name__ == "__main__":
results = LoggedDuckDuckGoSearch(
name="web_search", max_results=SEARCH_MAX_RESULTS, output_format="list"
name="web_search", max_results=3, output_format="list"
).invoke("cute panda")
print(json.dumps(results, indent=2, ensure_ascii=False))
+128
View File
@@ -0,0 +1,128 @@
import json
import pytest
from unittest.mock import patch, MagicMock
# 在这里 mock 掉 get_llm_by_type,避免 ValueError
with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()):
from langgraph.types import Command
from src.graph.nodes import background_investigation_node
from src.config import SearchEngine
from langchain_core.messages import HumanMessage
# Mock data
MOCK_SEARCH_RESULTS = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
@pytest.fixture
def mock_state():
return {
"messages": [HumanMessage(content="test query")],
"background_investigation_results": None,
}
@pytest.fixture
def mock_configurable():
mock = MagicMock()
mock.max_search_results = 5
return mock
@pytest.fixture
def mock_config():
# 你可以根据实际需要返回一个 MagicMock 或 dict
return MagicMock()
@pytest.fixture
def patch_config_from_runnable_config(mock_configurable):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
):
yield
@pytest.fixture
def mock_tavily_search():
with patch("src.graph.nodes.LoggedTavilySearch") as mock:
instance = mock.return_value
instance.invoke.return_value = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
yield mock
@pytest.fixture
def mock_web_search_tool():
with patch("src.graph.nodes.get_web_search_tool") as mock:
instance = mock.return_value
instance.invoke.return_value = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
yield mock
@pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY.value, "other"])
def test_background_investigation_node_tavily(
mock_state,
mock_tavily_search,
mock_web_search_tool,
search_engine,
patch_config_from_runnable_config,
mock_config,
):
"""Test background_investigation_node with Tavily search engine"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", search_engine):
result = background_investigation_node(mock_state, mock_config)
# Verify the result structure
assert isinstance(result, Command)
assert result.goto == "planner"
# Verify the update contains background_investigation_results
update = result.update
assert "background_investigation_results" in update
# Parse and verify the JSON content
results = json.loads(update["background_investigation_results"])
assert isinstance(results, list)
if search_engine == SearchEngine.TAVILY.value:
mock_tavily_search.return_value.invoke.assert_called_once_with("test query")
assert len(results) == 2
assert results[0]["title"] == "Test Title 1"
assert results[0]["content"] == "Test Content 1"
else:
mock_web_search_tool.return_value.invoke.assert_called_once_with(
"test query"
)
assert len(results) == 2
def test_background_investigation_node_malformed_response(
mock_state, mock_tavily_search, patch_config_from_runnable_config, mock_config
):
"""Test background_investigation_node with malformed Tavily response"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value):
# Mock a malformed response
mock_tavily_search.return_value.invoke.return_value = "invalid response"
result = background_investigation_node(mock_state, mock_config)
# Verify the result structure
assert isinstance(result, Command)
assert result.goto == "planner"
# Verify the update contains background_investigation_results
update = result.update
assert "background_investigation_results" in update
# Parse and verify the JSON content
results = json.loads(update["background_investigation_results"])
assert results is None
+27
View File
@@ -32,6 +32,9 @@ const generalFormSchema = z.object({
maxStepNum: z.number().min(1, {
message: "Max step number must be at least 1.",
}),
maxSearchResults: z.number().min(1, {
message: "Max search results must be at least 1.",
}),
});
export const GeneralTab: Tab = ({
@@ -143,6 +146,30 @@ export const GeneralTab: Tab = ({
</FormItem>
)}
/>
<FormField
control={form.control}
name="maxSearchResults"
render={({ field }) => (
<FormItem>
<FormLabel>Max search results</FormLabel>
<FormControl>
<Input
className="w-60"
type="number"
defaultValue={field.value}
min={1}
onChange={(event) =>
field.onChange(parseInt(event.target.value || "0"))
}
/>
</FormControl>
<FormDescription>
By default, each search step has 3 results.
</FormDescription>
<FormMessage />
</FormItem>
)}
/>
</form>
</Form>
</main>
+4
View File
@@ -18,6 +18,7 @@ export async function* chatStream(
auto_accepted_plan: boolean;
max_plan_iterations: number;
max_step_num: number;
max_search_results?: number;
interrupt_feedback?: string;
enable_background_investigation: boolean;
mcp_settings?: {
@@ -61,12 +62,14 @@ async function* chatReplayStream(
auto_accepted_plan: boolean;
max_plan_iterations: number;
max_step_num: number;
max_search_results?: number;
interrupt_feedback?: string;
} = {
thread_id: "__mock__",
auto_accepted_plan: false,
max_plan_iterations: 3,
max_step_num: 1,
max_search_results: 3,
interrupt_feedback: undefined,
},
options: { abortSignal?: AbortSignal } = {},
@@ -157,6 +160,7 @@ export async function fetchReplayTitle() {
auto_accepted_plan: false,
max_plan_iterations: 3,
max_step_num: 1,
max_search_results: 3,
},
{},
);
+2
View File
@@ -13,6 +13,7 @@ const DEFAULT_SETTINGS: SettingsState = {
enableBackgroundInvestigation: false,
maxPlanIterations: 1,
maxStepNum: 3,
maxSearchResults: 3,
},
mcp: {
servers: [],
@@ -25,6 +26,7 @@ export type SettingsState = {
enableBackgroundInvestigation: boolean;
maxPlanIterations: number;
maxStepNum: number;
maxSearchResults: number;
};
mcp: {
servers: MCPServerMetadata[];
+1
View File
@@ -104,6 +104,7 @@ export async function sendMessage(
settings.enableBackgroundInvestigation ?? true,
max_plan_iterations: settings.maxPlanIterations,
max_step_num: settings.maxStepNum,
max_search_results: settings.maxSearchResults,
mcp_settings: settings.mcpSettings,
},
options,