mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-20 15:11:09 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d9aa92afaa | |||
| 29be360954 | |||
| 3ed70e11d5 | |||
| 55ce399969 |
@@ -11,6 +11,7 @@ static/browser_history/*.gif
|
|||||||
|
|
||||||
# Virtual environments
|
# Virtual environments
|
||||||
.venv
|
.venv
|
||||||
|
venv/
|
||||||
|
|
||||||
# Environment variables
|
# Environment variables
|
||||||
.env
|
.env
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ Server script for running the DeerFlow API.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
@@ -18,6 +19,17 @@ logging.basicConfig(
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_shutdown(signum, frame):
|
||||||
|
"""Handle graceful shutdown on SIGTERM/SIGINT"""
|
||||||
|
logger.info("Received shutdown signal. Starting graceful shutdown...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
# Register signal handlers
|
||||||
|
signal.signal(signal.SIGTERM, handle_shutdown)
|
||||||
|
signal.signal(signal.SIGINT, handle_shutdown)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
parser = argparse.ArgumentParser(description="Run the DeerFlow API server")
|
parser = argparse.ArgumentParser(description="Run the DeerFlow API server")
|
||||||
@@ -50,16 +62,18 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Determine reload setting
|
# Determine reload setting
|
||||||
reload = False
|
reload = False
|
||||||
|
|
||||||
# Command line arguments override defaults
|
|
||||||
if args.reload:
|
if args.reload:
|
||||||
reload = True
|
reload = True
|
||||||
|
|
||||||
logger.info("Starting DeerFlow API server")
|
try:
|
||||||
uvicorn.run(
|
logger.info(f"Starting DeerFlow API server on {args.host}:{args.port}")
|
||||||
"src.server:app",
|
uvicorn.run(
|
||||||
host=args.host,
|
"src.server:app",
|
||||||
port=args.port,
|
host=args.host,
|
||||||
reload=reload,
|
port=args.port,
|
||||||
log_level=args.log_level,
|
reload=reload,
|
||||||
)
|
log_level=args.log_level,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start server: {str(e)}")
|
||||||
|
sys.exit(1)
|
||||||
|
|||||||
+2
-2
@@ -50,10 +50,10 @@ def background_investigation_node(
|
|||||||
logger.info("background investigation node is running.")
|
logger.info("background investigation node is running.")
|
||||||
configurable = Configuration.from_runnable_config(config)
|
configurable = Configuration.from_runnable_config(config)
|
||||||
query = state["messages"][-1].content
|
query = state["messages"][-1].content
|
||||||
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY:
|
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
|
||||||
searched_content = LoggedTavilySearch(
|
searched_content = LoggedTavilySearch(
|
||||||
max_results=configurable.max_search_results
|
max_results=configurable.max_search_results
|
||||||
).invoke({"query": query})
|
).invoke(query)
|
||||||
background_investigation_results = None
|
background_investigation_results = None
|
||||||
if isinstance(searched_content, list):
|
if isinstance(searched_content, list):
|
||||||
background_investigation_results = [
|
background_investigation_results = [
|
||||||
|
|||||||
+2
-3
@@ -44,13 +44,12 @@ def get_llm_by_type(
|
|||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
|
||||||
# Initialize LLMs for different purposes - now these will be cached
|
|
||||||
basic_llm = get_llm_by_type("basic")
|
|
||||||
|
|
||||||
# In the future, we will use reasoning_llm and vl_llm for different purposes
|
# In the future, we will use reasoning_llm and vl_llm for different purposes
|
||||||
# reasoning_llm = get_llm_by_type("reasoning")
|
# reasoning_llm = get_llm_by_type("reasoning")
|
||||||
# vl_llm = get_llm_by_type("vision")
|
# vl_llm = get_llm_by_type("vision")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# Initialize LLMs for different purposes - now these will be cached
|
||||||
|
basic_llm = get_llm_by_type("basic")
|
||||||
print(basic_llm.invoke("Hello"))
|
print(basic_llm.invoke("Hello"))
|
||||||
|
|||||||
@@ -0,0 +1,128 @@
|
|||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
# 在这里 mock 掉 get_llm_by_type,避免 ValueError
|
||||||
|
with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()):
|
||||||
|
from langgraph.types import Command
|
||||||
|
from src.graph.nodes import background_investigation_node
|
||||||
|
from src.config import SearchEngine
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
# Mock data
|
||||||
|
MOCK_SEARCH_RESULTS = [
|
||||||
|
{"title": "Test Title 1", "content": "Test Content 1"},
|
||||||
|
{"title": "Test Title 2", "content": "Test Content 2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_state():
|
||||||
|
return {
|
||||||
|
"messages": [HumanMessage(content="test query")],
|
||||||
|
"background_investigation_results": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_configurable():
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.max_search_results = 5
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config():
|
||||||
|
# 你可以根据实际需要返回一个 MagicMock 或 dict
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patch_config_from_runnable_config(mock_configurable):
|
||||||
|
with patch(
|
||||||
|
"src.graph.nodes.Configuration.from_runnable_config",
|
||||||
|
return_value=mock_configurable,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tavily_search():
|
||||||
|
with patch("src.graph.nodes.LoggedTavilySearch") as mock:
|
||||||
|
instance = mock.return_value
|
||||||
|
instance.invoke.return_value = [
|
||||||
|
{"title": "Test Title 1", "content": "Test Content 1"},
|
||||||
|
{"title": "Test Title 2", "content": "Test Content 2"},
|
||||||
|
]
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_web_search_tool():
|
||||||
|
with patch("src.graph.nodes.get_web_search_tool") as mock:
|
||||||
|
instance = mock.return_value
|
||||||
|
instance.invoke.return_value = [
|
||||||
|
{"title": "Test Title 1", "content": "Test Content 1"},
|
||||||
|
{"title": "Test Title 2", "content": "Test Content 2"},
|
||||||
|
]
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY.value, "other"])
|
||||||
|
def test_background_investigation_node_tavily(
|
||||||
|
mock_state,
|
||||||
|
mock_tavily_search,
|
||||||
|
mock_web_search_tool,
|
||||||
|
search_engine,
|
||||||
|
patch_config_from_runnable_config,
|
||||||
|
mock_config,
|
||||||
|
):
|
||||||
|
"""Test background_investigation_node with Tavily search engine"""
|
||||||
|
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", search_engine):
|
||||||
|
result = background_investigation_node(mock_state, mock_config)
|
||||||
|
|
||||||
|
# Verify the result structure
|
||||||
|
assert isinstance(result, Command)
|
||||||
|
assert result.goto == "planner"
|
||||||
|
|
||||||
|
# Verify the update contains background_investigation_results
|
||||||
|
update = result.update
|
||||||
|
assert "background_investigation_results" in update
|
||||||
|
|
||||||
|
# Parse and verify the JSON content
|
||||||
|
results = json.loads(update["background_investigation_results"])
|
||||||
|
assert isinstance(results, list)
|
||||||
|
|
||||||
|
if search_engine == SearchEngine.TAVILY.value:
|
||||||
|
mock_tavily_search.return_value.invoke.assert_called_once_with("test query")
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0]["title"] == "Test Title 1"
|
||||||
|
assert results[0]["content"] == "Test Content 1"
|
||||||
|
else:
|
||||||
|
mock_web_search_tool.return_value.invoke.assert_called_once_with(
|
||||||
|
"test query"
|
||||||
|
)
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_background_investigation_node_malformed_response(
|
||||||
|
mock_state, mock_tavily_search, patch_config_from_runnable_config, mock_config
|
||||||
|
):
|
||||||
|
"""Test background_investigation_node with malformed Tavily response"""
|
||||||
|
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value):
|
||||||
|
# Mock a malformed response
|
||||||
|
mock_tavily_search.return_value.invoke.return_value = "invalid response"
|
||||||
|
|
||||||
|
result = background_investigation_node(mock_state, mock_config)
|
||||||
|
|
||||||
|
# Verify the result structure
|
||||||
|
assert isinstance(result, Command)
|
||||||
|
assert result.goto == "planner"
|
||||||
|
|
||||||
|
# Verify the update contains background_investigation_results
|
||||||
|
update = result.update
|
||||||
|
assert "background_investigation_results" in update
|
||||||
|
|
||||||
|
# Parse and verify the JSON content
|
||||||
|
results = json.loads(update["background_investigation_results"])
|
||||||
|
assert results is None
|
||||||
Reference in New Issue
Block a user