Prepare to merge deer-flow-2

This commit is contained in:
Willem Jiang
2026-02-14 16:28:12 +08:00
parent 06248fa6f1
commit a66d8c94fa
451 changed files with 0 additions and 142650 deletions
-29
View File
@@ -1,29 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from src.crawler import Crawler
def test_crawler_initialization():
"""Test that crawler can be properly initialized."""
crawler = Crawler()
assert isinstance(crawler, Crawler)
def test_crawler_crawl_valid_url():
"""Test crawling with a valid URL."""
crawler = Crawler()
test_url = "https://finance.sina.com.cn/stock/relnews/us/2024-08-15/doc-incitsya6536375.shtml"
result = crawler.crawl(test_url)
assert result is not None
assert hasattr(result, "to_markdown")
def test_crawler_markdown_output():
"""Test that crawler output can be converted to markdown."""
crawler = Crawler()
test_url = "https://finance.sina.com.cn/stock/relnews/us/2024-08-15/doc-incitsya6536375.shtml"
result = crawler.crawl(test_url)
markdown = result.to_markdown()
assert isinstance(markdown, str)
assert len(markdown) > 0
File diff suppressed because it is too large Load Diff
-144
View File
@@ -1,144 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
from src.prompts.template import apply_prompt_template, get_prompt_template
def test_get_prompt_template_success():
"""Test successful template loading"""
template = get_prompt_template("coder")
assert template is not None
assert isinstance(template, str)
assert len(template) > 0
def test_get_prompt_template_not_found():
"""Test handling of non-existent template"""
with pytest.raises(ValueError) as exc_info:
get_prompt_template("non_existent_template")
assert "Error loading template" in str(exc_info.value)
def test_apply_prompt_template():
"""Test template variable substitution"""
test_state = {
"messages": [{"role": "user", "content": "test message"}],
"task": "test task",
"workspace_context": "test context",
}
messages = apply_prompt_template("coder", test_state)
assert isinstance(messages, list)
assert len(messages) > 1
assert messages[0]["role"] == "system"
assert "CURRENT_TIME" in messages[0]["content"]
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "test message"
def test_apply_prompt_template_empty_messages():
"""Test template with empty messages list"""
test_state = {
"messages": [],
"task": "test task",
"workspace_context": "test context",
}
messages = apply_prompt_template("coder", test_state)
assert len(messages) == 1 # Only system message
assert messages[0]["role"] == "system"
def test_apply_prompt_template_multiple_messages():
"""Test template with multiple messages"""
test_state = {
"messages": [
{"role": "user", "content": "first message"},
{"role": "assistant", "content": "response"},
{"role": "user", "content": "second message"},
],
"task": "test task",
"workspace_context": "test context",
}
messages = apply_prompt_template("coder", test_state)
assert len(messages) == 4 # system + 3 messages
assert messages[0]["role"] == "system"
assert all(m["role"] in ["system", "user", "assistant"] for m in messages)
def test_apply_prompt_template_with_special_chars():
"""Test template with special characters in variables"""
test_state = {
"messages": [{"role": "user", "content": "test\nmessage\"with'special{chars}"}],
"task": "task with $pecial ch@rs",
"workspace_context": "<html>context</html>",
}
messages = apply_prompt_template("coder", test_state)
assert messages[1]["content"] == "test\nmessage\"with'special{chars}"
@pytest.mark.parametrize("prompt_name", ["coder", "coder", "coordinator", "planner"])
def test_multiple_template_types(prompt_name):
"""Test loading different types of templates"""
template = get_prompt_template(prompt_name)
assert template is not None
assert isinstance(template, str)
assert len(template) > 0
def test_current_time_format():
"""Test the format of CURRENT_TIME in rendered template"""
test_state = {
"messages": [{"role": "user", "content": "test"}],
"task": "test",
"workspace_context": "test",
}
messages = apply_prompt_template("coder", test_state)
system_content = messages[0]["content"]
assert any(
line.strip().startswith("CURRENT_TIME:") for line in system_content.split("\n")
)
def test_apply_prompt_template_reporter():
"""Test reporter template rendering with different styles and locale"""
test_state_news = {
"messages": [],
"task": "test reporter task",
"workspace_context": "test reporter context",
"report_style": "news",
"locale": "en-US",
}
messages_news = apply_prompt_template("reporter", test_state_news)
system_content_news = messages_news[0]["content"]
assert "NBC News" in system_content_news
test_state_social_media_en = {
"messages": [],
"task": "test reporter task",
"workspace_context": "test reporter context",
"report_style": "social_media",
"locale": "en-US",
}
messages_default = apply_prompt_template("reporter", test_state_social_media_en)
system_content_default = messages_default[0]["content"]
assert "Twitter/X" in system_content_default
test_state_social_media_cn = {
"messages": [],
"task": "test reporter task",
"workspace_context": "test reporter context",
"report_style": "social_media",
"locale": "zh-CN",
}
messages_cn = apply_prompt_template("reporter", test_state_social_media_cn)
system_content_cn = messages_cn[0]["content"]
assert "小红书" in system_content_cn
@@ -1,473 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Integration tests for tool-specific interrupts feature (Issue #572).
Tests the complete flow of selective tool interrupts including:
- Tool wrapping with interrupt logic
- Agent creation with interrupt configuration
- Tool execution with user feedback
- Resume mechanism after interrupt
"""
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from src.agents.agents import create_agent
from src.agents.tool_interceptor import ToolInterceptor, wrap_tools_with_interceptor
from src.config.configuration import Configuration
from src.server.chat_request import ChatRequest
class TestToolInterceptorIntegration:
"""Integration tests for tool interceptor with agent workflow."""
def test_agent_creation_with_tool_interrupts(self):
"""Test creating an agent with tool interrupts configured."""
@tool
def search_tool(query: str) -> str:
"""Search the web."""
return f"Search results for: {query}"
@tool
def db_tool(query: str) -> str:
"""Query database."""
return f"DB results for: {query}"
tools = [search_tool, db_tool]
# Create agent with interrupts on db_tool only
with patch("src.agents.agents.langchain_create_agent") as mock_create, \
patch("src.agents.agents.get_llm_by_type") as mock_llm:
mock_create.return_value = MagicMock()
mock_llm.return_value = MagicMock()
agent = create_agent(
agent_name="test_agent",
agent_type="researcher",
tools=tools,
prompt_template="researcher",
interrupt_before_tools=["db_tool"],
)
# Verify langchain_create_agent was called with wrapped tools
assert mock_create.called
call_args = mock_create.call_args
wrapped_tools = call_args.kwargs["tools"]
# Should have wrapped the tools
assert len(wrapped_tools) == 2
assert wrapped_tools[0].name == "search_tool"
assert wrapped_tools[1].name == "db_tool"
def test_configuration_with_tool_interrupts(self):
"""Test Configuration object with interrupt_before_tools."""
config = Configuration(
interrupt_before_tools=["db_tool", "api_write_tool"],
max_step_num=3,
max_search_results=5,
)
assert config.interrupt_before_tools == ["db_tool", "api_write_tool"]
assert config.max_step_num == 3
assert config.max_search_results == 5
def test_configuration_default_no_interrupts(self):
"""Test Configuration defaults to no interrupts."""
config = Configuration()
assert config.interrupt_before_tools == []
def test_chat_request_with_tool_interrupts(self):
"""Test ChatRequest with interrupt_before_tools."""
request = ChatRequest(
messages=[{"role": "user", "content": "Search for X"}],
interrupt_before_tools=["db_tool", "payment_api"],
)
assert request.interrupt_before_tools == ["db_tool", "payment_api"]
def test_chat_request_interrupt_feedback_with_tool_interrupts(self):
"""Test ChatRequest with both interrupt_before_tools and interrupt_feedback."""
request = ChatRequest(
messages=[{"role": "user", "content": "Research topic"}],
interrupt_before_tools=["db_tool"],
interrupt_feedback="approved",
)
assert request.interrupt_before_tools == ["db_tool"]
assert request.interrupt_feedback == "approved"
def test_multiple_tools_selective_interrupt(self):
"""Test that only specified tools trigger interrupts."""
@tool
def tool_a(x: str) -> str:
"""Tool A"""
return f"A: {x}"
@tool
def tool_b(x: str) -> str:
"""Tool B"""
return f"B: {x}"
@tool
def tool_c(x: str) -> str:
"""Tool C"""
return f"C: {x}"
tools = [tool_a, tool_b, tool_c]
interceptor = ToolInterceptor(["tool_b"])
# Wrap all tools
wrapped_tools = wrap_tools_with_interceptor(tools, ["tool_b"])
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
# tool_a should not interrupt
mock_interrupt.return_value = "approved"
result_a = wrapped_tools[0].invoke("test")
mock_interrupt.assert_not_called()
# tool_b should interrupt
result_b = wrapped_tools[1].invoke("test")
mock_interrupt.assert_called()
# tool_c should not interrupt
mock_interrupt.reset_mock()
result_c = wrapped_tools[2].invoke("test")
mock_interrupt.assert_not_called()
def test_interrupt_with_user_approval(self):
"""Test interrupt flow with user approval."""
@tool
def sensitive_tool(action: str) -> str:
"""A sensitive tool."""
return f"Executed: {action}"
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
mock_interrupt.return_value = "approved"
interceptor = ToolInterceptor(["sensitive_tool"])
wrapped = ToolInterceptor.wrap_tool(sensitive_tool, interceptor)
result = wrapped.invoke("delete_data")
mock_interrupt.assert_called()
assert "Executed: delete_data" in str(result)
def test_interrupt_with_user_rejection(self):
"""Test interrupt flow with user rejection."""
@tool
def sensitive_tool(action: str) -> str:
"""A sensitive tool."""
return f"Executed: {action}"
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
mock_interrupt.return_value = "rejected"
interceptor = ToolInterceptor(["sensitive_tool"])
wrapped = ToolInterceptor.wrap_tool(sensitive_tool, interceptor)
result = wrapped.invoke("delete_data")
mock_interrupt.assert_called()
assert isinstance(result, dict)
assert "error" in result
assert result["status"] == "rejected"
def test_interrupt_message_contains_tool_info(self):
"""Test that interrupt message contains tool name and input."""
@tool
def db_query_tool(query: str) -> str:
"""Database query tool."""
return f"Query result: {query}"
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
mock_interrupt.return_value = "approved"
interceptor = ToolInterceptor(["db_query_tool"])
wrapped = ToolInterceptor.wrap_tool(db_query_tool, interceptor)
wrapped.invoke("SELECT * FROM users")
# Verify interrupt was called with meaningful message
mock_interrupt.assert_called()
interrupt_message = mock_interrupt.call_args[0][0]
assert "db_query_tool" in interrupt_message
assert "SELECT * FROM users" in interrupt_message
def test_tool_wrapping_preserves_functionality(self):
"""Test that tool wrapping preserves original tool functionality."""
@tool
def simple_tool(text: str) -> str:
"""Process text."""
return f"Processed: {text}"
interceptor = ToolInterceptor([]) # No interrupts
wrapped = ToolInterceptor.wrap_tool(simple_tool, interceptor)
result = wrapped.invoke({"text": "hello"})
assert "hello" in str(result)
def test_tool_wrapping_preserves_tool_metadata(self):
"""Test that tool wrapping preserves tool name and description."""
@tool
def my_special_tool(x: str) -> str:
"""This is my special tool description."""
return f"Result: {x}"
interceptor = ToolInterceptor([])
wrapped = ToolInterceptor.wrap_tool(my_special_tool, interceptor)
assert wrapped.name == "my_special_tool"
assert "special tool" in wrapped.description.lower()
def test_multiple_interrupts_in_sequence(self):
"""Test handling multiple tool interrupts in sequence."""
@tool
def tool_one(x: str) -> str:
"""Tool one."""
return f"One: {x}"
@tool
def tool_two(x: str) -> str:
"""Tool two."""
return f"Two: {x}"
@tool
def tool_three(x: str) -> str:
"""Tool three."""
return f"Three: {x}"
tools = [tool_one, tool_two, tool_three]
wrapped_tools = wrap_tools_with_interceptor(
tools, ["tool_one", "tool_two"]
)
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
mock_interrupt.return_value = "approved"
# First interrupt
result_one = wrapped_tools[0].invoke("first")
assert mock_interrupt.call_count == 1
# Second interrupt
result_two = wrapped_tools[1].invoke("second")
assert mock_interrupt.call_count == 2
# Third (no interrupt)
result_three = wrapped_tools[2].invoke("third")
assert mock_interrupt.call_count == 2
assert "One: first" in str(result_one)
assert "Two: second" in str(result_two)
assert "Three: third" in str(result_three)
def test_empty_interrupt_list_no_interrupts(self):
"""Test that empty interrupt list doesn't trigger interrupts."""
@tool
def test_tool(x: str) -> str:
"""Test tool."""
return f"Result: {x}"
wrapped_tools = wrap_tools_with_interceptor([test_tool], [])
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
wrapped_tools[0].invoke("test")
mock_interrupt.assert_not_called()
def test_none_interrupt_list_no_interrupts(self):
"""Test that None interrupt list doesn't trigger interrupts."""
@tool
def test_tool(x: str) -> str:
"""Test tool."""
return f"Result: {x}"
wrapped_tools = wrap_tools_with_interceptor([test_tool], None)
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
wrapped_tools[0].invoke("test")
mock_interrupt.assert_not_called()
def test_case_sensitive_tool_name_matching(self):
"""Test that tool name matching is case-sensitive."""
@tool
def MyTool(x: str) -> str:
"""A tool."""
return f"Result: {x}"
interceptor_lower = ToolInterceptor(["mytool"])
interceptor_exact = ToolInterceptor(["MyTool"])
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
mock_interrupt.return_value = "approved"
# Case mismatch - should NOT interrupt
wrapped_lower = ToolInterceptor.wrap_tool(MyTool, interceptor_lower)
result_lower = wrapped_lower.invoke("test")
mock_interrupt.assert_not_called()
# Case match - should interrupt
wrapped_exact = ToolInterceptor.wrap_tool(MyTool, interceptor_exact)
result_exact = wrapped_exact.invoke("test")
mock_interrupt.assert_called()
def test_tool_error_handling(self):
"""Test handling of tool errors during execution."""
@tool
def error_tool(x: str) -> str:
"""A tool that raises an error."""
raise ValueError(f"Intentional error: {x}")
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
mock_interrupt.return_value = "approved"
interceptor = ToolInterceptor(["error_tool"])
wrapped = ToolInterceptor.wrap_tool(error_tool, interceptor)
with pytest.raises(ValueError) as exc_info:
wrapped.invoke("test")
assert "Intentional error: test" in str(exc_info.value)
def test_approval_keywords_comprehensive(self):
"""Test all approved keywords are recognized."""
approval_keywords = [
"approved",
"approve",
"yes",
"proceed",
"continue",
"ok",
"okay",
"accepted",
"accept",
"[approved]",
"APPROVED",
"Proceed with this action",
"[ACCEPTED] I approve",
]
for keyword in approval_keywords:
result = ToolInterceptor._parse_approval(keyword)
assert (
result is True
), f"Keyword '{keyword}' should be approved but got {result}"
def test_rejection_keywords_comprehensive(self):
"""Test that rejection keywords are recognized."""
rejection_keywords = [
"no",
"reject",
"cancel",
"decline",
"stop",
"abort",
"maybe",
"later",
"random text",
"",
]
for keyword in rejection_keywords:
result = ToolInterceptor._parse_approval(keyword)
assert (
result is False
), f"Keyword '{keyword}' should be rejected but got {result}"
def test_interrupt_with_complex_tool_input(self):
"""Test interrupt with complex tool input types."""
@tool
def complex_tool(data: str) -> str:
"""A tool with complex input."""
return f"Processed: {data}"
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
mock_interrupt.return_value = "approved"
interceptor = ToolInterceptor(["complex_tool"])
wrapped = ToolInterceptor.wrap_tool(complex_tool, interceptor)
complex_input = {
"data": "complex data with nested info"
}
result = wrapped.invoke(complex_input)
mock_interrupt.assert_called()
assert "Processed" in str(result)
def test_configuration_from_runnable_config(self):
"""Test Configuration.from_runnable_config with interrupt_before_tools."""
from langchain_core.runnables import RunnableConfig
config = RunnableConfig(
configurable={
"interrupt_before_tools": ["db_tool"],
"max_step_num": 5,
}
)
configuration = Configuration.from_runnable_config(config)
assert configuration.interrupt_before_tools == ["db_tool"]
assert configuration.max_step_num == 5
def test_tool_interceptor_initialization_logging(self):
"""Test that ToolInterceptor initialization is logged."""
with patch("src.agents.tool_interceptor.logger") as mock_logger:
interceptor = ToolInterceptor(["tool_a", "tool_b"])
mock_logger.info.assert_called()
def test_wrap_tools_with_interceptor_logging(self):
"""Test that tool wrapping is logged."""
@tool
def test_tool(x: str) -> str:
"""Test."""
return x
with patch("src.agents.tool_interceptor.logger") as mock_logger:
wrapped = wrap_tools_with_interceptor([test_tool], ["test_tool"])
# Check that at least one info log was called
assert mock_logger.info.called or mock_logger.debug.called
def test_interrupt_resolution_with_empty_feedback(self):
"""Test interrupt resolution with empty feedback."""
@tool
def test_tool(x: str) -> str:
"""Test."""
return f"Result: {x}"
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
mock_interrupt.return_value = ""
interceptor = ToolInterceptor(["test_tool"])
wrapped = ToolInterceptor.wrap_tool(test_tool, interceptor)
result = wrapped.invoke("test")
# Empty feedback should be treated as rejection
assert isinstance(result, dict)
assert result["status"] == "rejected"
def test_interrupt_resolution_with_none_feedback(self):
"""Test interrupt resolution with None feedback."""
@tool
def test_tool(x: str) -> str:
"""Test."""
return f"Result: {x}"
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
mock_interrupt.return_value = None
interceptor = ToolInterceptor(["test_tool"])
wrapped = ToolInterceptor.wrap_tool(test_tool, interceptor)
result = wrapped.invoke("test")
# None feedback should be treated as rejection
assert isinstance(result, dict)
assert result["status"] == "rejected"
-247
View File
@@ -1,247 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import base64
import json
from unittest.mock import MagicMock, patch
from src.tools.tts import VolcengineTTS
class TestVolcengineTTS:
"""Test suite for the VolcengineTTS class."""
def test_initialization(self):
"""Test that VolcengineTTS can be properly initialized."""
tts = VolcengineTTS(
appid="test_appid",
access_token="test_token",
cluster="test_cluster",
voice_type="test_voice",
host="test.host.com",
)
assert tts.appid == "test_appid"
assert tts.access_token == "test_token"
assert tts.cluster == "test_cluster"
assert tts.voice_type == "test_voice"
assert tts.host == "test.host.com"
assert tts.api_url == "https://test.host.com/api/v1/tts"
assert tts.header == {"Authorization": "Bearer;test_token"}
def test_initialization_with_defaults(self):
"""Test initialization with default values."""
tts = VolcengineTTS(
appid="test_appid",
access_token="test_token",
)
assert tts.appid == "test_appid"
assert tts.access_token == "test_token"
assert tts.cluster == "volcano_tts"
assert tts.voice_type == "BV700_V2_streaming"
assert tts.host == "openspeech.bytedance.com"
assert tts.api_url == "https://openspeech.bytedance.com/api/v1/tts"
@patch("src.tools.tts.requests.post")
def test_text_to_speech_success(self, mock_post):
"""Test successful text-to-speech conversion."""
# Mock response
mock_response = MagicMock()
mock_response.status_code = 200
# Create a base64 encoded string for the mock audio data
mock_audio_data = base64.b64encode(b"audio_data").decode()
mock_response.json.return_value = {
"code": 0,
"message": "success",
"data": mock_audio_data,
}
mock_post.return_value = mock_response
# Create TTS client
tts = VolcengineTTS(
appid="test_appid",
access_token="test_token",
)
# Call the method
result = tts.text_to_speech("Hello, world!")
# Verify the result
assert result["success"] is True
assert result["audio_data"] == mock_audio_data
assert "response" in result
# Verify the request
mock_post.assert_called_once()
args, _ = mock_post.call_args
assert args[0] == "https://openspeech.bytedance.com/api/v1/tts"
# Verify request JSON - the data is passed as the second positional argument
request_json = json.loads(args[1])
assert request_json["app"]["appid"] == "test_appid"
assert request_json["app"]["token"] == "test_token"
assert request_json["app"]["cluster"] == "volcano_tts"
assert request_json["audio"]["voice_type"] == "BV700_V2_streaming"
assert request_json["audio"]["encoding"] == "mp3"
assert request_json["request"]["text"] == "Hello, world!"
@patch("src.tools.tts.requests.post")
def test_text_to_speech_api_error(self, mock_post):
"""Test error handling when API returns an error."""
# Mock response
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.json.return_value = {
"code": 400,
"message": "Bad request",
}
mock_post.return_value = mock_response
# Create TTS client
tts = VolcengineTTS(
appid="test_appid",
access_token="test_token",
)
# Call the method
result = tts.text_to_speech("Hello, world!")
# Verify the result
assert result["success"] is False
assert result["error"] == {"code": 400, "message": "Bad request"}
assert result["audio_data"] is None
@patch("src.tools.tts.requests.post")
def test_text_to_speech_no_data(self, mock_post):
"""Test error handling when API response doesn't contain data."""
# Mock response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"code": 0,
"message": "success",
# No data field
}
mock_post.return_value = mock_response
# Create TTS client
tts = VolcengineTTS(
appid="test_appid",
access_token="test_token",
)
# Call the method
result = tts.text_to_speech("Hello, world!")
# Verify the result
assert result["success"] is False
assert result["error"] == "No audio data returned"
assert result["audio_data"] is None
@patch("src.tools.tts.requests.post")
def test_text_to_speech_with_custom_parameters(self, mock_post):
"""Test text_to_speech with custom parameters."""
# Mock response
mock_response = MagicMock()
mock_response.status_code = 200
# Create a base64 encoded string for the mock audio data
mock_audio_data = base64.b64encode(b"audio_data").decode()
mock_response.json.return_value = {
"code": 0,
"message": "success",
"data": mock_audio_data,
}
mock_post.return_value = mock_response
# Create TTS client
tts = VolcengineTTS(
appid="test_appid",
access_token="test_token",
)
# Call the method with custom parameters
result = tts.text_to_speech(
text="Custom text",
encoding="wav",
speed_ratio=1.2,
volume_ratio=0.8,
pitch_ratio=1.1,
text_type="ssml",
with_frontend=0,
frontend_type="custom",
uid="custom-uid",
)
# Verify the result
assert result["success"] is True
assert result["audio_data"] == mock_audio_data
# Verify request JSON - the data is passed as the second positional argument
args, kwargs = mock_post.call_args
request_json = json.loads(args[1])
assert request_json["audio"]["encoding"] == "wav"
assert request_json["audio"]["speed_ratio"] == 1.2
assert request_json["audio"]["volume_ratio"] == 0.8
assert request_json["audio"]["pitch_ratio"] == 1.1
assert request_json["request"]["text"] == "Custom text"
assert request_json["request"]["text_type"] == "ssml"
assert request_json["request"]["with_frontend"] == 0
assert request_json["request"]["frontend_type"] == "custom"
assert request_json["user"]["uid"] == "custom-uid"
@patch("src.tools.tts.requests.post")
@patch("src.tools.tts.uuid.uuid4")
def test_text_to_speech_auto_generated_uid(self, mock_uuid, mock_post):
"""Test that UUID is auto-generated if not provided."""
# Mock UUID
mock_uuid_value = "test-uuid-value"
mock_uuid.return_value = mock_uuid_value
# Mock response
mock_response = MagicMock()
mock_response.status_code = 200
# Create a base64 encoded string for the mock audio data
mock_audio_data = base64.b64encode(b"audio_data").decode()
mock_response.json.return_value = {
"code": 0,
"message": "success",
"data": mock_audio_data,
}
mock_post.return_value = mock_response
# Create TTS client
tts = VolcengineTTS(
appid="test_appid",
access_token="test_token",
)
# Call the method without providing a UID
result = tts.text_to_speech("Hello, world!")
# Verify the result
assert result["success"] is True
assert result["audio_data"] == mock_audio_data
# Verify the request JSON - the data is passed as the second positional argument
args, kwargs = mock_post.call_args
request_json = json.loads(args[1])
assert request_json["user"]["uid"] == str(mock_uuid_value)
@patch("src.tools.tts.requests.post")
def test_text_to_speech_request_exception(self, mock_post):
"""Test error handling when requests.post raises an exception."""
# Mock requests.post to raise an exception
mock_post.side_effect = Exception("Network error")
# Create TTS client
tts = VolcengineTTS(
appid="test_appid",
access_token="test_token",
)
# Call the method
result = tts.text_to_speech("Hello, world!")
# Verify the result
assert result["success"] is False
# The TTS error is caught and returned as a string
assert result["error"] == "TTS API call error"
assert result["audio_data"] is None
-135
View File
@@ -1,135 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for PPT composer localization functionality.
These tests verify that the ppt_composer_node correctly passes locale information
to get_prompt_template, allowing for locale-specific prompt selection.
"""
import pytest
class MockLLMResponse:
"""Mock LLM response object."""
def __init__(self, content: str = "Mock PPT content"):
self.content = content
class MockLLM:
"""Mock LLM model with invoke method."""
def invoke(self, messages):
"""Return a mock response."""
return MockLLMResponse()
class TestPPTLocalization:
"""Test suite for PPT composer locale handling."""
def test_locale_passed_to_prompt_template(self, monkeypatch):
"""
Test that when locale is provided in state, it is passed to get_prompt_template.
This test verifies that ppt_composer_node correctly extracts the locale
from the state dict and passes it to get_prompt_template.
"""
# Track calls to get_prompt_template
captured_calls = []
def mock_get_prompt_template(prompt_name, locale="en-US"):
"""Capture the arguments passed to get_prompt_template."""
captured_calls.append({"prompt_name": prompt_name, "locale": locale})
return "Mock prompt template"
def mock_get_llm_by_type(llm_type):
"""Return a mock LLM."""
return MockLLM()
# Import here to ensure monkeypatching happens before module import
import src.ppt.graph.ppt_composer_node as ppt_module
# Monkeypatch the functions
monkeypatch.setattr(
ppt_module,
"get_prompt_template",
mock_get_prompt_template
)
monkeypatch.setattr(
ppt_module,
"get_llm_by_type",
mock_get_llm_by_type
)
# Create state with input and locale
state = {
"input": "hello",
"locale": "zh-CN"
}
# Call the ppt_composer_node
result = ppt_module.ppt_composer_node(state)
# Verify get_prompt_template was called with the correct locale
assert len(captured_calls) == 1, "get_prompt_template should be called once"
assert captured_calls[0]["prompt_name"] == "ppt/ppt_composer"
assert captured_calls[0]["locale"] == "zh-CN", \
"get_prompt_template should be called with locale 'zh-CN'"
# Verify result structure
assert "ppt_content" in result
assert "ppt_file_path" in result
def test_default_locale_fallback(self, monkeypatch):
"""
Test that when locale is missing from state, default locale 'en-US' is used.
This test verifies that ppt_composer_node falls back to the default locale
'en-US' when no locale is provided in the state dict.
"""
# Track calls to get_prompt_template
captured_calls = []
def mock_get_prompt_template(prompt_name, locale="en-US"):
"""Capture the arguments passed to get_prompt_template."""
captured_calls.append({"prompt_name": prompt_name, "locale": locale})
return "Mock prompt template"
def mock_get_llm_by_type(llm_type):
"""Return a mock LLM."""
return MockLLM()
# Import here to ensure monkeypatching happens before module import
import src.ppt.graph.ppt_composer_node as ppt_module
# Monkeypatch the functions
monkeypatch.setattr(
ppt_module,
"get_prompt_template",
mock_get_prompt_template
)
monkeypatch.setattr(
ppt_module,
"get_llm_by_type",
mock_get_llm_by_type
)
# Create state without locale (only input)
state = {
"input": "hello"
}
# Call the ppt_composer_node
result = ppt_module.ppt_composer_node(state)
# Verify get_prompt_template was called with the default locale
assert len(captured_calls) == 1, "get_prompt_template should be called once"
assert captured_calls[0]["prompt_name"] == "ppt/ppt_composer"
assert captured_calls[0]["locale"] == "en-US", \
"get_prompt_template should be called with default locale 'en-US'"
# Verify result structure
assert "ppt_content" in result
assert "ppt_file_path" in result
-133
View File
@@ -1,133 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
import sys
from typing import Annotated
# Import MessagesState directly from langgraph rather than through our application
from langgraph.graph import MessagesState
# Create stub versions of Plan/Step/StepType to avoid dependencies
class StepType:
RESEARCH = "research"
PROCESSING = "processing"
class Step:
def __init__(self, need_search, title, description, step_type):
self.need_search = need_search
self.title = title
self.description = description
self.step_type = step_type
class Plan:
def __init__(self, locale, has_enough_context, thought, title, steps):
self.locale = locale
self.has_enough_context = has_enough_context
self.thought = thought
self.title = title
self.steps = steps
# Import the actual State class by loading the module directly
# This avoids the cascade of imports that would normally happen
def load_state_class():
# Get the absolute path to the types.py file
src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))
types_path = os.path.join(src_dir, "graph", "types.py")
# Create a namespace for the module
import types
module_name = "src.graph.types_direct"
spec = types.ModuleType(module_name)
# Add the module to sys.modules to avoid import loops
sys.modules[module_name] = spec
# Set up the namespace with required imports
spec.__dict__["operator"] = __import__("operator")
spec.__dict__["Annotated"] = Annotated
spec.__dict__["MessagesState"] = MessagesState
spec.__dict__["Plan"] = Plan
# Execute the module code
with open(types_path, "r") as f:
module_code = f.read()
exec(module_code, spec.__dict__)
# Return the State class
return spec.State
# Load the actual State class
State = load_state_class()
def test_state_initialization():
"""Test that State class has correct default attribute definitions."""
# Test that the class has the expected attribute definitions
assert State.locale == "en-US"
assert State.observations == []
assert State.plan_iterations == 0
assert State.current_plan is None
assert State.final_report == ""
assert State.auto_accepted_plan is False
assert State.enable_background_investigation is True
assert State.background_investigation_results is None
# Verify state initialization
state = State(messages=[])
assert "messages" in state
# Without explicitly passing attributes, they're not in the state
assert "locale" not in state
assert "observations" not in state
def test_state_with_custom_values():
"""Test that State can be initialized with custom values."""
test_step = Step(
need_search=True,
title="Test Step",
description="Step description",
step_type=StepType.RESEARCH,
)
test_plan = Plan(
locale="en-US",
has_enough_context=False,
thought="Test thought",
title="Test Plan",
steps=[test_step],
)
# Initialize state with custom values and required messages field
state = State(
messages=[],
locale="fr-FR",
observations=["Observation 1"],
plan_iterations=2,
current_plan=test_plan,
final_report="Test report",
auto_accepted_plan=True,
enable_background_investigation=False,
background_investigation_results="Test results",
)
# Access state keys - these are explicitly initialized
assert state["locale"] == "fr-FR"
assert state["observations"] == ["Observation 1"]
assert state["plan_iterations"] == 2
assert state["current_plan"].title == "Test Plan"
assert state["current_plan"].thought == "Test thought"
assert len(state["current_plan"].steps) == 1
assert state["current_plan"].steps[0].title == "Test Step"
assert state["final_report"] == "Test report"
assert state["auto_accepted_plan"] is True
assert state["enable_background_investigation"] is False
assert state["background_investigation_results"] == "Test results"
-335
View File
@@ -1,335 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import asyncio
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from src.agents.agents import DynamicPromptMiddleware, PreModelHookMiddleware
@pytest.fixture
def mock_runtime():
"""Mock Runtime object."""
runtime = MagicMock()
runtime.config = {}
return runtime
@pytest.fixture
def mock_state():
"""Mock state object."""
return {
"messages": [HumanMessage(content="Test message")],
"context": "Test context",
}
@pytest.fixture
def mock_messages():
"""Mock messages returned by apply_prompt_template."""
return [
SystemMessage(content="Test system prompt"),
HumanMessage(content="Test human message"),
]
class TestDynamicPromptMiddleware:
"""Tests for DynamicPromptMiddleware class."""
def test_init(self):
"""Test middleware initialization."""
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
assert middleware.prompt_template == "test_template"
assert middleware.locale == "zh-CN"
def test_init_default_locale(self):
"""Test middleware initialization with default locale."""
middleware = DynamicPromptMiddleware("test_template")
assert middleware.prompt_template == "test_template"
assert middleware.locale == "en-US"
@patch("src.agents.agents.apply_prompt_template")
def test_before_model_success(
self, mock_apply_template, mock_state, mock_runtime, mock_messages
):
"""Test before_model successfully applies prompt template."""
mock_apply_template.return_value = mock_messages
middleware = DynamicPromptMiddleware("test_template", locale="en-US")
result = middleware.before_model(mock_state, mock_runtime)
# Verify apply_prompt_template was called with correct arguments
mock_apply_template.assert_called_once_with(
"test_template", mock_state, locale="en-US"
)
# Verify system message is returned
assert result == {"messages": [mock_messages[0]]}
assert result["messages"][0].content == "Test system prompt"
@patch("src.agents.agents.apply_prompt_template")
def test_before_model_empty_messages(
self, mock_apply_template, mock_state, mock_runtime
):
"""Test before_model with empty message list."""
mock_apply_template.return_value = []
middleware = DynamicPromptMiddleware("test_template")
result = middleware.before_model(mock_state, mock_runtime)
# Should return None when no messages are rendered
assert result is None
@patch("src.agents.agents.apply_prompt_template")
def test_before_model_none_messages(
self, mock_apply_template, mock_state, mock_runtime
):
"""Test before_model when apply_prompt_template returns None."""
mock_apply_template.return_value = None
middleware = DynamicPromptMiddleware("test_template")
result = middleware.before_model(mock_state, mock_runtime)
# Should return None when template returns None
assert result is None
@patch("src.agents.agents.apply_prompt_template")
@patch("src.agents.agents.logger")
def test_before_model_exception_handling(
self, mock_logger, mock_apply_template, mock_state, mock_runtime
):
"""Test before_model handles exceptions gracefully."""
mock_apply_template.side_effect = ValueError("Template rendering failed")
middleware = DynamicPromptMiddleware("test_template")
result = middleware.before_model(mock_state, mock_runtime)
# Should return None on exception
assert result is None
# Should log error with exc_info
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Failed to apply prompt template in before_model" in error_message
assert mock_logger.error.call_args[1]["exc_info"] is True
@patch("src.agents.agents.apply_prompt_template")
def test_before_model_with_different_locale(
self, mock_apply_template, mock_state, mock_runtime, mock_messages
):
"""Test before_model with different locale."""
mock_apply_template.return_value = mock_messages
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
result = middleware.before_model(mock_state, mock_runtime)
# Verify locale is passed correctly
mock_apply_template.assert_called_once_with(
"test_template", mock_state, locale="zh-CN"
)
assert result == {"messages": [mock_messages[0]]}
@pytest.mark.asyncio
@patch("src.agents.agents.apply_prompt_template")
async def test_abefore_model(
self, mock_apply_template, mock_state, mock_runtime, mock_messages
):
"""Test async version of before_model."""
mock_apply_template.return_value = mock_messages
middleware = DynamicPromptMiddleware("test_template")
result = await middleware.abefore_model(mock_state, mock_runtime)
# Should call the sync version and return same result
assert result == {"messages": [mock_messages[0]]}
mock_apply_template.assert_called_once_with(
"test_template", mock_state, locale="en-US"
)
class TestPreModelHookMiddleware:
"""Tests for PreModelHookMiddleware class."""
def test_init(self):
"""Test middleware initialization."""
hook = Mock()
middleware = PreModelHookMiddleware(hook)
assert middleware._pre_model_hook == hook
def test_before_model_with_sync_hook(self, mock_state, mock_runtime):
"""Test before_model with synchronous hook."""
hook = Mock(return_value={"custom_data": "test"})
middleware = PreModelHookMiddleware(hook)
result = middleware.before_model(mock_state, mock_runtime)
# Verify hook was called with correct arguments
hook.assert_called_once_with(mock_state, mock_runtime)
assert result == {"custom_data": "test"}
def test_before_model_with_none_hook(self, mock_state, mock_runtime):
"""Test before_model when hook is None."""
middleware = PreModelHookMiddleware(None)
result = middleware.before_model(mock_state, mock_runtime)
# Should return None when hook is None
assert result is None
def test_before_model_hook_returns_none(self, mock_state, mock_runtime):
"""Test before_model when hook returns None."""
hook = Mock(return_value=None)
middleware = PreModelHookMiddleware(hook)
result = middleware.before_model(mock_state, mock_runtime)
hook.assert_called_once_with(mock_state, mock_runtime)
assert result is None
@patch("src.agents.agents.logger")
def test_before_model_hook_exception(
self, mock_logger, mock_state, mock_runtime
):
"""Test before_model handles hook exceptions gracefully."""
hook = Mock(side_effect=RuntimeError("Hook execution failed"))
middleware = PreModelHookMiddleware(hook)
result = middleware.before_model(mock_state, mock_runtime)
# Should return None on exception
assert result is None
# Should log error with exc_info
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Pre-model hook execution failed in before_model" in error_message
assert mock_logger.error.call_args[1]["exc_info"] is True
@pytest.mark.asyncio
async def test_abefore_model_with_async_hook(self, mock_state, mock_runtime):
"""Test async before_model with async hook."""
async def async_hook(state, runtime):
await asyncio.sleep(0.001) # Simulate async work
return {"async_data": "test"}
middleware = PreModelHookMiddleware(async_hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
assert result == {"async_data": "test"}
@pytest.mark.asyncio
@patch("src.agents.agents.asyncio.to_thread")
async def test_abefore_model_with_sync_hook(
self, mock_to_thread, mock_state, mock_runtime
):
"""Test async before_model with synchronous hook uses asyncio.to_thread."""
hook = Mock(return_value={"sync_data": "test"})
mock_to_thread.return_value = {"sync_data": "test"}
middleware = PreModelHookMiddleware(hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Verify asyncio.to_thread was called with the sync hook
mock_to_thread.assert_called_once_with(hook, mock_state, mock_runtime)
assert result == {"sync_data": "test"}
@pytest.mark.asyncio
async def test_abefore_model_with_none_hook(self, mock_state, mock_runtime):
"""Test async before_model when hook is None."""
middleware = PreModelHookMiddleware(None)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Should return None when hook is None
assert result is None
@pytest.mark.asyncio
@patch("src.agents.agents.logger")
async def test_abefore_model_async_hook_exception(
self, mock_logger, mock_state, mock_runtime
):
"""Test async before_model handles async hook exceptions gracefully."""
async def failing_hook(state, runtime):
raise ValueError("Async hook failed")
middleware = PreModelHookMiddleware(failing_hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Should return None on exception
assert result is None
# Should log error with exc_info
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Pre-model hook execution failed in abefore_model" in error_message
assert mock_logger.error.call_args[1]["exc_info"] is True
@pytest.mark.asyncio
@patch("src.agents.agents.asyncio.to_thread")
@patch("src.agents.agents.logger")
async def test_abefore_model_sync_hook_exception(
self, mock_logger, mock_to_thread, mock_state, mock_runtime
):
"""Test async before_model handles sync hook exceptions gracefully."""
hook = Mock()
mock_to_thread.side_effect = RuntimeError("Thread execution failed")
middleware = PreModelHookMiddleware(hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Should return None on exception
assert result is None
# Should log error with exc_info
mock_logger.error.assert_called_once()
error_message = mock_logger.error.call_args[0][0]
assert "Pre-model hook execution failed in abefore_model" in error_message
assert mock_logger.error.call_args[1]["exc_info"] is True
@pytest.mark.asyncio
async def test_abefore_model_sync_hook_actual_execution(
self, mock_state, mock_runtime
):
"""Test async before_model actually runs sync hook in thread pool."""
# Track if hook was called
hook_called = []
def sync_hook(state, runtime):
hook_called.append(True)
return {"data": "from_sync_hook"}
middleware = PreModelHookMiddleware(sync_hook)
result = await middleware.abefore_model(mock_state, mock_runtime)
# Verify hook was called and result returned
assert len(hook_called) == 1
assert result == {"data": "from_sync_hook"}
@pytest.mark.asyncio
async def test_abefore_model_detects_coroutine_function(
self, mock_state, mock_runtime
):
"""Test that abefore_model correctly detects async vs sync functions."""
# Test with async function
async def async_hook(state, runtime):
return {"type": "async"}
# Test with sync function
def sync_hook(state, runtime):
return {"type": "sync"}
async_middleware = PreModelHookMiddleware(async_hook)
sync_middleware = PreModelHookMiddleware(sync_hook)
# Both should execute successfully
async_result = await async_middleware.abefore_model(mock_state, mock_runtime)
sync_result = await sync_middleware.abefore_model(mock_state, mock_runtime)
assert async_result == {"type": "async"}
assert sync_result == {"type": "sync"}
-434
View File
@@ -1,434 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from langchain_core.tools import BaseTool, tool
from src.agents.tool_interceptor import (
ToolInterceptor,
wrap_tools_with_interceptor,
)
class TestToolInterceptor:
"""Tests for ToolInterceptor class."""
def test_init_with_tools(self):
"""Test initializing interceptor with tool list."""
tools = ["db_tool", "api_tool"]
interceptor = ToolInterceptor(tools)
assert interceptor.interrupt_before_tools == tools
def test_init_without_tools(self):
"""Test initializing interceptor without tools."""
interceptor = ToolInterceptor()
assert interceptor.interrupt_before_tools == []
def test_should_interrupt_with_matching_tool(self):
"""Test should_interrupt returns True for matching tools."""
tools = ["db_tool", "api_tool"]
interceptor = ToolInterceptor(tools)
assert interceptor.should_interrupt("db_tool") is True
assert interceptor.should_interrupt("api_tool") is True
def test_should_interrupt_with_non_matching_tool(self):
"""Test should_interrupt returns False for non-matching tools."""
tools = ["db_tool", "api_tool"]
interceptor = ToolInterceptor(tools)
assert interceptor.should_interrupt("search_tool") is False
assert interceptor.should_interrupt("crawl_tool") is False
def test_should_interrupt_empty_list(self):
"""Test should_interrupt with empty interrupt list."""
interceptor = ToolInterceptor([])
assert interceptor.should_interrupt("db_tool") is False
def test_parse_approval_with_approval_keywords(self):
"""Test parsing user feedback with approval keywords."""
assert ToolInterceptor._parse_approval("approved") is True
assert ToolInterceptor._parse_approval("approve") is True
assert ToolInterceptor._parse_approval("yes") is True
assert ToolInterceptor._parse_approval("proceed") is True
assert ToolInterceptor._parse_approval("continue") is True
assert ToolInterceptor._parse_approval("ok") is True
assert ToolInterceptor._parse_approval("okay") is True
assert ToolInterceptor._parse_approval("accepted") is True
assert ToolInterceptor._parse_approval("accept") is True
assert ToolInterceptor._parse_approval("[approved]") is True
def test_parse_approval_case_insensitive(self):
"""Test parsing is case-insensitive."""
assert ToolInterceptor._parse_approval("APPROVED") is True
assert ToolInterceptor._parse_approval("Approved") is True
assert ToolInterceptor._parse_approval("PROCEED") is True
def test_parse_approval_with_surrounding_text(self):
"""Test parsing with surrounding text."""
assert ToolInterceptor._parse_approval("Sure, proceed with the tool") is True
assert ToolInterceptor._parse_approval("[ACCEPTED] I approve this") is True
def test_parse_approval_rejection(self):
"""Test parsing rejects non-approval feedback."""
assert ToolInterceptor._parse_approval("no") is False
assert ToolInterceptor._parse_approval("reject") is False
assert ToolInterceptor._parse_approval("cancel") is False
assert ToolInterceptor._parse_approval("random feedback") is False
def test_parse_approval_empty_string(self):
"""Test parsing empty string."""
assert ToolInterceptor._parse_approval("") is False
def test_parse_approval_none(self):
"""Test parsing None."""
assert ToolInterceptor._parse_approval(None) is False
@patch("src.agents.tool_interceptor.interrupt")
def test_wrap_tool_with_interrupt(self, mock_interrupt):
"""Test wrapping a tool with interrupt."""
mock_interrupt.return_value = "approved"
# Create a simple test tool
@tool
def test_tool(input_text: str) -> str:
"""Test tool."""
return f"Result: {input_text}"
interceptor = ToolInterceptor(["test_tool"])
# Wrap the tool
wrapped_tool = ToolInterceptor.wrap_tool(test_tool, interceptor)
# Invoke the wrapped tool
result = wrapped_tool.invoke("hello")
# Verify interrupt was called
mock_interrupt.assert_called_once()
assert "test_tool" in mock_interrupt.call_args[0][0]
@patch("src.agents.tool_interceptor.interrupt")
def test_wrap_tool_without_interrupt(self, mock_interrupt):
"""Test wrapping a tool that doesn't trigger interrupt."""
# Create a simple test tool
@tool
def test_tool(input_text: str) -> str:
"""Test tool."""
return f"Result: {input_text}"
interceptor = ToolInterceptor(["other_tool"])
# Wrap the tool
wrapped_tool = ToolInterceptor.wrap_tool(test_tool, interceptor)
# Invoke the wrapped tool
result = wrapped_tool.invoke("hello")
# Verify interrupt was NOT called
mock_interrupt.assert_not_called()
assert "Result: hello" in str(result)
@patch("src.agents.tool_interceptor.interrupt")
def test_wrap_tool_user_rejects(self, mock_interrupt):
"""Test user rejecting tool execution."""
mock_interrupt.return_value = "no"
@tool
def test_tool(input_text: str) -> str:
"""Test tool."""
return f"Result: {input_text}"
interceptor = ToolInterceptor(["test_tool"])
wrapped_tool = ToolInterceptor.wrap_tool(test_tool, interceptor)
# Invoke the wrapped tool
result = wrapped_tool.invoke("hello")
# Verify tool was not executed
assert isinstance(result, dict)
assert "error" in result
assert result["status"] == "rejected"
def test_wrap_tools_with_interceptor_empty_list(self):
"""Test wrapping tools with empty interrupt list."""
@tool
def test_tool(input_text: str) -> str:
"""Test tool."""
return f"Result: {input_text}"
tools = [test_tool]
wrapped_tools = wrap_tools_with_interceptor(tools, [])
# Should return tools as-is
assert len(wrapped_tools) == 1
assert wrapped_tools[0].name == "test_tool"
def test_wrap_tools_with_interceptor_none(self):
"""Test wrapping tools with None interrupt list."""
@tool
def test_tool(input_text: str) -> str:
"""Test tool."""
return f"Result: {input_text}"
tools = [test_tool]
wrapped_tools = wrap_tools_with_interceptor(tools, None)
# Should return tools as-is
assert len(wrapped_tools) == 1
@patch("src.agents.tool_interceptor.interrupt")
def test_wrap_tools_with_interceptor_multiple(self, mock_interrupt):
"""Test wrapping multiple tools."""
mock_interrupt.return_value = "approved"
@tool
def db_tool(query: str) -> str:
"""DB tool."""
return f"Query result: {query}"
@tool
def search_tool(query: str) -> str:
"""Search tool."""
return f"Search result: {query}"
tools = [db_tool, search_tool]
wrapped_tools = wrap_tools_with_interceptor(tools, ["db_tool"])
# Only db_tool should trigger interrupt
db_result = wrapped_tools[0].invoke("test query")
assert mock_interrupt.call_count == 1
search_result = wrapped_tools[1].invoke("test query")
# No additional interrupt calls for search_tool
assert mock_interrupt.call_count == 1
def test_wrap_tool_preserves_tool_properties(self):
"""Test that wrapping preserves tool properties."""
@tool
def my_tool(input_text: str) -> str:
"""My tool description."""
return f"Result: {input_text}"
interceptor = ToolInterceptor([])
wrapped_tool = ToolInterceptor.wrap_tool(my_tool, interceptor)
assert wrapped_tool.name == "my_tool"
assert wrapped_tool.description == "My tool description."
class TestFormatToolInput:
"""Tests for tool input formatting functionality."""
def test_format_tool_input_none(self):
"""Test formatting None input."""
result = ToolInterceptor._format_tool_input(None)
assert result == "No input"
def test_format_tool_input_string(self):
"""Test formatting string input."""
input_str = "SELECT * FROM users"
result = ToolInterceptor._format_tool_input(input_str)
assert result == input_str
def test_format_tool_input_simple_dict(self):
"""Test formatting simple dictionary."""
input_dict = {"query": "test", "limit": 10}
result = ToolInterceptor._format_tool_input(input_dict)
# Should be valid JSON
import json
parsed = json.loads(result)
assert parsed == input_dict
# Should be indented
assert "\n" in result
def test_format_tool_input_nested_dict(self):
"""Test formatting nested dictionary."""
input_dict = {
"query": "SELECT * FROM users",
"config": {
"timeout": 30,
"retry": True
}
}
result = ToolInterceptor._format_tool_input(input_dict)
import json
parsed = json.loads(result)
assert parsed == input_dict
assert "timeout" in result
assert "retry" in result
def test_format_tool_input_list(self):
"""Test formatting list input."""
input_list = ["item1", "item2", 123]
result = ToolInterceptor._format_tool_input(input_list)
import json
parsed = json.loads(result)
assert parsed == input_list
def test_format_tool_input_complex_list(self):
"""Test formatting list with mixed types."""
input_list = ["text", 42, 3.14, True, {"key": "value"}]
result = ToolInterceptor._format_tool_input(input_list)
import json
parsed = json.loads(result)
assert parsed == input_list
def test_format_tool_input_tuple(self):
"""Test formatting tuple input."""
input_tuple = ("item1", "item2", 123)
result = ToolInterceptor._format_tool_input(input_tuple)
import json
parsed = json.loads(result)
# JSON converts tuples to lists
assert parsed == list(input_tuple)
def test_format_tool_input_integer(self):
"""Test formatting integer input."""
result = ToolInterceptor._format_tool_input(42)
assert result == "42"
def test_format_tool_input_float(self):
"""Test formatting float input."""
result = ToolInterceptor._format_tool_input(3.14)
assert result == "3.14"
def test_format_tool_input_boolean(self):
"""Test formatting boolean input."""
result_true = ToolInterceptor._format_tool_input(True)
result_false = ToolInterceptor._format_tool_input(False)
assert result_true == "True"
assert result_false == "False"
def test_format_tool_input_deeply_nested(self):
"""Test formatting deeply nested structure."""
input_dict = {
"level1": {
"level2": {
"level3": {
"level4": ["a", "b", "c"],
"data": {"key": "value"}
}
}
}
}
result = ToolInterceptor._format_tool_input(input_dict)
import json
parsed = json.loads(result)
assert parsed == input_dict
def test_format_tool_input_empty_dict(self):
"""Test formatting empty dictionary."""
result = ToolInterceptor._format_tool_input({})
assert result == "{}"
def test_format_tool_input_empty_list(self):
"""Test formatting empty list."""
result = ToolInterceptor._format_tool_input([])
assert result == "[]"
def test_format_tool_input_special_characters(self):
"""Test formatting dict with special characters."""
input_dict = {
"query": 'SELECT * FROM users WHERE name = "John"',
"path": "/usr/local/bin",
"unicode": "你好世界"
}
result = ToolInterceptor._format_tool_input(input_dict)
import json
parsed = json.loads(result)
assert parsed == input_dict
def test_format_tool_input_numbers_as_strings(self):
"""Test formatting with various number types."""
input_dict = {
"int": 42,
"float": 3.14159,
"negative": -100,
"zero": 0,
"scientific": 1e-5
}
result = ToolInterceptor._format_tool_input(input_dict)
import json
parsed = json.loads(result)
assert parsed["int"] == 42
assert abs(parsed["float"] - 3.14159) < 0.00001
assert parsed["negative"] == -100
assert parsed["zero"] == 0
def test_format_tool_input_with_none_values(self):
"""Test formatting dict with None values."""
input_dict = {
"key1": "value1",
"key2": None,
"key3": {"nested": None}
}
result = ToolInterceptor._format_tool_input(input_dict)
import json
parsed = json.loads(result)
assert parsed == input_dict
def test_format_tool_input_indentation(self):
"""Test that output uses proper indentation (2 spaces)."""
input_dict = {"outer": {"inner": "value"}}
result = ToolInterceptor._format_tool_input(input_dict)
# Should have indented lines
assert " " in result # 2-space indentation
lines = result.split("\n")
# Check that indentation increases with nesting
assert any(line.startswith(" ") for line in lines)
def test_format_tool_input_preserves_order_insertion(self):
"""Test that dict order is preserved in output."""
input_dict = {
"first": 1,
"second": 2,
"third": 3
}
result = ToolInterceptor._format_tool_input(input_dict)
import json
parsed = json.loads(result)
# Verify all keys are present
assert set(parsed.keys()) == {"first", "second", "third"}
def test_format_tool_input_long_strings(self):
"""Test formatting with long string values."""
long_string = "x" * 1000
input_dict = {"long": long_string}
result = ToolInterceptor._format_tool_input(input_dict)
import json
parsed = json.loads(result)
assert parsed["long"] == long_string
def test_format_tool_input_mixed_types_in_list(self):
"""Test formatting list with mixed complex types."""
input_list = [
"string",
42,
{"dict": "value"},
[1, 2, 3],
True,
None
]
result = ToolInterceptor._format_tool_input(input_list)
import json
parsed = json.loads(result)
assert len(parsed) == 6
assert parsed[0] == "string"
assert parsed[1] == 42
assert parsed[2] == {"dict": "value"}
assert parsed[3] == [1, 2, 3]
assert parsed[4] is True
assert parsed[5] is None
@@ -1,69 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import unittest
from unittest.mock import MagicMock
from langchain_core.tools import Tool
from src.agents.tool_interceptor import ToolInterceptor
class TestToolInterceptorFix(unittest.TestCase):
def test_interceptor_patches_run_method(self):
# Create a mock tool
mock_func = MagicMock(return_value="Original Result")
tool = Tool(name="resolve_company_name", func=mock_func, description="test tool")
# Interceptor that always interrupts 'resolve_company_name'
interceptor = ToolInterceptor(interrupt_before_tools=["resolve_company_name"])
# Wrap the tool
wrapped_tool = ToolInterceptor.wrap_tool(tool, interceptor)
# Mock interrupt to avoid actual suspension
with unittest.mock.patch("src.agents.tool_interceptor.interrupt", return_value="approved"):
# Call using .run() which triggers ._run()
# Standard BaseTool execution flow is invoke -> run -> _run
# If we only patched func, run() would call original _run which calls original func, bypassing interception
# With the fix, _run should be patched to call intercepted_func
result = wrapped_tool.run("some input")
# Verify result
self.assertEqual(result, "Original Result")
# Verify the original function was called
# If interception works, intercepted_func calls original_func
mock_func.assert_called_once()
def test_run_method_without_interrupt(self):
"""Test that tools not in interrupt list work normally via .run()"""
mock_func = MagicMock(return_value="Result")
tool = Tool(name="other_tool", func=mock_func, description="test")
interceptor = ToolInterceptor(interrupt_before_tools=["resolve_company_name"])
wrapped_tool = ToolInterceptor.wrap_tool(tool, interceptor)
with unittest.mock.patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
result = wrapped_tool.run("input")
# Verify interrupt was NOT called for non-intercepted tool
mock_interrupt.assert_not_called()
assert result == "Result"
mock_func.assert_called_once()
def test_interceptor_resolve_company_name_example(self):
"""Test specific resolve_company_name logic capability using interceptor subclassing or custom logic simulation."""
# This test verifies that we can intercept execution of resolve_company_name
# even if it's called via .run()
mock_func = MagicMock(return_value='{"code": 0, "data": [{"companyName": "A"}, {"companyName": "B"}]}')
tool = Tool(name="resolve_company_name", func=mock_func, description="resolve company")
interceptor = ToolInterceptor(interrupt_before_tools=["resolve_company_name"])
wrapped_tool = ToolInterceptor.wrap_tool(tool, interceptor)
# Simulate user selecting "B"
with unittest.mock.patch("src.agents.tool_interceptor.interrupt", return_value="approved"):
# We are not testing the complex business logic here because we didn't add it to ToolInterceptor class
# We are mostly verifying that the INTERCEPTION mechanism works for this tool name when called via .run()
wrapped_tool.run("query")
mock_func.assert_called_once()
@@ -1,153 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional
from unittest.mock import MagicMock, patch
import psycopg
import pytest
class PostgreSQLMockInstance:
"""Utility class for managing PostgreSQL mock instances."""
def __init__(self, database_name: str = "test_db"):
self.database_name = database_name
self.temp_dir: Optional[Path] = None
self.mock_connection: Optional[MagicMock] = None
self.mock_data: Dict[str, Any] = {}
self._setup_mock_data()
def _setup_mock_data(self):
"""Initialize mock data storage."""
self.mock_data = {
"chat_streams": {}, # thread_id -> record
"table_exists": False,
"connection_active": True,
}
def connect(self) -> MagicMock:
"""Create a mock PostgreSQL connection."""
self.mock_connection = MagicMock()
self._setup_mock_methods()
return self.mock_connection
def _setup_mock_methods(self):
"""Setup mock methods for PostgreSQL operations."""
if not self.mock_connection:
return
# Mock cursor context manager
mock_cursor = MagicMock()
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
# Setup cursor operations
mock_cursor.execute = MagicMock(side_effect=self._mock_execute)
mock_cursor.fetchone = MagicMock(side_effect=self._mock_fetchone)
mock_cursor.rowcount = 0
# Setup connection operations
self.mock_connection.cursor = MagicMock(return_value=mock_cursor)
self.mock_connection.commit = MagicMock()
self.mock_connection.rollback = MagicMock()
self.mock_connection.close = MagicMock()
# Store cursor for external access
self._mock_cursor = mock_cursor
def _mock_execute(self, sql: str, params=None):
"""Mock SQL execution."""
sql_upper = sql.upper().strip()
if "CREATE TABLE" in sql_upper:
self.mock_data["table_exists"] = True
self._mock_cursor.rowcount = 0
elif "SELECT" in sql_upper and "chat_streams" in sql_upper:
# Mock SELECT query
if params and len(params) > 0:
thread_id = params[0]
if thread_id in self.mock_data["chat_streams"]:
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][
thread_id
]
else:
self._mock_cursor._fetch_result = None
else:
self._mock_cursor._fetch_result = None
elif "UPDATE" in sql_upper and "chat_streams" in sql_upper:
# Mock UPDATE query
if params and len(params) >= 2:
messages, thread_id = params[0], params[1]
if thread_id in self.mock_data["chat_streams"]:
self.mock_data["chat_streams"][thread_id] = {
"id": thread_id,
"thread_id": thread_id,
"messages": messages,
}
self._mock_cursor.rowcount = 1
else:
self._mock_cursor.rowcount = 0
elif "INSERT" in sql_upper and "chat_streams" in sql_upper:
# Mock INSERT query
if params and len(params) >= 2:
thread_id, messages = params[0], params[1]
self.mock_data["chat_streams"][thread_id] = {
"id": thread_id,
"thread_id": thread_id,
"messages": messages,
}
self._mock_cursor.rowcount = 1
def _mock_fetchone(self):
"""Mock fetchone operation."""
return getattr(self._mock_cursor, "_fetch_result", None)
def disconnect(self):
"""Cleanup mock connection."""
if self.mock_connection:
self.mock_connection.close()
self._setup_mock_data() # Reset data
def reset_data(self):
"""Reset all mock data."""
self._setup_mock_data()
def get_table_count(self, table_name: str) -> int:
"""Get record count in a table."""
if table_name == "chat_streams":
return len(self.mock_data["chat_streams"])
return 0
def create_test_data(self, table_name: str, records: list):
"""Insert test data into a table."""
if table_name == "chat_streams":
for record in records:
thread_id = record.get("thread_id")
if thread_id:
self.mock_data["chat_streams"][thread_id] = record
@pytest.fixture
def mock_postgresql():
"""Create a PostgreSQL mock instance."""
instance = PostgreSQLMockInstance()
instance.connect()
yield instance
instance.disconnect()
@pytest.fixture
def clean_mock_postgresql():
"""Create a clean PostgreSQL mock instance that resets between tests."""
instance = PostgreSQLMockInstance()
instance.connect()
instance.reset_data()
yield instance
instance.disconnect()
-685
View File
@@ -1,685 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
from unittest.mock import MagicMock, patch
import mongomock
import pytest
from postgres_mock_utils import PostgreSQLMockInstance
import src.graph.checkpoint as checkpoint
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
def has_real_db_connection():
# Check the environment if the MongoDB server is available
enabled = os.getenv("DB_TESTS_ENABLED", "false")
if enabled.lower() == "true":
return True
return False
def test_with_local_postgres_db():
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
with patch("psycopg.connect") as mock_connect:
# Setup mock PostgreSQL connection
pg_mock = PostgreSQLMockInstance()
mock_connect.return_value = pg_mock.connect()
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
assert manager.postgres_conn is not None
assert manager.mongo_client is None
def test_with_local_mongo_db():
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.mongo_db is not None
assert manager.postgres_conn is None
def test_init_without_checkpoint_saver():
"""Manager should not create DB clients when checkpoint_saver is False."""
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
assert manager.checkpoint_saver is False
# DB connections are not created when saver is disabled
assert manager.mongo_client is None
assert manager.postgres_conn is None
def test_process_stream_partial_buffer_postgres(monkeypatch):
"""Partial chunks should be buffered; Postgres init is stubbed to no-op."""
# Patch Postgres init to no-op
def _no_pg(self):
self.postgres_conn = None
monkeypatch.setattr(
checkpoint.ChatStreamManager, "_init_postgresql", _no_pg, raising=True
)
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
result = manager.process_stream_message("t1", "hello", finish_reason="partial")
assert result is True
# Verify the chunk was stored in the in-memory store
items = manager.store.search(("messages", "t1"), limit=10)
values = [it.dict()["value"] for it in items]
assert "hello" in values
def test_process_stream_partial_buffer_mongo():
"""Partial chunks should be buffered; Use mongomock instead of real MongoDB."""
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
result = manager.process_stream_message("t2", "hello", finish_reason="partial")
assert result is True
# Verify the chunk was stored in the in-memory store
items = manager.store.search(("messages", "t2"), limit=10)
values = [it.dict()["value"] for it in items]
assert "hello" in values
@pytest.mark.skipif(
not has_real_db_connection(), reason="PostgreSQL Server is not available"
)
def test_persist_postgresql_local_db():
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
assert manager.postgres_conn is not None
assert manager.mongo_client is None
# Simulate a message to persist
thread_id = "test_thread"
messages = ["This is a test message."]
result = manager._persist_to_postgresql(thread_id, messages)
assert result is True
# Simulate a message with existing thread (should append, not overwrite)
result = manager._persist_to_postgresql(thread_id, ["Another message."])
assert result is True
# Verify the messages were appended correctly
with manager.postgres_conn.cursor() as cursor:
cursor.execute(
"SELECT messages FROM chat_streams WHERE thread_id = %s", (thread_id,)
)
existing_record = cursor.fetchone()
assert existing_record is not None
assert existing_record["messages"] == ["This is a test message.", "Another message."]
@pytest.mark.skipif(
not has_real_db_connection(), reason="PostgreSQL Server is not available"
)
def test_persist_postgresql_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
assert (
manager.process_stream_message("thd3", "Hello", finish_reason="partial") is True
)
assert (
manager.process_stream_message("thd3", " World", finish_reason="stop") is True
)
# Verify the messages were aggregated correctly
with manager.postgres_conn.cursor() as cursor:
# Check if conversation already exists
cursor.execute(
"SELECT messages FROM chat_streams WHERE thread_id = %s", ("thd3",)
)
existing_record = cursor.fetchone()
assert existing_record is not None
assert existing_record["messages"] == ["Hello", " World"]
def test_persist_not_attempted_when_saver_disabled():
"""When saver disabled, stop should not persist and should return False."""
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
# stop should try to persist, but saver disabled => returns False
assert manager.process_stream_message("t4", "hello", finish_reason="stop") is False
def test_persist_mongodb_local_db():
"""Ensure that the ChatStreamManager can persist to a mocked MongoDB."""
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.mongo_db is not None
assert manager.postgres_conn is None
# Simulate a message to persist
thread_id = "test_thread"
messages = ["This is a test message."]
result = manager._persist_to_mongodb(thread_id, messages)
assert result is True
# Verify data was persisted in mock
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": thread_id})
assert doc is not None
assert doc["messages"] == messages
# Simulate a message with existing thread (should append, not overwrite)
result = manager._persist_to_mongodb(thread_id, ["Another message."])
assert result is True
# Verify update worked - messages should be appended to existing ones
doc = collection.find_one({"thread_id": thread_id})
assert doc["messages"] == ["This is a test message.", "Another message."]
@pytest.mark.skipif(
not has_real_db_connection(), reason="MongoDB server is not available"
)
def test_persist_mongodb_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert (
manager.process_stream_message("thd5", "Hello", finish_reason="partial") is True
)
assert (
manager.process_stream_message("thd5", " World", finish_reason="stop") is True
)
# Verify the messages were aggregated correctly
collection = manager.mongo_db.chat_streams
existing_record = collection.find_one({"thread_id": "thd5"})
assert existing_record is not None
assert existing_record["messages"] == ["Hello", " World"]
def test_invalid_inputs_return_false(monkeypatch):
"""Empty thread_id or message should be rejected and return False."""
def _no_mongo(self):
self.mongo_client = None
self.mongo_db = None
monkeypatch.setattr(
checkpoint.ChatStreamManager, "_init_mongodb", _no_mongo, raising=True
)
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.process_stream_message("", "msg", finish_reason="partial") is False
assert manager.process_stream_message("tid", "", finish_reason="partial") is False
def test_unsupported_db_uri_scheme():
"""Manager should log warning for unsupported database URI schemes."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True, db_uri="redis://localhost:6379/0"
)
# Should not have any database connections
assert manager.mongo_client is None
assert manager.postgres_conn is None
assert manager.mongo_db is None
def test_process_stream_with_interrupt_finish_reason():
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
# Add partial message
assert (
manager.process_stream_message(
"int_test", "Interrupted", finish_reason="partial"
)
is True
)
# Interrupt should trigger persistence
assert (
manager.process_stream_message(
"int_test", " message", finish_reason="interrupt"
)
is True
)
# Verify persistence occurred
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": "int_test"})
assert doc is not None
assert doc["messages"] == ["Interrupted", " message"]
def test_postgresql_connection_failure(monkeypatch):
"""Test PostgreSQL connection failure handling."""
def failing_connect(dsn, **kwargs):
raise RuntimeError("Connection failed")
monkeypatch.setattr("psycopg.connect", failing_connect)
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
# Should have no postgres connection on failure
assert manager.postgres_conn is None
def test_mongodb_ping_failure(monkeypatch):
"""Test MongoDB ping failure during initialization."""
class FakeAdmin:
def command(self, name):
raise RuntimeError("Ping failed")
class FakeClient:
def __init__(self, uri):
self.admin = FakeAdmin()
monkeypatch.setattr(checkpoint, "MongoClient", lambda uri: FakeClient(uri))
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
# Should not have mongo_db set on ping failure
assert getattr(manager, "mongo_db", None) is None
def test_store_namespace_consistency():
"""Test that store namespace is consistently used across methods."""
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
# Process a partial message
assert (
manager.process_stream_message("ns_test", "chunk1", finish_reason="partial")
is True
)
# Verify cursor is stored correctly
cursor = manager.store.get(("messages", "ns_test"), "cursor")
assert cursor is not None
assert cursor.value["index"] == 0
# Add another chunk
assert (
manager.process_stream_message("ns_test", "chunk2", finish_reason="partial")
is True
)
# Verify cursor is incremented
cursor = manager.store.get(("messages", "ns_test"), "cursor")
assert cursor.value["index"] == 1
def test_cursor_initialization_edge_cases():
"""Test cursor handling edge cases."""
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
# Manually set a cursor with missing index
namespace = ("messages", "edge_test")
manager.store.put(namespace, "cursor", {}) # Missing 'index' key
# Should handle missing index gracefully
result = manager.process_stream_message(
"edge_test", "test", finish_reason="partial"
)
assert result is True
# Should default to 0 and increment to 1
cursor = manager.store.get(namespace, "cursor")
assert cursor.value["index"] == 1
def test_multiple_threads_isolation():
"""Test that different thread_ids are properly isolated."""
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
# Process messages for different threads
assert (
manager.process_stream_message("thread1", "msg1", finish_reason="partial")
is True
)
assert (
manager.process_stream_message("thread2", "msg2", finish_reason="partial")
is True
)
assert (
manager.process_stream_message("thread1", "msg3", finish_reason="partial")
is True
)
# Verify isolation
thread1_items = manager.store.search(("messages", "thread1"), limit=10)
thread2_items = manager.store.search(("messages", "thread2"), limit=10)
thread1_values = [
item.dict()["value"]
for item in thread1_items
if isinstance(item.dict()["value"], str)
]
thread2_values = [
item.dict()["value"]
for item in thread2_items
if isinstance(item.dict()["value"], str)
]
assert "msg1" in thread1_values
assert "msg3" in thread1_values
assert "msg2" in thread2_values
assert "msg1" not in thread2_values
assert "msg2" not in thread1_values
def test_mongodb_insert_and_update_paths():
"""Exercise MongoDB insert, update, and exception branches using mongomock."""
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
# Insert success (new thread)
assert manager._persist_to_mongodb("th1", ["message1"]) is True
# Verify insert worked
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": "th1"})
assert doc is not None
assert doc["messages"] == ["message1"]
# Update success (existing thread - should append, not overwrite)
assert manager._persist_to_mongodb("th1", ["message2"]) is True
# Verify update worked - messages should be appended
doc = collection.find_one({"thread_id": "th1"})
assert doc["messages"] == ["message1", "message2"]
# Test error case by mocking collection methods
original_find_one = collection.find_one
collection.find_one = MagicMock(side_effect=RuntimeError("Database error"))
assert manager._persist_to_mongodb("th2", ["message"]) is False
# Restore original method
collection.find_one = original_find_one
def test_postgresql_insert_update_and_error_paths():
"""Exercise PostgreSQL update, insert, and error/rollback branches."""
calls = {"executed": []}
class FakeCursor:
def __init__(self, mode):
self.mode = mode
self.rowcount = 0
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, sql, params=None):
calls["executed"].append(sql.strip().split()[0])
if "SELECT" in sql:
if self.mode == "update":
self._fetch = {"id": "x"}
elif self.mode == "error":
raise RuntimeError("sql error")
else:
self._fetch = None
else:
# UPDATE or INSERT
self.rowcount = 1
def fetchone(self):
return getattr(self, "_fetch", None)
class FakeConn:
def __init__(self, mode):
self.mode = mode
self.commit_called = False
self.rollback_called = False
def cursor(self):
return FakeCursor(self.mode)
def commit(self):
self.commit_called = True
def rollback(self):
self.rollback_called = True
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL)
# Update path
manager.postgres_conn = FakeConn("update")
assert manager._persist_to_postgresql("t", ["m"]) is True
assert manager.postgres_conn.commit_called is True
# Insert path
manager.postgres_conn = FakeConn("insert")
assert manager._persist_to_postgresql("t", ["m"]) is True
assert manager.postgres_conn.commit_called is True
# Error path with rollback
manager.postgres_conn = FakeConn("error")
assert manager._persist_to_postgresql("t", ["m"]) is False
assert manager.postgres_conn.rollback_called is True
def test_create_chat_streams_table_success_and_error():
"""Ensure table creation commits on success and rolls back on failure."""
class FakeCursor:
def __init__(self, should_fail=False):
self.should_fail = should_fail
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, sql):
if self.should_fail:
raise RuntimeError("ddl fail")
class FakeConn:
def __init__(self, should_fail=False):
self.should_fail = should_fail
self.commits = 0
self.rollbacks = 0
def cursor(self):
return FakeCursor(self.should_fail)
def commit(self):
self.commits += 1
def rollback(self):
self.rollbacks += 1
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL)
# Success
manager.postgres_conn = FakeConn(False)
manager._create_chat_streams_table()
assert manager.postgres_conn.commits == 1
# Failure triggers rollback
manager.postgres_conn = FakeConn(True)
manager._create_chat_streams_table()
assert manager.postgres_conn.rollbacks == 1
def test_close_closes_resources_and_handles_errors():
"""Close should gracefully handle both success and exceptions."""
flags = {"mongo": 0, "pg": 0}
class M:
def close(self):
flags["mongo"] += 1
class P:
def __init__(self, raise_on_close=False):
self.raise_on_close = raise_on_close
def close(self):
if self.raise_on_close:
raise RuntimeError("close fail")
flags["pg"] += 1
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
manager.mongo_client = M()
manager.postgres_conn = P()
manager.close()
assert flags == {"mongo": 1, "pg": 1}
# Trigger error branches (no raise escapes)
manager.mongo_client = None # skip mongo
manager.postgres_conn = P(True)
manager.close() # should handle exception gracefully
def test_context_manager_calls_close(monkeypatch):
"""The context manager protocol should call close() on exit."""
called = {"close": 0}
def _noop(self):
self.mongo_client = None
self.mongo_db = None
monkeypatch.setattr(
checkpoint.ChatStreamManager, "_init_mongodb", _noop, raising=True
)
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
def fake_close():
called["close"] += 1
manager.close = fake_close
with manager:
pass
assert called["close"] == 1
def test_init_mongodb_success_and_failure(monkeypatch):
"""MongoDB init should succeed with mongomock and fail gracefully with errors."""
# Success path with mongomock
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
assert manager.mongo_db is not None
# Failure path
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
mock_mongo_client.side_effect = RuntimeError("Connection failed")
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
# Should have no mongo_db set on failure
assert getattr(manager, "mongo_db", None) is None
def test_init_postgresql_calls_connect_and_create_table(monkeypatch):
"""PostgreSQL init should connect and create the required table."""
flags = {"connected": 0, "created": 0}
class FakeConn:
def __init__(self):
pass
def close(self):
pass
def fake_connect(self):
flags["connected"] += 1
flags["created"] += 1
return FakeConn()
monkeypatch.setattr(
checkpoint.ChatStreamManager, "_init_postgresql", fake_connect, raising=True
)
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL)
assert manager.postgres_conn is None
assert flags == {"connected": 1, "created": 1}
def test_chat_stream_message_wrapper(monkeypatch):
"""Wrapper should delegate when enabled and return False when disabled."""
# When saver enabled, should call default manager
monkeypatch.setattr(
checkpoint, "get_bool_env", lambda k, d=False: True, raising=True
)
called = {"args": None}
def fake_process(tid, msg, fr):
called["args"] = (tid, msg, fr)
return True
monkeypatch.setattr(
checkpoint._default_manager,
"process_stream_message",
fake_process,
raising=True,
)
assert checkpoint.chat_stream_message("tid", "msg", "stop") is True
assert called["args"] == ("tid", "msg", "stop")
# When saver disabled, returns False and does not call manager
monkeypatch.setattr(
checkpoint, "get_bool_env", lambda k, d=False: False, raising=True
)
called["args"] = None
assert checkpoint.chat_stream_message("tid", "msg", "stop") is False
assert called["args"] is None
-46
View File
@@ -1,46 +0,0 @@
from unittest.mock import patch
import mongomock
import src.graph.checkpoint as checkpoint
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
def test_memory_leak_check_memory_cleared_after_persistence():
"""
Test that InMemoryStore is cleared for a thread after successful persistence.
This prevents memory leaks for long-running processes.
"""
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
thread_id = "leak_test_thread"
namespace = ("messages", thread_id)
# 1. Simulate streaming messages
manager.process_stream_message(thread_id, "Hello", "partial")
manager.process_stream_message(thread_id, " World", "partial")
# Verify items are in store during streaming
items = manager.store.search(namespace)
assert len(items) > 0, "Store should contain items during streaming"
# 2. Simulate end of conversation (trigger persistence)
# 'stop' should trigger _persist_complete_conversation which now includes cleanup
manager.process_stream_message(thread_id, "!", "stop")
# 3. Verify store is empty for this thread
items_after = manager.store.search(namespace)
assert len(items_after) == 0, "Memory should be cleared after successful persistence"
# Verify persistence actually happened
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": thread_id})
assert doc is not None
assert doc["messages"] == ["Hello", " World", "!"]
-136
View File
@@ -1,136 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from langchain_core.messages import ToolMessage
from src.citations.collector import CitationCollector
from src.citations.extractor import (
_extract_domain,
citations_to_markdown_references,
extract_citations_from_messages,
merge_citations,
)
from src.citations.formatter import CitationFormatter
from src.citations.models import Citation, CitationMetadata
class TestCitationMetadata:
def test_initialization(self):
meta = CitationMetadata(
url="https://example.com/page",
title="Example Page",
description="An example description",
)
assert meta.url == "https://example.com/page"
assert meta.title == "Example Page"
assert meta.description == "An example description"
assert meta.domain == "example.com" # Auto-extracted in post_init
def test_id_generation(self):
meta = CitationMetadata(url="https://example.com", title="Test")
# Just check it's a non-empty string, length 12
assert len(meta.id) == 12
assert isinstance(meta.id, str)
def test_to_dict(self):
meta = CitationMetadata(
url="https://example.com", title="Test", relevance_score=0.8
)
data = meta.to_dict()
assert data["url"] == "https://example.com"
assert data["title"] == "Test"
assert data["relevance_score"] == 0.8
assert "id" in data
class TestCitation:
def test_citation_wrapper(self):
meta = CitationMetadata(url="https://example.com", title="Test")
citation = Citation(number=1, metadata=meta)
assert citation.number == 1
assert citation.url == "https://example.com"
assert citation.title == "Test"
assert citation.to_markdown_reference() == "[Test](https://example.com)"
assert citation.to_numbered_reference() == "[1] Test - https://example.com"
class TestExtractor:
def test_extract_from_tool_message_web_search(self):
search_result = {
"results": [
{
"url": "https://example.com/1",
"title": "Result 1",
"content": "Content 1",
"score": 0.9,
}
]
}
msg = ToolMessage(
content=str(search_result).replace("'", '"'), # Simple JSON dump simulation
tool_call_id="call_1",
name="web_search",
)
# Mocking json structure if ToolMessage content expects stringified JSON
import json
msg.content = json.dumps(search_result)
citations = extract_citations_from_messages([msg])
assert len(citations) == 1
assert citations[0]["url"] == "https://example.com/1"
assert citations[0]["title"] == "Result 1"
def test_extract_domain(self):
assert _extract_domain("https://www.example.com/path") == "www.example.com"
assert _extract_domain("http://example.org") == "example.org"
def test_merge_citations(self):
existing = [{"url": "https://a.com", "title": "A", "relevance_score": 0.5}]
new = [
{"url": "https://b.com", "title": "B", "relevance_score": 0.6},
{
"url": "https://a.com",
"title": "A New",
"relevance_score": 0.7,
}, # Better score for A
]
merged = merge_citations(existing, new)
assert len(merged) == 2
# Check A was updated
a_citation = next(c for c in merged if c["url"] == "https://a.com")
assert a_citation["relevance_score"] == 0.7
# Check B is present
b_citation = next(c for c in merged if c["url"] == "https://b.com")
assert b_citation["title"] == "B"
def test_citations_to_markdown(self):
citations = [{"url": "https://a.com", "title": "A", "description": "Desc A"}]
md = citations_to_markdown_references(citations)
assert "## Key Citations" in md
assert "- [A](https://a.com)" in md
class TestCollector:
def test_add_citations(self):
collector = CitationCollector()
results = [
{"url": "https://example.com", "title": "Example", "content": "Test"}
]
added = collector.add_from_search_results(results, query="test")
assert len(added) == 1
assert added[0].url == "https://example.com"
assert collector.count == 1
class TestFormatter:
def test_format_inline(self):
formatter = CitationFormatter(style="superscript")
assert formatter.format_inline_marker(1) == "¹"
assert formatter.format_inline_marker(12) == "¹²"
-289
View File
@@ -1,289 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for CitationCollector optimization with reverse index cache.
Tests the O(1) URL lookup performance optimization via _url_to_index cache.
"""
from src.citations.collector import CitationCollector
class TestCitationCollectorOptimization:
"""Test CitationCollector reverse index cache optimization."""
def test_url_to_index_cache_initialization(self):
"""Test that _url_to_index is properly initialized."""
collector = CitationCollector()
assert hasattr(collector, "_url_to_index")
assert isinstance(collector._url_to_index, dict)
assert len(collector._url_to_index) == 0
def test_add_single_citation_updates_cache(self):
"""Test that adding a citation updates _url_to_index."""
collector = CitationCollector()
results = [
{
"url": "https://example.com",
"title": "Example",
"content": "Content",
"score": 0.9,
}
]
collector.add_from_search_results(results)
# Check cache is populated
assert "https://example.com" in collector._url_to_index
assert collector._url_to_index["https://example.com"] == 0
def test_add_multiple_citations_updates_cache_correctly(self):
"""Test that multiple citations are indexed correctly."""
collector = CitationCollector()
results = [
{
"url": f"https://example.com/{i}",
"title": f"Page {i}",
"content": f"Content {i}",
"score": 0.9,
}
for i in range(5)
]
collector.add_from_search_results(results)
# Check all URLs are indexed
assert len(collector._url_to_index) == 5
for i in range(5):
url = f"https://example.com/{i}"
assert collector._url_to_index[url] == i
def test_get_number_uses_cache_for_o1_lookup(self):
"""Test that get_number uses cache for O(1) lookup."""
collector = CitationCollector()
urls = [f"https://example.com/{i}" for i in range(100)]
results = [
{
"url": url,
"title": f"Title {i}",
"content": f"Content {i}",
"score": 0.9,
}
for i, url in enumerate(urls)
]
collector.add_from_search_results(results)
# Test lookup for various positions
assert collector.get_number("https://example.com/0") == 1
assert collector.get_number("https://example.com/50") == 51
assert collector.get_number("https://example.com/99") == 100
# Non-existent URL returns None
assert collector.get_number("https://nonexistent.com") is None
def test_add_from_crawl_result_updates_cache(self):
"""Test that add_from_crawl_result updates cache."""
collector = CitationCollector()
collector.add_from_crawl_result(
url="https://crawled.com/page",
title="Crawled Page",
content="Crawled content",
)
assert "https://crawled.com/page" in collector._url_to_index
assert collector._url_to_index["https://crawled.com/page"] == 0
def test_duplicate_url_does_not_change_cache(self):
"""Test that adding duplicate URLs doesn't change cache indices."""
collector = CitationCollector()
# Add first time
collector.add_from_search_results(
[
{
"url": "https://example.com",
"title": "Title 1",
"content": "Content 1",
"score": 0.8,
}
]
)
assert collector._url_to_index["https://example.com"] == 0
# Add same URL again with better score
collector.add_from_search_results(
[
{
"url": "https://example.com",
"title": "Title 1 Updated",
"content": "Content 1 Updated",
"score": 0.95,
}
]
)
# Cache index should not change
assert collector._url_to_index["https://example.com"] == 0
# But metadata should be updated
assert collector._citations["https://example.com"].relevance_score == 0.95
def test_merge_with_updates_cache_correctly(self):
"""Test that merge_with correctly updates cache for new URLs."""
collector1 = CitationCollector()
collector2 = CitationCollector()
# Add to collector1
collector1.add_from_search_results(
[
{
"url": "https://a.com",
"title": "A",
"content": "Content A",
"score": 0.9,
}
]
)
# Add to collector2
collector2.add_from_search_results(
[
{
"url": "https://b.com",
"title": "B",
"content": "Content B",
"score": 0.9,
}
]
)
collector1.merge_with(collector2)
# Both URLs should be in cache
assert "https://a.com" in collector1._url_to_index
assert "https://b.com" in collector1._url_to_index
assert collector1._url_to_index["https://a.com"] == 0
assert collector1._url_to_index["https://b.com"] == 1
def test_from_dict_rebuilds_cache(self):
"""Test that from_dict properly rebuilds cache."""
# Create original collector
original = CitationCollector()
original.add_from_search_results(
[
{
"url": f"https://example.com/{i}",
"title": f"Page {i}",
"content": f"Content {i}",
"score": 0.9,
}
for i in range(3)
]
)
# Serialize and deserialize
data = original.to_dict()
restored = CitationCollector.from_dict(data)
# Check cache is properly rebuilt
assert len(restored._url_to_index) == 3
for i in range(3):
url = f"https://example.com/{i}"
assert url in restored._url_to_index
assert restored._url_to_index[url] == i
def test_clear_resets_cache(self):
"""Test that clear() properly resets the cache."""
collector = CitationCollector()
collector.add_from_search_results(
[
{
"url": "https://example.com",
"title": "Example",
"content": "Content",
"score": 0.9,
}
]
)
assert len(collector._url_to_index) > 0
collector.clear()
assert len(collector._url_to_index) == 0
assert len(collector._citations) == 0
assert len(collector._citation_order) == 0
def test_cache_consistency_with_order_list(self):
"""Test that cache indices match positions in _citation_order."""
collector = CitationCollector()
urls = [f"https://example.com/{i}" for i in range(10)]
results = [
{
"url": url,
"title": f"Title {i}",
"content": f"Content {i}",
"score": 0.9,
}
for i, url in enumerate(urls)
]
collector.add_from_search_results(results)
# Verify cache indices match order list positions
for i, url in enumerate(collector._citation_order):
assert collector._url_to_index[url] == i
def test_mark_used_with_cache(self):
"""Test that mark_used works correctly with cache."""
collector = CitationCollector()
collector.add_from_search_results(
[
{
"url": "https://example.com/1",
"title": "Page 1",
"content": "Content 1",
"score": 0.9,
},
{
"url": "https://example.com/2",
"title": "Page 2",
"content": "Content 2",
"score": 0.9,
},
]
)
# Mark one as used
number = collector.mark_used("https://example.com/2")
assert number == 2
# Verify it's in used set
assert "https://example.com/2" in collector._used_citations
def test_large_collection_cache_performance(self):
"""Test that cache works correctly with large collections."""
collector = CitationCollector()
num_citations = 1000
results = [
{
"url": f"https://example.com/{i}",
"title": f"Title {i}",
"content": f"Content {i}",
"score": 0.9,
}
for i in range(num_citations)
]
collector.add_from_search_results(results)
# Verify cache size
assert len(collector._url_to_index) == num_citations
# Test lookups at various positions
test_indices = [0, 100, 500, 999]
for idx in test_indices:
url = f"https://example.com/{idx}"
assert collector.get_number(url) == idx + 1
-251
View File
@@ -1,251 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for extractor optimizations.
Tests the enhanced domain extraction and title extraction functions.
"""
from src.citations.extractor import (
_extract_domain,
extract_title_from_content,
)
class TestExtractDomainOptimization:
"""Test domain extraction with urllib + regex fallback strategy."""
def test_extract_domain_standard_urls(self):
"""Test extraction from standard URLs."""
assert _extract_domain("https://www.example.com/path") == "www.example.com"
assert _extract_domain("http://example.org") == "example.org"
assert _extract_domain("https://github.com/user/repo") == "github.com"
def test_extract_domain_with_port(self):
"""Test extraction from URLs with ports."""
assert _extract_domain("http://localhost:8080/api") == "localhost:8080"
assert (
_extract_domain("https://example.com:3000/page")
== "example.com:3000"
)
def test_extract_domain_with_subdomain(self):
"""Test extraction from URLs with subdomains."""
assert _extract_domain("https://api.github.com/repos") == "api.github.com"
assert (
_extract_domain("https://docs.python.org/en/")
== "docs.python.org"
)
def test_extract_domain_invalid_url(self):
"""Test handling of invalid URLs."""
# Should not crash, might return empty string
result = _extract_domain("not a url")
assert isinstance(result, str)
def test_extract_domain_empty_url(self):
"""Test handling of empty URL."""
assert _extract_domain("") == ""
def test_extract_domain_without_scheme(self):
"""Test extraction from URLs without scheme (handled by regex fallback)."""
# These may be handled by regex fallback
result = _extract_domain("example.com/path")
# Should at least not crash
assert isinstance(result, str)
def test_extract_domain_complex_urls(self):
"""Test extraction from complex URLs."""
# urllib includes credentials in netloc, so this is expected behavior
assert (
_extract_domain("https://user:pass@example.com/path")
== "user:pass@example.com"
)
assert (
_extract_domain("https://example.com:443/path?query=value#hash")
== "example.com:443"
)
def test_extract_domain_ipv4(self):
"""Test extraction from IPv4 addresses."""
result = _extract_domain("http://192.168.1.1:8080/")
# Should handle IP addresses
assert isinstance(result, str)
def test_extract_domain_query_params(self):
"""Test that query params don't affect domain extraction."""
url1 = "https://example.com/page?query=value"
url2 = "https://example.com/page"
assert _extract_domain(url1) == _extract_domain(url2)
def test_extract_domain_url_fragments(self):
"""Test that fragments don't affect domain extraction."""
url1 = "https://example.com/page#section"
url2 = "https://example.com/page"
assert _extract_domain(url1) == _extract_domain(url2)
class TestExtractTitleFromContent:
"""Test intelligent title extraction with 5-tier priority system."""
def test_extract_title_html_title_tag(self):
"""Test priority 1: HTML <title> tag extraction."""
content = "<html><head><title>HTML Title</title></head><body>Content</body></html>"
assert extract_title_from_content(content) == "HTML Title"
def test_extract_title_html_title_case_insensitive(self):
"""Test that HTML title extraction is case-insensitive."""
content = "<html><head><TITLE>HTML Title</TITLE></head><body></body></html>"
assert extract_title_from_content(content) == "HTML Title"
def test_extract_title_markdown_h1(self):
"""Test priority 2: Markdown h1 extraction."""
content = "# Main Title\n\nSome content here"
assert extract_title_from_content(content) == "Main Title"
def test_extract_title_markdown_h1_with_spaces(self):
"""Test markdown h1 with extra spaces."""
content = "# Title with Spaces \n\nContent"
assert extract_title_from_content(content) == "Title with Spaces"
def test_extract_title_markdown_h2_fallback(self):
"""Test priority 3: Markdown h2 as fallback when no h1."""
content = "## Second Level Title\n\nSome content"
assert extract_title_from_content(content) == "Second Level Title"
def test_extract_title_markdown_h6_fallback(self):
"""Test markdown h6 as fallback."""
content = "###### Small Heading\n\nContent"
assert extract_title_from_content(content) == "Small Heading"
def test_extract_title_prefers_h1_over_h2(self):
"""Test that h1 is preferred over h2."""
content = "# H1 Title\n## H2 Title\n\nContent"
assert extract_title_from_content(content) == "H1 Title"
def test_extract_title_json_field(self):
"""Test priority 4: JSON title field extraction."""
content = '{"title": "JSON Title", "content": "Some data"}'
assert extract_title_from_content(content) == "JSON Title"
def test_extract_title_yaml_field(self):
"""Test YAML title field extraction."""
content = 'title: "YAML Title"\ncontent: "Some data"'
assert extract_title_from_content(content) == "YAML Title"
def test_extract_title_first_substantial_line(self):
"""Test priority 5: First substantial non-empty line."""
content = "\n\n\nThis is the first substantial line\n\nMore content"
assert extract_title_from_content(content) == "This is the first substantial line"
def test_extract_title_skips_short_lines(self):
"""Test that short lines are skipped."""
content = "abc\nThis is a longer first substantial line\nContent"
assert extract_title_from_content(content) == "This is a longer first substantial line"
def test_extract_title_skips_code_blocks(self):
"""Test that code blocks are skipped."""
content = "```\ncode here\n```\nThis is the title\n\nContent"
result = extract_title_from_content(content)
# Should skip the code block and find the actual title
assert "title" in result.lower() or "code" not in result
def test_extract_title_skips_list_items(self):
"""Test that list items are skipped."""
content = "- Item 1\n- Item 2\nThis is the actual first substantial line\n\nContent"
result = extract_title_from_content(content)
assert "actual" in result or "Item" not in result
def test_extract_title_skips_separators(self):
"""Test that separator lines are skipped."""
content = "---\n\n***\n\nThis is the real title\n\nContent"
result = extract_title_from_content(content)
assert "---" not in result and "***" not in result
def test_extract_title_max_length(self):
"""Test that title respects max_length parameter."""
long_title = "A" * 300
content = f"# {long_title}"
result = extract_title_from_content(content, max_length=100)
assert len(result) <= 100
assert result == long_title[:100]
def test_extract_title_empty_content(self):
"""Test handling of empty content."""
assert extract_title_from_content("") == "Untitled"
assert extract_title_from_content(None) == "Untitled"
def test_extract_title_no_title_found(self):
"""Test fallback to 'Untitled' when no title can be extracted."""
content = "a\nb\nc\n" # Only short lines
result = extract_title_from_content(content)
# May return Untitled or one of the short lines
assert isinstance(result, str)
def test_extract_title_whitespace_handling(self):
"""Test that whitespace is properly handled."""
content = "# Title with extra spaces \n\nContent"
result = extract_title_from_content(content)
# Should normalize spaces
assert "Title with extra spaces" in result or len(result) > 5
def test_extract_title_multiline_html(self):
"""Test HTML title extraction across multiple lines."""
content = """
<html>
<head>
<title>
Multiline Title
</title>
</head>
<body>Content</body>
</html>
"""
result = extract_title_from_content(content)
# Should handle multiline titles
assert "Title" in result
def test_extract_title_mixed_formats(self):
"""Test content with mixed formats (h1 should win)."""
content = """
<title>HTML Title</title>
# Markdown H1
## Markdown H2
Some paragraph content
"""
# HTML title comes first in priority
assert extract_title_from_content(content) == "HTML Title"
def test_extract_title_real_world_example(self):
"""Test with real-world HTML example."""
content = """
<!DOCTYPE html>
<html>
<head>
<title>GitHub: Where the world builds software</title>
<meta property="og:title" content="GitHub">
</head>
<body>
<h1>Let's build from here</h1>
<p>The complete developer platform...</p>
</body>
</html>
"""
result = extract_title_from_content(content)
assert result == "GitHub: Where the world builds software"
def test_extract_title_json_with_nested_title(self):
"""Test JSON title extraction with nested structures."""
content = '{"meta": {"title": "Should not match"}, "title": "JSON Title"}'
result = extract_title_from_content(content)
# The regex will match the first "title" field it finds, which could be nested
# Just verify it finds a title field
assert result and result != "Untitled"
def test_extract_title_preserves_special_characters(self):
"""Test that special characters are preserved in title."""
content = "# Title with Special Characters: @#$%"
result = extract_title_from_content(content)
assert "@" in result or "$" in result or "%" in result or "Title" in result
-423
View File
@@ -1,423 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for citation formatter enhancements.
Tests the multi-format citation parsing and extraction capabilities.
"""
from src.citations.formatter import (
parse_citations_from_report,
_extract_markdown_links,
_extract_numbered_citations,
_extract_footnote_citations,
_extract_html_links,
)
class TestExtractMarkdownLinks:
"""Test Markdown link extraction [title](url)."""
def test_extract_single_markdown_link(self):
"""Test extraction of a single markdown link."""
text = "[Example Article](https://example.com)"
citations = _extract_markdown_links(text)
assert len(citations) == 1
assert citations[0]["title"] == "Example Article"
assert citations[0]["url"] == "https://example.com"
assert citations[0]["format"] == "markdown"
def test_extract_multiple_markdown_links(self):
"""Test extraction of multiple markdown links."""
text = "[Link 1](https://example.com/1) and [Link 2](https://example.com/2)"
citations = _extract_markdown_links(text)
assert len(citations) == 2
assert citations[0]["title"] == "Link 1"
assert citations[1]["title"] == "Link 2"
def test_extract_markdown_link_with_spaces(self):
"""Test markdown link with spaces in title."""
text = "[Article Title With Spaces](https://example.com)"
citations = _extract_markdown_links(text)
assert len(citations) == 1
assert citations[0]["title"] == "Article Title With Spaces"
def test_extract_markdown_link_ignore_non_http(self):
"""Test that non-HTTP URLs are ignored."""
text = "[Relative Link](./relative/path) [HTTP Link](https://example.com)"
citations = _extract_markdown_links(text)
assert len(citations) == 1
assert citations[0]["url"] == "https://example.com"
def test_extract_markdown_link_with_query_params(self):
"""Test markdown links with query parameters."""
text = "[Search Result](https://example.com/search?q=test&page=1)"
citations = _extract_markdown_links(text)
assert len(citations) == 1
assert "q=test" in citations[0]["url"]
def test_extract_markdown_link_empty_text(self):
"""Test with no markdown links."""
text = "Just plain text with no links"
citations = _extract_markdown_links(text)
assert len(citations) == 0
def test_extract_markdown_link_strip_whitespace(self):
"""Test that whitespace in title and URL is stripped."""
# Markdown links with spaces in URL are not valid, so they won't be extracted
text = "[Title](https://example.com)"
citations = _extract_markdown_links(text)
assert len(citations) >= 1
assert citations[0]["title"] == "Title"
assert citations[0]["url"] == "https://example.com"
class TestExtractNumberedCitations:
"""Test numbered citation extraction [1] Title - URL."""
def test_extract_single_numbered_citation(self):
"""Test extraction of a single numbered citation."""
text = "[1] Example Article - https://example.com"
citations = _extract_numbered_citations(text)
assert len(citations) == 1
assert citations[0]["title"] == "Example Article"
assert citations[0]["url"] == "https://example.com"
assert citations[0]["format"] == "numbered"
def test_extract_multiple_numbered_citations(self):
"""Test extraction of multiple numbered citations."""
text = "[1] First - https://example.com/1\n[2] Second - https://example.com/2"
citations = _extract_numbered_citations(text)
assert len(citations) == 2
assert citations[0]["title"] == "First"
assert citations[1]["title"] == "Second"
def test_extract_numbered_citation_with_long_title(self):
"""Test numbered citation with longer title."""
text = "[5] A Comprehensive Guide to Python Programming - https://example.com"
citations = _extract_numbered_citations(text)
assert len(citations) == 1
assert "Comprehensive Guide" in citations[0]["title"]
def test_extract_numbered_citation_requires_valid_format(self):
"""Test that invalid numbered format is not extracted."""
text = "[1 Title - https://example.com" # Missing closing bracket
citations = _extract_numbered_citations(text)
assert len(citations) == 0
def test_extract_numbered_citation_empty_text(self):
"""Test with no numbered citations."""
text = "Just plain text"
citations = _extract_numbered_citations(text)
assert len(citations) == 0
def test_extract_numbered_citation_various_numbers(self):
"""Test with various citation numbers."""
text = "[10] Title Ten - https://example.com/10\n[999] Title 999 - https://example.com/999"
citations = _extract_numbered_citations(text)
assert len(citations) == 2
def test_extract_numbered_citation_ignore_non_http(self):
"""Test that non-HTTP URLs in numbered citations are ignored."""
text = "[1] Invalid - file://path [2] Valid - https://example.com"
citations = _extract_numbered_citations(text)
# Only the valid one should be extracted
assert len(citations) <= 1
class TestExtractFootnoteCitations:
"""Test footnote citation extraction [^1]: Title - URL."""
def test_extract_single_footnote_citation(self):
"""Test extraction of a single footnote citation."""
text = "[^1]: Example Article - https://example.com"
citations = _extract_footnote_citations(text)
assert len(citations) == 1
assert citations[0]["title"] == "Example Article"
assert citations[0]["url"] == "https://example.com"
assert citations[0]["format"] == "footnote"
def test_extract_multiple_footnote_citations(self):
"""Test extraction of multiple footnote citations."""
text = "[^1]: First - https://example.com/1\n[^2]: Second - https://example.com/2"
citations = _extract_footnote_citations(text)
assert len(citations) == 2
def test_extract_footnote_with_complex_number(self):
"""Test footnote extraction with various numbers."""
text = "[^123]: Title - https://example.com"
citations = _extract_footnote_citations(text)
assert len(citations) == 1
assert citations[0]["title"] == "Title"
def test_extract_footnote_citation_with_spaces(self):
"""Test footnote with spaces around separator."""
text = "[^1]: Title with spaces - https://example.com "
citations = _extract_footnote_citations(text)
assert len(citations) == 1
# Should strip whitespace
assert citations[0]["title"] == "Title with spaces"
def test_extract_footnote_citation_empty_text(self):
"""Test with no footnote citations."""
text = "No footnotes here"
citations = _extract_footnote_citations(text)
assert len(citations) == 0
def test_extract_footnote_requires_caret(self):
"""Test that missing caret prevents extraction."""
text = "[1]: Title - https://example.com" # Missing ^
citations = _extract_footnote_citations(text)
assert len(citations) == 0
class TestExtractHtmlLinks:
"""Test HTML link extraction <a href="url">title</a>."""
def test_extract_single_html_link(self):
"""Test extraction of a single HTML link."""
text = '<a href="https://example.com">Example Article</a>'
citations = _extract_html_links(text)
assert len(citations) == 1
assert citations[0]["title"] == "Example Article"
assert citations[0]["url"] == "https://example.com"
assert citations[0]["format"] == "html"
def test_extract_multiple_html_links(self):
"""Test extraction of multiple HTML links."""
text = '<a href="https://a.com">Link A</a> <a href="https://b.com">Link B</a>'
citations = _extract_html_links(text)
assert len(citations) == 2
def test_extract_html_link_single_quotes(self):
"""Test HTML links with single quotes."""
text = "<a href='https://example.com'>Title</a>"
citations = _extract_html_links(text)
assert len(citations) == 1
assert citations[0]["url"] == "https://example.com"
def test_extract_html_link_with_attributes(self):
"""Test HTML links with additional attributes."""
text = '<a class="link" href="https://example.com" target="_blank">Title</a>'
citations = _extract_html_links(text)
assert len(citations) == 1
assert citations[0]["url"] == "https://example.com"
def test_extract_html_link_ignore_non_http(self):
"""Test that non-HTTP URLs are ignored."""
text = '<a href="mailto:test@example.com">Email</a> <a href="https://example.com">Web</a>'
citations = _extract_html_links(text)
assert len(citations) == 1
assert citations[0]["url"] == "https://example.com"
def test_extract_html_link_case_insensitive(self):
"""Test that HTML extraction is case-insensitive."""
text = '<A HREF="https://example.com">Title</A>'
citations = _extract_html_links(text)
assert len(citations) == 1
def test_extract_html_link_empty_text(self):
"""Test with no HTML links."""
text = "No links here"
citations = _extract_html_links(text)
assert len(citations) == 0
def test_extract_html_link_strip_whitespace(self):
"""Test that whitespace in title is stripped."""
text = '<a href="https://example.com"> Title with spaces </a>'
citations = _extract_html_links(text)
assert citations[0]["title"] == "Title with spaces"
class TestParseCitationsFromReport:
"""Test comprehensive citation parsing from complete reports."""
def test_parse_markdown_links_from_report(self):
"""Test parsing markdown links from a report."""
report = """
## Key Citations
[GitHub](https://github.com)
[Python Docs](https://python.org)
"""
result = parse_citations_from_report(report)
assert result["count"] >= 2
urls = [c["url"] for c in result["citations"]]
assert "https://github.com" in urls
def test_parse_numbered_citations_from_report(self):
"""Test parsing numbered citations."""
report = """
## References
[1] GitHub - https://github.com
[2] Python - https://python.org
"""
result = parse_citations_from_report(report)
assert result["count"] >= 2
def test_parse_mixed_format_citations(self):
"""Test parsing mixed citation formats."""
report = """
## Key Citations
[GitHub](https://github.com)
[^1]: Python - https://python.org
[2] Wikipedia - https://wikipedia.org
<a href="https://stackoverflow.com">Stack Overflow</a>
"""
result = parse_citations_from_report(report)
# Should find all 4 citations
assert result["count"] >= 3
def test_parse_citations_deduplication(self):
"""Test that duplicate URLs are deduplicated."""
report = """
## Key Citations
[GitHub 1](https://github.com)
[GitHub 2](https://github.com)
[GitHub](https://github.com)
"""
result = parse_citations_from_report(report)
# Should have only 1 unique citation
assert result["count"] == 1
assert result["citations"][0]["url"] == "https://github.com"
def test_parse_citations_various_section_patterns(self):
"""Test parsing with different section headers."""
report_refs = """
## References
[GitHub](https://github.com)
"""
report_sources = """
## Sources
[GitHub](https://github.com)
"""
report_bibliography = """
## Bibliography
[GitHub](https://github.com)
"""
assert parse_citations_from_report(report_refs)["count"] >= 1
assert parse_citations_from_report(report_sources)["count"] >= 1
assert parse_citations_from_report(report_bibliography)["count"] >= 1
def test_parse_citations_custom_patterns(self):
"""Test parsing with custom section patterns."""
report = """
## My Custom Sources
[GitHub](https://github.com)
"""
result = parse_citations_from_report(
report,
section_patterns=[r"##\s*My Custom Sources"]
)
assert result["count"] >= 1
def test_parse_citations_empty_report(self):
"""Test parsing an empty report."""
result = parse_citations_from_report("")
assert result["count"] == 0
assert result["citations"] == []
def test_parse_citations_no_section(self):
"""Test parsing report without citation section."""
report = "This is a report with no citations section"
result = parse_citations_from_report(report)
assert result["count"] == 0
def test_parse_citations_complex_report(self):
"""Test parsing a complex, realistic report."""
report = """
# Research Report
## Introduction
This report summarizes findings from multiple sources.
## Key Findings
Some important discoveries were made based on research [GitHub](https://github.com).
## Key Citations
1. Primary sources:
[GitHub](https://github.com) - A collaborative platform
[^1]: Python - https://python.org
2. Secondary sources:
[2] Wikipedia - https://wikipedia.org
3. Web resources:
<a href="https://stackoverflow.com">Stack Overflow</a>
## Methodology
[Additional](https://example.com) details about methodology.
---
[^1]: The Python programming language official site
"""
result = parse_citations_from_report(report)
# Should extract multiple citations from the Key Citations section
assert result["count"] >= 3
urls = [c["url"] for c in result["citations"]]
# Verify some key URLs are found
assert any("github.com" in url or "python.org" in url for url in urls)
def test_parse_citations_stops_at_next_section(self):
"""Test that citation extraction looks for citation sections."""
report = """
## Key Citations
[Cite 1](https://example.com/1)
[Cite 2](https://example.com/2)
## Next Section
Some other content
"""
result = parse_citations_from_report(report)
# Should extract citations from the Key Citations section
# Note: The regex stops at next ## section
assert result["count"] >= 1
assert any("example.com/1" in c["url"] for c in result["citations"])
def test_parse_citations_preserves_metadata(self):
"""Test that citation metadata is preserved."""
report = """
## Key Citations
[Python Documentation](https://python.org)
"""
result = parse_citations_from_report(report)
assert len(result["citations"]) >= 1
citation = result["citations"][0]
assert "title" in citation
assert "url" in citation
assert "format" in citation
def test_parse_citations_whitespace_handling(self):
"""Test handling of various whitespace configurations."""
report = """
## Key Citations
[Link](https://example.com)
"""
result = parse_citations_from_report(report)
assert result["count"] >= 1
def test_parse_citations_multiline_links(self):
"""Test extraction of links across formatting."""
report = """
## Key Citations
Some paragraph with a [link to example](https://example.com) in the middle.
"""
result = parse_citations_from_report(report)
assert result["count"] >= 1
-467
View File
@@ -1,467 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for citation models.
Tests the Pydantic BaseModel implementation of CitationMetadata and Citation classes.
"""
import json
import pytest
from pydantic import ValidationError
from src.citations.models import Citation, CitationMetadata
class TestCitationMetadata:
"""Test CitationMetadata Pydantic model."""
def test_create_basic_metadata(self):
"""Test creating basic citation metadata."""
metadata = CitationMetadata(
url="https://example.com/article",
title="Example Article",
)
assert metadata.url == "https://example.com/article"
assert metadata.title == "Example Article"
assert metadata.domain == "example.com" # Auto-extracted from URL
assert metadata.description is None
assert metadata.images == []
assert metadata.extra == {}
def test_metadata_with_all_fields(self):
"""Test creating metadata with all fields populated."""
metadata = CitationMetadata(
url="https://github.com/example/repo",
title="Example Repository",
description="A great repository",
content_snippet="This is a snippet",
raw_content="Full content here",
author="John Doe",
published_date="2025-01-24",
language="en",
relevance_score=0.95,
credibility_score=0.88,
)
assert metadata.url == "https://github.com/example/repo"
assert metadata.domain == "github.com"
assert metadata.author == "John Doe"
assert metadata.relevance_score == 0.95
assert metadata.credibility_score == 0.88
def test_metadata_domain_auto_extraction(self):
"""Test automatic domain extraction from URL."""
test_cases = [
("https://www.example.com/path", "www.example.com"),
("http://github.com/user/repo", "github.com"),
("https://api.github.com:443/repos", "api.github.com:443"),
]
for url, expected_domain in test_cases:
metadata = CitationMetadata(url=url, title="Test")
assert metadata.domain == expected_domain
def test_metadata_id_generation(self):
"""Test unique ID generation from URL."""
metadata1 = CitationMetadata(
url="https://example.com/article",
title="Article",
)
metadata2 = CitationMetadata(
url="https://example.com/article",
title="Article",
)
# Same URL should produce same ID
assert metadata1.id == metadata2.id
metadata3 = CitationMetadata(
url="https://different.com/article",
title="Article",
)
# Different URL should produce different ID
assert metadata1.id != metadata3.id
def test_metadata_id_length(self):
"""Test that ID is truncated to 12 characters."""
metadata = CitationMetadata(
url="https://example.com",
title="Test",
)
assert len(metadata.id) == 12
assert metadata.id.isalnum() or all(c in "0123456789abcdef" for c in metadata.id)
def test_metadata_from_dict(self):
"""Test creating metadata from dictionary."""
data = {
"url": "https://example.com",
"title": "Example",
"description": "A description",
"author": "John Doe",
}
metadata = CitationMetadata.from_dict(data)
assert metadata.url == "https://example.com"
assert metadata.title == "Example"
assert metadata.description == "A description"
assert metadata.author == "John Doe"
def test_metadata_from_dict_removes_id(self):
"""Test that from_dict removes computed 'id' field."""
data = {
"url": "https://example.com",
"title": "Example",
"id": "some_old_id", # Should be ignored
}
metadata = CitationMetadata.from_dict(data)
# Should use newly computed ID, not the old one
assert metadata.id != "some_old_id"
def test_metadata_to_dict(self):
"""Test converting metadata to dictionary."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
author="John Doe",
)
result = metadata.to_dict()
assert result["url"] == "https://example.com"
assert result["title"] == "Example"
assert result["author"] == "John Doe"
assert result["id"] == metadata.id
assert result["domain"] == "example.com"
def test_metadata_from_search_result(self):
"""Test creating metadata from search result."""
search_result = {
"url": "https://example.com/article",
"title": "Article Title",
"content": "Article content here",
"score": 0.92,
"type": "page",
}
metadata = CitationMetadata.from_search_result(
search_result,
query="test query",
)
assert metadata.url == "https://example.com/article"
assert metadata.title == "Article Title"
assert metadata.description == "Article content here"
assert metadata.relevance_score == 0.92
assert metadata.extra["query"] == "test query"
assert metadata.extra["result_type"] == "page"
def test_metadata_pydantic_validation(self):
"""Test that Pydantic validates required fields."""
# URL and title are required
with pytest.raises(ValidationError):
CitationMetadata() # Missing required fields
with pytest.raises(ValidationError):
CitationMetadata(url="https://example.com") # Missing title
def test_metadata_model_dump(self):
"""Test Pydantic model_dump method."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
author="John Doe",
)
result = metadata.model_dump()
assert isinstance(result, dict)
assert result["url"] == "https://example.com"
assert result["title"] == "Example"
def test_metadata_model_dump_json(self):
"""Test Pydantic model_dump_json method."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
result = metadata.model_dump_json()
assert isinstance(result, str)
data = json.loads(result)
assert data["url"] == "https://example.com"
assert data["title"] == "Example"
def test_metadata_with_images_and_extra(self):
"""Test metadata with list and dict fields."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
images=["https://example.com/image1.jpg", "https://example.com/image2.jpg"],
favicon="https://example.com/favicon.ico",
extra={"custom_field": "value", "tags": ["tag1", "tag2"]},
)
assert len(metadata.images) == 2
assert metadata.favicon == "https://example.com/favicon.ico"
assert metadata.extra["custom_field"] == "value"
class TestCitation:
"""Test Citation Pydantic model."""
def test_create_basic_citation(self):
"""Test creating a basic citation."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
citation = Citation(number=1, metadata=metadata)
assert citation.number == 1
assert citation.metadata == metadata
assert citation.context is None
assert citation.cited_text is None
def test_citation_properties(self):
"""Test citation property shortcuts."""
metadata = CitationMetadata(
url="https://example.com",
title="Example Title",
)
citation = Citation(number=1, metadata=metadata)
assert citation.id == metadata.id
assert citation.url == "https://example.com"
assert citation.title == "Example Title"
def test_citation_to_markdown_reference(self):
"""Test markdown reference generation."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
citation = Citation(number=1, metadata=metadata)
result = citation.to_markdown_reference()
assert result == "[Example](https://example.com)"
def test_citation_to_numbered_reference(self):
"""Test numbered reference generation."""
metadata = CitationMetadata(
url="https://example.com",
title="Example Article",
)
citation = Citation(number=5, metadata=metadata)
result = citation.to_numbered_reference()
assert result == "[5] Example Article - https://example.com"
def test_citation_to_inline_marker(self):
"""Test inline marker generation."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
citation = Citation(number=3, metadata=metadata)
result = citation.to_inline_marker()
assert result == "[^3]"
def test_citation_to_footnote(self):
"""Test footnote generation."""
metadata = CitationMetadata(
url="https://example.com",
title="Example Article",
)
citation = Citation(number=2, metadata=metadata)
result = citation.to_footnote()
assert result == "[^2]: Example Article - https://example.com"
def test_citation_with_context_and_text(self):
"""Test citation with context and cited text."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
citation = Citation(
number=1,
metadata=metadata,
context="This is important context",
cited_text="Important quote from the source",
)
assert citation.context == "This is important context"
assert citation.cited_text == "Important quote from the source"
def test_citation_from_dict(self):
"""Test creating citation from dictionary."""
data = {
"number": 1,
"metadata": {
"url": "https://example.com",
"title": "Example",
"author": "John Doe",
},
"context": "Test context",
}
citation = Citation.from_dict(data)
assert citation.number == 1
assert citation.metadata.url == "https://example.com"
assert citation.metadata.title == "Example"
assert citation.metadata.author == "John Doe"
assert citation.context == "Test context"
def test_citation_to_dict(self):
"""Test converting citation to dictionary."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
author="John Doe",
)
citation = Citation(
number=1,
metadata=metadata,
context="Test context",
)
result = citation.to_dict()
assert result["number"] == 1
assert result["metadata"]["url"] == "https://example.com"
assert result["metadata"]["author"] == "John Doe"
assert result["context"] == "Test context"
def test_citation_round_trip(self):
"""Test converting to dict and back."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
author="John Doe",
relevance_score=0.95,
)
original = Citation(number=1, metadata=metadata, context="Test")
# Convert to dict and back
dict_repr = original.to_dict()
restored = Citation.from_dict(dict_repr)
assert restored.number == original.number
assert restored.metadata.url == original.metadata.url
assert restored.metadata.title == original.metadata.title
assert restored.metadata.author == original.metadata.author
assert restored.metadata.relevance_score == original.metadata.relevance_score
def test_citation_model_dump(self):
"""Test Pydantic model_dump method."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
citation = Citation(number=1, metadata=metadata)
result = citation.model_dump()
assert isinstance(result, dict)
assert result["number"] == 1
assert result["metadata"]["url"] == "https://example.com"
def test_citation_model_dump_json(self):
"""Test Pydantic model_dump_json method."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
citation = Citation(number=1, metadata=metadata)
result = citation.model_dump_json()
assert isinstance(result, str)
data = json.loads(result)
assert data["number"] == 1
assert data["metadata"]["url"] == "https://example.com"
def test_citation_pydantic_validation(self):
"""Test that Pydantic validates required fields."""
# Number and metadata are required
with pytest.raises(ValidationError):
Citation() # Missing required fields
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
with pytest.raises(ValidationError):
Citation(metadata=metadata) # Missing number
class TestCitationIntegration:
"""Integration tests for citation models."""
def test_search_result_to_citation_workflow(self):
"""Test complete workflow from search result to citation."""
search_result = {
"url": "https://example.com/article",
"title": "Great Article",
"content": "This is a great article about testing",
"score": 0.92,
}
# Create metadata from search result
metadata = CitationMetadata.from_search_result(search_result, query="testing")
# Create citation
citation = Citation(number=1, metadata=metadata, context="Important source")
# Verify the workflow
assert citation.number == 1
assert citation.url == "https://example.com/article"
assert citation.title == "Great Article"
assert citation.metadata.relevance_score == 0.92
assert citation.to_markdown_reference() == "[Great Article](https://example.com/article)"
def test_multiple_citations_with_different_formats(self):
"""Test handling multiple citations in different formats."""
citations = []
# Create first citation
metadata1 = CitationMetadata(
url="https://example.com/1",
title="First Article",
)
citations.append(Citation(number=1, metadata=metadata1))
# Create second citation
metadata2 = CitationMetadata(
url="https://example.com/2",
title="Second Article",
)
citations.append(Citation(number=2, metadata=metadata2))
# Verify all reference formats
assert citations[0].to_markdown_reference() == "[First Article](https://example.com/1)"
assert citations[1].to_numbered_reference() == "[2] Second Article - https://example.com/2"
def test_citation_json_serialization_roundtrip(self):
"""Test JSON serialization and deserialization roundtrip."""
original_data = {
"number": 1,
"metadata": {
"url": "https://example.com",
"title": "Example",
"author": "John Doe",
"relevance_score": 0.95,
},
"context": "Test context",
"cited_text": "Important quote",
}
# Create from dict
citation = Citation.from_dict(original_data)
# Serialize to JSON
json_str = citation.model_dump_json()
# Deserialize from JSON
restored = Citation.model_validate_json(json_str)
# Verify data integrity
assert restored.number == original_data["number"]
assert restored.metadata.url == original_data["metadata"]["url"]
assert restored.metadata.relevance_score == original_data["metadata"]["relevance_score"]
assert restored.context == original_data["context"]
-183
View File
@@ -1,183 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import sys
import types
from src.config.configuration import Configuration
# Patch sys.path so relative import works
# Patch Resource for import
mock_resource = type("Resource", (), {})
# Patch src.rag.retriever.Resource for import
module_name = "src.rag.retriever"
if module_name not in sys.modules:
retriever_mod = types.ModuleType(module_name)
retriever_mod.Resource = mock_resource
sys.modules[module_name] = retriever_mod
# Relative import of Configuration
def test_default_configuration():
config = Configuration()
assert config.resources == []
assert config.max_plan_iterations == 1
assert config.max_step_num == 3
assert config.max_search_results == 3
assert config.mcp_settings is None
def test_from_runnable_config_with_config_dict(monkeypatch):
config_dict = {
"configurable": {
"max_plan_iterations": 5,
"max_step_num": 7,
"max_search_results": 10,
"mcp_settings": {"foo": "bar"},
}
}
config = Configuration.from_runnable_config(config_dict)
assert config.max_plan_iterations == 5
assert config.max_step_num == 7
assert config.max_search_results == 10
assert config.mcp_settings == {"foo": "bar"}
def test_from_runnable_config_with_env_override(monkeypatch):
monkeypatch.setenv("MAX_PLAN_ITERATIONS", "9")
monkeypatch.setenv("MAX_STEP_NUM", "11")
config_dict = {
"configurable": {
"max_plan_iterations": 2,
"max_step_num": 3,
"max_search_results": 4,
}
}
config = Configuration.from_runnable_config(config_dict)
# Environment variables take precedence and are strings
assert config.max_plan_iterations == "9"
assert config.max_step_num == "11"
assert config.max_search_results == 4 # not overridden
# Clean up
monkeypatch.delenv("MAX_PLAN_ITERATIONS")
monkeypatch.delenv("MAX_STEP_NUM")
def test_from_runnable_config_with_none_and_falsy(monkeypatch):
"""Test that None values are skipped but falsy values (0, False, empty string) are preserved."""
config_dict = {
"configurable": {
"max_plan_iterations": None, # None should be skipped, use default
"max_step_num": 0, # 0 is valid, should be preserved
"max_search_results": "", # Empty string should be preserved
}
}
config = Configuration.from_runnable_config(config_dict)
# None values should fall back to defaults
assert config.max_plan_iterations == 1
# Falsy but valid values should be preserved
assert config.max_step_num == 0
assert config.max_search_results == ""
def test_from_runnable_config_with_no_config():
config = Configuration.from_runnable_config()
assert config.max_plan_iterations == 1
assert config.max_step_num == 3
assert config.max_search_results == 3
assert config.resources == []
assert config.mcp_settings is None
def test_from_runnable_config_with_boolean_false_values():
"""Test that boolean False values are correctly preserved and not filtered out.
This is a regression test for the bug where False values were treated as falsy
and filtered out, causing fields to revert to their default values.
"""
config_dict = {
"configurable": {
"enable_web_search": False, # Should be preserved as False, not revert to True
"enable_deep_thinking": False, # Should be preserved as False
"enforce_web_search": False, # Should be preserved as False
"enforce_researcher_search": False, # Should be preserved as False
"max_plan_iterations": 5, # Control: non-falsy value
}
}
config = Configuration.from_runnable_config(config_dict)
# Assert that False values are preserved
assert config.enable_web_search is False, "enable_web_search should be False, not default True"
assert config.enable_deep_thinking is False, "enable_deep_thinking should be False"
assert config.enforce_web_search is False, "enforce_web_search should be False"
assert config.enforce_researcher_search is False, "enforce_researcher_search should be False, not default True"
# Control: verify non-falsy values still work
assert config.max_plan_iterations == 5
def test_from_runnable_config_with_boolean_true_values():
"""Test that boolean True values work correctly (control test)."""
config_dict = {
"configurable": {
"enable_web_search": True,
"enable_deep_thinking": True,
"enforce_web_search": True,
}
}
config = Configuration.from_runnable_config(config_dict)
assert config.enable_web_search is True
assert config.enable_deep_thinking is True
assert config.enforce_web_search is True
def test_get_recursion_limit_default(monkeypatch):
from src.config.configuration import get_recursion_limit
monkeypatch.delenv("AGENT_RECURSION_LIMIT", raising=False)
result = get_recursion_limit()
assert result == 25
def test_get_recursion_limit_custom_default(monkeypatch):
from src.config.configuration import get_recursion_limit
monkeypatch.delenv("AGENT_RECURSION_LIMIT", raising=False)
result = get_recursion_limit(50)
assert result == 50
def test_get_recursion_limit_from_env(monkeypatch):
from src.config.configuration import get_recursion_limit
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "100")
result = get_recursion_limit()
assert result == 100
def test_get_recursion_limit_invalid_env_value(monkeypatch):
from src.config.configuration import get_recursion_limit
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "invalid")
result = get_recursion_limit()
assert result == 25
def test_get_recursion_limit_negative_env_value(monkeypatch):
from src.config.configuration import get_recursion_limit
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "-5")
result = get_recursion_limit()
assert result == 25
def test_get_recursion_limit_zero_env_value(monkeypatch):
from src.config.configuration import get_recursion_limit
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "0")
result = get_recursion_limit()
assert result == 25
-82
View File
@@ -1,82 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
import tempfile
from src.config.loader import load_yaml_config, process_dict, replace_env_vars
def test_replace_env_vars_with_env(monkeypatch):
monkeypatch.setenv("TEST_ENV", "env_value")
assert replace_env_vars("$TEST_ENV") == "env_value"
def test_replace_env_vars_without_env(monkeypatch):
monkeypatch.delenv("NOT_SET_ENV", raising=False)
assert replace_env_vars("$NOT_SET_ENV") == "NOT_SET_ENV"
def test_replace_env_vars_non_string():
assert replace_env_vars(123) == 123
def test_replace_env_vars_regular_string():
assert replace_env_vars("no_env") == "no_env"
def test_process_dict_nested(monkeypatch):
monkeypatch.setenv("FOO", "bar")
config = {"a": "$FOO", "b": {"c": "$FOO", "d": 42, "e": "$NOT_SET_ENV"}}
processed = process_dict(config)
assert processed["a"] == "bar"
assert processed["b"]["c"] == "bar"
assert processed["b"]["d"] == 42
assert processed["b"]["e"] == "NOT_SET_ENV"
def test_process_dict_empty():
assert process_dict({}) == {}
def test_load_yaml_config_file_not_exist():
assert load_yaml_config("non_existent_file.yaml") == {}
def test_load_yaml_config(monkeypatch):
monkeypatch.setenv("MY_ENV", "my_value")
yaml_content = """
key1: value1
key2: $MY_ENV
nested:
key3: $MY_ENV
key4: 123
"""
with tempfile.NamedTemporaryFile("w+", delete=False) as tmp:
tmp.write(yaml_content)
tmp_path = tmp.name
try:
config = load_yaml_config(tmp_path)
assert config["key1"] == "value1"
assert config["key2"] == "my_value"
assert config["nested"]["key3"] == "my_value"
assert config["nested"]["key4"] == 123
finally:
os.remove(tmp_path)
def test_load_yaml_config_cache(monkeypatch):
monkeypatch.setenv("CACHE_ENV", "cache_value")
yaml_content = "foo: $CACHE_ENV"
with tempfile.NamedTemporaryFile("w+", delete=False) as tmp:
tmp.write(yaml_content)
tmp_path = tmp.name
try:
config1 = load_yaml_config(tmp_path)
config2 = load_yaml_config(tmp_path)
assert config1 is config2 # Should be cached (same object)
assert config1["foo"] == "cache_value"
finally:
os.remove(tmp_path)
-113
View File
@@ -1,113 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from src.crawler.article import Article
class DummyMarkdownify:
"""A dummy markdownify replacement for patching if needed."""
@staticmethod
def markdownify(html):
return html
def test_to_markdown_includes_title(monkeypatch):
article = Article("Test Title", "<p>Hello <b>world</b>!</p>")
result = article.to_markdown(including_title=True)
assert result.startswith("# Test Title")
assert "Hello" in result
def test_to_markdown_excludes_title():
article = Article("Test Title", "<p>Hello <b>world</b>!</p>")
result = article.to_markdown(including_title=False)
assert not result.startswith("# Test Title")
assert "Hello" in result
def test_to_message_with_text_only():
article = Article("Test Title", "<p>Hello world!</p>")
article.url = "https://example.com/"
result = article.to_message()
assert isinstance(result, list)
assert any(item["type"] == "text" for item in result)
assert all("type" in item for item in result)
def test_to_message_with_image(monkeypatch):
html = '<p>Intro</p><img src="img/pic.png"/>'
article = Article("Title", html)
article.url = "https://host.com/path/"
# The markdownify library will convert <img> to markdown image syntax
result = article.to_message()
# Should have both text and image_url types
types = [item["type"] for item in result]
assert "image_url" in types
assert "text" in types
# Check that the image_url is correctly joined
image_items = [item for item in result if item["type"] == "image_url"]
assert image_items
assert image_items[0]["image_url"]["url"] == "https://host.com/path/img/pic.png"
def test_to_message_multiple_images():
html = '<p>Start</p><img src="a.png"/><p>Mid</p><img src="b.jpg"/>End'
article = Article("Title", html)
article.url = "http://x/"
result = article.to_message()
image_urls = [
item["image_url"]["url"] for item in result if item["type"] == "image_url"
]
assert "http://x/a.png" in image_urls
assert "http://x/b.jpg" in image_urls
text_items = [item for item in result if item["type"] == "text"]
assert any("Start" in item["text"] for item in text_items)
assert any("Mid" in item["text"] for item in text_items)
def test_to_message_handles_empty_html():
article = Article("Empty", "")
article.url = "http://test/"
result = article.to_message()
assert isinstance(result, list)
assert result[0]["type"] == "text"
def test_to_markdown_handles_none_content():
article = Article("Test Title", None)
result = article.to_markdown(including_title=True)
assert "# Test Title" in result
assert "No content available" in result
def test_to_markdown_handles_empty_string():
article = Article("Test Title", "")
result = article.to_markdown(including_title=True)
assert "# Test Title" in result
assert "No content available" in result
def test_to_markdown_handles_whitespace_only():
article = Article("Test Title", " \n \t ")
result = article.to_markdown(including_title=True)
assert "# Test Title" in result
assert "No content available" in result
def test_to_message_handles_none_content():
article = Article("Title", None)
article.url = "http://test/"
result = article.to_message()
assert isinstance(result, list)
assert len(result) > 0
assert result[0]["type"] == "text"
assert "No content available" in result[0]["text"]
def test_to_message_handles_whitespace_only_content():
article = Article("Title", " \n ")
article.url = "http://test/"
result = article.to_message()
assert isinstance(result, list)
assert result[0]["type"] == "text"
assert "No content available" in result[0]["text"]
-675
View File
@@ -1,675 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import src.crawler as crawler_module
from src.crawler.crawler import safe_truncate
from src.crawler.infoquest_client import InfoQuestClient
def test_crawler_sets_article_url(monkeypatch):
"""Test that the crawler sets the article.url field correctly."""
class DummyArticle:
def __init__(self):
self.url = None
def to_markdown(self):
return "# Dummy"
class DummyJinaClient:
def crawl(self, url, return_format=None):
return "<html>dummy</html>"
class DummyInfoQuestClient:
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
pass
def crawl(self, url, return_format=None):
return "<html>dummy</html>"
class DummyReadabilityExtractor:
def extract_article(self, html):
return DummyArticle()
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "jina"}}
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
monkeypatch.setattr(
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com"
article = crawler.crawl(url)
assert article.url == url
assert article.to_markdown() == "# Dummy"
def test_crawler_calls_dependencies(monkeypatch):
"""Test that Crawler calls JinaClient.crawl and ReadabilityExtractor.extract_article."""
calls = {}
class DummyJinaClient:
def crawl(self, url, return_format=None):
calls["jina"] = (url, return_format)
return "<html>dummy</html>"
# Fix: Update DummyInfoQuestClient to accept initialization parameters
class DummyInfoQuestClient:
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
# We don't need to use these parameters, just accept them
pass
def crawl(self, url, return_format=None):
calls["infoquest"] = (url, return_format)
return "<html>dummy</html>"
class DummyReadabilityExtractor:
def extract_article(self, html):
calls["extractor"] = html
class DummyArticle:
url = None
def to_markdown(self):
return "# Dummy"
return DummyArticle()
# Add mock for load_yaml_config to ensure it returns configuration with Jina engine
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "jina"}}
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient) # Include this if InfoQuest might be used
monkeypatch.setattr("src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com"
crawler.crawl(url)
assert "jina" in calls
assert calls["jina"][0] == url
assert calls["jina"][1] == "html"
assert "extractor" in calls
assert calls["extractor"] == "<html>dummy</html>"
def test_crawler_handles_empty_content(monkeypatch):
"""Test that the crawler handles empty content gracefully."""
class DummyArticle:
def __init__(self, title, html_content):
self.title = title
self.html_content = html_content
self.url = None
def to_markdown(self):
return f"# {self.title}"
class DummyJinaClient:
def crawl(self, url, return_format=None):
return "" # Empty content
class DummyReadabilityExtractor:
def extract_article(self, html):
# This should not be called for empty content
assert False, "ReadabilityExtractor should not be called for empty content"
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "jina"}}
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
monkeypatch.setattr(
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com"
article = crawler.crawl(url)
assert article.url == url
assert article.title == "Empty Content"
assert "No content could be extracted from this page" in article.html_content
def test_crawler_handles_error_response_from_client(monkeypatch):
"""Test that the crawler handles error responses from the client gracefully."""
class DummyArticle:
def __init__(self, title, html_content):
self.title = title
self.html_content = html_content
self.url = None
def to_markdown(self):
return f"# {self.title}"
class DummyJinaClient:
def crawl(self, url, return_format=None):
return "Error: API returned status 500"
class DummyReadabilityExtractor:
def extract_article(self, html):
# This should not be called for error responses
assert False, "ReadabilityExtractor should not be called for error responses"
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "jina"}}
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
monkeypatch.setattr(
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com"
article = crawler.crawl(url)
assert article.url == url
assert article.title in ["Non-HTML Content", "Content Extraction Failed"]
assert "Error: API returned status 500" in article.html_content
def test_crawler_handles_non_html_content(monkeypatch):
"""Test that the crawler handles non-HTML content gracefully."""
class DummyArticle:
def __init__(self, title, html_content):
self.title = title
self.html_content = html_content
self.url = None
def to_markdown(self):
return f"# {self.title}"
class DummyJinaClient:
def crawl(self, url, return_format=None):
return "This is plain text content, not HTML"
class DummyReadabilityExtractor:
def extract_article(self, html):
# This should not be called for non-HTML content
assert False, "ReadabilityExtractor should not be called for non-HTML content"
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "jina"}}
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
monkeypatch.setattr(
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com"
article = crawler.crawl(url)
assert article.url == url
assert article.title in ["Non-HTML Content", "Content Extraction Failed"]
assert "cannot be parsed as HTML" in article.html_content or "Content extraction failed" in article.html_content
assert "plain text content" in article.html_content # Should include a snippet of the original content
def test_crawler_handles_extraction_failure(monkeypatch):
"""Test that the crawler handles readability extraction failure gracefully."""
class DummyArticle:
def __init__(self, title, html_content):
self.title = title
self.html_content = html_content
self.url = None
def to_markdown(self):
return f"# {self.title}"
class DummyJinaClient:
def crawl(self, url, return_format=None):
return "<html><body>Valid HTML but extraction will fail</body></html>"
class DummyReadabilityExtractor:
def extract_article(self, html):
raise Exception("Extraction failed")
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "jina"}}
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
monkeypatch.setattr(
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com"
article = crawler.crawl(url)
assert article.url == url
assert article.title == "Content Extraction Failed"
assert "Content extraction failed" in article.html_content
assert "Valid HTML but extraction will fail" in article.html_content # Should include a snippet of the HTML
def test_crawler_with_json_like_content(monkeypatch):
"""Test that the crawler handles JSON-like content gracefully."""
class DummyArticle:
def __init__(self, title, html_content):
self.title = title
self.html_content = html_content
self.url = None
def to_markdown(self):
return f"# {self.title}"
class DummyJinaClient:
def crawl(self, url, return_format=None):
return '{"title": "Some JSON", "content": "This is JSON content"}'
class DummyReadabilityExtractor:
def extract_article(self, html):
# This should not be called for JSON content
assert False, "ReadabilityExtractor should not be called for JSON content"
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "jina"}}
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
monkeypatch.setattr(
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com/api/data"
article = crawler.crawl(url)
assert article.url == url
assert article.title in ["Non-HTML Content", "Content Extraction Failed"]
assert "cannot be parsed as HTML" in article.html_content or "Content extraction failed" in article.html_content
assert '{"title": "Some JSON"' in article.html_content # Should include a snippet of the JSON
def test_crawler_with_various_html_formats(monkeypatch):
"""Test that the crawler correctly identifies various HTML formats."""
class DummyArticle:
def __init__(self, title, html_content):
self.title = title
self.html_content = html_content
self.url = None
def to_markdown(self):
return f"# {self.title}"
# Test case 1: HTML with DOCTYPE
class DummyJinaClient1:
def crawl(self, url, return_format=None):
return "<!DOCTYPE html><html><body><p>Test content</p></body></html>"
# Test case 2: HTML with leading whitespace
class DummyJinaClient2:
def crawl(self, url, return_format=None):
return "\n\n <html><body><p>Test content</p></body></html>"
# Test case 3: HTML with comments
class DummyJinaClient3:
def crawl(self, url, return_format=None):
return "<!-- HTML comment --><html><body><p>Test content</p></body></html>"
# Test case 4: HTML with self-closing tags
class DummyJinaClient4:
def crawl(self, url, return_format=None):
return '<img src="test.jpg" alt="test" /><p>Test content</p>'
class DummyReadabilityExtractor:
def extract_article(self, html):
return DummyArticle("Extracted Article", "<p>Extracted content</p>")
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "jina"}}
# Test each HTML format
test_cases = [
(DummyJinaClient1, "HTML with DOCTYPE"),
(DummyJinaClient2, "HTML with leading whitespace"),
(DummyJinaClient3, "HTML with comments"),
(DummyJinaClient4, "HTML with self-closing tags"),
]
for JinaClientClass, description in test_cases:
monkeypatch.setattr("src.crawler.crawler.JinaClient", JinaClientClass)
monkeypatch.setattr("src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com"
article = crawler.crawl(url)
assert article.url == url
assert article.title == "Extracted Article"
assert "Extracted content" in article.html_content
def test_safe_truncate_function():
"""Test the safe_truncate function handles various character sets correctly."""
# Test None input
assert safe_truncate(None) is None
# Test empty string
assert safe_truncate("") == ""
# Test string shorter than limit
assert safe_truncate("Short text") == "Short text"
# Test ASCII truncation
result = safe_truncate("This is a longer text that needs truncation", 20)
assert len(result) <= 20
assert "..." in result
# Test Unicode/emoji characters
text_with_emoji = "Hello! 🌍 Welcome to the world 🚀"
result = safe_truncate(text_with_emoji, 20)
assert len(result) <= 20
assert "..." in result
# Verify it's valid UTF-8
assert result.encode('utf-8').decode('utf-8') == result
# Test very small limit
assert safe_truncate("Long text", 1) == "."
assert safe_truncate("Long text", 2) == ".."
assert safe_truncate("Long text", 3) == "..."
# Test with Chinese characters
chinese_text = "这是一个中文测试文本"
result = safe_truncate(chinese_text, 10)
assert len(result) <= 10
# Verify it's valid UTF-8
assert result.encode('utf-8').decode('utf-8') == result
# ========== InfoQuest Client Tests ==========
def test_crawler_selects_infoquest_engine(monkeypatch):
"""Test that the crawler selects InfoQuestClient when configured to use it."""
calls = {}
class DummyJinaClient:
def crawl(self, url, return_format=None):
calls["jina"] = True
return "<html>dummy</html>"
class DummyInfoQuestClient:
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
calls["infoquest_init"] = (fetch_time, timeout, navi_timeout)
def crawl(self, url, return_format=None):
calls["infoquest"] = (url, return_format)
return "<html>dummy from infoquest</html>"
class DummyReadabilityExtractor:
def extract_article(self, html):
calls["extractor"] = html
class DummyArticle:
url = None
def to_markdown(self):
return "# Dummy"
return DummyArticle()
# Mock configuration to use InfoQuest engine with custom parameters
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {
"engine": "infoquest",
"fetch_time": 30,
"timeout": 60,
"navi_timeout": 45
}}
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
monkeypatch.setattr("src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com"
crawler.crawl(url)
# Verify InfoQuestClient was used, not JinaClient
assert "infoquest_init" in calls
assert calls["infoquest_init"] == (30, 60, 45) # Verify parameters were passed correctly
assert "infoquest" in calls
assert calls["infoquest"][0] == url
assert calls["infoquest"][1] == "html"
assert "extractor" in calls
assert calls["extractor"] == "<html>dummy from infoquest</html>"
assert "jina" not in calls
def test_crawler_with_infoquest_empty_content(monkeypatch):
"""Test that the crawler handles empty content from InfoQuest client gracefully."""
class DummyArticle:
def __init__(self, title, html_content):
self.title = title
self.html_content = html_content
self.url = None
def to_markdown(self):
return f"# {self.title}"
class DummyInfoQuestClient:
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
pass
def crawl(self, url, return_format=None):
return "" # Empty content
class DummyReadabilityExtractor:
def extract_article(self, html):
# This should not be called for empty content
assert False, "ReadabilityExtractor should not be called for empty content"
# Mock configuration to use InfoQuest engine
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "infoquest"}}
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
monkeypatch.setattr(
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com"
article = crawler.crawl(url)
assert article.url == url
assert article.title == "Empty Content"
assert "No content could be extracted from this page" in article.html_content
def test_crawler_with_infoquest_non_html_content(monkeypatch):
"""Test that the crawler handles non-HTML content from InfoQuest client gracefully."""
class DummyArticle:
def __init__(self, title, html_content):
self.title = title
self.html_content = html_content
self.url = None
def to_markdown(self):
return f"# {self.title}"
class DummyInfoQuestClient:
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
pass
def crawl(self, url, return_format=None):
return "This is plain text content from InfoQuest, not HTML"
class DummyReadabilityExtractor:
def extract_article(self, html):
# This should not be called for non-HTML content
assert False, "ReadabilityExtractor should not be called for non-HTML content"
# Mock configuration to use InfoQuest engine
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "infoquest"}}
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
monkeypatch.setattr(
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com"
article = crawler.crawl(url)
assert article.url == url
assert article.title in ["Non-HTML Content", "Content Extraction Failed"]
assert "cannot be parsed as HTML" in article.html_content or "Content extraction failed" in article.html_content
assert "plain text content from InfoQuest" in article.html_content
def test_crawler_with_infoquest_error_response(monkeypatch):
"""Test that the crawler handles error responses from InfoQuest client gracefully."""
class DummyArticle:
def __init__(self, title, html_content):
self.title = title
self.html_content = html_content
self.url = None
def to_markdown(self):
return f"# {self.title}"
class DummyInfoQuestClient:
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
pass
def crawl(self, url, return_format=None):
return "Error: InfoQuest API returned status 403: Forbidden"
class DummyReadabilityExtractor:
def extract_article(self, html):
# This should not be called for error responses
assert False, "ReadabilityExtractor should not be called for error responses"
# Mock configuration to use InfoQuest engine
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "infoquest"}}
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
monkeypatch.setattr(
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com"
article = crawler.crawl(url)
assert article.url == url
assert article.title in ["Non-HTML Content", "Content Extraction Failed"]
assert "Error: InfoQuest API returned status 403: Forbidden" in article.html_content
def test_crawler_with_infoquest_json_response(monkeypatch):
"""Test that the crawler handles JSON responses from InfoQuest client correctly."""
class DummyArticle:
def __init__(self, title, html_content):
self.title = title
self.html_content = html_content
self.url = None
def to_markdown(self):
return f"# {self.title}"
class DummyInfoQuestClient:
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
pass
def crawl(self, url, return_format=None):
return "<html><body>Content from InfoQuest JSON</body></html>"
class DummyReadabilityExtractor:
def extract_article(self, html):
return DummyArticle("Extracted from JSON", html)
# Mock configuration to use InfoQuest engine
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "infoquest"}}
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
monkeypatch.setattr(
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
url = "http://example.com"
article = crawler.crawl(url)
assert article.url == url
assert article.title == "Extracted from JSON"
assert "Content from InfoQuest JSON" in article.html_content
def test_infoquest_client_initialization_params():
"""Test that InfoQuestClient correctly initializes with the provided parameters."""
# Test default parameters
client_default = InfoQuestClient()
assert client_default.fetch_time == -1
assert client_default.timeout == -1
assert client_default.navi_timeout == -1
# Test custom parameters
client_custom = InfoQuestClient(fetch_time=30, timeout=60, navi_timeout=45)
assert client_custom.fetch_time == 30
assert client_custom.timeout == 60
assert client_custom.navi_timeout == 45
def test_crawler_with_infoquest_default_parameters(monkeypatch):
"""Test that the crawler initializes InfoQuestClient with default parameters when none are provided."""
calls = {}
class DummyInfoQuestClient:
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
calls["infoquest_init"] = (fetch_time, timeout, navi_timeout)
def crawl(self, url, return_format=None):
return "<html>dummy</html>"
class DummyReadabilityExtractor:
def extract_article(self, html):
class DummyArticle:
url = None
def to_markdown(self):
return "# Dummy"
return DummyArticle()
# Mock configuration to use InfoQuest engine without custom parameters
def mock_load_config(*args, **kwargs):
return {"CRAWLER_ENGINE": {"engine": "infoquest"}}
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
monkeypatch.setattr("src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor)
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
crawler = crawler_module.crawler.Crawler()
crawler.crawl("http://example.com")
# Verify default parameters were passed
assert "infoquest_init" in calls
assert calls["infoquest_init"] == (-1, -1, -1)
-230
View File
@@ -1,230 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import Mock, patch
import json
from src.crawler.infoquest_client import InfoQuestClient
class TestInfoQuestClient:
@patch("src.crawler.infoquest_client.requests.post")
def test_crawl_success(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "<html><body>Test Content</body></html>"
mock_post.return_value = mock_response
client = InfoQuestClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result == "<html><body>Test Content</body></html>"
mock_post.assert_called_once()
@patch("src.crawler.infoquest_client.requests.post")
def test_crawl_json_response_with_reader_result(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
json_data = {
"reader_result": "<p>Extracted content from JSON</p>",
"err_code": 0,
"err_msg": "success"
}
mock_response.text = json.dumps(json_data)
mock_post.return_value = mock_response
client = InfoQuestClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result == "<p>Extracted content from JSON</p>"
@patch("src.crawler.infoquest_client.requests.post")
def test_crawl_json_response_with_content_fallback(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
json_data = {
"content": "<p>Content fallback from JSON</p>",
"err_code": 0,
"err_msg": "success"
}
mock_response.text = json.dumps(json_data)
mock_post.return_value = mock_response
client = InfoQuestClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result == "<p>Content fallback from JSON</p>"
@patch("src.crawler.infoquest_client.requests.post")
def test_crawl_json_response_without_expected_fields(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
json_data = {
"unexpected_field": "some value",
"err_code": 0,
"err_msg": "success"
}
mock_response.text = json.dumps(json_data)
mock_post.return_value = mock_response
client = InfoQuestClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result == json.dumps(json_data)
@patch("src.crawler.infoquest_client.requests.post")
def test_crawl_http_error(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_post.return_value = mock_response
client = InfoQuestClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result.startswith("Error:")
assert "status 500" in result
@patch("src.crawler.infoquest_client.requests.post")
def test_crawl_empty_response(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = ""
mock_post.return_value = mock_response
client = InfoQuestClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result.startswith("Error:")
assert "empty response" in result
@patch("src.crawler.infoquest_client.requests.post")
def test_crawl_whitespace_only_response(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = " \n \t "
mock_post.return_value = mock_response
client = InfoQuestClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result.startswith("Error:")
assert "empty response" in result
@patch("src.crawler.infoquest_client.requests.post")
def test_crawl_not_found(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 404
mock_response.text = "Not Found"
mock_post.return_value = mock_response
client = InfoQuestClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result.startswith("Error:")
assert "status 404" in result
@patch.dict("os.environ", {}, clear=True)
@patch("src.crawler.infoquest_client.requests.post")
def test_crawl_without_api_key_logs_warning(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "<html>Test</html>"
mock_post.return_value = mock_response
client = InfoQuestClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result == "<html>Test</html>"
@patch("src.crawler.infoquest_client.requests.post")
def test_crawl_with_timeout_parameters(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "<html>Test</html>"
mock_post.return_value = mock_response
client = InfoQuestClient(fetch_time=10, timeout=20, navi_timeout=30)
# Act
result = client.crawl("https://example.com")
# Assert
assert result == "<html>Test</html>"
# Verify the post call was made with timeout parameters
call_args = mock_post.call_args[1]
assert call_args['json']['fetch_time'] == 10
assert call_args['json']['timeout'] == 20
assert call_args['json']['navi_timeout'] == 30
@patch("src.crawler.infoquest_client.requests.post")
def test_crawl_with_markdown_format(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "# Markdown Content"
mock_post.return_value = mock_response
client = InfoQuestClient()
# Act
result = client.crawl("https://example.com", return_format="markdown")
# Assert
assert result == "# Markdown Content"
# Verify the format was set correctly
call_args = mock_post.call_args[1]
assert call_args['json']['format'] == "markdown"
@patch("src.crawler.infoquest_client.requests.post")
def test_crawl_exception_handling(self, mock_post):
# Arrange
mock_post.side_effect = Exception("Network error")
client = InfoQuestClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result.startswith("Error:")
assert "Network error" in result
-126
View File
@@ -1,126 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import Mock, patch
import pytest
from src.crawler.jina_client import JinaClient
class TestJinaClient:
@patch("src.crawler.jina_client.requests.post")
def test_crawl_success(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "<html><body>Test</body></html>"
mock_post.return_value = mock_response
client = JinaClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result == "<html><body>Test</body></html>"
mock_post.assert_called_once()
@patch("src.crawler.jina_client.requests.post")
def test_crawl_http_error(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_post.return_value = mock_response
client = JinaClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result.startswith("Error:")
assert "status 500" in result
@patch("src.crawler.jina_client.requests.post")
def test_crawl_empty_response(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = ""
mock_post.return_value = mock_response
client = JinaClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result.startswith("Error:")
assert "empty response" in result
@patch("src.crawler.jina_client.requests.post")
def test_crawl_whitespace_only_response(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = " \n \t "
mock_post.return_value = mock_response
client = JinaClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result.startswith("Error:")
assert "empty response" in result
@patch("src.crawler.jina_client.requests.post")
def test_crawl_not_found(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 404
mock_response.text = "Not Found"
mock_post.return_value = mock_response
client = JinaClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result.startswith("Error:")
assert "status 404" in result
@patch.dict("os.environ", {}, clear=True)
@patch("src.crawler.jina_client.requests.post")
def test_crawl_without_api_key_logs_warning(self, mock_post):
# Arrange
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "<html>Test</html>"
mock_post.return_value = mock_response
client = JinaClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result == "<html>Test</html>"
@patch("src.crawler.jina_client.requests.post")
def test_crawl_exception_handling(self, mock_post):
# Arrange
mock_post.side_effect = Exception("Network error")
client = JinaClient()
# Act
result = client.crawl("https://example.com")
# Assert
assert result.startswith("Error:")
assert "Network error" in result
@@ -1,104 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import patch
from src.crawler.readability_extractor import ReadabilityExtractor
class TestReadabilityExtractor:
@patch("src.crawler.readability_extractor.simple_json_from_html_string")
def test_extract_article_with_valid_content(self, mock_simple_json):
# Arrange
mock_simple_json.return_value = {
"title": "Test Article",
"content": "<p>Article content</p>",
}
extractor = ReadabilityExtractor()
# Act
article = extractor.extract_article("<html>test</html>")
# Assert
assert article.title == "Test Article"
assert article.html_content == "<p>Article content</p>"
@patch("src.crawler.readability_extractor.simple_json_from_html_string")
def test_extract_article_with_none_content(self, mock_simple_json):
# Arrange
mock_simple_json.return_value = {
"title": "Test Article",
"content": None,
}
extractor = ReadabilityExtractor()
# Act
article = extractor.extract_article("<html>test</html>")
# Assert
assert article.title == "Test Article"
assert article.html_content == "<p>No content could be extracted from this page</p>"
@patch("src.crawler.readability_extractor.simple_json_from_html_string")
def test_extract_article_with_empty_content(self, mock_simple_json):
# Arrange
mock_simple_json.return_value = {
"title": "Test Article",
"content": "",
}
extractor = ReadabilityExtractor()
# Act
article = extractor.extract_article("<html>test</html>")
# Assert
assert article.title == "Test Article"
assert article.html_content == "<p>No content could be extracted from this page</p>"
@patch("src.crawler.readability_extractor.simple_json_from_html_string")
def test_extract_article_with_whitespace_only_content(self, mock_simple_json):
# Arrange
mock_simple_json.return_value = {
"title": "Test Article",
"content": " \n \t ",
}
extractor = ReadabilityExtractor()
# Act
article = extractor.extract_article("<html>test</html>")
# Assert
assert article.title == "Test Article"
assert article.html_content == "<p>No content could be extracted from this page</p>"
@patch("src.crawler.readability_extractor.simple_json_from_html_string")
def test_extract_article_with_none_title(self, mock_simple_json):
# Arrange
mock_simple_json.return_value = {
"title": None,
"content": "<p>Article content</p>",
}
extractor = ReadabilityExtractor()
# Act
article = extractor.extract_article("<html>test</html>")
# Assert
assert article.title == "Untitled"
assert article.html_content == "<p>Article content</p>"
@patch("src.crawler.readability_extractor.simple_json_from_html_string")
def test_extract_article_with_empty_title(self, mock_simple_json):
# Arrange
mock_simple_json.return_value = {
"title": "",
"content": "<p>Article content</p>",
}
extractor = ReadabilityExtractor()
# Act
article = extractor.extract_article("<html>test</html>")
# Assert
assert article.title == "Untitled"
assert article.html_content == "<p>Article content</p>"
-2
View File
@@ -1,2 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
-489
View File
@@ -1,489 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""Unit tests for the combined report evaluator."""
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from src.eval.evaluator import CombinedEvaluation, ReportEvaluator, score_to_grade
from src.eval.llm_judge import (
EVALUATION_CRITERIA,
MAX_REPORT_LENGTH,
EvaluationResult,
LLMJudge,
)
from src.eval.metrics import ReportMetrics
class TestScoreToGrade:
"""Tests for score to grade conversion."""
def test_excellent_scores(self):
assert score_to_grade(9.5) == "A+"
assert score_to_grade(9.0) == "A+"
assert score_to_grade(8.7) == "A"
assert score_to_grade(8.5) == "A"
assert score_to_grade(8.2) == "A-"
def test_good_scores(self):
assert score_to_grade(7.8) == "B+"
assert score_to_grade(7.5) == "B+"
assert score_to_grade(7.2) == "B"
assert score_to_grade(7.0) == "B"
assert score_to_grade(6.7) == "B-"
def test_average_scores(self):
assert score_to_grade(6.2) == "C+"
assert score_to_grade(5.8) == "C"
assert score_to_grade(5.5) == "C"
assert score_to_grade(5.2) == "C-"
def test_poor_scores(self):
assert score_to_grade(4.5) == "D"
assert score_to_grade(4.0) == "D"
assert score_to_grade(3.0) == "F"
assert score_to_grade(1.0) == "F"
class TestReportEvaluator:
"""Tests for ReportEvaluator class."""
@pytest.fixture
def evaluator(self):
"""Create evaluator without LLM for metrics-only tests."""
return ReportEvaluator(use_llm=False)
@pytest.fixture
def sample_report(self):
"""Sample report for testing."""
return """
# Comprehensive Research Report
## Key Points
- Important finding number one with significant implications
- Critical discovery that changes our understanding
- Key insight that provides actionable recommendations
- Notable observation from the research data
## Overview
This report presents a comprehensive analysis of the research topic.
The findings are based on extensive data collection and analysis.
## Detailed Analysis
### Section 1: Background
The background of this research involves multiple factors.
[Source 1](https://example.com/source1) provides foundational context.
### Section 2: Methodology
Our methodology follows established research practices.
[Source 2](https://research.org/methods) outlines the approach.
### Section 3: Findings
The key findings include several important discoveries.
![Research Data](https://example.com/chart.png)
[Source 3](https://academic.edu/paper) supports these conclusions.
## Key Citations
- [Example Source](https://example.com/source1)
- [Research Methods](https://research.org/methods)
- [Academic Paper](https://academic.edu/paper)
- [Additional Reference](https://reference.com/doc)
"""
def test_evaluate_metrics_only(self, evaluator, sample_report):
"""Test metrics-only evaluation."""
result = evaluator.evaluate_metrics_only(sample_report)
assert "metrics" in result
assert "score" in result
assert "grade" in result
assert result["score"] > 0
assert result["grade"] in ["A+", "A", "A-", "B+", "B", "B-", "C+", "C", "C-", "D", "F"]
def test_evaluate_metrics_only_structure(self, evaluator, sample_report):
"""Test that metrics contain expected fields."""
result = evaluator.evaluate_metrics_only(sample_report)
metrics = result["metrics"]
assert "word_count" in metrics
assert "citation_count" in metrics
assert "unique_sources" in metrics
assert "image_count" in metrics
assert "section_coverage_score" in metrics
def test_evaluate_minimal_report(self, evaluator):
"""Test evaluation of minimal report."""
minimal_report = "Just some text."
result = evaluator.evaluate_metrics_only(minimal_report)
assert result["score"] < 5.0
assert result["grade"] in ["D", "F"]
def test_metrics_score_calculation(self, evaluator):
"""Test that metrics score is calculated correctly."""
good_report = """
# Title
## Key Points
- Point 1
- Point 2
## Overview
Overview content here.
## Detailed Analysis
Analysis with [cite](https://a.com) and [cite2](https://b.com)
and [cite3](https://c.com) and more [refs](https://d.com).
![Image](https://img.com/1.png)
## Key Citations
- [A](https://a.com)
- [B](https://b.com)
"""
result = evaluator.evaluate_metrics_only(good_report)
assert result["score"] > 5.0
def test_combined_evaluation_to_dict(self):
"""Test CombinedEvaluation to_dict method."""
metrics = ReportMetrics(
word_count=1000,
citation_count=5,
unique_sources=3,
)
evaluation = CombinedEvaluation(
metrics=metrics,
llm_evaluation=None,
final_score=7.5,
grade="B+",
summary="Test summary",
)
result = evaluation.to_dict()
assert result["final_score"] == 7.5
assert result["grade"] == "B+"
assert result["metrics"]["word_count"] == 1000
class TestReportEvaluatorIntegration:
"""Integration tests for evaluator (may require LLM)."""
@pytest.mark.asyncio
async def test_full_evaluation_without_llm(self):
"""Test full evaluation with LLM disabled."""
evaluator = ReportEvaluator(use_llm=False)
report = """
# Test Report
## Key Points
- Key point 1
## Overview
Test overview.
## Key Citations
- [Test](https://test.com)
"""
result = await evaluator.evaluate(report, "test query")
assert isinstance(result, CombinedEvaluation)
assert result.final_score > 0
assert result.grade is not None
assert result.summary is not None
assert result.llm_evaluation is None
class TestLLMJudgeParseResponse:
"""Tests for LLMJudge._parse_response method."""
@pytest.fixture
def judge(self):
"""Create LLMJudge with mock LLM."""
return LLMJudge(llm=MagicMock())
@pytest.fixture
def valid_response_data(self):
"""Valid evaluation response data."""
return {
"scores": {
"factual_accuracy": 8,
"completeness": 7,
"coherence": 9,
"relevance": 8,
"citation_quality": 6,
"writing_quality": 8,
},
"overall_score": 8,
"strengths": ["Well researched", "Clear structure"],
"weaknesses": ["Could use more citations"],
"suggestions": ["Add more sources"],
}
def test_parse_valid_json(self, judge, valid_response_data):
"""Test parsing valid JSON response."""
response = json.dumps(valid_response_data)
result = judge._parse_response(response)
assert result["scores"]["factual_accuracy"] == 8
assert result["overall_score"] == 8
assert "Well researched" in result["strengths"]
def test_parse_json_in_markdown_block(self, judge, valid_response_data):
"""Test parsing JSON wrapped in markdown code block."""
response = f"```json\n{json.dumps(valid_response_data)}\n```"
result = judge._parse_response(response)
assert result["scores"]["coherence"] == 9
assert result["overall_score"] == 8
def test_parse_json_in_generic_code_block(self, judge, valid_response_data):
"""Test parsing JSON in generic code block."""
response = f"```\n{json.dumps(valid_response_data)}\n```"
result = judge._parse_response(response)
assert result["scores"]["relevance"] == 8
def test_parse_malformed_json_returns_defaults(self, judge):
"""Test that malformed JSON returns default scores."""
response = "This is not valid JSON at all"
result = judge._parse_response(response)
assert result["scores"]["factual_accuracy"] == 5
assert result["scores"]["completeness"] == 5
assert result["overall_score"] == 5
assert "Unable to parse evaluation" in result["strengths"]
assert "Evaluation parsing failed" in result["weaknesses"]
def test_parse_incomplete_json(self, judge):
"""Test parsing incomplete JSON."""
response = '{"scores": {"factual_accuracy": 8}' # Missing closing braces
result = judge._parse_response(response)
# Should return defaults due to parse failure
assert result["overall_score"] == 5
def test_parse_json_with_extra_text(self, judge, valid_response_data):
"""Test parsing JSON with surrounding text."""
response = f"Here is my evaluation:\n```json\n{json.dumps(valid_response_data)}\n```\nHope this helps!"
result = judge._parse_response(response)
assert result["scores"]["factual_accuracy"] == 8
class TestLLMJudgeCalculateWeightedScore:
"""Tests for LLMJudge._calculate_weighted_score method."""
@pytest.fixture
def judge(self):
"""Create LLMJudge with mock LLM."""
return LLMJudge(llm=MagicMock())
def test_calculate_with_all_scores(self, judge):
"""Test weighted score calculation with all criteria."""
scores = {
"factual_accuracy": 10, # weight 0.25
"completeness": 10, # weight 0.20
"coherence": 10, # weight 0.20
"relevance": 10, # weight 0.15
"citation_quality": 10, # weight 0.10
"writing_quality": 10, # weight 0.10
}
result = judge._calculate_weighted_score(scores)
assert result == 10.0
def test_calculate_with_varied_scores(self, judge):
"""Test weighted score with varied scores."""
scores = {
"factual_accuracy": 8, # 8 * 0.25 = 2.0
"completeness": 6, # 6 * 0.20 = 1.2
"coherence": 7, # 7 * 0.20 = 1.4
"relevance": 9, # 9 * 0.15 = 1.35
"citation_quality": 5, # 5 * 0.10 = 0.5
"writing_quality": 8, # 8 * 0.10 = 0.8
}
# Total: 7.25
result = judge._calculate_weighted_score(scores)
assert result == 7.25
def test_calculate_with_partial_scores(self, judge):
"""Test weighted score with only some criteria."""
scores = {
"factual_accuracy": 8, # weight 0.25
"completeness": 6, # weight 0.20
}
# (8 * 0.25 + 6 * 0.20) / (0.25 + 0.20) = 3.2 / 0.45 = 7.11
result = judge._calculate_weighted_score(scores)
assert abs(result - 7.11) < 0.01
def test_calculate_with_unknown_criteria(self, judge):
"""Test that unknown criteria are ignored."""
scores = {
"factual_accuracy": 10,
"unknown_criterion": 1, # Should be ignored
}
result = judge._calculate_weighted_score(scores)
assert result == 10.0
def test_calculate_with_empty_scores(self, judge):
"""Test with empty scores dict."""
result = judge._calculate_weighted_score({})
assert result == 0.0
def test_weights_sum_to_one(self):
"""Verify that all criteria weights sum to 1.0."""
total_weight = sum(c["weight"] for c in EVALUATION_CRITERIA.values())
assert abs(total_weight - 1.0) < 0.001
class TestLLMJudgeEvaluate:
"""Tests for LLMJudge.evaluate method with mocked LLM."""
@pytest.fixture
def valid_llm_response(self):
"""Create a valid LLM response."""
return json.dumps(
{
"scores": {
"factual_accuracy": 8,
"completeness": 7,
"coherence": 9,
"relevance": 8,
"citation_quality": 7,
"writing_quality": 8,
},
"overall_score": 8,
"strengths": ["Comprehensive coverage", "Well structured"],
"weaknesses": ["Some claims need more support"],
"suggestions": ["Add more academic sources"],
}
)
@pytest.mark.asyncio
async def test_successful_evaluation(self, valid_llm_response):
"""Test successful LLM evaluation."""
mock_llm = AsyncMock()
mock_response = MagicMock()
mock_response.content = valid_llm_response
mock_llm.ainvoke.return_value = mock_response
judge = LLMJudge(llm=mock_llm)
result = await judge.evaluate("Test report", "Test query")
assert isinstance(result, EvaluationResult)
assert result.scores["factual_accuracy"] == 8
assert result.overall_score == 8
assert result.weighted_score > 0
assert "Comprehensive coverage" in result.strengths
assert result.raw_response == valid_llm_response
@pytest.mark.asyncio
async def test_evaluation_with_llm_failure(self):
"""Test that LLM failures are handled gracefully."""
mock_llm = AsyncMock()
mock_llm.ainvoke.side_effect = Exception("LLM service unavailable")
judge = LLMJudge(llm=mock_llm)
result = await judge.evaluate("Test report", "Test query")
assert isinstance(result, EvaluationResult)
assert result.overall_score == 0
assert result.weighted_score == 0
assert all(score == 0 for score in result.scores.values())
assert any("failed" in w.lower() for w in result.weaknesses)
@pytest.mark.asyncio
async def test_evaluation_with_malformed_response(self):
"""Test handling of malformed LLM response."""
mock_llm = AsyncMock()
mock_response = MagicMock()
mock_response.content = "I cannot evaluate this report properly."
mock_llm.ainvoke.return_value = mock_response
judge = LLMJudge(llm=mock_llm)
result = await judge.evaluate("Test report", "Test query")
# Should return default scores
assert result.scores["factual_accuracy"] == 5
assert result.overall_score == 5
@pytest.mark.asyncio
async def test_evaluation_passes_report_style(self):
"""Test that report_style is passed to LLM."""
mock_llm = AsyncMock()
mock_response = MagicMock()
mock_response.content = json.dumps(
{
"scores": {k: 7 for k in EVALUATION_CRITERIA.keys()},
"overall_score": 7,
"strengths": [],
"weaknesses": [],
"suggestions": [],
}
)
mock_llm.ainvoke.return_value = mock_response
judge = LLMJudge(llm=mock_llm)
await judge.evaluate("Test report", "Test query", report_style="academic")
# Verify the prompt contains the report style
call_args = mock_llm.ainvoke.call_args
messages = call_args[0][0]
user_message_content = messages[1].content
assert "academic" in user_message_content
@pytest.mark.asyncio
async def test_evaluation_truncates_long_reports(self):
"""Test that very long reports are truncated."""
mock_llm = AsyncMock()
mock_response = MagicMock()
mock_response.content = json.dumps(
{
"scores": {k: 7 for k in EVALUATION_CRITERIA.keys()},
"overall_score": 7,
"strengths": [],
"weaknesses": [],
"suggestions": [],
}
)
mock_llm.ainvoke.return_value = mock_response
judge = LLMJudge(llm=mock_llm)
long_report = "x" * (MAX_REPORT_LENGTH + 5000)
await judge.evaluate(long_report, "Test query")
call_args = mock_llm.ainvoke.call_args
messages = call_args[0][0]
user_message_content = messages[1].content
# The report content in the message should be truncated to MAX_REPORT_LENGTH
assert len(user_message_content) < len(long_report) + 500
class TestEvaluationResult:
"""Tests for EvaluationResult dataclass."""
def test_to_dict(self):
"""Test EvaluationResult.to_dict method."""
result = EvaluationResult(
scores={"factual_accuracy": 8, "completeness": 7},
overall_score=7.5,
weighted_score=7.6,
strengths=["Good research"],
weaknesses=["Needs more detail"],
suggestions=["Expand section 2"],
raw_response="test response",
)
d = result.to_dict()
assert d["scores"]["factual_accuracy"] == 8
assert d["overall_score"] == 7.5
assert d["weighted_score"] == 7.6
assert "Good research" in d["strengths"]
# raw_response should not be in dict
assert "raw_response" not in d
-207
View File
@@ -1,207 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""Unit tests for report evaluation metrics."""
from src.eval.metrics import (
compute_metrics,
count_citations,
count_images,
count_words,
detect_sections,
extract_domains,
get_word_count_target,
)
class TestCountWords:
"""Tests for word counting function."""
def test_english_words(self):
text = "This is a simple test sentence."
assert count_words(text) == 6
def test_chinese_characters(self):
text = "这是一个测试"
assert count_words(text) == 6
def test_mixed_content(self):
text = "Hello 你好 World 世界"
assert count_words(text) == 4 + 2 # 2 English + 4 Chinese
def test_empty_string(self):
assert count_words("") == 0
class TestCountCitations:
"""Tests for citation counting function."""
def test_markdown_citations(self):
text = """
Check out [Google](https://google.com) and [GitHub](https://github.com).
"""
assert count_citations(text) == 2
def test_no_citations(self):
text = "This is plain text without any links."
assert count_citations(text) == 0
def test_invalid_urls(self):
text = "[Link](not-a-url) [Another](ftp://ftp.example.com)"
assert count_citations(text) == 0
def test_complex_urls(self):
text = "[Article](https://example.com/path/to/article?id=123&ref=test)"
assert count_citations(text) == 1
class TestExtractDomains:
"""Tests for domain extraction function."""
def test_extract_multiple_domains(self):
text = """
https://google.com/search
https://www.github.com/user/repo
https://docs.python.org/3/
"""
domains = extract_domains(text)
assert len(domains) == 3
assert "google.com" in domains
assert "github.com" in domains
assert "docs.python.org" in domains
def test_deduplicate_domains(self):
text = """
https://example.com/page1
https://example.com/page2
https://www.example.com/page3
"""
domains = extract_domains(text)
assert len(domains) == 1
assert "example.com" in domains
def test_no_urls(self):
text = "Plain text without URLs"
assert extract_domains(text) == []
class TestCountImages:
"""Tests for image counting function."""
def test_markdown_images(self):
text = """
![Alt text](https://example.com/image1.png)
![](https://example.com/image2.jpg)
"""
assert count_images(text) == 2
def test_no_images(self):
text = "Text without images [link](url)"
assert count_images(text) == 0
class TestDetectSections:
"""Tests for section detection function."""
def test_detect_title(self):
text = "# My Report Title\n\nSome content here."
sections = detect_sections(text)
assert sections.get("title") is True
def test_detect_key_points(self):
text = "## Key Points\n- Point 1\n- Point 2"
sections = detect_sections(text)
assert sections.get("key_points") is True
def test_detect_chinese_sections(self):
text = """# 报告标题
## 要点
- 要点1
## 概述
这是概述内容
"""
sections = detect_sections(text)
assert sections.get("title") is True
assert sections.get("key_points") is True
assert sections.get("overview") is True
def test_detect_citations_section(self):
text = """
## Key Citations
- [Source 1](https://example.com)
"""
sections = detect_sections(text)
assert sections.get("key_citations") is True
class TestComputeMetrics:
"""Tests for the main compute_metrics function."""
def test_complete_report(self):
report = """
# Research Report Title
## Key Points
- Point 1
- Point 2
- Point 3
## Overview
This is an overview of the research topic.
## Detailed Analysis
Here is the detailed analysis with [source](https://example.com).
![Figure 1](https://example.com/image.png)
## Key Citations
- [Source 1](https://example.com)
- [Source 2](https://another.com)
"""
metrics = compute_metrics(report)
assert metrics.has_title is True
assert metrics.has_key_points is True
assert metrics.has_overview is True
assert metrics.has_citations_section is True
assert metrics.citation_count >= 2
assert metrics.image_count == 1
assert metrics.unique_sources >= 1
assert metrics.section_coverage_score > 0.5
def test_minimal_report(self):
report = "Just some text without structure."
metrics = compute_metrics(report)
assert metrics.has_title is False
assert metrics.citation_count == 0
assert metrics.section_coverage_score < 0.5
def test_metrics_to_dict(self):
report = "# Title\n\nSome content"
metrics = compute_metrics(report)
result = metrics.to_dict()
assert isinstance(result, dict)
assert "word_count" in result
assert "citation_count" in result
assert "section_coverage_score" in result
class TestGetWordCountTarget:
"""Tests for word count target function."""
def test_strategic_investment_target(self):
target = get_word_count_target("strategic_investment")
assert target["min"] == 10000
assert target["max"] == 15000
def test_news_target(self):
target = get_word_count_target("news")
assert target["min"] == 800
assert target["max"] == 2000
def test_default_target(self):
target = get_word_count_target("unknown_style")
assert target["min"] == 1000
assert target["max"] == 5000
@@ -1,241 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for agent locale restoration after agent execution.
Tests that meta fields (especially locale) are properly restored after
agent.ainvoke() returns, since the agent creates a MessagesState
subgraph that filters out custom fields.
"""
import pytest
from src.graph.nodes import preserve_state_meta_fields
from src.graph.types import State
class TestAgentLocaleRestoration:
"""Test suite for locale restoration after agent execution."""
def test_locale_lost_in_agent_subgraph(self):
"""
Demonstrate the problem: agent subgraph filters out locale.
When the agent creates a subgraph with MessagesState,
it only returns messages, not custom fields.
"""
# Simulate agent behavior: only returns messages
initial_state = State(messages=[], locale="zh-CN")
# Agent subgraph returns (like MessagesState would)
agent_result = {
"messages": ["agent response"],
}
# Problem: locale is missing
assert "locale" not in agent_result
assert agent_result.get("locale") is None
def test_locale_restoration_after_agent(self):
"""Test that locale can be restored after agent.ainvoke() returns."""
initial_state = State(
messages=[],
locale="zh-CN",
research_topic="test",
)
# Simulate agent returning (MessagesState only)
agent_result = {
"messages": ["agent response"],
}
# Apply restoration
preserved = preserve_state_meta_fields(initial_state)
agent_result.update(preserved)
# Verify restoration worked
assert agent_result["locale"] == "zh-CN"
assert agent_result["research_topic"] == "test"
assert "messages" in agent_result
def test_all_meta_fields_restored(self):
"""Test that all meta fields are restored, not just locale."""
initial_state = State(
messages=[],
locale="en-US",
research_topic="Original Topic",
clarified_research_topic="Clarified Topic",
clarification_history=["Q1", "A1"],
enable_clarification=True,
max_clarification_rounds=5,
clarification_rounds=2,
resources=["resource1"],
)
# Agent result
agent_result = {"messages": ["response"]}
agent_result.update(preserve_state_meta_fields(initial_state))
# All fields should be restored
assert agent_result["locale"] == "en-US"
assert agent_result["research_topic"] == "Original Topic"
assert agent_result["clarified_research_topic"] == "Clarified Topic"
assert agent_result["clarification_history"] == ["Q1", "A1"]
assert agent_result["enable_clarification"] is True
assert agent_result["max_clarification_rounds"] == 5
assert agent_result["clarification_rounds"] == 2
assert agent_result["resources"] == ["resource1"]
def test_locale_preservation_through_agent_cycle(self):
"""Test the complete cycle: state in → agent → state out."""
# Initial state with zh-CN locale
initial_state = State(messages=[], locale="zh-CN")
# Step 1: Extract meta fields
preserved = preserve_state_meta_fields(initial_state)
assert preserved["locale"] == "zh-CN"
# Step 2: Agent runs and returns only messages
agent_result = {"messages": ["agent output"]}
assert "locale" not in agent_result # Missing!
# Step 3: Restore meta fields
agent_result.update(preserved)
# Step 4: Verify locale is restored
assert agent_result["locale"] == "zh-CN"
# Step 5: Create new state with restored fields
final_state = State(messages=agent_result["messages"], **preserved)
assert final_state.get("locale") == "zh-CN"
def test_locale_not_auto_after_restoration(self):
"""
Test that locale is NOT "auto" after restoration.
This tests the specific bug: locale was becoming "auto"
instead of the preserved "zh-CN" value.
"""
state = State(messages=[], locale="zh-CN")
# Agent returns without locale
agent_result = {"messages": []}
# Before fix: locale would be "auto" (default behavior)
# After fix: locale is preserved
agent_result.update(preserve_state_meta_fields(state))
assert agent_result.get("locale") == "zh-CN"
assert agent_result.get("locale") != "auto"
assert agent_result.get("locale") is not None
def test_chinese_locale_preserved(self):
"""Test that Chinese locale specifically is preserved."""
locales_to_test = ["zh-CN", "zh", "zh-Hans", "zh-Hant"]
for locale_value in locales_to_test:
state = State(messages=[], locale=locale_value)
agent_result = {"messages": []}
agent_result.update(preserve_state_meta_fields(state))
assert agent_result["locale"] == locale_value, f"Failed for locale: {locale_value}"
def test_restoration_with_new_messages(self):
"""Test that restoration works even when agent adds new messages."""
state = State(messages=[], locale="zh-CN", research_topic="research")
# Agent processes and returns new messages
agent_result = {
"messages": ["message1", "message2", "message3"],
}
# Restore meta fields
agent_result.update(preserve_state_meta_fields(state))
# Should have both new messages AND preserved meta fields
assert len(agent_result["messages"]) == 3
assert agent_result["locale"] == "zh-CN"
assert agent_result["research_topic"] == "research"
def test_restoration_idempotent(self):
"""Test that restoring meta fields multiple times doesn't cause issues."""
state = State(messages=[], locale="en-US")
preserved = preserve_state_meta_fields(state)
agent_result = {"messages": []}
# Apply restoration multiple times
agent_result.update(preserved)
agent_result.update(preserved)
agent_result.update(preserved)
# Should still have correct locale (not corrupted)
assert agent_result["locale"] == "en-US"
assert len(agent_result) == 9 # messages + 8 meta fields
class TestAgentLocaleRestorationScenarios:
"""Real-world scenario tests for agent locale restoration."""
def test_researcher_agent_preserves_locale(self):
"""
Simulate researcher agent execution preserving locale.
Scenario:
1. Researcher node receives state with locale="zh-CN"
2. Calls agent.ainvoke() which returns only messages
3. Restores locale before returning
"""
# State coming into researcher node
state = State(
messages=[],
locale="zh-CN",
research_topic="生产1公斤牛肉需要多少升水?",
)
# Agent executes and returns
agent_result = {
"messages": ["Researcher analysis of water usage..."],
}
# Apply restoration (what the fix does)
agent_result.update(preserve_state_meta_fields(state))
# Verify for next node
assert agent_result["locale"] == "zh-CN" # ✓ Preserved!
assert agent_result.get("locale") != "auto" # ✓ Not "auto"
def test_coder_agent_preserves_locale(self):
"""Coder agent should also preserve locale."""
state = State(messages=[], locale="en-US")
agent_result = {"messages": ["Code generation result"]}
agent_result.update(preserve_state_meta_fields(state))
assert agent_result["locale"] == "en-US"
def test_locale_persists_across_multiple_agents(self):
"""Test that locale persists through multiple agent calls."""
locales = ["zh-CN", "en-US", "fr-FR"]
for locale in locales:
# Initial state
state = State(messages=[], locale=locale)
preserved_1 = preserve_state_meta_fields(state)
# First agent
result_1 = {"messages": ["agent1"]}
result_1.update(preserved_1)
# Create state for second agent
state_2 = State(messages=result_1["messages"], **preserved_1)
preserved_2 = preserve_state_meta_fields(state_2)
# Second agent
result_2 = {"messages": result_1["messages"] + ["agent2"]}
result_2.update(preserved_2)
# Locale should persist
assert result_2["locale"] == locale
-134
View File
@@ -1,134 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import importlib
import sys
from unittest.mock import MagicMock, patch
import pytest
import src.graph.builder as builder_mod
@pytest.fixture
def mock_state():
class Step:
def __init__(self, execution_res=None, step_type=None):
self.execution_res = execution_res
self.step_type = step_type
class Plan:
def __init__(self, steps):
self.steps = steps
return {
"Step": Step,
"Plan": Plan,
}
def test_continue_to_running_research_team_no_plan(mock_state):
state = {"current_plan": None}
assert builder_mod.continue_to_running_research_team(state) == "planner"
def test_continue_to_running_research_team_no_steps(mock_state):
state = {"current_plan": mock_state["Plan"](steps=[])}
assert builder_mod.continue_to_running_research_team(state) == "planner"
def test_continue_to_running_research_team_all_executed(mock_state):
Step = mock_state["Step"]
Plan = mock_state["Plan"]
steps = [Step(execution_res=True), Step(execution_res=True)]
state = {"current_plan": Plan(steps=steps)}
assert builder_mod.continue_to_running_research_team(state) == "planner"
def test_continue_to_running_research_team_next_researcher(mock_state):
Step = mock_state["Step"]
Plan = mock_state["Plan"]
steps = [
Step(execution_res=True),
Step(execution_res=None, step_type=builder_mod.StepType.RESEARCH),
]
state = {"current_plan": Plan(steps=steps)}
assert builder_mod.continue_to_running_research_team(state) == "researcher"
def test_continue_to_running_research_team_next_coder(mock_state):
Step = mock_state["Step"]
Plan = mock_state["Plan"]
steps = [
Step(execution_res=True),
Step(execution_res=None, step_type=builder_mod.StepType.PROCESSING),
]
state = {"current_plan": Plan(steps=steps)}
assert builder_mod.continue_to_running_research_team(state) == "coder"
def test_continue_to_running_research_team_next_coder_withresult(mock_state):
Step = mock_state["Step"]
Plan = mock_state["Plan"]
steps = [
Step(execution_res=True),
Step(execution_res=True, step_type=builder_mod.StepType.PROCESSING),
]
state = {"current_plan": Plan(steps=steps)}
assert builder_mod.continue_to_running_research_team(state) == "planner"
def test_continue_to_running_research_team_default_planner(mock_state):
Step = mock_state["Step"]
Plan = mock_state["Plan"]
steps = [Step(execution_res=True), Step(execution_res=None, step_type=None)]
state = {"current_plan": Plan(steps=steps)}
assert builder_mod.continue_to_running_research_team(state) == "planner"
@patch("src.graph.builder.StateGraph")
def test_build_base_graph_adds_nodes_and_edges(MockStateGraph):
mock_builder = MagicMock()
MockStateGraph.return_value = mock_builder
builder_mod._build_base_graph()
# Check that all nodes and edges are added
assert mock_builder.add_edge.call_count >= 2
assert mock_builder.add_node.call_count >= 8
# Now we have 1 conditional edges: research_team
assert mock_builder.add_conditional_edges.call_count == 1
@patch("src.graph.builder._build_base_graph")
@patch("src.graph.builder.MemorySaver")
def test_build_graph_with_memory_uses_memory(MockMemorySaver, mock_build_base_graph):
mock_builder = MagicMock()
mock_build_base_graph.return_value = mock_builder
mock_memory = MagicMock()
MockMemorySaver.return_value = mock_memory
builder_mod.build_graph_with_memory()
mock_builder.compile.assert_called_once_with(checkpointer=mock_memory)
@patch("src.graph.builder._build_base_graph")
def test_build_graph_without_memory(mock_build_base_graph):
mock_builder = MagicMock()
mock_build_base_graph.return_value = mock_builder
builder_mod.build_graph()
mock_builder.compile.assert_called_once_with()
def test_graph_is_compiled():
# The graph object should be the result of build_graph()
with patch("src.graph.builder._build_base_graph") as mock_base:
mock_builder = MagicMock()
mock_base.return_value = mock_builder
mock_builder.compile.return_value = "compiled_graph"
# reload the module to re-run the graph assignment
importlib.reload(sys.modules["src.graph.builder"])
assert builder_mod.graph is not None
@@ -1,317 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for the human_feedback_node locale fix.
Tests that the duplicate locale assignment issue is resolved:
- Locale is safely retrieved from new_plan using .get() with fallback
- If new_plan['locale'] doesn't exist, it doesn't cause a KeyError
- If new_plan['locale'] is None or empty, the preserved state locale is used
- If new_plan['locale'] has a valid value, it properly overrides the state locale
"""
import pytest
from src.graph.nodes import preserve_state_meta_fields
from src.graph.types import State
from src.prompts.planner_model import Plan
class TestHumanFeedbackLocaleFixture:
"""Test suite for human_feedback_node locale safe handling."""
def test_preserve_state_meta_fields_no_keyerror(self):
"""Test that preserve_state_meta_fields never raises KeyError."""
state = State(messages=[], locale="zh-CN")
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == "zh-CN"
assert "locale" in preserved
def test_new_plan_without_locale_override(self):
"""
Test scenario: Plan doesn't override locale when not provided in override dict.
Before fix: Would set locale twice (duplicate assignment)
After fix: Uses .get() safely and only overrides if value is truthy
"""
state = State(messages=[], locale="zh-CN")
# Simulate a plan that doesn't want to override the locale
# (locale is in the plan for validation, but not in override dict)
new_plan_dict = {"title": "Test", "thought": "Test", "steps": [], "locale": "zh-CN", "has_enough_context": False}
# Get preserved fields
preserved = preserve_state_meta_fields(state)
# Build update dict like the fixed code does
update_dict = {
"current_plan": Plan.model_validate(new_plan_dict),
**preserved,
}
# Simulate a dict that doesn't have locale override (like when plan_dict is empty for override)
plan_override = {} # No locale in override dict
# Only override locale if override dict provides a valid value
if plan_override.get("locale"):
update_dict["locale"] = plan_override["locale"]
# The preserved locale should be used when override doesn't provide one
assert update_dict["locale"] == "zh-CN"
def test_new_plan_with_none_locale(self):
"""
Test scenario: new_plan has locale=None.
Before fix: Would try to set locale to None (but Plan requires it)
After fix: Uses preserved state locale since new_plan.get("locale") is falsy
"""
state = State(messages=[], locale="zh-CN")
# new_plan with None locale (won't validate, but test the logic)
new_plan_attempt = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
# Get preserved fields
preserved = preserve_state_meta_fields(state)
# Build update dict like the fixed code does
update_dict = {
"current_plan": Plan.model_validate(new_plan_attempt),
**preserved,
}
# Simulate checking for None locale (if it somehow got set)
new_plan_with_none = {"locale": None}
# Only override if new_plan provides a VALID value
if new_plan_with_none.get("locale"):
update_dict["locale"] = new_plan_with_none["locale"]
# Should use preserved locale (zh-CN), not None
assert update_dict["locale"] == "zh-CN"
assert update_dict["locale"] is not None
def test_new_plan_with_empty_string_locale(self):
"""
Test scenario: new_plan has locale="" (empty string).
Before fix: Would try to set locale to "" (but Plan requires valid value)
After fix: Uses preserved state locale since empty string is falsy
"""
state = State(messages=[], locale="zh-CN")
# new_plan with valid locale
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
# Get preserved fields
preserved = preserve_state_meta_fields(state)
# Build update dict like the fixed code does
update_dict = {
"current_plan": Plan.model_validate(new_plan),
**preserved,
}
# Simulate checking for empty string locale
new_plan_empty = {"locale": ""}
# Only override if new_plan provides a VALID (truthy) value
if new_plan_empty.get("locale"):
update_dict["locale"] = new_plan_empty["locale"]
# Should use preserved locale (zh-CN), not empty string
assert update_dict["locale"] == "zh-CN"
assert update_dict["locale"] != ""
def test_new_plan_with_valid_locale_overrides(self):
"""
Test scenario: new_plan has valid locale="en-US".
Before fix: Would override with new_plan locale ✓ (worked)
After fix: Should still properly override with valid locale
"""
state = State(messages=[], locale="zh-CN")
# new_plan has a different valid locale
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
# Get preserved fields
preserved = preserve_state_meta_fields(state)
# Build update dict like the fixed code does
update_dict = {
"current_plan": Plan.model_validate(new_plan),
**preserved,
}
# Override if new_plan provides a VALID value
if new_plan.get("locale"):
update_dict["locale"] = new_plan["locale"]
# Should override with new_plan locale
assert update_dict["locale"] == "en-US"
assert update_dict["locale"] != "zh-CN"
def test_locale_field_not_duplicated(self):
"""
Test that locale field is not duplicated in the update dict.
Before fix: locale was set twice in the same dict
After fix: locale is only set once
"""
state = State(messages=[], locale="zh-CN")
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
preserved = preserve_state_meta_fields(state)
# Count how many times 'locale' is set
update_dict = {
"current_plan": Plan.model_validate(new_plan),
**preserved, # Sets locale once
}
# Override locale only if new_plan provides valid value
if new_plan.get("locale"):
update_dict["locale"] = new_plan["locale"]
# Verify locale is in dict exactly once
locale_count = sum(1 for k in update_dict.keys() if k == "locale")
assert locale_count == 1
assert update_dict["locale"] == "en-US" # Should be overridden
def test_all_meta_fields_preserved(self):
"""
Test that all 8 meta fields are preserved along with locale fix.
Ensures the fix doesn't break other meta field preservation.
"""
state = State(
messages=[],
locale="zh-CN",
research_topic="Research",
clarified_research_topic="Clarified",
clarification_history=["Q1"],
enable_clarification=True,
max_clarification_rounds=5,
clarification_rounds=1,
resources=["resource1"],
)
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
preserved = preserve_state_meta_fields(state)
# All 8 meta fields should be in preserved
meta_fields = [
"locale",
"research_topic",
"clarified_research_topic",
"clarification_history",
"enable_clarification",
"max_clarification_rounds",
"clarification_rounds",
"resources",
]
for field in meta_fields:
assert field in preserved
# Build update dict
update_dict = {
"current_plan": Plan.model_validate(new_plan),
**preserved,
}
# Override locale if new_plan provides valid value
if new_plan.get("locale"):
update_dict["locale"] = new_plan["locale"]
# All meta fields should be in update_dict
for field in meta_fields:
assert field in update_dict
class TestHumanFeedbackLocaleScenarios:
"""Real-world scenarios for human_feedback_node locale handling."""
def test_scenario_chinese_locale_preserved_when_plan_has_no_locale(self):
"""
Scenario: User selected Chinese, plan preserves it.
Expected: Preserved Chinese locale should be used
"""
state = State(messages=[], locale="zh-CN")
# Plan from planner with required fields
new_plan_json = {
"title": "Research Plan",
"thought": "...",
"steps": [
{
"title": "Step 1",
"description": "...",
"need_search": True,
"step_type": "research",
}
],
"locale": "zh-CN",
"has_enough_context": False,
}
preserved = preserve_state_meta_fields(state)
update_dict = {
"current_plan": Plan.model_validate(new_plan_json),
**preserved,
}
if new_plan_json.get("locale"):
update_dict["locale"] = new_plan_json["locale"]
# Chinese locale should be preserved
assert update_dict["locale"] == "zh-CN"
def test_scenario_en_us_restored_even_if_plan_minimal(self):
"""
Scenario: Minimal plan with en-US locale.
Expected: Preserved en-US locale should survive
"""
state = State(messages=[], locale="en-US")
# Minimal plan with required fields
new_plan_json = {"title": "Quick Plan", "steps": [], "locale": "en-US", "has_enough_context": False}
preserved = preserve_state_meta_fields(state)
update_dict = {
"current_plan": Plan.model_validate(new_plan_json),
**preserved,
}
if new_plan_json.get("locale"):
update_dict["locale"] = new_plan_json["locale"]
# en-US should survive
assert update_dict["locale"] == "en-US"
def test_scenario_multiple_locale_updates_safe(self):
"""
Scenario: Multiple plan iterations with locale preservation.
Expected: Each iteration safely handles locale
"""
locales = ["zh-CN", "en-US", "fr-FR"]
for locale in locales:
state = State(messages=[], locale=locale)
new_plan = {"title": "Plan", "steps": [], "locale": locale, "has_enough_context": False}
preserved = preserve_state_meta_fields(state)
update_dict = {
"current_plan": Plan.model_validate(new_plan),
**preserved,
}
if new_plan.get("locale"):
update_dict["locale"] = new_plan["locale"]
# Each iteration should preserve its locale
assert update_dict["locale"] == locale
@@ -1,623 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for recursion limit fallback functionality in graph nodes.
Tests the graceful fallback behavior when agents hit the recursion limit,
including the _handle_recursion_limit_fallback function and the
enable_recursion_fallback configuration option.
"""
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from src.config.configuration import Configuration
from src.graph.nodes import _handle_recursion_limit_fallback
from src.graph.types import State
class TestHandleRecursionLimitFallback:
"""Test suite for _handle_recursion_limit_fallback() function."""
@pytest.mark.asyncio
async def test_fallback_generates_summary_from_observations(self):
"""Test that fallback generates summary using accumulated agent messages."""
from langchain_core.messages import ToolCall
# Create test state with messages
state = State(
messages=[
HumanMessage(content="Research topic: AI safety")
],
locale="en-US",
)
# Mock current step
current_step = MagicMock()
current_step.execution_res = None
# Mock partial agent messages (accumulated during execution before hitting limit)
tool_call = ToolCall(
name="web_search",
args={"query": "AI safety"},
id="123"
)
partial_agent_messages = [
HumanMessage(content="# Research Topic\n\nAI safety\n\n# Current Step\n\n## Title\n\nAnalyze AI safety"),
AIMessage(content="", tool_calls=[tool_call]),
HumanMessage(content="Tool result: Found 5 articles about AI safety"),
]
# Mock the LLM response
mock_llm_response = MagicMock()
mock_llm_response.content = "# Summary\n\nBased on the research, AI safety is important."
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
mock_get_system_prompt.return_value = "Fallback instructions"
# Call the fallback function
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify result is a list
assert isinstance(result, list)
# Verify step execution result was set
assert current_step.execution_res == mock_llm_response.content
# Verify messages include partial agent messages and the AI response
# Should have partial messages + 1 new AI response
assert len(result) == len(partial_agent_messages) + 1
# Last message should be the fallback AI response
assert isinstance(result[-1], AIMessage)
assert result[-1].content == mock_llm_response.content
assert result[-1].name == "researcher"
# First messages should be from partial_agent_messages
assert result[0] == partial_agent_messages[0]
assert result[1] == partial_agent_messages[1]
assert result[2] == partial_agent_messages[2]
@pytest.mark.asyncio
async def test_fallback_applies_prompt_template(self):
"""Test that fallback applies the recursion_fallback prompt template."""
state = State(messages=[], locale="zh-CN")
current_step = MagicMock()
# Create non-empty messages to avoid early return
partial_agent_messages = [HumanMessage(content="Test")]
mock_llm_response = MagicMock()
mock_llm_response.content = "Summary in Chinese"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
mock_get_system_prompt.return_value = "Template rendered"
await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify get_system_prompt_template was called with correct arguments
assert mock_get_system_prompt.call_count == 2 # Called twice (once for agent, once for fallback)
# Check the first call (for agent prompt)
first_call = mock_get_system_prompt.call_args_list[0]
assert first_call[0][0] == "researcher" # agent_name
assert first_call[0][1]["locale"] == "zh-CN" # locale in state
# Check the second call (for recursion_fallback prompt)
second_call = mock_get_system_prompt.call_args_list[1]
assert second_call[0][0] == "recursion_fallback" # prompt_name
assert second_call[0][1]["locale"] == "zh-CN" # locale in state
@pytest.mark.asyncio
async def test_fallback_gets_llm_without_tools(self):
"""Test that fallback gets LLM without tools bound."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
partial_agent_messages = []
mock_llm_response = MagicMock()
mock_llm_response.content = "Summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value="Template"), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="coder",
current_step=current_step,
state=state,
)
# With empty messages, should return empty list
assert result == []
# Verify get_llm_by_type was NOT called (empty messages return early)
mock_get_llm.assert_not_called()
@pytest.mark.asyncio
async def test_fallback_sanitizes_response(self):
"""Test that fallback response is sanitized."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
# Create test messages (not empty)
partial_agent_messages = [HumanMessage(content="Test")]
# Mock unsanitized response with extra tokens
mock_llm_response = MagicMock()
mock_llm_response.content = "<extra_tokens>Summary content</extra_tokens>"
sanitized_content = "Summary content"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=sanitized_content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify sanitized content was used
assert result[-1].content == sanitized_content
assert current_step.execution_res == sanitized_content
@pytest.mark.asyncio
async def test_fallback_preserves_meta_fields(self):
"""Test that fallback uses state locale correctly."""
state = State(
messages=[],
locale="zh-CN",
research_topic="原始研究主题",
clarification_rounds=2,
)
current_step = MagicMock()
# Create test messages (not empty)
partial_agent_messages = [HumanMessage(content="Test")]
mock_llm_response = MagicMock()
mock_llm_response.content = "Summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
mock_get_system_prompt.return_value = "Template"
await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify locale was passed to template
call_args = mock_get_system_prompt.call_args
assert call_args[0][1]["locale"] == "zh-CN"
@pytest.mark.asyncio
async def test_fallback_raises_on_llm_failure(self):
"""Test that fallback raises exception when LLM call fails."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
# Create test messages (not empty)
partial_agent_messages = [HumanMessage(content="Test")]
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(side_effect=Exception("LLM API error"))
mock_get_llm.return_value = mock_llm
# Should raise the exception
with pytest.raises(Exception, match="LLM API error"):
await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
@pytest.mark.asyncio
async def test_fallback_handles_different_agent_types(self):
"""Test that fallback works with different agent types."""
state = State(messages=[], locale="en-US")
# Create test messages (not empty)
partial_agent_messages = [HumanMessage(content="Test")]
mock_llm_response = MagicMock()
mock_llm_response.content = "Agent summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
for agent_name in ["researcher", "coder", "analyst"]:
current_step = MagicMock()
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name=agent_name,
current_step=current_step,
state=state,
)
# Verify agent name is set correctly
assert result[-1].name == agent_name
@pytest.mark.asyncio
async def test_fallback_uses_partial_agent_messages(self):
"""Test that fallback includes partial agent messages in result."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
# Create partial agent messages with tool calls
# Use proper tool_call format
from langchain_core.messages import ToolCall
tool_call = ToolCall(
name="web_search",
args={"query": "test query"},
id="123"
)
partial_agent_messages = [
HumanMessage(content="Input message"),
AIMessage(content="", tool_calls=[tool_call]),
HumanMessage(content="Tool result: Search completed"),
]
mock_llm_response = MagicMock()
mock_llm_response.content = "Fallback summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify partial messages are in result
result_messages = result
assert len(result_messages) == len(partial_agent_messages) + 1
# First messages should be from partial_agent_messages
assert result_messages[0] == partial_agent_messages[0]
assert result_messages[1] == partial_agent_messages[1]
assert result_messages[2] == partial_agent_messages[2]
# Last message should be the fallback AI response
assert isinstance(result_messages[3], AIMessage)
assert result_messages[3].content == "Fallback summary"
@pytest.mark.asyncio
async def test_fallback_handles_empty_partial_messages(self):
"""Test that fallback handles empty partial_agent_messages."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
partial_agent_messages = [] # Empty
mock_llm_response = MagicMock()
mock_llm_response.content = "Fallback summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# With empty messages, should return empty list (early return)
assert result == []
# Verify get_llm_by_type was NOT called (early return)
mock_get_llm.assert_not_called()
class TestRecursionFallbackConfiguration:
"""Test suite for enable_recursion_fallback configuration."""
def test_config_default_is_enabled(self):
"""Test that enable_recursion_fallback defaults to True."""
config = Configuration()
assert config.enable_recursion_fallback is True
def test_config_from_env_variable_true(self):
"""Test that enable_recursion_fallback can be set via environment variable."""
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "true"}):
config = Configuration()
assert config.enable_recursion_fallback is True
def test_config_from_env_variable_false(self):
"""Test that enable_recursion_fallback can be disabled via environment variable.
NOTE: This test documents the current behavior. The Configuration.from_runnable_config
method has a known issue where it doesn't properly convert boolean strings like "false"
to boolean False. This test reflects the actual (buggy) behavior and should be updated
when the Configuration class is fixed to use get_bool_env for boolean fields.
"""
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "false"}):
config = Configuration()
# Currently returns True due to Configuration class bug
# Should return False when using get_bool_env properly
assert config.enable_recursion_fallback is True # Actual behavior
def test_config_from_env_variable_1(self):
"""Test that '1' is treated as True for enable_recursion_fallback."""
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "1"}):
config = Configuration()
assert config.enable_recursion_fallback is True
def test_config_from_env_variable_0(self):
"""Test that '0' is treated as False for enable_recursion_fallback.
NOTE: This test documents the current behavior. The Configuration class has a known
issue where string "0" is not properly converted to boolean False.
"""
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "0"}):
config = Configuration()
# Currently returns True due to Configuration class bug
assert config.enable_recursion_fallback is True # Actual behavior
def test_config_from_env_variable_yes(self):
"""Test that 'yes' is treated as True for enable_recursion_fallback."""
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "yes"}):
config = Configuration()
assert config.enable_recursion_fallback is True
def test_config_from_env_variable_no(self):
"""Test that 'no' is treated as False for enable_recursion_fallback.
NOTE: This test documents the current behavior. The Configuration class has a known
issue where string "no" is not properly converted to boolean False.
"""
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "no"}):
config = Configuration()
# Currently returns True due to Configuration class bug
assert config.enable_recursion_fallback is True # Actual behavior
def test_config_from_runnable_config(self):
"""Test that enable_recursion_fallback can be set via RunnableConfig."""
from langchain_core.runnables import RunnableConfig
# Test with False value
config_false = RunnableConfig(configurable={"enable_recursion_fallback": False})
configuration_false = Configuration.from_runnable_config(config_false)
assert configuration_false.enable_recursion_fallback is False
# Test with True value
config_true = RunnableConfig(configurable={"enable_recursion_fallback": True})
configuration_true = Configuration.from_runnable_config(config_true)
assert configuration_true.enable_recursion_fallback is True
def test_config_field_exists(self):
"""Test that enable_recursion_fallback field exists in Configuration."""
config = Configuration()
assert hasattr(config, "enable_recursion_fallback")
assert isinstance(config.enable_recursion_fallback, bool)
class TestRecursionFallbackIntegration:
"""Integration tests for recursion fallback in agent execution."""
@pytest.mark.asyncio
async def test_fallback_function_signature_returns_list(self):
"""Test that the fallback function returns a list of messages."""
from src.graph.nodes import _handle_recursion_limit_fallback
state = State(messages=[], locale="en-US")
current_step = MagicMock()
# Create non-empty messages to avoid early return
partial_agent_messages = [HumanMessage(content="Test")]
mock_llm_response = MagicMock()
mock_llm_response.content = "Summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
# This should not raise - just verify the function returns a list
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify it returns a list
assert isinstance(result, list)
@pytest.mark.asyncio
async def test_configuration_enables_disables_fallback(self):
"""Test that configuration controls fallback behavior."""
configurable_enabled = Configuration(enable_recursion_fallback=True)
configurable_disabled = Configuration(enable_recursion_fallback=False)
assert configurable_enabled.enable_recursion_fallback is True
assert configurable_disabled.enable_recursion_fallback is False
class TestRecursionFallbackEdgeCases:
"""Test edge cases and boundary conditions for recursion fallback."""
@pytest.mark.asyncio
async def test_fallback_with_empty_observations(self):
"""Test fallback behavior when there are no observations."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
partial_agent_messages = []
mock_llm_response = MagicMock()
mock_llm_response.content = "No observations available"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# With empty messages, should return empty list
assert result == []
@pytest.mark.asyncio
async def test_fallback_with_very_long_recursion_limit(self):
"""Test fallback with very large recursion limit value."""
state = State(messages=[], locale="en-US")
current_step = MagicMock()
partial_agent_messages = []
mock_llm_response = MagicMock()
mock_llm_response.content = "Summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
result = await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# With empty messages, should return empty list
assert result == []
@pytest.mark.asyncio
async def test_fallback_with_unicode_locale(self):
"""Test fallback with various locale formats including unicode."""
for locale in ["zh-CN", "ja-JP", "ko-KR", "en-US", "pt-BR"]:
state = State(messages=[], locale=locale)
current_step = MagicMock()
# Create non-empty messages to avoid early return
partial_agent_messages = [HumanMessage(content="Test")]
mock_llm_response = MagicMock()
mock_llm_response.content = f"Summary for {locale}"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
mock_get_system_prompt.return_value = "Template"
await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify locale was passed to template
call_args = mock_get_system_prompt.call_args
assert call_args[0][1]["locale"] == locale
@pytest.mark.asyncio
async def test_fallback_with_none_locale(self):
"""Test fallback handles None locale gracefully."""
state = State(messages=[], locale=None)
current_step = MagicMock()
# Create non-empty messages to avoid early return
partial_agent_messages = [HumanMessage(content="Test")]
mock_llm_response = MagicMock()
mock_llm_response.content = "Summary"
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
mock_llm = MagicMock()
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
mock_get_llm.return_value = mock_llm
mock_get_system_prompt.return_value = "Template"
# Should not raise, should use default locale
await _handle_recursion_limit_fallback(
messages=partial_agent_messages,
agent_name="researcher",
current_step=current_step,
state=state,
)
# Verify default locale "en-US" was used
call_args = mock_get_system_prompt.call_args
assert call_args[0][1]["locale"] is None or call_args[0][1]["locale"] == "en-US"
-491
View File
@@ -1,491 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import MagicMock, patch
import pytest
from src.graph.nodes import validate_and_fix_plan
class TestValidateAndFixPlanStepTypeRepair:
"""Test step_type field repair logic (Issue #650 fix)."""
def test_repair_missing_step_type_with_need_search_true(self):
"""Test that missing step_type is inferred as 'research' when need_search=true."""
plan = {
"steps": [
{
"need_search": True,
"title": "Research Step",
"description": "Gather data",
# step_type is MISSING
}
]
}
result = validate_and_fix_plan(plan)
assert result["steps"][0]["step_type"] == "research"
def test_repair_missing_step_type_with_need_search_false(self):
"""Test that missing step_type is inferred as 'analysis' when need_search=false (Issue #677)."""
plan = {
"steps": [
{
"need_search": False,
"title": "Processing Step",
"description": "Analyze data",
# step_type is MISSING
}
]
}
result = validate_and_fix_plan(plan)
# Issue #677: non-search steps now default to 'analysis' instead of 'processing'
assert result["steps"][0]["step_type"] == "analysis"
def test_repair_missing_step_type_default_to_analysis(self):
"""Test that missing step_type defaults to 'analysis' when need_search is not specified (Issue #677)."""
plan = {
"steps": [
{
"title": "Unknown Step",
"description": "Do something",
# need_search is MISSING, step_type is MISSING
}
]
}
result = validate_and_fix_plan(plan)
# Issue #677: non-search steps now default to 'analysis' instead of 'processing'
assert result["steps"][0]["step_type"] == "analysis"
def test_repair_empty_step_type_field(self):
"""Test that empty step_type field is repaired."""
plan = {
"steps": [
{
"need_search": True,
"title": "Research Step",
"description": "Gather data",
"step_type": "", # Empty string
}
]
}
result = validate_and_fix_plan(plan)
assert result["steps"][0]["step_type"] == "research"
def test_repair_null_step_type_field(self):
"""Test that null step_type field is repaired."""
plan = {
"steps": [
{
"need_search": False,
"title": "Processing Step",
"description": "Analyze data",
"step_type": None,
}
]
}
result = validate_and_fix_plan(plan)
# Issue #677: non-search steps now default to 'analysis' instead of 'processing'
assert result["steps"][0]["step_type"] == "analysis"
def test_multiple_steps_with_mixed_missing_step_types(self):
"""Test repair of multiple steps with different missing step_type scenarios."""
plan = {
"steps": [
{
"need_search": True,
"title": "Research 1",
"description": "Gather",
# MISSING step_type
},
{
"need_search": False,
"title": "Processing 1",
"description": "Analyze",
"step_type": "processing", # Already has step_type
},
{
"need_search": True,
"title": "Research 2",
"description": "More gathering",
# MISSING step_type
},
]
}
result = validate_and_fix_plan(plan)
assert result["steps"][0]["step_type"] == "research"
assert result["steps"][1]["step_type"] == "processing" # Should remain unchanged
assert result["steps"][2]["step_type"] == "research"
def test_preserve_explicit_step_type(self):
"""Test that explicitly provided step_type values are preserved."""
plan = {
"steps": [
{
"need_search": True,
"title": "Research Step",
"description": "Gather",
"step_type": "research",
},
{
"need_search": False,
"title": "Processing Step",
"description": "Analyze",
"step_type": "processing",
},
]
}
result = validate_and_fix_plan(plan)
# Should remain unchanged
assert result["steps"][0]["step_type"] == "research"
assert result["steps"][1]["step_type"] == "processing"
def test_repair_logs_warning(self):
"""Test that repair operations are logged."""
plan = {
"steps": [
{
"need_search": True,
"title": "Missing Type Step",
"description": "Gather",
}
]
}
with patch("src.graph.nodes.logger") as mock_logger:
validate_and_fix_plan(plan)
# Should log repair operation
mock_logger.info.assert_called()
# Check that any of the info calls contains "Repaired missing step_type"
assert any("Repaired missing step_type" in str(call) for call in mock_logger.info.call_args_list)
def test_non_dict_plan_returns_unchanged(self):
"""Test that non-dict plans are returned unchanged."""
plan = "not a dict"
result = validate_and_fix_plan(plan)
assert result == plan
def test_plan_with_non_dict_step_skipped(self):
"""Test that non-dict step items are skipped without error."""
plan = {
"steps": [
"not a dict step", # This should be skipped
{
"need_search": True,
"title": "Valid Step",
"description": "Gather",
},
]
}
result = validate_and_fix_plan(plan)
# Non-dict step should be unchanged, valid step should be fixed
assert result["steps"][0] == "not a dict step"
assert result["steps"][1]["step_type"] == "research"
def test_empty_steps_list(self):
"""Test that plan with empty steps list is handled gracefully."""
plan = {"steps": []}
result = validate_and_fix_plan(plan)
assert result["steps"] == []
def test_missing_steps_key(self):
"""Test that plan without steps key is handled gracefully."""
plan = {"locale": "en-US", "title": "Test"}
result = validate_and_fix_plan(plan)
assert "steps" not in result
class TestValidateAndFixPlanWebSearchEnforcement:
"""Test web search enforcement logic."""
def test_enforce_web_search_sets_first_research_step(self):
"""Test that enforce_web_search=True sets need_search on first research step."""
plan = {
"steps": [
{
"need_search": False,
"title": "Research Step",
"description": "Gather",
"step_type": "research",
},
{
"need_search": False,
"title": "Processing Step",
"description": "Analyze",
"step_type": "processing",
},
]
}
result = validate_and_fix_plan(plan, enforce_web_search=True)
# First research step should have web search enabled
assert result["steps"][0]["need_search"] is True
assert result["steps"][1]["need_search"] is False
def test_enforce_web_search_converts_first_step(self):
"""Test that enforce_web_search converts first step to research if needed."""
plan = {
"steps": [
{
"need_search": False,
"title": "First Step",
"description": "Do something",
"step_type": "processing",
},
]
}
result = validate_and_fix_plan(plan, enforce_web_search=True)
# First step should be converted to research with web search
assert result["steps"][0]["step_type"] == "research"
assert result["steps"][0]["need_search"] is True
def test_enforce_web_search_with_existing_search_step(self):
"""Test that enforce_web_search doesn't modify if search step already exists."""
plan = {
"steps": [
{
"need_search": True,
"title": "Research Step",
"description": "Gather",
"step_type": "research",
},
{
"need_search": False,
"title": "Processing Step",
"description": "Analyze",
"step_type": "processing",
},
]
}
result = validate_and_fix_plan(plan, enforce_web_search=True)
# Steps should remain unchanged
assert result["steps"][0]["need_search"] is True
assert result["steps"][1]["need_search"] is False
def test_enforce_web_search_adds_default_step(self):
"""Test that enforce_web_search adds default research step if no steps exist."""
plan = {"steps": []}
result = validate_and_fix_plan(plan, enforce_web_search=True)
assert len(result["steps"]) == 1
assert result["steps"][0]["step_type"] == "research"
assert result["steps"][0]["need_search"] is True
assert "title" in result["steps"][0]
assert "description" in result["steps"][0]
def test_enforce_web_search_without_steps_key(self):
"""Test enforce_web_search when steps key is missing."""
plan = {"locale": "en-US"}
result = validate_and_fix_plan(plan, enforce_web_search=True)
assert len(result.get("steps", [])) > 0
assert result["steps"][0]["step_type"] == "research"
class TestValidateAndFixPlanIntegration:
"""Integration tests for step_type repair and web search enforcement together."""
def test_repair_and_enforce_together(self):
"""Test that step_type repair and web search enforcement work together."""
plan = {
"steps": [
{
"need_search": True,
"title": "Research Step",
"description": "Gather",
# MISSING step_type
},
{
"need_search": False,
"title": "Processing Step",
"description": "Analyze",
# MISSING step_type, but enforce_web_search won't change it
},
]
}
result = validate_and_fix_plan(plan, enforce_web_search=True)
# step_type should be repaired
assert result["steps"][0]["step_type"] == "research"
# Issue #677: non-search steps now default to 'analysis' instead of 'processing'
assert result["steps"][1]["step_type"] == "analysis"
# First research step should have web search (already has it)
assert result["steps"][0]["need_search"] is True
def test_repair_then_enforce_cascade(self):
"""Test complex scenario with repair and enforcement cascading."""
plan = {
"steps": [
{
"need_search": False,
"title": "Step 1",
"description": "Do something",
# MISSING step_type
},
{
"need_search": False,
"title": "Step 2",
"description": "Do something else",
# MISSING step_type
},
]
}
result = validate_and_fix_plan(plan, enforce_web_search=True)
# Step 1: Originally analysis (from auto-repair) but converted to research with web search enforcement
assert result["steps"][0]["step_type"] == "research"
assert result["steps"][0]["need_search"] is True
# Step 2: Should remain as analysis since enforcement already satisfied by step 1
# Issue #677: non-search steps now default to 'analysis' instead of 'processing'
assert result["steps"][1]["step_type"] == "analysis"
assert result["steps"][1]["need_search"] is False
class TestValidateAndFixPlanIssue650:
"""Specific tests for Issue #650 scenarios."""
def test_issue_650_water_footprint_scenario_fixed(self):
"""Test the exact scenario from issue #650 - water footprint query with missing step_type."""
# This is a simplified version of the actual error from issue #650
plan = {
"locale": "en-US",
"has_enough_context": False,
"title": "Research Plan — Water Footprint of 1 kg of Beef",
"thought": "You asked: 'How many liters of water are required to produce 1 kg of beef?'",
"steps": [
{
"need_search": True,
"title": "Authoritative estimates",
"description": "Collect peer-reviewed estimates",
# MISSING step_type - this caused the error in issue #650
},
{
"need_search": True,
"title": "System-specific data",
"description": "Gather system-level data",
# MISSING step_type
},
{
"need_search": False,
"title": "Processing and analysis",
"description": "Compute scenario-based estimates",
# MISSING step_type
},
],
}
result = validate_and_fix_plan(plan)
# All steps should now have step_type
assert result["steps"][0]["step_type"] == "research"
assert result["steps"][1]["step_type"] == "research"
# Issue #677: non-search steps now default to 'analysis' instead of 'processing'
assert result["steps"][2]["step_type"] == "analysis"
def test_issue_650_scenario_passes_pydantic_validation(self):
"""Test that fixed plan can be validated by Pydantic schema."""
from src.prompts.planner_model import Plan as PlanModel
plan = {
"locale": "en-US",
"has_enough_context": False,
"title": "Research Plan",
"thought": "Test thought",
"steps": [
{
"need_search": True,
"title": "Research",
"description": "Gather data",
# MISSING step_type
},
],
}
# First validate and fix
fixed_plan = validate_and_fix_plan(plan)
# Then try Pydantic validation (should not raise)
validated = PlanModel.model_validate(fixed_plan)
assert validated.steps[0].step_type == "research"
assert validated.steps[0].need_search is True
def test_issue_650_multiple_validation_errors_fixed(self):
"""Test that plan with multiple missing step_types (like in issue #650) all get fixed."""
plan = {
"locale": "en-US",
"has_enough_context": False,
"title": "Complex Plan",
"thought": "Research plan",
"steps": [
{
"need_search": True,
"title": "Step 0",
"description": "Data gathering",
},
{
"need_search": True,
"title": "Step 1",
"description": "More gathering",
},
{
"need_search": False,
"title": "Step 2",
"description": "Processing",
},
],
}
result = validate_and_fix_plan(plan)
# All steps should have step_type now
for step in result["steps"]:
assert "step_type" in step
# Issue #677: 'analysis' is now a valid step_type
assert step["step_type"] in ["research", "analysis", "processing"]
def test_issue_650_no_exceptions_raised(self):
"""Test that validate_and_fix_plan handles all edge cases without raising exceptions."""
test_cases = [
{"steps": []},
{"steps": [{"need_search": True}]},
{"steps": [None, {}]},
{"steps": ["invalid"]},
{"steps": [{"need_search": True, "step_type": ""}]},
"not a dict",
]
for plan in test_cases:
try:
result = validate_and_fix_plan(plan)
# Should succeed without exception - result may be returned as-is for non-dict
# but the function should not raise
# No assertion needed; test passes if no exception is raised
except Exception as e:
pytest.fail(f"validate_and_fix_plan raised exception for {plan}: {e}")
-355
View File
@@ -1,355 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for state preservation functionality in graph nodes.
Tests the preserve_state_meta_fields() function and verifies that
critical state fields (especially locale) are properly preserved
across node state transitions.
"""
import pytest
from langgraph.types import Command
from src.graph.nodes import preserve_state_meta_fields
from src.graph.types import State
class TestPreserveStateMetaFields:
"""Test suite for preserve_state_meta_fields() function."""
def test_preserve_all_fields_with_defaults(self):
"""Test that all fields are preserved with default values when state is empty."""
# Create a minimal state with only messages
state = State(messages=[])
# Extract meta fields
preserved = preserve_state_meta_fields(state)
# Verify all expected fields are present
assert "locale" in preserved
assert "research_topic" in preserved
assert "clarified_research_topic" in preserved
assert "clarification_history" in preserved
assert "enable_clarification" in preserved
assert "max_clarification_rounds" in preserved
assert "clarification_rounds" in preserved
assert "resources" in preserved
# Verify default values
assert preserved["locale"] == "en-US"
assert preserved["research_topic"] == ""
assert preserved["clarified_research_topic"] == ""
assert preserved["clarification_history"] == []
assert preserved["enable_clarification"] is False
assert preserved["max_clarification_rounds"] == 3
assert preserved["clarification_rounds"] == 0
assert preserved["resources"] == []
def test_preserve_locale_from_state(self):
"""Test that locale is correctly preserved when set in state."""
state = State(messages=[], locale="zh-CN")
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == "zh-CN"
def test_preserve_locale_english(self):
"""Test that English locale is preserved."""
state = State(messages=[], locale="en-US")
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == "en-US"
def test_preserve_locale_with_custom_value(self):
"""Test that custom locale values are preserved."""
state = State(messages=[], locale="fr-FR")
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == "fr-FR"
def test_preserve_research_topic(self):
"""Test that research_topic is correctly preserved."""
test_topic = "How to build sustainable cities"
state = State(messages=[], research_topic=test_topic)
preserved = preserve_state_meta_fields(state)
assert preserved["research_topic"] == test_topic
def test_preserve_clarified_research_topic(self):
"""Test that clarified_research_topic is correctly preserved."""
test_topic = "Sustainable urban development with focus on green spaces"
state = State(messages=[], clarified_research_topic=test_topic)
preserved = preserve_state_meta_fields(state)
assert preserved["clarified_research_topic"] == test_topic
def test_preserve_clarification_history(self):
"""Test that clarification_history is correctly preserved."""
history = ["Q: What aspects?", "A: Architecture and planning"]
state = State(messages=[], clarification_history=history)
preserved = preserve_state_meta_fields(state)
assert preserved["clarification_history"] == history
def test_preserve_clarification_flags(self):
"""Test that clarification flags are correctly preserved."""
state = State(
messages=[],
enable_clarification=True,
max_clarification_rounds=5,
clarification_rounds=2,
)
preserved = preserve_state_meta_fields(state)
assert preserved["enable_clarification"] is True
assert preserved["max_clarification_rounds"] == 5
assert preserved["clarification_rounds"] == 2
def test_preserve_resources(self):
"""Test that resources list is correctly preserved."""
resources = [{"id": "1", "name": "Resource 1"}]
state = State(messages=[], resources=resources)
preserved = preserve_state_meta_fields(state)
assert preserved["resources"] == resources
def test_preserve_all_fields_together(self):
"""Test that all meta fields are preserved together correctly."""
state = State(
messages=[],
locale="zh-CN",
research_topic="原始查询",
clarified_research_topic="澄清后的查询",
clarification_history=["Q1", "A1", "Q2", "A2"],
enable_clarification=True,
max_clarification_rounds=5,
clarification_rounds=2,
resources=["resource1"],
)
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == "zh-CN"
assert preserved["research_topic"] == "原始查询"
assert preserved["clarified_research_topic"] == "澄清后的查询"
assert preserved["clarification_history"] == ["Q1", "A1", "Q2", "A2"]
assert preserved["enable_clarification"] is True
assert preserved["max_clarification_rounds"] == 5
assert preserved["clarification_rounds"] == 2
assert preserved["resources"] == ["resource1"]
def test_preserve_returns_dict_not_state_object(self):
"""Test that preserve_state_meta_fields returns a dict."""
state = State(messages=[], locale="zh-CN")
preserved = preserve_state_meta_fields(state)
assert isinstance(preserved, dict)
# Verify it's a plain dict with expected keys
assert "locale" in preserved
assert "research_topic" in preserved
def test_preserve_does_not_mutate_original_state(self):
"""Test that calling preserve_state_meta_fields does not mutate the original state."""
original_locale = "zh-CN"
state = State(messages=[], locale=original_locale)
original_state_copy = dict(state)
preserve_state_meta_fields(state)
# Verify state hasn't changed
assert state["locale"] == original_locale
assert dict(state) == original_state_copy
def test_preserve_with_none_values(self):
"""Test that preserve handles None values gracefully."""
state = State(messages=[], locale=None)
preserved = preserve_state_meta_fields(state)
# Should use default when value is None
assert preserved["locale"] is None or preserved["locale"] == "en-US"
def test_preserve_empty_lists_preserved(self):
"""Test that empty lists are preserved correctly."""
state = State(
messages=[], clarification_history=[], resources=[]
)
preserved = preserve_state_meta_fields(state)
assert preserved["clarification_history"] == []
assert preserved["resources"] == []
def test_preserve_count_of_fields(self):
"""Test that exactly 8 fields are preserved."""
state = State(messages=[])
preserved = preserve_state_meta_fields(state)
# Should have exactly 8 meta fields
assert len(preserved) == 8
def test_preserve_field_names(self):
"""Test that all expected field names are present."""
state = State(messages=[])
preserved = preserve_state_meta_fields(state)
expected_fields = {
"locale",
"research_topic",
"clarified_research_topic",
"clarification_history",
"enable_clarification",
"max_clarification_rounds",
"clarification_rounds",
"resources",
}
assert set(preserved.keys()) == expected_fields
class TestStatePreservationInCommand:
"""Test suite for using preserved state fields in Command objects."""
def test_command_update_with_preserved_fields(self):
"""Test that preserved fields can be unpacked into Command.update."""
state = State(messages=[], locale="zh-CN", research_topic="测试")
# This should not raise any errors
preserved = preserve_state_meta_fields(state)
command_update = {
"messages": [],
**preserved,
}
assert "locale" in command_update
assert "research_topic" in command_update
assert command_update["locale"] == "zh-CN"
def test_command_unpacking_syntax(self):
"""Test that the unpacking syntax works correctly with preserved fields."""
state = State(messages=[], locale="en-US")
preserved = preserve_state_meta_fields(state)
# Simulate how it's used in actual nodes
update_dict = {
"messages": [],
"current_plan": None,
**preserved,
"locale": "zh-CN",
}
assert len(update_dict) >= 10 # 2 explicit + 8 preserved
assert update_dict["locale"] == "zh-CN" # overridden value
class TestLocalePreservationSpecific:
"""Specific test cases for locale preservation (the main issue being fixed)."""
def test_locale_not_lost_in_transition(self):
"""Test that locale is not lost when transitioning between nodes."""
# Initial state from frontend with Chinese locale
initial_state = State(messages=[], locale="zh-CN")
# Extract for first node transition
preserved_1 = preserve_state_meta_fields(initial_state)
# Simulate state update from first node
updated_state_1 = State(
messages=[], **preserved_1
)
# Extract for second node transition
preserved_2 = preserve_state_meta_fields(updated_state_1)
# Locale should still be zh-CN after two transitions
assert preserved_2["locale"] == "zh-CN"
def test_locale_chain_through_multiple_nodes(self):
"""Test that locale survives through multiple node transitions."""
initial_locale = "zh-CN"
state = State(messages=[], locale=initial_locale)
# Simulate 5 node transitions
for _ in range(5):
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == initial_locale
# Create new state for next "node"
state = State(messages=[], **preserved)
# After 5 transitions, locale should still be preserved
assert state.get("locale") == initial_locale
def test_locale_with_other_fields_preserved_together(self):
"""Test that locale is preserved correctly even when other fields change."""
initial_state = State(
messages=[],
locale="zh-CN",
research_topic="Original",
clarification_rounds=0,
)
preserved = preserve_state_meta_fields(initial_state)
# Verify locale is in preserved dict
assert preserved["locale"] == "zh-CN"
assert preserved["research_topic"] == "Original"
assert preserved["clarification_rounds"] == 0
# Create new state with preserved fields
modified_state = State(
messages=[],
**preserved,
)
# Locale should be preserved
assert modified_state.get("locale") == "zh-CN"
# Research topic should be preserved from original
assert modified_state.get("research_topic") == "Original"
assert modified_state.get("clarification_rounds") == 0
class TestEdgeCases:
"""Test edge cases and boundary conditions."""
def test_very_long_research_topic(self):
"""Test preservation with very long research_topic."""
long_topic = "a" * 10000
state = State(messages=[], research_topic=long_topic)
preserved = preserve_state_meta_fields(state)
assert preserved["research_topic"] == long_topic
def test_unicode_characters_in_topic(self):
"""Test preservation with unicode characters."""
unicode_topic = "中文测试 🌍 テスト 🧪"
state = State(messages=[], research_topic=unicode_topic)
preserved = preserve_state_meta_fields(state)
assert preserved["research_topic"] == unicode_topic
def test_special_characters_in_locale(self):
"""Test preservation with special locale formats."""
special_locales = ["zh-CN", "en-US", "pt-BR", "es-ES", "ja-JP"]
for locale in special_locales:
state = State(messages=[], locale=locale)
preserved = preserve_state_meta_fields(state)
assert preserved["locale"] == locale
def test_large_clarification_history(self):
"""Test preservation with large clarification_history."""
large_history = [f"Q{i}: Question {i}" for i in range(100)]
state = State(messages=[], clarification_history=large_history)
preserved = preserve_state_meta_fields(state)
assert len(preserved["clarification_history"]) == 100
assert preserved["clarification_history"] == large_history
def test_max_clarification_rounds_boundary(self):
"""Test preservation with boundary values for max_clarification_rounds."""
test_cases = [0, 1, 3, 10, 100, 999]
for value in test_cases:
state = State(messages=[], max_clarification_rounds=value)
preserved = preserve_state_meta_fields(state)
assert preserved["max_clarification_rounds"] == value
-305
View File
@@ -1,305 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
from langchain_core.messages import (
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessageChunk,
)
from src.llms import llm as llm_module
from src.llms.providers import dashscope as dashscope_module
from src.llms.providers.dashscope import (
ChatDashscope,
_convert_chunk_to_generation_chunk,
_convert_delta_to_message_chunk,
)
class DummyChatDashscope:
def __init__(self, **kwargs):
self.kwargs = kwargs
@pytest.fixture
def dashscope_conf():
return {
"BASIC_MODEL": {
"api_key": "k",
"base_url": "https://dashscope.aliyuncs.com/v1",
"model": "qwen3-235b-a22b-instruct-2507",
},
"REASONING_MODEL": {
"api_key": "rk",
"base_url": "https://dashscope.aliyuncs.com/v1",
"model": "qwen3-235b-a22b-thinking-2507",
},
}
def test_convert_delta_to_message_chunk_roles_and_extras():
# Assistant with reasoning + tool calls
delta = {
"role": "assistant",
"content": "Hello",
"reasoning_content": "Think...",
"tool_calls": [
{
"id": "call_1",
"index": 0,
"function": {"name": "lookup", "arguments": '{\\"q\\":\\"x\\"}'},
}
],
}
msg = _convert_delta_to_message_chunk(delta, AIMessageChunk)
assert isinstance(msg, AIMessageChunk)
assert msg.content == "Hello"
assert msg.additional_kwargs.get("reasoning_content") == "Think..."
# tool_call_chunks should be present
assert getattr(msg, "tool_call_chunks", None)
# Human
delta = {"role": "user", "content": "Hi"}
msg = _convert_delta_to_message_chunk(delta, HumanMessageChunk)
assert isinstance(msg, HumanMessageChunk)
# System
delta = {"role": "system", "content": "Rules"}
msg = _convert_delta_to_message_chunk(delta, SystemMessageChunk)
assert isinstance(msg, SystemMessageChunk)
# Function
delta = {"role": "function", "name": "f", "content": "{}"}
msg = _convert_delta_to_message_chunk(delta, FunctionMessageChunk)
assert isinstance(msg, FunctionMessageChunk)
# Tool
delta = {"role": "tool", "tool_call_id": "t1", "content": "ok"}
msg = _convert_delta_to_message_chunk(delta, ToolMessageChunk)
assert isinstance(msg, ToolMessageChunk)
def test_convert_chunk_to_generation_chunk_skip_and_usage():
# Skips content.delta type
assert (
_convert_chunk_to_generation_chunk(
{"type": "content.delta"}, AIMessageChunk, None
)
is None
)
# Proper chunk with usage and finish info
chunk = {
"choices": [
{
"delta": {"role": "assistant", "content": "Hi"},
"finish_reason": "stop",
}
],
"model": "qwen3-235b-a22b-instruct-2507",
"system_fingerprint": "fp",
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
}
gen = _convert_chunk_to_generation_chunk(chunk, AIMessageChunk, None)
assert gen is not None
assert isinstance(gen.message, AIMessageChunk)
assert gen.message.content == "Hi"
# usage metadata should attach to AI message
assert getattr(gen.message, "usage_metadata", None) is not None
assert gen.generation_info.get("finish_reason") == "stop"
assert gen.generation_info.get("model_name") == "qwen3-235b-a22b-instruct-2507"
assert gen.generation_info.get("system_fingerprint") == "fp"
def test_llm_selects_dashscope_and_sets_enable_thinking(monkeypatch, dashscope_conf):
# Use dummy class to capture kwargs on construction
monkeypatch.setattr(llm_module, "ChatDashscope", DummyChatDashscope)
# basic -> enable_thinking False
inst = llm_module._create_llm_use_conf("basic", dashscope_conf)
assert isinstance(inst, DummyChatDashscope)
assert inst.kwargs["extra_body"]["enable_thinking"] is False
assert inst.kwargs["base_url"].find("dashscope.") > 0
# reasoning -> enable_thinking True
inst2 = llm_module._create_llm_use_conf("reasoning", dashscope_conf)
assert isinstance(inst2, DummyChatDashscope)
assert inst2.kwargs["extra_body"]["enable_thinking"] is True
def test_llm_verify_ssl_false_adds_http_clients(monkeypatch, dashscope_conf):
monkeypatch.setattr(llm_module, "ChatDashscope", DummyChatDashscope)
# turn off ssl
dashscope_conf = {**dashscope_conf}
dashscope_conf["BASIC_MODEL"] = {
**dashscope_conf["BASIC_MODEL"],
"verify_ssl": False,
}
inst = llm_module._create_llm_use_conf("basic", dashscope_conf)
assert "http_client" in inst.kwargs
assert "http_async_client" in inst.kwargs
def test_convert_delta_to_message_chunk_developer_and_function_call_and_tool_calls():
# developer role -> SystemMessageChunk with __openai_role__
delta = {"role": "developer", "content": "dev rules"}
msg = _convert_delta_to_message_chunk(delta, SystemMessageChunk)
assert isinstance(msg, SystemMessageChunk)
assert msg.additional_kwargs.get("__openai_role__") == "developer"
# function_call name None -> empty string
delta = {"role": "assistant", "function_call": {"name": None, "arguments": "{}"}}
msg = _convert_delta_to_message_chunk(delta, AIMessageChunk)
assert isinstance(msg, AIMessageChunk)
assert msg.additional_kwargs["function_call"]["name"] == ""
# tool_calls: one valid, one missing function -> should not crash and create one chunk
delta = {
"role": "assistant",
"tool_calls": [
{"id": "t1", "index": 0, "function": {"name": "f", "arguments": "{}"}},
{"id": "t2", "index": 1}, # missing function key
],
}
msg = _convert_delta_to_message_chunk(delta, AIMessageChunk)
assert isinstance(msg, AIMessageChunk)
# tool_calls copied as-is
assert msg.additional_kwargs["tool_calls"][0]["id"] == "t1"
# tool_call_chunks only for valid one
assert getattr(msg, "tool_call_chunks") and len(msg.tool_call_chunks) == 1
def test_convert_delta_to_message_chunk_default_class_and_unknown_role():
# No role, default human -> HumanMessageChunk
delta = {"content": "hey"}
msg = _convert_delta_to_message_chunk(delta, HumanMessageChunk)
assert isinstance(msg, HumanMessageChunk)
# Unknown role -> ChatMessageChunk with that role
delta = {"role": "observer", "content": "hmm"}
msg = _convert_delta_to_message_chunk(delta, ChatMessageChunk)
assert isinstance(msg, ChatMessageChunk)
assert msg.role == "observer"
def test_convert_chunk_to_generation_chunk_empty_choices_and_usage():
chunk = {
"choices": [],
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
}
gen = _convert_chunk_to_generation_chunk(chunk, AIMessageChunk, None)
assert gen is not None
assert isinstance(gen.message, AIMessageChunk)
assert gen.message.content == ""
assert getattr(gen.message, "usage_metadata", None) is not None
assert gen.generation_info is None
def test_convert_chunk_to_generation_chunk_includes_base_info_and_logprobs():
chunk = {
"choices": [
{
"delta": {"role": "assistant", "content": "T"},
"logprobs": {"content": [{"token": "T", "logprob": -0.1}]},
}
]
}
base_info = {"headers": {"a": "b"}}
gen = _convert_chunk_to_generation_chunk(chunk, AIMessageChunk, base_info)
assert gen is not None
assert gen.message.content == "T"
assert gen.generation_info.get("headers") == {"a": "b"}
assert "logprobs" in gen.generation_info
def test_convert_chunk_to_generation_chunk_beta_stream_format():
chunk = {
"chunk": {
"choices": [
{"delta": {"role": "assistant", "content": "From beta stream format"}}
]
}
}
gen = _convert_chunk_to_generation_chunk(chunk, AIMessageChunk, None)
assert gen is not None
assert gen.message.content == "From beta stream format"
def test_chatdashscope_create_chat_result_adds_reasoning_content(monkeypatch):
# Dummy objects for the super() return
class DummyMsg:
def __init__(self):
self.additional_kwargs = {}
class DummyGen:
def __init__(self):
self.message = DummyMsg()
class DummyChatResult:
def __init__(self):
self.generations = [DummyGen()]
# Patch super()._create_chat_result to return our dummy structure
def fake_super_create(self, response, generation_info=None):
return DummyChatResult()
monkeypatch.setattr(
dashscope_module.ChatOpenAI, "_create_chat_result", fake_super_create
)
# Patch openai.BaseModel in the module under test
class DummyBaseModel:
pass
monkeypatch.setattr(dashscope_module.openai, "BaseModel", DummyBaseModel)
# Build a fake OpenAI-like response with reasoning_content
class RMsg:
def __init__(self, rc):
self.reasoning_content = rc
class Choice:
def __init__(self, rc):
self.message = RMsg(rc)
class FakeResponse(DummyBaseModel):
def __init__(self):
self.choices = [Choice("Reasoning...")]
llm = ChatDashscope(model="dummy", api_key="k")
result = llm._create_chat_result(FakeResponse())
assert (
result.generations[0].message.additional_kwargs.get("reasoning_content")
== "Reasoning..."
)
def test_chatdashscope_create_chat_result_dict_passthrough(monkeypatch):
class DummyMsg:
def __init__(self):
self.additional_kwargs = {}
class DummyGen:
def __init__(self):
self.message = DummyMsg()
class DummyChatResult:
def __init__(self):
self.generations = [DummyGen()]
def fake_super_create(self, response, generation_info=None):
return DummyChatResult()
monkeypatch.setattr(
dashscope_module.ChatOpenAI, "_create_chat_result", fake_super_create
)
llm = ChatDashscope(model="dummy", api_key="k")
result = llm._create_chat_result({"raw": "dict"})
# Should not inject reasoning_content for dict responses
assert "reasoning_content" not in result.generations[0].message.additional_kwargs
-127
View File
@@ -1,127 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
from src.llms import llm
class DummyChatOpenAI:
def __init__(self, **kwargs):
self.kwargs = kwargs
def invoke(self, msg):
return f"Echo: {msg}"
@pytest.fixture(autouse=True)
def patch_chat_openai(monkeypatch):
monkeypatch.setattr(llm, "ChatOpenAI", DummyChatOpenAI)
@pytest.fixture
def dummy_conf():
return {
"BASIC_MODEL": {"api_key": "test_key", "base_url": "http://test"},
"REASONING_MODEL": {"api_key": "reason_key"},
"VISION_MODEL": {"api_key": "vision_key"},
}
def test_get_env_llm_conf(monkeypatch):
# Clear any existing environment variables that might interfere
monkeypatch.delenv("BASIC_MODEL__API_KEY", raising=False)
monkeypatch.delenv("BASIC_MODEL__BASE_URL", raising=False)
monkeypatch.delenv("BASIC_MODEL__MODEL", raising=False)
monkeypatch.setenv("BASIC_MODEL__API_KEY", "env_key")
monkeypatch.setenv("BASIC_MODEL__BASE_URL", "http://env")
conf = llm._get_env_llm_conf("basic")
assert conf["api_key"] == "env_key"
assert conf["base_url"] == "http://env"
def test_create_llm_use_conf_merges_env(monkeypatch, dummy_conf):
# Clear any existing environment variables that might interfere
monkeypatch.delenv("BASIC_MODEL__BASE_URL", raising=False)
monkeypatch.delenv("BASIC_MODEL__MODEL", raising=False)
monkeypatch.setenv("BASIC_MODEL__API_KEY", "env_key")
result = llm._create_llm_use_conf("basic", dummy_conf)
assert isinstance(result, DummyChatOpenAI)
assert result.kwargs["api_key"] == "env_key"
assert result.kwargs["base_url"] == "http://test"
def test_create_llm_use_conf_invalid_type(monkeypatch, dummy_conf):
# Clear any existing environment variables that might interfere
monkeypatch.delenv("BASIC_MODEL__API_KEY", raising=False)
monkeypatch.delenv("BASIC_MODEL__BASE_URL", raising=False)
monkeypatch.delenv("BASIC_MODEL__MODEL", raising=False)
with pytest.raises(ValueError):
llm._create_llm_use_conf("unknown", dummy_conf)
def test_create_llm_use_conf_empty_conf(monkeypatch):
# Clear any existing environment variables that might interfere
monkeypatch.delenv("BASIC_MODEL__API_KEY", raising=False)
monkeypatch.delenv("BASIC_MODEL__BASE_URL", raising=False)
monkeypatch.delenv("BASIC_MODEL__MODEL", raising=False)
with pytest.raises(ValueError):
llm._create_llm_use_conf("basic", {})
def test_get_llm_by_type_caches(monkeypatch, dummy_conf):
called = {}
def fake_load_yaml_config(path):
called["called"] = True
return dummy_conf
monkeypatch.setattr(llm, "load_yaml_config", fake_load_yaml_config)
llm._llm_cache.clear()
inst1 = llm.get_llm_by_type("basic")
inst2 = llm.get_llm_by_type("basic")
assert inst1 is inst2
assert called["called"]
def test_create_llm_filters_unexpected_keys(monkeypatch, caplog):
"""Test that unexpected configuration keys like SEARCH_ENGINE are filtered out (Issue #411)."""
import logging
# Clear any existing environment variables that might interfere
monkeypatch.delenv("BASIC_MODEL__API_KEY", raising=False)
monkeypatch.delenv("BASIC_MODEL__BASE_URL", raising=False)
monkeypatch.delenv("BASIC_MODEL__MODEL", raising=False)
# Config with unexpected keys that should be filtered
conf_with_unexpected_keys = {
"BASIC_MODEL": {
"api_key": "test_key",
"base_url": "http://test",
"model": "gpt-4",
"SEARCH_ENGINE": {"include_domains": ["example.com"]}, # Should be filtered
"engine": "tavily", # Should be filtered
}
}
with caplog.at_level(logging.WARNING):
result = llm._create_llm_use_conf("basic", conf_with_unexpected_keys)
# Verify the LLM was created
assert isinstance(result, DummyChatOpenAI)
# Verify unexpected keys were not passed to the LLM
assert "SEARCH_ENGINE" not in result.kwargs
assert "engine" not in result.kwargs
# Verify valid keys were passed
assert result.kwargs["api_key"] == "test_key"
assert result.kwargs["base_url"] == "http://test"
assert result.kwargs["model"] == "gpt-4"
# Verify warnings were logged
assert any("SEARCH_ENGINE" in record.message for record in caplog.records)
assert any("engine" in record.message for record in caplog.records)
-2
View File
@@ -1,2 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
@@ -1,214 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
from unittest.mock import MagicMock, patch
import openai
import pytest
from src.podcast.graph.script_writer_node import script_writer_node
from src.podcast.types import Script, ScriptLine
class TestScriptWriterNode:
"""Tests for script_writer_node function."""
@pytest.fixture
def sample_state(self):
"""Create a sample podcast state."""
return {"input": "Test content for podcast generation"}
@pytest.fixture
def sample_script(self):
"""Create a sample Script object."""
return Script(
locale="en",
lines=[
ScriptLine(speaker="male", paragraph="Hello, welcome to our podcast."),
ScriptLine(speaker="female", paragraph="Today we discuss testing."),
],
)
@pytest.fixture
def sample_script_json(self, sample_script):
"""Create JSON representation of sample script."""
return sample_script.model_dump_json()
@patch("src.podcast.graph.script_writer_node.get_prompt_template")
@patch("src.podcast.graph.script_writer_node.get_llm_by_type")
def test_script_writer_with_json_mode_success(
self, mock_get_llm, mock_get_template, sample_state, sample_script
):
"""Test successful script generation using json_mode."""
mock_get_template.return_value = "Generate a podcast script."
mock_model = MagicMock()
mock_structured_model = MagicMock()
mock_model.with_structured_output.return_value = mock_structured_model
mock_structured_model.invoke.return_value = sample_script
mock_get_llm.return_value = mock_model
result = script_writer_node(sample_state)
assert result["script"] == sample_script
assert result["audio_chunks"] == []
mock_model.with_structured_output.assert_called_once_with(
Script, method="json_mode"
)
@patch("src.podcast.graph.script_writer_node.get_prompt_template")
@patch("src.podcast.graph.script_writer_node.get_llm_by_type")
def test_script_writer_fallback_on_json_object_not_supported(
self, mock_get_llm, mock_get_template, sample_state, sample_script_json
):
"""Test fallback to prompting when model doesn't support json_object."""
mock_get_template.return_value = "Generate a podcast script."
mock_model = MagicMock()
mock_structured_model = MagicMock()
mock_model.with_structured_output.return_value = mock_structured_model
# Simulate json_object not supported error
mock_structured_model.invoke.side_effect = openai.BadRequestError(
message="json_object is not supported by this model",
response=MagicMock(status_code=400),
body={"error": {"message": "json_object is not supported"}},
)
# Mock the fallback response
mock_response = MagicMock()
mock_response.content = sample_script_json
mock_model.invoke.return_value = mock_response
mock_get_llm.return_value = mock_model
result = script_writer_node(sample_state)
assert result["script"].locale == "en"
assert len(result["script"].lines) == 2
assert result["audio_chunks"] == []
# Verify fallback was used
mock_model.invoke.assert_called_once()
@patch("src.podcast.graph.script_writer_node.get_prompt_template")
@patch("src.podcast.graph.script_writer_node.get_llm_by_type")
def test_script_writer_reraises_other_bad_request_errors(
self, mock_get_llm, mock_get_template, sample_state
):
"""Test that other BadRequestError types are re-raised."""
mock_get_template.return_value = "Generate a podcast script."
mock_model = MagicMock()
mock_structured_model = MagicMock()
mock_model.with_structured_output.return_value = mock_structured_model
# Simulate a different BadRequestError (not json_object related)
mock_structured_model.invoke.side_effect = openai.BadRequestError(
message="Invalid model parameter",
response=MagicMock(status_code=400),
body={"error": {"message": "Invalid model parameter"}},
)
mock_get_llm.return_value = mock_model
with pytest.raises(openai.BadRequestError) as exc_info:
script_writer_node(sample_state)
assert "Invalid model parameter" in str(exc_info.value)
@patch("src.podcast.graph.script_writer_node.get_prompt_template")
@patch("src.podcast.graph.script_writer_node.get_llm_by_type")
def test_script_writer_fallback_with_markdown_wrapped_json(
self, mock_get_llm, mock_get_template, sample_state
):
"""Test fallback handles JSON wrapped in markdown code blocks."""
mock_get_template.return_value = "Generate a podcast script."
mock_model = MagicMock()
mock_structured_model = MagicMock()
mock_model.with_structured_output.return_value = mock_structured_model
mock_structured_model.invoke.side_effect = openai.BadRequestError(
message="json_object is not supported",
response=MagicMock(status_code=400),
body={},
)
# Mock response with markdown-wrapped JSON (common LLM output)
mock_response = MagicMock()
mock_response.content = """```json
{
"locale": "zh",
"lines": [
{"speaker": "male", "paragraph": "欢迎收听播客。"}
]
}
```"""
mock_model.invoke.return_value = mock_response
mock_get_llm.return_value = mock_model
result = script_writer_node(sample_state)
assert result["script"].locale == "zh"
assert len(result["script"].lines) == 1
assert result["script"].lines[0].speaker == "male"
@patch("src.podcast.graph.script_writer_node.get_prompt_template")
@patch("src.podcast.graph.script_writer_node.get_llm_by_type")
def test_script_writer_fallback_raises_on_invalid_json(
self, mock_get_llm, mock_get_template, sample_state
):
"""Test that fallback raises JSONDecodeError when response is not valid JSON."""
mock_get_template.return_value = "Generate a podcast script."
mock_model = MagicMock()
mock_structured_model = MagicMock()
mock_model.with_structured_output.return_value = mock_structured_model
mock_structured_model.invoke.side_effect = openai.BadRequestError(
message="json_object is not supported",
response=MagicMock(status_code=400),
body={},
)
# Mock response with completely invalid JSON
mock_response = MagicMock()
mock_response.content = "This is not JSON at all, just plain text response."
mock_model.invoke.return_value = mock_response
mock_get_llm.return_value = mock_model
with pytest.raises(json.JSONDecodeError):
script_writer_node(sample_state)
@patch("src.podcast.graph.script_writer_node.get_prompt_template")
@patch("src.podcast.graph.script_writer_node.get_llm_by_type")
def test_script_writer_fallback_raises_on_invalid_schema(
self, mock_get_llm, mock_get_template, sample_state
):
"""Test that fallback raises ValidationError when JSON doesn't match Script schema."""
mock_get_template.return_value = "Generate a podcast script."
mock_model = MagicMock()
mock_structured_model = MagicMock()
mock_model.with_structured_output.return_value = mock_structured_model
mock_structured_model.invoke.side_effect = openai.BadRequestError(
message="json_object is not supported",
response=MagicMock(status_code=400),
body={},
)
# Mock response with valid JSON but invalid schema (missing required fields, wrong types)
mock_response = MagicMock()
mock_response.content = '{"locale": "invalid_locale", "lines": "not_a_list"}'
mock_model.invoke.return_value = mock_response
mock_get_llm.return_value = mock_model
# Pydantic ValidationError is raised when schema validation fails
from pydantic import ValidationError
with pytest.raises(ValidationError):
script_writer_node(sample_state)
-2
View File
@@ -1,2 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
@@ -1,2 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
@@ -1,156 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import MagicMock, patch
import pytest
from src.prompt_enhancer.graph.builder import build_graph
from src.prompt_enhancer.graph.state import PromptEnhancerState
class TestBuildGraph:
"""Test cases for build_graph function."""
@patch("src.prompt_enhancer.graph.builder.StateGraph")
def test_build_graph_structure(self, mock_state_graph):
"""Test that build_graph creates the correct graph structure."""
mock_builder = MagicMock()
mock_compiled_graph = MagicMock()
mock_state_graph.return_value = mock_builder
mock_builder.compile.return_value = mock_compiled_graph
result = build_graph()
# Verify StateGraph was created with correct state type
mock_state_graph.assert_called_once_with(PromptEnhancerState)
# Verify entry point was set
mock_builder.set_entry_point.assert_called_once_with("enhancer")
# Verify finish point was set
mock_builder.set_finish_point.assert_called_once_with("enhancer")
# Verify graph was compiled
mock_builder.compile.assert_called_once()
# Verify return value
assert result == mock_compiled_graph
@patch("src.prompt_enhancer.graph.builder.StateGraph")
@patch("src.prompt_enhancer.graph.builder.prompt_enhancer_node")
def test_build_graph_node_function(self, mock_enhancer_node, mock_state_graph):
"""Test that the correct node function is added to the graph."""
mock_builder = MagicMock()
mock_compiled_graph = MagicMock()
mock_state_graph.return_value = mock_builder
mock_builder.compile.return_value = mock_compiled_graph
build_graph()
# Verify the correct node function was added
mock_builder.add_node.assert_called_once_with("enhancer", mock_enhancer_node)
def test_build_graph_returns_compiled_graph(self):
"""Test that build_graph returns a compiled graph object."""
with patch("src.prompt_enhancer.graph.builder.StateGraph") as mock_state_graph:
mock_builder = MagicMock()
mock_compiled_graph = MagicMock()
mock_state_graph.return_value = mock_builder
mock_builder.compile.return_value = mock_compiled_graph
result = build_graph()
assert result is mock_compiled_graph
@patch("src.prompt_enhancer.graph.builder.StateGraph")
def test_build_graph_call_sequence(self, mock_state_graph):
"""Test that build_graph calls methods in the correct sequence."""
mock_builder = MagicMock()
mock_compiled_graph = MagicMock()
mock_state_graph.return_value = mock_builder
mock_builder.compile.return_value = mock_compiled_graph
# Track call order
call_order = []
def track_add_node(*args, **kwargs):
call_order.append("add_node")
def track_set_entry_point(*args, **kwargs):
call_order.append("set_entry_point")
def track_set_finish_point(*args, **kwargs):
call_order.append("set_finish_point")
def track_compile(*args, **kwargs):
call_order.append("compile")
return mock_compiled_graph
mock_builder.add_node.side_effect = track_add_node
mock_builder.set_entry_point.side_effect = track_set_entry_point
mock_builder.set_finish_point.side_effect = track_set_finish_point
mock_builder.compile.side_effect = track_compile
build_graph()
# Verify the correct call sequence
expected_order = ["add_node", "set_entry_point", "set_finish_point", "compile"]
assert call_order == expected_order
def test_build_graph_integration(self):
"""Integration test to verify the graph can be built without mocking."""
# This test verifies that all imports and dependencies are correct
try:
graph = build_graph()
assert graph is not None
# The graph should be a compiled LangGraph object
assert hasattr(graph, "invoke") or hasattr(graph, "stream")
except ImportError as e:
pytest.skip(f"Skipping integration test due to missing dependencies: {e}")
except Exception as e:
# If there are configuration issues (like missing LLM config),
# we still consider the test successful if the graph structure is built
if "LLM" in str(e) or "configuration" in str(e).lower():
pytest.skip(
f"Skipping integration test due to configuration issues: {e}"
)
else:
raise
@patch("src.prompt_enhancer.graph.builder.StateGraph")
def test_build_graph_single_node_workflow(self, mock_state_graph):
"""Test that the graph is configured as a single-node workflow."""
mock_builder = MagicMock()
mock_compiled_graph = MagicMock()
mock_state_graph.return_value = mock_builder
mock_builder.compile.return_value = mock_compiled_graph
build_graph()
# Verify only one node is added
assert mock_builder.add_node.call_count == 1
# Verify entry and finish points are the same node
mock_builder.set_entry_point.assert_called_once_with("enhancer")
mock_builder.set_finish_point.assert_called_once_with("enhancer")
@patch("src.prompt_enhancer.graph.builder.StateGraph")
def test_build_graph_state_type(self, mock_state_graph):
"""Test that the graph is initialized with the correct state type."""
mock_builder = MagicMock()
mock_compiled_graph = MagicMock()
mock_state_graph.return_value = mock_builder
mock_builder.compile.return_value = mock_compiled_graph
build_graph()
# Verify StateGraph was initialized with PromptEnhancerState
args, kwargs = mock_state_graph.call_args
assert args[0] == PromptEnhancerState
@@ -1,526 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from src.config.report_style import ReportStyle
from src.prompt_enhancer.graph.enhancer_node import prompt_enhancer_node
from src.prompt_enhancer.graph.state import PromptEnhancerState
@pytest.fixture
def mock_llm():
"""Mock LLM that returns a test response."""
llm = MagicMock()
llm.invoke.return_value = MagicMock(
content="""Thoughts: LLM thinks a lot
<enhanced_prompt>
Enhanced test prompt
</enhanced_prompt>
"""
)
return llm
@pytest.fixture
def mock_llm_xml_with_whitespace():
"""Mock LLM that returns XML response with extra whitespace."""
llm = MagicMock()
llm.invoke.return_value = MagicMock(
content="""
Some thoughts here...
<enhanced_prompt>
Enhanced prompt with whitespace
</enhanced_prompt>
Additional content after XML
"""
)
return llm
@pytest.fixture
def mock_llm_xml_multiline():
"""Mock LLM that returns XML response with multiline content."""
llm = MagicMock()
llm.invoke.return_value = MagicMock(
content="""
<enhanced_prompt>
This is a multiline enhanced prompt
that spans multiple lines
and includes various formatting.
It should preserve the structure.
</enhanced_prompt>
"""
)
return llm
@pytest.fixture
def mock_llm_no_xml():
"""Mock LLM that returns response without XML tags."""
llm = MagicMock()
llm.invoke.return_value = MagicMock(
content="Enhanced Prompt: This is an enhanced prompt without XML tags"
)
return llm
@pytest.fixture
def mock_llm_malformed_xml():
"""Mock LLM that returns response with malformed XML."""
llm = MagicMock()
llm.invoke.return_value = MagicMock(
content="""
<enhanced_prompt>
This XML tag is not properly closed
<enhanced_prompt>
"""
)
return llm
@pytest.fixture
def mock_messages():
"""Mock messages returned by apply_prompt_template."""
return [
SystemMessage(content="System prompt template"),
HumanMessage(content="Test human message"),
]
class TestPromptEnhancerNode:
"""Test cases for prompt_enhancer_node function."""
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_basic_prompt_enhancement(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test basic prompt enhancement without context or report style."""
mock_get_llm.return_value = mock_llm
mock_apply_template.return_value = mock_messages
state = PromptEnhancerState(prompt="Write about AI")
result = prompt_enhancer_node(state)
# Verify LLM was called
mock_get_llm.assert_called_once_with("basic")
mock_llm.invoke.assert_called_once_with(mock_messages)
# Verify apply_prompt_template was called correctly
mock_apply_template.assert_called_once()
call_args = mock_apply_template.call_args
assert call_args[0][0] == "prompt_enhancer/prompt_enhancer"
assert "messages" in call_args[0][1]
assert "report_style" in call_args[0][1]
# Verify result
assert result == {"output": "Enhanced test prompt"}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_prompt_enhancement_with_report_style(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test prompt enhancement with report style."""
mock_get_llm.return_value = mock_llm
mock_apply_template.return_value = mock_messages
state = PromptEnhancerState(
prompt="Write about AI", report_style=ReportStyle.ACADEMIC
)
result = prompt_enhancer_node(state)
# Verify apply_prompt_template was called with report_style
mock_apply_template.assert_called_once()
call_args = mock_apply_template.call_args
assert call_args[0][0] == "prompt_enhancer/prompt_enhancer"
assert call_args[0][1]["report_style"] == ReportStyle.ACADEMIC
# Verify result
assert result == {"output": "Enhanced test prompt"}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_prompt_enhancement_with_context(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test prompt enhancement with additional context."""
mock_get_llm.return_value = mock_llm
mock_apply_template.return_value = mock_messages
state = PromptEnhancerState(
prompt="Write about AI", context="Focus on machine learning applications"
)
result = prompt_enhancer_node(state)
# Verify apply_prompt_template was called
mock_apply_template.assert_called_once()
call_args = mock_apply_template.call_args
# Check that the context was included in the human message
messages_arg = call_args[0][1]["messages"]
assert len(messages_arg) == 1
human_message = messages_arg[0]
assert isinstance(human_message, HumanMessage)
assert "Focus on machine learning applications" in human_message.content
assert result == {"output": "Enhanced test prompt"}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_error_handling(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test error handling when LLM call fails."""
mock_get_llm.return_value = mock_llm
mock_apply_template.return_value = mock_messages
# Mock LLM to raise an exception
mock_llm.invoke.side_effect = Exception("LLM error")
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
# Should return original prompt on error
assert result == {"output": "Test prompt"}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_template_error_handling(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test error handling when template application fails."""
mock_get_llm.return_value = mock_llm
# Mock apply_prompt_template to raise an exception
mock_apply_template.side_effect = Exception("Template error")
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
# Should return original prompt on error
assert result == {"output": "Test prompt"}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_prefix_removal(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test that common prefixes are removed from LLM response."""
mock_get_llm.return_value = mock_llm
mock_apply_template.return_value = mock_messages
# Test different prefixes that should be removed
test_cases = [
"Enhanced Prompt: This is the enhanced prompt",
"Enhanced prompt: This is the enhanced prompt",
"Here's the enhanced prompt: This is the enhanced prompt",
"Here is the enhanced prompt: This is the enhanced prompt",
"**Enhanced Prompt**: This is the enhanced prompt",
"**Enhanced prompt**: This is the enhanced prompt",
]
for response_with_prefix in test_cases:
mock_llm.invoke.return_value = MagicMock(content=response_with_prefix)
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
assert result == {"output": "This is the enhanced prompt"}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_whitespace_handling(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test that whitespace is properly stripped from LLM response."""
mock_get_llm.return_value = mock_llm
mock_apply_template.return_value = mock_messages
# Mock LLM response with extra whitespace
mock_llm.invoke.return_value = MagicMock(
content=" \n\n Enhanced prompt \n\n "
)
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
assert result == {"output": "Enhanced prompt"}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_xml_with_whitespace_handling(
self,
mock_get_llm,
mock_apply_template,
mock_llm_xml_with_whitespace,
mock_messages,
):
"""Test XML extraction with extra whitespace inside tags."""
mock_get_llm.return_value = mock_llm_xml_with_whitespace
mock_apply_template.return_value = mock_messages
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
assert result == {"output": "Enhanced prompt with whitespace"}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_xml_multiline_content(
self, mock_get_llm, mock_apply_template, mock_llm_xml_multiline, mock_messages
):
"""Test XML extraction with multiline content."""
mock_get_llm.return_value = mock_llm_xml_multiline
mock_apply_template.return_value = mock_messages
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
expected_output = """This is a multiline enhanced prompt
that spans multiple lines
and includes various formatting.
It should preserve the structure."""
assert result == {"output": expected_output}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_fallback_to_prefix_removal(
self, mock_get_llm, mock_apply_template, mock_llm_no_xml, mock_messages
):
"""Test fallback to prefix removal when no XML tags are found."""
mock_get_llm.return_value = mock_llm_no_xml
mock_apply_template.return_value = mock_messages
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
assert result == {"output": "This is an enhanced prompt without XML tags"}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_malformed_xml_fallback(
self, mock_get_llm, mock_apply_template, mock_llm_malformed_xml, mock_messages
):
"""Test handling of malformed XML tags."""
mock_get_llm.return_value = mock_llm_malformed_xml
mock_apply_template.return_value = mock_messages
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
# Should fall back to using the entire content since XML is malformed
expected_content = """<enhanced_prompt>
This XML tag is not properly closed
<enhanced_prompt>"""
assert result == {"output": expected_content}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_case_sensitive_prefix_removal(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test that prefix removal is case-sensitive."""
mock_get_llm.return_value = mock_llm
mock_apply_template.return_value = mock_messages
# Test case variations that should NOT be removed
test_cases = [
"ENHANCED PROMPT: This should not be removed",
"enhanced prompt: This should not be removed",
"Enhanced Prompt This should not be removed", # Missing colon
"Enhanced Prompt :: This should not be removed", # Double colon
]
for response_content in test_cases:
mock_llm.invoke.return_value = MagicMock(content=response_content)
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
# Should return the full content since prefix doesn't match exactly
assert result == {"output": response_content}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_prefix_with_extra_whitespace(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test prefix removal with extra whitespace after colon."""
mock_get_llm.return_value = mock_llm
mock_apply_template.return_value = mock_messages
test_cases = [
("Enhanced Prompt: This has extra spaces", "This has extra spaces"),
("Enhanced prompt:\t\tThis has tabs", "This has tabs"),
("Here's the enhanced prompt:\n\nThis has newlines", "This has newlines"),
]
for response_content, expected_output in test_cases:
mock_llm.invoke.return_value = MagicMock(content=response_content)
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
assert result == {"output": expected_output}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_xml_with_special_characters(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test XML extraction with special characters and symbols."""
mock_get_llm.return_value = mock_llm
mock_apply_template.return_value = mock_messages
special_content = """<enhanced_prompt>
Enhanced prompt with special chars: @#$%^&*()
Unicode: 🚀 ✨ 💡
Quotes: "double" and 'single'
Backslashes: \\n \\t \\r
</enhanced_prompt>"""
mock_llm.invoke.return_value = MagicMock(content=special_content)
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
expected_output = """Enhanced prompt with special chars: @#$%^&*()
Unicode: 🚀 ✨ 💡
Quotes: "double" and 'single'
Backslashes: \\n \\t \\r"""
assert result == {"output": expected_output}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_very_long_response(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test handling of very long LLM responses."""
mock_get_llm.return_value = mock_llm
mock_apply_template.return_value = mock_messages
# Create a very long response
long_content = "This is a very long enhanced prompt. " * 100
xml_response = f"<enhanced_prompt>\n{long_content}\n</enhanced_prompt>"
mock_llm.invoke.return_value = MagicMock(content=xml_response)
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
assert result == {"output": long_content.strip()}
assert len(result["output"]) > 1000 # Verify it's actually long
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_empty_response_content(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test handling of empty response content."""
mock_get_llm.return_value = mock_llm
mock_apply_template.return_value = mock_messages
mock_llm.invoke.return_value = MagicMock(content="")
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
assert result == {"output": ""}
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
@patch(
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
{"prompt_enhancer": "basic"},
)
def test_only_whitespace_response(
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
):
"""Test handling of response with only whitespace."""
mock_get_llm.return_value = mock_llm
mock_apply_template.return_value = mock_messages
mock_llm.invoke.return_value = MagicMock(content=" \n\n\t\t ")
state = PromptEnhancerState(prompt="Test prompt")
result = prompt_enhancer_node(state)
assert result == {"output": ""}
@@ -1,107 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from src.config.report_style import ReportStyle
from src.prompt_enhancer.graph.state import PromptEnhancerState
def test_prompt_enhancer_state_creation():
"""Test that PromptEnhancerState can be created with required fields."""
state = PromptEnhancerState(
prompt="Test prompt", context=None, report_style=None, output=None
)
assert state["prompt"] == "Test prompt"
assert state["context"] is None
assert state["report_style"] is None
assert state["output"] is None
def test_prompt_enhancer_state_with_all_fields():
"""Test PromptEnhancerState with all fields populated."""
state = PromptEnhancerState(
prompt="Write about AI",
context="Additional context about AI research",
report_style=ReportStyle.ACADEMIC,
output="Enhanced prompt about AI research",
)
assert state["prompt"] == "Write about AI"
assert state["context"] == "Additional context about AI research"
assert state["report_style"] == ReportStyle.ACADEMIC
assert state["output"] == "Enhanced prompt about AI research"
def test_prompt_enhancer_state_minimal():
"""Test PromptEnhancerState with only required prompt field."""
state = PromptEnhancerState(prompt="Minimal prompt")
assert state["prompt"] == "Minimal prompt"
# Optional fields should not be present if not specified
assert "context" not in state
assert "report_style" not in state
assert "output" not in state
def test_prompt_enhancer_state_with_different_report_styles():
"""Test PromptEnhancerState with different ReportStyle values."""
styles = [
ReportStyle.ACADEMIC,
ReportStyle.POPULAR_SCIENCE,
ReportStyle.NEWS,
ReportStyle.SOCIAL_MEDIA,
]
for style in styles:
state = PromptEnhancerState(prompt="Test prompt", report_style=style)
assert state["report_style"] == style
def test_prompt_enhancer_state_update():
"""Test updating PromptEnhancerState fields."""
state = PromptEnhancerState(prompt="Original prompt")
# Update with new fields
state.update(
{
"context": "New context",
"report_style": ReportStyle.NEWS,
"output": "Enhanced output",
}
)
assert state["prompt"] == "Original prompt"
assert state["context"] == "New context"
assert state["report_style"] == ReportStyle.NEWS
assert state["output"] == "Enhanced output"
def test_prompt_enhancer_state_get_method():
"""Test using get() method on PromptEnhancerState."""
state = PromptEnhancerState(prompt="Test prompt", report_style=ReportStyle.ACADEMIC)
# Test get with existing keys
assert state.get("prompt") == "Test prompt"
assert state.get("report_style") == ReportStyle.ACADEMIC
# Test get with non-existing keys
assert state.get("context") is None
assert state.get("output") is None
assert state.get("nonexistent", "default") == "default"
def test_prompt_enhancer_state_type_annotations():
"""Test that the state accepts correct types."""
# This test ensures the TypedDict structure is working correctly
state = PromptEnhancerState(
prompt="Test prompt",
context="Test context",
report_style=ReportStyle.POPULAR_SCIENCE,
output="Test output",
)
# Verify types
assert isinstance(state["prompt"], str)
assert isinstance(state["context"], str)
assert isinstance(state["report_style"], ReportStyle)
assert isinstance(state["output"], str)
-154
View File
@@ -1,154 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import MagicMock, patch
import pytest
from src.rag.dify import DifyProvider, parse_uri
# Dummy classes to mock dependencies
class DummyResource:
def __init__(self, uri, title="", description=""):
self.uri = uri
self.title = title
self.description = description
class DummyChunk:
def __init__(self, content, similarity):
self.content = content
self.similarity = similarity
class DummyDocument:
def __init__(self, id, title, chunks=None):
self.id = id
self.title = title
self.chunks = chunks or []
# Patch imports in dify.py to use dummy classes
@pytest.fixture(autouse=True)
def patch_imports(monkeypatch):
import src.rag.dify as dify
dify.Resource = DummyResource
dify.Chunk = DummyChunk
dify.Document = DummyDocument
yield
def test_parse_uri_valid():
uri = "rag://dataset/123#abc"
dataset_id, document_id = parse_uri(uri)
assert dataset_id == "123"
assert document_id == "abc"
def test_parse_uri_invalid():
with pytest.raises(ValueError):
parse_uri("http://dataset/123#abc")
def test_init_env_vars(monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
assert provider.api_url == "http://api"
assert provider.api_key == "key"
def test_init_missing_env(monkeypatch):
monkeypatch.delenv("DIFY_API_URL", raising=False)
monkeypatch.setenv("DIFY_API_KEY", "key")
with pytest.raises(ValueError):
DifyProvider()
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.delenv("DIFY_API_KEY", raising=False)
with pytest.raises(ValueError):
DifyProvider()
@patch("src.rag.dify.requests.post")
def test_query_relevant_documents_success(mock_post, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
resource = DummyResource("rag://dataset/123#doc456")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"records": [
{
"segment": {
"content": "chunk text",
"document": {
"id": "doc456",
"name": "Doc Title",
},
},
"score": 0.9,
}
]
}
mock_post.return_value = mock_response
docs = provider.query_relevant_documents("query", [resource])
assert len(docs) == 1
assert docs[0].id == "doc456"
assert docs[0].title == "Doc Title"
assert len(docs[0].chunks) == 1
assert docs[0].chunks[0].content == "chunk text"
assert docs[0].chunks[0].similarity == 0.9
@patch("src.rag.dify.requests.post")
def test_query_relevant_documents_error(mock_post, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
resource = DummyResource("rag://dataset/123#doc456")
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "error"
mock_post.return_value = mock_response
with pytest.raises(Exception):
provider.query_relevant_documents("query", [resource])
@patch("src.rag.dify.requests.get")
def test_list_resources_success(mock_get, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "123", "name": "Dataset1", "description": "desc1"},
{"id": "456", "name": "Dataset2", "description": "desc2"},
]
}
mock_get.return_value = mock_response
resources = provider.list_resources()
assert len(resources) == 2
assert resources[0].uri == "rag://dataset/123"
assert resources[0].title == "Dataset1"
assert resources[0].description == "desc1"
assert resources[1].uri == "rag://dataset/456"
assert resources[1].title == "Dataset2"
assert resources[1].description == "desc2"
@patch("src.rag.dify.requests.get")
def test_list_resources_error(mock_get, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "fail"
mock_get.return_value = mock_response
with pytest.raises(Exception):
provider.list_resources()
-930
View File
@@ -1,930 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Tests for Milvus RAG provider.
IMPORTANT NOTE: This test file creates temporary directories for testing examples
functionality. All temporary directories are automatically cleaned up using pytest
fixtures. When adding new tests that create temporary directories:
1. Use the provided fixtures (temp_examples_dir, temp_error_examples_dir, etc.)
2. Never create temporary directories without automatic cleanup
3. Follow the pattern: fixture -> use -> automatic cleanup
4. If you need a new directory pattern, create a corresponding fixture
This ensures tests don't leave behind temporary files that clutter the workspace.
"""
from __future__ import annotations
import shutil
import tempfile
from pathlib import Path
from types import SimpleNamespace
from uuid import uuid4
import pytest
import src.rag.milvus as milvus_mod
from src.rag.milvus import MilvusProvider
from src.rag.retriever import Resource
class DummyEmbedding:
def __init__(self, **kwargs):
self.kwargs = kwargs
def embed_query(self, text: str):
return [0.1, 0.2, 0.3]
def embed_documents(self, texts):
return [[0.1, 0.2, 0.3] for _ in texts]
@pytest.fixture(autouse=True)
def patch_embeddings(monkeypatch):
# Prevent network / external API usage during __init__
monkeypatch.setenv("MILVUS_EMBEDDING_PROVIDER", "openai")
monkeypatch.setenv("MILVUS_EMBEDDING_MODEL", "text-embedding-ada-002")
monkeypatch.setenv("MILVUS_COLLECTION", "documents")
monkeypatch.setenv("MILVUS_URI", "./milvus_demo.db") # default lite
monkeypatch.setattr(milvus_mod, "OpenAIEmbeddings", DummyEmbedding)
monkeypatch.setattr(milvus_mod, "DashscopeEmbeddings", DummyEmbedding)
yield
@pytest.fixture
def project_root():
# Mirror logic from implementation: current_file.parent.parent.parent
return Path(milvus_mod.__file__).parent.parent.parent
@pytest.fixture
def temp_examples_dir(project_root):
"""Create a temporary examples directory with automatic cleanup."""
# Create a unique temporary directory name
temp_dir_name = f"examples_test_{uuid4().hex}"
temp_dir_path = project_root / temp_dir_name
# Create the directory
temp_dir_path.mkdir(parents=True, exist_ok=True)
yield temp_dir_path
# Cleanup: remove the directory and all its contents
if temp_dir_path.exists():
shutil.rmtree(temp_dir_path)
@pytest.fixture
def temp_error_examples_dir(project_root):
"""Create a temporary error examples directory with automatic cleanup."""
# Create a unique temporary directory name for error tests
temp_dir_name = f"examples_error_{uuid4().hex}"
temp_dir_path = project_root / temp_dir_name
# Create the directory
temp_dir_path.mkdir(parents=True, exist_ok=True)
yield temp_dir_path
# Cleanup: remove the directory and all its contents
if temp_dir_path.exists():
shutil.rmtree(temp_dir_path)
@pytest.fixture
def temp_load_skip_examples_dir(project_root):
"""Create a temporary examples directory for load_skip tests with automatic cleanup."""
# Use the expected directory name for this test
temp_dir_name = "examples_test_load_skip"
temp_dir_path = project_root / temp_dir_name
# Create the directory if it doesn't exist
temp_dir_path.mkdir(parents=True, exist_ok=True)
yield temp_dir_path
# Cleanup: remove the directory and all its contents
if temp_dir_path.exists():
shutil.rmtree(temp_dir_path)
@pytest.fixture
def temp_single_chunk_examples_dir(project_root):
"""Create a temporary examples directory for single_chunk tests with automatic cleanup."""
# Use the expected directory name for this test
temp_dir_name = "examples_test_single_chunk"
temp_dir_path = project_root / temp_dir_name
# Create the directory if it doesn't exist
temp_dir_path.mkdir(parents=True, exist_ok=True)
yield temp_dir_path
# Cleanup: remove the directory and all its contents
if temp_dir_path.exists():
shutil.rmtree(temp_dir_path)
def _patch_init(monkeypatch):
"""Patch retriever initialization to use dummy embedding model."""
monkeypatch.setattr(
MilvusProvider,
"_init_embedding_model",
lambda self: setattr(self, "embedding_model", DummyEmbedding()),
)
def test_list_local_markdown_resources_missing_dir(project_root):
retriever = MilvusProvider()
# Point to a non-existent examples dir
retriever.examples_dir = f"missing_examples_{uuid4().hex}"
resources = retriever._list_local_markdown_resources()
assert resources == []
def test_list_local_markdown_resources_populated(temp_examples_dir):
retriever = MilvusProvider()
# Use the name of the temp directory for examples_dir
retriever.examples_dir = temp_examples_dir.name
# File with heading
(temp_examples_dir / "file1.md").write_text(
"# Title One\n\nContent body.", encoding="utf-8"
)
# File without heading -> fallback title
(temp_examples_dir / "file_two.md").write_text("No heading here.", encoding="utf-8")
# Non-markdown file should be ignored
(temp_examples_dir / "ignore.txt").write_text(
"Should not be picked up.", encoding="utf-8"
)
resources = retriever._list_local_markdown_resources()
# Order not guaranteed; sort by uri for assertions
resources.sort(key=lambda r: r.uri)
# Expect two resources
assert len(resources) == 2
uris = {r.uri for r in resources}
assert uris == {
f"milvus://{retriever.collection_name}/file1.md",
f"milvus://{retriever.collection_name}/file_two.md",
}
res_map = {r.uri: r for r in resources}
r1 = res_map[f"milvus://{retriever.collection_name}/file1.md"]
assert isinstance(r1, Resource)
assert r1.title == "Title One"
assert r1.description == "Local markdown example (not yet ingested)"
r2 = res_map[f"milvus://{retriever.collection_name}/file_two.md"]
# Fallback logic: filename -> "file_two" -> "file two" -> title case -> "File Two"
assert r2.title == "File Two"
assert r2.description == "Local markdown example (not yet ingested)"
def test_list_local_markdown_resources_read_error(monkeypatch, temp_error_examples_dir):
retriever = MilvusProvider()
# Use the name of the temp directory for examples_dir
retriever.examples_dir = temp_error_examples_dir.name
bad_file = temp_error_examples_dir / "bad.md"
good_file = temp_error_examples_dir / "good.md"
good_file.write_text("# Good Title\n\nBody.", encoding="utf-8")
bad_file.write_text("Broken", encoding="utf-8")
# Patch Path.read_text to raise for bad.md only
original_read_text = Path.read_text
def fake_read_text(self, *args, **kwargs):
if self == bad_file:
raise OSError("Cannot read file")
return original_read_text(self, *args, **kwargs)
monkeypatch.setattr(Path, "read_text", fake_read_text)
resources = retriever._list_local_markdown_resources()
# Only good.md should appear
assert len(resources) == 1
r = resources[0]
assert r.title == "Good Title"
assert r.uri == f"milvus://{retriever.collection_name}/good.md"
def test_create_collection_schema_fields(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
schema = retriever._create_collection_schema()
field_names = {f.name for f in schema.fields}
# Core fields must be present
assert {
retriever.id_field,
retriever.vector_field,
retriever.content_field,
} <= field_names
# Dynamic field enabled for extra metadata
assert schema.enable_dynamic_field is True
def test_generate_doc_id_stable(monkeypatch, tmp_path):
_patch_init(monkeypatch)
retriever = MilvusProvider()
test_file = tmp_path / "example.md"
test_file.write_text("# Title\nBody", encoding="utf-8")
doc_id1 = retriever._generate_doc_id(test_file)
doc_id2 = retriever._generate_doc_id(test_file)
assert doc_id1 == doc_id2 # deterministic given unchanged file metadata
def test_extract_title_from_markdown(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
heading = retriever._extract_title_from_markdown("# Heading\nBody", "ignored.md")
assert heading == "Heading"
fallback = retriever._extract_title_from_markdown("Body only", "my_file_name.md")
assert fallback == "My File Name"
def test_split_content_chunking(monkeypatch):
monkeypatch.setenv("MILVUS_CHUNK_SIZE", "40") # small to force split
_patch_init(monkeypatch)
retriever = MilvusProvider()
long_content = (
"Para1 text here.\n\nPara2 second block.\n\nPara3 final." # 3 paragraphs
)
chunks = retriever._split_content(long_content)
assert len(chunks) >= 2 # forced split
assert all(chunks) # no empty chunks
def test_get_embedding_invalid_inputs(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
# Non-string value
with pytest.raises(RuntimeError):
retriever._get_embedding(123) # type: ignore[arg-type]
# Whitespace only
with pytest.raises(RuntimeError):
retriever._get_embedding(" ")
def test_list_resources_remote_success_and_dedup(monkeypatch):
monkeypatch.setenv("MILVUS_URI", "http://remote")
_patch_init(monkeypatch)
retriever = MilvusProvider()
class DocObj:
def __init__(self, content: str, meta: dict):
self.page_content = content
self.metadata = meta
calls = {"similarity_search": 0}
class RemoteClient:
def similarity_search(self, query, k, expr): # noqa: D401
calls["similarity_search"] += 1
# Two docs with identical id to test dedup
meta1 = {
retriever.id_field: "d1",
retriever.title_field: "T1",
retriever.url_field: "u1",
}
meta2 = {
retriever.id_field: "d1",
retriever.title_field: "T1_dup",
retriever.url_field: "u1",
}
return [DocObj("c1", meta1), DocObj("c1_dup", meta2)]
retriever.client = RemoteClient()
resources = retriever.list_resources("query text")
assert len(resources) == 1 # dedup applied
assert resources[0].title.startswith("T1")
assert calls["similarity_search"] == 1
def test_list_resources_lite_success(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
class DummyMilvusLite:
def query(self, collection_name, filter, output_fields, limit): # noqa: D401
return [
{
retriever.id_field: "idA",
retriever.title_field: "Alpha",
retriever.url_field: "u://a",
},
{
retriever.id_field: "idB",
retriever.title_field: "Beta",
retriever.url_field: "u://b",
},
]
retriever.client = DummyMilvusLite()
resources = retriever.list_resources()
assert {r.title for r in resources} == {"Alpha", "Beta"}
def test_query_relevant_documents_lite_success(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
# Provide deterministic embedding output
retriever.embedding_model.embed_query = lambda text: [0.1, 0.2, 0.3] # type: ignore
class DummyMilvusLite:
def search(
self, collection_name, data, anns_field, param, limit, output_fields
): # noqa: D401
# Simulate two result entries
return [
[
{
"entity": {
retriever.id_field: "d1",
retriever.content_field: "c1",
retriever.title_field: "T1",
retriever.url_field: "u1",
},
"distance": 0.9,
},
{
"entity": {
retriever.id_field: "d2",
retriever.content_field: "c2",
retriever.title_field: "T2",
retriever.url_field: "u2",
},
"distance": 0.8,
},
]
]
retriever.client = DummyMilvusLite()
# Filter for only d2 via resource list
docs = retriever.query_relevant_documents(
"question", resources=[Resource(uri="milvus://d2", title="", description="")]
)
assert len(docs) == 1 and docs[0].id == "d2" and docs[0].chunks[0].similarity == 0.8
def test_query_relevant_documents_remote_success(monkeypatch):
monkeypatch.setenv("MILVUS_URI", "http://remote")
_patch_init(monkeypatch)
retriever = MilvusProvider()
retriever.embedding_model.embed_query = lambda text: [0.1, 0.2, 0.3] # type: ignore
class DocObj:
def __init__(self, content: str, meta: dict): # noqa: D401
self.page_content = content
self.metadata = meta
class RemoteClient:
def similarity_search_with_score(self, query, k): # noqa: D401
return [
(
DocObj(
"c1",
{
retriever.id_field: "d1",
retriever.title_field: "T1",
retriever.url_field: "u1",
},
),
0.7,
),
(
DocObj(
"c2",
{
retriever.id_field: "d2",
retriever.title_field: "T2",
retriever.url_field: "u2",
},
),
0.6,
),
]
retriever.client = RemoteClient()
# Filter to only d1
docs = retriever.query_relevant_documents(
"q", resources=[Resource(uri="milvus://d1", title="", description="")]
)
assert len(docs) == 1 and docs[0].id == "d1" and docs[0].chunks[0].similarity == 0.7
def test_get_embedding_dimension_explicit(monkeypatch):
monkeypatch.setenv("MILVUS_EMBEDDING_DIM", "777")
_patch_init(monkeypatch)
retriever = MilvusProvider()
assert retriever.embedding_dim == 777
def test_get_embedding_dimension_unknown_model(monkeypatch):
monkeypatch.delenv("MILVUS_EMBEDDING_DIM", raising=False)
monkeypatch.setenv("MILVUS_EMBEDDING_MODEL", "unknown-model-x")
_patch_init(monkeypatch)
retriever = MilvusProvider()
# falls back to default 1536
assert retriever.embedding_dim == 1536
def test_is_milvus_lite_variants(monkeypatch):
_patch_init(monkeypatch)
monkeypatch.setenv("MILVUS_URI", "mydb.db")
assert MilvusProvider()._is_milvus_lite() is True
monkeypatch.setenv("MILVUS_URI", "relative_path_store")
assert MilvusProvider()._is_milvus_lite() is True
monkeypatch.setenv("MILVUS_URI", "http://host:19530")
assert MilvusProvider()._is_milvus_lite() is False
def test_create_collection_lite(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
created: dict = {}
class DummyMilvusLite:
def list_collections(self): # noqa: D401
return [] # empty triggers creation
def create_collection(self, collection_name, schema, index_params): # noqa: D401
created["name"] = collection_name
created["schema"] = schema
created["index"] = index_params
retriever.client = DummyMilvusLite()
retriever._ensure_collection_exists()
assert created["name"] == retriever.collection_name
def test_ensure_collection_exists_remote(monkeypatch):
_patch_init(monkeypatch)
monkeypatch.setenv("MILVUS_URI", "http://remote:19530")
retriever = MilvusProvider()
# remote path, nothing thrown
retriever.client = SimpleNamespace()
retriever._ensure_collection_exists()
def test_get_existing_document_ids_lite(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
class DummyMilvusLite:
def query(self, collection_name, filter, output_fields, limit): # noqa: D401
return [
{retriever.id_field: "a"},
{retriever.id_field: "b"},
{"other": "ignored"},
]
retriever.client = DummyMilvusLite()
assert retriever._get_existing_document_ids() == {"a", "b"}
def test_get_existing_document_ids_remote(monkeypatch):
_patch_init(monkeypatch)
monkeypatch.setenv("MILVUS_URI", "http://x")
retriever = MilvusProvider()
retriever.client = object()
assert retriever._get_existing_document_ids() == set()
def test_insert_document_chunk_lite_and_error(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
captured = {}
class DummyMilvusLite:
def insert(self, collection_name, data): # noqa: D401
captured["data"] = data
retriever.client = DummyMilvusLite()
retriever._insert_document_chunk(
doc_id="id1", content="hello", title="T", url="u", metadata={"m": 1}
)
assert captured["data"][0][retriever.id_field] == "id1"
# error path: patch embedding to raise
def bad_embed(text): # noqa: D401
raise RuntimeError("boom")
retriever.embedding_model.embed_query = bad_embed # type: ignore[attr-defined]
with pytest.raises(RuntimeError):
retriever._insert_document_chunk(
doc_id="id2", content="err", title="T", url="u", metadata={}
)
def test_insert_document_chunk_remote(monkeypatch):
_patch_init(monkeypatch)
monkeypatch.setenv("MILVUS_URI", "http://remote")
retriever = MilvusProvider()
added = {}
class RemoteClient:
def add_texts(self, texts, metadatas): # noqa: D401
added["texts"] = texts
added["meta"] = metadatas
retriever.client = RemoteClient()
retriever._insert_document_chunk(
doc_id="idx", content="ct", title="Title", url="urlx", metadata={"k": 2}
)
assert added["meta"][0][retriever.id_field] == "idx"
def test_connect_lite_and_error(monkeypatch):
# patch MilvusClient to a dummy
class FakeMilvusClient:
def __init__(self, uri): # noqa: D401
self.uri = uri
def list_collections(self): # noqa: D401
return []
def create_collection(self, **kwargs): # noqa: D401
pass
monkeypatch.setattr(milvus_mod, "MilvusClient", FakeMilvusClient)
_patch_init(monkeypatch)
retriever = MilvusProvider()
retriever._connect()
assert isinstance(retriever.client, FakeMilvusClient)
# error path: patch MilvusClient to raise
class BadClient:
def __init__(self, uri): # noqa: D401
raise RuntimeError("fail connect")
monkeypatch.setattr(milvus_mod, "MilvusClient", BadClient)
retriever2 = MilvusProvider()
with pytest.raises(ConnectionError):
retriever2._connect()
def test_connect_remote(monkeypatch):
monkeypatch.setenv("MILVUS_URI", "http://remote")
_patch_init(monkeypatch)
created = {}
class FakeLangchainMilvus:
def __init__(self, **kwargs): # noqa: D401
created.update(kwargs)
monkeypatch.setattr(milvus_mod, "LangchainMilvus", FakeLangchainMilvus)
retriever = MilvusProvider()
retriever._connect()
assert created["collection_name"] == retriever.collection_name
def test_list_resources_remote_failure(monkeypatch):
monkeypatch.setenv("MILVUS_URI", "http://remote")
_patch_init(monkeypatch)
retriever = MilvusProvider()
# Provide minimal working local examples dir (none -> returns [])
monkeypatch.setattr(retriever, "_list_local_markdown_resources", lambda: [])
# patch client to raise inside similarity_search to trigger fallback path
class BadClient:
def similarity_search(self, *args, **kwargs): # noqa: D401
raise RuntimeError("fail")
retriever.client = BadClient()
# Should fallback to [] without raising
assert retriever.list_resources() == []
def test_list_local_markdown_resources_empty(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
monkeypatch.setenv("MILVUS_EXAMPLES_DIR", "nonexistent_dir")
retriever.examples_dir = "nonexistent_dir"
assert retriever._list_local_markdown_resources() == []
def test_query_relevant_documents_error(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
retriever.embedding_model.embed_query = lambda text: ( # type: ignore
_ for _ in ()
).throw(RuntimeError("embed fail"))
with pytest.raises(RuntimeError):
retriever.query_relevant_documents("q")
def test_create_collection_when_client_exists(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
retriever.client = SimpleNamespace(closed=False)
# remote vs lite path difference handled by _is_milvus_lite
retriever.create_collection() # should no-op gracefully
def test_load_examples_force_reload(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
retriever.client = SimpleNamespace()
called = {"clear": 0, "load": 0}
monkeypatch.setattr(
retriever, "_clear_example_documents", lambda: called.__setitem__("clear", 1)
)
monkeypatch.setattr(
retriever, "_load_example_files", lambda: called.__setitem__("load", 1)
)
retriever.load_examples(force_reload=True)
assert called == {"clear": 1, "load": 1}
def test_clear_example_documents_remote(monkeypatch):
monkeypatch.setenv("MILVUS_URI", "http://remote")
_patch_init(monkeypatch)
retriever = MilvusProvider()
retriever.client = SimpleNamespace()
# Should just log and not raise
retriever._clear_example_documents()
def test_clear_example_documents_lite(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
deleted = {}
class DummyMilvusLite:
def query(self, **kwargs): # noqa: D401
return [
{retriever.id_field: "ex1"},
{retriever.id_field: "ex2"},
]
def delete(self, collection_name, ids): # noqa: D401
deleted["ids"] = ids
retriever.client = DummyMilvusLite()
retriever._clear_example_documents()
assert deleted["ids"] == ["ex1", "ex2"]
def test_get_loaded_examples_lite_and_error(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
class DummyMilvusLite:
def query(self, **kwargs): # noqa: D401
return [
{
retriever.id_field: "id1",
retriever.title_field: "T1",
retriever.url_field: "u1",
"file": "f1",
}
]
retriever.client = DummyMilvusLite()
loaded = retriever.get_loaded_examples()
assert loaded[0]["id"] == "id1"
# error path
class BadClient:
def query(self, **kwargs): # noqa: D401
raise RuntimeError("fail")
retriever.client = BadClient()
assert retriever.get_loaded_examples() == []
def test_get_loaded_examples_remote(monkeypatch):
monkeypatch.setenv("MILVUS_URI", "http://remote")
_patch_init(monkeypatch)
retriever = MilvusProvider()
retriever.client = SimpleNamespace()
assert retriever.get_loaded_examples() == []
def test_close_lite_and_remote(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
closed = {"c": 0}
class DummyMilvusLite:
def close(self): # noqa: D401
closed["c"] += 1
def list_collections(self): # noqa: D401
return []
def create_collection(self, **kwargs): # noqa: D401
pass
retriever.client = DummyMilvusLite()
retriever.close()
assert closed["c"] == 1
# remote path: no close attr usage expected
monkeypatch.setenv("MILVUS_URI", "http://remote")
retriever2 = MilvusProvider()
retriever2.client = SimpleNamespace()
retriever2.close() # should not raise
def test_get_embedding_invalid_output(monkeypatch):
_patch_init(monkeypatch)
retriever = MilvusProvider()
# patch embedding model to return invalid output (empty list)
retriever.embedding_model.embed_query = lambda text: [] # type: ignore
with pytest.raises(RuntimeError):
retriever._get_embedding("text")
def test_dashscope_embeddings_empty_inputs_short_circuit(monkeypatch):
# Use real class but swap _client to ensure create is never called
emb = milvus_mod.DashscopeEmbeddings(model="m")
class FailingClient:
class _Emb:
def create(self, *a, **k):
raise AssertionError("Should not be called for empty input")
embeddings = _Emb()
emb._client = FailingClient() # type: ignore
assert emb.embed_documents([]) == []
# Tests for _init_embedding_model provider selection logic
def test_init_embedding_model_openai(monkeypatch):
monkeypatch.setenv("MILVUS_EMBEDDING_PROVIDER", "openai")
monkeypatch.setenv("MILVUS_EMBEDDING_MODEL", "text-embedding-ada-002")
captured = {}
class CapturingOpenAI:
def __init__(self, **kwargs):
captured.update(kwargs)
monkeypatch.setattr(milvus_mod, "OpenAIEmbeddings", CapturingOpenAI)
prov = MilvusProvider()
assert isinstance(prov.embedding_model, CapturingOpenAI)
# kwargs forwarded
assert captured["model"] == "text-embedding-ada-002"
assert captured["encoding_format"] == "float"
assert captured["dimensions"] == prov.embedding_dim
def test_init_embedding_model_dashscope(monkeypatch):
monkeypatch.setenv("MILVUS_EMBEDDING_PROVIDER", "dashscope")
monkeypatch.setenv("MILVUS_EMBEDDING_MODEL", "text-embedding-ada-002")
captured = {}
class CapturingDashscope:
def __init__(self, **kwargs):
captured.update(kwargs)
monkeypatch.setattr(milvus_mod, "DashscopeEmbeddings", CapturingDashscope)
prov = MilvusProvider()
assert isinstance(prov.embedding_model, CapturingDashscope)
assert captured["model"] == "text-embedding-ada-002"
assert captured["encoding_format"] == "float"
assert captured["dimensions"] == prov.embedding_dim
def test_init_embedding_model_invalid_provider(monkeypatch):
monkeypatch.setenv("MILVUS_EMBEDDING_PROVIDER", "not_a_provider")
with pytest.raises(ValueError):
MilvusProvider()
def test_load_example_files_directory_missing(monkeypatch):
_patch_init(monkeypatch)
missing_dir = "examples_dir_does_not_exist_xyz"
monkeypatch.setenv("MILVUS_EXAMPLES_DIR", missing_dir)
retriever = MilvusProvider()
retriever.examples_dir = missing_dir
called = {"insert": 0}
monkeypatch.setattr(
retriever,
"_insert_document_chunk",
lambda **kwargs: (_ for _ in ()).throw(AssertionError("should not insert")),
)
retriever._load_example_files()
assert called["insert"] == 0 # sanity (no insertion attempted)
def test_load_example_files_loads_and_skips_existing(
monkeypatch, temp_load_skip_examples_dir
):
_patch_init(monkeypatch)
examples_dir_name = temp_load_skip_examples_dir.name
file1 = temp_load_skip_examples_dir / "file1.md"
file2 = temp_load_skip_examples_dir / "file2.md"
file1.write_text("# Title One\nContent A", encoding="utf-8")
file2.write_text("# Title Two\nContent B", encoding="utf-8")
monkeypatch.setenv("MILVUS_EXAMPLES_DIR", examples_dir_name)
retriever = MilvusProvider()
retriever.examples_dir = examples_dir_name
# Compute doc ids using real method
doc_id_file1 = retriever._generate_doc_id(file1)
doc_id_file2 = retriever._generate_doc_id(file2)
# Existing docs contains file1 so it is skipped
monkeypatch.setattr(retriever, "_get_existing_document_ids", lambda: {doc_id_file1})
# Force two chunks for any file to test suffix logic
monkeypatch.setattr(retriever, "_split_content", lambda content: ["part1", "part2"])
calls = []
def record_insert(doc_id, content, title, url, metadata):
calls.append(
{
"doc_id": doc_id,
"content": content,
"title": title,
"url": url,
"metadata": metadata,
}
)
monkeypatch.setattr(retriever, "_insert_document_chunk", record_insert)
retriever._load_example_files()
# Only file2 processed -> two chunk inserts
assert len(calls) == 2
expected_ids = {f"{doc_id_file2}_chunk_0", f"{doc_id_file2}_chunk_1"}
assert {c["doc_id"] for c in calls} == expected_ids
assert all(c["metadata"]["file"] == "file2.md" for c in calls)
assert all(c["metadata"]["source"] == "examples" for c in calls)
assert all(c["title"] == "Title Two" for c in calls)
def test_load_example_files_single_chunk_no_suffix(
monkeypatch, temp_single_chunk_examples_dir
):
_patch_init(monkeypatch)
examples_dir_name = temp_single_chunk_examples_dir.name
file_single = temp_single_chunk_examples_dir / "single.md"
file_single.write_text(
"# Single Title\nOnly one small paragraph.", encoding="utf-8"
)
monkeypatch.setenv("MILVUS_EXAMPLES_DIR", examples_dir_name)
retriever = MilvusProvider()
retriever.examples_dir = examples_dir_name
base_doc_id = retriever._generate_doc_id(file_single)
monkeypatch.setattr(retriever, "_get_existing_document_ids", lambda: set())
monkeypatch.setattr(retriever, "_split_content", lambda content: ["onlychunk"])
captured = {}
def capture(doc_id, content, title, url, metadata):
captured["doc_id"] = doc_id
captured["title"] = title
captured["metadata"] = metadata
monkeypatch.setattr(retriever, "_insert_document_chunk", capture)
retriever._load_example_files()
assert captured["doc_id"] == base_doc_id # no _chunk_ suffix
assert captured["title"] == "Single Title"
assert captured["metadata"]["file"] == "single.md"
assert captured["metadata"]["source"] == "examples"
# Clean up test database file after tests
import atexit
def cleanup_test_database():
"""Clean up milvus_demo.db file created during testing."""
import os
from pathlib import Path
# Skip cleanup if disabled
if os.getenv("DISABLE_TEST_CLEANUP", "false").lower() == "true":
return
db_file = Path.cwd() / "milvus_demo.db"
if db_file.exists():
try:
db_file.unlink()
print("🧹 Cleaned up milvus_demo.db")
except Exception:
pass # Silently ignore cleanup errors
# Register cleanup to run when Python exits
atexit.register(cleanup_test_database)
-333
View File
@@ -1,333 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from __future__ import annotations
import shutil
from pathlib import Path
from uuid import uuid4
import pytest
import src.rag.qdrant as qdrant_mod
from src.rag.qdrant import QdrantProvider
class DummyEmbedding:
def __init__(self, **kwargs):
self.kwargs = kwargs
def embed_query(self, text: str):
return [0.1] * 1536
def embed_documents(self, texts):
return [[0.1] * 1536 for _ in texts]
@pytest.fixture(autouse=True)
def patch_embeddings(monkeypatch):
monkeypatch.setenv("QDRANT_EMBEDDING_PROVIDER", "openai")
monkeypatch.setenv("QDRANT_EMBEDDING_MODEL", "text-embedding-ada-002")
monkeypatch.setenv("QDRANT_COLLECTION", "documents")
monkeypatch.setenv("QDRANT_LOCATION", ":memory:")
monkeypatch.setattr(qdrant_mod, "OpenAIEmbeddings", DummyEmbedding)
monkeypatch.setattr(qdrant_mod, "DashscopeEmbeddings", DummyEmbedding)
yield
@pytest.fixture
def project_root():
return Path(qdrant_mod.__file__).parent.parent.parent
@pytest.fixture
def temp_examples_dir(project_root):
temp_dir_name = f"examples_test_{uuid4().hex}"
temp_dir_path = project_root / temp_dir_name
temp_dir_path.mkdir(parents=True, exist_ok=True)
yield temp_dir_path
if temp_dir_path.exists():
shutil.rmtree(temp_dir_path)
@pytest.fixture
def temp_error_examples_dir(project_root):
temp_dir_name = f"examples_error_{uuid4().hex}"
temp_dir_path = project_root / temp_dir_name
temp_dir_path.mkdir(parents=True, exist_ok=True)
yield temp_dir_path
if temp_dir_path.exists():
shutil.rmtree(temp_dir_path)
@pytest.fixture
def temp_load_skip_examples_dir(project_root):
temp_dir_name = f"examples_load_skip_{uuid4().hex}"
temp_dir_path = project_root / temp_dir_name
temp_dir_path.mkdir(parents=True, exist_ok=True)
yield temp_dir_path
if temp_dir_path.exists():
shutil.rmtree(temp_dir_path)
def test_init_openai_provider(monkeypatch):
monkeypatch.setenv("QDRANT_EMBEDDING_PROVIDER", "openai")
provider = QdrantProvider()
assert provider.embedding_provider == "openai"
assert isinstance(provider.embedding_model, DummyEmbedding)
def test_init_dashscope_provider(monkeypatch):
monkeypatch.setenv("QDRANT_EMBEDDING_PROVIDER", "dashscope")
provider = QdrantProvider()
assert provider.embedding_provider == "dashscope"
assert isinstance(provider.embedding_model, DummyEmbedding)
def test_init_invalid_provider(monkeypatch):
monkeypatch.setenv("QDRANT_EMBEDDING_PROVIDER", "invalid_provider")
with pytest.raises(ValueError, match="Unsupported embedding provider"):
QdrantProvider()
def test_get_embedding_dimension_explicit(monkeypatch):
monkeypatch.setenv("QDRANT_EMBEDDING_DIM", "2048")
provider = QdrantProvider()
assert provider.embedding_dim == 2048
def test_get_embedding_dimension_default(monkeypatch):
monkeypatch.delenv("QDRANT_EMBEDDING_DIM", raising=False)
monkeypatch.setenv("QDRANT_EMBEDDING_MODEL", "text-embedding-ada-002")
provider = QdrantProvider()
assert provider.embedding_dim == 1536
def test_get_embedding_dimension_unknown_model(monkeypatch):
monkeypatch.delenv("QDRANT_EMBEDDING_DIM", raising=False)
monkeypatch.setenv("QDRANT_EMBEDDING_MODEL", "unknown-model")
provider = QdrantProvider()
assert provider.embedding_dim == 1536
def test_connect_memory_mode(monkeypatch):
monkeypatch.setenv("QDRANT_LOCATION", ":memory:")
provider = QdrantProvider()
provider._connect()
assert provider.client is not None
def test_create_collection(monkeypatch):
provider = QdrantProvider()
provider.create_collection()
assert provider.client is not None
def test_extract_title_from_markdown():
provider = QdrantProvider()
content = "# Test Title\n\nSome content"
title = provider._extract_title_from_markdown(content, "test.md")
assert title == "Test Title"
def test_extract_title_fallback():
provider = QdrantProvider()
content = "No title here"
title = provider._extract_title_from_markdown(content, "test_file.md")
assert title == "Test File"
def test_split_content_short():
provider = QdrantProvider()
content = "Short content"
chunks = provider._split_content(content)
assert len(chunks) == 1
assert chunks[0] == content
def test_split_content_long(monkeypatch):
monkeypatch.setenv("QDRANT_CHUNK_SIZE", "20")
provider = QdrantProvider()
content = "Paragraph one here\n\nParagraph two here\n\nParagraph three here\n\nParagraph four here"
chunks = provider._split_content(content)
assert len(chunks) > 1
def test_string_to_uuid():
provider = QdrantProvider()
uuid1 = provider._string_to_uuid("test")
uuid2 = provider._string_to_uuid("test")
assert uuid1 == uuid2
def test_get_embedding():
provider = QdrantProvider()
embedding = provider._get_embedding("test text")
assert len(embedding) == 1536
assert all(isinstance(x, float) for x in embedding)
def test_load_examples_no_directory(monkeypatch, project_root):
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", "nonexistent_dir")
provider = QdrantProvider()
provider.load_examples()
def test_load_examples_empty_directory(monkeypatch, temp_examples_dir):
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
provider = QdrantProvider()
provider.load_examples()
def test_load_examples_with_files(monkeypatch, temp_examples_dir):
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
md_file = temp_examples_dir / "test.md"
md_file.write_text("# Test\n\nContent", encoding="utf-8")
provider = QdrantProvider()
provider.load_examples()
loaded = provider.get_loaded_examples()
assert len(loaded) == 1
assert loaded[0]["title"] == "Test"
def test_load_examples_skip_existing(monkeypatch, temp_load_skip_examples_dir):
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_load_skip_examples_dir.name)
md_file = temp_load_skip_examples_dir / "test.md"
md_file.write_text("# Test\n\nContent", encoding="utf-8")
provider = QdrantProvider()
provider.load_examples()
provider.load_examples()
loaded = provider.get_loaded_examples()
assert len(loaded) == 1
def test_load_examples_force_reload(monkeypatch, temp_examples_dir):
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
md_file = temp_examples_dir / "test.md"
md_file.write_text("# Test\n\nContent", encoding="utf-8")
provider = QdrantProvider()
provider.load_examples()
provider.load_examples(force_reload=True)
loaded = provider.get_loaded_examples()
assert len(loaded) == 1
def test_load_examples_error_handling(monkeypatch, temp_error_examples_dir):
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_error_examples_dir.name)
good_file = temp_error_examples_dir / "good.md"
good_file.write_text("# Good\n\nContent", encoding="utf-8")
bad_file = temp_error_examples_dir / "bad.md"
bad_file.write_text("# Bad\n\n", encoding="utf-8")
provider = QdrantProvider()
provider.load_examples()
loaded = provider.get_loaded_examples()
assert len(loaded) >= 1
def test_list_resources_no_query(monkeypatch, temp_examples_dir):
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
md_file = temp_examples_dir / "test.md"
md_file.write_text("# Test\n\nContent", encoding="utf-8")
provider = QdrantProvider()
provider.load_examples()
resources = provider.list_resources()
assert len(resources) >= 1
def test_list_resources_with_query(monkeypatch, temp_examples_dir):
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
md_file = temp_examples_dir / "test.md"
md_file.write_text("# Test\n\nContent", encoding="utf-8")
provider = QdrantProvider()
provider.load_examples()
resources = provider.list_resources(query="test")
assert isinstance(resources, list)
def test_query_relevant_documents(monkeypatch, temp_examples_dir):
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
md_file = temp_examples_dir / "test.md"
md_file.write_text("# Test\n\nContent about testing", encoding="utf-8")
provider = QdrantProvider()
provider.load_examples()
documents = provider.query_relevant_documents("testing")
assert isinstance(documents, list)
def test_query_relevant_documents_with_resources(monkeypatch, temp_examples_dir):
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
md_file = temp_examples_dir / "test.md"
md_file.write_text("# Test\n\nContent", encoding="utf-8")
provider = QdrantProvider()
provider.load_examples()
resources = provider.list_resources()
documents = provider.query_relevant_documents("test", resources=resources)
assert isinstance(documents, list)
def test_close():
provider = QdrantProvider()
provider._connect()
provider.close()
assert provider.client is None
def test_del():
provider = QdrantProvider()
provider._connect()
del provider
def test_top_k_configuration(monkeypatch):
monkeypatch.setenv("QDRANT_TOP_K", "20")
provider = QdrantProvider()
assert provider.top_k == 20
def test_top_k_invalid(monkeypatch):
monkeypatch.setenv("QDRANT_TOP_K", "invalid")
provider = QdrantProvider()
assert provider.top_k == 10
def test_chunk_size_configuration(monkeypatch):
monkeypatch.setenv("QDRANT_CHUNK_SIZE", "5000")
provider = QdrantProvider()
assert provider.chunk_size == 5000
def test_collection_name_configuration(monkeypatch):
monkeypatch.setenv("QDRANT_COLLECTION", "custom_collection")
provider = QdrantProvider()
assert provider.collection_name == "custom_collection"
def test_auto_load_examples_configuration(monkeypatch):
monkeypatch.setenv("QDRANT_AUTO_LOAD_EXAMPLES", "false")
provider = QdrantProvider()
assert provider.auto_load_examples is False
-165
View File
@@ -1,165 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import MagicMock, patch
import pytest
from src.rag.ragflow import RAGFlowProvider, parse_uri
# Dummy classes to mock dependencies
class DummyResource:
def __init__(self, uri, title="", description=""):
self.uri = uri
self.title = title
self.description = description
class DummyChunk:
def __init__(self, content, similarity):
self.content = content
self.similarity = similarity
class DummyDocument:
def __init__(self, id, title, chunks=None):
self.id = id
self.title = title
self.chunks = chunks or []
# Patch imports in ragflow.py to use dummy classes
@pytest.fixture(autouse=True)
def patch_imports(monkeypatch):
import src.rag.ragflow as ragflow
ragflow.Resource = DummyResource
ragflow.Chunk = DummyChunk
ragflow.Document = DummyDocument
yield
def test_parse_uri_valid():
uri = "rag://dataset/123#abc"
dataset_id, document_id = parse_uri(uri)
assert dataset_id == "123"
assert document_id == "abc"
def test_parse_uri_invalid():
with pytest.raises(ValueError):
parse_uri("http://dataset/123#abc")
def test_init_env_vars(monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
monkeypatch.delenv("RAGFLOW_PAGE_SIZE", raising=False)
provider = RAGFlowProvider()
assert provider.api_url == "http://api"
assert provider.api_key == "key"
assert provider.page_size == 10
def test_init_page_size(monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
monkeypatch.setenv("RAGFLOW_PAGE_SIZE", "5")
provider = RAGFlowProvider()
assert provider.page_size == 5
def test_init_cross_language(monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
monkeypatch.setenv("RAGFLOW_CROSS_LANGUAGES", "lang1,lang2")
provider = RAGFlowProvider()
assert provider.cross_languages == ["lang1", "lang2"]
def test_init_missing_env(monkeypatch):
monkeypatch.delenv("RAGFLOW_API_URL", raising=False)
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
with pytest.raises(ValueError):
RAGFlowProvider()
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.delenv("RAGFLOW_API_KEY", raising=False)
with pytest.raises(ValueError):
RAGFlowProvider()
@patch("src.rag.ragflow.requests.post")
def test_query_relevant_documents_success(mock_post, monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
provider = RAGFlowProvider()
resource = DummyResource("rag://dataset/123#doc456")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": {
"doc_aggs": [{"doc_id": "doc456", "doc_name": "Doc Title"}],
"chunks": [
{"document_id": "doc456", "content": "chunk text", "similarity": 0.9}
],
}
}
mock_post.return_value = mock_response
docs = provider.query_relevant_documents("query", [resource])
assert len(docs) == 1
assert docs[0].id == "doc456"
assert docs[0].title == "Doc Title"
assert len(docs[0].chunks) == 1
assert docs[0].chunks[0].content == "chunk text"
assert docs[0].chunks[0].similarity == 0.9
@patch("src.rag.ragflow.requests.post")
def test_query_relevant_documents_error(mock_post, monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
provider = RAGFlowProvider()
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "error"
mock_post.return_value = mock_response
with pytest.raises(Exception):
provider.query_relevant_documents("query", [])
@patch("src.rag.ragflow.requests.get")
def test_list_resources_success(mock_get, monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
provider = RAGFlowProvider()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "123", "name": "Dataset1", "description": "desc1"},
{"id": "456", "name": "Dataset2", "description": "desc2"},
]
}
mock_get.return_value = mock_response
resources = provider.list_resources()
assert len(resources) == 2
assert resources[0].uri == "rag://dataset/123"
assert resources[0].title == "Dataset1"
assert resources[0].description == "desc1"
assert resources[1].uri == "rag://dataset/456"
assert resources[1].title == "Dataset2"
assert resources[1].description == "desc2"
@patch("src.rag.ragflow.requests.get")
def test_list_resources_error(mock_get, monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
provider = RAGFlowProvider()
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "fail"
mock_get.return_value = mock_response
with pytest.raises(Exception):
provider.list_resources()
-114
View File
@@ -1,114 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
from src.rag.retriever import Chunk, Document, Resource, Retriever
def test_chunk_init():
chunk = Chunk(content="test content", similarity=0.9)
assert chunk.content == "test content"
assert chunk.similarity == 0.9
def test_document_init_and_to_dict():
chunk1 = Chunk(content="chunk1", similarity=0.8)
chunk2 = Chunk(content="chunk2", similarity=0.7)
doc = Document(
id="doc1", url="http://example.com", title="Title", chunks=[chunk1, chunk2]
)
assert doc.id == "doc1"
assert doc.url == "http://example.com"
assert doc.title == "Title"
assert doc.chunks == [chunk1, chunk2]
d = doc.to_dict()
assert d["id"] == "doc1"
assert d["content"] == "chunk1\n\nchunk2"
assert d["url"] == "http://example.com"
assert d["title"] == "Title"
def test_document_to_dict_optional_fields():
chunk = Chunk(content="only chunk", similarity=1.0)
doc = Document(id="doc2", chunks=[chunk])
d = doc.to_dict()
assert d["id"] == "doc2"
assert d["content"] == "only chunk"
assert "url" not in d
assert "title" not in d
def test_resource_model():
resource = Resource(uri="uri1", title="Resource Title")
assert resource.uri == "uri1"
assert resource.title == "Resource Title"
assert resource.description == ""
def test_resource_model_with_description():
resource = Resource(uri="uri2", title="Resource2", description="desc")
assert resource.description == "desc"
def test_retriever_abstract_methods():
class DummyRetriever(Retriever):
def list_resources(self, query=None):
return [Resource(uri="uri", title="title")]
async def list_resources_async(self, query=None):
return [Resource(uri="uri", title="title")]
def query_relevant_documents(self, query, resources=[]):
return [Document(id="id", chunks=[])]
async def query_relevant_documents_async(self, query, resources=[]):
return [Document(id="id", chunks=[])]
retriever = DummyRetriever()
# Test synchronous methods
resources = retriever.list_resources()
assert isinstance(resources, list)
assert isinstance(resources[0], Resource)
assert resources[0].uri == "uri"
docs = retriever.query_relevant_documents("query", resources)
assert isinstance(docs, list)
assert isinstance(docs[0], Document)
assert docs[0].id == "id"
def test_retriever_cannot_instantiate():
with pytest.raises(TypeError):
Retriever()
@pytest.mark.asyncio
async def test_retriever_async_methods():
"""Test that async methods work correctly in DummyRetriever."""
class DummyRetriever(Retriever):
def list_resources(self, query=None):
return [Resource(uri="uri", title="title")]
async def list_resources_async(self, query=None):
return [Resource(uri="uri_async", title="title_async")]
def query_relevant_documents(self, query, resources=[]):
return [Document(id="id", chunks=[])]
async def query_relevant_documents_async(self, query, resources=[]):
return [Document(id="id_async", chunks=[])]
retriever = DummyRetriever()
# Test async list_resources
resources = await retriever.list_resources_async()
assert isinstance(resources, list)
assert isinstance(resources[0], Resource)
assert resources[0].uri == "uri_async"
# Test async query_relevant_documents
docs = await retriever.query_relevant_documents_async("query", resources)
assert isinstance(docs, list)
assert isinstance(docs[0], Document)
assert docs[0].id == "id_async"
@@ -1,540 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import hashlib
import hmac
import json
import os
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider, parse_uri
# Dummy classes to mock dependencies
class MockResource:
def __init__(self, uri, title="", description=""):
self.uri = uri
self.title = title
self.description = description
class MockChunk:
def __init__(self, content, similarity):
self.content = content
self.similarity = similarity
class MockDocument:
def __init__(self, id, title, chunks=None):
self.id = id
self.title = title
self.chunks = chunks or []
# Patch the imports to use mock classes
@pytest.fixture(autouse=True)
def patch_imports():
with (
patch("src.rag.vikingdb_knowledge_base.Resource", MockResource),
patch("src.rag.vikingdb_knowledge_base.Chunk", MockChunk),
patch("src.rag.vikingdb_knowledge_base.Document", MockDocument),
):
yield
@pytest.fixture
def env_vars():
"""Fixture to set up environment variables"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
"VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE": "10",
"VIKINGDB_KNOWLEDGE_BASE_REGION": "cn-north-1",
},
):
yield
class TestParseUri:
def test_parse_uri_valid_with_fragment(self):
"""Test parsing valid URI with fragment"""
uri = "rag://dataset/123#doc456"
resource_id, document_id = parse_uri(uri)
assert resource_id == "123"
assert document_id == "doc456"
def test_parse_uri_valid_without_fragment(self):
"""Test parsing valid URI without fragment"""
uri = "rag://dataset/123"
resource_id, document_id = parse_uri(uri)
assert resource_id == "123"
assert document_id == ""
def test_parse_uri_invalid_scheme(self):
"""Test parsing URI with invalid scheme"""
with pytest.raises(ValueError, match="Invalid URI"):
parse_uri("http://dataset/123#abc")
def test_parse_uri_malformed(self):
"""Test parsing malformed URI"""
with pytest.raises(ValueError, match="Invalid URI"):
parse_uri("invalid_uri")
class TestVikingDBKnowledgeBaseProviderInit:
def test_init_success_with_all_env_vars(self, env_vars):
"""Test successful initialization with all environment variables"""
provider = VikingDBKnowledgeBaseProvider()
assert provider.api_url == "api-test.example.com"
assert provider.api_ak == "test_ak"
assert provider.api_sk == "test_sk"
assert provider.retrieval_size == 10
assert provider.region == "cn-north-1"
assert provider.service == "air"
def test_init_success_without_retrieval_size(self):
"""Test initialization without VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE (should use default)"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
},
clear=True,
):
provider = VikingDBKnowledgeBaseProvider()
assert provider.retrieval_size == 10
def test_init_custom_retrieval_size(self):
"""Test initialization with custom retrieval size"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
"VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE": "5",
},
):
provider = VikingDBKnowledgeBaseProvider()
assert provider.retrieval_size == 5
def test_init_custom_region(self):
"""Test initialization with custom region"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
"VIKINGDB_KNOWLEDGE_BASE_REGION": "us-east-1",
},
):
provider = VikingDBKnowledgeBaseProvider()
assert provider.region == "us-east-1"
def test_init_missing_api_url(self):
"""Test initialization fails when API URL is missing"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
},
clear=True,
):
with pytest.raises(
ValueError, match="VIKINGDB_KNOWLEDGE_BASE_API_URL is not set"
):
VikingDBKnowledgeBaseProvider()
def test_init_missing_api_ak(self):
"""Test initialization fails when API AK is missing"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
},
clear=True,
):
with pytest.raises(
ValueError, match="VIKINGDB_KNOWLEDGE_BASE_API_AK is not set"
):
VikingDBKnowledgeBaseProvider()
def test_init_missing_api_sk(self):
"""Test initialization fails when API SK is missing"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
},
clear=True,
):
with pytest.raises(
ValueError, match="VIKINGDB_KNOWLEDGE_BASE_API_SK is not set"
):
VikingDBKnowledgeBaseProvider()
class TestVikingDBKnowledgeBaseProviderSignature:
@pytest.fixture
def provider(self, env_vars):
return VikingDBKnowledgeBaseProvider()
def test_hmac_sha256(self, provider):
"""Test HMAC SHA256 calculation"""
key = b"test_key"
content = "test_content"
result = provider._hmac_sha256(key, content)
expected = hmac.new(key, content.encode("utf-8"), hashlib.sha256).digest()
assert result == expected
def test_hash_sha256(self, provider):
"""Test SHA256 hash calculation"""
data = b"test_data"
result = provider._hash_sha256(data)
expected = hashlib.sha256(data).digest()
assert result == expected
def test_get_signed_key(self, provider):
"""Test signed key generation"""
secret_key = "test_secret"
date = "20250722"
region = "cn-north-1"
service = "air"
result = provider._get_signed_key(secret_key, date, region, service)
assert isinstance(result, bytes)
assert len(result) == 32 # SHA256 digest is 32 bytes
def test_create_canonical_request(self, provider):
"""Test canonical request creation"""
method = "POST"
path = "/api/test"
query_params = {"param1": "value1", "param2": "value2"}
headers = {"Content-Type": "application/json", "Host": "example.com"}
payload = b'{"test": "data"}'
canonical_request, signed_headers = provider._create_canonical_request(
method, path, query_params, headers, payload
)
assert "POST" in canonical_request
assert "/api/test" in canonical_request
assert "param1=value1&param2=value2" in canonical_request
assert "content-type:application/json" in canonical_request
assert "host:example.com" in canonical_request
assert signed_headers == "content-type;host"
@patch("src.rag.vikingdb_knowledge_base.datetime")
def test_create_signature(self, mock_datetime, provider):
"""Test signature creation"""
# Mock datetime
mock_now = datetime(2025, 7, 22, 10, 30, 45)
mock_datetime.utcnow.return_value = mock_now
method = "POST"
path = "/api/test"
query_params = {}
headers = {}
payload = b'{"test": "data"}'
result = provider._create_signature(
method, path, query_params, headers, payload
)
assert "X-Date" in result
assert "Host" in result
assert "X-Content-Sha256" in result
assert "Content-Type" in result
assert "Authorization" in result
assert "HMAC-SHA256" in result["Authorization"]
@patch("src.rag.vikingdb_knowledge_base.requests.request")
def test_make_signed_request_success(self, mock_request, provider):
"""Test successful signed request"""
mock_response = MagicMock()
mock_response.json.return_value = {"code": 0, "data": {}}
mock_request.return_value = mock_response
result = provider._make_signed_request(
"POST", "/api/test", data={"test": "data"}
)
assert result == mock_response
mock_request.assert_called_once()
# Verify the call arguments
call_args = mock_request.call_args
assert call_args[1]["method"] == "POST"
assert call_args[1]["url"] == f"https://{provider.api_url}/api/test"
assert call_args[1]["timeout"] == 30
@patch("src.rag.vikingdb_knowledge_base.requests.request")
def test_make_signed_request_with_exception(self, mock_request, provider):
"""Test signed request with exception"""
mock_request.side_effect = Exception("Network error")
with pytest.raises(ValueError, match="Request failed: Network error"):
provider._make_signed_request("GET", "/api/test")
class TestVikingDBKnowledgeBaseProviderQueryRelevantDocuments:
@pytest.fixture
def provider(self, env_vars):
return VikingDBKnowledgeBaseProvider()
def test_query_relevant_documents_empty_resources(self, provider):
"""Test querying with empty resources list"""
result = provider.query_relevant_documents("test query", [])
assert result == []
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_query_relevant_documents_success(self, mock_request, provider):
"""Test successful document query"""
# Mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"code": 0,
"data": {
"result_list": [
{
"doc_info": {
"doc_id": "doc123",
"doc_name": "Test Document",
},
"content": "Test content",
"score": 0.95,
}
]
},
}
mock_request.return_value = mock_response
resources = [MockResource("rag://dataset/123")]
result = provider.query_relevant_documents("test query", resources)
assert len(result) == 1
assert result[0].id == "doc123"
assert result[0].title == "Test Document"
assert len(result[0].chunks) == 1
assert result[0].chunks[0].content == "Test content"
assert result[0].chunks[0].similarity == 0.95
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_query_relevant_documents_with_document_filter(
self, mock_request, provider
):
"""Test document query with document ID filter"""
mock_response = MagicMock()
mock_response.json.return_value = {"code": 0, "data": {"result_list": []}}
mock_request.return_value = mock_response
resources = [MockResource("rag://dataset/123#doc456")]
provider.query_relevant_documents("test query", resources)
# Verify that query_param with doc_filter was included in the request
call_args = mock_request.call_args
request_data = call_args[1]["data"]
assert "query_param" in request_data
assert "doc_filter" in request_data["query_param"]
doc_filter = request_data["query_param"]["doc_filter"]
assert doc_filter["op"] == "must"
assert doc_filter["field"] == "doc_id"
assert doc_filter["conds"] == ["doc456"]
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_query_relevant_documents_api_error(self, mock_request, provider):
"""Test handling of API error response"""
mock_response = MagicMock()
mock_response.json.return_value = {"code": 1, "message": "API Error"}
mock_request.return_value = mock_response
resources = [MockResource("rag://dataset/123")]
with pytest.raises(
ValueError, match="Failed to query documents from resource: API Error"
):
provider.query_relevant_documents("test query", resources)
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_query_relevant_documents_json_decode_error(self, mock_request, provider):
"""Test handling of JSON decode error"""
mock_response = MagicMock()
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
mock_request.return_value = mock_response
resources = [MockResource("rag://dataset/123")]
with pytest.raises(ValueError, match="Failed to parse JSON response"):
provider.query_relevant_documents("test query", resources)
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_query_relevant_documents_multiple_resources(self, mock_request, provider):
"""Test querying multiple resources and merging results"""
# Mock responses for different resources
responses = [
{
"code": 0,
"data": {
"result_list": [
{
"doc_info": {
"doc_id": "doc1",
"doc_name": "Document 1",
},
"content": "Content 1",
"score": 0.9,
}
]
},
},
{
"code": 0,
"data": {
"result_list": [
{
"doc_info": {
"doc_id": "doc1",
"doc_name": "Document 1",
},
"content": "Content 2",
"score": 0.8,
},
{
"doc_info": {
"doc_id": "doc2",
"doc_name": "Document 2",
},
"content": "Content 3",
"score": 0.7,
},
]
},
},
]
mock_responses = [MagicMock() for _ in responses]
for i, resp in enumerate(responses):
mock_responses[i].json.return_value = resp
mock_request.side_effect = mock_responses
resources = [
MockResource("rag://dataset/123"),
MockResource("rag://dataset/456"),
]
result = provider.query_relevant_documents("test query", resources)
# Should have 2 documents: doc1 (with 2 chunks) and doc2 (with 1 chunk)
assert len(result) == 2
doc1 = next(doc for doc in result if doc.id == "doc1")
doc2 = next(doc for doc in result if doc.id == "doc2")
assert len(doc1.chunks) == 2
assert len(doc2.chunks) == 1
class TestVikingDBKnowledgeBaseProviderListResources:
@pytest.fixture
def provider(self, env_vars):
return VikingDBKnowledgeBaseProvider()
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_list_resources_success(self, mock_request, provider):
"""Test successful resource listing"""
mock_response = MagicMock()
mock_response.json.return_value = {
"code": 0,
"data": {
"collection_list": [
{
"resource_id": "123",
"collection_name": "Dataset 1",
"description": "Description 1",
},
{
"resource_id": "456",
"collection_name": "Dataset 2",
"description": "Description 2",
},
]
},
}
mock_request.return_value = mock_response
result = provider.list_resources()
assert len(result) == 2
assert result[0].uri == "rag://dataset/123"
assert result[0].title == "Dataset 1"
assert result[0].description == "Description 1"
assert result[1].uri == "rag://dataset/456"
assert result[1].title == "Dataset 2"
assert result[1].description == "Description 2"
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_list_resources_with_query_filter(self, mock_request, provider):
"""Test resource listing with query filter"""
mock_response = MagicMock()
mock_response.json.return_value = {
"code": 0,
"data": {
"collection_list": [
{
"resource_id": "123",
"collection_name": "Test Dataset",
"description": "Description",
},
{
"resource_id": "456",
"collection_name": "Other Dataset",
"description": "Description",
},
]
},
}
mock_request.return_value = mock_response
result = provider.list_resources("test")
# Should only return the dataset with "test" in the name
assert len(result) == 1
assert result[0].title == "Test Dataset"
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_list_resources_api_error(self, mock_request, provider):
"""Test handling of API error in list_resources"""
mock_response = MagicMock()
mock_response.json.return_value = {"code": 1, "message": "API Error"}
mock_request.return_value = mock_response
with pytest.raises(Exception, match="Failed to list resources: API Error"):
provider.list_resources()
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_list_resources_json_decode_error(self, mock_request, provider):
"""Test handling of JSON decode error in list_resources"""
mock_response = MagicMock()
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
mock_request.return_value = mock_response
with pytest.raises(ValueError, match="Failed to parse JSON response"):
provider.list_resources()
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_list_resources_empty_response(self, mock_request, provider):
"""Test handling of empty response"""
mock_response = MagicMock()
mock_response.json.return_value = {"code": 0, "data": {"collection_list": []}}
mock_request.return_value = mock_response
result = provider.list_resources()
assert result == []
File diff suppressed because it is too large Load Diff
-168
View File
@@ -1,168 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from pydantic import ValidationError
import src.server.mcp_utils as mcp_utils # Assuming mcp_utils is the module to test
from src.config.report_style import ReportStyle
from src.rag.retriever import Resource
from src.server.chat_request import (
ChatMessage,
ChatRequest,
ContentItem,
EnhancePromptRequest,
GeneratePodcastRequest,
GeneratePPTRequest,
GenerateProseRequest,
TTSRequest,
)
def test_content_item_text_and_image():
item_text = ContentItem(type="text", text="hello")
assert item_text.type == "text"
assert item_text.text == "hello"
assert item_text.image_url is None
item_image = ContentItem(type="image", image_url="http://img.com/1.png")
assert item_image.type == "image"
assert item_image.text is None
assert item_image.image_url == "http://img.com/1.png"
def test_chat_message_with_string_content():
msg = ChatMessage(role="user", content="Hello!")
assert msg.role == "user"
assert msg.content == "Hello!"
def test_chat_message_with_content_items():
items = [ContentItem(type="text", text="hi")]
msg = ChatMessage(role="assistant", content=items)
assert msg.role == "assistant"
assert isinstance(msg.content, list)
assert msg.content[0].type == "text"
def test_chat_request_defaults():
req = ChatRequest()
assert req.messages == []
assert req.resources == []
assert req.debug is False
assert req.thread_id == "__default__"
assert req.max_plan_iterations == 1
assert req.max_step_num == 3
assert req.max_search_results == 3
assert req.auto_accepted_plan is False
assert req.interrupt_feedback is None
assert req.mcp_settings is None
assert req.enable_background_investigation is True
assert req.report_style == ReportStyle.ACADEMIC
def test_chat_request_with_values():
resource = Resource(
name="test", type="doc", uri="some-uri-value", title="some-title-value"
)
msg = ChatMessage(role="user", content="hi")
req = ChatRequest(
messages=[msg],
resources=[resource],
debug=True,
thread_id="tid",
max_plan_iterations=2,
max_step_num=5,
max_search_results=10,
auto_accepted_plan=True,
interrupt_feedback="stop",
mcp_settings={"foo": "bar"},
enable_background_investigation=False,
report_style="academic",
)
assert req.messages[0].role == "user"
assert req.debug is True
assert req.thread_id == "tid"
assert req.max_plan_iterations == 2
assert req.max_step_num == 5
assert req.max_search_results == 10
assert req.auto_accepted_plan is True
assert req.interrupt_feedback == "stop"
assert req.mcp_settings == {"foo": "bar"}
assert req.enable_background_investigation is False
assert req.report_style == ReportStyle.ACADEMIC
def test_tts_request_defaults():
req = TTSRequest(text="hello")
assert req.text == "hello"
assert req.voice_type == "BV700_V2_streaming"
assert req.encoding == "mp3"
assert req.speed_ratio == 1.0
assert req.volume_ratio == 1.0
assert req.pitch_ratio == 1.0
assert req.text_type == "plain"
assert req.with_frontend == 1
assert req.frontend_type == "unitTson"
def test_generate_podcast_request():
req = GeneratePodcastRequest(content="Podcast content")
assert req.content == "Podcast content"
def test_generate_ppt_request():
req = GeneratePPTRequest(content="PPT content")
assert req.content == "PPT content"
def test_generate_prose_request():
req = GenerateProseRequest(prompt="Write a poem", option="poet", command="rhyme")
assert req.prompt == "Write a poem"
assert req.option == "poet"
assert req.command == "rhyme"
req2 = GenerateProseRequest(prompt="Write", option="short")
assert req2.command == ""
def test_enhance_prompt_request_defaults():
req = EnhancePromptRequest(prompt="Improve this")
assert req.prompt == "Improve this"
assert req.context == ""
assert req.report_style == "academic"
def test_content_item_validation_error():
with pytest.raises(ValidationError):
ContentItem() # missing required 'type'
def test_chat_message_validation_error():
with pytest.raises(ValidationError):
ChatMessage(role="user") # missing content
def test_tts_request_validation_error():
with pytest.raises(ValidationError):
TTSRequest() # missing required 'text'
@pytest.mark.asyncio
@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock)
@patch("src.server.mcp_utils.StdioServerParameters")
@patch("src.server.mcp_utils.stdio_client")
async def test_load_mcp_tools_exception_handling(
mock_stdio_client, mock_StdioServerParameters, mock_get_tools
): # Changed to async def
mock_get_tools.side_effect = Exception("unexpected error")
mock_StdioServerParameters.return_value = MagicMock()
mock_stdio_client.return_value = MagicMock()
with pytest.raises(HTTPException) as exc:
await mcp_utils.load_mcp_tools(server_type="stdio", command="node") # Use await
assert exc.value.status_code == 500
assert "unexpected error" in exc.value.detail
-77
View File
@@ -1,77 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
from pydantic import ValidationError
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
def test_mcp_server_metadata_request_required_fields():
# 'transport' is required
req = MCPServerMetadataRequest(transport="stdio")
assert req.transport == "stdio"
assert req.command is None
assert req.args is None
assert req.url is None
assert req.env is None
assert req.timeout_seconds is None
assert req.sse_read_timeout is None
def test_mcp_server_metadata_request_optional_fields():
req = MCPServerMetadataRequest(
transport="sse",
command="run",
args=["--foo", "bar"],
url="http://localhost:8080",
env={"FOO": "BAR"},
timeout_seconds=30,
sse_read_timeout=15,
)
assert req.transport == "sse"
assert req.command == "run"
assert req.args == ["--foo", "bar"]
assert req.url == "http://localhost:8080"
assert req.env == {"FOO": "BAR"}
assert req.timeout_seconds == 30
assert req.sse_read_timeout == 15
def test_mcp_server_metadata_request_missing_transport():
with pytest.raises(ValidationError):
MCPServerMetadataRequest()
def test_mcp_server_metadata_response_required_fields():
resp = MCPServerMetadataResponse(transport="stdio")
assert resp.transport == "stdio"
assert resp.command is None
assert resp.args is None
assert resp.url is None
assert resp.env is None
assert resp.tools == []
def test_mcp_server_metadata_response_optional_fields():
resp = MCPServerMetadataResponse(
transport="sse",
command="run",
args=["--foo", "bar"],
url="http://localhost:8080",
env={"FOO": "BAR"},
tools=["tool1", "tool2"],
)
assert resp.transport == "sse"
assert resp.command == "run"
assert resp.args == ["--foo", "bar"]
assert resp.url == "http://localhost:8080"
assert resp.env == {"FOO": "BAR"}
assert resp.tools == ["tool1", "tool2"]
def test_mcp_server_metadata_response_tools_default_factory():
resp1 = MCPServerMetadataResponse(transport="stdio")
resp2 = MCPServerMetadataResponse(transport="stdio")
resp1.tools.append("toolA")
assert resp2.tools == [] # Should not share list between instances
-185
View File
@@ -1,185 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
import src.server.mcp_utils as mcp_utils
@pytest.mark.asyncio
@patch("src.server.mcp_utils.ClientSession")
async def test__get_tools_from_client_session_success(mock_ClientSession):
mock_read = AsyncMock()
mock_write = AsyncMock()
mock_callback = AsyncMock()
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__.return_value = (
mock_read,
mock_write,
mock_callback,
)
mock_context_manager.__aexit__.return_value = None
mock_session = AsyncMock()
mock_session.__aenter__.return_value = mock_session
mock_session.__aexit__.return_value = None
mock_session.initialize = AsyncMock()
mock_tools_obj = MagicMock()
mock_tools_obj.tools = ["tool1", "tool2"]
mock_session.list_tools = AsyncMock(return_value=mock_tools_obj)
mock_ClientSession.return_value = mock_session
result = await mcp_utils._get_tools_from_client_session(
mock_context_manager, timeout_seconds=5
)
assert result == ["tool1", "tool2"]
mock_session.initialize.assert_awaited_once()
mock_session.list_tools.assert_awaited_once()
@pytest.mark.asyncio
@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock)
@patch("src.server.mcp_utils.StdioServerParameters")
@patch("src.server.mcp_utils.stdio_client")
async def test_load_mcp_tools_stdio_success(
mock_stdio_client, mock_StdioServerParameters, mock_get_tools
):
mock_get_tools.return_value = ["toolA"]
params = MagicMock()
mock_StdioServerParameters.return_value = params
mock_client = MagicMock()
mock_stdio_client.return_value = mock_client
result = await mcp_utils.load_mcp_tools(
server_type="stdio",
command="node",
args=["server.js"],
env={"API_KEY": "test123"},
timeout_seconds=3,
)
assert result == ["toolA"]
mock_StdioServerParameters.assert_called_once_with(
command="node", args=["server.js"], env={"API_KEY": "test123"}
)
mock_stdio_client.assert_called_once_with(params)
mock_get_tools.assert_awaited_once_with(mock_client, 3)
@pytest.mark.asyncio
async def test_load_mcp_tools_stdio_missing_command():
with pytest.raises(HTTPException) as exc:
await mcp_utils.load_mcp_tools(server_type="stdio")
assert exc.value.status_code == 400
assert "Command is required" in exc.value.detail
@pytest.mark.asyncio
@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock)
@patch("src.server.mcp_utils.sse_client")
async def test_load_mcp_tools_sse_success(mock_sse_client, mock_get_tools):
mock_get_tools.return_value = ["toolB"]
mock_client = MagicMock()
mock_sse_client.return_value = mock_client
result = await mcp_utils.load_mcp_tools(
server_type="sse",
url="http://localhost:1234",
headers={"Authorization": "Bearer 1234567890"},
timeout_seconds=7,
)
assert result == ["toolB"]
# When sse_read_timeout is None, it should not be passed
mock_sse_client.assert_called_once_with(
url="http://localhost:1234",
headers={"Authorization": "Bearer 1234567890"},
timeout=7,
)
mock_get_tools.assert_awaited_once_with(mock_client, 7)
@pytest.mark.asyncio
@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock)
@patch("src.server.mcp_utils.sse_client")
async def test_load_mcp_tools_sse_with_sse_read_timeout(mock_sse_client, mock_get_tools):
"""Test that sse_read_timeout parameter is used when provided."""
mock_get_tools.return_value = ["toolC"]
mock_client = MagicMock()
mock_sse_client.return_value = mock_client
result = await mcp_utils.load_mcp_tools(
server_type="sse",
url="http://localhost:1234",
headers={"Authorization": "Bearer token"},
timeout_seconds=10,
sse_read_timeout=5,
)
assert result == ["toolC"]
# Both timeout_seconds and sse_read_timeout should be passed
mock_sse_client.assert_called_once_with(
url="http://localhost:1234",
headers={"Authorization": "Bearer token"},
timeout=10,
sse_read_timeout=5,
)
# But timeout_seconds should be used for the session timeout
mock_get_tools.assert_awaited_once_with(mock_client, 10)
@pytest.mark.asyncio
@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock)
@patch("src.server.mcp_utils.sse_client")
async def test_load_mcp_tools_sse_without_sse_read_timeout(mock_sse_client, mock_get_tools):
"""Test that timeout_seconds is used when sse_read_timeout is not provided."""
mock_get_tools.return_value = ["toolD"]
mock_client = MagicMock()
mock_sse_client.return_value = mock_client
result = await mcp_utils.load_mcp_tools(
server_type="sse",
url="http://localhost:1234",
timeout_seconds=20,
)
assert result == ["toolD"]
# When sse_read_timeout is not provided, it should not be passed
mock_sse_client.assert_called_once_with(
url="http://localhost:1234",
headers=None,
timeout=20,
)
mock_get_tools.assert_awaited_once_with(mock_client, 20)
@pytest.mark.asyncio
async def test_load_mcp_tools_sse_missing_url():
with pytest.raises(HTTPException) as exc:
await mcp_utils.load_mcp_tools(server_type="sse")
assert exc.value.status_code == 400
assert "URL is required" in exc.value.detail
@pytest.mark.asyncio
async def test_load_mcp_tools_unsupported_type():
with pytest.raises(HTTPException) as exc:
await mcp_utils.load_mcp_tools(server_type="unknown")
assert exc.value.status_code == 400
assert "Invalid transport type" in exc.value.detail or "Unsupported server type" in exc.value.detail
@pytest.mark.asyncio
@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock)
@patch("src.server.mcp_utils.StdioServerParameters")
@patch("src.server.mcp_utils.stdio_client")
async def test_load_mcp_tools_exception_handling(
mock_stdio_client, mock_StdioServerParameters, mock_get_tools
):
mock_get_tools.side_effect = Exception("unexpected error")
mock_StdioServerParameters.return_value = MagicMock()
mock_stdio_client.return_value = MagicMock()
with pytest.raises(HTTPException) as exc:
await mcp_utils.load_mcp_tools(server_type="stdio", command="node")
assert exc.value.status_code == 500
assert "unexpected error" in exc.value.detail
-450
View File
@@ -1,450 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for MCP server configuration validators.
Tests cover:
- Command validation (allowlist)
- Argument validation (path traversal, command injection)
- Environment variable validation
- URL validation
- Header validation
- Full config validation
"""
import pytest
from src.server.mcp_validators import (
ALLOWED_COMMANDS,
MCPValidationError,
validate_args_for_local_file_access,
validate_command,
validate_command_injection,
validate_environment_variables,
validate_headers,
validate_mcp_server_config,
validate_url,
)
class TestValidateCommand:
"""Tests for validate_command function."""
def test_allowed_commands(self):
"""Test that all allowed commands pass validation."""
for cmd in ALLOWED_COMMANDS:
validate_command(cmd) # Should not raise
def test_allowed_command_with_path(self):
"""Test that commands with paths are validated by base name."""
validate_command("/usr/bin/python3")
validate_command("/usr/local/bin/node")
validate_command("C:\\Python\\python.exe")
def test_disallowed_command(self):
"""Test that disallowed commands raise an error."""
with pytest.raises(MCPValidationError) as exc_info:
validate_command("bash")
assert "not allowed" in exc_info.value.message
assert exc_info.value.field == "command"
def test_disallowed_dangerous_commands(self):
"""Test that dangerous commands are rejected."""
dangerous_commands = ["rm", "sudo", "chmod", "chown", "curl", "wget", "sh"]
for cmd in dangerous_commands:
with pytest.raises(MCPValidationError):
validate_command(cmd)
def test_empty_command(self):
"""Test that empty command raises an error."""
with pytest.raises(MCPValidationError):
validate_command("")
def test_none_command(self):
"""Test that None command raises an error."""
with pytest.raises(MCPValidationError):
validate_command(None)
class TestValidateArgsForLocalFileAccess:
"""Tests for validate_args_for_local_file_access function."""
def test_safe_args(self):
"""Test that safe arguments pass validation."""
safe_args = [
["--help"],
["-v", "--verbose"],
["package-name"],
["--config", "config.json"],
["run", "script.py"],
]
for args in safe_args:
validate_args_for_local_file_access(args) # Should not raise
def test_directory_traversal(self):
"""Test that directory traversal patterns are rejected."""
traversal_patterns = [
["../etc/passwd"],
["..\\windows\\system32"],
["../../secret"],
["foo/../bar/../../../etc/passwd"],
["foo/.."], # ".." at end after path separator
["bar\\.."], # ".." at end after Windows path separator
["path/to/foo/.."], # Longer path ending with ".."
]
for args in traversal_patterns:
with pytest.raises(MCPValidationError) as exc_info:
validate_args_for_local_file_access(args)
assert "traversal" in exc_info.value.message.lower()
def test_absolute_path_with_dangerous_extension(self):
"""Test that absolute paths with dangerous extensions are rejected."""
with pytest.raises(MCPValidationError):
validate_args_for_local_file_access(["/etc/passwd.sh"])
def test_windows_absolute_path(self):
"""Test that Windows absolute paths are rejected."""
with pytest.raises(MCPValidationError):
validate_args_for_local_file_access(["C:\\Windows\\system32"])
def test_home_directory_reference(self):
"""Test that home directory references are rejected."""
with pytest.raises(MCPValidationError):
validate_args_for_local_file_access(["~/secrets"])
with pytest.raises(MCPValidationError):
validate_args_for_local_file_access(["~\\secrets"])
def test_null_byte(self):
"""Test that null bytes in arguments are rejected."""
with pytest.raises(MCPValidationError) as exc_info:
validate_args_for_local_file_access(["file\x00.txt"])
assert "null byte" in exc_info.value.message.lower()
def test_excessively_long_argument(self):
"""Test that excessively long arguments are rejected."""
with pytest.raises(MCPValidationError) as exc_info:
validate_args_for_local_file_access(["a" * 1001])
assert "maximum length" in exc_info.value.message.lower()
def test_dangerous_extensions(self):
"""Test that dangerous file extensions are rejected."""
dangerous_files = [
["script.sh"],
["binary.exe"],
["library.dll"],
["secret.env"],
["key.pem"],
]
for args in dangerous_files:
with pytest.raises(MCPValidationError) as exc_info:
validate_args_for_local_file_access(args)
assert "dangerous file type" in exc_info.value.message.lower()
def test_empty_args(self):
"""Test that empty args list passes validation."""
validate_args_for_local_file_access([])
validate_args_for_local_file_access(None)
class TestValidateCommandInjection:
"""Tests for validate_command_injection function."""
def test_safe_args(self):
"""Test that safe arguments pass validation."""
safe_args = [
["--help"],
["package-name"],
["@scope/package"],
["file.json"],
]
for args in safe_args:
validate_command_injection(args) # Should not raise
def test_shell_metacharacters(self):
"""Test that shell metacharacters are rejected."""
metachar_args = [
["foo; rm -rf /"],
["foo & bar"],
["foo | cat /etc/passwd"],
["$(whoami)"],
["`id`"],
["foo > /etc/passwd"],
["foo < /etc/passwd"],
["${PATH}"],
]
for args in metachar_args:
with pytest.raises(MCPValidationError) as exc_info:
validate_command_injection(args)
assert "args" == exc_info.value.field
def test_command_chaining(self):
"""Test that command chaining patterns are rejected."""
chaining_args = [
["foo && bar"],
["foo || bar"],
["foo;; bar"],
["foo >> output"],
["foo << input"],
]
for args in chaining_args:
with pytest.raises(MCPValidationError):
validate_command_injection(args)
def test_backtick_injection(self):
"""Test that backtick command substitution is rejected."""
with pytest.raises(MCPValidationError):
validate_command_injection(["`whoami`"])
def test_process_substitution(self):
"""Test that process substitution is rejected."""
with pytest.raises(MCPValidationError):
validate_command_injection(["<(cat /etc/passwd)"])
with pytest.raises(MCPValidationError):
validate_command_injection([">(tee /tmp/out)"])
class TestValidateEnvironmentVariables:
"""Tests for validate_environment_variables function."""
def test_safe_env_vars(self):
"""Test that safe environment variables pass validation."""
safe_env = {
"API_KEY": "secret123",
"DEBUG": "true",
"MY_VARIABLE": "value",
}
validate_environment_variables(safe_env) # Should not raise
def test_dangerous_env_vars(self):
"""Test that dangerous environment variables are rejected."""
dangerous_vars = [
{"PATH": "/malicious/path"},
{"LD_LIBRARY_PATH": "/malicious/lib"},
{"DYLD_LIBRARY_PATH": "/malicious/lib"},
{"LD_PRELOAD": "/malicious/lib.so"},
{"PYTHONPATH": "/malicious/python"},
{"NODE_PATH": "/malicious/node"},
]
for env in dangerous_vars:
with pytest.raises(MCPValidationError) as exc_info:
validate_environment_variables(env)
assert "not allowed" in exc_info.value.message.lower()
def test_null_byte_in_value(self):
"""Test that null bytes in values are rejected."""
with pytest.raises(MCPValidationError):
validate_environment_variables({"KEY": "value\x00malicious"})
def test_empty_env(self):
"""Test that empty env dict passes validation."""
validate_environment_variables({})
validate_environment_variables(None)
class TestValidateUrl:
"""Tests for validate_url function."""
def test_valid_urls(self):
"""Test that valid URLs pass validation."""
valid_urls = [
"http://localhost:3000",
"https://api.example.com",
"http://192.168.1.1:8080/api",
"https://mcp.example.com/sse",
]
for url in valid_urls:
validate_url(url) # Should not raise
def test_invalid_scheme(self):
"""Test that invalid URL schemes are rejected."""
with pytest.raises(MCPValidationError) as exc_info:
validate_url("ftp://example.com")
assert "scheme" in exc_info.value.message.lower()
with pytest.raises(MCPValidationError):
validate_url("file:///etc/passwd")
def test_credentials_in_url(self):
"""Test that URLs with credentials are rejected."""
with pytest.raises(MCPValidationError) as exc_info:
validate_url("https://user:pass@example.com")
assert "credentials" in exc_info.value.message.lower()
def test_null_byte_in_url(self):
"""Test that null bytes in URL are rejected."""
with pytest.raises(MCPValidationError):
validate_url("https://example.com\x00/malicious")
def test_empty_url(self):
"""Test that empty URL raises an error."""
with pytest.raises(MCPValidationError):
validate_url("")
def test_no_host(self):
"""Test that URL without host raises an error."""
with pytest.raises(MCPValidationError):
validate_url("http:///path")
class TestValidateHeaders:
"""Tests for validate_headers function."""
def test_valid_headers(self):
"""Test that valid headers pass validation."""
valid_headers = {
"Authorization": "Bearer token123",
"Content-Type": "application/json",
"X-Custom-Header": "value",
}
validate_headers(valid_headers) # Should not raise
def test_newline_in_header_name(self):
"""Test that newlines in header names are rejected (HTTP header injection)."""
with pytest.raises(MCPValidationError) as exc_info:
validate_headers({"X-Bad\nHeader": "value"})
assert "newline" in exc_info.value.message.lower()
def test_newline_in_header_value(self):
"""Test that newlines in header values are rejected (HTTP header injection)."""
with pytest.raises(MCPValidationError):
validate_headers({"X-Header": "value\r\nX-Injected: malicious"})
def test_null_byte_in_header(self):
"""Test that null bytes in headers are rejected."""
with pytest.raises(MCPValidationError):
validate_headers({"X-Header": "value\x00"})
def test_empty_headers(self):
"""Test that empty headers dict passes validation."""
validate_headers({})
validate_headers(None)
class TestValidateMCPServerConfig:
"""Tests for the main validate_mcp_server_config function."""
def test_valid_stdio_config(self):
"""Test valid stdio configuration."""
validate_mcp_server_config(
transport="stdio",
command="npx",
args=["@modelcontextprotocol/server-filesystem"],
env={"API_KEY": "secret"},
) # Should not raise
def test_valid_sse_config(self):
"""Test valid SSE configuration."""
validate_mcp_server_config(
transport="sse",
url="https://api.example.com/sse",
headers={"Authorization": "Bearer token"},
) # Should not raise
def test_valid_http_config(self):
"""Test valid streamable_http configuration."""
validate_mcp_server_config(
transport="streamable_http",
url="https://api.example.com/mcp",
) # Should not raise
def test_invalid_transport(self):
"""Test that invalid transport type raises an error."""
with pytest.raises(MCPValidationError) as exc_info:
validate_mcp_server_config(transport="invalid")
assert "Invalid transport type" in exc_info.value.message
def test_combined_validation_errors(self):
"""Test that multiple validation errors are combined."""
with pytest.raises(MCPValidationError) as exc_info:
validate_mcp_server_config(
transport="stdio",
command="bash", # Not allowed
args=["../etc/passwd"], # Path traversal
env={"PATH": "/malicious"}, # Dangerous env var
)
# All errors should be combined
assert "not allowed" in exc_info.value.message
assert "traversal" in exc_info.value.message.lower()
def test_non_strict_mode(self):
"""Test that non-strict mode logs warnings instead of raising."""
# Should not raise, but would log warnings
validate_mcp_server_config(
transport="stdio",
command="bash",
strict=False,
)
def test_stdio_with_dangerous_args(self):
"""Test stdio config with command injection attempt."""
with pytest.raises(MCPValidationError):
validate_mcp_server_config(
transport="stdio",
command="node",
args=["script.js; rm -rf /"],
)
def test_sse_with_invalid_url(self):
"""Test SSE config with invalid URL."""
with pytest.raises(MCPValidationError):
validate_mcp_server_config(
transport="sse",
url="ftp://example.com",
)
class TestMCPServerMetadataRequest:
"""Tests for Pydantic model validation."""
def test_valid_request(self):
"""Test that valid request passes validation."""
from src.server.mcp_request import MCPServerMetadataRequest
request = MCPServerMetadataRequest(
transport="stdio",
command="npx",
args=["@modelcontextprotocol/server-filesystem"],
)
assert request.transport == "stdio"
assert request.command == "npx"
def test_invalid_command_raises_validation_error(self):
"""Test that invalid command raises Pydantic ValidationError."""
from pydantic import ValidationError
from src.server.mcp_request import MCPServerMetadataRequest
with pytest.raises(ValidationError) as exc_info:
MCPServerMetadataRequest(
transport="stdio",
command="bash",
)
assert "not allowed" in str(exc_info.value).lower()
def test_command_injection_raises_validation_error(self):
"""Test that command injection raises Pydantic ValidationError."""
from pydantic import ValidationError
from src.server.mcp_request import MCPServerMetadataRequest
with pytest.raises(ValidationError):
MCPServerMetadataRequest(
transport="stdio",
command="node",
args=["script.js; rm -rf /"],
)
def test_invalid_url_raises_validation_error(self):
"""Test that invalid URL raises Pydantic ValidationError."""
from pydantic import ValidationError
from src.server.mcp_request import MCPServerMetadataRequest
with pytest.raises(ValidationError):
MCPServerMetadataRequest(
transport="sse",
url="ftp://example.com",
)
-317
View File
@@ -1,317 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for tool call chunk processing.
Tests for the fix of issue #523: Tool name concatenation in consecutive tool calls.
This ensures that tool call chunks are properly segregated by index to prevent
tool names from being concatenated when multiple tool calls happen in sequence.
"""
import logging
import os
# Import the functions to test
# Note: We need to import from the app module
import sys
from unittest.mock import MagicMock, patch
import pytest
# Add src directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../"))
from src.server.app import _process_tool_call_chunks, _validate_tool_call_chunks
class TestProcessToolCallChunks:
"""Test cases for _process_tool_call_chunks function."""
def test_empty_tool_call_chunks(self):
"""Test processing empty tool call chunks."""
result = _process_tool_call_chunks([])
assert result == []
def test_single_tool_call_single_chunk(self):
"""Test processing a single tool call with a single chunk."""
chunks = [
{"name": "web_search", "args": '{"query": "test"}', "id": "call_1", "index": 0}
]
result = _process_tool_call_chunks(chunks)
assert len(result) == 1
assert result[0]["name"] == "web_search"
assert result[0]["id"] == "call_1"
assert result[0]["index"] == 0
assert '"query": "test"' in result[0]["args"]
def test_consecutive_tool_calls_different_indices(self):
"""Test that consecutive tool calls with different indices are not concatenated."""
chunks = [
{"name": "web_search", "args": '{"query": "test"}', "id": "call_1", "index": 0},
{"name": "web_search", "args": '{"query": "test2"}', "id": "call_2", "index": 1},
]
result = _process_tool_call_chunks(chunks)
assert len(result) == 2
assert result[0]["name"] == "web_search"
assert result[0]["id"] == "call_1"
assert result[0]["index"] == 0
assert result[1]["name"] == "web_search"
assert result[1]["id"] == "call_2"
assert result[1]["index"] == 1
# Verify names are NOT concatenated
assert result[0]["name"] != "web_searchweb_search"
assert result[1]["name"] != "web_searchweb_search"
def test_different_tools_different_indices(self):
"""Test consecutive calls to different tools."""
chunks = [
{"name": "web_search", "args": '{"query": "test"}', "id": "call_1", "index": 0},
{"name": "crawl_tool", "args": '{"url": "http://example.com"}', "id": "call_2", "index": 1},
]
result = _process_tool_call_chunks(chunks)
assert len(result) == 2
assert result[0]["name"] == "web_search"
assert result[1]["name"] == "crawl_tool"
# Verify names are NOT concatenated (the issue bug scenario)
assert "web_searchcrawl_tool" not in result[0]["name"]
assert "web_searchcrawl_tool" not in result[1]["name"]
def test_streaming_chunks_same_index(self):
"""Test streaming chunks with same index are properly accumulated."""
chunks = [
{"name": "web_", "args": '{"query"', "id": "call_1", "index": 0},
{"name": "search", "args": ': "test"}', "id": "call_1", "index": 0},
]
result = _process_tool_call_chunks(chunks)
assert len(result) == 1
# Name should NOT be concatenated when it's the same tool
assert result[0]["name"] in ["web_", "search", "web_search"]
assert result[0]["id"] == "call_1"
# Args should be accumulated
assert "query" in result[0]["args"]
assert "test" in result[0]["args"]
def test_tool_call_index_collision_warning(self):
"""Test that index collision with different names generates warning."""
chunks = [
{"name": "web_search", "args": '{}', "id": "call_1", "index": 0},
{"name": "crawl_tool", "args": '{}', "id": "call_2", "index": 0},
]
# This should trigger a warning
with patch('src.server.app.logger') as mock_logger:
result = _process_tool_call_chunks(chunks)
# Verify warning was logged
mock_logger.warning.assert_called()
call_args = mock_logger.warning.call_args[0][0]
assert "Tool name mismatch detected" in call_args
assert "web_search" in call_args
assert "crawl_tool" in call_args
def test_chunks_without_explicit_index(self):
"""Test handling chunks without explicit index (edge case)."""
chunks = [
{"name": "web_search", "args": '{}', "id": "call_1"} # No index
]
result = _process_tool_call_chunks(chunks)
assert len(result) == 1
assert result[0]["name"] == "web_search"
def test_chunk_sorting_by_index(self):
"""Test that chunks are sorted by index in proper order."""
chunks = [
{"name": "tool_3", "args": '{}', "id": "call_3", "index": 2},
{"name": "tool_1", "args": '{}', "id": "call_1", "index": 0},
{"name": "tool_2", "args": '{}', "id": "call_2", "index": 1},
]
result = _process_tool_call_chunks(chunks)
assert len(result) == 3
assert result[0]["index"] == 0
assert result[1]["index"] == 1
assert result[2]["index"] == 2
def test_args_accumulation(self):
"""Test that arguments are properly accumulated for same index."""
chunks = [
{"name": "web_search", "args": '{"q', "id": "call_1", "index": 0},
{"name": "web_search", "args": 'uery": "test"}', "id": "call_1", "index": 0},
]
result = _process_tool_call_chunks(chunks)
assert len(result) == 1
# Sanitize removes json encoding, so just check it's accumulated
assert len(result[0]["args"]) > 0
def test_preserve_tool_id(self):
"""Test that tool IDs are preserved correctly."""
chunks = [
{"name": "web_search", "args": '{}', "id": "call_abc123", "index": 0},
{"name": "web_search", "args": '{}', "id": "call_xyz789", "index": 1},
]
result = _process_tool_call_chunks(chunks)
assert result[0]["id"] == "call_abc123"
assert result[1]["id"] == "call_xyz789"
def test_multiple_indices_detected(self):
"""Test that multiple indices are properly detected and logged."""
chunks = [
{"name": "tool_a", "args": '{}', "id": "call_1", "index": 0},
{"name": "tool_b", "args": '{}', "id": "call_2", "index": 1},
{"name": "tool_c", "args": '{}', "id": "call_3", "index": 2},
]
with patch('src.server.app.logger') as mock_logger:
result = _process_tool_call_chunks(chunks)
# Should have debug logs for multiple indices
debug_calls = [call[0][0] for call in mock_logger.debug.call_args_list]
# Check if any debug call mentions multiple indices
multiple_indices_mentioned = any(
"Multiple indices" in call for call in debug_calls
)
assert multiple_indices_mentioned or len(result) == 3
class TestValidateToolCallChunks:
"""Test cases for _validate_tool_call_chunks function."""
def test_validate_empty_chunks(self):
"""Test validation of empty chunks."""
# Should not raise any exception
_validate_tool_call_chunks([])
def test_validate_logs_chunk_info(self):
"""Test that validation logs chunk information."""
chunks = [
{"name": "web_search", "args": '{}', "id": "call_1", "index": 0},
]
with patch('src.server.app.logger') as mock_logger:
_validate_tool_call_chunks(chunks)
# Should have logged debug info
assert mock_logger.debug.called
def test_validate_detects_multiple_indices(self):
"""Test that validation detects multiple indices."""
chunks = [
{"name": "tool_1", "args": '{}', "id": "call_1", "index": 0},
{"name": "tool_2", "args": '{}', "id": "call_2", "index": 1},
]
with patch('src.server.app.logger') as mock_logger:
_validate_tool_call_chunks(chunks)
# Should have logged about multiple indices
debug_calls = [call[0][0] for call in mock_logger.debug.call_args_list]
multiple_indices_mentioned = any(
"Multiple indices" in call for call in debug_calls
)
assert multiple_indices_mentioned
class TestRealWorldScenarios:
"""Test cases for real-world scenarios from issue #523."""
def test_issue_523_scenario_consecutive_web_search(self):
"""
Replicate issue #523: Consecutive web_search calls.
Previously would result in "web_searchweb_search" error.
"""
# Simulate streaming chunks from two consecutive web_search calls
chunks = [
# First web_search call (index 0)
{"name": "web_", "args": '{"query', "id": "call_1", "index": 0},
{"name": "search", "args": '": "first query"}', "id": "call_1", "index": 0},
# Second web_search call (index 1)
{"name": "web_", "args": '{"query', "id": "call_2", "index": 1},
{"name": "search", "args": '": "second query"}', "id": "call_2", "index": 1},
]
result = _process_tool_call_chunks(chunks)
# Should have 2 tool calls, not concatenated names
assert len(result) >= 1 # At minimum should process without error
# Extract tool names from result
tool_names = [chunk.get("name") for chunk in result]
# Verify "web_searchweb_search" error doesn't occur
assert "web_searchweb_search" not in tool_names
# Both calls should have web_search (or parts of it)
concatenated_names = "".join(tool_names)
assert "web_search" in concatenated_names or "web_" in concatenated_names
def test_mixed_tools_consecutive_calls(self):
"""Test realistic scenario with mixed tools in sequence."""
chunks = [
# web_search call
{"name": "web_search", "args": '{"query": "python"}', "id": "1", "index": 0},
# crawl_tool call
{"name": "crawl_tool", "args": '{"url": "http://example.com"}', "id": "2", "index": 1},
# Another web_search
{"name": "web_search", "args": '{"query": "rust"}', "id": "3", "index": 2},
]
result = _process_tool_call_chunks(chunks)
assert len(result) == 3
tool_names = [chunk.get("name") for chunk in result]
# No concatenation should occur
assert "web_searchcrawl_tool" not in tool_names
assert "crawl_toolweb_search" not in tool_names
def test_long_sequence_tool_calls(self):
"""Test a long sequence of tool calls."""
chunks = []
for i in range(10):
tool_name = "web_search" if i % 2 == 0 else "crawl_tool"
chunks.append({
"name": tool_name,
"args": '{"query": "test"}' if tool_name == "web_search" else '{"url": "http://example.com"}',
"id": f"call_{i}",
"index": i
})
result = _process_tool_call_chunks(chunks)
# Should process all 10 tool calls
assert len(result) == 10
# Verify each tool call has correct name (not concatenated with other tool names)
for i, chunk in enumerate(result):
expected_name = "web_search" if i % 2 == 0 else "crawl_tool"
actual_name = chunk.get("name", "")
# The actual name should be the expected name, not concatenated
assert actual_name == expected_name, (
f"Tool call {i} has name '{actual_name}', expected '{expected_name}'. "
f"This indicates concatenation with adjacent tool call."
)
# Verify IDs are correct
assert chunk.get("id") == f"call_{i}"
assert chunk.get("index") == i
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
-216
View File
@@ -1,216 +0,0 @@
import json
from unittest.mock import Mock, patch
from src.tools.crawl import crawl_tool, is_pdf_url
class TestCrawlTool:
@patch("src.tools.crawl.Crawler")
def test_crawl_tool_success(self, mock_crawler_class):
# Arrange
mock_crawler = Mock()
mock_article = Mock()
mock_article.to_markdown.return_value = (
"# Test Article\nThis is test content." * 100
)
mock_crawler.crawl.return_value = mock_article
mock_crawler_class.return_value = mock_crawler
url = "https://example.com"
# Act
result = crawl_tool.invoke({"url": url})
# Assert
assert isinstance(result, str)
result_dict = json.loads(result)
assert result_dict["url"] == url
assert "crawled_content" in result_dict
assert len(result_dict["crawled_content"]) <= 1000
mock_crawler_class.assert_called_once()
mock_crawler.crawl.assert_called_once_with(url)
mock_article.to_markdown.assert_called_once()
@patch("src.tools.crawl.Crawler")
def test_crawl_tool_short_content(self, mock_crawler_class):
# Arrange
mock_crawler = Mock()
mock_article = Mock()
short_content = "Short content"
mock_article.to_markdown.return_value = short_content
mock_crawler.crawl.return_value = mock_article
mock_crawler_class.return_value = mock_crawler
url = "https://example.com"
# Act
result = crawl_tool.invoke({"url": url})
# Assert
result_dict = json.loads(result)
assert result_dict["crawled_content"] == short_content
@patch("src.tools.crawl.Crawler")
@patch("src.tools.crawl.logger")
def test_crawl_tool_crawler_exception(self, mock_logger, mock_crawler_class):
# Arrange
mock_crawler = Mock()
mock_crawler.crawl.side_effect = Exception("Network error")
mock_crawler_class.return_value = mock_crawler
url = "https://example.com"
# Act
result = crawl_tool.invoke({"url": url})
# Assert
assert isinstance(result, str)
assert "Failed to crawl" in result
assert "Network error" in result
mock_logger.error.assert_called_once()
@patch("src.tools.crawl.Crawler")
@patch("src.tools.crawl.logger")
def test_crawl_tool_crawler_instantiation_exception(
self, mock_logger, mock_crawler_class
):
# Arrange
mock_crawler_class.side_effect = Exception("Crawler init error")
url = "https://example.com"
# Act
result = crawl_tool.invoke({"url": url})
# Assert
assert isinstance(result, str)
assert "Failed to crawl" in result
assert "Crawler init error" in result
mock_logger.error.assert_called_once()
@patch("src.tools.crawl.Crawler")
@patch("src.tools.crawl.logger")
def test_crawl_tool_markdown_conversion_exception(
self, mock_logger, mock_crawler_class
):
# Arrange
mock_crawler = Mock()
mock_article = Mock()
mock_article.to_markdown.side_effect = Exception("Markdown conversion error")
mock_crawler.crawl.return_value = mock_article
mock_crawler_class.return_value = mock_crawler
url = "https://example.com"
# Act
result = crawl_tool.invoke({"url": url})
# Assert
assert isinstance(result, str)
assert "Failed to crawl" in result
assert "Markdown conversion error" in result
mock_logger.error.assert_called_once()
@patch("src.tools.crawl.Crawler")
def test_crawl_tool_with_none_content(self, mock_crawler_class):
# Arrange
mock_crawler = Mock()
mock_article = Mock()
mock_article.to_markdown.return_value = "# Article\n\n*No content available*\n"
mock_crawler.crawl.return_value = mock_article
mock_crawler_class.return_value = mock_crawler
url = "https://example.com"
# Act
result = crawl_tool.invoke({"url": url})
# Assert
assert isinstance(result, str)
result_dict = json.loads(result)
assert result_dict["url"] == url
assert "crawled_content" in result_dict
assert "No content available" in result_dict["crawled_content"]
class TestPDFHandling:
"""Test PDF URL detection and handling for issue #701."""
def test_is_pdf_url_with_pdf_urls(self):
"""Test that PDF URLs are correctly identified."""
test_cases = [
("https://example.com/document.pdf", True),
("https://example.com/file.PDF", True), # Case insensitive
("https://example.com/path/to/report.pdf", True),
("https://pdf.dfcfw.com/pdf/H3_AP202503071644153386_1.pdf", True), # URL from issue
("http://site.com/path/document.pdf?param=value", True), # With query params
]
for url, expected in test_cases:
assert is_pdf_url(url) == expected, f"Failed for URL: {url}"
def test_is_pdf_url_with_non_pdf_urls(self):
"""Test that non-PDF URLs are correctly identified."""
test_cases = [
("https://example.com/page.html", False),
("https://example.com/article.php", False),
("https://example.com/", False),
("https://example.com/document.pdfx", False), # Not exactly .pdf
("https://example.com/document.doc", False),
("https://example.com/document.txt", False),
("https://example.com?file=document.pdf", False), # Query param, not path
("", False), # Empty string
(None, False), # None value
]
for url, expected in test_cases:
assert is_pdf_url(url) == expected, f"Failed for URL: {url}"
def test_crawl_tool_with_pdf_url(self):
"""Test that PDF URLs return the expected error structure."""
pdf_url = "https://example.com/document.pdf"
# Act
result = crawl_tool.invoke({"url": pdf_url})
# Assert
assert isinstance(result, str)
result_dict = json.loads(result)
# Check structure of PDF error response
assert result_dict["url"] == pdf_url
assert "error" in result_dict
assert result_dict["crawled_content"] is None
assert result_dict["is_pdf"] is True
assert "PDF files cannot be crawled directly" in result_dict["error"]
def test_crawl_tool_with_issue_pdf_url(self):
"""Test with the exact PDF URL from issue #701."""
issue_pdf_url = "https://pdf.dfcfw.com/pdf/H3_AP202503071644153386_1.pdf"
# Act
result = crawl_tool.invoke({"url": issue_pdf_url})
# Assert
result_dict = json.loads(result)
assert result_dict["url"] == issue_pdf_url
assert result_dict["is_pdf"] is True
assert "cannot be crawled directly" in result_dict["error"]
@patch("src.tools.crawl.Crawler")
@patch("src.tools.crawl.logger")
def test_crawl_tool_skips_crawler_for_pdfs(self, mock_logger, mock_crawler_class):
"""Test that the crawler is not instantiated for PDF URLs."""
pdf_url = "https://example.com/document.pdf"
# Act
result = crawl_tool.invoke({"url": pdf_url})
# Assert
# Crawler should not be instantiated for PDF URLs
mock_crawler_class.assert_not_called()
mock_logger.info.assert_called_once_with(f"PDF URL detected, skipping crawling: {pdf_url}")
# Should return proper PDF error structure
result_dict = json.loads(result)
assert result_dict["is_pdf"] is True
-119
View File
@@ -1,119 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import Mock, call, patch
from src.tools.decorators import create_logged_tool
class MockBaseTool:
"""Mock base tool class for testing."""
def _run(self, *args, **kwargs):
return "base_result"
class TestLoggedToolMixin:
def test_run_calls_log_operation(self):
"""Test that _run calls _log_operation with correct parameters."""
# Create a logged tool instance
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
# Mock the _log_operation method
tool._log_operation = Mock()
# Call _run with test parameters
args = ("arg1", "arg2")
kwargs = {"key1": "value1", "key2": "value2"}
tool._run(*args, **kwargs)
# Verify _log_operation was called with correct parameters
tool._log_operation.assert_called_once_with("_run", *args, **kwargs)
def test_run_calls_super_run(self):
"""Test that _run calls the parent class _run method."""
# Create a logged tool instance
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
# Mock the parent _run method
with patch.object(
MockBaseTool, "_run", return_value="mocked_result"
) as mock_super_run:
args = ("arg1", "arg2")
kwargs = {"key1": "value1"}
result = tool._run(*args, **kwargs)
# Verify super()._run was called with correct parameters
mock_super_run.assert_called_once_with(*args, **kwargs)
# Verify the result is returned
assert result == "mocked_result"
def test_run_logs_result(self):
"""Test that _run logs the result with debug level."""
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
with patch("src.tools.decorators.logger.debug") as mock_debug:
tool._run("test_arg")
# Verify debug log was called with correct message
mock_debug.assert_has_calls(
[
call("Tool MockBaseTool._run called with parameters: test_arg"),
call("Tool MockBaseTool returned: base_result"),
]
)
def test_run_returns_super_result(self):
"""Test that _run returns the result from parent class."""
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
result = tool._run()
assert result == "base_result"
def test_run_with_no_args(self):
"""Test _run method with no arguments."""
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
with patch("src.tools.decorators.logger.debug") as mock_debug:
tool._log_operation = Mock()
result = tool._run()
# Verify _log_operation called with no args
tool._log_operation.assert_called_once_with("_run")
# Verify result logging
mock_debug.assert_called_once()
assert result == "base_result"
def test_run_with_mixed_args_kwargs(self):
"""Test _run method with both positional and keyword arguments."""
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
tool._log_operation = Mock()
args = ("pos1", "pos2")
kwargs = {"kw1": "val1", "kw2": "val2"}
result = tool._run(*args, **kwargs)
# Verify all arguments passed correctly
tool._log_operation.assert_called_once_with("_run", *args, **kwargs)
assert result == "base_result"
def test_run_class_name_replacement(self):
"""Test that class name 'Logged' prefix is correctly removed in logging."""
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
with patch("src.tools.decorators.logger.debug") as mock_debug:
tool._run()
# Verify the logged class name has 'Logged' prefix removed
call_args = mock_debug.call_args[0][0]
assert "Tool MockBaseTool returned:" in call_args
assert "LoggedMockBaseTool" not in call_args
@@ -1,218 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import Mock, patch
import pytest
import requests
from src.tools.infoquest_search.infoquest_search_api import InfoQuestAPIWrapper
class TestInfoQuestAPIWrapper:
@pytest.fixture
def wrapper(self):
# Create a wrapper instance with mock API key
return InfoQuestAPIWrapper(infoquest_api_key="dummy-key")
@pytest.fixture
def mock_response_data(self):
# Mock search result data
return {
"search_result": {
"results": [
{
"content": {
"results": {
"organic": [
{
"title": "Test Title",
"url": "https://example.com",
"desc": "Test description"
}
],
"top_stories": {
"items": [
{
"time_frame": "2 days ago",
"title": "Test News",
"url": "https://example.com/news",
"source": "Test Source"
}
]
},
"images": {
"items": [
{
"url": "https://example.com/image.jpg",
"alt": "Test image description"
}
]
}
}
}
}
]
}
}
@patch("src.tools.infoquest_search.infoquest_search_api.requests.post")
def test_raw_results_success(self, mock_post, wrapper, mock_response_data):
# Test successful synchronous search results
mock_response = Mock()
mock_response.json.return_value = mock_response_data
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = wrapper.raw_results("test query", time_range=0, site="")
assert result == mock_response_data["search_result"]
mock_post.assert_called_once()
call_args = mock_post.call_args
assert "json" in call_args.kwargs
assert call_args.kwargs["json"]["query"] == "test query"
assert "time_range" not in call_args.kwargs["json"]
assert "site" not in call_args.kwargs["json"]
@patch("src.tools.infoquest_search.infoquest_search_api.requests.post")
def test_raw_results_with_time_range_and_site(self, mock_post, wrapper, mock_response_data):
# Test search with time range and site filtering
mock_response = Mock()
mock_response.json.return_value = mock_response_data
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = wrapper.raw_results("test query", time_range=30, site="example.com")
assert result == mock_response_data["search_result"]
call_args = mock_post.call_args
params = call_args.kwargs["json"]
assert params["time_range"] == 30
assert params["site"] == "example.com"
@patch("src.tools.infoquest_search.infoquest_search_api.requests.post")
def test_raw_results_http_error(self, mock_post, wrapper):
# Test HTTP error handling
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("API Error")
mock_post.return_value = mock_response
with pytest.raises(requests.HTTPError):
wrapper.raw_results("test query", time_range=0, site="")
# Check if pytest-asyncio is available, otherwise mark for conditional skipping
try:
import pytest_asyncio
_asyncio_available = True
except ImportError:
_asyncio_available = False
@pytest.mark.asyncio
async def test_raw_results_async_success(self, wrapper, mock_response_data):
# Skip only if pytest-asyncio is not installed
if not self._asyncio_available:
pytest.skip("pytest-asyncio is not installed")
with patch('json.loads', return_value=mock_response_data):
original_method = InfoQuestAPIWrapper.raw_results_async
async def mock_raw_results_async(self, query, time_range=0, site="", output_format="json"):
return mock_response_data["search_result"]
InfoQuestAPIWrapper.raw_results_async = mock_raw_results_async
try:
result = await wrapper.raw_results_async("test query", time_range=0, site="")
assert result == mock_response_data["search_result"]
finally:
InfoQuestAPIWrapper.raw_results_async = original_method
@pytest.mark.asyncio
async def test_raw_results_async_error(self, wrapper):
if not self._asyncio_available:
pytest.skip("pytest-asyncio is not installed")
original_method = InfoQuestAPIWrapper.raw_results_async
async def mock_raw_results_async_error(self, query, time_range=0, site="", output_format="json"):
raise Exception("Error 400: Bad Request")
InfoQuestAPIWrapper.raw_results_async = mock_raw_results_async_error
try:
with pytest.raises(Exception, match="Error 400: Bad Request"):
await wrapper.raw_results_async("test query", time_range=0, site="")
finally:
InfoQuestAPIWrapper.raw_results_async = original_method
def test_clean_results_with_images(self, wrapper, mock_response_data):
# Test result cleaning functionality
raw_results = mock_response_data["search_result"]["results"]
cleaned_results = wrapper.clean_results_with_images(raw_results)
assert len(cleaned_results) == 3
# Test page result
page_result = cleaned_results[0]
assert page_result["type"] == "page"
assert page_result["title"] == "Test Title"
assert page_result["url"] == "https://example.com"
assert page_result["desc"] == "Test description"
# Test news result
news_result = cleaned_results[1]
assert news_result["type"] == "news"
assert news_result["time_frame"] == "2 days ago"
assert news_result["title"] == "Test News"
assert news_result["url"] == "https://example.com/news"
assert news_result["source"] == "Test Source"
# Test image result
image_result = cleaned_results[2]
assert image_result["type"] == "image_url"
assert image_result["image_url"] == "https://example.com/image.jpg"
assert image_result["image_description"] == "Test image description"
def test_clean_results_empty_categories(self, wrapper):
# Test result cleaning with empty categories
data = [
{
"content": {
"results": {
"organic": [],
"top_stories": {"items": []},
"images": {"items": []}
}
}
}
]
result = wrapper.clean_results_with_images(data)
assert len(result) == 0
def test_clean_results_url_deduplication(self, wrapper):
# Test URL deduplication functionality
data = [
{
"content": {
"results": {
"organic": [
{
"title": "Test Title 1",
"url": "https://example.com",
"desc": "Description 1"
},
{
"title": "Test Title 2",
"url": "https://example.com",
"desc": "Description 2"
}
]
}
}
}
]
result = wrapper.clean_results_with_images(data)
assert len(result) == 1
assert result[0]["title"] == "Test Title 1"
@@ -1,226 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
from unittest.mock import Mock, patch
import pytest
class TestInfoQuestSearchResults:
@pytest.fixture
def search_tool(self):
"""Create a mock InfoQuestSearchResults instance."""
mock_tool = Mock()
mock_tool.time_range = 30
mock_tool.site = "example.com"
def mock_run(query, **kwargs):
sample_cleaned_results = [
{
"type": "page",
"title": "Test Title",
"url": "https://example.com",
"desc": "Test description"
}
]
sample_raw_results = {
"results": [
{
"content": {
"results": {
"organic": [
{
"title": "Test Title",
"url": "https://example.com",
"desc": "Test description"
}
]
}
}
}
]
}
return json.dumps(sample_cleaned_results, ensure_ascii=False), sample_raw_results
async def mock_arun(query, **kwargs):
return mock_run(query, **kwargs)
mock_tool._run = mock_run
mock_tool._arun = mock_arun
return mock_tool
@pytest.fixture
def sample_raw_results(self):
"""Sample raw results from InfoQuest API."""
return {
"results": [
{
"content": {
"results": {
"organic": [
{
"title": "Test Title",
"url": "https://example.com",
"desc": "Test description"
}
]
}
}
}
]
}
@pytest.fixture
def sample_cleaned_results(self):
"""Sample cleaned results."""
return [
{
"type": "page",
"title": "Test Title",
"url": "https://example.com",
"desc": "Test description"
}
]
def test_init_default_values(self):
"""Test initialization with default values using patch."""
with patch('src.tools.infoquest_search.infoquest_search_results.InfoQuestAPIWrapper') as mock_wrapper_class:
mock_instance = Mock()
mock_wrapper_class.return_value = mock_instance
from src.tools.infoquest_search.infoquest_search_results import InfoQuestSearchResults
with patch.object(InfoQuestSearchResults, '__init__', return_value=None) as mock_init:
InfoQuestSearchResults(infoquest_api_key="dummy-key")
mock_init.assert_called_once()
def test_init_custom_values(self):
"""Test initialization with custom values using patch."""
with patch('src.tools.infoquest_search.infoquest_search_results.InfoQuestAPIWrapper') as mock_wrapper_class:
mock_instance = Mock()
mock_wrapper_class.return_value = mock_instance
from src.tools.infoquest_search.infoquest_search_results import InfoQuestSearchResults
with patch.object(InfoQuestSearchResults, '__init__', return_value=None) as mock_init:
InfoQuestSearchResults(
time_range=10,
site="test.com",
infoquest_api_key="dummy-key"
)
mock_init.assert_called_once()
def test_run_success(
self,
search_tool,
sample_raw_results,
sample_cleaned_results,
):
"""Test successful synchronous run."""
result, raw = search_tool._run("test query")
assert isinstance(result, str)
assert isinstance(raw, dict)
assert "results" in raw
result_data = json.loads(result)
assert isinstance(result_data, list)
assert len(result_data) > 0
def test_run_exception(self, search_tool):
"""Test synchronous run with exception."""
original_run = search_tool._run
def mock_run_with_error(query, **kwargs):
return json.dumps({"error": "API Error"}, ensure_ascii=False), {}
try:
search_tool._run = mock_run_with_error
result, raw = search_tool._run("test query")
result_dict = json.loads(result)
assert "error" in result_dict
assert "API Error" in result_dict["error"]
assert raw == {}
finally:
search_tool._run = original_run
@pytest.mark.asyncio
async def test_arun_success(
self,
search_tool,
sample_raw_results,
sample_cleaned_results,
):
"""Test successful asynchronous run."""
result, raw = await search_tool._arun("test query")
assert isinstance(result, str)
assert isinstance(raw, dict)
assert "results" in raw
@pytest.mark.asyncio
async def test_arun_exception(self, search_tool):
"""Test asynchronous run with exception."""
original_arun = search_tool._arun
async def mock_arun_with_error(query, **kwargs):
return json.dumps({"error": "Async API Error"}, ensure_ascii=False), {}
try:
search_tool._arun = mock_arun_with_error
result, raw = await search_tool._arun("test query")
result_dict = json.loads(result)
assert "error" in result_dict
assert "Async API Error" in result_dict["error"]
assert raw == {}
finally:
search_tool._arun = original_arun
def test_run_with_run_manager(
self,
search_tool,
sample_raw_results,
sample_cleaned_results,
):
"""Test run with callback manager."""
mock_run_manager = Mock()
result, raw = search_tool._run("test query", run_manager=mock_run_manager)
assert isinstance(result, str)
assert isinstance(raw, dict)
@pytest.mark.asyncio
async def test_arun_with_run_manager(
self,
search_tool,
sample_raw_results,
sample_cleaned_results,
):
"""Test async run with callback manager."""
mock_run_manager = Mock()
result, raw = await search_tool._arun("test query", run_manager=mock_run_manager)
assert isinstance(result, str)
assert isinstance(raw, dict)
def test_api_wrapper_initialization_with_key(self):
"""Test API wrapper initialization with key."""
with patch('src.tools.infoquest_search.infoquest_search_results.InfoQuestAPIWrapper') as mock_wrapper_class:
mock_instance = Mock()
mock_wrapper_class.return_value = mock_instance
from src.tools.infoquest_search.infoquest_search_results import InfoQuestSearchResults
with patch.object(InfoQuestSearchResults, '__init__', return_value=None) as mock_init:
InfoQuestSearchResults(infoquest_api_key="test-key")
mock_init.assert_called_once()
-222
View File
@@ -1,222 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
from unittest.mock import patch
import pytest
from src.tools.python_repl import python_repl_tool
class TestPythonReplTool:
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_successful_code_execution(self, mock_logger, mock_repl):
# Arrange
code = "print('Hello, World!')"
expected_output = "Hello, World!\n"
mock_repl.run.return_value = expected_output
# Act
result = python_repl_tool.invoke({"code": code})
# Assert
mock_repl.run.assert_called_once_with(code)
mock_logger.info.assert_called_with("Code execution successful")
assert "Successfully executed:" in result
assert code in result
assert expected_output in result
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_invalid_input_type(self, mock_logger, mock_repl):
# Arrange
invalid_code = 123
# Act & Assert - expect ValidationError when passing invalid input
with pytest.raises(Exception): # Could be ValidationError or similar
python_repl_tool.invoke({"code": invalid_code})
mock_repl.run.assert_not_called()
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_code_execution_with_error_in_result(self, mock_logger, mock_repl):
# Arrange
code = "invalid_function()"
error_result = "NameError: name 'invalid_function' is not defined"
mock_repl.run.return_value = error_result
# Act
result = python_repl_tool.invoke({"code": code})
# Assert
mock_repl.run.assert_called_once_with(code)
mock_logger.error.assert_called_with(error_result)
assert "Error executing code:" in result
assert code in result
assert error_result in result
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_code_execution_with_exception_in_result(self, mock_logger, mock_repl):
# Arrange
code = "1/0"
exception_result = "ZeroDivisionError: division by zero"
mock_repl.run.return_value = exception_result
# Act
result = python_repl_tool.invoke({"code": code})
# Assert
mock_repl.run.assert_called_once_with(code)
mock_logger.error.assert_called_with(exception_result)
assert "Error executing code:" in result
assert code in result
assert exception_result in result
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_code_execution_raises_exception(self, mock_logger, mock_repl):
# Arrange
code = "print('test')"
exception = RuntimeError("REPL failed")
mock_repl.run.side_effect = exception
# Act
result = python_repl_tool.invoke({"code": code})
# Assert
mock_repl.run.assert_called_once_with(code)
mock_logger.error.assert_called_with(repr(exception))
assert "Error executing code:" in result
assert code in result
assert repr(exception) in result
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_successful_execution_with_calculation(self, mock_logger, mock_repl):
# Arrange
code = "result = 2 + 3\nprint(result)"
expected_output = "5\n"
mock_repl.run.return_value = expected_output
# Act
result = python_repl_tool.invoke({"code": code})
# Assert
mock_repl.run.assert_called_once_with(code)
mock_logger.info.assert_any_call("Executing Python code")
mock_logger.info.assert_any_call("Code execution successful")
assert "Successfully executed:" in result
assert code in result
assert expected_output in result
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_empty_string_code(self, mock_logger, mock_repl):
# Arrange
code = ""
mock_repl.run.return_value = ""
# Act
result = python_repl_tool.invoke({"code": code})
# Assert
mock_repl.run.assert_called_once_with(code)
mock_logger.info.assert_called_with("Code execution successful")
assert "Successfully executed:" in result
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_logging_calls(self, mock_logger, mock_repl):
# Arrange
code = "x = 1"
mock_repl.run.return_value = ""
# Act
python_repl_tool.invoke({"code": code})
# Assert
mock_logger.info.assert_any_call("Executing Python code")
mock_logger.info.assert_any_call("Code execution successful")
# New tests for configuration behavior
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "false"})
@patch("src.tools.python_repl.logger")
def test_tool_disabled(self, mock_logger):
# Arrange
code = "print('test')"
# Act
result = python_repl_tool.invoke({"code": code})
# Assert
mock_logger.warning.assert_called_with(
"Python REPL tool is disabled. Please enable it in environment configuration."
)
assert "Tool disabled:" in result
assert "Python REPL tool is disabled" in result
@patch.dict(os.environ, {}, clear=True)
@patch("src.tools.python_repl.logger")
def test_tool_disabled_by_default(self, mock_logger):
# Arrange - remove any existing ENABLE_PYTHON_REPL variable
if "ENABLE_PYTHON_REPL" in os.environ:
del os.environ["ENABLE_PYTHON_REPL"]
code = "print('test')"
# Act
result = python_repl_tool.invoke({"code": code})
# Assert
mock_logger.warning.assert_called_with(
"Python REPL tool is disabled. Please enable it in environment configuration."
)
assert "Tool disabled:" in result
@pytest.mark.parametrize("env_value", ["true", "True", "TRUE", "1", "yes", "on"])
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_tool_enabled_with_various_truthy_values(
self, mock_logger, mock_repl, env_value
):
# Arrange
with patch.dict(os.environ, {"ENABLE_PYTHON_REPL": env_value}):
code = "print('enabled')"
expected_output = "enabled\n"
mock_repl.run.return_value = expected_output
# Act
result = python_repl_tool.invoke({"code": code})
# Assert
mock_repl.run.assert_called_once_with(code)
assert "Successfully executed:" in result
@pytest.mark.parametrize(
"env_value", ["false", "False", "FALSE", "0", "no", "off", ""]
)
@patch("src.tools.python_repl.logger")
def test_tool_disabled_with_various_falsy_values(self, mock_logger, env_value):
# Arrange
with patch.dict(os.environ, {"ENABLE_PYTHON_REPL": env_value}):
code = "print('disabled')"
# Act
result = python_repl_tool.invoke({"code": code})
# Assert
mock_logger.warning.assert_called_with(
"Python REPL tool is disabled. Please enable it in environment configuration."
)
assert "Tool disabled:" in result
-291
View File
@@ -1,291 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
from unittest.mock import patch
import pytest
from pydantic import ValidationError
from src.config import SearchEngine
from src.tools.search import get_web_search_tool
class TestGetWebSearchTool:
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
def test_get_web_search_tool_tavily(self):
tool = get_web_search_tool(max_search_results=5)
assert tool.name == "web_search"
assert tool.max_results == 5
assert tool.include_raw_content is False
assert tool.include_images is True
assert tool.include_image_descriptions is True
assert tool.include_answer is False
assert tool.search_depth == "advanced"
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.DUCKDUCKGO.value)
def test_get_web_search_tool_duckduckgo(self):
tool = get_web_search_tool(max_search_results=3)
assert tool.name == "web_search"
assert tool.max_results == 3
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.BRAVE_SEARCH.value)
@patch.dict(os.environ, {"BRAVE_SEARCH_API_KEY": "test_api_key"})
def test_get_web_search_tool_brave(self):
tool = get_web_search_tool(max_search_results=4)
assert tool.name == "web_search"
assert tool.search_wrapper.api_key.get_secret_value() == "test_api_key"
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.ARXIV.value)
def test_get_web_search_tool_arxiv(self):
tool = get_web_search_tool(max_search_results=2)
assert tool.name == "web_search"
assert tool.api_wrapper.top_k_results == 2
assert tool.api_wrapper.load_max_docs == 2
assert tool.api_wrapper.load_all_available_meta is True
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", "unsupported_engine")
def test_get_web_search_tool_unsupported_engine(self):
with pytest.raises(
ValueError, match="Unsupported search engine: unsupported_engine"
):
get_web_search_tool(max_search_results=1)
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.BRAVE_SEARCH.value)
@patch.dict(os.environ, {}, clear=True)
def test_get_web_search_tool_brave_no_api_key(self):
tool = get_web_search_tool(max_search_results=1)
assert tool.search_wrapper.api_key.get_secret_value() == ""
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.SERPER.value)
@patch.dict(os.environ, {"SERPER_API_KEY": "test_serper_key"})
def test_get_web_search_tool_serper(self):
tool = get_web_search_tool(max_search_results=6)
assert tool.name == "web_search"
assert tool.api_wrapper.k == 6
assert tool.api_wrapper.serper_api_key == "test_serper_key"
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.SERPER.value)
@patch.dict(os.environ, {}, clear=True)
def test_get_web_search_tool_serper_no_api_key(self):
with pytest.raises(ValidationError):
get_web_search_tool(max_search_results=1)
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_get_web_search_tool_tavily_with_custom_config(self, mock_config):
"""Test Tavily tool with custom configuration values."""
mock_config.return_value = {
"SEARCH_ENGINE": {
"include_answer": True,
"search_depth": "basic",
"include_raw_content": True,
"include_images": False,
"include_image_descriptions": True,
"include_domains": ["example.com"],
"exclude_domains": ["spam.com"],
}
}
tool = get_web_search_tool(max_search_results=5)
assert tool.name == "web_search"
assert tool.max_results == 5
assert tool.include_answer is True
assert tool.search_depth == "basic"
assert tool.include_raw_content is True
assert tool.include_images is False
# include_image_descriptions should be False because include_images is False
assert tool.include_image_descriptions is False
assert tool.include_domains == ["example.com"]
assert tool.exclude_domains == ["spam.com"]
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_get_web_search_tool_tavily_with_empty_config(self, mock_config):
"""Test Tavily tool uses defaults when config is empty."""
mock_config.return_value = {"SEARCH_ENGINE": {}}
tool = get_web_search_tool(max_search_results=10)
assert tool.name == "web_search"
assert tool.max_results == 10
assert tool.include_answer is False
assert tool.search_depth == "advanced"
assert tool.include_raw_content is False
assert tool.include_images is True
assert tool.include_image_descriptions is True
assert tool.include_domains == []
assert tool.exclude_domains == []
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_get_web_search_tool_tavily_image_descriptions_disabled_when_images_disabled(
self, mock_config
):
"""Test that include_image_descriptions is False when include_images is False."""
mock_config.return_value = {
"SEARCH_ENGINE": {
"include_images": False,
"include_image_descriptions": True, # This should be ignored
}
}
tool = get_web_search_tool(max_search_results=5)
assert tool.include_images is False
assert tool.include_image_descriptions is False
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_get_web_search_tool_tavily_partial_config(self, mock_config):
"""Test Tavily tool with partial configuration."""
mock_config.return_value = {
"SEARCH_ENGINE": {
"include_answer": True,
"include_domains": ["trusted.com"],
}
}
tool = get_web_search_tool(max_search_results=3)
assert tool.include_answer is True
assert tool.search_depth == "advanced" # default
assert tool.include_raw_content is False # default
assert tool.include_domains == ["trusted.com"]
assert tool.exclude_domains == [] # default
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_get_web_search_tool_tavily_with_no_config_file(self, mock_config):
"""Test Tavily tool when config file doesn't exist."""
mock_config.return_value = {}
tool = get_web_search_tool(max_search_results=5)
assert tool.name == "web_search"
assert tool.max_results == 5
assert tool.include_answer is False
assert tool.search_depth == "advanced"
assert tool.include_raw_content is False
assert tool.include_images is True
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_get_web_search_tool_tavily_multiple_domains(self, mock_config):
"""Test Tavily tool with multiple domains in include/exclude lists."""
mock_config.return_value = {
"SEARCH_ENGINE": {
"include_domains": ["example.com", "trusted.com", "gov.cn"],
"exclude_domains": ["spam.com", "scam.org"],
}
}
tool = get_web_search_tool(max_search_results=5)
assert tool.include_domains == ["example.com", "trusted.com", "gov.cn"]
assert tool.exclude_domains == ["spam.com", "scam.org"]
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_tavily_with_no_search_engine_section(self, mock_config):
"""Test Tavily tool when SEARCH_ENGINE section doesn't exist in config."""
mock_config.return_value = {"OTHER_CONFIG": {}}
tool = get_web_search_tool(max_search_results=5)
assert tool.name == "web_search"
assert tool.max_results == 5
assert tool.include_answer is False
assert tool.search_depth == "advanced"
assert tool.include_raw_content is False
assert tool.include_images is True
assert tool.include_domains == []
assert tool.exclude_domains == []
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_tavily_with_completely_empty_config(self, mock_config):
"""Test Tavily tool with completely empty config."""
mock_config.return_value = {}
tool = get_web_search_tool(max_search_results=5)
assert tool.name == "web_search"
assert tool.max_results == 5
assert tool.include_answer is False
assert tool.search_depth == "advanced"
assert tool.include_raw_content is False
assert tool.include_images is True
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_tavily_with_only_include_answer_param(self, mock_config):
"""Test Tavily tool with only include_answer parameter specified."""
mock_config.return_value = {"SEARCH_ENGINE": {"include_answer": True}}
tool = get_web_search_tool(max_search_results=5)
assert tool.include_answer is True
assert tool.search_depth == "advanced"
assert tool.include_raw_content is False
assert tool.include_images is True
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_tavily_with_only_search_depth_param(self, mock_config):
"""Test Tavily tool with only search_depth parameter specified."""
mock_config.return_value = {"SEARCH_ENGINE": {"search_depth": "basic"}}
tool = get_web_search_tool(max_search_results=5)
assert tool.search_depth == "basic"
assert tool.include_answer is False
assert tool.include_raw_content is False
assert tool.include_images is True
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_tavily_with_only_include_domains_param(self, mock_config):
"""Test Tavily tool with only include_domains parameter specified."""
mock_config.return_value = {
"SEARCH_ENGINE": {"include_domains": ["example.com"]}
}
tool = get_web_search_tool(max_search_results=5)
assert tool.include_domains == ["example.com"]
assert tool.exclude_domains == []
assert tool.include_answer is False
assert tool.search_depth == "advanced"
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_tavily_with_explicit_false_boolean_values(self, mock_config):
"""Test that explicitly False boolean values are respected (not treated as missing)."""
mock_config.return_value = {
"SEARCH_ENGINE": {
"include_answer": False,
"include_raw_content": False,
"include_images": False,
}
}
tool = get_web_search_tool(max_search_results=5)
assert tool.include_answer is False
assert tool.include_raw_content is False
assert tool.include_images is False
assert tool.include_image_descriptions is False
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_tavily_with_empty_domain_lists(self, mock_config):
"""Test that empty domain lists are treated as optional."""
mock_config.return_value = {
"SEARCH_ENGINE": {
"include_domains": [],
"exclude_domains": [],
}
}
tool = get_web_search_tool(max_search_results=5)
assert tool.include_domains == []
assert tool.exclude_domains == []
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
@patch("src.tools.search.load_yaml_config")
def test_tavily_all_parameters_optional_mix(self, mock_config):
"""Test that any combination of optional parameters works."""
mock_config.return_value = {
"SEARCH_ENGINE": {
"include_answer": True,
"include_images": False,
# Deliberately omit search_depth, include_raw_content, domains
}
}
tool = get_web_search_tool(max_search_results=5)
assert tool.include_answer is True
assert tool.include_images is False
assert (
tool.include_image_descriptions is False
) # should be False since include_images is False
assert tool.search_depth == "advanced" # default
assert tool.include_raw_content is False # default
assert tool.include_domains == [] # default
assert tool.exclude_domains == [] # default
@@ -1,263 +0,0 @@
import pytest
from src.tools.search_postprocessor import SearchResultPostProcessor
class TestSearchResultPostProcessor:
"""Test cases for SearchResultPostProcessor"""
@pytest.fixture
def post_processor(self):
"""Create a SearchResultPostProcessor instance for testing"""
return SearchResultPostProcessor(
min_score_threshold=0.5, max_content_length_per_page=100
)
def test_process_results_empty_input(self, post_processor):
"""Test processing empty results"""
results = []
processed = post_processor.process_results(results)
assert processed == []
def test_process_results_with_valid_page_results(self, post_processor):
"""Test processing valid page results"""
results = [
{
"type": "page",
"title": "Test Page",
"url": "https://example.com",
"content": "Test content",
"score": 0.8,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["title"] == "Test Page"
assert processed[0]["url"] == "https://example.com"
assert processed[0]["content"] == "Test content"
assert processed[0]["score"] == 0.8
def test_process_results_filter_low_score(self, post_processor):
"""Test filtering out low score results"""
results = [
{
"type": "page",
"title": "Low Score Page",
"url": "https://example.com/low",
"content": "Low score content",
"score": 0.3, # Below threshold of 0.5
},
{
"type": "page",
"title": "High Score Page",
"url": "https://example.com/high",
"content": "High score content",
"score": 0.9,
},
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["title"] == "High Score Page"
def test_process_results_remove_duplicates(self, post_processor):
"""Test removing duplicate URLs"""
results = [
{
"type": "page",
"title": "Page 1",
"url": "https://example.com",
"content": "Content 1",
"score": 0.8,
},
{
"type": "page",
"title": "Page 2",
"url": "https://example.com", # Duplicate URL
"content": "Content 2",
"score": 0.7,
},
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["title"] == "Page 1" # First one should be kept
def test_process_results_sort_by_score(self, post_processor):
"""Test sorting results by score in descending order"""
results = [
{
"type": "page",
"title": "Low Score",
"url": "https://example.com/low",
"content": "Low score content",
"score": 0.3,
},
{
"type": "page",
"title": "High Score",
"url": "https://example.com/high",
"content": "High score content",
"score": 0.9,
},
{
"type": "page",
"title": "Medium Score",
"url": "https://example.com/medium",
"content": "Medium score content",
"score": 0.6,
},
]
processed = post_processor.process_results(results)
assert len(processed) == 2 # Low score filtered out
# Should be sorted by score descending
assert processed[0]["title"] == "High Score"
assert processed[1]["title"] == "Medium Score"
def test_process_results_truncate_long_content(self, post_processor):
"""Test truncating long content"""
long_content = "A" * 150 # Longer than max_content_length of 100
results = [
{
"type": "page",
"title": "Long Content Page",
"url": "https://example.com",
"content": long_content,
"score": 0.8,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert len(processed[0]["content"]) == 103 # 100 + "..."
assert processed[0]["content"].endswith("...")
def test_process_results_remove_base64_images(self, post_processor):
"""Test removing base64 images from content"""
content_with_base64 = (
"Content with image "
+ "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
)
results = [
{
"type": "page",
"title": "Page with Base64",
"url": "https://example.com",
"content": content_with_base64,
"score": 0.8,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["content"] == "Content with image "
def test_process_results_with_image_type(self, post_processor):
"""Test processing image type results"""
results = [
{
"type": "image",
"image_url": "https://example.com/image.jpg",
"image_description": "Test image",
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["type"] == "image"
assert processed[0]["image_url"] == "https://example.com/image.jpg"
assert processed[0]["image_description"] == "Test image"
def test_process_results_filter_base64_image_urls(self, post_processor):
"""Test filtering out image results with base64 URLs"""
results = [
{
"type": "image",
"image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==",
"image_description": "Base64 image",
},
{
"type": "image",
"image_url": "https://example.com/image.jpg",
"image_description": "Regular image",
},
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["image_url"] == "https://example.com/image.jpg"
def test_process_results_truncate_long_image_description(self, post_processor):
"""Test truncating long image descriptions"""
long_description = "A" * 150 # Longer than max_content_length of 100
results = [
{
"type": "image",
"image_url": "https://example.com/image.jpg",
"image_description": long_description,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert len(processed[0]["image_description"]) == 103 # 100 + "..."
assert processed[0]["image_description"].endswith("...")
def test_process_results_other_types_passthrough(self, post_processor):
"""Test that other result types pass through unchanged"""
results = [
{
"type": "video",
"title": "Test Video",
"url": "https://example.com/video.mp4",
"score": 0.8,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["type"] == "video"
assert processed[0]["title"] == "Test Video"
def test_process_results_truncate_long_content_with_no_config(self):
"""Test truncating long content"""
post_processor = SearchResultPostProcessor(None, None)
long_content = "A" * 150 # Longer than max_content_length of 100
results = [
{
"type": "page",
"title": "Long Content Page",
"url": "https://example.com",
"content": long_content,
"score": 0.8,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert len(processed[0]["content"]) == len("A" * 150)
def test_process_results_truncate_long_content_with_max_content_length_config(self):
"""Test truncating long content"""
post_processor = SearchResultPostProcessor(None, 100)
long_content = "A" * 150 # Longer than max_content_length of 100
results = [
{
"type": "page",
"title": "Long Content Page",
"url": "https://example.com",
"content": long_content,
"score": 0.8,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert len(processed[0]["content"]) == 103
assert processed[0]["content"].endswith("...")
def test_process_results_truncate_long_content_with_min_score_config(self):
"""Test truncating long content"""
post_processor = SearchResultPostProcessor(0.8, None)
long_content = "A" * 150 # Longer than max_content_length of 100
results = [
{
"type": "page",
"title": "Long Content Page",
"url": "https://example.com",
"content": long_content,
"score": 0.3,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 0
@@ -1,207 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
import requests
from src.tools.tavily_search.tavily_search_api_wrapper import (
EnhancedTavilySearchAPIWrapper,
)
class TestEnhancedTavilySearchAPIWrapper:
@pytest.fixture
def wrapper(self):
with patch(
"src.tools.tavily_search.tavily_search_api_wrapper.OriginalTavilySearchAPIWrapper"
):
wrapper = EnhancedTavilySearchAPIWrapper(tavily_api_key="dummy-key")
# The parent class is mocked, so initialization won't fail
return wrapper
@pytest.fixture
def mock_response_data(self):
return {
"results": [
{
"title": "Test Title",
"url": "https://example.com",
"content": "Test content",
"score": 0.9,
"raw_content": "Raw test content",
}
],
"images": [
{
"url": "https://example.com/image.jpg",
"description": "Test image description",
}
],
}
@patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post")
def test_raw_results_success(self, mock_post, wrapper, mock_response_data):
mock_response = Mock()
mock_response.json.return_value = mock_response_data
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = wrapper.raw_results("test query", max_results=10)
assert result == mock_response_data
mock_post.assert_called_once()
call_args = mock_post.call_args
assert "json" in call_args.kwargs
assert call_args.kwargs["json"]["query"] == "test query"
assert call_args.kwargs["json"]["max_results"] == 10
@patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post")
def test_raw_results_with_all_parameters(
self, mock_post, wrapper, mock_response_data
):
mock_response = Mock()
mock_response.json.return_value = mock_response_data
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = wrapper.raw_results(
"test query",
max_results=3,
search_depth="basic",
include_domains=["example.com"],
exclude_domains=["spam.com"],
include_answer=True,
include_raw_content=True,
include_images=True,
include_image_descriptions=True,
)
assert result == mock_response_data
call_args = mock_post.call_args
params = call_args.kwargs["json"]
assert params["include_domains"] == ["example.com"]
assert params["exclude_domains"] == ["spam.com"]
assert params["include_answer"] is True
assert params["include_raw_content"] is True
@patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post")
def test_raw_results_http_error(self, mock_post, wrapper):
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("API Error")
mock_post.return_value = mock_response
with pytest.raises(requests.HTTPError):
wrapper.raw_results("test query")
@pytest.mark.asyncio
async def test_raw_results_async_success(self, wrapper, mock_response_data):
# Create a mock that acts as both the response and its context manager
mock_response_cm = AsyncMock()
mock_response_cm.__aenter__ = AsyncMock(return_value=mock_response_cm)
mock_response_cm.__aexit__ = AsyncMock(return_value=None)
mock_response_cm.status = 200
mock_response_cm.text = AsyncMock(return_value=json.dumps(mock_response_data))
# Create mock session that returns the context manager
mock_session = AsyncMock()
mock_session.post = MagicMock(
return_value=mock_response_cm
) # Use MagicMock, not AsyncMock
# Create mock session class
mock_session_cm = AsyncMock()
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_cm.__aexit__ = AsyncMock(return_value=None)
with patch(
"src.tools.tavily_search.tavily_search_api_wrapper.aiohttp.ClientSession",
return_value=mock_session_cm,
):
result = await wrapper.raw_results_async("test query")
assert result == mock_response_data
@pytest.mark.asyncio
async def test_raw_results_async_error(self, wrapper):
# Create a mock that acts as both the response and its context manager
mock_response_cm = AsyncMock()
mock_response_cm.__aenter__ = AsyncMock(return_value=mock_response_cm)
mock_response_cm.__aexit__ = AsyncMock(return_value=None)
mock_response_cm.status = 400
mock_response_cm.reason = "Bad Request"
# Create mock session that returns the context manager
mock_session = AsyncMock()
mock_session.post = MagicMock(
return_value=mock_response_cm
) # Use MagicMock, not AsyncMock
# Create mock session class
mock_session_cm = AsyncMock()
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_cm.__aexit__ = AsyncMock(return_value=None)
with patch(
"src.tools.tavily_search.tavily_search_api_wrapper.aiohttp.ClientSession",
return_value=mock_session_cm,
):
with pytest.raises(Exception, match="Error 400: Bad Request"):
await wrapper.raw_results_async("test query")
def test_clean_results_with_images(self, wrapper, mock_response_data):
result = wrapper.clean_results_with_images(mock_response_data)
assert len(result) == 2
# Test page result
page_result = result[0]
assert page_result["type"] == "page"
assert page_result["title"] == "Test Title"
assert page_result["url"] == "https://example.com"
assert page_result["content"] == "Test content"
assert page_result["score"] == 0.9
assert page_result["raw_content"] == "Raw test content"
# Test image result
image_result = result[1]
assert image_result["type"] == "image_url"
assert image_result["image_url"] == {"url": "https://example.com/image.jpg"}
assert image_result["image_description"] == "Test image description"
def test_clean_results_without_raw_content(self, wrapper):
data = {
"results": [
{
"title": "Test Title",
"url": "https://example.com",
"content": "Test content",
"score": 0.9,
}
],
"images": [],
}
result = wrapper.clean_results_with_images(data)
assert len(result) == 1
assert "raw_content" not in result[0]
def test_clean_results_empty_images(self, wrapper):
data = {
"results": [
{
"title": "Test Title",
"url": "https://example.com",
"content": "Test content",
"score": 0.9,
}
],
"images": [],
}
result = wrapper.clean_results_with_images(data)
assert len(result) == 1
assert result[0]["type"] == "page"
@@ -1,206 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
from unittest.mock import AsyncMock, Mock, patch
import pytest
from src.tools.tavily_search.tavily_search_api_wrapper import (
EnhancedTavilySearchAPIWrapper,
)
from src.tools.tavily_search.tavily_search_results_with_images import (
TavilySearchWithImages,
)
class TestTavilySearchWithImages:
@pytest.fixture
def mock_api_wrapper(self):
"""Create a mock API wrapper."""
wrapper = Mock(spec=EnhancedTavilySearchAPIWrapper)
return wrapper
@pytest.fixture
def search_tool(self, mock_api_wrapper):
"""Create a TavilySearchWithImages instance with mocked dependencies."""
tool = TavilySearchWithImages(
max_results=5,
include_answer=True,
include_raw_content=True,
include_images=True,
include_image_descriptions=True,
)
tool.api_wrapper = mock_api_wrapper
return tool
@pytest.fixture
def sample_raw_results(self):
"""Sample raw results from Tavily API."""
return {
"query": "test query",
"answer": "Test answer",
"images": ["https://example.com/image1.jpg"],
"results": [
{
"title": "Test Title",
"url": "https://example.com",
"content": "Test content",
"score": 0.95,
"raw_content": "Raw test content",
}
],
"response_time": 1.5,
}
@pytest.fixture
def sample_cleaned_results(self):
"""Sample cleaned results."""
return [
{
"title": "Test Title",
"url": "https://example.com",
"content": "Test content",
}
]
def test_init_default_values(self):
"""Test initialization with default values."""
tool = TavilySearchWithImages()
assert tool.include_image_descriptions is False
assert isinstance(tool.api_wrapper, EnhancedTavilySearchAPIWrapper)
def test_init_custom_values(self):
"""Test initialization with custom values."""
tool = TavilySearchWithImages(max_results=10, include_image_descriptions=True)
assert tool.max_results == 10
assert tool.include_image_descriptions is True
def test_run_success(
self,
search_tool,
mock_api_wrapper,
sample_raw_results,
sample_cleaned_results,
):
"""Test successful synchronous run."""
mock_api_wrapper.raw_results.return_value = sample_raw_results
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
result, raw = search_tool._run("test query")
assert result == json.dumps(sample_cleaned_results, ensure_ascii=False)
assert raw == sample_raw_results
mock_api_wrapper.raw_results.assert_called_once_with(
"test query",
search_tool.max_results,
search_tool.search_depth,
search_tool.include_domains,
search_tool.exclude_domains,
search_tool.include_answer,
search_tool.include_raw_content,
search_tool.include_images,
search_tool.include_image_descriptions,
)
mock_api_wrapper.clean_results_with_images.assert_called_once_with(
sample_raw_results
)
def test_run_exception(self, search_tool, mock_api_wrapper):
"""Test synchronous run with exception."""
mock_api_wrapper.raw_results.side_effect = Exception("API Error")
result, raw = search_tool._run("test query")
result_dict = json.loads(result)
assert "error" in result_dict
assert "API Error" in result_dict["error"]
assert raw == {}
mock_api_wrapper.clean_results_with_images.assert_not_called()
@pytest.mark.asyncio
async def test_arun_success(
self,
search_tool,
mock_api_wrapper,
sample_raw_results,
sample_cleaned_results,
):
"""Test successful asynchronous run."""
mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results)
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
result, raw = await search_tool._arun("test query")
assert result == json.dumps(sample_cleaned_results, ensure_ascii=False)
assert raw == sample_raw_results
mock_api_wrapper.raw_results_async.assert_called_once_with(
"test query",
search_tool.max_results,
search_tool.search_depth,
search_tool.include_domains,
search_tool.exclude_domains,
search_tool.include_answer,
search_tool.include_raw_content,
search_tool.include_images,
search_tool.include_image_descriptions,
)
mock_api_wrapper.clean_results_with_images.assert_called_once_with(
sample_raw_results
)
@pytest.mark.asyncio
async def test_arun_exception(self, search_tool, mock_api_wrapper):
"""Test asynchronous run with exception."""
mock_api_wrapper.raw_results_async = AsyncMock(
side_effect=Exception("Async API Error")
)
result, raw = await search_tool._arun("test query")
result_dict = json.loads(result)
assert "error" in result_dict
assert "Async API Error" in result_dict["error"]
assert raw == {}
mock_api_wrapper.clean_results_with_images.assert_not_called()
def test_run_with_run_manager(
self,
search_tool,
mock_api_wrapper,
sample_raw_results,
sample_cleaned_results,
):
"""Test run with callback manager."""
mock_run_manager = Mock()
mock_api_wrapper.raw_results.return_value = sample_raw_results
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
result, raw = search_tool._run("test query", run_manager=mock_run_manager)
assert result == json.dumps(sample_cleaned_results, ensure_ascii=False)
assert raw == sample_raw_results
@pytest.mark.asyncio
async def test_arun_with_run_manager(
self,
search_tool,
mock_api_wrapper,
sample_raw_results,
sample_cleaned_results,
):
"""Test async run with callback manager."""
mock_run_manager = Mock()
mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results)
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
result, raw = await search_tool._arun(
"test query", run_manager=mock_run_manager
)
assert result == json.dumps(sample_cleaned_results, ensure_ascii=False)
assert raw == sample_raw_results
-126
View File
@@ -1,126 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import Mock, patch
import pytest
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from src.rag import Chunk, Document, Resource, Retriever
from src.tools.retriever import RetrieverInput, RetrieverTool, get_retriever_tool
def test_retriever_input_model():
input_data = RetrieverInput(keywords="test keywords")
assert input_data.keywords == "test keywords"
def test_retriever_tool_init():
mock_retriever = Mock(spec=Retriever)
resources = [Resource(uri="test://uri", title="Test")]
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
assert tool.name == "local_search_tool"
assert "retrieving information" in tool.description
assert tool.args_schema == RetrieverInput
assert tool.retriever == mock_retriever
assert tool.resources == resources
def test_retriever_tool_run_with_results():
mock_retriever = Mock(spec=Retriever)
chunk = Chunk(content="test content", similarity=0.9)
doc = Document(id="doc1", chunks=[chunk])
mock_retriever.query_relevant_documents.return_value = [doc]
resources = [Resource(uri="test://uri", title="Test")]
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
result = tool._run("test keywords")
mock_retriever.query_relevant_documents.assert_called_once_with(
"test keywords", resources
)
assert isinstance(result, list)
assert len(result) == 1
assert result[0] == doc.to_dict()
def test_retriever_tool_run_no_results():
mock_retriever = Mock(spec=Retriever)
mock_retriever.query_relevant_documents.return_value = []
resources = [Resource(uri="test://uri", title="Test")]
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
result = tool._run("test keywords")
assert result == "No results found from the local knowledge base."
@pytest.mark.asyncio
async def test_retriever_tool_arun():
mock_retriever = Mock(spec=Retriever)
chunk = Chunk(content="async content", similarity=0.8)
doc = Document(id="doc2", chunks=[chunk])
# Mock the async method
async def mock_async_query(*args, **kwargs):
return [doc]
mock_retriever.query_relevant_documents_async = mock_async_query
resources = [Resource(uri="test://uri", title="Test")]
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
mock_run_manager = Mock(spec=AsyncCallbackManagerForToolRun)
result = await tool._arun("async keywords", mock_run_manager)
assert isinstance(result, list)
assert len(result) == 1
assert result[0] == doc.to_dict()
@patch("src.tools.retriever.build_retriever")
def test_get_retriever_tool_success(mock_build_retriever):
mock_retriever = Mock(spec=Retriever)
mock_build_retriever.return_value = mock_retriever
resources = [Resource(uri="test://uri", title="Test")]
tool = get_retriever_tool(resources)
assert isinstance(tool, RetrieverTool)
assert tool.retriever == mock_retriever
assert tool.resources == resources
def test_get_retriever_tool_empty_resources():
result = get_retriever_tool([])
assert result is None
@patch("src.tools.retriever.build_retriever")
def test_get_retriever_tool_no_retriever(mock_build_retriever):
mock_build_retriever.return_value = None
resources = [Resource(uri="test://uri", title="Test")]
result = get_retriever_tool(resources)
assert result is None
def test_retriever_tool_run_with_callback_manager():
mock_retriever = Mock(spec=Retriever)
mock_retriever.query_relevant_documents.return_value = []
resources = [Resource(uri="test://uri", title="Test")]
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
mock_callback_manager = Mock(spec=CallbackManagerForToolRun)
result = tool._run("test keywords", mock_callback_manager)
assert result == "No results found from the local knowledge base."
-235
View File
@@ -1,235 +0,0 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from src.utils.context_manager import ContextManager
class TestContextManager:
"""Test cases for ContextManager"""
def test_count_tokens_with_empty_messages(self):
"""Test counting tokens with empty message list"""
context_manager = ContextManager(token_limit=1000)
messages = []
token_count = context_manager.count_tokens(messages)
assert token_count == 0
def test_count_tokens_with_system_message(self):
"""Test counting tokens with system message"""
context_manager = ContextManager(token_limit=1000)
messages = [SystemMessage(content="You are a helpful assistant.")]
token_count = context_manager.count_tokens(messages)
# System message has 28 characters, should be around 8 tokens (28/4 * 1.1)
assert token_count > 7
def test_count_tokens_with_human_message(self):
"""Test counting tokens with human message"""
context_manager = ContextManager(token_limit=1000)
messages = [HumanMessage(content="你好,这是一个测试消息。")]
token_count = context_manager.count_tokens(messages)
assert token_count > 12
def test_count_tokens_with_ai_message(self):
"""Test counting tokens with AI message"""
context_manager = ContextManager(token_limit=1000)
messages = [AIMessage(content="I'm doing well, thank you for asking!")]
token_count = context_manager.count_tokens(messages)
assert token_count >= 10
def test_count_tokens_with_tool_message(self):
"""Test counting tokens with tool message"""
context_manager = ContextManager(token_limit=1000)
messages = [
ToolMessage(content="Tool execution result data here", tool_call_id="test")
]
token_count = context_manager.count_tokens(messages)
# Tool message has about 32 characters, should be around 10 tokens (32/4 * 1.3)
assert token_count > 0
def test_count_tokens_with_multiple_messages(self):
"""Test counting tokens with multiple messages"""
context_manager = ContextManager(token_limit=1000)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello, how are you?"),
AIMessage(content="I'm doing well, thank you for asking!"),
]
token_count = context_manager.count_tokens(messages)
# Should be sum of all individual message tokens
assert token_count > 0
def test_is_over_limit_when_under_limit(self):
"""Test is_over_limit when messages are under token limit"""
context_manager = ContextManager(token_limit=1000)
short_messages = [HumanMessage(content="Short message")]
is_over = context_manager.is_over_limit(short_messages)
assert is_over is False
def test_is_over_limit_when_over_limit(self):
"""Test is_over_limit when messages exceed token limit"""
# Create a context manager with a very low limit
low_limit_cm = ContextManager(token_limit=5)
long_messages = [
HumanMessage(
content="This is a very long message that should exceed the limit"
)
]
is_over = low_limit_cm.is_over_limit(long_messages)
assert is_over is True
def test_compress_messages_when_not_over_limit(self):
"""Test compress_messages when messages are not over limit"""
context_manager = ContextManager(token_limit=1000)
messages = [HumanMessage(content="Short message")]
compressed = context_manager.compress_messages({"messages": messages})
# Should return the same messages when not over limit
assert len(compressed["messages"]) == len(messages)
def test_compress_messages_with_tool_message(self):
"""Test compress_messages preserves system message and compresses raw_content"""
# Create a context manager with limited token capacity
limited_cm = ContextManager(token_limit=200)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
ToolMessage(
name="web_search",
content='[{"title": "Test Result", "url": "https://example.com", "raw_content": "' + ("This is a test content that should be compressed if it exceeds 1024 characters. " * 2000) + '"}]',
tool_call_id="test_search",
)
]
compressed = limited_cm.compress_messages({"messages": messages})
# Should preserve system message and some recent messages
assert len(compressed["messages"]) == 4
# Verify raw_content was compressed to 1024 characters
import json
for msg in compressed["messages"]:
if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "web_search":
content_data = json.loads(msg.content)
if isinstance(content_data, list):
for item in content_data:
if isinstance(item, dict) and "raw_content" in item:
assert len(item["raw_content"]) == 1024
def test_compress_messages_with_preserve_prefix_message(self):
"""Test compress_messages when no system message is present"""
# Create a context manager with limited token capacity
limited_cm = ContextManager(token_limit=100, preserve_prefix_message_count=2)
messages = [
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
HumanMessage(
content="Can you tell me a very long story that would exceed token limits? "
* 10
),
]
compressed = limited_cm.compress_messages({"messages": messages})
# Should keep only the most recent messages that fit
assert len(compressed["messages"]) == 3
def test_compress_messages_without_config(self):
"""Test compress_messages preserves system message"""
# Create a context manager with limited token capacity
limited_cm = ContextManager(None)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
HumanMessage(
content="Can you tell me a very long story that would exceed token limits? "
* 100
),
]
compressed = limited_cm.compress_messages({"messages": messages})
# return the original messages
assert len(compressed["messages"]) == 4
def test_count_message_tokens_with_additional_kwargs(self):
"""Test counting tokens for messages with additional kwargs"""
context_manager = ContextManager(token_limit=1000)
message = ToolMessage(
content="Tool result",
tool_call_id="test",
additional_kwargs={"tool_calls": [{"name": "test_function"}]},
)
token_count = context_manager._count_message_tokens(message)
assert token_count > 0
def test_count_message_tokens_minimum_one_token(self):
"""Test that message token count is at least 1"""
context_manager = ContextManager(token_limit=1000)
message = HumanMessage(content="") # Empty content
token_count = context_manager._count_message_tokens(message)
assert token_count == 1 # Should be at least 1
def test_count_text_tokens_english_only(self):
"""Test counting tokens for English text"""
context_manager = ContextManager(token_limit=1000)
# 16 English characters should result in 4 tokens (16/4)
text = "This is a test."
token_count = context_manager._count_text_tokens(text)
assert token_count > 0
def test_count_text_tokens_chinese_only(self):
"""Test counting tokens for Chinese text"""
context_manager = ContextManager(token_limit=1000)
# 8 Chinese characters should result in 8 tokens (1:1 ratio)
text = "这是一个测试文本"
token_count = context_manager._count_text_tokens(text)
assert token_count == 8
def test_count_text_tokens_mixed_content(self):
"""Test counting tokens for mixed English and Chinese text"""
context_manager = ContextManager(token_limit=1000)
text = "Hello world 这是一些中文"
token_count = context_manager._count_text_tokens(text)
assert token_count > 6
def test_compress_messages_with_runtime_when_not_over_limit(self):
"""compress_messages accepts runtime param when under limit"""
context_manager = ContextManager(token_limit=1000)
messages = [HumanMessage(content="Short message"), AIMessage(content="OK")]
compressed = context_manager.compress_messages({"messages": messages}, runtime=object())
assert isinstance(compressed, dict)
assert "messages" in compressed
assert len(compressed["messages"]) == len(messages)
def test_compress_messages_with_runtime_when_over_limit(self):
"""compress_messages accepts runtime param and still compresses"""
limited_cm = ContextManager(token_limit=200)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
HumanMessage(
content="Can you tell me a very long story that would exceed token limits? " * 100
),
ToolMessage(
name="web_search",
content='[{"title": "Test Result", "url": "https://example.com", "raw_content": "' + ("This is a test content that should be compressed if it exceeds 1024 characters. " * 2000) + '"}]',
tool_call_id="test_search",
)
]
compressed = limited_cm.compress_messages({"messages": messages}, runtime=object())
assert isinstance(compressed, dict)
assert "messages" in compressed
# Should preserve only what fits; with this setup we expect heavy compression
assert len(compressed["messages"]) == 5
# Verify raw_content was compressed to 1024 characters
import json
for msg in compressed["messages"]:
if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "web_search":
content_data = json.loads(msg.content)
if isinstance(content_data, list):
for item in content_data:
if isinstance(item, dict) and "raw_content" in item:
assert len(item["raw_content"]) == 1024
-581
View File
@@ -1,581 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
from src.utils.json_utils import (
_extract_json_from_content,
repair_json_output,
sanitize_args,
sanitize_tool_response,
)
class TestRepairJsonOutput:
def test_valid_json_object(self):
"""Test with valid JSON object"""
content = '{"key": "value", "number": 123}'
result = repair_json_output(content)
expected = json.dumps({"key": "value", "number": 123}, ensure_ascii=False)
assert result == expected
def test_valid_json_array(self):
"""Test with valid JSON array"""
content = '[1, 2, 3, "test"]'
result = repair_json_output(content)
expected = json.dumps([1, 2, 3, "test"], ensure_ascii=False)
assert result == expected
def test_json_with_code_block_json(self):
"""Test JSON wrapped in ```json code block"""
content = '```json\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_json_with_code_block_ts(self):
"""Test JSON wrapped in ```ts code block"""
content = '```ts\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_json_with_code_block_uppercase_json(self):
"""Test JSON wrapped in ```JSON (uppercase) code block"""
content = '```JSON\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_json_with_code_block_uppercase_ts(self):
"""Test JSON wrapped in ```TS (uppercase) code block"""
content = '```TS\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_json_with_code_block_mixed_case_json(self):
"""Test JSON wrapped in ```Json (mixed case) code block"""
content = '```Json\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_json_with_code_block_uppercase_ts_with_prefix(self):
"""Test JSON wrapped in ```TS code block with prefix text"""
content = 'some prefix ```TS\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_json_with_code_block_uppercase_json_with_prefix(self):
"""Test JSON wrapped in ```JSON code block with prefix text - case sensitive fix"""
# This tests the fix for case-insensitive guard when fence is not at start
content = 'prefix ```JSON\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_json_with_plain_code_block_uppercase(self):
"""Test JSON wrapped in plain ``` code block (case insensitive)"""
content = '```\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_malformed_json_repair(self):
"""Test with malformed JSON that can be repaired"""
content = '{"key": "value", "incomplete":'
result = repair_json_output(content)
# Should return repaired JSON
assert result.startswith('{"key": "value"')
def test_non_json_content(self):
"""Test with non-JSON content"""
content = "This is just plain text"
result = repair_json_output(content)
assert result == content
def test_empty_string(self):
"""Test with empty string"""
content = ""
result = repair_json_output(content)
assert result == ""
def test_whitespace_only(self):
"""Test with whitespace only"""
content = " \n\t "
result = repair_json_output(content)
assert result == ""
def test_json_with_unicode(self):
"""Test JSON with unicode characters"""
content = '{"name": "测试", "emoji": "🎯"}'
result = repair_json_output(content)
expected = json.dumps({"name": "测试", "emoji": "🎯"}, ensure_ascii=False)
assert result == expected
def test_json_code_block_without_closing(self):
"""Test JSON code block without closing```"""
content = '```json\n{"key": "value"}'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_json_repair_broken_json(self):
"""Test exception handling when JSON repair fails"""
content = '{"this": "is", "completely": broken and unparseable'
expect = '{"this": "is", "completely": "broken and unparseable"}'
result = repair_json_output(content)
assert result == expect
def test_nested_json_object(self):
"""Test with nested JSON object"""
content = '{"outer": {"inner": {"deep": "value"}}}'
result = repair_json_output(content)
expected = json.dumps(
{"outer": {"inner": {"deep": "value"}}}, ensure_ascii=False
)
assert result == expected
def test_json_array_with_objects(self):
"""Test JSON array containing objects"""
content = '[{"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}]'
result = repair_json_output(content)
expected = json.dumps(
[{"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}], ensure_ascii=False
)
assert result == expected
def test_content_with_json_in_middle(self):
"""Test content that contains ```json in the middle"""
content = 'Some text before ```json {"key": "value"} and after'
result = repair_json_output(content)
# Should attempt to process as JSON since it contains ```json
assert isinstance(result, str)
assert result == '{"key": "value"}'
class TestExtractJsonFromContent:
def test_json_with_extra_tokens_after_closing_brace(self):
"""Test extracting JSON with extra tokens after closing brace"""
content = '{"key": "value"} extra tokens here'
result = _extract_json_from_content(content)
assert result == '{"key": "value"}'
def test_json_with_extra_tokens_after_closing_bracket(self):
"""Test extracting JSON array with extra tokens"""
content = '[1, 2, 3] garbage data'
result = _extract_json_from_content(content)
assert result == '[1, 2, 3]'
def test_nested_json_with_extra_tokens(self):
"""Test nested JSON with extra tokens"""
content = '{"nested": {"inner": [1, 2, 3]}} invalid text'
result = _extract_json_from_content(content)
assert result == '{"nested": {"inner": [1, 2, 3]}}'
def test_json_with_string_containing_braces(self):
"""Test JSON with strings containing braces"""
content = '{"text": "this has {braces} in it"} extra'
result = _extract_json_from_content(content)
assert result == '{"text": "this has {braces} in it"}'
def test_json_with_escaped_quotes(self):
"""Test JSON with escaped quotes in strings"""
content = '{"text": "quote \\"here\\""} junk'
result = _extract_json_from_content(content)
assert result == '{"text": "quote \\"here\\""}'
def test_clean_json_no_extra_tokens(self):
"""Test clean JSON without extra tokens"""
content = '{"key": "value"}'
result = _extract_json_from_content(content)
assert result == '{"key": "value"}'
def test_empty_object(self):
"""Test empty object"""
content = '{} extra'
result = _extract_json_from_content(content)
assert result == '{}'
def test_empty_array(self):
"""Test empty array"""
content = '[] more stuff'
result = _extract_json_from_content(content)
assert result == '[]'
def test_extra_closing_brace_no_opening(self):
"""Test that extra closing brace without opening is not marked as valid end"""
content = '} garbage data'
result = _extract_json_from_content(content)
# Should return original content since no opening brace was seen
assert result == content
def test_extra_closing_bracket_no_opening(self):
"""Test that extra closing bracket without opening is not marked as valid end"""
content = '] garbage data'
result = _extract_json_from_content(content)
# Should return original content since no opening bracket was seen
assert result == content
class TestSanitizeToolResponse:
def test_basic_sanitization(self):
"""Test basic tool response sanitization"""
content = "normal response"
result = sanitize_tool_response(content)
assert result == "normal response"
def test_json_with_extra_tokens(self):
"""Test sanitizing JSON with extra tokens"""
content = '{"data": "value"} some garbage'
result = sanitize_tool_response(content)
assert result == '{"data": "value"}'
def test_very_long_response_truncation(self):
"""Test truncation of very long responses"""
long_content = "a" * 60000 # Exceeds default max of 50000
result = sanitize_tool_response(long_content)
assert len(result) <= 50003 # 50000 + "..."
assert result.endswith("...")
def test_custom_max_length(self):
"""Test custom maximum length"""
long_content = "a" * 1000
result = sanitize_tool_response(long_content, max_length=100)
assert len(result) <= 103 # 100 + "..."
assert result.endswith("...")
def test_control_character_removal(self):
"""Test removal of control characters"""
content = "text with \x00 null \x01 chars"
result = sanitize_tool_response(content)
assert "\x00" not in result
assert "\x01" not in result
def test_none_content(self):
"""Test handling of None content"""
result = sanitize_tool_response("")
assert result == ""
def test_whitespace_handling(self):
"""Test whitespace handling"""
content = " text with spaces "
result = sanitize_tool_response(content)
assert result == "text with spaces"
def test_json_array_with_extra_tokens(self):
"""Test JSON array with extra tokens"""
content = '[{"id": 1}, {"id": 2}] invalid stuff'
result = sanitize_tool_response(content)
assert result == '[{"id": 1}, {"id": 2}]'
class TestSanitizeArgs:
def test_sanitize_special_characters(self):
"""Test sanitization of special characters"""
args = '{"key": "value", "array": [1, 2, 3]}'
result = sanitize_args(args)
assert result == '&#123;"key": "value", "array": &#91;1, 2, 3&#93;&#125;'
def test_sanitize_square_brackets(self):
"""Test sanitization of square brackets"""
args = '[1, 2, 3]'
result = sanitize_args(args)
assert result == '&#91;1, 2, 3&#93;'
def test_sanitize_curly_braces(self):
"""Test sanitization of curly braces"""
args = '{key: value}'
result = sanitize_args(args)
assert result == '&#123;key: value&#125;'
def test_sanitize_mixed_brackets(self):
"""Test sanitization of mixed bracket types"""
args = '{[test]}'
result = sanitize_args(args)
assert result == '&#123;&#91;test&#93;&#125;'
def test_sanitize_non_string_input(self):
"""Test sanitization of non-string input returns empty string"""
assert sanitize_args(None) == ""
assert sanitize_args(123) == ""
assert sanitize_args([1, 2, 3]) == ""
assert sanitize_args({"key": "value"}) == ""
def test_sanitize_empty_string(self):
"""Test sanitization of empty string"""
result = sanitize_args("")
assert result == ""
def test_sanitize_plain_text(self):
"""Test sanitization of plain text without special characters"""
args = "plain text without brackets or braces"
result = sanitize_args(args)
assert result == "plain text without brackets or braces"
def test_sanitize_nested_structures(self):
"""Test sanitization of deeply nested structures"""
args = '{"outer": {"inner": [1, [2, 3]]}}'
result = sanitize_args(args)
assert result == '&#123;"outer": &#123;"inner": &#91;1, &#91;2, 3&#93;&#93;&#125;&#125;'
class TestRepairJsonOutputEdgeCases:
def test_code_block_with_leading_spaces(self):
"""Test code block with leading spaces"""
content = ' ```json\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_code_block_with_tabs(self):
"""Test code block with tabs"""
content = '\t```json\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_code_block_with_multiple_newlines(self):
"""Test code block with multiple newlines after opening fence"""
content = '```json\n\n\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_code_block_with_spaces_before_closing(self):
"""Test code block with spaces before closing fence"""
content = '```json\n{"key": "value"}\n ```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_json_with_newlines_in_values(self):
"""Test JSON with newlines in string values"""
content = '{"text": "line1\\nline2\\nline3"}'
result = repair_json_output(content)
expected = json.dumps({"text": "line1\nline2\nline3"}, ensure_ascii=False)
assert result == expected
def test_json_with_special_unicode(self):
"""Test JSON with special unicode characters"""
content = '{"emoji": "🔥💯", "chinese": "中文测试", "math": "∑∫"}'
result = repair_json_output(content)
expected = json.dumps({"emoji": "🔥💯", "chinese": "中文测试", "math": "∑∫"}, ensure_ascii=False)
assert result == expected
def test_json_boolean_values(self):
"""Test JSON with boolean values"""
content = '{"active": true, "disabled": false, "nullable": null}'
result = repair_json_output(content)
expected = json.dumps({"active": True, "disabled": False, "nullable": None}, ensure_ascii=False)
assert result == expected
def test_json_numeric_values(self):
"""Test JSON with various numeric values"""
content = '{"int": 42, "float": 3.14159, "negative": -123, "scientific": 1.23e10}'
result = repair_json_output(content)
parsed = json.loads(result)
assert parsed["int"] == 42
assert parsed["float"] == 3.14159
assert parsed["negative"] == -123
def test_plain_code_block_marker(self):
"""Test plain ``` code block without language specifier"""
content = '```\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_multiple_json_objects_takes_first_complete(self):
"""Test that multiple JSON objects are properly extracted"""
content = '{"first": "object"} {"second": "object"}'
result = repair_json_output(content)
# json_repair will combine multiple objects into an array
expected = json.dumps([{"first": "object"}, {"second": "object"}], ensure_ascii=False)
assert result == expected
def test_chinese_json_with_code_block(self):
"""Test JSON with Chinese content wrapped in markdown code block"""
content = '''```json
{
"locale": "en-US",
"has_enough_context": true,
"thought": "测试中文内容",
"title": "地月距离小报告",
"steps": []
}
```'''
result = repair_json_output(content)
parsed = json.loads(result)
assert parsed["locale"] == "en-US"
assert parsed["title"] == "地月距离小报告"
assert parsed["thought"] == "测试中文内容"
assert isinstance(parsed["steps"], list)
def test_code_block_uppercase_json_with_leading_spaces(self):
"""Test uppercase JSON code block with leading spaces"""
content = ' ```JSON\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_code_block_uppercase_json_with_tabs(self):
"""Test uppercase JSON code block with tabs"""
content = '\t```JSON\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_code_block_mixed_case_with_multiple_newlines(self):
"""Test mixed case code block with multiple newlines"""
content = '```JsOn\n\n\n{"key": "value"}\n```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_code_block_uppercase_with_spaces_before_closing(self):
"""Test uppercase code block with spaces before closing fence"""
content = '```TYPESCRIPT\n{"key": "value"}\n ```'
result = repair_json_output(content)
expected = json.dumps({"key": "value"}, ensure_ascii=False)
assert result == expected
def test_code_block_case_insensitive_various_languages(self):
"""Test code blocks with various language specifiers in different cases"""
test_cases = [
('```Python\n{"key": "value"}\n```', '{"key": "value"}'),
('```PYTHON\n{"key": "value"}\n```', '{"key": "value"}'),
('```pYtHoN\n{"key": "value"}\n```', '{"key": "value"}'),
('```sql\n{"key": "value"}\n```', '{"key": "value"}'),
('```SQL\n{"key": "value"}\n```', '{"key": "value"}'),
]
for content, expected_json_str in test_cases:
result = repair_json_output(content)
# Verify it's valid JSON
parsed = json.loads(result)
assert parsed["key"] == "value"
class TestExtractJsonFromContentEdgeCases:
def test_deeply_nested_json(self):
"""Test extraction of deeply nested JSON"""
content = '{"l1": {"l2": {"l3": {"l4": {"l5": "deep"}}}}} garbage'
result = _extract_json_from_content(content)
assert result == '{"l1": {"l2": {"l3": {"l4": {"l5": "deep"}}}}}'
def test_json_array_of_arrays(self):
"""Test extraction of nested arrays"""
content = '[[1, 2], [3, 4], [5, 6]] extra'
result = _extract_json_from_content(content)
assert result == '[[1, 2], [3, 4], [5, 6]]'
def test_json_with_backslashes_in_string(self):
"""Test JSON with backslashes in string values"""
content = r'{"path": "C:\\Users\\test\\file.txt"} garbage'
result = _extract_json_from_content(content)
assert result == r'{"path": "C:\\Users\\test\\file.txt"}'
def test_json_with_forward_slashes(self):
"""Test JSON with forward slashes in string values"""
content = '{"url": "https://example.com/path/to/resource"} extra'
result = _extract_json_from_content(content)
assert result == '{"url": "https://example.com/path/to/resource"}'
def test_mixed_object_and_array(self):
"""Test JSON with mixed objects and arrays"""
content = '{"items": [{"id": 1}, {"id": 2}], "count": 2} tail'
result = _extract_json_from_content(content)
assert result == '{"items": [{"id": 1}, {"id": 2}], "count": 2}'
def test_json_with_unicode_escape_sequences(self):
"""Test JSON with unicode escape sequences"""
content = r'{"text": "\u4E2D\u6587"} junk'
result = _extract_json_from_content(content)
assert result == r'{"text": "\u4E2D\u6587"}'
def test_no_json_structure(self):
"""Test content without JSON structure"""
content = 'just plain text without brackets'
result = _extract_json_from_content(content)
assert result == content
def test_unbalanced_braces_in_middle(self):
"""Test content with unbalanced braces doesn't extract invalid JSON"""
content = '{"incomplete": {"nested": } text'
result = _extract_json_from_content(content)
# Should not mark as valid end since braces are unbalanced
assert result == content
def test_json_with_comma_separated_values(self):
"""Test JSON object with multiple comma-separated values"""
content = '{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} more text'
result = _extract_json_from_content(content)
assert result == '{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}'
class TestSanitizeToolResponseEdgeCases:
def test_json_object_with_extra_tokens(self):
"""Test sanitizing JSON object with trailing tokens"""
content = '{"status": "success", "data": {"id": 123}} trailing garbage'
result = sanitize_tool_response(content)
assert result == '{"status": "success", "data": {"id": 123}}'
def test_truncation_at_exact_boundary(self):
"""Test truncation behavior at exact max_length boundary"""
content = "x" * 50000
result = sanitize_tool_response(content, max_length=50000)
assert len(result) == 50000
assert not result.endswith("...")
def test_truncation_one_over_boundary(self):
"""Test truncation when content is one char over limit"""
content = "x" * 50001
result = sanitize_tool_response(content, max_length=50000)
assert len(result) <= 50003
assert result.endswith("...")
def test_multiple_control_characters(self):
"""Test removal of multiple types of control characters"""
content = "text\x00with\x01various\x02control\x1Fchars\x7F"
result = sanitize_tool_response(content)
# All control characters should be removed
assert "\x00" not in result
assert "\x01" not in result
assert "\x02" not in result
assert "\x1F" not in result
assert "\x7F" not in result
assert "textwithvariouscontrolchars" == result
def test_newline_and_tab_preservation(self):
"""Test that newlines and tabs are preserved (they are valid)"""
content = "line1\nline2\tindented"
result = sanitize_tool_response(content)
assert "\n" in result
assert "\t" in result
assert result == "line1\nline2\tindented"
def test_non_json_content_unchanged(self):
"""Test that non-JSON content is not modified"""
content = "This is plain text without any JSON structure"
result = sanitize_tool_response(content)
assert result == content
def test_json_array_at_start(self):
"""Test extraction of JSON array at start of content"""
content = '[1, 2, 3, 4, 5] followed by text'
result = sanitize_tool_response(content)
assert result == '[1, 2, 3, 4, 5]'
def test_empty_json_structures_preserved(self):
"""Test that empty JSON structures are preserved"""
content = '{"empty_obj": {}, "empty_arr": []} extra'
result = sanitize_tool_response(content)
assert result == '{"empty_obj": {}, "empty_arr": []}'
def test_whitespace_variations(self):
"""Test handling of various whitespace patterns"""
content = " \n\t content with spaces \t\n "
result = sanitize_tool_response(content)
assert result == "content with spaces"
-268
View File
@@ -1,268 +0,0 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for log sanitization utilities.
This test file verifies that the log sanitizer properly prevents log injection attacks
by escaping dangerous characters in user-controlled input before logging.
"""
import pytest
from src.utils.log_sanitizer import (
create_safe_log_message,
sanitize_agent_name,
sanitize_feedback,
sanitize_log_input,
sanitize_thread_id,
sanitize_tool_name,
sanitize_user_content,
)
class TestSanitizeLogInput:
"""Test the main sanitize_log_input function."""
def test_sanitize_normal_text(self):
"""Test that normal text is preserved."""
text = "normal text"
result = sanitize_log_input(text)
assert result == "normal text"
def test_sanitize_newline_injection(self):
"""Test prevention of newline injection attack."""
malicious = "abc\n[INFO] Forged log entry"
result = sanitize_log_input(malicious)
assert "\n" not in result
assert "[INFO]" in result # The attack text is preserved but escaped
assert "\\n" in result # Newline is escaped
def test_sanitize_carriage_return(self):
"""Test prevention of carriage return injection."""
malicious = "text\r[WARN] Forged entry"
result = sanitize_log_input(malicious)
assert "\r" not in result
assert "\\r" in result
def test_sanitize_tab_character(self):
"""Test prevention of tab character injection."""
malicious = "text\t[ERROR] Forged"
result = sanitize_log_input(malicious)
assert "\t" not in result
assert "\\t" in result
def test_sanitize_null_character(self):
"""Test prevention of null character injection."""
malicious = "text\x00[CRITICAL]"
result = sanitize_log_input(malicious)
assert "\x00" not in result
def test_sanitize_backslash(self):
"""Test that backslashes are properly escaped."""
text = "path\\to\\file"
result = sanitize_log_input(text)
assert result == "path\\\\to\\\\file"
def test_sanitize_escape_character(self):
"""Test prevention of ANSI escape sequence injection."""
malicious = "text\x1b[31mRED TEXT\x1b[0m"
result = sanitize_log_input(malicious)
assert "\x1b" not in result
assert "\\x1b" in result
def test_sanitize_max_length_truncation(self):
"""Test that long strings are truncated."""
long_text = "a" * 1000
result = sanitize_log_input(long_text, max_length=100)
assert len(result) <= 100
assert result.endswith("...")
def test_sanitize_none_value(self):
"""Test that None is handled properly."""
result = sanitize_log_input(None)
assert result == "None"
def test_sanitize_numeric_value(self):
"""Test that numeric values are converted to strings."""
result = sanitize_log_input(12345)
assert result == "12345"
def test_sanitize_complex_injection_attack(self):
"""Test complex multi-character injection attack."""
malicious = 'thread-123\n[WARNING] Unauthorized\r[ERROR] System failure\t[CRITICAL] Shutdown'
result = sanitize_log_input(malicious)
# All dangerous characters should be escaped
assert "\n" not in result
assert "\r" not in result
assert "\t" not in result
# But the text should still be there (escaped)
assert "WARNING" in result
assert "ERROR" in result
class TestSanitizeThreadId:
"""Test sanitization of thread IDs."""
def test_thread_id_normal(self):
"""Test normal thread ID."""
thread_id = "thread-123-abc"
result = sanitize_thread_id(thread_id)
assert result == "thread-123-abc"
def test_thread_id_with_newline(self):
"""Test thread ID with newline injection."""
malicious = "thread-1\n[INFO] Forged"
result = sanitize_thread_id(malicious)
assert "\n" not in result
assert "\\n" in result
def test_thread_id_max_length(self):
"""Test that thread ID truncation respects max length."""
long_id = "x" * 200
result = sanitize_thread_id(long_id)
assert len(result) <= 100
class TestSanitizeUserContent:
"""Test sanitization of user-provided message content."""
def test_user_content_normal(self):
"""Test normal user content."""
content = "What is the weather today?"
result = sanitize_user_content(content)
assert result == "What is the weather today?"
def test_user_content_with_newline(self):
"""Test user content with newline."""
malicious = "My question\n[ADMIN] Delete user"
result = sanitize_user_content(malicious)
assert "\n" not in result
assert "\\n" in result
def test_user_content_max_length(self):
"""Test that user content is truncated more aggressively."""
long_content = "x" * 500
result = sanitize_user_content(long_content)
assert len(result) <= 200
class TestSanitizeToolName:
"""Test sanitization of tool names."""
def test_tool_name_normal(self):
"""Test normal tool name."""
tool = "web_search"
result = sanitize_tool_name(tool)
assert result == "web_search"
def test_tool_name_injection(self):
"""Test tool name with injection attempt."""
malicious = "search\n[WARN] Forged"
result = sanitize_tool_name(malicious)
assert "\n" not in result
class TestSanitizeFeedback:
"""Test sanitization of user feedback."""
def test_feedback_normal(self):
"""Test normal feedback."""
feedback = "[accepted]"
result = sanitize_feedback(feedback)
assert result == "[accepted]"
def test_feedback_injection(self):
"""Test feedback with injection attempt."""
malicious = "[approved]\n[CRITICAL] System down"
result = sanitize_feedback(malicious)
assert "\n" not in result
assert "\\n" in result
def test_feedback_max_length(self):
"""Test that feedback is truncated."""
long_feedback = "x" * 500
result = sanitize_feedback(long_feedback)
assert len(result) <= 150
class TestCreateSafeLogMessage:
"""Test the create_safe_log_message helper function."""
def test_safe_message_normal(self):
"""Test normal message creation."""
msg = create_safe_log_message(
"[{thread_id}] Processing {tool_name}",
thread_id="thread-1",
tool_name="search",
)
assert "[thread-1] Processing search" == msg
def test_safe_message_with_injection(self):
"""Test message creation with injected values."""
msg = create_safe_log_message(
"[{thread_id}] Tool: {tool_name}",
thread_id="id\n[INFO] Forged",
tool_name="search\r[ERROR]",
)
# The dangerous characters should be escaped
assert "\n" not in msg
assert "\r" not in msg
assert "\\n" in msg
assert "\\r" in msg
def test_safe_message_multiple_values(self):
"""Test message with multiple values."""
msg = create_safe_log_message(
"[{id}] User: {user} Tool: {tool}",
id="123",
user="admin\t[WARN]",
tool="delete\x1b[31m",
)
assert "\t" not in msg
assert "\x1b" not in msg
class TestLogInjectionAttackPrevention:
"""Integration tests for log injection prevention."""
def test_classic_log_injection_newline(self):
"""Test the classic log injection attack using newlines."""
attacker_input = 'abc\n[WARNING] Unauthorized access detected'
result = sanitize_log_input(attacker_input)
# The output should not contain an actual newline that would create a new log entry
assert result.count("\n") == 0
# But the escaped version should be in there
assert "\\n" in result
def test_carriage_return_log_injection(self):
"""Test log injection via carriage return."""
attacker_input = "request_id\r\n[ERROR] CRITICAL FAILURE"
result = sanitize_log_input(attacker_input)
assert "\r" not in result
assert "\n" not in result
def test_html_injection_prevention(self):
"""Test prevention of HTML injection in logs."""
# While HTML tags themselves aren't dangerous in log files,
# escaping control characters helps prevent parsing attacks
malicious_html = "user\x1b[32m<script>alert('xss')</script>"
result = sanitize_log_input(malicious_html)
assert "\x1b" not in result
# HTML is preserved but with escaped control chars
assert "<script>" in result
def test_multiple_injection_techniques(self):
"""Test prevention of multiple injection techniques combined."""
attack = 'id_1\n\r\t[CRITICAL]\x1b[31m RED TEXT'
result = sanitize_log_input(attack)
# No actual control characters should exist
assert "\n" not in result
assert "\r" not in result
assert "\t" not in result
assert "\x1b" not in result
# But escaped versions should exist
assert "\\n" in result
assert "\\r" in result
assert "\\t" in result
assert "\\x1b" in result