Compare commits

...

4 Commits

Author SHA1 Message Date
hetao d9aa92afaa fix: fix unittes & background investigation search logic 2025-05-28 12:46:55 +08:00
wushiai1109 29be360954 Update nodes.py (#242)
SELECTED_SEARCH_ENGINE impossible equal to SearchEngine.ARXIV, should be SearchEngine.ARXIV.value, or use the encapsulated get_web_search_tool
2025-05-27 18:58:14 +08:00
Harsha Vardhan Mannem 3ed70e11d5 Fix/server error handling (#212)
* chore: add venv/ to gitignore

* fix: add server error handling and graceful shutdown

* Fix linting issues in server.py
2025-05-22 13:45:07 +08:00
laundry 55ce399969 test: add background node unit test (#198)
* test: add background node unit test

Change-Id: Ia99f5a1687464387dcb01bbee04deaa371c6e490

* test: add background node unit test

Change-Id: I9aabcf02ff04fda40c56f3ea22abe6b8f93bf9b6

* test: fix test error

Change-Id: I3997dc53a2cfaa35501a1fbda5902ee15528124e

* test: fix unit test error

Change-Id: If4c4cd10673e76a30945674c7cda198aeabf28d0

* test: fix unit test error

Change-Id: I3dd7a6179132e5497a30ada443d88de0c47af3d4
2025-05-20 14:25:35 +08:00
5 changed files with 158 additions and 16 deletions
+1
View File
@@ -11,6 +11,7 @@ static/browser_history/*.gif
# Virtual environments
.venv
venv/
# Environment variables
.env
+25 -11
View File
@@ -7,7 +7,8 @@ Server script for running the DeerFlow API.
import argparse
import logging
import signal
import sys
import uvicorn
# Configure logging
@@ -18,6 +19,17 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
def handle_shutdown(signum, frame):
"""Handle graceful shutdown on SIGTERM/SIGINT"""
logger.info("Received shutdown signal. Starting graceful shutdown...")
sys.exit(0)
# Register signal handlers
signal.signal(signal.SIGTERM, handle_shutdown)
signal.signal(signal.SIGINT, handle_shutdown)
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser(description="Run the DeerFlow API server")
@@ -50,16 +62,18 @@ if __name__ == "__main__":
# Determine reload setting
reload = False
# Command line arguments override defaults
if args.reload:
reload = True
logger.info("Starting DeerFlow API server")
uvicorn.run(
"src.server:app",
host=args.host,
port=args.port,
reload=reload,
log_level=args.log_level,
)
try:
logger.info(f"Starting DeerFlow API server on {args.host}:{args.port}")
uvicorn.run(
"src.server:app",
host=args.host,
port=args.port,
reload=reload,
log_level=args.log_level,
)
except Exception as e:
logger.error(f"Failed to start server: {str(e)}")
sys.exit(1)
+2 -2
View File
@@ -50,10 +50,10 @@ def background_investigation_node(
logger.info("background investigation node is running.")
configurable = Configuration.from_runnable_config(config)
query = state["messages"][-1].content
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY:
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
searched_content = LoggedTavilySearch(
max_results=configurable.max_search_results
).invoke({"query": query})
).invoke(query)
background_investigation_results = None
if isinstance(searched_content, list):
background_investigation_results = [
+2 -3
View File
@@ -44,13 +44,12 @@ def get_llm_by_type(
return llm
# Initialize LLMs for different purposes - now these will be cached
basic_llm = get_llm_by_type("basic")
# In the future, we will use reasoning_llm and vl_llm for different purposes
# reasoning_llm = get_llm_by_type("reasoning")
# vl_llm = get_llm_by_type("vision")
if __name__ == "__main__":
# Initialize LLMs for different purposes - now these will be cached
basic_llm = get_llm_by_type("basic")
print(basic_llm.invoke("Hello"))
+128
View File
@@ -0,0 +1,128 @@
import json
import pytest
from unittest.mock import patch, MagicMock
# 在这里 mock 掉 get_llm_by_type,避免 ValueError
with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()):
from langgraph.types import Command
from src.graph.nodes import background_investigation_node
from src.config import SearchEngine
from langchain_core.messages import HumanMessage
# Mock data
MOCK_SEARCH_RESULTS = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
@pytest.fixture
def mock_state():
return {
"messages": [HumanMessage(content="test query")],
"background_investigation_results": None,
}
@pytest.fixture
def mock_configurable():
mock = MagicMock()
mock.max_search_results = 5
return mock
@pytest.fixture
def mock_config():
# 你可以根据实际需要返回一个 MagicMock 或 dict
return MagicMock()
@pytest.fixture
def patch_config_from_runnable_config(mock_configurable):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
):
yield
@pytest.fixture
def mock_tavily_search():
with patch("src.graph.nodes.LoggedTavilySearch") as mock:
instance = mock.return_value
instance.invoke.return_value = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
yield mock
@pytest.fixture
def mock_web_search_tool():
with patch("src.graph.nodes.get_web_search_tool") as mock:
instance = mock.return_value
instance.invoke.return_value = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
yield mock
@pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY.value, "other"])
def test_background_investigation_node_tavily(
mock_state,
mock_tavily_search,
mock_web_search_tool,
search_engine,
patch_config_from_runnable_config,
mock_config,
):
"""Test background_investigation_node with Tavily search engine"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", search_engine):
result = background_investigation_node(mock_state, mock_config)
# Verify the result structure
assert isinstance(result, Command)
assert result.goto == "planner"
# Verify the update contains background_investigation_results
update = result.update
assert "background_investigation_results" in update
# Parse and verify the JSON content
results = json.loads(update["background_investigation_results"])
assert isinstance(results, list)
if search_engine == SearchEngine.TAVILY.value:
mock_tavily_search.return_value.invoke.assert_called_once_with("test query")
assert len(results) == 2
assert results[0]["title"] == "Test Title 1"
assert results[0]["content"] == "Test Content 1"
else:
mock_web_search_tool.return_value.invoke.assert_called_once_with(
"test query"
)
assert len(results) == 2
def test_background_investigation_node_malformed_response(
mock_state, mock_tavily_search, patch_config_from_runnable_config, mock_config
):
"""Test background_investigation_node with malformed Tavily response"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value):
# Mock a malformed response
mock_tavily_search.return_value.invoke.return_value = "invalid response"
result = background_investigation_node(mock_state, mock_config)
# Verify the result structure
assert isinstance(result, Command)
assert result.goto == "planner"
# Verify the update contains background_investigation_results
update = result.update
assert "background_investigation_results" in update
# Parse and verify the JSON content
results = json.loads(update["background_investigation_results"])
assert results is None