mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
Prepare to merge deer-flow-2
This commit is contained in:
@@ -1,29 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from src.crawler import Crawler
|
||||
|
||||
|
||||
def test_crawler_initialization():
|
||||
"""Test that crawler can be properly initialized."""
|
||||
crawler = Crawler()
|
||||
assert isinstance(crawler, Crawler)
|
||||
|
||||
|
||||
def test_crawler_crawl_valid_url():
|
||||
"""Test crawling with a valid URL."""
|
||||
crawler = Crawler()
|
||||
test_url = "https://finance.sina.com.cn/stock/relnews/us/2024-08-15/doc-incitsya6536375.shtml"
|
||||
result = crawler.crawl(test_url)
|
||||
assert result is not None
|
||||
assert hasattr(result, "to_markdown")
|
||||
|
||||
|
||||
def test_crawler_markdown_output():
|
||||
"""Test that crawler output can be converted to markdown."""
|
||||
crawler = Crawler()
|
||||
test_url = "https://finance.sina.com.cn/stock/relnews/us/2024-08-15/doc-incitsya6536375.shtml"
|
||||
result = crawler.crawl(test_url)
|
||||
markdown = result.to_markdown()
|
||||
assert isinstance(markdown, str)
|
||||
assert len(markdown) > 0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,144 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
|
||||
from src.prompts.template import apply_prompt_template, get_prompt_template
|
||||
|
||||
|
||||
def test_get_prompt_template_success():
|
||||
"""Test successful template loading"""
|
||||
template = get_prompt_template("coder")
|
||||
assert template is not None
|
||||
assert isinstance(template, str)
|
||||
assert len(template) > 0
|
||||
|
||||
|
||||
def test_get_prompt_template_not_found():
|
||||
"""Test handling of non-existent template"""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_prompt_template("non_existent_template")
|
||||
assert "Error loading template" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_apply_prompt_template():
|
||||
"""Test template variable substitution"""
|
||||
test_state = {
|
||||
"messages": [{"role": "user", "content": "test message"}],
|
||||
"task": "test task",
|
||||
"workspace_context": "test context",
|
||||
}
|
||||
|
||||
messages = apply_prompt_template("coder", test_state)
|
||||
|
||||
assert isinstance(messages, list)
|
||||
assert len(messages) > 1
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "CURRENT_TIME" in messages[0]["content"]
|
||||
assert messages[1]["role"] == "user"
|
||||
assert messages[1]["content"] == "test message"
|
||||
|
||||
|
||||
def test_apply_prompt_template_empty_messages():
|
||||
"""Test template with empty messages list"""
|
||||
test_state = {
|
||||
"messages": [],
|
||||
"task": "test task",
|
||||
"workspace_context": "test context",
|
||||
}
|
||||
|
||||
messages = apply_prompt_template("coder", test_state)
|
||||
assert len(messages) == 1 # Only system message
|
||||
assert messages[0]["role"] == "system"
|
||||
|
||||
|
||||
def test_apply_prompt_template_multiple_messages():
|
||||
"""Test template with multiple messages"""
|
||||
test_state = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "first message"},
|
||||
{"role": "assistant", "content": "response"},
|
||||
{"role": "user", "content": "second message"},
|
||||
],
|
||||
"task": "test task",
|
||||
"workspace_context": "test context",
|
||||
}
|
||||
|
||||
messages = apply_prompt_template("coder", test_state)
|
||||
assert len(messages) == 4 # system + 3 messages
|
||||
assert messages[0]["role"] == "system"
|
||||
assert all(m["role"] in ["system", "user", "assistant"] for m in messages)
|
||||
|
||||
|
||||
def test_apply_prompt_template_with_special_chars():
|
||||
"""Test template with special characters in variables"""
|
||||
test_state = {
|
||||
"messages": [{"role": "user", "content": "test\nmessage\"with'special{chars}"}],
|
||||
"task": "task with $pecial ch@rs",
|
||||
"workspace_context": "<html>context</html>",
|
||||
}
|
||||
|
||||
messages = apply_prompt_template("coder", test_state)
|
||||
assert messages[1]["content"] == "test\nmessage\"with'special{chars}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompt_name", ["coder", "coder", "coordinator", "planner"])
|
||||
def test_multiple_template_types(prompt_name):
|
||||
"""Test loading different types of templates"""
|
||||
template = get_prompt_template(prompt_name)
|
||||
assert template is not None
|
||||
assert isinstance(template, str)
|
||||
assert len(template) > 0
|
||||
|
||||
|
||||
def test_current_time_format():
|
||||
"""Test the format of CURRENT_TIME in rendered template"""
|
||||
test_state = {
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"task": "test",
|
||||
"workspace_context": "test",
|
||||
}
|
||||
|
||||
messages = apply_prompt_template("coder", test_state)
|
||||
system_content = messages[0]["content"]
|
||||
|
||||
assert any(
|
||||
line.strip().startswith("CURRENT_TIME:") for line in system_content.split("\n")
|
||||
)
|
||||
|
||||
|
||||
def test_apply_prompt_template_reporter():
|
||||
"""Test reporter template rendering with different styles and locale"""
|
||||
|
||||
test_state_news = {
|
||||
"messages": [],
|
||||
"task": "test reporter task",
|
||||
"workspace_context": "test reporter context",
|
||||
"report_style": "news",
|
||||
"locale": "en-US",
|
||||
}
|
||||
messages_news = apply_prompt_template("reporter", test_state_news)
|
||||
system_content_news = messages_news[0]["content"]
|
||||
assert "NBC News" in system_content_news
|
||||
|
||||
test_state_social_media_en = {
|
||||
"messages": [],
|
||||
"task": "test reporter task",
|
||||
"workspace_context": "test reporter context",
|
||||
"report_style": "social_media",
|
||||
"locale": "en-US",
|
||||
}
|
||||
messages_default = apply_prompt_template("reporter", test_state_social_media_en)
|
||||
system_content_default = messages_default[0]["content"]
|
||||
assert "Twitter/X" in system_content_default
|
||||
|
||||
test_state_social_media_cn = {
|
||||
"messages": [],
|
||||
"task": "test reporter task",
|
||||
"workspace_context": "test reporter context",
|
||||
"report_style": "social_media",
|
||||
"locale": "zh-CN",
|
||||
}
|
||||
messages_cn = apply_prompt_template("reporter", test_state_social_media_cn)
|
||||
system_content_cn = messages_cn[0]["content"]
|
||||
assert "小红书" in system_content_cn
|
||||
@@ -1,473 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Integration tests for tool-specific interrupts feature (Issue #572).
|
||||
|
||||
Tests the complete flow of selective tool interrupts including:
|
||||
- Tool wrapping with interrupt logic
|
||||
- Agent creation with interrupt configuration
|
||||
- Tool execution with user feedback
|
||||
- Resume mechanism after interrupt
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from src.agents.agents import create_agent
|
||||
from src.agents.tool_interceptor import ToolInterceptor, wrap_tools_with_interceptor
|
||||
from src.config.configuration import Configuration
|
||||
from src.server.chat_request import ChatRequest
|
||||
|
||||
|
||||
class TestToolInterceptorIntegration:
|
||||
"""Integration tests for tool interceptor with agent workflow."""
|
||||
|
||||
def test_agent_creation_with_tool_interrupts(self):
|
||||
"""Test creating an agent with tool interrupts configured."""
|
||||
@tool
|
||||
def search_tool(query: str) -> str:
|
||||
"""Search the web."""
|
||||
return f"Search results for: {query}"
|
||||
|
||||
@tool
|
||||
def db_tool(query: str) -> str:
|
||||
"""Query database."""
|
||||
return f"DB results for: {query}"
|
||||
|
||||
tools = [search_tool, db_tool]
|
||||
|
||||
# Create agent with interrupts on db_tool only
|
||||
with patch("src.agents.agents.langchain_create_agent") as mock_create, \
|
||||
patch("src.agents.agents.get_llm_by_type") as mock_llm:
|
||||
mock_create.return_value = MagicMock()
|
||||
mock_llm.return_value = MagicMock()
|
||||
|
||||
agent = create_agent(
|
||||
agent_name="test_agent",
|
||||
agent_type="researcher",
|
||||
tools=tools,
|
||||
prompt_template="researcher",
|
||||
interrupt_before_tools=["db_tool"],
|
||||
)
|
||||
|
||||
# Verify langchain_create_agent was called with wrapped tools
|
||||
assert mock_create.called
|
||||
call_args = mock_create.call_args
|
||||
wrapped_tools = call_args.kwargs["tools"]
|
||||
|
||||
# Should have wrapped the tools
|
||||
assert len(wrapped_tools) == 2
|
||||
assert wrapped_tools[0].name == "search_tool"
|
||||
assert wrapped_tools[1].name == "db_tool"
|
||||
|
||||
def test_configuration_with_tool_interrupts(self):
|
||||
"""Test Configuration object with interrupt_before_tools."""
|
||||
config = Configuration(
|
||||
interrupt_before_tools=["db_tool", "api_write_tool"],
|
||||
max_step_num=3,
|
||||
max_search_results=5,
|
||||
)
|
||||
|
||||
assert config.interrupt_before_tools == ["db_tool", "api_write_tool"]
|
||||
assert config.max_step_num == 3
|
||||
assert config.max_search_results == 5
|
||||
|
||||
def test_configuration_default_no_interrupts(self):
|
||||
"""Test Configuration defaults to no interrupts."""
|
||||
config = Configuration()
|
||||
assert config.interrupt_before_tools == []
|
||||
|
||||
def test_chat_request_with_tool_interrupts(self):
|
||||
"""Test ChatRequest with interrupt_before_tools."""
|
||||
request = ChatRequest(
|
||||
messages=[{"role": "user", "content": "Search for X"}],
|
||||
interrupt_before_tools=["db_tool", "payment_api"],
|
||||
)
|
||||
|
||||
assert request.interrupt_before_tools == ["db_tool", "payment_api"]
|
||||
|
||||
def test_chat_request_interrupt_feedback_with_tool_interrupts(self):
|
||||
"""Test ChatRequest with both interrupt_before_tools and interrupt_feedback."""
|
||||
request = ChatRequest(
|
||||
messages=[{"role": "user", "content": "Research topic"}],
|
||||
interrupt_before_tools=["db_tool"],
|
||||
interrupt_feedback="approved",
|
||||
)
|
||||
|
||||
assert request.interrupt_before_tools == ["db_tool"]
|
||||
assert request.interrupt_feedback == "approved"
|
||||
|
||||
def test_multiple_tools_selective_interrupt(self):
|
||||
"""Test that only specified tools trigger interrupts."""
|
||||
@tool
|
||||
def tool_a(x: str) -> str:
|
||||
"""Tool A"""
|
||||
return f"A: {x}"
|
||||
|
||||
@tool
|
||||
def tool_b(x: str) -> str:
|
||||
"""Tool B"""
|
||||
return f"B: {x}"
|
||||
|
||||
@tool
|
||||
def tool_c(x: str) -> str:
|
||||
"""Tool C"""
|
||||
return f"C: {x}"
|
||||
|
||||
tools = [tool_a, tool_b, tool_c]
|
||||
interceptor = ToolInterceptor(["tool_b"])
|
||||
|
||||
# Wrap all tools
|
||||
wrapped_tools = wrap_tools_with_interceptor(tools, ["tool_b"])
|
||||
|
||||
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
# tool_a should not interrupt
|
||||
mock_interrupt.return_value = "approved"
|
||||
result_a = wrapped_tools[0].invoke("test")
|
||||
mock_interrupt.assert_not_called()
|
||||
|
||||
# tool_b should interrupt
|
||||
result_b = wrapped_tools[1].invoke("test")
|
||||
mock_interrupt.assert_called()
|
||||
|
||||
# tool_c should not interrupt
|
||||
mock_interrupt.reset_mock()
|
||||
result_c = wrapped_tools[2].invoke("test")
|
||||
mock_interrupt.assert_not_called()
|
||||
|
||||
def test_interrupt_with_user_approval(self):
|
||||
"""Test interrupt flow with user approval."""
|
||||
@tool
|
||||
def sensitive_tool(action: str) -> str:
|
||||
"""A sensitive tool."""
|
||||
return f"Executed: {action}"
|
||||
|
||||
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
mock_interrupt.return_value = "approved"
|
||||
|
||||
interceptor = ToolInterceptor(["sensitive_tool"])
|
||||
wrapped = ToolInterceptor.wrap_tool(sensitive_tool, interceptor)
|
||||
|
||||
result = wrapped.invoke("delete_data")
|
||||
|
||||
mock_interrupt.assert_called()
|
||||
assert "Executed: delete_data" in str(result)
|
||||
|
||||
def test_interrupt_with_user_rejection(self):
|
||||
"""Test interrupt flow with user rejection."""
|
||||
@tool
|
||||
def sensitive_tool(action: str) -> str:
|
||||
"""A sensitive tool."""
|
||||
return f"Executed: {action}"
|
||||
|
||||
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
mock_interrupt.return_value = "rejected"
|
||||
|
||||
interceptor = ToolInterceptor(["sensitive_tool"])
|
||||
wrapped = ToolInterceptor.wrap_tool(sensitive_tool, interceptor)
|
||||
|
||||
result = wrapped.invoke("delete_data")
|
||||
|
||||
mock_interrupt.assert_called()
|
||||
assert isinstance(result, dict)
|
||||
assert "error" in result
|
||||
assert result["status"] == "rejected"
|
||||
|
||||
def test_interrupt_message_contains_tool_info(self):
|
||||
"""Test that interrupt message contains tool name and input."""
|
||||
@tool
|
||||
def db_query_tool(query: str) -> str:
|
||||
"""Database query tool."""
|
||||
return f"Query result: {query}"
|
||||
|
||||
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
mock_interrupt.return_value = "approved"
|
||||
|
||||
interceptor = ToolInterceptor(["db_query_tool"])
|
||||
wrapped = ToolInterceptor.wrap_tool(db_query_tool, interceptor)
|
||||
|
||||
wrapped.invoke("SELECT * FROM users")
|
||||
|
||||
# Verify interrupt was called with meaningful message
|
||||
mock_interrupt.assert_called()
|
||||
interrupt_message = mock_interrupt.call_args[0][0]
|
||||
assert "db_query_tool" in interrupt_message
|
||||
assert "SELECT * FROM users" in interrupt_message
|
||||
|
||||
def test_tool_wrapping_preserves_functionality(self):
|
||||
"""Test that tool wrapping preserves original tool functionality."""
|
||||
@tool
|
||||
def simple_tool(text: str) -> str:
|
||||
"""Process text."""
|
||||
return f"Processed: {text}"
|
||||
|
||||
interceptor = ToolInterceptor([]) # No interrupts
|
||||
wrapped = ToolInterceptor.wrap_tool(simple_tool, interceptor)
|
||||
|
||||
result = wrapped.invoke({"text": "hello"})
|
||||
assert "hello" in str(result)
|
||||
|
||||
def test_tool_wrapping_preserves_tool_metadata(self):
|
||||
"""Test that tool wrapping preserves tool name and description."""
|
||||
@tool
|
||||
def my_special_tool(x: str) -> str:
|
||||
"""This is my special tool description."""
|
||||
return f"Result: {x}"
|
||||
|
||||
interceptor = ToolInterceptor([])
|
||||
wrapped = ToolInterceptor.wrap_tool(my_special_tool, interceptor)
|
||||
|
||||
assert wrapped.name == "my_special_tool"
|
||||
assert "special tool" in wrapped.description.lower()
|
||||
|
||||
def test_multiple_interrupts_in_sequence(self):
|
||||
"""Test handling multiple tool interrupts in sequence."""
|
||||
@tool
|
||||
def tool_one(x: str) -> str:
|
||||
"""Tool one."""
|
||||
return f"One: {x}"
|
||||
|
||||
@tool
|
||||
def tool_two(x: str) -> str:
|
||||
"""Tool two."""
|
||||
return f"Two: {x}"
|
||||
|
||||
@tool
|
||||
def tool_three(x: str) -> str:
|
||||
"""Tool three."""
|
||||
return f"Three: {x}"
|
||||
|
||||
tools = [tool_one, tool_two, tool_three]
|
||||
wrapped_tools = wrap_tools_with_interceptor(
|
||||
tools, ["tool_one", "tool_two"]
|
||||
)
|
||||
|
||||
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
mock_interrupt.return_value = "approved"
|
||||
|
||||
# First interrupt
|
||||
result_one = wrapped_tools[0].invoke("first")
|
||||
assert mock_interrupt.call_count == 1
|
||||
|
||||
# Second interrupt
|
||||
result_two = wrapped_tools[1].invoke("second")
|
||||
assert mock_interrupt.call_count == 2
|
||||
|
||||
# Third (no interrupt)
|
||||
result_three = wrapped_tools[2].invoke("third")
|
||||
assert mock_interrupt.call_count == 2
|
||||
|
||||
assert "One: first" in str(result_one)
|
||||
assert "Two: second" in str(result_two)
|
||||
assert "Three: third" in str(result_three)
|
||||
|
||||
def test_empty_interrupt_list_no_interrupts(self):
|
||||
"""Test that empty interrupt list doesn't trigger interrupts."""
|
||||
@tool
|
||||
def test_tool(x: str) -> str:
|
||||
"""Test tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
wrapped_tools = wrap_tools_with_interceptor([test_tool], [])
|
||||
|
||||
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
wrapped_tools[0].invoke("test")
|
||||
mock_interrupt.assert_not_called()
|
||||
|
||||
def test_none_interrupt_list_no_interrupts(self):
|
||||
"""Test that None interrupt list doesn't trigger interrupts."""
|
||||
@tool
|
||||
def test_tool(x: str) -> str:
|
||||
"""Test tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
wrapped_tools = wrap_tools_with_interceptor([test_tool], None)
|
||||
|
||||
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
wrapped_tools[0].invoke("test")
|
||||
mock_interrupt.assert_not_called()
|
||||
|
||||
def test_case_sensitive_tool_name_matching(self):
|
||||
"""Test that tool name matching is case-sensitive."""
|
||||
@tool
|
||||
def MyTool(x: str) -> str:
|
||||
"""A tool."""
|
||||
return f"Result: {x}"
|
||||
|
||||
interceptor_lower = ToolInterceptor(["mytool"])
|
||||
interceptor_exact = ToolInterceptor(["MyTool"])
|
||||
|
||||
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
mock_interrupt.return_value = "approved"
|
||||
|
||||
# Case mismatch - should NOT interrupt
|
||||
wrapped_lower = ToolInterceptor.wrap_tool(MyTool, interceptor_lower)
|
||||
result_lower = wrapped_lower.invoke("test")
|
||||
mock_interrupt.assert_not_called()
|
||||
|
||||
# Case match - should interrupt
|
||||
wrapped_exact = ToolInterceptor.wrap_tool(MyTool, interceptor_exact)
|
||||
result_exact = wrapped_exact.invoke("test")
|
||||
mock_interrupt.assert_called()
|
||||
|
||||
def test_tool_error_handling(self):
|
||||
"""Test handling of tool errors during execution."""
|
||||
@tool
|
||||
def error_tool(x: str) -> str:
|
||||
"""A tool that raises an error."""
|
||||
raise ValueError(f"Intentional error: {x}")
|
||||
|
||||
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
mock_interrupt.return_value = "approved"
|
||||
|
||||
interceptor = ToolInterceptor(["error_tool"])
|
||||
wrapped = ToolInterceptor.wrap_tool(error_tool, interceptor)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
wrapped.invoke("test")
|
||||
|
||||
assert "Intentional error: test" in str(exc_info.value)
|
||||
|
||||
def test_approval_keywords_comprehensive(self):
|
||||
"""Test all approved keywords are recognized."""
|
||||
approval_keywords = [
|
||||
"approved",
|
||||
"approve",
|
||||
"yes",
|
||||
"proceed",
|
||||
"continue",
|
||||
"ok",
|
||||
"okay",
|
||||
"accepted",
|
||||
"accept",
|
||||
"[approved]",
|
||||
"APPROVED",
|
||||
"Proceed with this action",
|
||||
"[ACCEPTED] I approve",
|
||||
]
|
||||
|
||||
for keyword in approval_keywords:
|
||||
result = ToolInterceptor._parse_approval(keyword)
|
||||
assert (
|
||||
result is True
|
||||
), f"Keyword '{keyword}' should be approved but got {result}"
|
||||
|
||||
def test_rejection_keywords_comprehensive(self):
|
||||
"""Test that rejection keywords are recognized."""
|
||||
rejection_keywords = [
|
||||
"no",
|
||||
"reject",
|
||||
"cancel",
|
||||
"decline",
|
||||
"stop",
|
||||
"abort",
|
||||
"maybe",
|
||||
"later",
|
||||
"random text",
|
||||
"",
|
||||
]
|
||||
|
||||
for keyword in rejection_keywords:
|
||||
result = ToolInterceptor._parse_approval(keyword)
|
||||
assert (
|
||||
result is False
|
||||
), f"Keyword '{keyword}' should be rejected but got {result}"
|
||||
|
||||
def test_interrupt_with_complex_tool_input(self):
|
||||
"""Test interrupt with complex tool input types."""
|
||||
@tool
|
||||
def complex_tool(data: str) -> str:
|
||||
"""A tool with complex input."""
|
||||
return f"Processed: {data}"
|
||||
|
||||
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
mock_interrupt.return_value = "approved"
|
||||
|
||||
interceptor = ToolInterceptor(["complex_tool"])
|
||||
wrapped = ToolInterceptor.wrap_tool(complex_tool, interceptor)
|
||||
|
||||
complex_input = {
|
||||
"data": "complex data with nested info"
|
||||
}
|
||||
|
||||
result = wrapped.invoke(complex_input)
|
||||
|
||||
mock_interrupt.assert_called()
|
||||
assert "Processed" in str(result)
|
||||
|
||||
def test_configuration_from_runnable_config(self):
|
||||
"""Test Configuration.from_runnable_config with interrupt_before_tools."""
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
config = RunnableConfig(
|
||||
configurable={
|
||||
"interrupt_before_tools": ["db_tool"],
|
||||
"max_step_num": 5,
|
||||
}
|
||||
)
|
||||
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
|
||||
assert configuration.interrupt_before_tools == ["db_tool"]
|
||||
assert configuration.max_step_num == 5
|
||||
|
||||
def test_tool_interceptor_initialization_logging(self):
|
||||
"""Test that ToolInterceptor initialization is logged."""
|
||||
with patch("src.agents.tool_interceptor.logger") as mock_logger:
|
||||
interceptor = ToolInterceptor(["tool_a", "tool_b"])
|
||||
mock_logger.info.assert_called()
|
||||
|
||||
def test_wrap_tools_with_interceptor_logging(self):
|
||||
"""Test that tool wrapping is logged."""
|
||||
@tool
|
||||
def test_tool(x: str) -> str:
|
||||
"""Test."""
|
||||
return x
|
||||
|
||||
with patch("src.agents.tool_interceptor.logger") as mock_logger:
|
||||
wrapped = wrap_tools_with_interceptor([test_tool], ["test_tool"])
|
||||
# Check that at least one info log was called
|
||||
assert mock_logger.info.called or mock_logger.debug.called
|
||||
|
||||
def test_interrupt_resolution_with_empty_feedback(self):
|
||||
"""Test interrupt resolution with empty feedback."""
|
||||
@tool
|
||||
def test_tool(x: str) -> str:
|
||||
"""Test."""
|
||||
return f"Result: {x}"
|
||||
|
||||
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
mock_interrupt.return_value = ""
|
||||
|
||||
interceptor = ToolInterceptor(["test_tool"])
|
||||
wrapped = ToolInterceptor.wrap_tool(test_tool, interceptor)
|
||||
|
||||
result = wrapped.invoke("test")
|
||||
|
||||
# Empty feedback should be treated as rejection
|
||||
assert isinstance(result, dict)
|
||||
assert result["status"] == "rejected"
|
||||
|
||||
def test_interrupt_resolution_with_none_feedback(self):
|
||||
"""Test interrupt resolution with None feedback."""
|
||||
@tool
|
||||
def test_tool(x: str) -> str:
|
||||
"""Test."""
|
||||
return f"Result: {x}"
|
||||
|
||||
with patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
mock_interrupt.return_value = None
|
||||
|
||||
interceptor = ToolInterceptor(["test_tool"])
|
||||
wrapped = ToolInterceptor.wrap_tool(test_tool, interceptor)
|
||||
|
||||
result = wrapped.invoke("test")
|
||||
|
||||
# None feedback should be treated as rejection
|
||||
assert isinstance(result, dict)
|
||||
assert result["status"] == "rejected"
|
||||
@@ -1,247 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from src.tools.tts import VolcengineTTS
|
||||
|
||||
|
||||
class TestVolcengineTTS:
|
||||
"""Test suite for the VolcengineTTS class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test that VolcengineTTS can be properly initialized."""
|
||||
tts = VolcengineTTS(
|
||||
appid="test_appid",
|
||||
access_token="test_token",
|
||||
cluster="test_cluster",
|
||||
voice_type="test_voice",
|
||||
host="test.host.com",
|
||||
)
|
||||
|
||||
assert tts.appid == "test_appid"
|
||||
assert tts.access_token == "test_token"
|
||||
assert tts.cluster == "test_cluster"
|
||||
assert tts.voice_type == "test_voice"
|
||||
assert tts.host == "test.host.com"
|
||||
assert tts.api_url == "https://test.host.com/api/v1/tts"
|
||||
assert tts.header == {"Authorization": "Bearer;test_token"}
|
||||
|
||||
def test_initialization_with_defaults(self):
|
||||
"""Test initialization with default values."""
|
||||
tts = VolcengineTTS(
|
||||
appid="test_appid",
|
||||
access_token="test_token",
|
||||
)
|
||||
|
||||
assert tts.appid == "test_appid"
|
||||
assert tts.access_token == "test_token"
|
||||
assert tts.cluster == "volcano_tts"
|
||||
assert tts.voice_type == "BV700_V2_streaming"
|
||||
assert tts.host == "openspeech.bytedance.com"
|
||||
assert tts.api_url == "https://openspeech.bytedance.com/api/v1/tts"
|
||||
|
||||
@patch("src.tools.tts.requests.post")
|
||||
def test_text_to_speech_success(self, mock_post):
|
||||
"""Test successful text-to-speech conversion."""
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
# Create a base64 encoded string for the mock audio data
|
||||
mock_audio_data = base64.b64encode(b"audio_data").decode()
|
||||
mock_response.json.return_value = {
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": mock_audio_data,
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create TTS client
|
||||
tts = VolcengineTTS(
|
||||
appid="test_appid",
|
||||
access_token="test_token",
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = tts.text_to_speech("Hello, world!")
|
||||
|
||||
# Verify the result
|
||||
assert result["success"] is True
|
||||
assert result["audio_data"] == mock_audio_data
|
||||
assert "response" in result
|
||||
|
||||
# Verify the request
|
||||
mock_post.assert_called_once()
|
||||
args, _ = mock_post.call_args
|
||||
assert args[0] == "https://openspeech.bytedance.com/api/v1/tts"
|
||||
|
||||
# Verify request JSON - the data is passed as the second positional argument
|
||||
request_json = json.loads(args[1])
|
||||
assert request_json["app"]["appid"] == "test_appid"
|
||||
assert request_json["app"]["token"] == "test_token"
|
||||
assert request_json["app"]["cluster"] == "volcano_tts"
|
||||
assert request_json["audio"]["voice_type"] == "BV700_V2_streaming"
|
||||
assert request_json["audio"]["encoding"] == "mp3"
|
||||
assert request_json["request"]["text"] == "Hello, world!"
|
||||
|
||||
@patch("src.tools.tts.requests.post")
|
||||
def test_text_to_speech_api_error(self, mock_post):
|
||||
"""Test error handling when API returns an error."""
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.json.return_value = {
|
||||
"code": 400,
|
||||
"message": "Bad request",
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create TTS client
|
||||
tts = VolcengineTTS(
|
||||
appid="test_appid",
|
||||
access_token="test_token",
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = tts.text_to_speech("Hello, world!")
|
||||
|
||||
# Verify the result
|
||||
assert result["success"] is False
|
||||
assert result["error"] == {"code": 400, "message": "Bad request"}
|
||||
assert result["audio_data"] is None
|
||||
|
||||
@patch("src.tools.tts.requests.post")
|
||||
def test_text_to_speech_no_data(self, mock_post):
|
||||
"""Test error handling when API response doesn't contain data."""
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
# No data field
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create TTS client
|
||||
tts = VolcengineTTS(
|
||||
appid="test_appid",
|
||||
access_token="test_token",
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = tts.text_to_speech("Hello, world!")
|
||||
|
||||
# Verify the result
|
||||
assert result["success"] is False
|
||||
assert result["error"] == "No audio data returned"
|
||||
assert result["audio_data"] is None
|
||||
|
||||
@patch("src.tools.tts.requests.post")
|
||||
def test_text_to_speech_with_custom_parameters(self, mock_post):
|
||||
"""Test text_to_speech with custom parameters."""
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
# Create a base64 encoded string for the mock audio data
|
||||
mock_audio_data = base64.b64encode(b"audio_data").decode()
|
||||
mock_response.json.return_value = {
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": mock_audio_data,
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create TTS client
|
||||
tts = VolcengineTTS(
|
||||
appid="test_appid",
|
||||
access_token="test_token",
|
||||
)
|
||||
|
||||
# Call the method with custom parameters
|
||||
result = tts.text_to_speech(
|
||||
text="Custom text",
|
||||
encoding="wav",
|
||||
speed_ratio=1.2,
|
||||
volume_ratio=0.8,
|
||||
pitch_ratio=1.1,
|
||||
text_type="ssml",
|
||||
with_frontend=0,
|
||||
frontend_type="custom",
|
||||
uid="custom-uid",
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result["success"] is True
|
||||
assert result["audio_data"] == mock_audio_data
|
||||
|
||||
# Verify request JSON - the data is passed as the second positional argument
|
||||
args, kwargs = mock_post.call_args
|
||||
request_json = json.loads(args[1])
|
||||
assert request_json["audio"]["encoding"] == "wav"
|
||||
assert request_json["audio"]["speed_ratio"] == 1.2
|
||||
assert request_json["audio"]["volume_ratio"] == 0.8
|
||||
assert request_json["audio"]["pitch_ratio"] == 1.1
|
||||
assert request_json["request"]["text"] == "Custom text"
|
||||
assert request_json["request"]["text_type"] == "ssml"
|
||||
assert request_json["request"]["with_frontend"] == 0
|
||||
assert request_json["request"]["frontend_type"] == "custom"
|
||||
assert request_json["user"]["uid"] == "custom-uid"
|
||||
|
||||
@patch("src.tools.tts.requests.post")
|
||||
@patch("src.tools.tts.uuid.uuid4")
|
||||
def test_text_to_speech_auto_generated_uid(self, mock_uuid, mock_post):
|
||||
"""Test that UUID is auto-generated if not provided."""
|
||||
# Mock UUID
|
||||
mock_uuid_value = "test-uuid-value"
|
||||
mock_uuid.return_value = mock_uuid_value
|
||||
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
# Create a base64 encoded string for the mock audio data
|
||||
mock_audio_data = base64.b64encode(b"audio_data").decode()
|
||||
mock_response.json.return_value = {
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": mock_audio_data,
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create TTS client
|
||||
tts = VolcengineTTS(
|
||||
appid="test_appid",
|
||||
access_token="test_token",
|
||||
)
|
||||
|
||||
# Call the method without providing a UID
|
||||
result = tts.text_to_speech("Hello, world!")
|
||||
|
||||
# Verify the result
|
||||
assert result["success"] is True
|
||||
assert result["audio_data"] == mock_audio_data
|
||||
|
||||
# Verify the request JSON - the data is passed as the second positional argument
|
||||
args, kwargs = mock_post.call_args
|
||||
request_json = json.loads(args[1])
|
||||
assert request_json["user"]["uid"] == str(mock_uuid_value)
|
||||
|
||||
@patch("src.tools.tts.requests.post")
|
||||
def test_text_to_speech_request_exception(self, mock_post):
|
||||
"""Test error handling when requests.post raises an exception."""
|
||||
# Mock requests.post to raise an exception
|
||||
mock_post.side_effect = Exception("Network error")
|
||||
# Create TTS client
|
||||
tts = VolcengineTTS(
|
||||
appid="test_appid",
|
||||
access_token="test_token",
|
||||
)
|
||||
# Call the method
|
||||
result = tts.text_to_speech("Hello, world!")
|
||||
# Verify the result
|
||||
assert result["success"] is False
|
||||
# The TTS error is caught and returned as a string
|
||||
assert result["error"] == "TTS API call error"
|
||||
assert result["audio_data"] is None
|
||||
@@ -1,135 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for PPT composer localization functionality.
|
||||
|
||||
These tests verify that the ppt_composer_node correctly passes locale information
|
||||
to get_prompt_template, allowing for locale-specific prompt selection.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class MockLLMResponse:
|
||||
"""Mock LLM response object."""
|
||||
|
||||
def __init__(self, content: str = "Mock PPT content"):
|
||||
self.content = content
|
||||
|
||||
|
||||
class MockLLM:
|
||||
"""Mock LLM model with invoke method."""
|
||||
|
||||
def invoke(self, messages):
|
||||
"""Return a mock response."""
|
||||
return MockLLMResponse()
|
||||
|
||||
|
||||
class TestPPTLocalization:
|
||||
"""Test suite for PPT composer locale handling."""
|
||||
|
||||
def test_locale_passed_to_prompt_template(self, monkeypatch):
|
||||
"""
|
||||
Test that when locale is provided in state, it is passed to get_prompt_template.
|
||||
|
||||
This test verifies that ppt_composer_node correctly extracts the locale
|
||||
from the state dict and passes it to get_prompt_template.
|
||||
"""
|
||||
# Track calls to get_prompt_template
|
||||
captured_calls = []
|
||||
|
||||
def mock_get_prompt_template(prompt_name, locale="en-US"):
|
||||
"""Capture the arguments passed to get_prompt_template."""
|
||||
captured_calls.append({"prompt_name": prompt_name, "locale": locale})
|
||||
return "Mock prompt template"
|
||||
|
||||
def mock_get_llm_by_type(llm_type):
|
||||
"""Return a mock LLM."""
|
||||
return MockLLM()
|
||||
|
||||
# Import here to ensure monkeypatching happens before module import
|
||||
import src.ppt.graph.ppt_composer_node as ppt_module
|
||||
|
||||
# Monkeypatch the functions
|
||||
monkeypatch.setattr(
|
||||
ppt_module,
|
||||
"get_prompt_template",
|
||||
mock_get_prompt_template
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ppt_module,
|
||||
"get_llm_by_type",
|
||||
mock_get_llm_by_type
|
||||
)
|
||||
|
||||
# Create state with input and locale
|
||||
state = {
|
||||
"input": "hello",
|
||||
"locale": "zh-CN"
|
||||
}
|
||||
|
||||
# Call the ppt_composer_node
|
||||
result = ppt_module.ppt_composer_node(state)
|
||||
|
||||
# Verify get_prompt_template was called with the correct locale
|
||||
assert len(captured_calls) == 1, "get_prompt_template should be called once"
|
||||
assert captured_calls[0]["prompt_name"] == "ppt/ppt_composer"
|
||||
assert captured_calls[0]["locale"] == "zh-CN", \
|
||||
"get_prompt_template should be called with locale 'zh-CN'"
|
||||
|
||||
# Verify result structure
|
||||
assert "ppt_content" in result
|
||||
assert "ppt_file_path" in result
|
||||
|
||||
def test_default_locale_fallback(self, monkeypatch):
|
||||
"""
|
||||
Test that when locale is missing from state, default locale 'en-US' is used.
|
||||
|
||||
This test verifies that ppt_composer_node falls back to the default locale
|
||||
'en-US' when no locale is provided in the state dict.
|
||||
"""
|
||||
# Track calls to get_prompt_template
|
||||
captured_calls = []
|
||||
|
||||
def mock_get_prompt_template(prompt_name, locale="en-US"):
|
||||
"""Capture the arguments passed to get_prompt_template."""
|
||||
captured_calls.append({"prompt_name": prompt_name, "locale": locale})
|
||||
return "Mock prompt template"
|
||||
|
||||
def mock_get_llm_by_type(llm_type):
|
||||
"""Return a mock LLM."""
|
||||
return MockLLM()
|
||||
|
||||
# Import here to ensure monkeypatching happens before module import
|
||||
import src.ppt.graph.ppt_composer_node as ppt_module
|
||||
|
||||
# Monkeypatch the functions
|
||||
monkeypatch.setattr(
|
||||
ppt_module,
|
||||
"get_prompt_template",
|
||||
mock_get_prompt_template
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ppt_module,
|
||||
"get_llm_by_type",
|
||||
mock_get_llm_by_type
|
||||
)
|
||||
|
||||
# Create state without locale (only input)
|
||||
state = {
|
||||
"input": "hello"
|
||||
}
|
||||
|
||||
# Call the ppt_composer_node
|
||||
result = ppt_module.ppt_composer_node(state)
|
||||
|
||||
# Verify get_prompt_template was called with the default locale
|
||||
assert len(captured_calls) == 1, "get_prompt_template should be called once"
|
||||
assert captured_calls[0]["prompt_name"] == "ppt/ppt_composer"
|
||||
assert captured_calls[0]["locale"] == "en-US", \
|
||||
"get_prompt_template should be called with default locale 'en-US'"
|
||||
|
||||
# Verify result structure
|
||||
assert "ppt_content" in result
|
||||
assert "ppt_file_path" in result
|
||||
@@ -1,133 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Annotated
|
||||
|
||||
# Import MessagesState directly from langgraph rather than through our application
|
||||
from langgraph.graph import MessagesState
|
||||
|
||||
|
||||
# Create stub versions of Plan/Step/StepType to avoid dependencies
|
||||
class StepType:
|
||||
RESEARCH = "research"
|
||||
PROCESSING = "processing"
|
||||
|
||||
|
||||
class Step:
|
||||
def __init__(self, need_search, title, description, step_type):
|
||||
self.need_search = need_search
|
||||
self.title = title
|
||||
self.description = description
|
||||
self.step_type = step_type
|
||||
|
||||
|
||||
class Plan:
|
||||
def __init__(self, locale, has_enough_context, thought, title, steps):
|
||||
self.locale = locale
|
||||
self.has_enough_context = has_enough_context
|
||||
self.thought = thought
|
||||
self.title = title
|
||||
self.steps = steps
|
||||
|
||||
|
||||
# Import the actual State class by loading the module directly
|
||||
# This avoids the cascade of imports that would normally happen
|
||||
def load_state_class():
|
||||
# Get the absolute path to the types.py file
|
||||
src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
types_path = os.path.join(src_dir, "graph", "types.py")
|
||||
|
||||
# Create a namespace for the module
|
||||
import types
|
||||
|
||||
module_name = "src.graph.types_direct"
|
||||
spec = types.ModuleType(module_name)
|
||||
|
||||
# Add the module to sys.modules to avoid import loops
|
||||
sys.modules[module_name] = spec
|
||||
|
||||
# Set up the namespace with required imports
|
||||
spec.__dict__["operator"] = __import__("operator")
|
||||
spec.__dict__["Annotated"] = Annotated
|
||||
spec.__dict__["MessagesState"] = MessagesState
|
||||
spec.__dict__["Plan"] = Plan
|
||||
|
||||
# Execute the module code
|
||||
with open(types_path, "r") as f:
|
||||
module_code = f.read()
|
||||
|
||||
exec(module_code, spec.__dict__)
|
||||
|
||||
# Return the State class
|
||||
return spec.State
|
||||
|
||||
|
||||
# Load the actual State class
|
||||
State = load_state_class()
|
||||
|
||||
|
||||
def test_state_initialization():
|
||||
"""Test that State class has correct default attribute definitions."""
|
||||
# Test that the class has the expected attribute definitions
|
||||
assert State.locale == "en-US"
|
||||
assert State.observations == []
|
||||
assert State.plan_iterations == 0
|
||||
assert State.current_plan is None
|
||||
assert State.final_report == ""
|
||||
assert State.auto_accepted_plan is False
|
||||
assert State.enable_background_investigation is True
|
||||
assert State.background_investigation_results is None
|
||||
|
||||
# Verify state initialization
|
||||
state = State(messages=[])
|
||||
assert "messages" in state
|
||||
|
||||
# Without explicitly passing attributes, they're not in the state
|
||||
assert "locale" not in state
|
||||
assert "observations" not in state
|
||||
|
||||
|
||||
def test_state_with_custom_values():
|
||||
"""Test that State can be initialized with custom values."""
|
||||
test_step = Step(
|
||||
need_search=True,
|
||||
title="Test Step",
|
||||
description="Step description",
|
||||
step_type=StepType.RESEARCH,
|
||||
)
|
||||
|
||||
test_plan = Plan(
|
||||
locale="en-US",
|
||||
has_enough_context=False,
|
||||
thought="Test thought",
|
||||
title="Test Plan",
|
||||
steps=[test_step],
|
||||
)
|
||||
|
||||
# Initialize state with custom values and required messages field
|
||||
state = State(
|
||||
messages=[],
|
||||
locale="fr-FR",
|
||||
observations=["Observation 1"],
|
||||
plan_iterations=2,
|
||||
current_plan=test_plan,
|
||||
final_report="Test report",
|
||||
auto_accepted_plan=True,
|
||||
enable_background_investigation=False,
|
||||
background_investigation_results="Test results",
|
||||
)
|
||||
|
||||
# Access state keys - these are explicitly initialized
|
||||
assert state["locale"] == "fr-FR"
|
||||
assert state["observations"] == ["Observation 1"]
|
||||
assert state["plan_iterations"] == 2
|
||||
assert state["current_plan"].title == "Test Plan"
|
||||
assert state["current_plan"].thought == "Test thought"
|
||||
assert len(state["current_plan"].steps) == 1
|
||||
assert state["current_plan"].steps[0].title == "Test Step"
|
||||
assert state["final_report"] == "Test report"
|
||||
assert state["auto_accepted_plan"] is True
|
||||
assert state["enable_background_investigation"] is False
|
||||
assert state["background_investigation_results"] == "Test results"
|
||||
@@ -1,335 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from src.agents.agents import DynamicPromptMiddleware, PreModelHookMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime():
|
||||
"""Mock Runtime object."""
|
||||
runtime = MagicMock()
|
||||
runtime.config = {}
|
||||
return runtime
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
"""Mock state object."""
|
||||
return {
|
||||
"messages": [HumanMessage(content="Test message")],
|
||||
"context": "Test context",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messages():
|
||||
"""Mock messages returned by apply_prompt_template."""
|
||||
return [
|
||||
SystemMessage(content="Test system prompt"),
|
||||
HumanMessage(content="Test human message"),
|
||||
]
|
||||
|
||||
|
||||
class TestDynamicPromptMiddleware:
|
||||
"""Tests for DynamicPromptMiddleware class."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test middleware initialization."""
|
||||
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
|
||||
assert middleware.prompt_template == "test_template"
|
||||
assert middleware.locale == "zh-CN"
|
||||
|
||||
def test_init_default_locale(self):
|
||||
"""Test middleware initialization with default locale."""
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
assert middleware.prompt_template == "test_template"
|
||||
assert middleware.locale == "en-US"
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
def test_before_model_success(
|
||||
self, mock_apply_template, mock_state, mock_runtime, mock_messages
|
||||
):
|
||||
"""Test before_model successfully applies prompt template."""
|
||||
mock_apply_template.return_value = mock_messages
|
||||
middleware = DynamicPromptMiddleware("test_template", locale="en-US")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify apply_prompt_template was called with correct arguments
|
||||
mock_apply_template.assert_called_once_with(
|
||||
"test_template", mock_state, locale="en-US"
|
||||
)
|
||||
|
||||
# Verify system message is returned
|
||||
assert result == {"messages": [mock_messages[0]]}
|
||||
assert result["messages"][0].content == "Test system prompt"
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
def test_before_model_empty_messages(
|
||||
self, mock_apply_template, mock_state, mock_runtime
|
||||
):
|
||||
"""Test before_model with empty message list."""
|
||||
mock_apply_template.return_value = []
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None when no messages are rendered
|
||||
assert result is None
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
def test_before_model_none_messages(
|
||||
self, mock_apply_template, mock_state, mock_runtime
|
||||
):
|
||||
"""Test before_model when apply_prompt_template returns None."""
|
||||
mock_apply_template.return_value = None
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None when template returns None
|
||||
assert result is None
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
@patch("src.agents.agents.logger")
|
||||
def test_before_model_exception_handling(
|
||||
self, mock_logger, mock_apply_template, mock_state, mock_runtime
|
||||
):
|
||||
"""Test before_model handles exceptions gracefully."""
|
||||
mock_apply_template.side_effect = ValueError("Template rendering failed")
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None on exception
|
||||
assert result is None
|
||||
|
||||
# Should log error with exc_info
|
||||
mock_logger.error.assert_called_once()
|
||||
error_message = mock_logger.error.call_args[0][0]
|
||||
assert "Failed to apply prompt template in before_model" in error_message
|
||||
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
def test_before_model_with_different_locale(
|
||||
self, mock_apply_template, mock_state, mock_runtime, mock_messages
|
||||
):
|
||||
"""Test before_model with different locale."""
|
||||
mock_apply_template.return_value = mock_messages
|
||||
middleware = DynamicPromptMiddleware("test_template", locale="zh-CN")
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify locale is passed correctly
|
||||
mock_apply_template.assert_called_once_with(
|
||||
"test_template", mock_state, locale="zh-CN"
|
||||
)
|
||||
assert result == {"messages": [mock_messages[0]]}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.agents.agents.apply_prompt_template")
|
||||
async def test_abefore_model(
|
||||
self, mock_apply_template, mock_state, mock_runtime, mock_messages
|
||||
):
|
||||
"""Test async version of before_model."""
|
||||
mock_apply_template.return_value = mock_messages
|
||||
middleware = DynamicPromptMiddleware("test_template")
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Should call the sync version and return same result
|
||||
assert result == {"messages": [mock_messages[0]]}
|
||||
mock_apply_template.assert_called_once_with(
|
||||
"test_template", mock_state, locale="en-US"
|
||||
)
|
||||
|
||||
|
||||
class TestPreModelHookMiddleware:
|
||||
"""Tests for PreModelHookMiddleware class."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test middleware initialization."""
|
||||
hook = Mock()
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
assert middleware._pre_model_hook == hook
|
||||
|
||||
def test_before_model_with_sync_hook(self, mock_state, mock_runtime):
|
||||
"""Test before_model with synchronous hook."""
|
||||
hook = Mock(return_value={"custom_data": "test"})
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify hook was called with correct arguments
|
||||
hook.assert_called_once_with(mock_state, mock_runtime)
|
||||
assert result == {"custom_data": "test"}
|
||||
|
||||
def test_before_model_with_none_hook(self, mock_state, mock_runtime):
|
||||
"""Test before_model when hook is None."""
|
||||
middleware = PreModelHookMiddleware(None)
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None when hook is None
|
||||
assert result is None
|
||||
|
||||
def test_before_model_hook_returns_none(self, mock_state, mock_runtime):
|
||||
"""Test before_model when hook returns None."""
|
||||
hook = Mock(return_value=None)
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
hook.assert_called_once_with(mock_state, mock_runtime)
|
||||
assert result is None
|
||||
|
||||
@patch("src.agents.agents.logger")
|
||||
def test_before_model_hook_exception(
|
||||
self, mock_logger, mock_state, mock_runtime
|
||||
):
|
||||
"""Test before_model handles hook exceptions gracefully."""
|
||||
hook = Mock(side_effect=RuntimeError("Hook execution failed"))
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = middleware.before_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None on exception
|
||||
assert result is None
|
||||
|
||||
# Should log error with exc_info
|
||||
mock_logger.error.assert_called_once()
|
||||
error_message = mock_logger.error.call_args[0][0]
|
||||
assert "Pre-model hook execution failed in before_model" in error_message
|
||||
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abefore_model_with_async_hook(self, mock_state, mock_runtime):
|
||||
"""Test async before_model with async hook."""
|
||||
async def async_hook(state, runtime):
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
return {"async_data": "test"}
|
||||
|
||||
middleware = PreModelHookMiddleware(async_hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
assert result == {"async_data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.agents.agents.asyncio.to_thread")
|
||||
async def test_abefore_model_with_sync_hook(
|
||||
self, mock_to_thread, mock_state, mock_runtime
|
||||
):
|
||||
"""Test async before_model with synchronous hook uses asyncio.to_thread."""
|
||||
hook = Mock(return_value={"sync_data": "test"})
|
||||
mock_to_thread.return_value = {"sync_data": "test"}
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify asyncio.to_thread was called with the sync hook
|
||||
mock_to_thread.assert_called_once_with(hook, mock_state, mock_runtime)
|
||||
assert result == {"sync_data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abefore_model_with_none_hook(self, mock_state, mock_runtime):
|
||||
"""Test async before_model when hook is None."""
|
||||
middleware = PreModelHookMiddleware(None)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None when hook is None
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.agents.agents.logger")
|
||||
async def test_abefore_model_async_hook_exception(
|
||||
self, mock_logger, mock_state, mock_runtime
|
||||
):
|
||||
"""Test async before_model handles async hook exceptions gracefully."""
|
||||
async def failing_hook(state, runtime):
|
||||
raise ValueError("Async hook failed")
|
||||
|
||||
middleware = PreModelHookMiddleware(failing_hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None on exception
|
||||
assert result is None
|
||||
|
||||
# Should log error with exc_info
|
||||
mock_logger.error.assert_called_once()
|
||||
error_message = mock_logger.error.call_args[0][0]
|
||||
assert "Pre-model hook execution failed in abefore_model" in error_message
|
||||
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.agents.agents.asyncio.to_thread")
|
||||
@patch("src.agents.agents.logger")
|
||||
async def test_abefore_model_sync_hook_exception(
|
||||
self, mock_logger, mock_to_thread, mock_state, mock_runtime
|
||||
):
|
||||
"""Test async before_model handles sync hook exceptions gracefully."""
|
||||
hook = Mock()
|
||||
mock_to_thread.side_effect = RuntimeError("Thread execution failed")
|
||||
middleware = PreModelHookMiddleware(hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Should return None on exception
|
||||
assert result is None
|
||||
|
||||
# Should log error with exc_info
|
||||
mock_logger.error.assert_called_once()
|
||||
error_message = mock_logger.error.call_args[0][0]
|
||||
assert "Pre-model hook execution failed in abefore_model" in error_message
|
||||
assert mock_logger.error.call_args[1]["exc_info"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abefore_model_sync_hook_actual_execution(
|
||||
self, mock_state, mock_runtime
|
||||
):
|
||||
"""Test async before_model actually runs sync hook in thread pool."""
|
||||
# Track if hook was called
|
||||
hook_called = []
|
||||
|
||||
def sync_hook(state, runtime):
|
||||
hook_called.append(True)
|
||||
return {"data": "from_sync_hook"}
|
||||
|
||||
middleware = PreModelHookMiddleware(sync_hook)
|
||||
|
||||
result = await middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
# Verify hook was called and result returned
|
||||
assert len(hook_called) == 1
|
||||
assert result == {"data": "from_sync_hook"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abefore_model_detects_coroutine_function(
|
||||
self, mock_state, mock_runtime
|
||||
):
|
||||
"""Test that abefore_model correctly detects async vs sync functions."""
|
||||
# Test with async function
|
||||
async def async_hook(state, runtime):
|
||||
return {"type": "async"}
|
||||
|
||||
# Test with sync function
|
||||
def sync_hook(state, runtime):
|
||||
return {"type": "sync"}
|
||||
|
||||
async_middleware = PreModelHookMiddleware(async_hook)
|
||||
sync_middleware = PreModelHookMiddleware(sync_hook)
|
||||
|
||||
# Both should execute successfully
|
||||
async_result = await async_middleware.abefore_model(mock_state, mock_runtime)
|
||||
sync_result = await sync_middleware.abefore_model(mock_state, mock_runtime)
|
||||
|
||||
assert async_result == {"type": "async"}
|
||||
assert sync_result == {"type": "sync"}
|
||||
@@ -1,434 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
|
||||
from src.agents.tool_interceptor import (
|
||||
ToolInterceptor,
|
||||
wrap_tools_with_interceptor,
|
||||
)
|
||||
|
||||
|
||||
class TestToolInterceptor:
|
||||
"""Tests for ToolInterceptor class."""
|
||||
|
||||
def test_init_with_tools(self):
|
||||
"""Test initializing interceptor with tool list."""
|
||||
tools = ["db_tool", "api_tool"]
|
||||
interceptor = ToolInterceptor(tools)
|
||||
assert interceptor.interrupt_before_tools == tools
|
||||
|
||||
def test_init_without_tools(self):
|
||||
"""Test initializing interceptor without tools."""
|
||||
interceptor = ToolInterceptor()
|
||||
assert interceptor.interrupt_before_tools == []
|
||||
|
||||
def test_should_interrupt_with_matching_tool(self):
|
||||
"""Test should_interrupt returns True for matching tools."""
|
||||
tools = ["db_tool", "api_tool"]
|
||||
interceptor = ToolInterceptor(tools)
|
||||
assert interceptor.should_interrupt("db_tool") is True
|
||||
assert interceptor.should_interrupt("api_tool") is True
|
||||
|
||||
def test_should_interrupt_with_non_matching_tool(self):
|
||||
"""Test should_interrupt returns False for non-matching tools."""
|
||||
tools = ["db_tool", "api_tool"]
|
||||
interceptor = ToolInterceptor(tools)
|
||||
assert interceptor.should_interrupt("search_tool") is False
|
||||
assert interceptor.should_interrupt("crawl_tool") is False
|
||||
|
||||
def test_should_interrupt_empty_list(self):
|
||||
"""Test should_interrupt with empty interrupt list."""
|
||||
interceptor = ToolInterceptor([])
|
||||
assert interceptor.should_interrupt("db_tool") is False
|
||||
|
||||
def test_parse_approval_with_approval_keywords(self):
|
||||
"""Test parsing user feedback with approval keywords."""
|
||||
assert ToolInterceptor._parse_approval("approved") is True
|
||||
assert ToolInterceptor._parse_approval("approve") is True
|
||||
assert ToolInterceptor._parse_approval("yes") is True
|
||||
assert ToolInterceptor._parse_approval("proceed") is True
|
||||
assert ToolInterceptor._parse_approval("continue") is True
|
||||
assert ToolInterceptor._parse_approval("ok") is True
|
||||
assert ToolInterceptor._parse_approval("okay") is True
|
||||
assert ToolInterceptor._parse_approval("accepted") is True
|
||||
assert ToolInterceptor._parse_approval("accept") is True
|
||||
assert ToolInterceptor._parse_approval("[approved]") is True
|
||||
|
||||
def test_parse_approval_case_insensitive(self):
|
||||
"""Test parsing is case-insensitive."""
|
||||
assert ToolInterceptor._parse_approval("APPROVED") is True
|
||||
assert ToolInterceptor._parse_approval("Approved") is True
|
||||
assert ToolInterceptor._parse_approval("PROCEED") is True
|
||||
|
||||
def test_parse_approval_with_surrounding_text(self):
|
||||
"""Test parsing with surrounding text."""
|
||||
assert ToolInterceptor._parse_approval("Sure, proceed with the tool") is True
|
||||
assert ToolInterceptor._parse_approval("[ACCEPTED] I approve this") is True
|
||||
|
||||
def test_parse_approval_rejection(self):
|
||||
"""Test parsing rejects non-approval feedback."""
|
||||
assert ToolInterceptor._parse_approval("no") is False
|
||||
assert ToolInterceptor._parse_approval("reject") is False
|
||||
assert ToolInterceptor._parse_approval("cancel") is False
|
||||
assert ToolInterceptor._parse_approval("random feedback") is False
|
||||
|
||||
def test_parse_approval_empty_string(self):
|
||||
"""Test parsing empty string."""
|
||||
assert ToolInterceptor._parse_approval("") is False
|
||||
|
||||
def test_parse_approval_none(self):
|
||||
"""Test parsing None."""
|
||||
assert ToolInterceptor._parse_approval(None) is False
|
||||
|
||||
@patch("src.agents.tool_interceptor.interrupt")
|
||||
def test_wrap_tool_with_interrupt(self, mock_interrupt):
|
||||
"""Test wrapping a tool with interrupt."""
|
||||
mock_interrupt.return_value = "approved"
|
||||
|
||||
# Create a simple test tool
|
||||
@tool
|
||||
def test_tool(input_text: str) -> str:
|
||||
"""Test tool."""
|
||||
return f"Result: {input_text}"
|
||||
|
||||
interceptor = ToolInterceptor(["test_tool"])
|
||||
|
||||
# Wrap the tool
|
||||
wrapped_tool = ToolInterceptor.wrap_tool(test_tool, interceptor)
|
||||
|
||||
# Invoke the wrapped tool
|
||||
result = wrapped_tool.invoke("hello")
|
||||
|
||||
# Verify interrupt was called
|
||||
mock_interrupt.assert_called_once()
|
||||
assert "test_tool" in mock_interrupt.call_args[0][0]
|
||||
|
||||
@patch("src.agents.tool_interceptor.interrupt")
|
||||
def test_wrap_tool_without_interrupt(self, mock_interrupt):
|
||||
"""Test wrapping a tool that doesn't trigger interrupt."""
|
||||
# Create a simple test tool
|
||||
@tool
|
||||
def test_tool(input_text: str) -> str:
|
||||
"""Test tool."""
|
||||
return f"Result: {input_text}"
|
||||
|
||||
interceptor = ToolInterceptor(["other_tool"])
|
||||
|
||||
# Wrap the tool
|
||||
wrapped_tool = ToolInterceptor.wrap_tool(test_tool, interceptor)
|
||||
|
||||
# Invoke the wrapped tool
|
||||
result = wrapped_tool.invoke("hello")
|
||||
|
||||
# Verify interrupt was NOT called
|
||||
mock_interrupt.assert_not_called()
|
||||
assert "Result: hello" in str(result)
|
||||
|
||||
@patch("src.agents.tool_interceptor.interrupt")
|
||||
def test_wrap_tool_user_rejects(self, mock_interrupt):
|
||||
"""Test user rejecting tool execution."""
|
||||
mock_interrupt.return_value = "no"
|
||||
|
||||
@tool
|
||||
def test_tool(input_text: str) -> str:
|
||||
"""Test tool."""
|
||||
return f"Result: {input_text}"
|
||||
|
||||
interceptor = ToolInterceptor(["test_tool"])
|
||||
wrapped_tool = ToolInterceptor.wrap_tool(test_tool, interceptor)
|
||||
|
||||
# Invoke the wrapped tool
|
||||
result = wrapped_tool.invoke("hello")
|
||||
|
||||
# Verify tool was not executed
|
||||
assert isinstance(result, dict)
|
||||
assert "error" in result
|
||||
assert result["status"] == "rejected"
|
||||
|
||||
def test_wrap_tools_with_interceptor_empty_list(self):
|
||||
"""Test wrapping tools with empty interrupt list."""
|
||||
@tool
|
||||
def test_tool(input_text: str) -> str:
|
||||
"""Test tool."""
|
||||
return f"Result: {input_text}"
|
||||
|
||||
tools = [test_tool]
|
||||
wrapped_tools = wrap_tools_with_interceptor(tools, [])
|
||||
|
||||
# Should return tools as-is
|
||||
assert len(wrapped_tools) == 1
|
||||
assert wrapped_tools[0].name == "test_tool"
|
||||
|
||||
def test_wrap_tools_with_interceptor_none(self):
|
||||
"""Test wrapping tools with None interrupt list."""
|
||||
@tool
|
||||
def test_tool(input_text: str) -> str:
|
||||
"""Test tool."""
|
||||
return f"Result: {input_text}"
|
||||
|
||||
tools = [test_tool]
|
||||
wrapped_tools = wrap_tools_with_interceptor(tools, None)
|
||||
|
||||
# Should return tools as-is
|
||||
assert len(wrapped_tools) == 1
|
||||
|
||||
@patch("src.agents.tool_interceptor.interrupt")
|
||||
def test_wrap_tools_with_interceptor_multiple(self, mock_interrupt):
|
||||
"""Test wrapping multiple tools."""
|
||||
mock_interrupt.return_value = "approved"
|
||||
|
||||
@tool
|
||||
def db_tool(query: str) -> str:
|
||||
"""DB tool."""
|
||||
return f"Query result: {query}"
|
||||
|
||||
@tool
|
||||
def search_tool(query: str) -> str:
|
||||
"""Search tool."""
|
||||
return f"Search result: {query}"
|
||||
|
||||
tools = [db_tool, search_tool]
|
||||
wrapped_tools = wrap_tools_with_interceptor(tools, ["db_tool"])
|
||||
|
||||
# Only db_tool should trigger interrupt
|
||||
db_result = wrapped_tools[0].invoke("test query")
|
||||
assert mock_interrupt.call_count == 1
|
||||
|
||||
search_result = wrapped_tools[1].invoke("test query")
|
||||
# No additional interrupt calls for search_tool
|
||||
assert mock_interrupt.call_count == 1
|
||||
|
||||
def test_wrap_tool_preserves_tool_properties(self):
|
||||
"""Test that wrapping preserves tool properties."""
|
||||
@tool
|
||||
def my_tool(input_text: str) -> str:
|
||||
"""My tool description."""
|
||||
return f"Result: {input_text}"
|
||||
|
||||
interceptor = ToolInterceptor([])
|
||||
wrapped_tool = ToolInterceptor.wrap_tool(my_tool, interceptor)
|
||||
|
||||
assert wrapped_tool.name == "my_tool"
|
||||
assert wrapped_tool.description == "My tool description."
|
||||
|
||||
|
||||
class TestFormatToolInput:
|
||||
"""Tests for tool input formatting functionality."""
|
||||
|
||||
def test_format_tool_input_none(self):
|
||||
"""Test formatting None input."""
|
||||
result = ToolInterceptor._format_tool_input(None)
|
||||
assert result == "No input"
|
||||
|
||||
def test_format_tool_input_string(self):
|
||||
"""Test formatting string input."""
|
||||
input_str = "SELECT * FROM users"
|
||||
result = ToolInterceptor._format_tool_input(input_str)
|
||||
assert result == input_str
|
||||
|
||||
def test_format_tool_input_simple_dict(self):
|
||||
"""Test formatting simple dictionary."""
|
||||
input_dict = {"query": "test", "limit": 10}
|
||||
result = ToolInterceptor._format_tool_input(input_dict)
|
||||
|
||||
# Should be valid JSON
|
||||
import json
|
||||
parsed = json.loads(result)
|
||||
assert parsed == input_dict
|
||||
# Should be indented
|
||||
assert "\n" in result
|
||||
|
||||
def test_format_tool_input_nested_dict(self):
|
||||
"""Test formatting nested dictionary."""
|
||||
input_dict = {
|
||||
"query": "SELECT * FROM users",
|
||||
"config": {
|
||||
"timeout": 30,
|
||||
"retry": True
|
||||
}
|
||||
}
|
||||
result = ToolInterceptor._format_tool_input(input_dict)
|
||||
|
||||
import json
|
||||
parsed = json.loads(result)
|
||||
assert parsed == input_dict
|
||||
assert "timeout" in result
|
||||
assert "retry" in result
|
||||
|
||||
def test_format_tool_input_list(self):
|
||||
"""Test formatting list input."""
|
||||
input_list = ["item1", "item2", 123]
|
||||
result = ToolInterceptor._format_tool_input(input_list)
|
||||
|
||||
import json
|
||||
parsed = json.loads(result)
|
||||
assert parsed == input_list
|
||||
|
||||
def test_format_tool_input_complex_list(self):
|
||||
"""Test formatting list with mixed types."""
|
||||
input_list = ["text", 42, 3.14, True, {"key": "value"}]
|
||||
result = ToolInterceptor._format_tool_input(input_list)
|
||||
|
||||
import json
|
||||
parsed = json.loads(result)
|
||||
assert parsed == input_list
|
||||
|
||||
def test_format_tool_input_tuple(self):
|
||||
"""Test formatting tuple input."""
|
||||
input_tuple = ("item1", "item2", 123)
|
||||
result = ToolInterceptor._format_tool_input(input_tuple)
|
||||
|
||||
import json
|
||||
parsed = json.loads(result)
|
||||
# JSON converts tuples to lists
|
||||
assert parsed == list(input_tuple)
|
||||
|
||||
def test_format_tool_input_integer(self):
|
||||
"""Test formatting integer input."""
|
||||
result = ToolInterceptor._format_tool_input(42)
|
||||
assert result == "42"
|
||||
|
||||
def test_format_tool_input_float(self):
|
||||
"""Test formatting float input."""
|
||||
result = ToolInterceptor._format_tool_input(3.14)
|
||||
assert result == "3.14"
|
||||
|
||||
def test_format_tool_input_boolean(self):
|
||||
"""Test formatting boolean input."""
|
||||
result_true = ToolInterceptor._format_tool_input(True)
|
||||
result_false = ToolInterceptor._format_tool_input(False)
|
||||
assert result_true == "True"
|
||||
assert result_false == "False"
|
||||
|
||||
def test_format_tool_input_deeply_nested(self):
|
||||
"""Test formatting deeply nested structure."""
|
||||
input_dict = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"level4": ["a", "b", "c"],
|
||||
"data": {"key": "value"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result = ToolInterceptor._format_tool_input(input_dict)
|
||||
|
||||
import json
|
||||
parsed = json.loads(result)
|
||||
assert parsed == input_dict
|
||||
|
||||
def test_format_tool_input_empty_dict(self):
|
||||
"""Test formatting empty dictionary."""
|
||||
result = ToolInterceptor._format_tool_input({})
|
||||
assert result == "{}"
|
||||
|
||||
def test_format_tool_input_empty_list(self):
|
||||
"""Test formatting empty list."""
|
||||
result = ToolInterceptor._format_tool_input([])
|
||||
assert result == "[]"
|
||||
|
||||
def test_format_tool_input_special_characters(self):
|
||||
"""Test formatting dict with special characters."""
|
||||
input_dict = {
|
||||
"query": 'SELECT * FROM users WHERE name = "John"',
|
||||
"path": "/usr/local/bin",
|
||||
"unicode": "你好世界"
|
||||
}
|
||||
result = ToolInterceptor._format_tool_input(input_dict)
|
||||
|
||||
import json
|
||||
parsed = json.loads(result)
|
||||
assert parsed == input_dict
|
||||
|
||||
def test_format_tool_input_numbers_as_strings(self):
|
||||
"""Test formatting with various number types."""
|
||||
input_dict = {
|
||||
"int": 42,
|
||||
"float": 3.14159,
|
||||
"negative": -100,
|
||||
"zero": 0,
|
||||
"scientific": 1e-5
|
||||
}
|
||||
result = ToolInterceptor._format_tool_input(input_dict)
|
||||
|
||||
import json
|
||||
parsed = json.loads(result)
|
||||
assert parsed["int"] == 42
|
||||
assert abs(parsed["float"] - 3.14159) < 0.00001
|
||||
assert parsed["negative"] == -100
|
||||
assert parsed["zero"] == 0
|
||||
|
||||
def test_format_tool_input_with_none_values(self):
|
||||
"""Test formatting dict with None values."""
|
||||
input_dict = {
|
||||
"key1": "value1",
|
||||
"key2": None,
|
||||
"key3": {"nested": None}
|
||||
}
|
||||
result = ToolInterceptor._format_tool_input(input_dict)
|
||||
|
||||
import json
|
||||
parsed = json.loads(result)
|
||||
assert parsed == input_dict
|
||||
|
||||
def test_format_tool_input_indentation(self):
|
||||
"""Test that output uses proper indentation (2 spaces)."""
|
||||
input_dict = {"outer": {"inner": "value"}}
|
||||
result = ToolInterceptor._format_tool_input(input_dict)
|
||||
|
||||
# Should have indented lines
|
||||
assert " " in result # 2-space indentation
|
||||
lines = result.split("\n")
|
||||
# Check that indentation increases with nesting
|
||||
assert any(line.startswith(" ") for line in lines)
|
||||
|
||||
def test_format_tool_input_preserves_order_insertion(self):
|
||||
"""Test that dict order is preserved in output."""
|
||||
input_dict = {
|
||||
"first": 1,
|
||||
"second": 2,
|
||||
"third": 3
|
||||
}
|
||||
result = ToolInterceptor._format_tool_input(input_dict)
|
||||
|
||||
import json
|
||||
parsed = json.loads(result)
|
||||
# Verify all keys are present
|
||||
assert set(parsed.keys()) == {"first", "second", "third"}
|
||||
|
||||
def test_format_tool_input_long_strings(self):
|
||||
"""Test formatting with long string values."""
|
||||
long_string = "x" * 1000
|
||||
input_dict = {"long": long_string}
|
||||
result = ToolInterceptor._format_tool_input(input_dict)
|
||||
|
||||
import json
|
||||
parsed = json.loads(result)
|
||||
assert parsed["long"] == long_string
|
||||
|
||||
def test_format_tool_input_mixed_types_in_list(self):
|
||||
"""Test formatting list with mixed complex types."""
|
||||
input_list = [
|
||||
"string",
|
||||
42,
|
||||
{"dict": "value"},
|
||||
[1, 2, 3],
|
||||
True,
|
||||
None
|
||||
]
|
||||
result = ToolInterceptor._format_tool_input(input_list)
|
||||
|
||||
import json
|
||||
parsed = json.loads(result)
|
||||
assert len(parsed) == 6
|
||||
assert parsed[0] == "string"
|
||||
assert parsed[1] == 42
|
||||
assert parsed[2] == {"dict": "value"}
|
||||
assert parsed[3] == [1, 2, 3]
|
||||
assert parsed[4] is True
|
||||
assert parsed[5] is None
|
||||
@@ -1,69 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
from langchain_core.tools import Tool
|
||||
from src.agents.tool_interceptor import ToolInterceptor
|
||||
|
||||
class TestToolInterceptorFix(unittest.TestCase):
|
||||
def test_interceptor_patches_run_method(self):
|
||||
# Create a mock tool
|
||||
mock_func = MagicMock(return_value="Original Result")
|
||||
tool = Tool(name="resolve_company_name", func=mock_func, description="test tool")
|
||||
|
||||
# Interceptor that always interrupts 'resolve_company_name'
|
||||
interceptor = ToolInterceptor(interrupt_before_tools=["resolve_company_name"])
|
||||
|
||||
# Wrap the tool
|
||||
wrapped_tool = ToolInterceptor.wrap_tool(tool, interceptor)
|
||||
|
||||
# Mock interrupt to avoid actual suspension
|
||||
with unittest.mock.patch("src.agents.tool_interceptor.interrupt", return_value="approved"):
|
||||
# Call using .run() which triggers ._run()
|
||||
# Standard BaseTool execution flow is invoke -> run -> _run
|
||||
# If we only patched func, run() would call original _run which calls original func, bypassing interception
|
||||
# With the fix, _run should be patched to call intercepted_func
|
||||
result = wrapped_tool.run("some input")
|
||||
|
||||
# Verify result
|
||||
self.assertEqual(result, "Original Result")
|
||||
|
||||
# Verify the original function was called
|
||||
# If interception works, intercepted_func calls original_func
|
||||
mock_func.assert_called_once()
|
||||
|
||||
def test_run_method_without_interrupt(self):
|
||||
"""Test that tools not in interrupt list work normally via .run()"""
|
||||
mock_func = MagicMock(return_value="Result")
|
||||
tool = Tool(name="other_tool", func=mock_func, description="test")
|
||||
|
||||
interceptor = ToolInterceptor(interrupt_before_tools=["resolve_company_name"])
|
||||
wrapped_tool = ToolInterceptor.wrap_tool(tool, interceptor)
|
||||
|
||||
with unittest.mock.patch("src.agents.tool_interceptor.interrupt") as mock_interrupt:
|
||||
result = wrapped_tool.run("input")
|
||||
|
||||
# Verify interrupt was NOT called for non-intercepted tool
|
||||
mock_interrupt.assert_not_called()
|
||||
assert result == "Result"
|
||||
mock_func.assert_called_once()
|
||||
|
||||
def test_interceptor_resolve_company_name_example(self):
|
||||
"""Test specific resolve_company_name logic capability using interceptor subclassing or custom logic simulation."""
|
||||
# This test verifies that we can intercept execution of resolve_company_name
|
||||
# even if it's called via .run()
|
||||
|
||||
mock_func = MagicMock(return_value='{"code": 0, "data": [{"companyName": "A"}, {"companyName": "B"}]}')
|
||||
tool = Tool(name="resolve_company_name", func=mock_func, description="resolve company")
|
||||
|
||||
interceptor = ToolInterceptor(interrupt_before_tools=["resolve_company_name"])
|
||||
wrapped_tool = ToolInterceptor.wrap_tool(tool, interceptor)
|
||||
|
||||
# Simulate user selecting "B"
|
||||
with unittest.mock.patch("src.agents.tool_interceptor.interrupt", return_value="approved"):
|
||||
# We are not testing the complex business logic here because we didn't add it to ToolInterceptor class
|
||||
# We are mostly verifying that the INTERCEPTION mechanism works for this tool name when called via .run()
|
||||
wrapped_tool.run("query")
|
||||
|
||||
mock_func.assert_called_once()
|
||||
@@ -1,153 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import psycopg
|
||||
import pytest
|
||||
|
||||
|
||||
class PostgreSQLMockInstance:
|
||||
"""Utility class for managing PostgreSQL mock instances."""
|
||||
|
||||
def __init__(self, database_name: str = "test_db"):
|
||||
self.database_name = database_name
|
||||
self.temp_dir: Optional[Path] = None
|
||||
self.mock_connection: Optional[MagicMock] = None
|
||||
self.mock_data: Dict[str, Any] = {}
|
||||
self._setup_mock_data()
|
||||
|
||||
def _setup_mock_data(self):
|
||||
"""Initialize mock data storage."""
|
||||
self.mock_data = {
|
||||
"chat_streams": {}, # thread_id -> record
|
||||
"table_exists": False,
|
||||
"connection_active": True,
|
||||
}
|
||||
|
||||
def connect(self) -> MagicMock:
|
||||
"""Create a mock PostgreSQL connection."""
|
||||
self.mock_connection = MagicMock()
|
||||
self._setup_mock_methods()
|
||||
return self.mock_connection
|
||||
|
||||
def _setup_mock_methods(self):
|
||||
"""Setup mock methods for PostgreSQL operations."""
|
||||
if not self.mock_connection:
|
||||
return
|
||||
|
||||
# Mock cursor context manager
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# Setup cursor operations
|
||||
mock_cursor.execute = MagicMock(side_effect=self._mock_execute)
|
||||
mock_cursor.fetchone = MagicMock(side_effect=self._mock_fetchone)
|
||||
mock_cursor.rowcount = 0
|
||||
|
||||
# Setup connection operations
|
||||
self.mock_connection.cursor = MagicMock(return_value=mock_cursor)
|
||||
self.mock_connection.commit = MagicMock()
|
||||
self.mock_connection.rollback = MagicMock()
|
||||
self.mock_connection.close = MagicMock()
|
||||
|
||||
# Store cursor for external access
|
||||
self._mock_cursor = mock_cursor
|
||||
|
||||
def _mock_execute(self, sql: str, params=None):
|
||||
"""Mock SQL execution."""
|
||||
sql_upper = sql.upper().strip()
|
||||
|
||||
if "CREATE TABLE" in sql_upper:
|
||||
self.mock_data["table_exists"] = True
|
||||
self._mock_cursor.rowcount = 0
|
||||
|
||||
elif "SELECT" in sql_upper and "chat_streams" in sql_upper:
|
||||
# Mock SELECT query
|
||||
if params and len(params) > 0:
|
||||
thread_id = params[0]
|
||||
if thread_id in self.mock_data["chat_streams"]:
|
||||
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][
|
||||
thread_id
|
||||
]
|
||||
else:
|
||||
self._mock_cursor._fetch_result = None
|
||||
else:
|
||||
self._mock_cursor._fetch_result = None
|
||||
|
||||
elif "UPDATE" in sql_upper and "chat_streams" in sql_upper:
|
||||
# Mock UPDATE query
|
||||
if params and len(params) >= 2:
|
||||
messages, thread_id = params[0], params[1]
|
||||
if thread_id in self.mock_data["chat_streams"]:
|
||||
self.mock_data["chat_streams"][thread_id] = {
|
||||
"id": thread_id,
|
||||
"thread_id": thread_id,
|
||||
"messages": messages,
|
||||
}
|
||||
self._mock_cursor.rowcount = 1
|
||||
else:
|
||||
self._mock_cursor.rowcount = 0
|
||||
|
||||
elif "INSERT" in sql_upper and "chat_streams" in sql_upper:
|
||||
# Mock INSERT query
|
||||
if params and len(params) >= 2:
|
||||
thread_id, messages = params[0], params[1]
|
||||
self.mock_data["chat_streams"][thread_id] = {
|
||||
"id": thread_id,
|
||||
"thread_id": thread_id,
|
||||
"messages": messages,
|
||||
}
|
||||
self._mock_cursor.rowcount = 1
|
||||
|
||||
def _mock_fetchone(self):
|
||||
"""Mock fetchone operation."""
|
||||
return getattr(self._mock_cursor, "_fetch_result", None)
|
||||
|
||||
def disconnect(self):
|
||||
"""Cleanup mock connection."""
|
||||
if self.mock_connection:
|
||||
self.mock_connection.close()
|
||||
self._setup_mock_data() # Reset data
|
||||
|
||||
def reset_data(self):
|
||||
"""Reset all mock data."""
|
||||
self._setup_mock_data()
|
||||
|
||||
def get_table_count(self, table_name: str) -> int:
|
||||
"""Get record count in a table."""
|
||||
if table_name == "chat_streams":
|
||||
return len(self.mock_data["chat_streams"])
|
||||
return 0
|
||||
|
||||
def create_test_data(self, table_name: str, records: list):
|
||||
"""Insert test data into a table."""
|
||||
if table_name == "chat_streams":
|
||||
for record in records:
|
||||
thread_id = record.get("thread_id")
|
||||
if thread_id:
|
||||
self.mock_data["chat_streams"][thread_id] = record
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_postgresql():
|
||||
"""Create a PostgreSQL mock instance."""
|
||||
instance = PostgreSQLMockInstance()
|
||||
instance.connect()
|
||||
yield instance
|
||||
instance.disconnect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_mock_postgresql():
|
||||
"""Create a clean PostgreSQL mock instance that resets between tests."""
|
||||
instance = PostgreSQLMockInstance()
|
||||
instance.connect()
|
||||
instance.reset_data()
|
||||
yield instance
|
||||
instance.disconnect()
|
||||
@@ -1,685 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import mongomock
|
||||
import pytest
|
||||
from postgres_mock_utils import PostgreSQLMockInstance
|
||||
|
||||
import src.graph.checkpoint as checkpoint
|
||||
|
||||
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
|
||||
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
|
||||
|
||||
|
||||
def has_real_db_connection():
|
||||
# Check the environment if the MongoDB server is available
|
||||
enabled = os.getenv("DB_TESTS_ENABLED", "false")
|
||||
if enabled.lower() == "true":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def test_with_local_postgres_db():
|
||||
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
|
||||
with patch("psycopg.connect") as mock_connect:
|
||||
# Setup mock PostgreSQL connection
|
||||
pg_mock = PostgreSQLMockInstance()
|
||||
mock_connect.return_value = pg_mock.connect()
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=POSTGRES_URL,
|
||||
)
|
||||
assert manager.postgres_conn is not None
|
||||
assert manager.mongo_client is None
|
||||
|
||||
|
||||
def test_with_local_mongo_db():
|
||||
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
assert manager.mongo_db is not None
|
||||
assert manager.postgres_conn is None
|
||||
|
||||
|
||||
def test_init_without_checkpoint_saver():
|
||||
"""Manager should not create DB clients when checkpoint_saver is False."""
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
|
||||
assert manager.checkpoint_saver is False
|
||||
# DB connections are not created when saver is disabled
|
||||
assert manager.mongo_client is None
|
||||
assert manager.postgres_conn is None
|
||||
|
||||
|
||||
def test_process_stream_partial_buffer_postgres(monkeypatch):
|
||||
"""Partial chunks should be buffered; Postgres init is stubbed to no-op."""
|
||||
|
||||
# Patch Postgres init to no-op
|
||||
def _no_pg(self):
|
||||
self.postgres_conn = None
|
||||
|
||||
monkeypatch.setattr(
|
||||
checkpoint.ChatStreamManager, "_init_postgresql", _no_pg, raising=True
|
||||
)
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=POSTGRES_URL,
|
||||
)
|
||||
result = manager.process_stream_message("t1", "hello", finish_reason="partial")
|
||||
assert result is True
|
||||
# Verify the chunk was stored in the in-memory store
|
||||
items = manager.store.search(("messages", "t1"), limit=10)
|
||||
values = [it.dict()["value"] for it in items]
|
||||
assert "hello" in values
|
||||
|
||||
|
||||
def test_process_stream_partial_buffer_mongo():
|
||||
"""Partial chunks should be buffered; Use mongomock instead of real MongoDB."""
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
result = manager.process_stream_message("t2", "hello", finish_reason="partial")
|
||||
assert result is True
|
||||
# Verify the chunk was stored in the in-memory store
|
||||
items = manager.store.search(("messages", "t2"), limit=10)
|
||||
values = [it.dict()["value"] for it in items]
|
||||
assert "hello" in values
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not has_real_db_connection(), reason="PostgreSQL Server is not available"
|
||||
)
|
||||
def test_persist_postgresql_local_db():
|
||||
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=POSTGRES_URL,
|
||||
)
|
||||
assert manager.postgres_conn is not None
|
||||
assert manager.mongo_client is None
|
||||
|
||||
# Simulate a message to persist
|
||||
thread_id = "test_thread"
|
||||
messages = ["This is a test message."]
|
||||
result = manager._persist_to_postgresql(thread_id, messages)
|
||||
assert result is True
|
||||
# Simulate a message with existing thread (should append, not overwrite)
|
||||
result = manager._persist_to_postgresql(thread_id, ["Another message."])
|
||||
assert result is True
|
||||
|
||||
# Verify the messages were appended correctly
|
||||
with manager.postgres_conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT messages FROM chat_streams WHERE thread_id = %s", (thread_id,)
|
||||
)
|
||||
existing_record = cursor.fetchone()
|
||||
assert existing_record is not None
|
||||
assert existing_record["messages"] == ["This is a test message.", "Another message."]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not has_real_db_connection(), reason="PostgreSQL Server is not available"
|
||||
)
|
||||
def test_persist_postgresql_called_with_aggregated_chunks():
|
||||
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=POSTGRES_URL,
|
||||
)
|
||||
|
||||
assert (
|
||||
manager.process_stream_message("thd3", "Hello", finish_reason="partial") is True
|
||||
)
|
||||
assert (
|
||||
manager.process_stream_message("thd3", " World", finish_reason="stop") is True
|
||||
)
|
||||
|
||||
# Verify the messages were aggregated correctly
|
||||
with manager.postgres_conn.cursor() as cursor:
|
||||
# Check if conversation already exists
|
||||
cursor.execute(
|
||||
"SELECT messages FROM chat_streams WHERE thread_id = %s", ("thd3",)
|
||||
)
|
||||
existing_record = cursor.fetchone()
|
||||
assert existing_record is not None
|
||||
assert existing_record["messages"] == ["Hello", " World"]
|
||||
|
||||
|
||||
def test_persist_not_attempted_when_saver_disabled():
|
||||
"""When saver disabled, stop should not persist and should return False."""
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
|
||||
# stop should try to persist, but saver disabled => returns False
|
||||
assert manager.process_stream_message("t4", "hello", finish_reason="stop") is False
|
||||
|
||||
|
||||
def test_persist_mongodb_local_db():
|
||||
"""Ensure that the ChatStreamManager can persist to a mocked MongoDB."""
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
assert manager.mongo_db is not None
|
||||
assert manager.postgres_conn is None
|
||||
|
||||
# Simulate a message to persist
|
||||
thread_id = "test_thread"
|
||||
messages = ["This is a test message."]
|
||||
result = manager._persist_to_mongodb(thread_id, messages)
|
||||
assert result is True
|
||||
|
||||
# Verify data was persisted in mock
|
||||
collection = manager.mongo_db.chat_streams
|
||||
doc = collection.find_one({"thread_id": thread_id})
|
||||
assert doc is not None
|
||||
assert doc["messages"] == messages
|
||||
|
||||
# Simulate a message with existing thread (should append, not overwrite)
|
||||
result = manager._persist_to_mongodb(thread_id, ["Another message."])
|
||||
assert result is True
|
||||
|
||||
# Verify update worked - messages should be appended to existing ones
|
||||
doc = collection.find_one({"thread_id": thread_id})
|
||||
assert doc["messages"] == ["This is a test message.", "Another message."]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not has_real_db_connection(), reason="MongoDB server is not available"
|
||||
)
|
||||
def test_persist_mongodb_called_with_aggregated_chunks():
|
||||
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
|
||||
assert (
|
||||
manager.process_stream_message("thd5", "Hello", finish_reason="partial") is True
|
||||
)
|
||||
assert (
|
||||
manager.process_stream_message("thd5", " World", finish_reason="stop") is True
|
||||
)
|
||||
|
||||
# Verify the messages were aggregated correctly
|
||||
collection = manager.mongo_db.chat_streams
|
||||
existing_record = collection.find_one({"thread_id": "thd5"})
|
||||
assert existing_record is not None
|
||||
assert existing_record["messages"] == ["Hello", " World"]
|
||||
|
||||
|
||||
def test_invalid_inputs_return_false(monkeypatch):
|
||||
"""Empty thread_id or message should be rejected and return False."""
|
||||
|
||||
def _no_mongo(self):
|
||||
self.mongo_client = None
|
||||
self.mongo_db = None
|
||||
|
||||
monkeypatch.setattr(
|
||||
checkpoint.ChatStreamManager, "_init_mongodb", _no_mongo, raising=True
|
||||
)
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
assert manager.process_stream_message("", "msg", finish_reason="partial") is False
|
||||
assert manager.process_stream_message("tid", "", finish_reason="partial") is False
|
||||
|
||||
|
||||
def test_unsupported_db_uri_scheme():
|
||||
"""Manager should log warning for unsupported database URI schemes."""
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True, db_uri="redis://localhost:6379/0"
|
||||
)
|
||||
# Should not have any database connections
|
||||
assert manager.mongo_client is None
|
||||
assert manager.postgres_conn is None
|
||||
assert manager.mongo_db is None
|
||||
|
||||
|
||||
def test_process_stream_with_interrupt_finish_reason():
|
||||
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
|
||||
# Add partial message
|
||||
assert (
|
||||
manager.process_stream_message(
|
||||
"int_test", "Interrupted", finish_reason="partial"
|
||||
)
|
||||
is True
|
||||
)
|
||||
# Interrupt should trigger persistence
|
||||
assert (
|
||||
manager.process_stream_message(
|
||||
"int_test", " message", finish_reason="interrupt"
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
# Verify persistence occurred
|
||||
collection = manager.mongo_db.chat_streams
|
||||
doc = collection.find_one({"thread_id": "int_test"})
|
||||
assert doc is not None
|
||||
assert doc["messages"] == ["Interrupted", " message"]
|
||||
|
||||
|
||||
def test_postgresql_connection_failure(monkeypatch):
|
||||
"""Test PostgreSQL connection failure handling."""
|
||||
|
||||
def failing_connect(dsn, **kwargs):
|
||||
raise RuntimeError("Connection failed")
|
||||
|
||||
monkeypatch.setattr("psycopg.connect", failing_connect)
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=POSTGRES_URL,
|
||||
)
|
||||
# Should have no postgres connection on failure
|
||||
assert manager.postgres_conn is None
|
||||
|
||||
|
||||
def test_mongodb_ping_failure(monkeypatch):
|
||||
"""Test MongoDB ping failure during initialization."""
|
||||
|
||||
class FakeAdmin:
|
||||
def command(self, name):
|
||||
raise RuntimeError("Ping failed")
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, uri):
|
||||
self.admin = FakeAdmin()
|
||||
|
||||
monkeypatch.setattr(checkpoint, "MongoClient", lambda uri: FakeClient(uri))
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
# Should not have mongo_db set on ping failure
|
||||
assert getattr(manager, "mongo_db", None) is None
|
||||
|
||||
|
||||
def test_store_namespace_consistency():
|
||||
"""Test that store namespace is consistently used across methods."""
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
|
||||
|
||||
# Process a partial message
|
||||
assert (
|
||||
manager.process_stream_message("ns_test", "chunk1", finish_reason="partial")
|
||||
is True
|
||||
)
|
||||
|
||||
# Verify cursor is stored correctly
|
||||
cursor = manager.store.get(("messages", "ns_test"), "cursor")
|
||||
assert cursor is not None
|
||||
assert cursor.value["index"] == 0
|
||||
|
||||
# Add another chunk
|
||||
assert (
|
||||
manager.process_stream_message("ns_test", "chunk2", finish_reason="partial")
|
||||
is True
|
||||
)
|
||||
|
||||
# Verify cursor is incremented
|
||||
cursor = manager.store.get(("messages", "ns_test"), "cursor")
|
||||
assert cursor.value["index"] == 1
|
||||
|
||||
|
||||
def test_cursor_initialization_edge_cases():
|
||||
"""Test cursor handling edge cases."""
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
|
||||
|
||||
# Manually set a cursor with missing index
|
||||
namespace = ("messages", "edge_test")
|
||||
manager.store.put(namespace, "cursor", {}) # Missing 'index' key
|
||||
|
||||
# Should handle missing index gracefully
|
||||
result = manager.process_stream_message(
|
||||
"edge_test", "test", finish_reason="partial"
|
||||
)
|
||||
assert result is True
|
||||
|
||||
# Should default to 0 and increment to 1
|
||||
cursor = manager.store.get(namespace, "cursor")
|
||||
assert cursor.value["index"] == 1
|
||||
|
||||
|
||||
def test_multiple_threads_isolation():
|
||||
"""Test that different thread_ids are properly isolated."""
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
|
||||
|
||||
# Process messages for different threads
|
||||
assert (
|
||||
manager.process_stream_message("thread1", "msg1", finish_reason="partial")
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
manager.process_stream_message("thread2", "msg2", finish_reason="partial")
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
manager.process_stream_message("thread1", "msg3", finish_reason="partial")
|
||||
is True
|
||||
)
|
||||
|
||||
# Verify isolation
|
||||
thread1_items = manager.store.search(("messages", "thread1"), limit=10)
|
||||
thread2_items = manager.store.search(("messages", "thread2"), limit=10)
|
||||
|
||||
thread1_values = [
|
||||
item.dict()["value"]
|
||||
for item in thread1_items
|
||||
if isinstance(item.dict()["value"], str)
|
||||
]
|
||||
thread2_values = [
|
||||
item.dict()["value"]
|
||||
for item in thread2_items
|
||||
if isinstance(item.dict()["value"], str)
|
||||
]
|
||||
|
||||
assert "msg1" in thread1_values
|
||||
assert "msg3" in thread1_values
|
||||
assert "msg2" in thread2_values
|
||||
assert "msg1" not in thread2_values
|
||||
assert "msg2" not in thread1_values
|
||||
|
||||
|
||||
def test_mongodb_insert_and_update_paths():
|
||||
"""Exercise MongoDB insert, update, and exception branches using mongomock."""
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||
|
||||
# Insert success (new thread)
|
||||
assert manager._persist_to_mongodb("th1", ["message1"]) is True
|
||||
|
||||
# Verify insert worked
|
||||
collection = manager.mongo_db.chat_streams
|
||||
doc = collection.find_one({"thread_id": "th1"})
|
||||
assert doc is not None
|
||||
assert doc["messages"] == ["message1"]
|
||||
|
||||
# Update success (existing thread - should append, not overwrite)
|
||||
assert manager._persist_to_mongodb("th1", ["message2"]) is True
|
||||
|
||||
# Verify update worked - messages should be appended
|
||||
doc = collection.find_one({"thread_id": "th1"})
|
||||
assert doc["messages"] == ["message1", "message2"]
|
||||
|
||||
# Test error case by mocking collection methods
|
||||
original_find_one = collection.find_one
|
||||
collection.find_one = MagicMock(side_effect=RuntimeError("Database error"))
|
||||
|
||||
assert manager._persist_to_mongodb("th2", ["message"]) is False
|
||||
|
||||
# Restore original method
|
||||
collection.find_one = original_find_one
|
||||
|
||||
|
||||
def test_postgresql_insert_update_and_error_paths():
|
||||
"""Exercise PostgreSQL update, insert, and error/rollback branches."""
|
||||
calls = {"executed": []}
|
||||
|
||||
class FakeCursor:
|
||||
def __init__(self, mode):
|
||||
self.mode = mode
|
||||
self.rowcount = 0
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
calls["executed"].append(sql.strip().split()[0])
|
||||
if "SELECT" in sql:
|
||||
if self.mode == "update":
|
||||
self._fetch = {"id": "x"}
|
||||
elif self.mode == "error":
|
||||
raise RuntimeError("sql error")
|
||||
else:
|
||||
self._fetch = None
|
||||
else:
|
||||
# UPDATE or INSERT
|
||||
self.rowcount = 1
|
||||
|
||||
def fetchone(self):
|
||||
return getattr(self, "_fetch", None)
|
||||
|
||||
class FakeConn:
|
||||
def __init__(self, mode):
|
||||
self.mode = mode
|
||||
self.commit_called = False
|
||||
self.rollback_called = False
|
||||
|
||||
def cursor(self):
|
||||
return FakeCursor(self.mode)
|
||||
|
||||
def commit(self):
|
||||
self.commit_called = True
|
||||
|
||||
def rollback(self):
|
||||
self.rollback_called = True
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL)
|
||||
|
||||
# Update path
|
||||
manager.postgres_conn = FakeConn("update")
|
||||
assert manager._persist_to_postgresql("t", ["m"]) is True
|
||||
assert manager.postgres_conn.commit_called is True
|
||||
|
||||
# Insert path
|
||||
manager.postgres_conn = FakeConn("insert")
|
||||
assert manager._persist_to_postgresql("t", ["m"]) is True
|
||||
assert manager.postgres_conn.commit_called is True
|
||||
|
||||
# Error path with rollback
|
||||
manager.postgres_conn = FakeConn("error")
|
||||
assert manager._persist_to_postgresql("t", ["m"]) is False
|
||||
assert manager.postgres_conn.rollback_called is True
|
||||
|
||||
|
||||
def test_create_chat_streams_table_success_and_error():
|
||||
"""Ensure table creation commits on success and rolls back on failure."""
|
||||
|
||||
class FakeCursor:
|
||||
def __init__(self, should_fail=False):
|
||||
self.should_fail = should_fail
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, sql):
|
||||
if self.should_fail:
|
||||
raise RuntimeError("ddl fail")
|
||||
|
||||
class FakeConn:
|
||||
def __init__(self, should_fail=False):
|
||||
self.should_fail = should_fail
|
||||
self.commits = 0
|
||||
self.rollbacks = 0
|
||||
|
||||
def cursor(self):
|
||||
return FakeCursor(self.should_fail)
|
||||
|
||||
def commit(self):
|
||||
self.commits += 1
|
||||
|
||||
def rollback(self):
|
||||
self.rollbacks += 1
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL)
|
||||
|
||||
# Success
|
||||
manager.postgres_conn = FakeConn(False)
|
||||
manager._create_chat_streams_table()
|
||||
assert manager.postgres_conn.commits == 1
|
||||
|
||||
# Failure triggers rollback
|
||||
manager.postgres_conn = FakeConn(True)
|
||||
manager._create_chat_streams_table()
|
||||
assert manager.postgres_conn.rollbacks == 1
|
||||
|
||||
|
||||
def test_close_closes_resources_and_handles_errors():
|
||||
"""Close should gracefully handle both success and exceptions."""
|
||||
flags = {"mongo": 0, "pg": 0}
|
||||
|
||||
class M:
|
||||
def close(self):
|
||||
flags["mongo"] += 1
|
||||
|
||||
class P:
|
||||
def __init__(self, raise_on_close=False):
|
||||
self.raise_on_close = raise_on_close
|
||||
|
||||
def close(self):
|
||||
if self.raise_on_close:
|
||||
raise RuntimeError("close fail")
|
||||
flags["pg"] += 1
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=False)
|
||||
manager.mongo_client = M()
|
||||
manager.postgres_conn = P()
|
||||
manager.close()
|
||||
assert flags == {"mongo": 1, "pg": 1}
|
||||
|
||||
# Trigger error branches (no raise escapes)
|
||||
manager.mongo_client = None # skip mongo
|
||||
manager.postgres_conn = P(True)
|
||||
manager.close() # should handle exception gracefully
|
||||
|
||||
|
||||
def test_context_manager_calls_close(monkeypatch):
|
||||
"""The context manager protocol should call close() on exit."""
|
||||
called = {"close": 0}
|
||||
|
||||
def _noop(self):
|
||||
self.mongo_client = None
|
||||
self.mongo_db = None
|
||||
|
||||
monkeypatch.setattr(
|
||||
checkpoint.ChatStreamManager, "_init_mongodb", _noop, raising=True
|
||||
)
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||
|
||||
def fake_close():
|
||||
called["close"] += 1
|
||||
|
||||
manager.close = fake_close
|
||||
with manager:
|
||||
pass
|
||||
assert called["close"] == 1
|
||||
|
||||
|
||||
def test_init_mongodb_success_and_failure(monkeypatch):
|
||||
"""MongoDB init should succeed with mongomock and fail gracefully with errors."""
|
||||
|
||||
# Success path with mongomock
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||
assert manager.mongo_db is not None
|
||||
|
||||
# Failure path
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
mock_mongo_client.side_effect = RuntimeError("Connection failed")
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||
# Should have no mongo_db set on failure
|
||||
assert getattr(manager, "mongo_db", None) is None
|
||||
|
||||
|
||||
def test_init_postgresql_calls_connect_and_create_table(monkeypatch):
|
||||
"""PostgreSQL init should connect and create the required table."""
|
||||
flags = {"connected": 0, "created": 0}
|
||||
|
||||
class FakeConn:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def fake_connect(self):
|
||||
flags["connected"] += 1
|
||||
flags["created"] += 1
|
||||
return FakeConn()
|
||||
|
||||
monkeypatch.setattr(
|
||||
checkpoint.ChatStreamManager, "_init_postgresql", fake_connect, raising=True
|
||||
)
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=POSTGRES_URL)
|
||||
assert manager.postgres_conn is None
|
||||
assert flags == {"connected": 1, "created": 1}
|
||||
|
||||
|
||||
def test_chat_stream_message_wrapper(monkeypatch):
|
||||
"""Wrapper should delegate when enabled and return False when disabled."""
|
||||
# When saver enabled, should call default manager
|
||||
monkeypatch.setattr(
|
||||
checkpoint, "get_bool_env", lambda k, d=False: True, raising=True
|
||||
)
|
||||
|
||||
called = {"args": None}
|
||||
|
||||
def fake_process(tid, msg, fr):
|
||||
called["args"] = (tid, msg, fr)
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(
|
||||
checkpoint._default_manager,
|
||||
"process_stream_message",
|
||||
fake_process,
|
||||
raising=True,
|
||||
)
|
||||
assert checkpoint.chat_stream_message("tid", "msg", "stop") is True
|
||||
assert called["args"] == ("tid", "msg", "stop")
|
||||
|
||||
# When saver disabled, returns False and does not call manager
|
||||
monkeypatch.setattr(
|
||||
checkpoint, "get_bool_env", lambda k, d=False: False, raising=True
|
||||
)
|
||||
called["args"] = None
|
||||
assert checkpoint.chat_stream_message("tid", "msg", "stop") is False
|
||||
assert called["args"] is None
|
||||
@@ -1,46 +0,0 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
import mongomock
|
||||
import src.graph.checkpoint as checkpoint
|
||||
|
||||
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
|
||||
|
||||
def test_memory_leak_check_memory_cleared_after_persistence():
|
||||
"""
|
||||
Test that InMemoryStore is cleared for a thread after successful persistence.
|
||||
This prevents memory leaks for long-running processes.
|
||||
"""
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
|
||||
thread_id = "leak_test_thread"
|
||||
namespace = ("messages", thread_id)
|
||||
|
||||
# 1. Simulate streaming messages
|
||||
manager.process_stream_message(thread_id, "Hello", "partial")
|
||||
manager.process_stream_message(thread_id, " World", "partial")
|
||||
|
||||
# Verify items are in store during streaming
|
||||
items = manager.store.search(namespace)
|
||||
assert len(items) > 0, "Store should contain items during streaming"
|
||||
|
||||
# 2. Simulate end of conversation (trigger persistence)
|
||||
# 'stop' should trigger _persist_complete_conversation which now includes cleanup
|
||||
manager.process_stream_message(thread_id, "!", "stop")
|
||||
|
||||
# 3. Verify store is empty for this thread
|
||||
items_after = manager.store.search(namespace)
|
||||
assert len(items_after) == 0, "Memory should be cleared after successful persistence"
|
||||
|
||||
# Verify persistence actually happened
|
||||
collection = manager.mongo_db.chat_streams
|
||||
doc = collection.find_one({"thread_id": thread_id})
|
||||
assert doc is not None
|
||||
assert doc["messages"] == ["Hello", " World", "!"]
|
||||
@@ -1,136 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from src.citations.collector import CitationCollector
|
||||
from src.citations.extractor import (
|
||||
_extract_domain,
|
||||
citations_to_markdown_references,
|
||||
extract_citations_from_messages,
|
||||
merge_citations,
|
||||
)
|
||||
from src.citations.formatter import CitationFormatter
|
||||
from src.citations.models import Citation, CitationMetadata
|
||||
|
||||
|
||||
class TestCitationMetadata:
|
||||
def test_initialization(self):
|
||||
meta = CitationMetadata(
|
||||
url="https://example.com/page",
|
||||
title="Example Page",
|
||||
description="An example description",
|
||||
)
|
||||
assert meta.url == "https://example.com/page"
|
||||
assert meta.title == "Example Page"
|
||||
assert meta.description == "An example description"
|
||||
assert meta.domain == "example.com" # Auto-extracted in post_init
|
||||
|
||||
def test_id_generation(self):
|
||||
meta = CitationMetadata(url="https://example.com", title="Test")
|
||||
# Just check it's a non-empty string, length 12
|
||||
assert len(meta.id) == 12
|
||||
assert isinstance(meta.id, str)
|
||||
|
||||
def test_to_dict(self):
|
||||
meta = CitationMetadata(
|
||||
url="https://example.com", title="Test", relevance_score=0.8
|
||||
)
|
||||
data = meta.to_dict()
|
||||
assert data["url"] == "https://example.com"
|
||||
assert data["title"] == "Test"
|
||||
assert data["relevance_score"] == 0.8
|
||||
assert "id" in data
|
||||
|
||||
|
||||
class TestCitation:
|
||||
def test_citation_wrapper(self):
|
||||
meta = CitationMetadata(url="https://example.com", title="Test")
|
||||
citation = Citation(number=1, metadata=meta)
|
||||
|
||||
assert citation.number == 1
|
||||
assert citation.url == "https://example.com"
|
||||
assert citation.title == "Test"
|
||||
assert citation.to_markdown_reference() == "[Test](https://example.com)"
|
||||
assert citation.to_numbered_reference() == "[1] Test - https://example.com"
|
||||
|
||||
|
||||
class TestExtractor:
|
||||
def test_extract_from_tool_message_web_search(self):
|
||||
search_result = {
|
||||
"results": [
|
||||
{
|
||||
"url": "https://example.com/1",
|
||||
"title": "Result 1",
|
||||
"content": "Content 1",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
msg = ToolMessage(
|
||||
content=str(search_result).replace("'", '"'), # Simple JSON dump simulation
|
||||
tool_call_id="call_1",
|
||||
name="web_search",
|
||||
)
|
||||
# Mocking json structure if ToolMessage content expects stringified JSON
|
||||
import json
|
||||
|
||||
msg.content = json.dumps(search_result)
|
||||
|
||||
citations = extract_citations_from_messages([msg])
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["url"] == "https://example.com/1"
|
||||
assert citations[0]["title"] == "Result 1"
|
||||
|
||||
def test_extract_domain(self):
|
||||
assert _extract_domain("https://www.example.com/path") == "www.example.com"
|
||||
assert _extract_domain("http://example.org") == "example.org"
|
||||
|
||||
def test_merge_citations(self):
|
||||
existing = [{"url": "https://a.com", "title": "A", "relevance_score": 0.5}]
|
||||
new = [
|
||||
{"url": "https://b.com", "title": "B", "relevance_score": 0.6},
|
||||
{
|
||||
"url": "https://a.com",
|
||||
"title": "A New",
|
||||
"relevance_score": 0.7,
|
||||
}, # Better score for A
|
||||
]
|
||||
|
||||
merged = merge_citations(existing, new)
|
||||
assert len(merged) == 2
|
||||
|
||||
# Check A was updated
|
||||
a_citation = next(c for c in merged if c["url"] == "https://a.com")
|
||||
assert a_citation["relevance_score"] == 0.7
|
||||
|
||||
# Check B is present
|
||||
b_citation = next(c for c in merged if c["url"] == "https://b.com")
|
||||
assert b_citation["title"] == "B"
|
||||
|
||||
def test_citations_to_markdown(self):
|
||||
citations = [{"url": "https://a.com", "title": "A", "description": "Desc A"}]
|
||||
md = citations_to_markdown_references(citations)
|
||||
assert "## Key Citations" in md
|
||||
assert "- [A](https://a.com)" in md
|
||||
|
||||
|
||||
class TestCollector:
|
||||
def test_add_citations(self):
|
||||
collector = CitationCollector()
|
||||
results = [
|
||||
{"url": "https://example.com", "title": "Example", "content": "Test"}
|
||||
]
|
||||
added = collector.add_from_search_results(results, query="test")
|
||||
|
||||
assert len(added) == 1
|
||||
assert added[0].url == "https://example.com"
|
||||
assert collector.count == 1
|
||||
|
||||
|
||||
class TestFormatter:
|
||||
def test_format_inline(self):
|
||||
formatter = CitationFormatter(style="superscript")
|
||||
assert formatter.format_inline_marker(1) == "¹"
|
||||
assert formatter.format_inline_marker(12) == "¹²"
|
||||
@@ -1,289 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for CitationCollector optimization with reverse index cache.
|
||||
|
||||
Tests the O(1) URL lookup performance optimization via _url_to_index cache.
|
||||
"""
|
||||
|
||||
from src.citations.collector import CitationCollector
|
||||
|
||||
|
||||
class TestCitationCollectorOptimization:
|
||||
"""Test CitationCollector reverse index cache optimization."""
|
||||
|
||||
def test_url_to_index_cache_initialization(self):
|
||||
"""Test that _url_to_index is properly initialized."""
|
||||
collector = CitationCollector()
|
||||
assert hasattr(collector, "_url_to_index")
|
||||
assert isinstance(collector._url_to_index, dict)
|
||||
assert len(collector._url_to_index) == 0
|
||||
|
||||
def test_add_single_citation_updates_cache(self):
|
||||
"""Test that adding a citation updates _url_to_index."""
|
||||
collector = CitationCollector()
|
||||
results = [
|
||||
{
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"content": "Content",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
|
||||
collector.add_from_search_results(results)
|
||||
|
||||
# Check cache is populated
|
||||
assert "https://example.com" in collector._url_to_index
|
||||
assert collector._url_to_index["https://example.com"] == 0
|
||||
|
||||
def test_add_multiple_citations_updates_cache_correctly(self):
|
||||
"""Test that multiple citations are indexed correctly."""
|
||||
collector = CitationCollector()
|
||||
results = [
|
||||
{
|
||||
"url": f"https://example.com/{i}",
|
||||
"title": f"Page {i}",
|
||||
"content": f"Content {i}",
|
||||
"score": 0.9,
|
||||
}
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
collector.add_from_search_results(results)
|
||||
|
||||
# Check all URLs are indexed
|
||||
assert len(collector._url_to_index) == 5
|
||||
for i in range(5):
|
||||
url = f"https://example.com/{i}"
|
||||
assert collector._url_to_index[url] == i
|
||||
|
||||
def test_get_number_uses_cache_for_o1_lookup(self):
|
||||
"""Test that get_number uses cache for O(1) lookup."""
|
||||
collector = CitationCollector()
|
||||
urls = [f"https://example.com/{i}" for i in range(100)]
|
||||
results = [
|
||||
{
|
||||
"url": url,
|
||||
"title": f"Title {i}",
|
||||
"content": f"Content {i}",
|
||||
"score": 0.9,
|
||||
}
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
collector.add_from_search_results(results)
|
||||
|
||||
# Test lookup for various positions
|
||||
assert collector.get_number("https://example.com/0") == 1
|
||||
assert collector.get_number("https://example.com/50") == 51
|
||||
assert collector.get_number("https://example.com/99") == 100
|
||||
|
||||
# Non-existent URL returns None
|
||||
assert collector.get_number("https://nonexistent.com") is None
|
||||
|
||||
def test_add_from_crawl_result_updates_cache(self):
|
||||
"""Test that add_from_crawl_result updates cache."""
|
||||
collector = CitationCollector()
|
||||
|
||||
collector.add_from_crawl_result(
|
||||
url="https://crawled.com/page",
|
||||
title="Crawled Page",
|
||||
content="Crawled content",
|
||||
)
|
||||
|
||||
assert "https://crawled.com/page" in collector._url_to_index
|
||||
assert collector._url_to_index["https://crawled.com/page"] == 0
|
||||
|
||||
def test_duplicate_url_does_not_change_cache(self):
|
||||
"""Test that adding duplicate URLs doesn't change cache indices."""
|
||||
collector = CitationCollector()
|
||||
|
||||
# Add first time
|
||||
collector.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": "https://example.com",
|
||||
"title": "Title 1",
|
||||
"content": "Content 1",
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
)
|
||||
assert collector._url_to_index["https://example.com"] == 0
|
||||
|
||||
# Add same URL again with better score
|
||||
collector.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": "https://example.com",
|
||||
"title": "Title 1 Updated",
|
||||
"content": "Content 1 Updated",
|
||||
"score": 0.95,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Cache index should not change
|
||||
assert collector._url_to_index["https://example.com"] == 0
|
||||
# But metadata should be updated
|
||||
assert collector._citations["https://example.com"].relevance_score == 0.95
|
||||
|
||||
def test_merge_with_updates_cache_correctly(self):
|
||||
"""Test that merge_with correctly updates cache for new URLs."""
|
||||
collector1 = CitationCollector()
|
||||
collector2 = CitationCollector()
|
||||
|
||||
# Add to collector1
|
||||
collector1.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": "https://a.com",
|
||||
"title": "A",
|
||||
"content": "Content A",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Add to collector2
|
||||
collector2.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": "https://b.com",
|
||||
"title": "B",
|
||||
"content": "Content B",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
collector1.merge_with(collector2)
|
||||
|
||||
# Both URLs should be in cache
|
||||
assert "https://a.com" in collector1._url_to_index
|
||||
assert "https://b.com" in collector1._url_to_index
|
||||
assert collector1._url_to_index["https://a.com"] == 0
|
||||
assert collector1._url_to_index["https://b.com"] == 1
|
||||
|
||||
def test_from_dict_rebuilds_cache(self):
|
||||
"""Test that from_dict properly rebuilds cache."""
|
||||
# Create original collector
|
||||
original = CitationCollector()
|
||||
original.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": f"https://example.com/{i}",
|
||||
"title": f"Page {i}",
|
||||
"content": f"Content {i}",
|
||||
"score": 0.9,
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
)
|
||||
|
||||
# Serialize and deserialize
|
||||
data = original.to_dict()
|
||||
restored = CitationCollector.from_dict(data)
|
||||
|
||||
# Check cache is properly rebuilt
|
||||
assert len(restored._url_to_index) == 3
|
||||
for i in range(3):
|
||||
url = f"https://example.com/{i}"
|
||||
assert url in restored._url_to_index
|
||||
assert restored._url_to_index[url] == i
|
||||
|
||||
def test_clear_resets_cache(self):
|
||||
"""Test that clear() properly resets the cache."""
|
||||
collector = CitationCollector()
|
||||
collector.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"content": "Content",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert len(collector._url_to_index) > 0
|
||||
|
||||
collector.clear()
|
||||
|
||||
assert len(collector._url_to_index) == 0
|
||||
assert len(collector._citations) == 0
|
||||
assert len(collector._citation_order) == 0
|
||||
|
||||
def test_cache_consistency_with_order_list(self):
|
||||
"""Test that cache indices match positions in _citation_order."""
|
||||
collector = CitationCollector()
|
||||
urls = [f"https://example.com/{i}" for i in range(10)]
|
||||
results = [
|
||||
{
|
||||
"url": url,
|
||||
"title": f"Title {i}",
|
||||
"content": f"Content {i}",
|
||||
"score": 0.9,
|
||||
}
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
collector.add_from_search_results(results)
|
||||
|
||||
# Verify cache indices match order list positions
|
||||
for i, url in enumerate(collector._citation_order):
|
||||
assert collector._url_to_index[url] == i
|
||||
|
||||
def test_mark_used_with_cache(self):
|
||||
"""Test that mark_used works correctly with cache."""
|
||||
collector = CitationCollector()
|
||||
collector.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": "https://example.com/1",
|
||||
"title": "Page 1",
|
||||
"content": "Content 1",
|
||||
"score": 0.9,
|
||||
},
|
||||
{
|
||||
"url": "https://example.com/2",
|
||||
"title": "Page 2",
|
||||
"content": "Content 2",
|
||||
"score": 0.9,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
# Mark one as used
|
||||
number = collector.mark_used("https://example.com/2")
|
||||
assert number == 2
|
||||
|
||||
# Verify it's in used set
|
||||
assert "https://example.com/2" in collector._used_citations
|
||||
|
||||
def test_large_collection_cache_performance(self):
|
||||
"""Test that cache works correctly with large collections."""
|
||||
collector = CitationCollector()
|
||||
num_citations = 1000
|
||||
results = [
|
||||
{
|
||||
"url": f"https://example.com/{i}",
|
||||
"title": f"Title {i}",
|
||||
"content": f"Content {i}",
|
||||
"score": 0.9,
|
||||
}
|
||||
for i in range(num_citations)
|
||||
]
|
||||
|
||||
collector.add_from_search_results(results)
|
||||
|
||||
# Verify cache size
|
||||
assert len(collector._url_to_index) == num_citations
|
||||
|
||||
# Test lookups at various positions
|
||||
test_indices = [0, 100, 500, 999]
|
||||
for idx in test_indices:
|
||||
url = f"https://example.com/{idx}"
|
||||
assert collector.get_number(url) == idx + 1
|
||||
@@ -1,251 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for extractor optimizations.
|
||||
|
||||
Tests the enhanced domain extraction and title extraction functions.
|
||||
"""
|
||||
|
||||
from src.citations.extractor import (
|
||||
_extract_domain,
|
||||
extract_title_from_content,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractDomainOptimization:
|
||||
"""Test domain extraction with urllib + regex fallback strategy."""
|
||||
|
||||
def test_extract_domain_standard_urls(self):
|
||||
"""Test extraction from standard URLs."""
|
||||
assert _extract_domain("https://www.example.com/path") == "www.example.com"
|
||||
assert _extract_domain("http://example.org") == "example.org"
|
||||
assert _extract_domain("https://github.com/user/repo") == "github.com"
|
||||
|
||||
def test_extract_domain_with_port(self):
|
||||
"""Test extraction from URLs with ports."""
|
||||
assert _extract_domain("http://localhost:8080/api") == "localhost:8080"
|
||||
assert (
|
||||
_extract_domain("https://example.com:3000/page")
|
||||
== "example.com:3000"
|
||||
)
|
||||
|
||||
def test_extract_domain_with_subdomain(self):
|
||||
"""Test extraction from URLs with subdomains."""
|
||||
assert _extract_domain("https://api.github.com/repos") == "api.github.com"
|
||||
assert (
|
||||
_extract_domain("https://docs.python.org/en/")
|
||||
== "docs.python.org"
|
||||
)
|
||||
|
||||
def test_extract_domain_invalid_url(self):
|
||||
"""Test handling of invalid URLs."""
|
||||
# Should not crash, might return empty string
|
||||
result = _extract_domain("not a url")
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_extract_domain_empty_url(self):
|
||||
"""Test handling of empty URL."""
|
||||
assert _extract_domain("") == ""
|
||||
|
||||
def test_extract_domain_without_scheme(self):
|
||||
"""Test extraction from URLs without scheme (handled by regex fallback)."""
|
||||
# These may be handled by regex fallback
|
||||
result = _extract_domain("example.com/path")
|
||||
# Should at least not crash
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_extract_domain_complex_urls(self):
|
||||
"""Test extraction from complex URLs."""
|
||||
# urllib includes credentials in netloc, so this is expected behavior
|
||||
assert (
|
||||
_extract_domain("https://user:pass@example.com/path")
|
||||
== "user:pass@example.com"
|
||||
)
|
||||
assert (
|
||||
_extract_domain("https://example.com:443/path?query=value#hash")
|
||||
== "example.com:443"
|
||||
)
|
||||
|
||||
def test_extract_domain_ipv4(self):
|
||||
"""Test extraction from IPv4 addresses."""
|
||||
result = _extract_domain("http://192.168.1.1:8080/")
|
||||
# Should handle IP addresses
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_extract_domain_query_params(self):
|
||||
"""Test that query params don't affect domain extraction."""
|
||||
url1 = "https://example.com/page?query=value"
|
||||
url2 = "https://example.com/page"
|
||||
assert _extract_domain(url1) == _extract_domain(url2)
|
||||
|
||||
def test_extract_domain_url_fragments(self):
|
||||
"""Test that fragments don't affect domain extraction."""
|
||||
url1 = "https://example.com/page#section"
|
||||
url2 = "https://example.com/page"
|
||||
assert _extract_domain(url1) == _extract_domain(url2)
|
||||
|
||||
|
||||
class TestExtractTitleFromContent:
|
||||
"""Test intelligent title extraction with 5-tier priority system."""
|
||||
|
||||
def test_extract_title_html_title_tag(self):
|
||||
"""Test priority 1: HTML <title> tag extraction."""
|
||||
content = "<html><head><title>HTML Title</title></head><body>Content</body></html>"
|
||||
assert extract_title_from_content(content) == "HTML Title"
|
||||
|
||||
def test_extract_title_html_title_case_insensitive(self):
|
||||
"""Test that HTML title extraction is case-insensitive."""
|
||||
content = "<html><head><TITLE>HTML Title</TITLE></head><body></body></html>"
|
||||
assert extract_title_from_content(content) == "HTML Title"
|
||||
|
||||
def test_extract_title_markdown_h1(self):
|
||||
"""Test priority 2: Markdown h1 extraction."""
|
||||
content = "# Main Title\n\nSome content here"
|
||||
assert extract_title_from_content(content) == "Main Title"
|
||||
|
||||
def test_extract_title_markdown_h1_with_spaces(self):
|
||||
"""Test markdown h1 with extra spaces."""
|
||||
content = "# Title with Spaces \n\nContent"
|
||||
assert extract_title_from_content(content) == "Title with Spaces"
|
||||
|
||||
def test_extract_title_markdown_h2_fallback(self):
|
||||
"""Test priority 3: Markdown h2 as fallback when no h1."""
|
||||
content = "## Second Level Title\n\nSome content"
|
||||
assert extract_title_from_content(content) == "Second Level Title"
|
||||
|
||||
def test_extract_title_markdown_h6_fallback(self):
|
||||
"""Test markdown h6 as fallback."""
|
||||
content = "###### Small Heading\n\nContent"
|
||||
assert extract_title_from_content(content) == "Small Heading"
|
||||
|
||||
def test_extract_title_prefers_h1_over_h2(self):
|
||||
"""Test that h1 is preferred over h2."""
|
||||
content = "# H1 Title\n## H2 Title\n\nContent"
|
||||
assert extract_title_from_content(content) == "H1 Title"
|
||||
|
||||
def test_extract_title_json_field(self):
|
||||
"""Test priority 4: JSON title field extraction."""
|
||||
content = '{"title": "JSON Title", "content": "Some data"}'
|
||||
assert extract_title_from_content(content) == "JSON Title"
|
||||
|
||||
def test_extract_title_yaml_field(self):
|
||||
"""Test YAML title field extraction."""
|
||||
content = 'title: "YAML Title"\ncontent: "Some data"'
|
||||
assert extract_title_from_content(content) == "YAML Title"
|
||||
|
||||
def test_extract_title_first_substantial_line(self):
|
||||
"""Test priority 5: First substantial non-empty line."""
|
||||
content = "\n\n\nThis is the first substantial line\n\nMore content"
|
||||
assert extract_title_from_content(content) == "This is the first substantial line"
|
||||
|
||||
def test_extract_title_skips_short_lines(self):
|
||||
"""Test that short lines are skipped."""
|
||||
content = "abc\nThis is a longer first substantial line\nContent"
|
||||
assert extract_title_from_content(content) == "This is a longer first substantial line"
|
||||
|
||||
def test_extract_title_skips_code_blocks(self):
|
||||
"""Test that code blocks are skipped."""
|
||||
content = "```\ncode here\n```\nThis is the title\n\nContent"
|
||||
result = extract_title_from_content(content)
|
||||
# Should skip the code block and find the actual title
|
||||
assert "title" in result.lower() or "code" not in result
|
||||
|
||||
def test_extract_title_skips_list_items(self):
|
||||
"""Test that list items are skipped."""
|
||||
content = "- Item 1\n- Item 2\nThis is the actual first substantial line\n\nContent"
|
||||
result = extract_title_from_content(content)
|
||||
assert "actual" in result or "Item" not in result
|
||||
|
||||
def test_extract_title_skips_separators(self):
|
||||
"""Test that separator lines are skipped."""
|
||||
content = "---\n\n***\n\nThis is the real title\n\nContent"
|
||||
result = extract_title_from_content(content)
|
||||
assert "---" not in result and "***" not in result
|
||||
|
||||
def test_extract_title_max_length(self):
|
||||
"""Test that title respects max_length parameter."""
|
||||
long_title = "A" * 300
|
||||
content = f"# {long_title}"
|
||||
result = extract_title_from_content(content, max_length=100)
|
||||
assert len(result) <= 100
|
||||
assert result == long_title[:100]
|
||||
|
||||
def test_extract_title_empty_content(self):
|
||||
"""Test handling of empty content."""
|
||||
assert extract_title_from_content("") == "Untitled"
|
||||
assert extract_title_from_content(None) == "Untitled"
|
||||
|
||||
def test_extract_title_no_title_found(self):
|
||||
"""Test fallback to 'Untitled' when no title can be extracted."""
|
||||
content = "a\nb\nc\n" # Only short lines
|
||||
result = extract_title_from_content(content)
|
||||
# May return Untitled or one of the short lines
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_extract_title_whitespace_handling(self):
|
||||
"""Test that whitespace is properly handled."""
|
||||
content = "# Title with extra spaces \n\nContent"
|
||||
result = extract_title_from_content(content)
|
||||
# Should normalize spaces
|
||||
assert "Title with extra spaces" in result or len(result) > 5
|
||||
|
||||
def test_extract_title_multiline_html(self):
|
||||
"""Test HTML title extraction across multiple lines."""
|
||||
content = """
|
||||
<html>
|
||||
<head>
|
||||
<title>
|
||||
Multiline Title
|
||||
</title>
|
||||
</head>
|
||||
<body>Content</body>
|
||||
</html>
|
||||
"""
|
||||
result = extract_title_from_content(content)
|
||||
# Should handle multiline titles
|
||||
assert "Title" in result
|
||||
|
||||
def test_extract_title_mixed_formats(self):
|
||||
"""Test content with mixed formats (h1 should win)."""
|
||||
content = """
|
||||
<title>HTML Title</title>
|
||||
# Markdown H1
|
||||
## Markdown H2
|
||||
|
||||
Some paragraph content
|
||||
"""
|
||||
# HTML title comes first in priority
|
||||
assert extract_title_from_content(content) == "HTML Title"
|
||||
|
||||
def test_extract_title_real_world_example(self):
|
||||
"""Test with real-world HTML example."""
|
||||
content = """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>GitHub: Where the world builds software</title>
|
||||
<meta property="og:title" content="GitHub">
|
||||
</head>
|
||||
<body>
|
||||
<h1>Let's build from here</h1>
|
||||
<p>The complete developer platform...</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
result = extract_title_from_content(content)
|
||||
assert result == "GitHub: Where the world builds software"
|
||||
|
||||
def test_extract_title_json_with_nested_title(self):
|
||||
"""Test JSON title extraction with nested structures."""
|
||||
content = '{"meta": {"title": "Should not match"}, "title": "JSON Title"}'
|
||||
result = extract_title_from_content(content)
|
||||
# The regex will match the first "title" field it finds, which could be nested
|
||||
# Just verify it finds a title field
|
||||
assert result and result != "Untitled"
|
||||
|
||||
def test_extract_title_preserves_special_characters(self):
|
||||
"""Test that special characters are preserved in title."""
|
||||
content = "# Title with Special Characters: @#$%"
|
||||
result = extract_title_from_content(content)
|
||||
assert "@" in result or "$" in result or "%" in result or "Title" in result
|
||||
@@ -1,423 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for citation formatter enhancements.
|
||||
|
||||
Tests the multi-format citation parsing and extraction capabilities.
|
||||
"""
|
||||
|
||||
from src.citations.formatter import (
|
||||
parse_citations_from_report,
|
||||
_extract_markdown_links,
|
||||
_extract_numbered_citations,
|
||||
_extract_footnote_citations,
|
||||
_extract_html_links,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractMarkdownLinks:
|
||||
"""Test Markdown link extraction [title](url)."""
|
||||
|
||||
def test_extract_single_markdown_link(self):
|
||||
"""Test extraction of a single markdown link."""
|
||||
text = "[Example Article](https://example.com)"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["title"] == "Example Article"
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
assert citations[0]["format"] == "markdown"
|
||||
|
||||
def test_extract_multiple_markdown_links(self):
|
||||
"""Test extraction of multiple markdown links."""
|
||||
text = "[Link 1](https://example.com/1) and [Link 2](https://example.com/2)"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) == 2
|
||||
assert citations[0]["title"] == "Link 1"
|
||||
assert citations[1]["title"] == "Link 2"
|
||||
|
||||
def test_extract_markdown_link_with_spaces(self):
|
||||
"""Test markdown link with spaces in title."""
|
||||
text = "[Article Title With Spaces](https://example.com)"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["title"] == "Article Title With Spaces"
|
||||
|
||||
def test_extract_markdown_link_ignore_non_http(self):
|
||||
"""Test that non-HTTP URLs are ignored."""
|
||||
text = "[Relative Link](./relative/path) [HTTP Link](https://example.com)"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
|
||||
def test_extract_markdown_link_with_query_params(self):
|
||||
"""Test markdown links with query parameters."""
|
||||
text = "[Search Result](https://example.com/search?q=test&page=1)"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) == 1
|
||||
assert "q=test" in citations[0]["url"]
|
||||
|
||||
def test_extract_markdown_link_empty_text(self):
|
||||
"""Test with no markdown links."""
|
||||
text = "Just plain text with no links"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) == 0
|
||||
|
||||
def test_extract_markdown_link_strip_whitespace(self):
|
||||
"""Test that whitespace in title and URL is stripped."""
|
||||
# Markdown links with spaces in URL are not valid, so they won't be extracted
|
||||
text = "[Title](https://example.com)"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) >= 1
|
||||
assert citations[0]["title"] == "Title"
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
|
||||
|
||||
class TestExtractNumberedCitations:
|
||||
"""Test numbered citation extraction [1] Title - URL."""
|
||||
|
||||
def test_extract_single_numbered_citation(self):
|
||||
"""Test extraction of a single numbered citation."""
|
||||
text = "[1] Example Article - https://example.com"
|
||||
citations = _extract_numbered_citations(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["title"] == "Example Article"
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
assert citations[0]["format"] == "numbered"
|
||||
|
||||
def test_extract_multiple_numbered_citations(self):
|
||||
"""Test extraction of multiple numbered citations."""
|
||||
text = "[1] First - https://example.com/1\n[2] Second - https://example.com/2"
|
||||
citations = _extract_numbered_citations(text)
|
||||
assert len(citations) == 2
|
||||
assert citations[0]["title"] == "First"
|
||||
assert citations[1]["title"] == "Second"
|
||||
|
||||
def test_extract_numbered_citation_with_long_title(self):
|
||||
"""Test numbered citation with longer title."""
|
||||
text = "[5] A Comprehensive Guide to Python Programming - https://example.com"
|
||||
citations = _extract_numbered_citations(text)
|
||||
assert len(citations) == 1
|
||||
assert "Comprehensive Guide" in citations[0]["title"]
|
||||
|
||||
def test_extract_numbered_citation_requires_valid_format(self):
|
||||
"""Test that invalid numbered format is not extracted."""
|
||||
text = "[1 Title - https://example.com" # Missing closing bracket
|
||||
citations = _extract_numbered_citations(text)
|
||||
assert len(citations) == 0
|
||||
|
||||
def test_extract_numbered_citation_empty_text(self):
|
||||
"""Test with no numbered citations."""
|
||||
text = "Just plain text"
|
||||
citations = _extract_numbered_citations(text)
|
||||
assert len(citations) == 0
|
||||
|
||||
def test_extract_numbered_citation_various_numbers(self):
|
||||
"""Test with various citation numbers."""
|
||||
text = "[10] Title Ten - https://example.com/10\n[999] Title 999 - https://example.com/999"
|
||||
citations = _extract_numbered_citations(text)
|
||||
assert len(citations) == 2
|
||||
|
||||
def test_extract_numbered_citation_ignore_non_http(self):
|
||||
"""Test that non-HTTP URLs in numbered citations are ignored."""
|
||||
text = "[1] Invalid - file://path [2] Valid - https://example.com"
|
||||
citations = _extract_numbered_citations(text)
|
||||
# Only the valid one should be extracted
|
||||
assert len(citations) <= 1
|
||||
|
||||
|
||||
class TestExtractFootnoteCitations:
|
||||
"""Test footnote citation extraction [^1]: Title - URL."""
|
||||
|
||||
def test_extract_single_footnote_citation(self):
|
||||
"""Test extraction of a single footnote citation."""
|
||||
text = "[^1]: Example Article - https://example.com"
|
||||
citations = _extract_footnote_citations(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["title"] == "Example Article"
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
assert citations[0]["format"] == "footnote"
|
||||
|
||||
def test_extract_multiple_footnote_citations(self):
|
||||
"""Test extraction of multiple footnote citations."""
|
||||
text = "[^1]: First - https://example.com/1\n[^2]: Second - https://example.com/2"
|
||||
citations = _extract_footnote_citations(text)
|
||||
assert len(citations) == 2
|
||||
|
||||
def test_extract_footnote_with_complex_number(self):
|
||||
"""Test footnote extraction with various numbers."""
|
||||
text = "[^123]: Title - https://example.com"
|
||||
citations = _extract_footnote_citations(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["title"] == "Title"
|
||||
|
||||
def test_extract_footnote_citation_with_spaces(self):
|
||||
"""Test footnote with spaces around separator."""
|
||||
text = "[^1]: Title with spaces - https://example.com "
|
||||
citations = _extract_footnote_citations(text)
|
||||
assert len(citations) == 1
|
||||
# Should strip whitespace
|
||||
assert citations[0]["title"] == "Title with spaces"
|
||||
|
||||
def test_extract_footnote_citation_empty_text(self):
|
||||
"""Test with no footnote citations."""
|
||||
text = "No footnotes here"
|
||||
citations = _extract_footnote_citations(text)
|
||||
assert len(citations) == 0
|
||||
|
||||
def test_extract_footnote_requires_caret(self):
|
||||
"""Test that missing caret prevents extraction."""
|
||||
text = "[1]: Title - https://example.com" # Missing ^
|
||||
citations = _extract_footnote_citations(text)
|
||||
assert len(citations) == 0
|
||||
|
||||
|
||||
class TestExtractHtmlLinks:
|
||||
"""Test HTML link extraction <a href="url">title</a>."""
|
||||
|
||||
def test_extract_single_html_link(self):
|
||||
"""Test extraction of a single HTML link."""
|
||||
text = '<a href="https://example.com">Example Article</a>'
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["title"] == "Example Article"
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
assert citations[0]["format"] == "html"
|
||||
|
||||
def test_extract_multiple_html_links(self):
|
||||
"""Test extraction of multiple HTML links."""
|
||||
text = '<a href="https://a.com">Link A</a> <a href="https://b.com">Link B</a>'
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 2
|
||||
|
||||
def test_extract_html_link_single_quotes(self):
|
||||
"""Test HTML links with single quotes."""
|
||||
text = "<a href='https://example.com'>Title</a>"
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
|
||||
def test_extract_html_link_with_attributes(self):
|
||||
"""Test HTML links with additional attributes."""
|
||||
text = '<a class="link" href="https://example.com" target="_blank">Title</a>'
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
|
||||
def test_extract_html_link_ignore_non_http(self):
|
||||
"""Test that non-HTTP URLs are ignored."""
|
||||
text = '<a href="mailto:test@example.com">Email</a> <a href="https://example.com">Web</a>'
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
|
||||
def test_extract_html_link_case_insensitive(self):
|
||||
"""Test that HTML extraction is case-insensitive."""
|
||||
text = '<A HREF="https://example.com">Title</A>'
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 1
|
||||
|
||||
def test_extract_html_link_empty_text(self):
|
||||
"""Test with no HTML links."""
|
||||
text = "No links here"
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 0
|
||||
|
||||
def test_extract_html_link_strip_whitespace(self):
|
||||
"""Test that whitespace in title is stripped."""
|
||||
text = '<a href="https://example.com"> Title with spaces </a>'
|
||||
citations = _extract_html_links(text)
|
||||
assert citations[0]["title"] == "Title with spaces"
|
||||
|
||||
|
||||
class TestParseCitationsFromReport:
|
||||
"""Test comprehensive citation parsing from complete reports."""
|
||||
|
||||
def test_parse_markdown_links_from_report(self):
|
||||
"""Test parsing markdown links from a report."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
[GitHub](https://github.com)
|
||||
[Python Docs](https://python.org)
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
assert result["count"] >= 2
|
||||
urls = [c["url"] for c in result["citations"]]
|
||||
assert "https://github.com" in urls
|
||||
|
||||
def test_parse_numbered_citations_from_report(self):
|
||||
"""Test parsing numbered citations."""
|
||||
report = """
|
||||
## References
|
||||
|
||||
[1] GitHub - https://github.com
|
||||
[2] Python - https://python.org
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
assert result["count"] >= 2
|
||||
|
||||
def test_parse_mixed_format_citations(self):
|
||||
"""Test parsing mixed citation formats."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
[GitHub](https://github.com)
|
||||
[^1]: Python - https://python.org
|
||||
[2] Wikipedia - https://wikipedia.org
|
||||
<a href="https://stackoverflow.com">Stack Overflow</a>
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
# Should find all 4 citations
|
||||
assert result["count"] >= 3
|
||||
|
||||
def test_parse_citations_deduplication(self):
|
||||
"""Test that duplicate URLs are deduplicated."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
[GitHub 1](https://github.com)
|
||||
[GitHub 2](https://github.com)
|
||||
[GitHub](https://github.com)
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
# Should have only 1 unique citation
|
||||
assert result["count"] == 1
|
||||
assert result["citations"][0]["url"] == "https://github.com"
|
||||
|
||||
def test_parse_citations_various_section_patterns(self):
|
||||
"""Test parsing with different section headers."""
|
||||
report_refs = """
|
||||
## References
|
||||
[GitHub](https://github.com)
|
||||
"""
|
||||
report_sources = """
|
||||
## Sources
|
||||
[GitHub](https://github.com)
|
||||
"""
|
||||
report_bibliography = """
|
||||
## Bibliography
|
||||
[GitHub](https://github.com)
|
||||
"""
|
||||
|
||||
assert parse_citations_from_report(report_refs)["count"] >= 1
|
||||
assert parse_citations_from_report(report_sources)["count"] >= 1
|
||||
assert parse_citations_from_report(report_bibliography)["count"] >= 1
|
||||
|
||||
def test_parse_citations_custom_patterns(self):
|
||||
"""Test parsing with custom section patterns."""
|
||||
report = """
|
||||
## My Custom Sources
|
||||
[GitHub](https://github.com)
|
||||
"""
|
||||
result = parse_citations_from_report(
|
||||
report,
|
||||
section_patterns=[r"##\s*My Custom Sources"]
|
||||
)
|
||||
assert result["count"] >= 1
|
||||
|
||||
def test_parse_citations_empty_report(self):
|
||||
"""Test parsing an empty report."""
|
||||
result = parse_citations_from_report("")
|
||||
assert result["count"] == 0
|
||||
assert result["citations"] == []
|
||||
|
||||
def test_parse_citations_no_section(self):
|
||||
"""Test parsing report without citation section."""
|
||||
report = "This is a report with no citations section"
|
||||
result = parse_citations_from_report(report)
|
||||
assert result["count"] == 0
|
||||
|
||||
def test_parse_citations_complex_report(self):
|
||||
"""Test parsing a complex, realistic report."""
|
||||
report = """
|
||||
# Research Report
|
||||
|
||||
## Introduction
|
||||
|
||||
This report summarizes findings from multiple sources.
|
||||
|
||||
## Key Findings
|
||||
|
||||
Some important discoveries were made based on research [GitHub](https://github.com).
|
||||
|
||||
## Key Citations
|
||||
|
||||
1. Primary sources:
|
||||
[GitHub](https://github.com) - A collaborative platform
|
||||
[^1]: Python - https://python.org
|
||||
|
||||
2. Secondary sources:
|
||||
[2] Wikipedia - https://wikipedia.org
|
||||
|
||||
3. Web resources:
|
||||
<a href="https://stackoverflow.com">Stack Overflow</a>
|
||||
|
||||
## Methodology
|
||||
|
||||
[Additional](https://example.com) details about methodology.
|
||||
|
||||
---
|
||||
|
||||
[^1]: The Python programming language official site
|
||||
"""
|
||||
|
||||
result = parse_citations_from_report(report)
|
||||
# Should extract multiple citations from the Key Citations section
|
||||
assert result["count"] >= 3
|
||||
urls = [c["url"] for c in result["citations"]]
|
||||
# Verify some key URLs are found
|
||||
assert any("github.com" in url or "python.org" in url for url in urls)
|
||||
|
||||
def test_parse_citations_stops_at_next_section(self):
|
||||
"""Test that citation extraction looks for citation sections."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
[Cite 1](https://example.com/1)
|
||||
[Cite 2](https://example.com/2)
|
||||
|
||||
## Next Section
|
||||
|
||||
Some other content
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
# Should extract citations from the Key Citations section
|
||||
# Note: The regex stops at next ## section
|
||||
assert result["count"] >= 1
|
||||
assert any("example.com/1" in c["url"] for c in result["citations"])
|
||||
|
||||
def test_parse_citations_preserves_metadata(self):
|
||||
"""Test that citation metadata is preserved."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
[Python Documentation](https://python.org)
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
assert len(result["citations"]) >= 1
|
||||
citation = result["citations"][0]
|
||||
assert "title" in citation
|
||||
assert "url" in citation
|
||||
assert "format" in citation
|
||||
|
||||
def test_parse_citations_whitespace_handling(self):
|
||||
"""Test handling of various whitespace configurations."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
[Link](https://example.com)
|
||||
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
assert result["count"] >= 1
|
||||
|
||||
def test_parse_citations_multiline_links(self):
|
||||
"""Test extraction of links across formatting."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
Some paragraph with a [link to example](https://example.com) in the middle.
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
assert result["count"] >= 1
|
||||
@@ -1,467 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for citation models.
|
||||
|
||||
Tests the Pydantic BaseModel implementation of CitationMetadata and Citation classes.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.citations.models import Citation, CitationMetadata
|
||||
|
||||
|
||||
class TestCitationMetadata:
|
||||
"""Test CitationMetadata Pydantic model."""
|
||||
|
||||
def test_create_basic_metadata(self):
|
||||
"""Test creating basic citation metadata."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com/article",
|
||||
title="Example Article",
|
||||
)
|
||||
assert metadata.url == "https://example.com/article"
|
||||
assert metadata.title == "Example Article"
|
||||
assert metadata.domain == "example.com" # Auto-extracted from URL
|
||||
assert metadata.description is None
|
||||
assert metadata.images == []
|
||||
assert metadata.extra == {}
|
||||
|
||||
def test_metadata_with_all_fields(self):
|
||||
"""Test creating metadata with all fields populated."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://github.com/example/repo",
|
||||
title="Example Repository",
|
||||
description="A great repository",
|
||||
content_snippet="This is a snippet",
|
||||
raw_content="Full content here",
|
||||
author="John Doe",
|
||||
published_date="2025-01-24",
|
||||
language="en",
|
||||
relevance_score=0.95,
|
||||
credibility_score=0.88,
|
||||
)
|
||||
assert metadata.url == "https://github.com/example/repo"
|
||||
assert metadata.domain == "github.com"
|
||||
assert metadata.author == "John Doe"
|
||||
assert metadata.relevance_score == 0.95
|
||||
assert metadata.credibility_score == 0.88
|
||||
|
||||
def test_metadata_domain_auto_extraction(self):
|
||||
"""Test automatic domain extraction from URL."""
|
||||
test_cases = [
|
||||
("https://www.example.com/path", "www.example.com"),
|
||||
("http://github.com/user/repo", "github.com"),
|
||||
("https://api.github.com:443/repos", "api.github.com:443"),
|
||||
]
|
||||
|
||||
for url, expected_domain in test_cases:
|
||||
metadata = CitationMetadata(url=url, title="Test")
|
||||
assert metadata.domain == expected_domain
|
||||
|
||||
def test_metadata_id_generation(self):
|
||||
"""Test unique ID generation from URL."""
|
||||
metadata1 = CitationMetadata(
|
||||
url="https://example.com/article",
|
||||
title="Article",
|
||||
)
|
||||
metadata2 = CitationMetadata(
|
||||
url="https://example.com/article",
|
||||
title="Article",
|
||||
)
|
||||
# Same URL should produce same ID
|
||||
assert metadata1.id == metadata2.id
|
||||
|
||||
metadata3 = CitationMetadata(
|
||||
url="https://different.com/article",
|
||||
title="Article",
|
||||
)
|
||||
# Different URL should produce different ID
|
||||
assert metadata1.id != metadata3.id
|
||||
|
||||
def test_metadata_id_length(self):
|
||||
"""Test that ID is truncated to 12 characters."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Test",
|
||||
)
|
||||
assert len(metadata.id) == 12
|
||||
assert metadata.id.isalnum() or all(c in "0123456789abcdef" for c in metadata.id)
|
||||
|
||||
def test_metadata_from_dict(self):
|
||||
"""Test creating metadata from dictionary."""
|
||||
data = {
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"description": "A description",
|
||||
"author": "John Doe",
|
||||
}
|
||||
metadata = CitationMetadata.from_dict(data)
|
||||
assert metadata.url == "https://example.com"
|
||||
assert metadata.title == "Example"
|
||||
assert metadata.description == "A description"
|
||||
assert metadata.author == "John Doe"
|
||||
|
||||
def test_metadata_from_dict_removes_id(self):
|
||||
"""Test that from_dict removes computed 'id' field."""
|
||||
data = {
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"id": "some_old_id", # Should be ignored
|
||||
}
|
||||
metadata = CitationMetadata.from_dict(data)
|
||||
# Should use newly computed ID, not the old one
|
||||
assert metadata.id != "some_old_id"
|
||||
|
||||
def test_metadata_to_dict(self):
|
||||
"""Test converting metadata to dictionary."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
author="John Doe",
|
||||
)
|
||||
result = metadata.to_dict()
|
||||
|
||||
assert result["url"] == "https://example.com"
|
||||
assert result["title"] == "Example"
|
||||
assert result["author"] == "John Doe"
|
||||
assert result["id"] == metadata.id
|
||||
assert result["domain"] == "example.com"
|
||||
|
||||
def test_metadata_from_search_result(self):
|
||||
"""Test creating metadata from search result."""
|
||||
search_result = {
|
||||
"url": "https://example.com/article",
|
||||
"title": "Article Title",
|
||||
"content": "Article content here",
|
||||
"score": 0.92,
|
||||
"type": "page",
|
||||
}
|
||||
metadata = CitationMetadata.from_search_result(
|
||||
search_result,
|
||||
query="test query",
|
||||
)
|
||||
|
||||
assert metadata.url == "https://example.com/article"
|
||||
assert metadata.title == "Article Title"
|
||||
assert metadata.description == "Article content here"
|
||||
assert metadata.relevance_score == 0.92
|
||||
assert metadata.extra["query"] == "test query"
|
||||
assert metadata.extra["result_type"] == "page"
|
||||
|
||||
def test_metadata_pydantic_validation(self):
|
||||
"""Test that Pydantic validates required fields."""
|
||||
# URL and title are required
|
||||
with pytest.raises(ValidationError):
|
||||
CitationMetadata() # Missing required fields
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
CitationMetadata(url="https://example.com") # Missing title
|
||||
|
||||
def test_metadata_model_dump(self):
|
||||
"""Test Pydantic model_dump method."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
author="John Doe",
|
||||
)
|
||||
result = metadata.model_dump()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["url"] == "https://example.com"
|
||||
assert result["title"] == "Example"
|
||||
|
||||
def test_metadata_model_dump_json(self):
|
||||
"""Test Pydantic model_dump_json method."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
result = metadata.model_dump_json()
|
||||
|
||||
assert isinstance(result, str)
|
||||
data = json.loads(result)
|
||||
assert data["url"] == "https://example.com"
|
||||
assert data["title"] == "Example"
|
||||
|
||||
def test_metadata_with_images_and_extra(self):
|
||||
"""Test metadata with list and dict fields."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
images=["https://example.com/image1.jpg", "https://example.com/image2.jpg"],
|
||||
favicon="https://example.com/favicon.ico",
|
||||
extra={"custom_field": "value", "tags": ["tag1", "tag2"]},
|
||||
)
|
||||
|
||||
assert len(metadata.images) == 2
|
||||
assert metadata.favicon == "https://example.com/favicon.ico"
|
||||
assert metadata.extra["custom_field"] == "value"
|
||||
|
||||
|
||||
class TestCitation:
|
||||
"""Test Citation Pydantic model."""
|
||||
|
||||
def test_create_basic_citation(self):
|
||||
"""Test creating a basic citation."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
citation = Citation(number=1, metadata=metadata)
|
||||
|
||||
assert citation.number == 1
|
||||
assert citation.metadata == metadata
|
||||
assert citation.context is None
|
||||
assert citation.cited_text is None
|
||||
|
||||
def test_citation_properties(self):
|
||||
"""Test citation property shortcuts."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example Title",
|
||||
)
|
||||
citation = Citation(number=1, metadata=metadata)
|
||||
|
||||
assert citation.id == metadata.id
|
||||
assert citation.url == "https://example.com"
|
||||
assert citation.title == "Example Title"
|
||||
|
||||
def test_citation_to_markdown_reference(self):
|
||||
"""Test markdown reference generation."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
citation = Citation(number=1, metadata=metadata)
|
||||
|
||||
result = citation.to_markdown_reference()
|
||||
assert result == "[Example](https://example.com)"
|
||||
|
||||
def test_citation_to_numbered_reference(self):
|
||||
"""Test numbered reference generation."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example Article",
|
||||
)
|
||||
citation = Citation(number=5, metadata=metadata)
|
||||
|
||||
result = citation.to_numbered_reference()
|
||||
assert result == "[5] Example Article - https://example.com"
|
||||
|
||||
def test_citation_to_inline_marker(self):
|
||||
"""Test inline marker generation."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
citation = Citation(number=3, metadata=metadata)
|
||||
|
||||
result = citation.to_inline_marker()
|
||||
assert result == "[^3]"
|
||||
|
||||
def test_citation_to_footnote(self):
|
||||
"""Test footnote generation."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example Article",
|
||||
)
|
||||
citation = Citation(number=2, metadata=metadata)
|
||||
|
||||
result = citation.to_footnote()
|
||||
assert result == "[^2]: Example Article - https://example.com"
|
||||
|
||||
def test_citation_with_context_and_text(self):
|
||||
"""Test citation with context and cited text."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
citation = Citation(
|
||||
number=1,
|
||||
metadata=metadata,
|
||||
context="This is important context",
|
||||
cited_text="Important quote from the source",
|
||||
)
|
||||
|
||||
assert citation.context == "This is important context"
|
||||
assert citation.cited_text == "Important quote from the source"
|
||||
|
||||
def test_citation_from_dict(self):
|
||||
"""Test creating citation from dictionary."""
|
||||
data = {
|
||||
"number": 1,
|
||||
"metadata": {
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"author": "John Doe",
|
||||
},
|
||||
"context": "Test context",
|
||||
}
|
||||
citation = Citation.from_dict(data)
|
||||
|
||||
assert citation.number == 1
|
||||
assert citation.metadata.url == "https://example.com"
|
||||
assert citation.metadata.title == "Example"
|
||||
assert citation.metadata.author == "John Doe"
|
||||
assert citation.context == "Test context"
|
||||
|
||||
def test_citation_to_dict(self):
|
||||
"""Test converting citation to dictionary."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
author="John Doe",
|
||||
)
|
||||
citation = Citation(
|
||||
number=1,
|
||||
metadata=metadata,
|
||||
context="Test context",
|
||||
)
|
||||
result = citation.to_dict()
|
||||
|
||||
assert result["number"] == 1
|
||||
assert result["metadata"]["url"] == "https://example.com"
|
||||
assert result["metadata"]["author"] == "John Doe"
|
||||
assert result["context"] == "Test context"
|
||||
|
||||
def test_citation_round_trip(self):
|
||||
"""Test converting to dict and back."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
author="John Doe",
|
||||
relevance_score=0.95,
|
||||
)
|
||||
original = Citation(number=1, metadata=metadata, context="Test")
|
||||
|
||||
# Convert to dict and back
|
||||
dict_repr = original.to_dict()
|
||||
restored = Citation.from_dict(dict_repr)
|
||||
|
||||
assert restored.number == original.number
|
||||
assert restored.metadata.url == original.metadata.url
|
||||
assert restored.metadata.title == original.metadata.title
|
||||
assert restored.metadata.author == original.metadata.author
|
||||
assert restored.metadata.relevance_score == original.metadata.relevance_score
|
||||
|
||||
def test_citation_model_dump(self):
|
||||
"""Test Pydantic model_dump method."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
citation = Citation(number=1, metadata=metadata)
|
||||
result = citation.model_dump()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["number"] == 1
|
||||
assert result["metadata"]["url"] == "https://example.com"
|
||||
|
||||
def test_citation_model_dump_json(self):
|
||||
"""Test Pydantic model_dump_json method."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
citation = Citation(number=1, metadata=metadata)
|
||||
result = citation.model_dump_json()
|
||||
|
||||
assert isinstance(result, str)
|
||||
data = json.loads(result)
|
||||
assert data["number"] == 1
|
||||
assert data["metadata"]["url"] == "https://example.com"
|
||||
|
||||
def test_citation_pydantic_validation(self):
|
||||
"""Test that Pydantic validates required fields."""
|
||||
# Number and metadata are required
|
||||
with pytest.raises(ValidationError):
|
||||
Citation() # Missing required fields
|
||||
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
Citation(metadata=metadata) # Missing number
|
||||
|
||||
|
||||
class TestCitationIntegration:
|
||||
"""Integration tests for citation models."""
|
||||
|
||||
def test_search_result_to_citation_workflow(self):
|
||||
"""Test complete workflow from search result to citation."""
|
||||
search_result = {
|
||||
"url": "https://example.com/article",
|
||||
"title": "Great Article",
|
||||
"content": "This is a great article about testing",
|
||||
"score": 0.92,
|
||||
}
|
||||
|
||||
# Create metadata from search result
|
||||
metadata = CitationMetadata.from_search_result(search_result, query="testing")
|
||||
|
||||
# Create citation
|
||||
citation = Citation(number=1, metadata=metadata, context="Important source")
|
||||
|
||||
# Verify the workflow
|
||||
assert citation.number == 1
|
||||
assert citation.url == "https://example.com/article"
|
||||
assert citation.title == "Great Article"
|
||||
assert citation.metadata.relevance_score == 0.92
|
||||
assert citation.to_markdown_reference() == "[Great Article](https://example.com/article)"
|
||||
|
||||
def test_multiple_citations_with_different_formats(self):
|
||||
"""Test handling multiple citations in different formats."""
|
||||
citations = []
|
||||
|
||||
# Create first citation
|
||||
metadata1 = CitationMetadata(
|
||||
url="https://example.com/1",
|
||||
title="First Article",
|
||||
)
|
||||
citations.append(Citation(number=1, metadata=metadata1))
|
||||
|
||||
# Create second citation
|
||||
metadata2 = CitationMetadata(
|
||||
url="https://example.com/2",
|
||||
title="Second Article",
|
||||
)
|
||||
citations.append(Citation(number=2, metadata=metadata2))
|
||||
|
||||
# Verify all reference formats
|
||||
assert citations[0].to_markdown_reference() == "[First Article](https://example.com/1)"
|
||||
assert citations[1].to_numbered_reference() == "[2] Second Article - https://example.com/2"
|
||||
|
||||
def test_citation_json_serialization_roundtrip(self):
|
||||
"""Test JSON serialization and deserialization roundtrip."""
|
||||
original_data = {
|
||||
"number": 1,
|
||||
"metadata": {
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"author": "John Doe",
|
||||
"relevance_score": 0.95,
|
||||
},
|
||||
"context": "Test context",
|
||||
"cited_text": "Important quote",
|
||||
}
|
||||
|
||||
# Create from dict
|
||||
citation = Citation.from_dict(original_data)
|
||||
|
||||
# Serialize to JSON
|
||||
json_str = citation.model_dump_json()
|
||||
|
||||
# Deserialize from JSON
|
||||
restored = Citation.model_validate_json(json_str)
|
||||
|
||||
# Verify data integrity
|
||||
assert restored.number == original_data["number"]
|
||||
assert restored.metadata.url == original_data["metadata"]["url"]
|
||||
assert restored.metadata.relevance_score == original_data["metadata"]["relevance_score"]
|
||||
assert restored.context == original_data["context"]
|
||||
@@ -1,183 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
from src.config.configuration import Configuration
|
||||
|
||||
# Patch sys.path so relative import works
|
||||
|
||||
# Patch Resource for import
|
||||
mock_resource = type("Resource", (), {})
|
||||
|
||||
# Patch src.rag.retriever.Resource for import
|
||||
|
||||
module_name = "src.rag.retriever"
|
||||
if module_name not in sys.modules:
|
||||
retriever_mod = types.ModuleType(module_name)
|
||||
retriever_mod.Resource = mock_resource
|
||||
sys.modules[module_name] = retriever_mod
|
||||
|
||||
# Relative import of Configuration
|
||||
|
||||
|
||||
def test_default_configuration():
|
||||
config = Configuration()
|
||||
assert config.resources == []
|
||||
assert config.max_plan_iterations == 1
|
||||
assert config.max_step_num == 3
|
||||
assert config.max_search_results == 3
|
||||
assert config.mcp_settings is None
|
||||
|
||||
|
||||
def test_from_runnable_config_with_config_dict(monkeypatch):
|
||||
config_dict = {
|
||||
"configurable": {
|
||||
"max_plan_iterations": 5,
|
||||
"max_step_num": 7,
|
||||
"max_search_results": 10,
|
||||
"mcp_settings": {"foo": "bar"},
|
||||
}
|
||||
}
|
||||
config = Configuration.from_runnable_config(config_dict)
|
||||
assert config.max_plan_iterations == 5
|
||||
assert config.max_step_num == 7
|
||||
assert config.max_search_results == 10
|
||||
assert config.mcp_settings == {"foo": "bar"}
|
||||
|
||||
|
||||
def test_from_runnable_config_with_env_override(monkeypatch):
|
||||
monkeypatch.setenv("MAX_PLAN_ITERATIONS", "9")
|
||||
monkeypatch.setenv("MAX_STEP_NUM", "11")
|
||||
config_dict = {
|
||||
"configurable": {
|
||||
"max_plan_iterations": 2,
|
||||
"max_step_num": 3,
|
||||
"max_search_results": 4,
|
||||
}
|
||||
}
|
||||
config = Configuration.from_runnable_config(config_dict)
|
||||
# Environment variables take precedence and are strings
|
||||
assert config.max_plan_iterations == "9"
|
||||
assert config.max_step_num == "11"
|
||||
assert config.max_search_results == 4 # not overridden
|
||||
# Clean up
|
||||
monkeypatch.delenv("MAX_PLAN_ITERATIONS")
|
||||
monkeypatch.delenv("MAX_STEP_NUM")
|
||||
|
||||
|
||||
def test_from_runnable_config_with_none_and_falsy(monkeypatch):
|
||||
"""Test that None values are skipped but falsy values (0, False, empty string) are preserved."""
|
||||
config_dict = {
|
||||
"configurable": {
|
||||
"max_plan_iterations": None, # None should be skipped, use default
|
||||
"max_step_num": 0, # 0 is valid, should be preserved
|
||||
"max_search_results": "", # Empty string should be preserved
|
||||
}
|
||||
}
|
||||
config = Configuration.from_runnable_config(config_dict)
|
||||
# None values should fall back to defaults
|
||||
assert config.max_plan_iterations == 1
|
||||
# Falsy but valid values should be preserved
|
||||
assert config.max_step_num == 0
|
||||
assert config.max_search_results == ""
|
||||
|
||||
|
||||
def test_from_runnable_config_with_no_config():
|
||||
config = Configuration.from_runnable_config()
|
||||
assert config.max_plan_iterations == 1
|
||||
assert config.max_step_num == 3
|
||||
assert config.max_search_results == 3
|
||||
assert config.resources == []
|
||||
assert config.mcp_settings is None
|
||||
|
||||
|
||||
def test_from_runnable_config_with_boolean_false_values():
|
||||
"""Test that boolean False values are correctly preserved and not filtered out.
|
||||
|
||||
This is a regression test for the bug where False values were treated as falsy
|
||||
and filtered out, causing fields to revert to their default values.
|
||||
"""
|
||||
config_dict = {
|
||||
"configurable": {
|
||||
"enable_web_search": False, # Should be preserved as False, not revert to True
|
||||
"enable_deep_thinking": False, # Should be preserved as False
|
||||
"enforce_web_search": False, # Should be preserved as False
|
||||
"enforce_researcher_search": False, # Should be preserved as False
|
||||
"max_plan_iterations": 5, # Control: non-falsy value
|
||||
}
|
||||
}
|
||||
config = Configuration.from_runnable_config(config_dict)
|
||||
|
||||
# Assert that False values are preserved
|
||||
assert config.enable_web_search is False, "enable_web_search should be False, not default True"
|
||||
assert config.enable_deep_thinking is False, "enable_deep_thinking should be False"
|
||||
assert config.enforce_web_search is False, "enforce_web_search should be False"
|
||||
assert config.enforce_researcher_search is False, "enforce_researcher_search should be False, not default True"
|
||||
|
||||
# Control: verify non-falsy values still work
|
||||
assert config.max_plan_iterations == 5
|
||||
|
||||
|
||||
def test_from_runnable_config_with_boolean_true_values():
|
||||
"""Test that boolean True values work correctly (control test)."""
|
||||
config_dict = {
|
||||
"configurable": {
|
||||
"enable_web_search": True,
|
||||
"enable_deep_thinking": True,
|
||||
"enforce_web_search": True,
|
||||
}
|
||||
}
|
||||
config = Configuration.from_runnable_config(config_dict)
|
||||
|
||||
assert config.enable_web_search is True
|
||||
assert config.enable_deep_thinking is True
|
||||
assert config.enforce_web_search is True
|
||||
|
||||
def test_get_recursion_limit_default(monkeypatch):
|
||||
from src.config.configuration import get_recursion_limit
|
||||
|
||||
monkeypatch.delenv("AGENT_RECURSION_LIMIT", raising=False)
|
||||
result = get_recursion_limit()
|
||||
assert result == 25
|
||||
|
||||
|
||||
def test_get_recursion_limit_custom_default(monkeypatch):
|
||||
from src.config.configuration import get_recursion_limit
|
||||
|
||||
monkeypatch.delenv("AGENT_RECURSION_LIMIT", raising=False)
|
||||
result = get_recursion_limit(50)
|
||||
assert result == 50
|
||||
|
||||
|
||||
def test_get_recursion_limit_from_env(monkeypatch):
|
||||
from src.config.configuration import get_recursion_limit
|
||||
|
||||
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "100")
|
||||
result = get_recursion_limit()
|
||||
assert result == 100
|
||||
|
||||
|
||||
def test_get_recursion_limit_invalid_env_value(monkeypatch):
|
||||
from src.config.configuration import get_recursion_limit
|
||||
|
||||
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "invalid")
|
||||
result = get_recursion_limit()
|
||||
assert result == 25
|
||||
|
||||
|
||||
def test_get_recursion_limit_negative_env_value(monkeypatch):
|
||||
from src.config.configuration import get_recursion_limit
|
||||
|
||||
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "-5")
|
||||
result = get_recursion_limit()
|
||||
assert result == 25
|
||||
|
||||
|
||||
def test_get_recursion_limit_zero_env_value(monkeypatch):
|
||||
from src.config.configuration import get_recursion_limit
|
||||
|
||||
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "0")
|
||||
result = get_recursion_limit()
|
||||
assert result == 25
|
||||
@@ -1,82 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from src.config.loader import load_yaml_config, process_dict, replace_env_vars
|
||||
|
||||
|
||||
def test_replace_env_vars_with_env(monkeypatch):
|
||||
monkeypatch.setenv("TEST_ENV", "env_value")
|
||||
assert replace_env_vars("$TEST_ENV") == "env_value"
|
||||
|
||||
|
||||
def test_replace_env_vars_without_env(monkeypatch):
|
||||
monkeypatch.delenv("NOT_SET_ENV", raising=False)
|
||||
assert replace_env_vars("$NOT_SET_ENV") == "NOT_SET_ENV"
|
||||
|
||||
|
||||
def test_replace_env_vars_non_string():
|
||||
assert replace_env_vars(123) == 123
|
||||
|
||||
|
||||
def test_replace_env_vars_regular_string():
|
||||
assert replace_env_vars("no_env") == "no_env"
|
||||
|
||||
|
||||
def test_process_dict_nested(monkeypatch):
|
||||
monkeypatch.setenv("FOO", "bar")
|
||||
config = {"a": "$FOO", "b": {"c": "$FOO", "d": 42, "e": "$NOT_SET_ENV"}}
|
||||
processed = process_dict(config)
|
||||
assert processed["a"] == "bar"
|
||||
assert processed["b"]["c"] == "bar"
|
||||
assert processed["b"]["d"] == 42
|
||||
assert processed["b"]["e"] == "NOT_SET_ENV"
|
||||
|
||||
|
||||
def test_process_dict_empty():
|
||||
assert process_dict({}) == {}
|
||||
|
||||
|
||||
def test_load_yaml_config_file_not_exist():
|
||||
assert load_yaml_config("non_existent_file.yaml") == {}
|
||||
|
||||
|
||||
def test_load_yaml_config(monkeypatch):
|
||||
monkeypatch.setenv("MY_ENV", "my_value")
|
||||
yaml_content = """
|
||||
key1: value1
|
||||
key2: $MY_ENV
|
||||
nested:
|
||||
key3: $MY_ENV
|
||||
key4: 123
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False) as tmp:
|
||||
tmp.write(yaml_content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
config = load_yaml_config(tmp_path)
|
||||
assert config["key1"] == "value1"
|
||||
assert config["key2"] == "my_value"
|
||||
assert config["nested"]["key3"] == "my_value"
|
||||
assert config["nested"]["key4"] == 123
|
||||
finally:
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def test_load_yaml_config_cache(monkeypatch):
|
||||
monkeypatch.setenv("CACHE_ENV", "cache_value")
|
||||
yaml_content = "foo: $CACHE_ENV"
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False) as tmp:
|
||||
tmp.write(yaml_content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
config1 = load_yaml_config(tmp_path)
|
||||
config2 = load_yaml_config(tmp_path)
|
||||
assert config1 is config2 # Should be cached (same object)
|
||||
assert config1["foo"] == "cache_value"
|
||||
finally:
|
||||
os.remove(tmp_path)
|
||||
@@ -1,113 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
from src.crawler.article import Article
|
||||
|
||||
|
||||
class DummyMarkdownify:
|
||||
"""A dummy markdownify replacement for patching if needed."""
|
||||
|
||||
@staticmethod
|
||||
def markdownify(html):
|
||||
return html
|
||||
|
||||
|
||||
def test_to_markdown_includes_title(monkeypatch):
|
||||
article = Article("Test Title", "<p>Hello <b>world</b>!</p>")
|
||||
result = article.to_markdown(including_title=True)
|
||||
assert result.startswith("# Test Title")
|
||||
assert "Hello" in result
|
||||
|
||||
|
||||
def test_to_markdown_excludes_title():
|
||||
article = Article("Test Title", "<p>Hello <b>world</b>!</p>")
|
||||
result = article.to_markdown(including_title=False)
|
||||
assert not result.startswith("# Test Title")
|
||||
assert "Hello" in result
|
||||
|
||||
|
||||
def test_to_message_with_text_only():
|
||||
article = Article("Test Title", "<p>Hello world!</p>")
|
||||
article.url = "https://example.com/"
|
||||
result = article.to_message()
|
||||
assert isinstance(result, list)
|
||||
assert any(item["type"] == "text" for item in result)
|
||||
assert all("type" in item for item in result)
|
||||
|
||||
|
||||
def test_to_message_with_image(monkeypatch):
|
||||
html = '<p>Intro</p><img src="img/pic.png"/>'
|
||||
article = Article("Title", html)
|
||||
article.url = "https://host.com/path/"
|
||||
# The markdownify library will convert <img> to markdown image syntax
|
||||
result = article.to_message()
|
||||
# Should have both text and image_url types
|
||||
types = [item["type"] for item in result]
|
||||
assert "image_url" in types
|
||||
assert "text" in types
|
||||
# Check that the image_url is correctly joined
|
||||
image_items = [item for item in result if item["type"] == "image_url"]
|
||||
assert image_items
|
||||
assert image_items[0]["image_url"]["url"] == "https://host.com/path/img/pic.png"
|
||||
|
||||
|
||||
def test_to_message_multiple_images():
|
||||
html = '<p>Start</p><img src="a.png"/><p>Mid</p><img src="b.jpg"/>End'
|
||||
article = Article("Title", html)
|
||||
article.url = "http://x/"
|
||||
result = article.to_message()
|
||||
image_urls = [
|
||||
item["image_url"]["url"] for item in result if item["type"] == "image_url"
|
||||
]
|
||||
assert "http://x/a.png" in image_urls
|
||||
assert "http://x/b.jpg" in image_urls
|
||||
text_items = [item for item in result if item["type"] == "text"]
|
||||
assert any("Start" in item["text"] for item in text_items)
|
||||
assert any("Mid" in item["text"] for item in text_items)
|
||||
|
||||
|
||||
def test_to_message_handles_empty_html():
|
||||
article = Article("Empty", "")
|
||||
article.url = "http://test/"
|
||||
result = article.to_message()
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["type"] == "text"
|
||||
|
||||
|
||||
def test_to_markdown_handles_none_content():
|
||||
article = Article("Test Title", None)
|
||||
result = article.to_markdown(including_title=True)
|
||||
assert "# Test Title" in result
|
||||
assert "No content available" in result
|
||||
|
||||
|
||||
def test_to_markdown_handles_empty_string():
|
||||
article = Article("Test Title", "")
|
||||
result = article.to_markdown(including_title=True)
|
||||
assert "# Test Title" in result
|
||||
assert "No content available" in result
|
||||
|
||||
|
||||
def test_to_markdown_handles_whitespace_only():
|
||||
article = Article("Test Title", " \n \t ")
|
||||
result = article.to_markdown(including_title=True)
|
||||
assert "# Test Title" in result
|
||||
assert "No content available" in result
|
||||
|
||||
|
||||
def test_to_message_handles_none_content():
|
||||
article = Article("Title", None)
|
||||
article.url = "http://test/"
|
||||
result = article.to_message()
|
||||
assert isinstance(result, list)
|
||||
assert len(result) > 0
|
||||
assert result[0]["type"] == "text"
|
||||
assert "No content available" in result[0]["text"]
|
||||
|
||||
|
||||
def test_to_message_handles_whitespace_only_content():
|
||||
article = Article("Title", " \n ")
|
||||
article.url = "http://test/"
|
||||
result = article.to_message()
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["type"] == "text"
|
||||
assert "No content available" in result[0]["text"]
|
||||
@@ -1,675 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import src.crawler as crawler_module
|
||||
from src.crawler.crawler import safe_truncate
|
||||
from src.crawler.infoquest_client import InfoQuestClient
|
||||
|
||||
|
||||
def test_crawler_sets_article_url(monkeypatch):
|
||||
"""Test that the crawler sets the article.url field correctly."""
|
||||
|
||||
class DummyArticle:
|
||||
def __init__(self):
|
||||
self.url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return "# Dummy"
|
||||
|
||||
class DummyJinaClient:
|
||||
def crawl(self, url, return_format=None):
|
||||
return "<html>dummy</html>"
|
||||
|
||||
class DummyInfoQuestClient:
|
||||
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
|
||||
pass
|
||||
|
||||
def crawl(self, url, return_format=None):
|
||||
return "<html>dummy</html>"
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
return DummyArticle()
|
||||
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "jina"}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
|
||||
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
|
||||
monkeypatch.setattr(
|
||||
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
|
||||
)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com"
|
||||
article = crawler.crawl(url)
|
||||
assert article.url == url
|
||||
assert article.to_markdown() == "# Dummy"
|
||||
|
||||
|
||||
def test_crawler_calls_dependencies(monkeypatch):
|
||||
"""Test that Crawler calls JinaClient.crawl and ReadabilityExtractor.extract_article."""
|
||||
calls = {}
|
||||
|
||||
class DummyJinaClient:
|
||||
def crawl(self, url, return_format=None):
|
||||
calls["jina"] = (url, return_format)
|
||||
return "<html>dummy</html>"
|
||||
|
||||
# Fix: Update DummyInfoQuestClient to accept initialization parameters
|
||||
class DummyInfoQuestClient:
|
||||
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
|
||||
# We don't need to use these parameters, just accept them
|
||||
pass
|
||||
|
||||
def crawl(self, url, return_format=None):
|
||||
calls["infoquest"] = (url, return_format)
|
||||
return "<html>dummy</html>"
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
calls["extractor"] = html
|
||||
|
||||
class DummyArticle:
|
||||
url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return "# Dummy"
|
||||
|
||||
return DummyArticle()
|
||||
|
||||
# Add mock for load_yaml_config to ensure it returns configuration with Jina engine
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "jina"}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
|
||||
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient) # Include this if InfoQuest might be used
|
||||
monkeypatch.setattr("src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com"
|
||||
crawler.crawl(url)
|
||||
assert "jina" in calls
|
||||
assert calls["jina"][0] == url
|
||||
assert calls["jina"][1] == "html"
|
||||
assert "extractor" in calls
|
||||
assert calls["extractor"] == "<html>dummy</html>"
|
||||
|
||||
|
||||
def test_crawler_handles_empty_content(monkeypatch):
|
||||
"""Test that the crawler handles empty content gracefully."""
|
||||
|
||||
class DummyArticle:
|
||||
def __init__(self, title, html_content):
|
||||
self.title = title
|
||||
self.html_content = html_content
|
||||
self.url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return f"# {self.title}"
|
||||
|
||||
class DummyJinaClient:
|
||||
def crawl(self, url, return_format=None):
|
||||
return "" # Empty content
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
# This should not be called for empty content
|
||||
assert False, "ReadabilityExtractor should not be called for empty content"
|
||||
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "jina"}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
|
||||
monkeypatch.setattr(
|
||||
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
|
||||
)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com"
|
||||
article = crawler.crawl(url)
|
||||
|
||||
assert article.url == url
|
||||
assert article.title == "Empty Content"
|
||||
assert "No content could be extracted from this page" in article.html_content
|
||||
|
||||
|
||||
def test_crawler_handles_error_response_from_client(monkeypatch):
|
||||
"""Test that the crawler handles error responses from the client gracefully."""
|
||||
|
||||
class DummyArticle:
|
||||
def __init__(self, title, html_content):
|
||||
self.title = title
|
||||
self.html_content = html_content
|
||||
self.url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return f"# {self.title}"
|
||||
|
||||
class DummyJinaClient:
|
||||
def crawl(self, url, return_format=None):
|
||||
return "Error: API returned status 500"
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
# This should not be called for error responses
|
||||
assert False, "ReadabilityExtractor should not be called for error responses"
|
||||
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "jina"}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
|
||||
monkeypatch.setattr(
|
||||
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
|
||||
)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com"
|
||||
article = crawler.crawl(url)
|
||||
|
||||
assert article.url == url
|
||||
assert article.title in ["Non-HTML Content", "Content Extraction Failed"]
|
||||
assert "Error: API returned status 500" in article.html_content
|
||||
|
||||
|
||||
def test_crawler_handles_non_html_content(monkeypatch):
|
||||
"""Test that the crawler handles non-HTML content gracefully."""
|
||||
|
||||
class DummyArticle:
|
||||
def __init__(self, title, html_content):
|
||||
self.title = title
|
||||
self.html_content = html_content
|
||||
self.url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return f"# {self.title}"
|
||||
|
||||
class DummyJinaClient:
|
||||
def crawl(self, url, return_format=None):
|
||||
return "This is plain text content, not HTML"
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
# This should not be called for non-HTML content
|
||||
assert False, "ReadabilityExtractor should not be called for non-HTML content"
|
||||
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "jina"}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
|
||||
monkeypatch.setattr(
|
||||
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
|
||||
)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com"
|
||||
article = crawler.crawl(url)
|
||||
|
||||
assert article.url == url
|
||||
assert article.title in ["Non-HTML Content", "Content Extraction Failed"]
|
||||
assert "cannot be parsed as HTML" in article.html_content or "Content extraction failed" in article.html_content
|
||||
assert "plain text content" in article.html_content # Should include a snippet of the original content
|
||||
|
||||
|
||||
def test_crawler_handles_extraction_failure(monkeypatch):
|
||||
"""Test that the crawler handles readability extraction failure gracefully."""
|
||||
|
||||
class DummyArticle:
|
||||
def __init__(self, title, html_content):
|
||||
self.title = title
|
||||
self.html_content = html_content
|
||||
self.url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return f"# {self.title}"
|
||||
|
||||
class DummyJinaClient:
|
||||
def crawl(self, url, return_format=None):
|
||||
return "<html><body>Valid HTML but extraction will fail</body></html>"
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
raise Exception("Extraction failed")
|
||||
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "jina"}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
|
||||
monkeypatch.setattr(
|
||||
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
|
||||
)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com"
|
||||
article = crawler.crawl(url)
|
||||
|
||||
assert article.url == url
|
||||
assert article.title == "Content Extraction Failed"
|
||||
assert "Content extraction failed" in article.html_content
|
||||
assert "Valid HTML but extraction will fail" in article.html_content # Should include a snippet of the HTML
|
||||
|
||||
|
||||
def test_crawler_with_json_like_content(monkeypatch):
|
||||
"""Test that the crawler handles JSON-like content gracefully."""
|
||||
|
||||
class DummyArticle:
|
||||
def __init__(self, title, html_content):
|
||||
self.title = title
|
||||
self.html_content = html_content
|
||||
self.url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return f"# {self.title}"
|
||||
|
||||
class DummyJinaClient:
|
||||
def crawl(self, url, return_format=None):
|
||||
return '{"title": "Some JSON", "content": "This is JSON content"}'
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
# This should not be called for JSON content
|
||||
assert False, "ReadabilityExtractor should not be called for JSON content"
|
||||
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "jina"}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
|
||||
monkeypatch.setattr(
|
||||
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
|
||||
)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com/api/data"
|
||||
article = crawler.crawl(url)
|
||||
|
||||
assert article.url == url
|
||||
assert article.title in ["Non-HTML Content", "Content Extraction Failed"]
|
||||
assert "cannot be parsed as HTML" in article.html_content or "Content extraction failed" in article.html_content
|
||||
assert '{"title": "Some JSON"' in article.html_content # Should include a snippet of the JSON
|
||||
|
||||
|
||||
def test_crawler_with_various_html_formats(monkeypatch):
|
||||
"""Test that the crawler correctly identifies various HTML formats."""
|
||||
|
||||
class DummyArticle:
|
||||
def __init__(self, title, html_content):
|
||||
self.title = title
|
||||
self.html_content = html_content
|
||||
self.url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return f"# {self.title}"
|
||||
|
||||
# Test case 1: HTML with DOCTYPE
|
||||
class DummyJinaClient1:
|
||||
def crawl(self, url, return_format=None):
|
||||
return "<!DOCTYPE html><html><body><p>Test content</p></body></html>"
|
||||
|
||||
# Test case 2: HTML with leading whitespace
|
||||
class DummyJinaClient2:
|
||||
def crawl(self, url, return_format=None):
|
||||
return "\n\n <html><body><p>Test content</p></body></html>"
|
||||
|
||||
# Test case 3: HTML with comments
|
||||
class DummyJinaClient3:
|
||||
def crawl(self, url, return_format=None):
|
||||
return "<!-- HTML comment --><html><body><p>Test content</p></body></html>"
|
||||
|
||||
# Test case 4: HTML with self-closing tags
|
||||
class DummyJinaClient4:
|
||||
def crawl(self, url, return_format=None):
|
||||
return '<img src="test.jpg" alt="test" /><p>Test content</p>'
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
return DummyArticle("Extracted Article", "<p>Extracted content</p>")
|
||||
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "jina"}}
|
||||
|
||||
# Test each HTML format
|
||||
test_cases = [
|
||||
(DummyJinaClient1, "HTML with DOCTYPE"),
|
||||
(DummyJinaClient2, "HTML with leading whitespace"),
|
||||
(DummyJinaClient3, "HTML with comments"),
|
||||
(DummyJinaClient4, "HTML with self-closing tags"),
|
||||
]
|
||||
|
||||
for JinaClientClass, description in test_cases:
|
||||
monkeypatch.setattr("src.crawler.crawler.JinaClient", JinaClientClass)
|
||||
monkeypatch.setattr("src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com"
|
||||
article = crawler.crawl(url)
|
||||
|
||||
assert article.url == url
|
||||
assert article.title == "Extracted Article"
|
||||
assert "Extracted content" in article.html_content
|
||||
|
||||
|
||||
def test_safe_truncate_function():
|
||||
"""Test the safe_truncate function handles various character sets correctly."""
|
||||
|
||||
# Test None input
|
||||
assert safe_truncate(None) is None
|
||||
|
||||
# Test empty string
|
||||
assert safe_truncate("") == ""
|
||||
|
||||
# Test string shorter than limit
|
||||
assert safe_truncate("Short text") == "Short text"
|
||||
|
||||
# Test ASCII truncation
|
||||
result = safe_truncate("This is a longer text that needs truncation", 20)
|
||||
assert len(result) <= 20
|
||||
assert "..." in result
|
||||
|
||||
# Test Unicode/emoji characters
|
||||
text_with_emoji = "Hello! 🌍 Welcome to the world 🚀"
|
||||
result = safe_truncate(text_with_emoji, 20)
|
||||
assert len(result) <= 20
|
||||
assert "..." in result
|
||||
# Verify it's valid UTF-8
|
||||
assert result.encode('utf-8').decode('utf-8') == result
|
||||
|
||||
# Test very small limit
|
||||
assert safe_truncate("Long text", 1) == "."
|
||||
assert safe_truncate("Long text", 2) == ".."
|
||||
assert safe_truncate("Long text", 3) == "..."
|
||||
|
||||
# Test with Chinese characters
|
||||
chinese_text = "这是一个中文测试文本"
|
||||
result = safe_truncate(chinese_text, 10)
|
||||
assert len(result) <= 10
|
||||
# Verify it's valid UTF-8
|
||||
assert result.encode('utf-8').decode('utf-8') == result
|
||||
|
||||
# ========== InfoQuest Client Tests ==========
|
||||
|
||||
def test_crawler_selects_infoquest_engine(monkeypatch):
|
||||
"""Test that the crawler selects InfoQuestClient when configured to use it."""
|
||||
calls = {}
|
||||
|
||||
class DummyJinaClient:
|
||||
def crawl(self, url, return_format=None):
|
||||
calls["jina"] = True
|
||||
return "<html>dummy</html>"
|
||||
|
||||
class DummyInfoQuestClient:
|
||||
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
|
||||
calls["infoquest_init"] = (fetch_time, timeout, navi_timeout)
|
||||
|
||||
def crawl(self, url, return_format=None):
|
||||
calls["infoquest"] = (url, return_format)
|
||||
return "<html>dummy from infoquest</html>"
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
calls["extractor"] = html
|
||||
|
||||
class DummyArticle:
|
||||
url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return "# Dummy"
|
||||
|
||||
return DummyArticle()
|
||||
|
||||
# Mock configuration to use InfoQuest engine with custom parameters
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {
|
||||
"engine": "infoquest",
|
||||
"fetch_time": 30,
|
||||
"timeout": 60,
|
||||
"navi_timeout": 45
|
||||
}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.JinaClient", DummyJinaClient)
|
||||
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
|
||||
monkeypatch.setattr("src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com"
|
||||
crawler.crawl(url)
|
||||
|
||||
# Verify InfoQuestClient was used, not JinaClient
|
||||
assert "infoquest_init" in calls
|
||||
assert calls["infoquest_init"] == (30, 60, 45) # Verify parameters were passed correctly
|
||||
assert "infoquest" in calls
|
||||
assert calls["infoquest"][0] == url
|
||||
assert calls["infoquest"][1] == "html"
|
||||
assert "extractor" in calls
|
||||
assert calls["extractor"] == "<html>dummy from infoquest</html>"
|
||||
assert "jina" not in calls
|
||||
|
||||
|
||||
def test_crawler_with_infoquest_empty_content(monkeypatch):
|
||||
"""Test that the crawler handles empty content from InfoQuest client gracefully."""
|
||||
|
||||
class DummyArticle:
|
||||
def __init__(self, title, html_content):
|
||||
self.title = title
|
||||
self.html_content = html_content
|
||||
self.url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return f"# {self.title}"
|
||||
|
||||
class DummyInfoQuestClient:
|
||||
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
|
||||
pass
|
||||
|
||||
def crawl(self, url, return_format=None):
|
||||
return "" # Empty content
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
# This should not be called for empty content
|
||||
assert False, "ReadabilityExtractor should not be called for empty content"
|
||||
|
||||
# Mock configuration to use InfoQuest engine
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "infoquest"}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
|
||||
monkeypatch.setattr(
|
||||
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
|
||||
)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com"
|
||||
article = crawler.crawl(url)
|
||||
|
||||
assert article.url == url
|
||||
assert article.title == "Empty Content"
|
||||
assert "No content could be extracted from this page" in article.html_content
|
||||
|
||||
|
||||
def test_crawler_with_infoquest_non_html_content(monkeypatch):
|
||||
"""Test that the crawler handles non-HTML content from InfoQuest client gracefully."""
|
||||
|
||||
class DummyArticle:
|
||||
def __init__(self, title, html_content):
|
||||
self.title = title
|
||||
self.html_content = html_content
|
||||
self.url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return f"# {self.title}"
|
||||
|
||||
class DummyInfoQuestClient:
|
||||
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
|
||||
pass
|
||||
|
||||
def crawl(self, url, return_format=None):
|
||||
return "This is plain text content from InfoQuest, not HTML"
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
# This should not be called for non-HTML content
|
||||
assert False, "ReadabilityExtractor should not be called for non-HTML content"
|
||||
|
||||
# Mock configuration to use InfoQuest engine
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "infoquest"}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
|
||||
monkeypatch.setattr(
|
||||
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
|
||||
)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com"
|
||||
article = crawler.crawl(url)
|
||||
|
||||
assert article.url == url
|
||||
assert article.title in ["Non-HTML Content", "Content Extraction Failed"]
|
||||
assert "cannot be parsed as HTML" in article.html_content or "Content extraction failed" in article.html_content
|
||||
assert "plain text content from InfoQuest" in article.html_content
|
||||
|
||||
|
||||
def test_crawler_with_infoquest_error_response(monkeypatch):
|
||||
"""Test that the crawler handles error responses from InfoQuest client gracefully."""
|
||||
|
||||
class DummyArticle:
|
||||
def __init__(self, title, html_content):
|
||||
self.title = title
|
||||
self.html_content = html_content
|
||||
self.url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return f"# {self.title}"
|
||||
|
||||
class DummyInfoQuestClient:
|
||||
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
|
||||
pass
|
||||
|
||||
def crawl(self, url, return_format=None):
|
||||
return "Error: InfoQuest API returned status 403: Forbidden"
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
# This should not be called for error responses
|
||||
assert False, "ReadabilityExtractor should not be called for error responses"
|
||||
|
||||
# Mock configuration to use InfoQuest engine
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "infoquest"}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
|
||||
monkeypatch.setattr(
|
||||
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
|
||||
)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com"
|
||||
article = crawler.crawl(url)
|
||||
|
||||
assert article.url == url
|
||||
assert article.title in ["Non-HTML Content", "Content Extraction Failed"]
|
||||
assert "Error: InfoQuest API returned status 403: Forbidden" in article.html_content
|
||||
|
||||
|
||||
def test_crawler_with_infoquest_json_response(monkeypatch):
|
||||
"""Test that the crawler handles JSON responses from InfoQuest client correctly."""
|
||||
|
||||
class DummyArticle:
|
||||
def __init__(self, title, html_content):
|
||||
self.title = title
|
||||
self.html_content = html_content
|
||||
self.url = None
|
||||
|
||||
def to_markdown(self):
|
||||
return f"# {self.title}"
|
||||
|
||||
class DummyInfoQuestClient:
|
||||
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
|
||||
pass
|
||||
|
||||
def crawl(self, url, return_format=None):
|
||||
return "<html><body>Content from InfoQuest JSON</body></html>"
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
return DummyArticle("Extracted from JSON", html)
|
||||
|
||||
# Mock configuration to use InfoQuest engine
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "infoquest"}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
|
||||
monkeypatch.setattr(
|
||||
"src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor
|
||||
)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
url = "http://example.com"
|
||||
article = crawler.crawl(url)
|
||||
|
||||
assert article.url == url
|
||||
assert article.title == "Extracted from JSON"
|
||||
assert "Content from InfoQuest JSON" in article.html_content
|
||||
|
||||
|
||||
def test_infoquest_client_initialization_params():
|
||||
"""Test that InfoQuestClient correctly initializes with the provided parameters."""
|
||||
# Test default parameters
|
||||
client_default = InfoQuestClient()
|
||||
assert client_default.fetch_time == -1
|
||||
assert client_default.timeout == -1
|
||||
assert client_default.navi_timeout == -1
|
||||
|
||||
# Test custom parameters
|
||||
client_custom = InfoQuestClient(fetch_time=30, timeout=60, navi_timeout=45)
|
||||
assert client_custom.fetch_time == 30
|
||||
assert client_custom.timeout == 60
|
||||
assert client_custom.navi_timeout == 45
|
||||
|
||||
|
||||
def test_crawler_with_infoquest_default_parameters(monkeypatch):
|
||||
"""Test that the crawler initializes InfoQuestClient with default parameters when none are provided."""
|
||||
calls = {}
|
||||
|
||||
class DummyInfoQuestClient:
|
||||
def __init__(self, fetch_time=None, timeout=None, navi_timeout=None):
|
||||
calls["infoquest_init"] = (fetch_time, timeout, navi_timeout)
|
||||
|
||||
def crawl(self, url, return_format=None):
|
||||
return "<html>dummy</html>"
|
||||
|
||||
class DummyReadabilityExtractor:
|
||||
def extract_article(self, html):
|
||||
class DummyArticle:
|
||||
url = None
|
||||
def to_markdown(self):
|
||||
return "# Dummy"
|
||||
return DummyArticle()
|
||||
|
||||
# Mock configuration to use InfoQuest engine without custom parameters
|
||||
def mock_load_config(*args, **kwargs):
|
||||
return {"CRAWLER_ENGINE": {"engine": "infoquest"}}
|
||||
|
||||
monkeypatch.setattr("src.crawler.crawler.InfoQuestClient", DummyInfoQuestClient)
|
||||
monkeypatch.setattr("src.crawler.crawler.ReadabilityExtractor", DummyReadabilityExtractor)
|
||||
monkeypatch.setattr("src.crawler.crawler.load_yaml_config", mock_load_config)
|
||||
|
||||
crawler = crawler_module.crawler.Crawler()
|
||||
crawler.crawl("http://example.com")
|
||||
|
||||
# Verify default parameters were passed
|
||||
assert "infoquest_init" in calls
|
||||
assert calls["infoquest_init"] == (-1, -1, -1)
|
||||
@@ -1,230 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
import json
|
||||
|
||||
|
||||
|
||||
from src.crawler.infoquest_client import InfoQuestClient
|
||||
|
||||
|
||||
class TestInfoQuestClient:
|
||||
@patch("src.crawler.infoquest_client.requests.post")
|
||||
def test_crawl_success(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "<html><body>Test Content</body></html>"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = InfoQuestClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result == "<html><body>Test Content</body></html>"
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch("src.crawler.infoquest_client.requests.post")
|
||||
def test_crawl_json_response_with_reader_result(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
json_data = {
|
||||
"reader_result": "<p>Extracted content from JSON</p>",
|
||||
"err_code": 0,
|
||||
"err_msg": "success"
|
||||
}
|
||||
mock_response.text = json.dumps(json_data)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = InfoQuestClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result == "<p>Extracted content from JSON</p>"
|
||||
|
||||
@patch("src.crawler.infoquest_client.requests.post")
|
||||
def test_crawl_json_response_with_content_fallback(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
json_data = {
|
||||
"content": "<p>Content fallback from JSON</p>",
|
||||
"err_code": 0,
|
||||
"err_msg": "success"
|
||||
}
|
||||
mock_response.text = json.dumps(json_data)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = InfoQuestClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result == "<p>Content fallback from JSON</p>"
|
||||
|
||||
@patch("src.crawler.infoquest_client.requests.post")
|
||||
def test_crawl_json_response_without_expected_fields(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
json_data = {
|
||||
"unexpected_field": "some value",
|
||||
"err_code": 0,
|
||||
"err_msg": "success"
|
||||
}
|
||||
mock_response.text = json.dumps(json_data)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = InfoQuestClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result == json.dumps(json_data)
|
||||
|
||||
@patch("src.crawler.infoquest_client.requests.post")
|
||||
def test_crawl_http_error(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = InfoQuestClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result.startswith("Error:")
|
||||
assert "status 500" in result
|
||||
|
||||
@patch("src.crawler.infoquest_client.requests.post")
|
||||
def test_crawl_empty_response(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = ""
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = InfoQuestClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result.startswith("Error:")
|
||||
assert "empty response" in result
|
||||
|
||||
@patch("src.crawler.infoquest_client.requests.post")
|
||||
def test_crawl_whitespace_only_response(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = " \n \t "
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = InfoQuestClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result.startswith("Error:")
|
||||
assert "empty response" in result
|
||||
|
||||
@patch("src.crawler.infoquest_client.requests.post")
|
||||
def test_crawl_not_found(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_response.text = "Not Found"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = InfoQuestClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result.startswith("Error:")
|
||||
assert "status 404" in result
|
||||
|
||||
@patch.dict("os.environ", {}, clear=True)
|
||||
@patch("src.crawler.infoquest_client.requests.post")
|
||||
def test_crawl_without_api_key_logs_warning(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "<html>Test</html>"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = InfoQuestClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result == "<html>Test</html>"
|
||||
|
||||
@patch("src.crawler.infoquest_client.requests.post")
|
||||
def test_crawl_with_timeout_parameters(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "<html>Test</html>"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = InfoQuestClient(fetch_time=10, timeout=20, navi_timeout=30)
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result == "<html>Test</html>"
|
||||
# Verify the post call was made with timeout parameters
|
||||
call_args = mock_post.call_args[1]
|
||||
assert call_args['json']['fetch_time'] == 10
|
||||
assert call_args['json']['timeout'] == 20
|
||||
assert call_args['json']['navi_timeout'] == 30
|
||||
|
||||
@patch("src.crawler.infoquest_client.requests.post")
|
||||
def test_crawl_with_markdown_format(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "# Markdown Content"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = InfoQuestClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com", return_format="markdown")
|
||||
|
||||
# Assert
|
||||
assert result == "# Markdown Content"
|
||||
# Verify the format was set correctly
|
||||
call_args = mock_post.call_args[1]
|
||||
assert call_args['json']['format'] == "markdown"
|
||||
|
||||
@patch("src.crawler.infoquest_client.requests.post")
|
||||
def test_crawl_exception_handling(self, mock_post):
|
||||
# Arrange
|
||||
mock_post.side_effect = Exception("Network error")
|
||||
|
||||
client = InfoQuestClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result.startswith("Error:")
|
||||
assert "Network error" in result
|
||||
@@ -1,126 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.crawler.jina_client import JinaClient
|
||||
|
||||
|
||||
class TestJinaClient:
|
||||
@patch("src.crawler.jina_client.requests.post")
|
||||
def test_crawl_success(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "<html><body>Test</body></html>"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = JinaClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result == "<html><body>Test</body></html>"
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch("src.crawler.jina_client.requests.post")
|
||||
def test_crawl_http_error(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = JinaClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result.startswith("Error:")
|
||||
assert "status 500" in result
|
||||
|
||||
@patch("src.crawler.jina_client.requests.post")
|
||||
def test_crawl_empty_response(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = ""
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = JinaClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result.startswith("Error:")
|
||||
assert "empty response" in result
|
||||
|
||||
@patch("src.crawler.jina_client.requests.post")
|
||||
def test_crawl_whitespace_only_response(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = " \n \t "
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = JinaClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result.startswith("Error:")
|
||||
assert "empty response" in result
|
||||
|
||||
@patch("src.crawler.jina_client.requests.post")
|
||||
def test_crawl_not_found(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_response.text = "Not Found"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = JinaClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result.startswith("Error:")
|
||||
assert "status 404" in result
|
||||
|
||||
@patch.dict("os.environ", {}, clear=True)
|
||||
@patch("src.crawler.jina_client.requests.post")
|
||||
def test_crawl_without_api_key_logs_warning(self, mock_post):
|
||||
# Arrange
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "<html>Test</html>"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = JinaClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result == "<html>Test</html>"
|
||||
|
||||
@patch("src.crawler.jina_client.requests.post")
|
||||
def test_crawl_exception_handling(self, mock_post):
|
||||
# Arrange
|
||||
mock_post.side_effect = Exception("Network error")
|
||||
|
||||
client = JinaClient()
|
||||
|
||||
# Act
|
||||
result = client.crawl("https://example.com")
|
||||
|
||||
# Assert
|
||||
assert result.startswith("Error:")
|
||||
assert "Network error" in result
|
||||
@@ -1,104 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from src.crawler.readability_extractor import ReadabilityExtractor
|
||||
|
||||
|
||||
class TestReadabilityExtractor:
|
||||
@patch("src.crawler.readability_extractor.simple_json_from_html_string")
|
||||
def test_extract_article_with_valid_content(self, mock_simple_json):
|
||||
# Arrange
|
||||
mock_simple_json.return_value = {
|
||||
"title": "Test Article",
|
||||
"content": "<p>Article content</p>",
|
||||
}
|
||||
extractor = ReadabilityExtractor()
|
||||
|
||||
# Act
|
||||
article = extractor.extract_article("<html>test</html>")
|
||||
|
||||
# Assert
|
||||
assert article.title == "Test Article"
|
||||
assert article.html_content == "<p>Article content</p>"
|
||||
|
||||
@patch("src.crawler.readability_extractor.simple_json_from_html_string")
|
||||
def test_extract_article_with_none_content(self, mock_simple_json):
|
||||
# Arrange
|
||||
mock_simple_json.return_value = {
|
||||
"title": "Test Article",
|
||||
"content": None,
|
||||
}
|
||||
extractor = ReadabilityExtractor()
|
||||
|
||||
# Act
|
||||
article = extractor.extract_article("<html>test</html>")
|
||||
|
||||
# Assert
|
||||
assert article.title == "Test Article"
|
||||
assert article.html_content == "<p>No content could be extracted from this page</p>"
|
||||
|
||||
@patch("src.crawler.readability_extractor.simple_json_from_html_string")
|
||||
def test_extract_article_with_empty_content(self, mock_simple_json):
|
||||
# Arrange
|
||||
mock_simple_json.return_value = {
|
||||
"title": "Test Article",
|
||||
"content": "",
|
||||
}
|
||||
extractor = ReadabilityExtractor()
|
||||
|
||||
# Act
|
||||
article = extractor.extract_article("<html>test</html>")
|
||||
|
||||
# Assert
|
||||
assert article.title == "Test Article"
|
||||
assert article.html_content == "<p>No content could be extracted from this page</p>"
|
||||
|
||||
@patch("src.crawler.readability_extractor.simple_json_from_html_string")
|
||||
def test_extract_article_with_whitespace_only_content(self, mock_simple_json):
|
||||
# Arrange
|
||||
mock_simple_json.return_value = {
|
||||
"title": "Test Article",
|
||||
"content": " \n \t ",
|
||||
}
|
||||
extractor = ReadabilityExtractor()
|
||||
|
||||
# Act
|
||||
article = extractor.extract_article("<html>test</html>")
|
||||
|
||||
# Assert
|
||||
assert article.title == "Test Article"
|
||||
assert article.html_content == "<p>No content could be extracted from this page</p>"
|
||||
|
||||
@patch("src.crawler.readability_extractor.simple_json_from_html_string")
|
||||
def test_extract_article_with_none_title(self, mock_simple_json):
|
||||
# Arrange
|
||||
mock_simple_json.return_value = {
|
||||
"title": None,
|
||||
"content": "<p>Article content</p>",
|
||||
}
|
||||
extractor = ReadabilityExtractor()
|
||||
|
||||
# Act
|
||||
article = extractor.extract_article("<html>test</html>")
|
||||
|
||||
# Assert
|
||||
assert article.title == "Untitled"
|
||||
assert article.html_content == "<p>Article content</p>"
|
||||
|
||||
@patch("src.crawler.readability_extractor.simple_json_from_html_string")
|
||||
def test_extract_article_with_empty_title(self, mock_simple_json):
|
||||
# Arrange
|
||||
mock_simple_json.return_value = {
|
||||
"title": "",
|
||||
"content": "<p>Article content</p>",
|
||||
}
|
||||
extractor = ReadabilityExtractor()
|
||||
|
||||
# Act
|
||||
article = extractor.extract_article("<html>test</html>")
|
||||
|
||||
# Assert
|
||||
assert article.title == "Untitled"
|
||||
assert article.html_content == "<p>Article content</p>"
|
||||
@@ -1,2 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
@@ -1,489 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Unit tests for the combined report evaluator."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.eval.evaluator import CombinedEvaluation, ReportEvaluator, score_to_grade
|
||||
from src.eval.llm_judge import (
|
||||
EVALUATION_CRITERIA,
|
||||
MAX_REPORT_LENGTH,
|
||||
EvaluationResult,
|
||||
LLMJudge,
|
||||
)
|
||||
from src.eval.metrics import ReportMetrics
|
||||
|
||||
|
||||
class TestScoreToGrade:
|
||||
"""Tests for score to grade conversion."""
|
||||
|
||||
def test_excellent_scores(self):
|
||||
assert score_to_grade(9.5) == "A+"
|
||||
assert score_to_grade(9.0) == "A+"
|
||||
assert score_to_grade(8.7) == "A"
|
||||
assert score_to_grade(8.5) == "A"
|
||||
assert score_to_grade(8.2) == "A-"
|
||||
|
||||
def test_good_scores(self):
|
||||
assert score_to_grade(7.8) == "B+"
|
||||
assert score_to_grade(7.5) == "B+"
|
||||
assert score_to_grade(7.2) == "B"
|
||||
assert score_to_grade(7.0) == "B"
|
||||
assert score_to_grade(6.7) == "B-"
|
||||
|
||||
def test_average_scores(self):
|
||||
assert score_to_grade(6.2) == "C+"
|
||||
assert score_to_grade(5.8) == "C"
|
||||
assert score_to_grade(5.5) == "C"
|
||||
assert score_to_grade(5.2) == "C-"
|
||||
|
||||
def test_poor_scores(self):
|
||||
assert score_to_grade(4.5) == "D"
|
||||
assert score_to_grade(4.0) == "D"
|
||||
assert score_to_grade(3.0) == "F"
|
||||
assert score_to_grade(1.0) == "F"
|
||||
|
||||
|
||||
class TestReportEvaluator:
|
||||
"""Tests for ReportEvaluator class."""
|
||||
|
||||
@pytest.fixture
|
||||
def evaluator(self):
|
||||
"""Create evaluator without LLM for metrics-only tests."""
|
||||
return ReportEvaluator(use_llm=False)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_report(self):
|
||||
"""Sample report for testing."""
|
||||
return """
|
||||
# Comprehensive Research Report
|
||||
|
||||
## Key Points
|
||||
- Important finding number one with significant implications
|
||||
- Critical discovery that changes our understanding
|
||||
- Key insight that provides actionable recommendations
|
||||
- Notable observation from the research data
|
||||
|
||||
## Overview
|
||||
This report presents a comprehensive analysis of the research topic.
|
||||
The findings are based on extensive data collection and analysis.
|
||||
|
||||
## Detailed Analysis
|
||||
|
||||
### Section 1: Background
|
||||
The background of this research involves multiple factors.
|
||||
[Source 1](https://example.com/source1) provides foundational context.
|
||||
|
||||
### Section 2: Methodology
|
||||
Our methodology follows established research practices.
|
||||
[Source 2](https://research.org/methods) outlines the approach.
|
||||
|
||||
### Section 3: Findings
|
||||
The key findings include several important discoveries.
|
||||

|
||||
|
||||
[Source 3](https://academic.edu/paper) supports these conclusions.
|
||||
|
||||
## Key Citations
|
||||
- [Example Source](https://example.com/source1)
|
||||
- [Research Methods](https://research.org/methods)
|
||||
- [Academic Paper](https://academic.edu/paper)
|
||||
- [Additional Reference](https://reference.com/doc)
|
||||
"""
|
||||
|
||||
def test_evaluate_metrics_only(self, evaluator, sample_report):
|
||||
"""Test metrics-only evaluation."""
|
||||
result = evaluator.evaluate_metrics_only(sample_report)
|
||||
|
||||
assert "metrics" in result
|
||||
assert "score" in result
|
||||
assert "grade" in result
|
||||
assert result["score"] > 0
|
||||
assert result["grade"] in ["A+", "A", "A-", "B+", "B", "B-", "C+", "C", "C-", "D", "F"]
|
||||
|
||||
def test_evaluate_metrics_only_structure(self, evaluator, sample_report):
|
||||
"""Test that metrics contain expected fields."""
|
||||
result = evaluator.evaluate_metrics_only(sample_report)
|
||||
metrics = result["metrics"]
|
||||
|
||||
assert "word_count" in metrics
|
||||
assert "citation_count" in metrics
|
||||
assert "unique_sources" in metrics
|
||||
assert "image_count" in metrics
|
||||
assert "section_coverage_score" in metrics
|
||||
|
||||
def test_evaluate_minimal_report(self, evaluator):
|
||||
"""Test evaluation of minimal report."""
|
||||
minimal_report = "Just some text."
|
||||
result = evaluator.evaluate_metrics_only(minimal_report)
|
||||
|
||||
assert result["score"] < 5.0
|
||||
assert result["grade"] in ["D", "F"]
|
||||
|
||||
def test_metrics_score_calculation(self, evaluator):
|
||||
"""Test that metrics score is calculated correctly."""
|
||||
good_report = """
|
||||
# Title
|
||||
|
||||
## Key Points
|
||||
- Point 1
|
||||
- Point 2
|
||||
|
||||
## Overview
|
||||
Overview content here.
|
||||
|
||||
## Detailed Analysis
|
||||
Analysis with [cite](https://a.com) and [cite2](https://b.com)
|
||||
and [cite3](https://c.com) and more [refs](https://d.com).
|
||||
|
||||

|
||||
|
||||
## Key Citations
|
||||
- [A](https://a.com)
|
||||
- [B](https://b.com)
|
||||
"""
|
||||
result = evaluator.evaluate_metrics_only(good_report)
|
||||
assert result["score"] > 5.0
|
||||
|
||||
def test_combined_evaluation_to_dict(self):
|
||||
"""Test CombinedEvaluation to_dict method."""
|
||||
metrics = ReportMetrics(
|
||||
word_count=1000,
|
||||
citation_count=5,
|
||||
unique_sources=3,
|
||||
)
|
||||
evaluation = CombinedEvaluation(
|
||||
metrics=metrics,
|
||||
llm_evaluation=None,
|
||||
final_score=7.5,
|
||||
grade="B+",
|
||||
summary="Test summary",
|
||||
)
|
||||
|
||||
result = evaluation.to_dict()
|
||||
assert result["final_score"] == 7.5
|
||||
assert result["grade"] == "B+"
|
||||
assert result["metrics"]["word_count"] == 1000
|
||||
|
||||
|
||||
class TestReportEvaluatorIntegration:
|
||||
"""Integration tests for evaluator (may require LLM)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_evaluation_without_llm(self):
|
||||
"""Test full evaluation with LLM disabled."""
|
||||
evaluator = ReportEvaluator(use_llm=False)
|
||||
|
||||
report = """
|
||||
# Test Report
|
||||
|
||||
## Key Points
|
||||
- Key point 1
|
||||
|
||||
## Overview
|
||||
Test overview.
|
||||
|
||||
## Key Citations
|
||||
- [Test](https://test.com)
|
||||
"""
|
||||
|
||||
result = await evaluator.evaluate(report, "test query")
|
||||
|
||||
assert isinstance(result, CombinedEvaluation)
|
||||
assert result.final_score > 0
|
||||
assert result.grade is not None
|
||||
assert result.summary is not None
|
||||
assert result.llm_evaluation is None
|
||||
|
||||
|
||||
class TestLLMJudgeParseResponse:
|
||||
"""Tests for LLMJudge._parse_response method."""
|
||||
|
||||
@pytest.fixture
|
||||
def judge(self):
|
||||
"""Create LLMJudge with mock LLM."""
|
||||
return LLMJudge(llm=MagicMock())
|
||||
|
||||
@pytest.fixture
|
||||
def valid_response_data(self):
|
||||
"""Valid evaluation response data."""
|
||||
return {
|
||||
"scores": {
|
||||
"factual_accuracy": 8,
|
||||
"completeness": 7,
|
||||
"coherence": 9,
|
||||
"relevance": 8,
|
||||
"citation_quality": 6,
|
||||
"writing_quality": 8,
|
||||
},
|
||||
"overall_score": 8,
|
||||
"strengths": ["Well researched", "Clear structure"],
|
||||
"weaknesses": ["Could use more citations"],
|
||||
"suggestions": ["Add more sources"],
|
||||
}
|
||||
|
||||
def test_parse_valid_json(self, judge, valid_response_data):
|
||||
"""Test parsing valid JSON response."""
|
||||
response = json.dumps(valid_response_data)
|
||||
result = judge._parse_response(response)
|
||||
|
||||
assert result["scores"]["factual_accuracy"] == 8
|
||||
assert result["overall_score"] == 8
|
||||
assert "Well researched" in result["strengths"]
|
||||
|
||||
def test_parse_json_in_markdown_block(self, judge, valid_response_data):
|
||||
"""Test parsing JSON wrapped in markdown code block."""
|
||||
response = f"```json\n{json.dumps(valid_response_data)}\n```"
|
||||
result = judge._parse_response(response)
|
||||
|
||||
assert result["scores"]["coherence"] == 9
|
||||
assert result["overall_score"] == 8
|
||||
|
||||
def test_parse_json_in_generic_code_block(self, judge, valid_response_data):
|
||||
"""Test parsing JSON in generic code block."""
|
||||
response = f"```\n{json.dumps(valid_response_data)}\n```"
|
||||
result = judge._parse_response(response)
|
||||
|
||||
assert result["scores"]["relevance"] == 8
|
||||
|
||||
def test_parse_malformed_json_returns_defaults(self, judge):
|
||||
"""Test that malformed JSON returns default scores."""
|
||||
response = "This is not valid JSON at all"
|
||||
result = judge._parse_response(response)
|
||||
|
||||
assert result["scores"]["factual_accuracy"] == 5
|
||||
assert result["scores"]["completeness"] == 5
|
||||
assert result["overall_score"] == 5
|
||||
assert "Unable to parse evaluation" in result["strengths"]
|
||||
assert "Evaluation parsing failed" in result["weaknesses"]
|
||||
|
||||
def test_parse_incomplete_json(self, judge):
|
||||
"""Test parsing incomplete JSON."""
|
||||
response = '{"scores": {"factual_accuracy": 8}' # Missing closing braces
|
||||
result = judge._parse_response(response)
|
||||
|
||||
# Should return defaults due to parse failure
|
||||
assert result["overall_score"] == 5
|
||||
|
||||
def test_parse_json_with_extra_text(self, judge, valid_response_data):
|
||||
"""Test parsing JSON with surrounding text."""
|
||||
response = f"Here is my evaluation:\n```json\n{json.dumps(valid_response_data)}\n```\nHope this helps!"
|
||||
result = judge._parse_response(response)
|
||||
|
||||
assert result["scores"]["factual_accuracy"] == 8
|
||||
|
||||
|
||||
class TestLLMJudgeCalculateWeightedScore:
|
||||
"""Tests for LLMJudge._calculate_weighted_score method."""
|
||||
|
||||
@pytest.fixture
|
||||
def judge(self):
|
||||
"""Create LLMJudge with mock LLM."""
|
||||
return LLMJudge(llm=MagicMock())
|
||||
|
||||
def test_calculate_with_all_scores(self, judge):
|
||||
"""Test weighted score calculation with all criteria."""
|
||||
scores = {
|
||||
"factual_accuracy": 10, # weight 0.25
|
||||
"completeness": 10, # weight 0.20
|
||||
"coherence": 10, # weight 0.20
|
||||
"relevance": 10, # weight 0.15
|
||||
"citation_quality": 10, # weight 0.10
|
||||
"writing_quality": 10, # weight 0.10
|
||||
}
|
||||
result = judge._calculate_weighted_score(scores)
|
||||
assert result == 10.0
|
||||
|
||||
def test_calculate_with_varied_scores(self, judge):
|
||||
"""Test weighted score with varied scores."""
|
||||
scores = {
|
||||
"factual_accuracy": 8, # 8 * 0.25 = 2.0
|
||||
"completeness": 6, # 6 * 0.20 = 1.2
|
||||
"coherence": 7, # 7 * 0.20 = 1.4
|
||||
"relevance": 9, # 9 * 0.15 = 1.35
|
||||
"citation_quality": 5, # 5 * 0.10 = 0.5
|
||||
"writing_quality": 8, # 8 * 0.10 = 0.8
|
||||
}
|
||||
# Total: 7.25
|
||||
result = judge._calculate_weighted_score(scores)
|
||||
assert result == 7.25
|
||||
|
||||
def test_calculate_with_partial_scores(self, judge):
|
||||
"""Test weighted score with only some criteria."""
|
||||
scores = {
|
||||
"factual_accuracy": 8, # weight 0.25
|
||||
"completeness": 6, # weight 0.20
|
||||
}
|
||||
# (8 * 0.25 + 6 * 0.20) / (0.25 + 0.20) = 3.2 / 0.45 = 7.11
|
||||
result = judge._calculate_weighted_score(scores)
|
||||
assert abs(result - 7.11) < 0.01
|
||||
|
||||
def test_calculate_with_unknown_criteria(self, judge):
|
||||
"""Test that unknown criteria are ignored."""
|
||||
scores = {
|
||||
"factual_accuracy": 10,
|
||||
"unknown_criterion": 1, # Should be ignored
|
||||
}
|
||||
result = judge._calculate_weighted_score(scores)
|
||||
assert result == 10.0
|
||||
|
||||
def test_calculate_with_empty_scores(self, judge):
|
||||
"""Test with empty scores dict."""
|
||||
result = judge._calculate_weighted_score({})
|
||||
assert result == 0.0
|
||||
|
||||
def test_weights_sum_to_one(self):
|
||||
"""Verify that all criteria weights sum to 1.0."""
|
||||
total_weight = sum(c["weight"] for c in EVALUATION_CRITERIA.values())
|
||||
assert abs(total_weight - 1.0) < 0.001
|
||||
|
||||
|
||||
class TestLLMJudgeEvaluate:
|
||||
"""Tests for LLMJudge.evaluate method with mocked LLM."""
|
||||
|
||||
@pytest.fixture
|
||||
def valid_llm_response(self):
|
||||
"""Create a valid LLM response."""
|
||||
return json.dumps(
|
||||
{
|
||||
"scores": {
|
||||
"factual_accuracy": 8,
|
||||
"completeness": 7,
|
||||
"coherence": 9,
|
||||
"relevance": 8,
|
||||
"citation_quality": 7,
|
||||
"writing_quality": 8,
|
||||
},
|
||||
"overall_score": 8,
|
||||
"strengths": ["Comprehensive coverage", "Well structured"],
|
||||
"weaknesses": ["Some claims need more support"],
|
||||
"suggestions": ["Add more academic sources"],
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_evaluation(self, valid_llm_response):
|
||||
"""Test successful LLM evaluation."""
|
||||
mock_llm = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = valid_llm_response
|
||||
mock_llm.ainvoke.return_value = mock_response
|
||||
|
||||
judge = LLMJudge(llm=mock_llm)
|
||||
result = await judge.evaluate("Test report", "Test query")
|
||||
|
||||
assert isinstance(result, EvaluationResult)
|
||||
assert result.scores["factual_accuracy"] == 8
|
||||
assert result.overall_score == 8
|
||||
assert result.weighted_score > 0
|
||||
assert "Comprehensive coverage" in result.strengths
|
||||
assert result.raw_response == valid_llm_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluation_with_llm_failure(self):
|
||||
"""Test that LLM failures are handled gracefully."""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.ainvoke.side_effect = Exception("LLM service unavailable")
|
||||
|
||||
judge = LLMJudge(llm=mock_llm)
|
||||
result = await judge.evaluate("Test report", "Test query")
|
||||
|
||||
assert isinstance(result, EvaluationResult)
|
||||
assert result.overall_score == 0
|
||||
assert result.weighted_score == 0
|
||||
assert all(score == 0 for score in result.scores.values())
|
||||
assert any("failed" in w.lower() for w in result.weaknesses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluation_with_malformed_response(self):
|
||||
"""Test handling of malformed LLM response."""
|
||||
mock_llm = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "I cannot evaluate this report properly."
|
||||
mock_llm.ainvoke.return_value = mock_response
|
||||
|
||||
judge = LLMJudge(llm=mock_llm)
|
||||
result = await judge.evaluate("Test report", "Test query")
|
||||
|
||||
# Should return default scores
|
||||
assert result.scores["factual_accuracy"] == 5
|
||||
assert result.overall_score == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluation_passes_report_style(self):
|
||||
"""Test that report_style is passed to LLM."""
|
||||
mock_llm = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps(
|
||||
{
|
||||
"scores": {k: 7 for k in EVALUATION_CRITERIA.keys()},
|
||||
"overall_score": 7,
|
||||
"strengths": [],
|
||||
"weaknesses": [],
|
||||
"suggestions": [],
|
||||
}
|
||||
)
|
||||
mock_llm.ainvoke.return_value = mock_response
|
||||
|
||||
judge = LLMJudge(llm=mock_llm)
|
||||
await judge.evaluate("Test report", "Test query", report_style="academic")
|
||||
|
||||
# Verify the prompt contains the report style
|
||||
call_args = mock_llm.ainvoke.call_args
|
||||
messages = call_args[0][0]
|
||||
user_message_content = messages[1].content
|
||||
assert "academic" in user_message_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluation_truncates_long_reports(self):
|
||||
"""Test that very long reports are truncated."""
|
||||
mock_llm = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps(
|
||||
{
|
||||
"scores": {k: 7 for k in EVALUATION_CRITERIA.keys()},
|
||||
"overall_score": 7,
|
||||
"strengths": [],
|
||||
"weaknesses": [],
|
||||
"suggestions": [],
|
||||
}
|
||||
)
|
||||
mock_llm.ainvoke.return_value = mock_response
|
||||
|
||||
judge = LLMJudge(llm=mock_llm)
|
||||
long_report = "x" * (MAX_REPORT_LENGTH + 5000)
|
||||
await judge.evaluate(long_report, "Test query")
|
||||
|
||||
call_args = mock_llm.ainvoke.call_args
|
||||
messages = call_args[0][0]
|
||||
user_message_content = messages[1].content
|
||||
# The report content in the message should be truncated to MAX_REPORT_LENGTH
|
||||
assert len(user_message_content) < len(long_report) + 500
|
||||
|
||||
|
||||
class TestEvaluationResult:
|
||||
"""Tests for EvaluationResult dataclass."""
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test EvaluationResult.to_dict method."""
|
||||
result = EvaluationResult(
|
||||
scores={"factual_accuracy": 8, "completeness": 7},
|
||||
overall_score=7.5,
|
||||
weighted_score=7.6,
|
||||
strengths=["Good research"],
|
||||
weaknesses=["Needs more detail"],
|
||||
suggestions=["Expand section 2"],
|
||||
raw_response="test response",
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
assert d["scores"]["factual_accuracy"] == 8
|
||||
assert d["overall_score"] == 7.5
|
||||
assert d["weighted_score"] == 7.6
|
||||
assert "Good research" in d["strengths"]
|
||||
# raw_response should not be in dict
|
||||
assert "raw_response" not in d
|
||||
@@ -1,207 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Unit tests for report evaluation metrics."""
|
||||
|
||||
from src.eval.metrics import (
|
||||
compute_metrics,
|
||||
count_citations,
|
||||
count_images,
|
||||
count_words,
|
||||
detect_sections,
|
||||
extract_domains,
|
||||
get_word_count_target,
|
||||
)
|
||||
|
||||
|
||||
class TestCountWords:
|
||||
"""Tests for word counting function."""
|
||||
|
||||
def test_english_words(self):
|
||||
text = "This is a simple test sentence."
|
||||
assert count_words(text) == 6
|
||||
|
||||
def test_chinese_characters(self):
|
||||
text = "这是一个测试"
|
||||
assert count_words(text) == 6
|
||||
|
||||
def test_mixed_content(self):
|
||||
text = "Hello 你好 World 世界"
|
||||
assert count_words(text) == 4 + 2 # 2 English + 4 Chinese
|
||||
|
||||
def test_empty_string(self):
|
||||
assert count_words("") == 0
|
||||
|
||||
|
||||
class TestCountCitations:
|
||||
"""Tests for citation counting function."""
|
||||
|
||||
def test_markdown_citations(self):
|
||||
text = """
|
||||
Check out [Google](https://google.com) and [GitHub](https://github.com).
|
||||
"""
|
||||
assert count_citations(text) == 2
|
||||
|
||||
def test_no_citations(self):
|
||||
text = "This is plain text without any links."
|
||||
assert count_citations(text) == 0
|
||||
|
||||
def test_invalid_urls(self):
|
||||
text = "[Link](not-a-url) [Another](ftp://ftp.example.com)"
|
||||
assert count_citations(text) == 0
|
||||
|
||||
def test_complex_urls(self):
|
||||
text = "[Article](https://example.com/path/to/article?id=123&ref=test)"
|
||||
assert count_citations(text) == 1
|
||||
|
||||
|
||||
class TestExtractDomains:
|
||||
"""Tests for domain extraction function."""
|
||||
|
||||
def test_extract_multiple_domains(self):
|
||||
text = """
|
||||
https://google.com/search
|
||||
https://www.github.com/user/repo
|
||||
https://docs.python.org/3/
|
||||
"""
|
||||
domains = extract_domains(text)
|
||||
assert len(domains) == 3
|
||||
assert "google.com" in domains
|
||||
assert "github.com" in domains
|
||||
assert "docs.python.org" in domains
|
||||
|
||||
def test_deduplicate_domains(self):
|
||||
text = """
|
||||
https://example.com/page1
|
||||
https://example.com/page2
|
||||
https://www.example.com/page3
|
||||
"""
|
||||
domains = extract_domains(text)
|
||||
assert len(domains) == 1
|
||||
assert "example.com" in domains
|
||||
|
||||
def test_no_urls(self):
|
||||
text = "Plain text without URLs"
|
||||
assert extract_domains(text) == []
|
||||
|
||||
|
||||
class TestCountImages:
|
||||
"""Tests for image counting function."""
|
||||
|
||||
def test_markdown_images(self):
|
||||
text = """
|
||||

|
||||

|
||||
"""
|
||||
assert count_images(text) == 2
|
||||
|
||||
def test_no_images(self):
|
||||
text = "Text without images [link](url)"
|
||||
assert count_images(text) == 0
|
||||
|
||||
|
||||
class TestDetectSections:
|
||||
"""Tests for section detection function."""
|
||||
|
||||
def test_detect_title(self):
|
||||
text = "# My Report Title\n\nSome content here."
|
||||
sections = detect_sections(text)
|
||||
assert sections.get("title") is True
|
||||
|
||||
def test_detect_key_points(self):
|
||||
text = "## Key Points\n- Point 1\n- Point 2"
|
||||
sections = detect_sections(text)
|
||||
assert sections.get("key_points") is True
|
||||
|
||||
def test_detect_chinese_sections(self):
|
||||
text = """# 报告标题
|
||||
## 要点
|
||||
- 要点1
|
||||
## 概述
|
||||
这是概述内容
|
||||
"""
|
||||
sections = detect_sections(text)
|
||||
assert sections.get("title") is True
|
||||
assert sections.get("key_points") is True
|
||||
assert sections.get("overview") is True
|
||||
|
||||
def test_detect_citations_section(self):
|
||||
text = """
|
||||
## Key Citations
|
||||
- [Source 1](https://example.com)
|
||||
"""
|
||||
sections = detect_sections(text)
|
||||
assert sections.get("key_citations") is True
|
||||
|
||||
|
||||
class TestComputeMetrics:
|
||||
"""Tests for the main compute_metrics function."""
|
||||
|
||||
def test_complete_report(self):
|
||||
report = """
|
||||
# Research Report Title
|
||||
|
||||
## Key Points
|
||||
- Point 1
|
||||
- Point 2
|
||||
- Point 3
|
||||
|
||||
## Overview
|
||||
This is an overview of the research topic.
|
||||
|
||||
## Detailed Analysis
|
||||
Here is the detailed analysis with [source](https://example.com).
|
||||
|
||||

|
||||
|
||||
## Key Citations
|
||||
- [Source 1](https://example.com)
|
||||
- [Source 2](https://another.com)
|
||||
"""
|
||||
metrics = compute_metrics(report)
|
||||
|
||||
assert metrics.has_title is True
|
||||
assert metrics.has_key_points is True
|
||||
assert metrics.has_overview is True
|
||||
assert metrics.has_citations_section is True
|
||||
assert metrics.citation_count >= 2
|
||||
assert metrics.image_count == 1
|
||||
assert metrics.unique_sources >= 1
|
||||
assert metrics.section_coverage_score > 0.5
|
||||
|
||||
def test_minimal_report(self):
|
||||
report = "Just some text without structure."
|
||||
metrics = compute_metrics(report)
|
||||
|
||||
assert metrics.has_title is False
|
||||
assert metrics.citation_count == 0
|
||||
assert metrics.section_coverage_score < 0.5
|
||||
|
||||
def test_metrics_to_dict(self):
|
||||
report = "# Title\n\nSome content"
|
||||
metrics = compute_metrics(report)
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "word_count" in result
|
||||
assert "citation_count" in result
|
||||
assert "section_coverage_score" in result
|
||||
|
||||
|
||||
class TestGetWordCountTarget:
|
||||
"""Tests for word count target function."""
|
||||
|
||||
def test_strategic_investment_target(self):
|
||||
target = get_word_count_target("strategic_investment")
|
||||
assert target["min"] == 10000
|
||||
assert target["max"] == 15000
|
||||
|
||||
def test_news_target(self):
|
||||
target = get_word_count_target("news")
|
||||
assert target["min"] == 800
|
||||
assert target["max"] == 2000
|
||||
|
||||
def test_default_target(self):
|
||||
target = get_word_count_target("unknown_style")
|
||||
assert target["min"] == 1000
|
||||
assert target["max"] == 5000
|
||||
@@ -1,241 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for agent locale restoration after agent execution.
|
||||
|
||||
Tests that meta fields (especially locale) are properly restored after
|
||||
agent.ainvoke() returns, since the agent creates a MessagesState
|
||||
subgraph that filters out custom fields.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.graph.nodes import preserve_state_meta_fields
|
||||
from src.graph.types import State
|
||||
|
||||
|
||||
class TestAgentLocaleRestoration:
|
||||
"""Test suite for locale restoration after agent execution."""
|
||||
|
||||
def test_locale_lost_in_agent_subgraph(self):
|
||||
"""
|
||||
Demonstrate the problem: agent subgraph filters out locale.
|
||||
|
||||
When the agent creates a subgraph with MessagesState,
|
||||
it only returns messages, not custom fields.
|
||||
"""
|
||||
# Simulate agent behavior: only returns messages
|
||||
initial_state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# Agent subgraph returns (like MessagesState would)
|
||||
agent_result = {
|
||||
"messages": ["agent response"],
|
||||
}
|
||||
|
||||
# Problem: locale is missing
|
||||
assert "locale" not in agent_result
|
||||
assert agent_result.get("locale") is None
|
||||
|
||||
def test_locale_restoration_after_agent(self):
|
||||
"""Test that locale can be restored after agent.ainvoke() returns."""
|
||||
initial_state = State(
|
||||
messages=[],
|
||||
locale="zh-CN",
|
||||
research_topic="test",
|
||||
)
|
||||
|
||||
# Simulate agent returning (MessagesState only)
|
||||
agent_result = {
|
||||
"messages": ["agent response"],
|
||||
}
|
||||
|
||||
# Apply restoration
|
||||
preserved = preserve_state_meta_fields(initial_state)
|
||||
agent_result.update(preserved)
|
||||
|
||||
# Verify restoration worked
|
||||
assert agent_result["locale"] == "zh-CN"
|
||||
assert agent_result["research_topic"] == "test"
|
||||
assert "messages" in agent_result
|
||||
|
||||
def test_all_meta_fields_restored(self):
|
||||
"""Test that all meta fields are restored, not just locale."""
|
||||
initial_state = State(
|
||||
messages=[],
|
||||
locale="en-US",
|
||||
research_topic="Original Topic",
|
||||
clarified_research_topic="Clarified Topic",
|
||||
clarification_history=["Q1", "A1"],
|
||||
enable_clarification=True,
|
||||
max_clarification_rounds=5,
|
||||
clarification_rounds=2,
|
||||
resources=["resource1"],
|
||||
)
|
||||
|
||||
# Agent result
|
||||
agent_result = {"messages": ["response"]}
|
||||
agent_result.update(preserve_state_meta_fields(initial_state))
|
||||
|
||||
# All fields should be restored
|
||||
assert agent_result["locale"] == "en-US"
|
||||
assert agent_result["research_topic"] == "Original Topic"
|
||||
assert agent_result["clarified_research_topic"] == "Clarified Topic"
|
||||
assert agent_result["clarification_history"] == ["Q1", "A1"]
|
||||
assert agent_result["enable_clarification"] is True
|
||||
assert agent_result["max_clarification_rounds"] == 5
|
||||
assert agent_result["clarification_rounds"] == 2
|
||||
assert agent_result["resources"] == ["resource1"]
|
||||
|
||||
def test_locale_preservation_through_agent_cycle(self):
|
||||
"""Test the complete cycle: state in → agent → state out."""
|
||||
# Initial state with zh-CN locale
|
||||
initial_state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# Step 1: Extract meta fields
|
||||
preserved = preserve_state_meta_fields(initial_state)
|
||||
assert preserved["locale"] == "zh-CN"
|
||||
|
||||
# Step 2: Agent runs and returns only messages
|
||||
agent_result = {"messages": ["agent output"]}
|
||||
assert "locale" not in agent_result # Missing!
|
||||
|
||||
# Step 3: Restore meta fields
|
||||
agent_result.update(preserved)
|
||||
|
||||
# Step 4: Verify locale is restored
|
||||
assert agent_result["locale"] == "zh-CN"
|
||||
|
||||
# Step 5: Create new state with restored fields
|
||||
final_state = State(messages=agent_result["messages"], **preserved)
|
||||
assert final_state.get("locale") == "zh-CN"
|
||||
|
||||
def test_locale_not_auto_after_restoration(self):
|
||||
"""
|
||||
Test that locale is NOT "auto" after restoration.
|
||||
|
||||
This tests the specific bug: locale was becoming "auto"
|
||||
instead of the preserved "zh-CN" value.
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# Agent returns without locale
|
||||
agent_result = {"messages": []}
|
||||
|
||||
# Before fix: locale would be "auto" (default behavior)
|
||||
# After fix: locale is preserved
|
||||
agent_result.update(preserve_state_meta_fields(state))
|
||||
|
||||
assert agent_result.get("locale") == "zh-CN"
|
||||
assert agent_result.get("locale") != "auto"
|
||||
assert agent_result.get("locale") is not None
|
||||
|
||||
def test_chinese_locale_preserved(self):
|
||||
"""Test that Chinese locale specifically is preserved."""
|
||||
locales_to_test = ["zh-CN", "zh", "zh-Hans", "zh-Hant"]
|
||||
|
||||
for locale_value in locales_to_test:
|
||||
state = State(messages=[], locale=locale_value)
|
||||
agent_result = {"messages": []}
|
||||
|
||||
agent_result.update(preserve_state_meta_fields(state))
|
||||
|
||||
assert agent_result["locale"] == locale_value, f"Failed for locale: {locale_value}"
|
||||
|
||||
def test_restoration_with_new_messages(self):
|
||||
"""Test that restoration works even when agent adds new messages."""
|
||||
state = State(messages=[], locale="zh-CN", research_topic="research")
|
||||
|
||||
# Agent processes and returns new messages
|
||||
agent_result = {
|
||||
"messages": ["message1", "message2", "message3"],
|
||||
}
|
||||
|
||||
# Restore meta fields
|
||||
agent_result.update(preserve_state_meta_fields(state))
|
||||
|
||||
# Should have both new messages AND preserved meta fields
|
||||
assert len(agent_result["messages"]) == 3
|
||||
assert agent_result["locale"] == "zh-CN"
|
||||
assert agent_result["research_topic"] == "research"
|
||||
|
||||
def test_restoration_idempotent(self):
|
||||
"""Test that restoring meta fields multiple times doesn't cause issues."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
agent_result = {"messages": []}
|
||||
|
||||
# Apply restoration multiple times
|
||||
agent_result.update(preserved)
|
||||
agent_result.update(preserved)
|
||||
agent_result.update(preserved)
|
||||
|
||||
# Should still have correct locale (not corrupted)
|
||||
assert agent_result["locale"] == "en-US"
|
||||
assert len(agent_result) == 9 # messages + 8 meta fields
|
||||
|
||||
|
||||
class TestAgentLocaleRestorationScenarios:
|
||||
"""Real-world scenario tests for agent locale restoration."""
|
||||
|
||||
def test_researcher_agent_preserves_locale(self):
|
||||
"""
|
||||
Simulate researcher agent execution preserving locale.
|
||||
|
||||
Scenario:
|
||||
1. Researcher node receives state with locale="zh-CN"
|
||||
2. Calls agent.ainvoke() which returns only messages
|
||||
3. Restores locale before returning
|
||||
"""
|
||||
# State coming into researcher node
|
||||
state = State(
|
||||
messages=[],
|
||||
locale="zh-CN",
|
||||
research_topic="生产1公斤牛肉需要多少升水?",
|
||||
)
|
||||
|
||||
# Agent executes and returns
|
||||
agent_result = {
|
||||
"messages": ["Researcher analysis of water usage..."],
|
||||
}
|
||||
|
||||
# Apply restoration (what the fix does)
|
||||
agent_result.update(preserve_state_meta_fields(state))
|
||||
|
||||
# Verify for next node
|
||||
assert agent_result["locale"] == "zh-CN" # ✓ Preserved!
|
||||
assert agent_result.get("locale") != "auto" # ✓ Not "auto"
|
||||
|
||||
def test_coder_agent_preserves_locale(self):
|
||||
"""Coder agent should also preserve locale."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
|
||||
agent_result = {"messages": ["Code generation result"]}
|
||||
agent_result.update(preserve_state_meta_fields(state))
|
||||
|
||||
assert agent_result["locale"] == "en-US"
|
||||
|
||||
def test_locale_persists_across_multiple_agents(self):
|
||||
"""Test that locale persists through multiple agent calls."""
|
||||
locales = ["zh-CN", "en-US", "fr-FR"]
|
||||
|
||||
for locale in locales:
|
||||
# Initial state
|
||||
state = State(messages=[], locale=locale)
|
||||
preserved_1 = preserve_state_meta_fields(state)
|
||||
|
||||
# First agent
|
||||
result_1 = {"messages": ["agent1"]}
|
||||
result_1.update(preserved_1)
|
||||
|
||||
# Create state for second agent
|
||||
state_2 = State(messages=result_1["messages"], **preserved_1)
|
||||
preserved_2 = preserve_state_meta_fields(state_2)
|
||||
|
||||
# Second agent
|
||||
result_2 = {"messages": result_1["messages"] + ["agent2"]}
|
||||
result_2.update(preserved_2)
|
||||
|
||||
# Locale should persist
|
||||
assert result_2["locale"] == locale
|
||||
@@ -1,134 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import src.graph.builder as builder_mod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
class Step:
|
||||
def __init__(self, execution_res=None, step_type=None):
|
||||
self.execution_res = execution_res
|
||||
self.step_type = step_type
|
||||
|
||||
class Plan:
|
||||
def __init__(self, steps):
|
||||
self.steps = steps
|
||||
|
||||
return {
|
||||
"Step": Step,
|
||||
"Plan": Plan,
|
||||
}
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_no_plan(mock_state):
|
||||
state = {"current_plan": None}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "planner"
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_no_steps(mock_state):
|
||||
state = {"current_plan": mock_state["Plan"](steps=[])}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "planner"
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_all_executed(mock_state):
|
||||
Step = mock_state["Step"]
|
||||
Plan = mock_state["Plan"]
|
||||
steps = [Step(execution_res=True), Step(execution_res=True)]
|
||||
state = {"current_plan": Plan(steps=steps)}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "planner"
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_next_researcher(mock_state):
|
||||
Step = mock_state["Step"]
|
||||
Plan = mock_state["Plan"]
|
||||
steps = [
|
||||
Step(execution_res=True),
|
||||
Step(execution_res=None, step_type=builder_mod.StepType.RESEARCH),
|
||||
]
|
||||
state = {"current_plan": Plan(steps=steps)}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "researcher"
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_next_coder(mock_state):
|
||||
Step = mock_state["Step"]
|
||||
Plan = mock_state["Plan"]
|
||||
steps = [
|
||||
Step(execution_res=True),
|
||||
Step(execution_res=None, step_type=builder_mod.StepType.PROCESSING),
|
||||
]
|
||||
state = {"current_plan": Plan(steps=steps)}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "coder"
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_next_coder_withresult(mock_state):
|
||||
Step = mock_state["Step"]
|
||||
Plan = mock_state["Plan"]
|
||||
steps = [
|
||||
Step(execution_res=True),
|
||||
Step(execution_res=True, step_type=builder_mod.StepType.PROCESSING),
|
||||
]
|
||||
state = {"current_plan": Plan(steps=steps)}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "planner"
|
||||
|
||||
|
||||
def test_continue_to_running_research_team_default_planner(mock_state):
|
||||
Step = mock_state["Step"]
|
||||
Plan = mock_state["Plan"]
|
||||
steps = [Step(execution_res=True), Step(execution_res=None, step_type=None)]
|
||||
state = {"current_plan": Plan(steps=steps)}
|
||||
assert builder_mod.continue_to_running_research_team(state) == "planner"
|
||||
|
||||
|
||||
@patch("src.graph.builder.StateGraph")
|
||||
def test_build_base_graph_adds_nodes_and_edges(MockStateGraph):
|
||||
mock_builder = MagicMock()
|
||||
MockStateGraph.return_value = mock_builder
|
||||
|
||||
builder_mod._build_base_graph()
|
||||
|
||||
# Check that all nodes and edges are added
|
||||
assert mock_builder.add_edge.call_count >= 2
|
||||
assert mock_builder.add_node.call_count >= 8
|
||||
# Now we have 1 conditional edges: research_team
|
||||
assert mock_builder.add_conditional_edges.call_count == 1
|
||||
|
||||
|
||||
@patch("src.graph.builder._build_base_graph")
|
||||
@patch("src.graph.builder.MemorySaver")
|
||||
def test_build_graph_with_memory_uses_memory(MockMemorySaver, mock_build_base_graph):
|
||||
mock_builder = MagicMock()
|
||||
mock_build_base_graph.return_value = mock_builder
|
||||
mock_memory = MagicMock()
|
||||
MockMemorySaver.return_value = mock_memory
|
||||
|
||||
builder_mod.build_graph_with_memory()
|
||||
|
||||
mock_builder.compile.assert_called_once_with(checkpointer=mock_memory)
|
||||
|
||||
|
||||
@patch("src.graph.builder._build_base_graph")
|
||||
def test_build_graph_without_memory(mock_build_base_graph):
|
||||
mock_builder = MagicMock()
|
||||
mock_build_base_graph.return_value = mock_builder
|
||||
|
||||
builder_mod.build_graph()
|
||||
|
||||
mock_builder.compile.assert_called_once_with()
|
||||
|
||||
|
||||
def test_graph_is_compiled():
|
||||
# The graph object should be the result of build_graph()
|
||||
with patch("src.graph.builder._build_base_graph") as mock_base:
|
||||
mock_builder = MagicMock()
|
||||
mock_base.return_value = mock_builder
|
||||
mock_builder.compile.return_value = "compiled_graph"
|
||||
# reload the module to re-run the graph assignment
|
||||
importlib.reload(sys.modules["src.graph.builder"])
|
||||
assert builder_mod.graph is not None
|
||||
@@ -1,317 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for the human_feedback_node locale fix.
|
||||
|
||||
Tests that the duplicate locale assignment issue is resolved:
|
||||
- Locale is safely retrieved from new_plan using .get() with fallback
|
||||
- If new_plan['locale'] doesn't exist, it doesn't cause a KeyError
|
||||
- If new_plan['locale'] is None or empty, the preserved state locale is used
|
||||
- If new_plan['locale'] has a valid value, it properly overrides the state locale
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.graph.nodes import preserve_state_meta_fields
|
||||
from src.graph.types import State
|
||||
from src.prompts.planner_model import Plan
|
||||
|
||||
|
||||
class TestHumanFeedbackLocaleFixture:
|
||||
"""Test suite for human_feedback_node locale safe handling."""
|
||||
|
||||
def test_preserve_state_meta_fields_no_keyerror(self):
|
||||
"""Test that preserve_state_meta_fields never raises KeyError."""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["locale"] == "zh-CN"
|
||||
assert "locale" in preserved
|
||||
|
||||
def test_new_plan_without_locale_override(self):
|
||||
"""
|
||||
Test scenario: Plan doesn't override locale when not provided in override dict.
|
||||
|
||||
Before fix: Would set locale twice (duplicate assignment)
|
||||
After fix: Uses .get() safely and only overrides if value is truthy
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# Simulate a plan that doesn't want to override the locale
|
||||
# (locale is in the plan for validation, but not in override dict)
|
||||
new_plan_dict = {"title": "Test", "thought": "Test", "steps": [], "locale": "zh-CN", "has_enough_context": False}
|
||||
|
||||
# Get preserved fields
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Build update dict like the fixed code does
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan_dict),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
# Simulate a dict that doesn't have locale override (like when plan_dict is empty for override)
|
||||
plan_override = {} # No locale in override dict
|
||||
|
||||
# Only override locale if override dict provides a valid value
|
||||
if plan_override.get("locale"):
|
||||
update_dict["locale"] = plan_override["locale"]
|
||||
|
||||
# The preserved locale should be used when override doesn't provide one
|
||||
assert update_dict["locale"] == "zh-CN"
|
||||
|
||||
def test_new_plan_with_none_locale(self):
|
||||
"""
|
||||
Test scenario: new_plan has locale=None.
|
||||
|
||||
Before fix: Would try to set locale to None (but Plan requires it)
|
||||
After fix: Uses preserved state locale since new_plan.get("locale") is falsy
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# new_plan with None locale (won't validate, but test the logic)
|
||||
new_plan_attempt = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
|
||||
|
||||
# Get preserved fields
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Build update dict like the fixed code does
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan_attempt),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
# Simulate checking for None locale (if it somehow got set)
|
||||
new_plan_with_none = {"locale": None}
|
||||
# Only override if new_plan provides a VALID value
|
||||
if new_plan_with_none.get("locale"):
|
||||
update_dict["locale"] = new_plan_with_none["locale"]
|
||||
|
||||
# Should use preserved locale (zh-CN), not None
|
||||
assert update_dict["locale"] == "zh-CN"
|
||||
assert update_dict["locale"] is not None
|
||||
|
||||
def test_new_plan_with_empty_string_locale(self):
|
||||
"""
|
||||
Test scenario: new_plan has locale="" (empty string).
|
||||
|
||||
Before fix: Would try to set locale to "" (but Plan requires valid value)
|
||||
After fix: Uses preserved state locale since empty string is falsy
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# new_plan with valid locale
|
||||
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
|
||||
|
||||
# Get preserved fields
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Build update dict like the fixed code does
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
# Simulate checking for empty string locale
|
||||
new_plan_empty = {"locale": ""}
|
||||
# Only override if new_plan provides a VALID (truthy) value
|
||||
if new_plan_empty.get("locale"):
|
||||
update_dict["locale"] = new_plan_empty["locale"]
|
||||
|
||||
# Should use preserved locale (zh-CN), not empty string
|
||||
assert update_dict["locale"] == "zh-CN"
|
||||
assert update_dict["locale"] != ""
|
||||
|
||||
def test_new_plan_with_valid_locale_overrides(self):
|
||||
"""
|
||||
Test scenario: new_plan has valid locale="en-US".
|
||||
|
||||
Before fix: Would override with new_plan locale ✓ (worked)
|
||||
After fix: Should still properly override with valid locale
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# new_plan has a different valid locale
|
||||
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
|
||||
|
||||
# Get preserved fields
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Build update dict like the fixed code does
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
# Override if new_plan provides a VALID value
|
||||
if new_plan.get("locale"):
|
||||
update_dict["locale"] = new_plan["locale"]
|
||||
|
||||
# Should override with new_plan locale
|
||||
assert update_dict["locale"] == "en-US"
|
||||
assert update_dict["locale"] != "zh-CN"
|
||||
|
||||
def test_locale_field_not_duplicated(self):
|
||||
"""
|
||||
Test that locale field is not duplicated in the update dict.
|
||||
|
||||
Before fix: locale was set twice in the same dict
|
||||
After fix: locale is only set once
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
|
||||
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Count how many times 'locale' is set
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
**preserved, # Sets locale once
|
||||
}
|
||||
|
||||
# Override locale only if new_plan provides valid value
|
||||
if new_plan.get("locale"):
|
||||
update_dict["locale"] = new_plan["locale"]
|
||||
|
||||
# Verify locale is in dict exactly once
|
||||
locale_count = sum(1 for k in update_dict.keys() if k == "locale")
|
||||
assert locale_count == 1
|
||||
assert update_dict["locale"] == "en-US" # Should be overridden
|
||||
|
||||
def test_all_meta_fields_preserved(self):
|
||||
"""
|
||||
Test that all 8 meta fields are preserved along with locale fix.
|
||||
|
||||
Ensures the fix doesn't break other meta field preservation.
|
||||
"""
|
||||
state = State(
|
||||
messages=[],
|
||||
locale="zh-CN",
|
||||
research_topic="Research",
|
||||
clarified_research_topic="Clarified",
|
||||
clarification_history=["Q1"],
|
||||
enable_clarification=True,
|
||||
max_clarification_rounds=5,
|
||||
clarification_rounds=1,
|
||||
resources=["resource1"],
|
||||
)
|
||||
|
||||
new_plan = {"title": "Test", "thought": "Test", "steps": [], "locale": "en-US", "has_enough_context": False}
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# All 8 meta fields should be in preserved
|
||||
meta_fields = [
|
||||
"locale",
|
||||
"research_topic",
|
||||
"clarified_research_topic",
|
||||
"clarification_history",
|
||||
"enable_clarification",
|
||||
"max_clarification_rounds",
|
||||
"clarification_rounds",
|
||||
"resources",
|
||||
]
|
||||
|
||||
for field in meta_fields:
|
||||
assert field in preserved
|
||||
|
||||
# Build update dict
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
# Override locale if new_plan provides valid value
|
||||
if new_plan.get("locale"):
|
||||
update_dict["locale"] = new_plan["locale"]
|
||||
|
||||
# All meta fields should be in update_dict
|
||||
for field in meta_fields:
|
||||
assert field in update_dict
|
||||
|
||||
|
||||
class TestHumanFeedbackLocaleScenarios:
|
||||
"""Real-world scenarios for human_feedback_node locale handling."""
|
||||
|
||||
def test_scenario_chinese_locale_preserved_when_plan_has_no_locale(self):
|
||||
"""
|
||||
Scenario: User selected Chinese, plan preserves it.
|
||||
|
||||
Expected: Preserved Chinese locale should be used
|
||||
"""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# Plan from planner with required fields
|
||||
new_plan_json = {
|
||||
"title": "Research Plan",
|
||||
"thought": "...",
|
||||
"steps": [
|
||||
{
|
||||
"title": "Step 1",
|
||||
"description": "...",
|
||||
"need_search": True,
|
||||
"step_type": "research",
|
||||
}
|
||||
],
|
||||
"locale": "zh-CN",
|
||||
"has_enough_context": False,
|
||||
}
|
||||
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan_json),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
if new_plan_json.get("locale"):
|
||||
update_dict["locale"] = new_plan_json["locale"]
|
||||
|
||||
# Chinese locale should be preserved
|
||||
assert update_dict["locale"] == "zh-CN"
|
||||
|
||||
def test_scenario_en_us_restored_even_if_plan_minimal(self):
|
||||
"""
|
||||
Scenario: Minimal plan with en-US locale.
|
||||
|
||||
Expected: Preserved en-US locale should survive
|
||||
"""
|
||||
state = State(messages=[], locale="en-US")
|
||||
|
||||
# Minimal plan with required fields
|
||||
new_plan_json = {"title": "Quick Plan", "steps": [], "locale": "en-US", "has_enough_context": False}
|
||||
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan_json),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
if new_plan_json.get("locale"):
|
||||
update_dict["locale"] = new_plan_json["locale"]
|
||||
|
||||
# en-US should survive
|
||||
assert update_dict["locale"] == "en-US"
|
||||
|
||||
def test_scenario_multiple_locale_updates_safe(self):
|
||||
"""
|
||||
Scenario: Multiple plan iterations with locale preservation.
|
||||
|
||||
Expected: Each iteration safely handles locale
|
||||
"""
|
||||
locales = ["zh-CN", "en-US", "fr-FR"]
|
||||
|
||||
for locale in locales:
|
||||
state = State(messages=[], locale=locale)
|
||||
new_plan = {"title": "Plan", "steps": [], "locale": locale, "has_enough_context": False}
|
||||
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
update_dict = {
|
||||
"current_plan": Plan.model_validate(new_plan),
|
||||
**preserved,
|
||||
}
|
||||
|
||||
if new_plan.get("locale"):
|
||||
update_dict["locale"] = new_plan["locale"]
|
||||
|
||||
# Each iteration should preserve its locale
|
||||
assert update_dict["locale"] == locale
|
||||
@@ -1,623 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for recursion limit fallback functionality in graph nodes.
|
||||
|
||||
Tests the graceful fallback behavior when agents hit the recursion limit,
|
||||
including the _handle_recursion_limit_fallback function and the
|
||||
enable_recursion_fallback configuration option.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from src.config.configuration import Configuration
|
||||
from src.graph.nodes import _handle_recursion_limit_fallback
|
||||
from src.graph.types import State
|
||||
|
||||
|
||||
class TestHandleRecursionLimitFallback:
|
||||
"""Test suite for _handle_recursion_limit_fallback() function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_generates_summary_from_observations(self):
|
||||
"""Test that fallback generates summary using accumulated agent messages."""
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
# Create test state with messages
|
||||
state = State(
|
||||
messages=[
|
||||
HumanMessage(content="Research topic: AI safety")
|
||||
],
|
||||
locale="en-US",
|
||||
)
|
||||
|
||||
# Mock current step
|
||||
current_step = MagicMock()
|
||||
current_step.execution_res = None
|
||||
|
||||
# Mock partial agent messages (accumulated during execution before hitting limit)
|
||||
tool_call = ToolCall(
|
||||
name="web_search",
|
||||
args={"query": "AI safety"},
|
||||
id="123"
|
||||
)
|
||||
|
||||
partial_agent_messages = [
|
||||
HumanMessage(content="# Research Topic\n\nAI safety\n\n# Current Step\n\n## Title\n\nAnalyze AI safety"),
|
||||
AIMessage(content="", tool_calls=[tool_call]),
|
||||
HumanMessage(content="Tool result: Found 5 articles about AI safety"),
|
||||
]
|
||||
|
||||
# Mock the LLM response
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "# Summary\n\nBased on the research, AI safety is important."
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_get_system_prompt.return_value = "Fallback instructions"
|
||||
|
||||
# Call the fallback function
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify result is a list
|
||||
assert isinstance(result, list)
|
||||
|
||||
# Verify step execution result was set
|
||||
assert current_step.execution_res == mock_llm_response.content
|
||||
|
||||
# Verify messages include partial agent messages and the AI response
|
||||
# Should have partial messages + 1 new AI response
|
||||
assert len(result) == len(partial_agent_messages) + 1
|
||||
# Last message should be the fallback AI response
|
||||
assert isinstance(result[-1], AIMessage)
|
||||
assert result[-1].content == mock_llm_response.content
|
||||
assert result[-1].name == "researcher"
|
||||
# First messages should be from partial_agent_messages
|
||||
assert result[0] == partial_agent_messages[0]
|
||||
assert result[1] == partial_agent_messages[1]
|
||||
assert result[2] == partial_agent_messages[2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_applies_prompt_template(self):
|
||||
"""Test that fallback applies the recursion_fallback prompt template."""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
current_step = MagicMock()
|
||||
# Create non-empty messages to avoid early return
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Summary in Chinese"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_get_system_prompt.return_value = "Template rendered"
|
||||
|
||||
await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify get_system_prompt_template was called with correct arguments
|
||||
assert mock_get_system_prompt.call_count == 2 # Called twice (once for agent, once for fallback)
|
||||
|
||||
# Check the first call (for agent prompt)
|
||||
first_call = mock_get_system_prompt.call_args_list[0]
|
||||
assert first_call[0][0] == "researcher" # agent_name
|
||||
assert first_call[0][1]["locale"] == "zh-CN" # locale in state
|
||||
|
||||
# Check the second call (for recursion_fallback prompt)
|
||||
second_call = mock_get_system_prompt.call_args_list[1]
|
||||
assert second_call[0][0] == "recursion_fallback" # prompt_name
|
||||
assert second_call[0][1]["locale"] == "zh-CN" # locale in state
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_gets_llm_without_tools(self):
|
||||
"""Test that fallback gets LLM without tools bound."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
partial_agent_messages = []
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value="Template"), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="coder",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# With empty messages, should return empty list
|
||||
assert result == []
|
||||
|
||||
# Verify get_llm_by_type was NOT called (empty messages return early)
|
||||
mock_get_llm.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_sanitizes_response(self):
|
||||
"""Test that fallback response is sanitized."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
|
||||
# Create test messages (not empty)
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
# Mock unsanitized response with extra tokens
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "<extra_tokens>Summary content</extra_tokens>"
|
||||
|
||||
sanitized_content = "Summary content"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=sanitized_content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify sanitized content was used
|
||||
assert result[-1].content == sanitized_content
|
||||
assert current_step.execution_res == sanitized_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_preserves_meta_fields(self):
|
||||
"""Test that fallback uses state locale correctly."""
|
||||
state = State(
|
||||
messages=[],
|
||||
locale="zh-CN",
|
||||
research_topic="原始研究主题",
|
||||
clarification_rounds=2,
|
||||
)
|
||||
current_step = MagicMock()
|
||||
|
||||
# Create test messages (not empty)
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_get_system_prompt.return_value = "Template"
|
||||
|
||||
await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify locale was passed to template
|
||||
call_args = mock_get_system_prompt.call_args
|
||||
assert call_args[0][1]["locale"] == "zh-CN"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_raises_on_llm_failure(self):
|
||||
"""Test that fallback raises exception when LLM call fails."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
|
||||
# Create test messages (not empty)
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(side_effect=Exception("LLM API error"))
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# Should raise the exception
|
||||
with pytest.raises(Exception, match="LLM API error"):
|
||||
await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_handles_different_agent_types(self):
|
||||
"""Test that fallback works with different agent types."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
|
||||
# Create test messages (not empty)
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Agent summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
for agent_name in ["researcher", "coder", "analyst"]:
|
||||
current_step = MagicMock()
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name=agent_name,
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify agent name is set correctly
|
||||
assert result[-1].name == agent_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_uses_partial_agent_messages(self):
|
||||
"""Test that fallback includes partial agent messages in result."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
|
||||
# Create partial agent messages with tool calls
|
||||
# Use proper tool_call format
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
tool_call = ToolCall(
|
||||
name="web_search",
|
||||
args={"query": "test query"},
|
||||
id="123"
|
||||
)
|
||||
|
||||
partial_agent_messages = [
|
||||
HumanMessage(content="Input message"),
|
||||
AIMessage(content="", tool_calls=[tool_call]),
|
||||
HumanMessage(content="Tool result: Search completed"),
|
||||
]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Fallback summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify partial messages are in result
|
||||
result_messages = result
|
||||
assert len(result_messages) == len(partial_agent_messages) + 1
|
||||
# First messages should be from partial_agent_messages
|
||||
assert result_messages[0] == partial_agent_messages[0]
|
||||
assert result_messages[1] == partial_agent_messages[1]
|
||||
assert result_messages[2] == partial_agent_messages[2]
|
||||
# Last message should be the fallback AI response
|
||||
assert isinstance(result_messages[3], AIMessage)
|
||||
assert result_messages[3].content == "Fallback summary"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_handles_empty_partial_messages(self):
|
||||
"""Test that fallback handles empty partial_agent_messages."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
partial_agent_messages = [] # Empty
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Fallback summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# With empty messages, should return empty list (early return)
|
||||
assert result == []
|
||||
# Verify get_llm_by_type was NOT called (early return)
|
||||
mock_get_llm.assert_not_called()
|
||||
|
||||
|
||||
class TestRecursionFallbackConfiguration:
|
||||
"""Test suite for enable_recursion_fallback configuration."""
|
||||
|
||||
def test_config_default_is_enabled(self):
|
||||
"""Test that enable_recursion_fallback defaults to True."""
|
||||
config = Configuration()
|
||||
|
||||
assert config.enable_recursion_fallback is True
|
||||
|
||||
def test_config_from_env_variable_true(self):
|
||||
"""Test that enable_recursion_fallback can be set via environment variable."""
|
||||
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "true"}):
|
||||
config = Configuration()
|
||||
assert config.enable_recursion_fallback is True
|
||||
|
||||
def test_config_from_env_variable_false(self):
|
||||
"""Test that enable_recursion_fallback can be disabled via environment variable.
|
||||
NOTE: This test documents the current behavior. The Configuration.from_runnable_config
|
||||
method has a known issue where it doesn't properly convert boolean strings like "false"
|
||||
to boolean False. This test reflects the actual (buggy) behavior and should be updated
|
||||
when the Configuration class is fixed to use get_bool_env for boolean fields.
|
||||
"""
|
||||
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "false"}):
|
||||
config = Configuration()
|
||||
# Currently returns True due to Configuration class bug
|
||||
# Should return False when using get_bool_env properly
|
||||
assert config.enable_recursion_fallback is True # Actual behavior
|
||||
|
||||
def test_config_from_env_variable_1(self):
|
||||
"""Test that '1' is treated as True for enable_recursion_fallback."""
|
||||
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "1"}):
|
||||
config = Configuration()
|
||||
assert config.enable_recursion_fallback is True
|
||||
|
||||
def test_config_from_env_variable_0(self):
|
||||
"""Test that '0' is treated as False for enable_recursion_fallback.
|
||||
NOTE: This test documents the current behavior. The Configuration class has a known
|
||||
issue where string "0" is not properly converted to boolean False.
|
||||
"""
|
||||
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "0"}):
|
||||
config = Configuration()
|
||||
# Currently returns True due to Configuration class bug
|
||||
assert config.enable_recursion_fallback is True # Actual behavior
|
||||
|
||||
def test_config_from_env_variable_yes(self):
|
||||
"""Test that 'yes' is treated as True for enable_recursion_fallback."""
|
||||
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "yes"}):
|
||||
config = Configuration()
|
||||
assert config.enable_recursion_fallback is True
|
||||
|
||||
def test_config_from_env_variable_no(self):
|
||||
"""Test that 'no' is treated as False for enable_recursion_fallback.
|
||||
NOTE: This test documents the current behavior. The Configuration class has a known
|
||||
issue where string "no" is not properly converted to boolean False.
|
||||
"""
|
||||
with patch.dict("os.environ", {"ENABLE_RECURSION_FALLBACK": "no"}):
|
||||
config = Configuration()
|
||||
# Currently returns True due to Configuration class bug
|
||||
assert config.enable_recursion_fallback is True # Actual behavior
|
||||
|
||||
def test_config_from_runnable_config(self):
|
||||
"""Test that enable_recursion_fallback can be set via RunnableConfig."""
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
# Test with False value
|
||||
config_false = RunnableConfig(configurable={"enable_recursion_fallback": False})
|
||||
configuration_false = Configuration.from_runnable_config(config_false)
|
||||
assert configuration_false.enable_recursion_fallback is False
|
||||
|
||||
# Test with True value
|
||||
config_true = RunnableConfig(configurable={"enable_recursion_fallback": True})
|
||||
configuration_true = Configuration.from_runnable_config(config_true)
|
||||
assert configuration_true.enable_recursion_fallback is True
|
||||
|
||||
def test_config_field_exists(self):
|
||||
"""Test that enable_recursion_fallback field exists in Configuration."""
|
||||
config = Configuration()
|
||||
|
||||
assert hasattr(config, "enable_recursion_fallback")
|
||||
assert isinstance(config.enable_recursion_fallback, bool)
|
||||
|
||||
|
||||
class TestRecursionFallbackIntegration:
|
||||
"""Integration tests for recursion fallback in agent execution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_function_signature_returns_list(self):
|
||||
"""Test that the fallback function returns a list of messages."""
|
||||
from src.graph.nodes import _handle_recursion_limit_fallback
|
||||
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
# Create non-empty messages to avoid early return
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# This should not raise - just verify the function returns a list
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify it returns a list
|
||||
assert isinstance(result, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configuration_enables_disables_fallback(self):
|
||||
"""Test that configuration controls fallback behavior."""
|
||||
configurable_enabled = Configuration(enable_recursion_fallback=True)
|
||||
configurable_disabled = Configuration(enable_recursion_fallback=False)
|
||||
|
||||
assert configurable_enabled.enable_recursion_fallback is True
|
||||
assert configurable_disabled.enable_recursion_fallback is False
|
||||
|
||||
|
||||
class TestRecursionFallbackEdgeCases:
|
||||
"""Test edge cases and boundary conditions for recursion fallback."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_with_empty_observations(self):
|
||||
"""Test fallback behavior when there are no observations."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
partial_agent_messages = []
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "No observations available"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# With empty messages, should return empty list
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_with_very_long_recursion_limit(self):
|
||||
"""Test fallback with very large recursion limit value."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
current_step = MagicMock()
|
||||
partial_agent_messages = []
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template", return_value=""), \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# With empty messages, should return empty list
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_with_unicode_locale(self):
|
||||
"""Test fallback with various locale formats including unicode."""
|
||||
for locale in ["zh-CN", "ja-JP", "ko-KR", "en-US", "pt-BR"]:
|
||||
state = State(messages=[], locale=locale)
|
||||
current_step = MagicMock()
|
||||
# Create non-empty messages to avoid early return
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = f"Summary for {locale}"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_get_system_prompt.return_value = "Template"
|
||||
|
||||
await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify locale was passed to template
|
||||
call_args = mock_get_system_prompt.call_args
|
||||
assert call_args[0][1]["locale"] == locale
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_with_none_locale(self):
|
||||
"""Test fallback handles None locale gracefully."""
|
||||
state = State(messages=[], locale=None)
|
||||
current_step = MagicMock()
|
||||
# Create non-empty messages to avoid early return
|
||||
partial_agent_messages = [HumanMessage(content="Test")]
|
||||
|
||||
mock_llm_response = MagicMock()
|
||||
mock_llm_response.content = "Summary"
|
||||
|
||||
with patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, \
|
||||
patch("src.graph.nodes.get_system_prompt_template") as mock_get_system_prompt, \
|
||||
patch("src.graph.nodes.sanitize_tool_response", return_value=mock_llm_response.content):
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.invoke = MagicMock(return_value=mock_llm_response)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_get_system_prompt.return_value = "Template"
|
||||
|
||||
# Should not raise, should use default locale
|
||||
await _handle_recursion_limit_fallback(
|
||||
messages=partial_agent_messages,
|
||||
agent_name="researcher",
|
||||
current_step=current_step,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Verify default locale "en-US" was used
|
||||
call_args = mock_get_system_prompt.call_args
|
||||
assert call_args[0][1]["locale"] is None or call_args[0][1]["locale"] == "en-US"
|
||||
@@ -1,491 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.graph.nodes import validate_and_fix_plan
|
||||
|
||||
|
||||
class TestValidateAndFixPlanStepTypeRepair:
|
||||
"""Test step_type field repair logic (Issue #650 fix)."""
|
||||
|
||||
def test_repair_missing_step_type_with_need_search_true(self):
|
||||
"""Test that missing step_type is inferred as 'research' when need_search=true."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Research Step",
|
||||
"description": "Gather data",
|
||||
# step_type is MISSING
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan)
|
||||
|
||||
assert result["steps"][0]["step_type"] == "research"
|
||||
|
||||
def test_repair_missing_step_type_with_need_search_false(self):
|
||||
"""Test that missing step_type is inferred as 'analysis' when need_search=false (Issue #677)."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "Processing Step",
|
||||
"description": "Analyze data",
|
||||
# step_type is MISSING
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan)
|
||||
|
||||
# Issue #677: non-search steps now default to 'analysis' instead of 'processing'
|
||||
assert result["steps"][0]["step_type"] == "analysis"
|
||||
|
||||
def test_repair_missing_step_type_default_to_analysis(self):
|
||||
"""Test that missing step_type defaults to 'analysis' when need_search is not specified (Issue #677)."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"title": "Unknown Step",
|
||||
"description": "Do something",
|
||||
# need_search is MISSING, step_type is MISSING
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan)
|
||||
|
||||
# Issue #677: non-search steps now default to 'analysis' instead of 'processing'
|
||||
assert result["steps"][0]["step_type"] == "analysis"
|
||||
|
||||
def test_repair_empty_step_type_field(self):
|
||||
"""Test that empty step_type field is repaired."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Research Step",
|
||||
"description": "Gather data",
|
||||
"step_type": "", # Empty string
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan)
|
||||
|
||||
assert result["steps"][0]["step_type"] == "research"
|
||||
|
||||
def test_repair_null_step_type_field(self):
|
||||
"""Test that null step_type field is repaired."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "Processing Step",
|
||||
"description": "Analyze data",
|
||||
"step_type": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan)
|
||||
|
||||
# Issue #677: non-search steps now default to 'analysis' instead of 'processing'
|
||||
assert result["steps"][0]["step_type"] == "analysis"
|
||||
|
||||
def test_multiple_steps_with_mixed_missing_step_types(self):
|
||||
"""Test repair of multiple steps with different missing step_type scenarios."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Research 1",
|
||||
"description": "Gather",
|
||||
# MISSING step_type
|
||||
},
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "Processing 1",
|
||||
"description": "Analyze",
|
||||
"step_type": "processing", # Already has step_type
|
||||
},
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Research 2",
|
||||
"description": "More gathering",
|
||||
# MISSING step_type
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan)
|
||||
|
||||
assert result["steps"][0]["step_type"] == "research"
|
||||
assert result["steps"][1]["step_type"] == "processing" # Should remain unchanged
|
||||
assert result["steps"][2]["step_type"] == "research"
|
||||
|
||||
def test_preserve_explicit_step_type(self):
|
||||
"""Test that explicitly provided step_type values are preserved."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Research Step",
|
||||
"description": "Gather",
|
||||
"step_type": "research",
|
||||
},
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "Processing Step",
|
||||
"description": "Analyze",
|
||||
"step_type": "processing",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan)
|
||||
|
||||
# Should remain unchanged
|
||||
assert result["steps"][0]["step_type"] == "research"
|
||||
assert result["steps"][1]["step_type"] == "processing"
|
||||
|
||||
def test_repair_logs_warning(self):
|
||||
"""Test that repair operations are logged."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Missing Type Step",
|
||||
"description": "Gather",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch("src.graph.nodes.logger") as mock_logger:
|
||||
validate_and_fix_plan(plan)
|
||||
# Should log repair operation
|
||||
mock_logger.info.assert_called()
|
||||
# Check that any of the info calls contains "Repaired missing step_type"
|
||||
assert any("Repaired missing step_type" in str(call) for call in mock_logger.info.call_args_list)
|
||||
|
||||
def test_non_dict_plan_returns_unchanged(self):
|
||||
"""Test that non-dict plans are returned unchanged."""
|
||||
plan = "not a dict"
|
||||
result = validate_and_fix_plan(plan)
|
||||
assert result == plan
|
||||
|
||||
def test_plan_with_non_dict_step_skipped(self):
|
||||
"""Test that non-dict step items are skipped without error."""
|
||||
plan = {
|
||||
"steps": [
|
||||
"not a dict step", # This should be skipped
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Valid Step",
|
||||
"description": "Gather",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan)
|
||||
|
||||
# Non-dict step should be unchanged, valid step should be fixed
|
||||
assert result["steps"][0] == "not a dict step"
|
||||
assert result["steps"][1]["step_type"] == "research"
|
||||
|
||||
def test_empty_steps_list(self):
|
||||
"""Test that plan with empty steps list is handled gracefully."""
|
||||
plan = {"steps": []}
|
||||
result = validate_and_fix_plan(plan)
|
||||
assert result["steps"] == []
|
||||
|
||||
def test_missing_steps_key(self):
|
||||
"""Test that plan without steps key is handled gracefully."""
|
||||
plan = {"locale": "en-US", "title": "Test"}
|
||||
result = validate_and_fix_plan(plan)
|
||||
assert "steps" not in result
|
||||
|
||||
|
||||
class TestValidateAndFixPlanWebSearchEnforcement:
|
||||
"""Test web search enforcement logic."""
|
||||
|
||||
def test_enforce_web_search_sets_first_research_step(self):
|
||||
"""Test that enforce_web_search=True sets need_search on first research step."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "Research Step",
|
||||
"description": "Gather",
|
||||
"step_type": "research",
|
||||
},
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "Processing Step",
|
||||
"description": "Analyze",
|
||||
"step_type": "processing",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan, enforce_web_search=True)
|
||||
|
||||
# First research step should have web search enabled
|
||||
assert result["steps"][0]["need_search"] is True
|
||||
assert result["steps"][1]["need_search"] is False
|
||||
|
||||
def test_enforce_web_search_converts_first_step(self):
|
||||
"""Test that enforce_web_search converts first step to research if needed."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "First Step",
|
||||
"description": "Do something",
|
||||
"step_type": "processing",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan, enforce_web_search=True)
|
||||
|
||||
# First step should be converted to research with web search
|
||||
assert result["steps"][0]["step_type"] == "research"
|
||||
assert result["steps"][0]["need_search"] is True
|
||||
|
||||
def test_enforce_web_search_with_existing_search_step(self):
|
||||
"""Test that enforce_web_search doesn't modify if search step already exists."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Research Step",
|
||||
"description": "Gather",
|
||||
"step_type": "research",
|
||||
},
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "Processing Step",
|
||||
"description": "Analyze",
|
||||
"step_type": "processing",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan, enforce_web_search=True)
|
||||
|
||||
# Steps should remain unchanged
|
||||
assert result["steps"][0]["need_search"] is True
|
||||
assert result["steps"][1]["need_search"] is False
|
||||
|
||||
def test_enforce_web_search_adds_default_step(self):
|
||||
"""Test that enforce_web_search adds default research step if no steps exist."""
|
||||
plan = {"steps": []}
|
||||
|
||||
result = validate_and_fix_plan(plan, enforce_web_search=True)
|
||||
|
||||
assert len(result["steps"]) == 1
|
||||
assert result["steps"][0]["step_type"] == "research"
|
||||
assert result["steps"][0]["need_search"] is True
|
||||
assert "title" in result["steps"][0]
|
||||
assert "description" in result["steps"][0]
|
||||
|
||||
def test_enforce_web_search_without_steps_key(self):
|
||||
"""Test enforce_web_search when steps key is missing."""
|
||||
plan = {"locale": "en-US"}
|
||||
|
||||
result = validate_and_fix_plan(plan, enforce_web_search=True)
|
||||
|
||||
assert len(result.get("steps", [])) > 0
|
||||
assert result["steps"][0]["step_type"] == "research"
|
||||
|
||||
|
||||
class TestValidateAndFixPlanIntegration:
|
||||
"""Integration tests for step_type repair and web search enforcement together."""
|
||||
|
||||
def test_repair_and_enforce_together(self):
|
||||
"""Test that step_type repair and web search enforcement work together."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Research Step",
|
||||
"description": "Gather",
|
||||
# MISSING step_type
|
||||
},
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "Processing Step",
|
||||
"description": "Analyze",
|
||||
# MISSING step_type, but enforce_web_search won't change it
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan, enforce_web_search=True)
|
||||
|
||||
# step_type should be repaired
|
||||
assert result["steps"][0]["step_type"] == "research"
|
||||
# Issue #677: non-search steps now default to 'analysis' instead of 'processing'
|
||||
assert result["steps"][1]["step_type"] == "analysis"
|
||||
|
||||
# First research step should have web search (already has it)
|
||||
assert result["steps"][0]["need_search"] is True
|
||||
|
||||
def test_repair_then_enforce_cascade(self):
|
||||
"""Test complex scenario with repair and enforcement cascading."""
|
||||
plan = {
|
||||
"steps": [
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "Step 1",
|
||||
"description": "Do something",
|
||||
# MISSING step_type
|
||||
},
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "Step 2",
|
||||
"description": "Do something else",
|
||||
# MISSING step_type
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan, enforce_web_search=True)
|
||||
|
||||
# Step 1: Originally analysis (from auto-repair) but converted to research with web search enforcement
|
||||
assert result["steps"][0]["step_type"] == "research"
|
||||
assert result["steps"][0]["need_search"] is True
|
||||
|
||||
# Step 2: Should remain as analysis since enforcement already satisfied by step 1
|
||||
# Issue #677: non-search steps now default to 'analysis' instead of 'processing'
|
||||
assert result["steps"][1]["step_type"] == "analysis"
|
||||
assert result["steps"][1]["need_search"] is False
|
||||
|
||||
class TestValidateAndFixPlanIssue650:
|
||||
"""Specific tests for Issue #650 scenarios."""
|
||||
|
||||
def test_issue_650_water_footprint_scenario_fixed(self):
|
||||
"""Test the exact scenario from issue #650 - water footprint query with missing step_type."""
|
||||
# This is a simplified version of the actual error from issue #650
|
||||
plan = {
|
||||
"locale": "en-US",
|
||||
"has_enough_context": False,
|
||||
"title": "Research Plan — Water Footprint of 1 kg of Beef",
|
||||
"thought": "You asked: 'How many liters of water are required to produce 1 kg of beef?'",
|
||||
"steps": [
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Authoritative estimates",
|
||||
"description": "Collect peer-reviewed estimates",
|
||||
# MISSING step_type - this caused the error in issue #650
|
||||
},
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "System-specific data",
|
||||
"description": "Gather system-level data",
|
||||
# MISSING step_type
|
||||
},
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "Processing and analysis",
|
||||
"description": "Compute scenario-based estimates",
|
||||
# MISSING step_type
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan)
|
||||
|
||||
# All steps should now have step_type
|
||||
assert result["steps"][0]["step_type"] == "research"
|
||||
assert result["steps"][1]["step_type"] == "research"
|
||||
# Issue #677: non-search steps now default to 'analysis' instead of 'processing'
|
||||
assert result["steps"][2]["step_type"] == "analysis"
|
||||
|
||||
def test_issue_650_scenario_passes_pydantic_validation(self):
|
||||
"""Test that fixed plan can be validated by Pydantic schema."""
|
||||
from src.prompts.planner_model import Plan as PlanModel
|
||||
|
||||
plan = {
|
||||
"locale": "en-US",
|
||||
"has_enough_context": False,
|
||||
"title": "Research Plan",
|
||||
"thought": "Test thought",
|
||||
"steps": [
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Research",
|
||||
"description": "Gather data",
|
||||
# MISSING step_type
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# First validate and fix
|
||||
fixed_plan = validate_and_fix_plan(plan)
|
||||
|
||||
# Then try Pydantic validation (should not raise)
|
||||
validated = PlanModel.model_validate(fixed_plan)
|
||||
|
||||
assert validated.steps[0].step_type == "research"
|
||||
assert validated.steps[0].need_search is True
|
||||
|
||||
def test_issue_650_multiple_validation_errors_fixed(self):
|
||||
"""Test that plan with multiple missing step_types (like in issue #650) all get fixed."""
|
||||
plan = {
|
||||
"locale": "en-US",
|
||||
"has_enough_context": False,
|
||||
"title": "Complex Plan",
|
||||
"thought": "Research plan",
|
||||
"steps": [
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Step 0",
|
||||
"description": "Data gathering",
|
||||
},
|
||||
{
|
||||
"need_search": True,
|
||||
"title": "Step 1",
|
||||
"description": "More gathering",
|
||||
},
|
||||
{
|
||||
"need_search": False,
|
||||
"title": "Step 2",
|
||||
"description": "Processing",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
result = validate_and_fix_plan(plan)
|
||||
|
||||
# All steps should have step_type now
|
||||
for step in result["steps"]:
|
||||
assert "step_type" in step
|
||||
# Issue #677: 'analysis' is now a valid step_type
|
||||
assert step["step_type"] in ["research", "analysis", "processing"]
|
||||
|
||||
def test_issue_650_no_exceptions_raised(self):
|
||||
"""Test that validate_and_fix_plan handles all edge cases without raising exceptions."""
|
||||
test_cases = [
|
||||
{"steps": []},
|
||||
{"steps": [{"need_search": True}]},
|
||||
{"steps": [None, {}]},
|
||||
{"steps": ["invalid"]},
|
||||
{"steps": [{"need_search": True, "step_type": ""}]},
|
||||
"not a dict",
|
||||
]
|
||||
|
||||
for plan in test_cases:
|
||||
try:
|
||||
result = validate_and_fix_plan(plan)
|
||||
# Should succeed without exception - result may be returned as-is for non-dict
|
||||
# but the function should not raise
|
||||
# No assertion needed; test passes if no exception is raised
|
||||
except Exception as e:
|
||||
pytest.fail(f"validate_and_fix_plan raised exception for {plan}: {e}")
|
||||
@@ -1,355 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for state preservation functionality in graph nodes.
|
||||
|
||||
Tests the preserve_state_meta_fields() function and verifies that
|
||||
critical state fields (especially locale) are properly preserved
|
||||
across node state transitions.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from langgraph.types import Command
|
||||
|
||||
from src.graph.nodes import preserve_state_meta_fields
|
||||
from src.graph.types import State
|
||||
|
||||
|
||||
class TestPreserveStateMetaFields:
|
||||
"""Test suite for preserve_state_meta_fields() function."""
|
||||
|
||||
def test_preserve_all_fields_with_defaults(self):
|
||||
"""Test that all fields are preserved with default values when state is empty."""
|
||||
# Create a minimal state with only messages
|
||||
state = State(messages=[])
|
||||
|
||||
# Extract meta fields
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Verify all expected fields are present
|
||||
assert "locale" in preserved
|
||||
assert "research_topic" in preserved
|
||||
assert "clarified_research_topic" in preserved
|
||||
assert "clarification_history" in preserved
|
||||
assert "enable_clarification" in preserved
|
||||
assert "max_clarification_rounds" in preserved
|
||||
assert "clarification_rounds" in preserved
|
||||
assert "resources" in preserved
|
||||
|
||||
# Verify default values
|
||||
assert preserved["locale"] == "en-US"
|
||||
assert preserved["research_topic"] == ""
|
||||
assert preserved["clarified_research_topic"] == ""
|
||||
assert preserved["clarification_history"] == []
|
||||
assert preserved["enable_clarification"] is False
|
||||
assert preserved["max_clarification_rounds"] == 3
|
||||
assert preserved["clarification_rounds"] == 0
|
||||
assert preserved["resources"] == []
|
||||
|
||||
def test_preserve_locale_from_state(self):
|
||||
"""Test that locale is correctly preserved when set in state."""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["locale"] == "zh-CN"
|
||||
|
||||
def test_preserve_locale_english(self):
|
||||
"""Test that English locale is preserved."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["locale"] == "en-US"
|
||||
|
||||
def test_preserve_locale_with_custom_value(self):
|
||||
"""Test that custom locale values are preserved."""
|
||||
state = State(messages=[], locale="fr-FR")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["locale"] == "fr-FR"
|
||||
|
||||
def test_preserve_research_topic(self):
|
||||
"""Test that research_topic is correctly preserved."""
|
||||
test_topic = "How to build sustainable cities"
|
||||
state = State(messages=[], research_topic=test_topic)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["research_topic"] == test_topic
|
||||
|
||||
def test_preserve_clarified_research_topic(self):
|
||||
"""Test that clarified_research_topic is correctly preserved."""
|
||||
test_topic = "Sustainable urban development with focus on green spaces"
|
||||
state = State(messages=[], clarified_research_topic=test_topic)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["clarified_research_topic"] == test_topic
|
||||
|
||||
def test_preserve_clarification_history(self):
|
||||
"""Test that clarification_history is correctly preserved."""
|
||||
history = ["Q: What aspects?", "A: Architecture and planning"]
|
||||
state = State(messages=[], clarification_history=history)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["clarification_history"] == history
|
||||
|
||||
def test_preserve_clarification_flags(self):
|
||||
"""Test that clarification flags are correctly preserved."""
|
||||
state = State(
|
||||
messages=[],
|
||||
enable_clarification=True,
|
||||
max_clarification_rounds=5,
|
||||
clarification_rounds=2,
|
||||
)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["enable_clarification"] is True
|
||||
assert preserved["max_clarification_rounds"] == 5
|
||||
assert preserved["clarification_rounds"] == 2
|
||||
|
||||
def test_preserve_resources(self):
|
||||
"""Test that resources list is correctly preserved."""
|
||||
resources = [{"id": "1", "name": "Resource 1"}]
|
||||
state = State(messages=[], resources=resources)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["resources"] == resources
|
||||
|
||||
def test_preserve_all_fields_together(self):
|
||||
"""Test that all meta fields are preserved together correctly."""
|
||||
state = State(
|
||||
messages=[],
|
||||
locale="zh-CN",
|
||||
research_topic="原始查询",
|
||||
clarified_research_topic="澄清后的查询",
|
||||
clarification_history=["Q1", "A1", "Q2", "A2"],
|
||||
enable_clarification=True,
|
||||
max_clarification_rounds=5,
|
||||
clarification_rounds=2,
|
||||
resources=["resource1"],
|
||||
)
|
||||
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["locale"] == "zh-CN"
|
||||
assert preserved["research_topic"] == "原始查询"
|
||||
assert preserved["clarified_research_topic"] == "澄清后的查询"
|
||||
assert preserved["clarification_history"] == ["Q1", "A1", "Q2", "A2"]
|
||||
assert preserved["enable_clarification"] is True
|
||||
assert preserved["max_clarification_rounds"] == 5
|
||||
assert preserved["clarification_rounds"] == 2
|
||||
assert preserved["resources"] == ["resource1"]
|
||||
|
||||
def test_preserve_returns_dict_not_state_object(self):
|
||||
"""Test that preserve_state_meta_fields returns a dict."""
|
||||
state = State(messages=[], locale="zh-CN")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert isinstance(preserved, dict)
|
||||
# Verify it's a plain dict with expected keys
|
||||
assert "locale" in preserved
|
||||
assert "research_topic" in preserved
|
||||
|
||||
def test_preserve_does_not_mutate_original_state(self):
|
||||
"""Test that calling preserve_state_meta_fields does not mutate the original state."""
|
||||
original_locale = "zh-CN"
|
||||
state = State(messages=[], locale=original_locale)
|
||||
original_state_copy = dict(state)
|
||||
|
||||
preserve_state_meta_fields(state)
|
||||
|
||||
# Verify state hasn't changed
|
||||
assert state["locale"] == original_locale
|
||||
assert dict(state) == original_state_copy
|
||||
|
||||
def test_preserve_with_none_values(self):
|
||||
"""Test that preserve handles None values gracefully."""
|
||||
state = State(messages=[], locale=None)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Should use default when value is None
|
||||
assert preserved["locale"] is None or preserved["locale"] == "en-US"
|
||||
|
||||
def test_preserve_empty_lists_preserved(self):
|
||||
"""Test that empty lists are preserved correctly."""
|
||||
state = State(
|
||||
messages=[], clarification_history=[], resources=[]
|
||||
)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["clarification_history"] == []
|
||||
assert preserved["resources"] == []
|
||||
|
||||
def test_preserve_count_of_fields(self):
|
||||
"""Test that exactly 8 fields are preserved."""
|
||||
state = State(messages=[])
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Should have exactly 8 meta fields
|
||||
assert len(preserved) == 8
|
||||
|
||||
def test_preserve_field_names(self):
|
||||
"""Test that all expected field names are present."""
|
||||
state = State(messages=[])
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
expected_fields = {
|
||||
"locale",
|
||||
"research_topic",
|
||||
"clarified_research_topic",
|
||||
"clarification_history",
|
||||
"enable_clarification",
|
||||
"max_clarification_rounds",
|
||||
"clarification_rounds",
|
||||
"resources",
|
||||
}
|
||||
|
||||
assert set(preserved.keys()) == expected_fields
|
||||
|
||||
|
||||
class TestStatePreservationInCommand:
|
||||
"""Test suite for using preserved state fields in Command objects."""
|
||||
|
||||
def test_command_update_with_preserved_fields(self):
|
||||
"""Test that preserved fields can be unpacked into Command.update."""
|
||||
state = State(messages=[], locale="zh-CN", research_topic="测试")
|
||||
|
||||
# This should not raise any errors
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
command_update = {
|
||||
"messages": [],
|
||||
**preserved,
|
||||
}
|
||||
|
||||
assert "locale" in command_update
|
||||
assert "research_topic" in command_update
|
||||
assert command_update["locale"] == "zh-CN"
|
||||
|
||||
def test_command_unpacking_syntax(self):
|
||||
"""Test that the unpacking syntax works correctly with preserved fields."""
|
||||
state = State(messages=[], locale="en-US")
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
# Simulate how it's used in actual nodes
|
||||
update_dict = {
|
||||
"messages": [],
|
||||
"current_plan": None,
|
||||
**preserved,
|
||||
"locale": "zh-CN",
|
||||
}
|
||||
|
||||
assert len(update_dict) >= 10 # 2 explicit + 8 preserved
|
||||
assert update_dict["locale"] == "zh-CN" # overridden value
|
||||
|
||||
|
||||
class TestLocalePreservationSpecific:
|
||||
"""Specific test cases for locale preservation (the main issue being fixed)."""
|
||||
|
||||
def test_locale_not_lost_in_transition(self):
|
||||
"""Test that locale is not lost when transitioning between nodes."""
|
||||
# Initial state from frontend with Chinese locale
|
||||
initial_state = State(messages=[], locale="zh-CN")
|
||||
|
||||
# Extract for first node transition
|
||||
preserved_1 = preserve_state_meta_fields(initial_state)
|
||||
|
||||
# Simulate state update from first node
|
||||
updated_state_1 = State(
|
||||
messages=[], **preserved_1
|
||||
)
|
||||
|
||||
# Extract for second node transition
|
||||
preserved_2 = preserve_state_meta_fields(updated_state_1)
|
||||
|
||||
# Locale should still be zh-CN after two transitions
|
||||
assert preserved_2["locale"] == "zh-CN"
|
||||
|
||||
def test_locale_chain_through_multiple_nodes(self):
|
||||
"""Test that locale survives through multiple node transitions."""
|
||||
initial_locale = "zh-CN"
|
||||
state = State(messages=[], locale=initial_locale)
|
||||
|
||||
# Simulate 5 node transitions
|
||||
for _ in range(5):
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
assert preserved["locale"] == initial_locale
|
||||
|
||||
# Create new state for next "node"
|
||||
state = State(messages=[], **preserved)
|
||||
|
||||
# After 5 transitions, locale should still be preserved
|
||||
assert state.get("locale") == initial_locale
|
||||
|
||||
def test_locale_with_other_fields_preserved_together(self):
|
||||
"""Test that locale is preserved correctly even when other fields change."""
|
||||
initial_state = State(
|
||||
messages=[],
|
||||
locale="zh-CN",
|
||||
research_topic="Original",
|
||||
clarification_rounds=0,
|
||||
)
|
||||
|
||||
preserved = preserve_state_meta_fields(initial_state)
|
||||
|
||||
# Verify locale is in preserved dict
|
||||
assert preserved["locale"] == "zh-CN"
|
||||
assert preserved["research_topic"] == "Original"
|
||||
assert preserved["clarification_rounds"] == 0
|
||||
|
||||
# Create new state with preserved fields
|
||||
modified_state = State(
|
||||
messages=[],
|
||||
**preserved,
|
||||
)
|
||||
|
||||
# Locale should be preserved
|
||||
assert modified_state.get("locale") == "zh-CN"
|
||||
# Research topic should be preserved from original
|
||||
assert modified_state.get("research_topic") == "Original"
|
||||
assert modified_state.get("clarification_rounds") == 0
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and boundary conditions."""
|
||||
|
||||
def test_very_long_research_topic(self):
|
||||
"""Test preservation with very long research_topic."""
|
||||
long_topic = "a" * 10000
|
||||
state = State(messages=[], research_topic=long_topic)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["research_topic"] == long_topic
|
||||
|
||||
def test_unicode_characters_in_topic(self):
|
||||
"""Test preservation with unicode characters."""
|
||||
unicode_topic = "中文测试 🌍 テスト 🧪"
|
||||
state = State(messages=[], research_topic=unicode_topic)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert preserved["research_topic"] == unicode_topic
|
||||
|
||||
def test_special_characters_in_locale(self):
|
||||
"""Test preservation with special locale formats."""
|
||||
special_locales = ["zh-CN", "en-US", "pt-BR", "es-ES", "ja-JP"]
|
||||
|
||||
for locale in special_locales:
|
||||
state = State(messages=[], locale=locale)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
assert preserved["locale"] == locale
|
||||
|
||||
def test_large_clarification_history(self):
|
||||
"""Test preservation with large clarification_history."""
|
||||
large_history = [f"Q{i}: Question {i}" for i in range(100)]
|
||||
state = State(messages=[], clarification_history=large_history)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
|
||||
assert len(preserved["clarification_history"]) == 100
|
||||
assert preserved["clarification_history"] == large_history
|
||||
|
||||
def test_max_clarification_rounds_boundary(self):
|
||||
"""Test preservation with boundary values for max_clarification_rounds."""
|
||||
test_cases = [0, 1, 3, 10, 100, 999]
|
||||
|
||||
for value in test_cases:
|
||||
state = State(messages=[], max_clarification_rounds=value)
|
||||
preserved = preserve_state_meta_fields(state)
|
||||
assert preserved["max_clarification_rounds"] == value
|
||||
@@ -1,305 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
HumanMessageChunk,
|
||||
SystemMessageChunk,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
|
||||
from src.llms import llm as llm_module
|
||||
from src.llms.providers import dashscope as dashscope_module
|
||||
from src.llms.providers.dashscope import (
|
||||
ChatDashscope,
|
||||
_convert_chunk_to_generation_chunk,
|
||||
_convert_delta_to_message_chunk,
|
||||
)
|
||||
|
||||
|
||||
class DummyChatDashscope:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dashscope_conf():
|
||||
return {
|
||||
"BASIC_MODEL": {
|
||||
"api_key": "k",
|
||||
"base_url": "https://dashscope.aliyuncs.com/v1",
|
||||
"model": "qwen3-235b-a22b-instruct-2507",
|
||||
},
|
||||
"REASONING_MODEL": {
|
||||
"api_key": "rk",
|
||||
"base_url": "https://dashscope.aliyuncs.com/v1",
|
||||
"model": "qwen3-235b-a22b-thinking-2507",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_convert_delta_to_message_chunk_roles_and_extras():
|
||||
# Assistant with reasoning + tool calls
|
||||
delta = {
|
||||
"role": "assistant",
|
||||
"content": "Hello",
|
||||
"reasoning_content": "Think...",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"index": 0,
|
||||
"function": {"name": "lookup", "arguments": '{\\"q\\":\\"x\\"}'},
|
||||
}
|
||||
],
|
||||
}
|
||||
msg = _convert_delta_to_message_chunk(delta, AIMessageChunk)
|
||||
assert isinstance(msg, AIMessageChunk)
|
||||
assert msg.content == "Hello"
|
||||
assert msg.additional_kwargs.get("reasoning_content") == "Think..."
|
||||
# tool_call_chunks should be present
|
||||
assert getattr(msg, "tool_call_chunks", None)
|
||||
|
||||
# Human
|
||||
delta = {"role": "user", "content": "Hi"}
|
||||
msg = _convert_delta_to_message_chunk(delta, HumanMessageChunk)
|
||||
assert isinstance(msg, HumanMessageChunk)
|
||||
|
||||
# System
|
||||
delta = {"role": "system", "content": "Rules"}
|
||||
msg = _convert_delta_to_message_chunk(delta, SystemMessageChunk)
|
||||
assert isinstance(msg, SystemMessageChunk)
|
||||
|
||||
# Function
|
||||
delta = {"role": "function", "name": "f", "content": "{}"}
|
||||
msg = _convert_delta_to_message_chunk(delta, FunctionMessageChunk)
|
||||
assert isinstance(msg, FunctionMessageChunk)
|
||||
|
||||
# Tool
|
||||
delta = {"role": "tool", "tool_call_id": "t1", "content": "ok"}
|
||||
msg = _convert_delta_to_message_chunk(delta, ToolMessageChunk)
|
||||
assert isinstance(msg, ToolMessageChunk)
|
||||
|
||||
|
||||
def test_convert_chunk_to_generation_chunk_skip_and_usage():
|
||||
# Skips content.delta type
|
||||
assert (
|
||||
_convert_chunk_to_generation_chunk(
|
||||
{"type": "content.delta"}, AIMessageChunk, None
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
# Proper chunk with usage and finish info
|
||||
chunk = {
|
||||
"choices": [
|
||||
{
|
||||
"delta": {"role": "assistant", "content": "Hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"model": "qwen3-235b-a22b-instruct-2507",
|
||||
"system_fingerprint": "fp",
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||
}
|
||||
gen = _convert_chunk_to_generation_chunk(chunk, AIMessageChunk, None)
|
||||
assert gen is not None
|
||||
assert isinstance(gen.message, AIMessageChunk)
|
||||
assert gen.message.content == "Hi"
|
||||
# usage metadata should attach to AI message
|
||||
assert getattr(gen.message, "usage_metadata", None) is not None
|
||||
assert gen.generation_info.get("finish_reason") == "stop"
|
||||
assert gen.generation_info.get("model_name") == "qwen3-235b-a22b-instruct-2507"
|
||||
assert gen.generation_info.get("system_fingerprint") == "fp"
|
||||
|
||||
|
||||
def test_llm_selects_dashscope_and_sets_enable_thinking(monkeypatch, dashscope_conf):
|
||||
# Use dummy class to capture kwargs on construction
|
||||
monkeypatch.setattr(llm_module, "ChatDashscope", DummyChatDashscope)
|
||||
|
||||
# basic -> enable_thinking False
|
||||
inst = llm_module._create_llm_use_conf("basic", dashscope_conf)
|
||||
assert isinstance(inst, DummyChatDashscope)
|
||||
assert inst.kwargs["extra_body"]["enable_thinking"] is False
|
||||
assert inst.kwargs["base_url"].find("dashscope.") > 0
|
||||
|
||||
# reasoning -> enable_thinking True
|
||||
inst2 = llm_module._create_llm_use_conf("reasoning", dashscope_conf)
|
||||
assert isinstance(inst2, DummyChatDashscope)
|
||||
assert inst2.kwargs["extra_body"]["enable_thinking"] is True
|
||||
|
||||
|
||||
def test_llm_verify_ssl_false_adds_http_clients(monkeypatch, dashscope_conf):
|
||||
monkeypatch.setattr(llm_module, "ChatDashscope", DummyChatDashscope)
|
||||
# turn off ssl
|
||||
dashscope_conf = {**dashscope_conf}
|
||||
dashscope_conf["BASIC_MODEL"] = {
|
||||
**dashscope_conf["BASIC_MODEL"],
|
||||
"verify_ssl": False,
|
||||
}
|
||||
|
||||
inst = llm_module._create_llm_use_conf("basic", dashscope_conf)
|
||||
assert "http_client" in inst.kwargs
|
||||
assert "http_async_client" in inst.kwargs
|
||||
|
||||
|
||||
def test_convert_delta_to_message_chunk_developer_and_function_call_and_tool_calls():
|
||||
# developer role -> SystemMessageChunk with __openai_role__
|
||||
delta = {"role": "developer", "content": "dev rules"}
|
||||
msg = _convert_delta_to_message_chunk(delta, SystemMessageChunk)
|
||||
assert isinstance(msg, SystemMessageChunk)
|
||||
assert msg.additional_kwargs.get("__openai_role__") == "developer"
|
||||
|
||||
# function_call name None -> empty string
|
||||
delta = {"role": "assistant", "function_call": {"name": None, "arguments": "{}"}}
|
||||
msg = _convert_delta_to_message_chunk(delta, AIMessageChunk)
|
||||
assert isinstance(msg, AIMessageChunk)
|
||||
assert msg.additional_kwargs["function_call"]["name"] == ""
|
||||
|
||||
# tool_calls: one valid, one missing function -> should not crash and create one chunk
|
||||
delta = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "t1", "index": 0, "function": {"name": "f", "arguments": "{}"}},
|
||||
{"id": "t2", "index": 1}, # missing function key
|
||||
],
|
||||
}
|
||||
msg = _convert_delta_to_message_chunk(delta, AIMessageChunk)
|
||||
assert isinstance(msg, AIMessageChunk)
|
||||
# tool_calls copied as-is
|
||||
assert msg.additional_kwargs["tool_calls"][0]["id"] == "t1"
|
||||
# tool_call_chunks only for valid one
|
||||
assert getattr(msg, "tool_call_chunks") and len(msg.tool_call_chunks) == 1
|
||||
|
||||
|
||||
def test_convert_delta_to_message_chunk_default_class_and_unknown_role():
|
||||
# No role, default human -> HumanMessageChunk
|
||||
delta = {"content": "hey"}
|
||||
msg = _convert_delta_to_message_chunk(delta, HumanMessageChunk)
|
||||
assert isinstance(msg, HumanMessageChunk)
|
||||
|
||||
# Unknown role -> ChatMessageChunk with that role
|
||||
delta = {"role": "observer", "content": "hmm"}
|
||||
msg = _convert_delta_to_message_chunk(delta, ChatMessageChunk)
|
||||
assert isinstance(msg, ChatMessageChunk)
|
||||
assert msg.role == "observer"
|
||||
|
||||
|
||||
def test_convert_chunk_to_generation_chunk_empty_choices_and_usage():
|
||||
chunk = {
|
||||
"choices": [],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||
}
|
||||
gen = _convert_chunk_to_generation_chunk(chunk, AIMessageChunk, None)
|
||||
assert gen is not None
|
||||
assert isinstance(gen.message, AIMessageChunk)
|
||||
assert gen.message.content == ""
|
||||
assert getattr(gen.message, "usage_metadata", None) is not None
|
||||
assert gen.generation_info is None
|
||||
|
||||
|
||||
def test_convert_chunk_to_generation_chunk_includes_base_info_and_logprobs():
|
||||
chunk = {
|
||||
"choices": [
|
||||
{
|
||||
"delta": {"role": "assistant", "content": "T"},
|
||||
"logprobs": {"content": [{"token": "T", "logprob": -0.1}]},
|
||||
}
|
||||
]
|
||||
}
|
||||
base_info = {"headers": {"a": "b"}}
|
||||
gen = _convert_chunk_to_generation_chunk(chunk, AIMessageChunk, base_info)
|
||||
assert gen is not None
|
||||
assert gen.message.content == "T"
|
||||
assert gen.generation_info.get("headers") == {"a": "b"}
|
||||
assert "logprobs" in gen.generation_info
|
||||
|
||||
|
||||
def test_convert_chunk_to_generation_chunk_beta_stream_format():
|
||||
chunk = {
|
||||
"chunk": {
|
||||
"choices": [
|
||||
{"delta": {"role": "assistant", "content": "From beta stream format"}}
|
||||
]
|
||||
}
|
||||
}
|
||||
gen = _convert_chunk_to_generation_chunk(chunk, AIMessageChunk, None)
|
||||
assert gen is not None
|
||||
assert gen.message.content == "From beta stream format"
|
||||
|
||||
|
||||
def test_chatdashscope_create_chat_result_adds_reasoning_content(monkeypatch):
|
||||
# Dummy objects for the super() return
|
||||
class DummyMsg:
|
||||
def __init__(self):
|
||||
self.additional_kwargs = {}
|
||||
|
||||
class DummyGen:
|
||||
def __init__(self):
|
||||
self.message = DummyMsg()
|
||||
|
||||
class DummyChatResult:
|
||||
def __init__(self):
|
||||
self.generations = [DummyGen()]
|
||||
|
||||
# Patch super()._create_chat_result to return our dummy structure
|
||||
def fake_super_create(self, response, generation_info=None):
|
||||
return DummyChatResult()
|
||||
|
||||
monkeypatch.setattr(
|
||||
dashscope_module.ChatOpenAI, "_create_chat_result", fake_super_create
|
||||
)
|
||||
|
||||
# Patch openai.BaseModel in the module under test
|
||||
class DummyBaseModel:
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(dashscope_module.openai, "BaseModel", DummyBaseModel)
|
||||
|
||||
# Build a fake OpenAI-like response with reasoning_content
|
||||
class RMsg:
|
||||
def __init__(self, rc):
|
||||
self.reasoning_content = rc
|
||||
|
||||
class Choice:
|
||||
def __init__(self, rc):
|
||||
self.message = RMsg(rc)
|
||||
|
||||
class FakeResponse(DummyBaseModel):
|
||||
def __init__(self):
|
||||
self.choices = [Choice("Reasoning...")]
|
||||
|
||||
llm = ChatDashscope(model="dummy", api_key="k")
|
||||
result = llm._create_chat_result(FakeResponse())
|
||||
assert (
|
||||
result.generations[0].message.additional_kwargs.get("reasoning_content")
|
||||
== "Reasoning..."
|
||||
)
|
||||
|
||||
|
||||
def test_chatdashscope_create_chat_result_dict_passthrough(monkeypatch):
|
||||
class DummyMsg:
|
||||
def __init__(self):
|
||||
self.additional_kwargs = {}
|
||||
|
||||
class DummyGen:
|
||||
def __init__(self):
|
||||
self.message = DummyMsg()
|
||||
|
||||
class DummyChatResult:
|
||||
def __init__(self):
|
||||
self.generations = [DummyGen()]
|
||||
|
||||
def fake_super_create(self, response, generation_info=None):
|
||||
return DummyChatResult()
|
||||
|
||||
monkeypatch.setattr(
|
||||
dashscope_module.ChatOpenAI, "_create_chat_result", fake_super_create
|
||||
)
|
||||
|
||||
llm = ChatDashscope(model="dummy", api_key="k")
|
||||
result = llm._create_chat_result({"raw": "dict"})
|
||||
# Should not inject reasoning_content for dict responses
|
||||
assert "reasoning_content" not in result.generations[0].message.additional_kwargs
|
||||
@@ -1,127 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
|
||||
from src.llms import llm
|
||||
|
||||
|
||||
class DummyChatOpenAI:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def invoke(self, msg):
|
||||
return f"Echo: {msg}"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_chat_openai(monkeypatch):
|
||||
monkeypatch.setattr(llm, "ChatOpenAI", DummyChatOpenAI)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_conf():
|
||||
return {
|
||||
"BASIC_MODEL": {"api_key": "test_key", "base_url": "http://test"},
|
||||
"REASONING_MODEL": {"api_key": "reason_key"},
|
||||
"VISION_MODEL": {"api_key": "vision_key"},
|
||||
}
|
||||
|
||||
|
||||
def test_get_env_llm_conf(monkeypatch):
|
||||
# Clear any existing environment variables that might interfere
|
||||
monkeypatch.delenv("BASIC_MODEL__API_KEY", raising=False)
|
||||
monkeypatch.delenv("BASIC_MODEL__BASE_URL", raising=False)
|
||||
monkeypatch.delenv("BASIC_MODEL__MODEL", raising=False)
|
||||
|
||||
monkeypatch.setenv("BASIC_MODEL__API_KEY", "env_key")
|
||||
monkeypatch.setenv("BASIC_MODEL__BASE_URL", "http://env")
|
||||
conf = llm._get_env_llm_conf("basic")
|
||||
assert conf["api_key"] == "env_key"
|
||||
assert conf["base_url"] == "http://env"
|
||||
|
||||
|
||||
def test_create_llm_use_conf_merges_env(monkeypatch, dummy_conf):
|
||||
# Clear any existing environment variables that might interfere
|
||||
monkeypatch.delenv("BASIC_MODEL__BASE_URL", raising=False)
|
||||
monkeypatch.delenv("BASIC_MODEL__MODEL", raising=False)
|
||||
monkeypatch.setenv("BASIC_MODEL__API_KEY", "env_key")
|
||||
result = llm._create_llm_use_conf("basic", dummy_conf)
|
||||
assert isinstance(result, DummyChatOpenAI)
|
||||
assert result.kwargs["api_key"] == "env_key"
|
||||
assert result.kwargs["base_url"] == "http://test"
|
||||
|
||||
|
||||
def test_create_llm_use_conf_invalid_type(monkeypatch, dummy_conf):
|
||||
# Clear any existing environment variables that might interfere
|
||||
monkeypatch.delenv("BASIC_MODEL__API_KEY", raising=False)
|
||||
monkeypatch.delenv("BASIC_MODEL__BASE_URL", raising=False)
|
||||
monkeypatch.delenv("BASIC_MODEL__MODEL", raising=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
llm._create_llm_use_conf("unknown", dummy_conf)
|
||||
|
||||
|
||||
def test_create_llm_use_conf_empty_conf(monkeypatch):
|
||||
# Clear any existing environment variables that might interfere
|
||||
monkeypatch.delenv("BASIC_MODEL__API_KEY", raising=False)
|
||||
monkeypatch.delenv("BASIC_MODEL__BASE_URL", raising=False)
|
||||
monkeypatch.delenv("BASIC_MODEL__MODEL", raising=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
llm._create_llm_use_conf("basic", {})
|
||||
|
||||
|
||||
def test_get_llm_by_type_caches(monkeypatch, dummy_conf):
|
||||
called = {}
|
||||
|
||||
def fake_load_yaml_config(path):
|
||||
called["called"] = True
|
||||
return dummy_conf
|
||||
|
||||
monkeypatch.setattr(llm, "load_yaml_config", fake_load_yaml_config)
|
||||
llm._llm_cache.clear()
|
||||
inst1 = llm.get_llm_by_type("basic")
|
||||
inst2 = llm.get_llm_by_type("basic")
|
||||
assert inst1 is inst2
|
||||
assert called["called"]
|
||||
|
||||
|
||||
def test_create_llm_filters_unexpected_keys(monkeypatch, caplog):
|
||||
"""Test that unexpected configuration keys like SEARCH_ENGINE are filtered out (Issue #411)."""
|
||||
import logging
|
||||
|
||||
# Clear any existing environment variables that might interfere
|
||||
monkeypatch.delenv("BASIC_MODEL__API_KEY", raising=False)
|
||||
monkeypatch.delenv("BASIC_MODEL__BASE_URL", raising=False)
|
||||
monkeypatch.delenv("BASIC_MODEL__MODEL", raising=False)
|
||||
|
||||
# Config with unexpected keys that should be filtered
|
||||
conf_with_unexpected_keys = {
|
||||
"BASIC_MODEL": {
|
||||
"api_key": "test_key",
|
||||
"base_url": "http://test",
|
||||
"model": "gpt-4",
|
||||
"SEARCH_ENGINE": {"include_domains": ["example.com"]}, # Should be filtered
|
||||
"engine": "tavily", # Should be filtered
|
||||
}
|
||||
}
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = llm._create_llm_use_conf("basic", conf_with_unexpected_keys)
|
||||
|
||||
# Verify the LLM was created
|
||||
assert isinstance(result, DummyChatOpenAI)
|
||||
|
||||
# Verify unexpected keys were not passed to the LLM
|
||||
assert "SEARCH_ENGINE" not in result.kwargs
|
||||
assert "engine" not in result.kwargs
|
||||
|
||||
# Verify valid keys were passed
|
||||
assert result.kwargs["api_key"] == "test_key"
|
||||
assert result.kwargs["base_url"] == "http://test"
|
||||
assert result.kwargs["model"] == "gpt-4"
|
||||
|
||||
# Verify warnings were logged
|
||||
assert any("SEARCH_ENGINE" in record.message for record in caplog.records)
|
||||
assert any("engine" in record.message for record in caplog.records)
|
||||
@@ -1,2 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
@@ -1,214 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from src.podcast.graph.script_writer_node import script_writer_node
|
||||
from src.podcast.types import Script, ScriptLine
|
||||
|
||||
|
||||
class TestScriptWriterNode:
|
||||
"""Tests for script_writer_node function."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_state(self):
|
||||
"""Create a sample podcast state."""
|
||||
return {"input": "Test content for podcast generation"}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_script(self):
|
||||
"""Create a sample Script object."""
|
||||
return Script(
|
||||
locale="en",
|
||||
lines=[
|
||||
ScriptLine(speaker="male", paragraph="Hello, welcome to our podcast."),
|
||||
ScriptLine(speaker="female", paragraph="Today we discuss testing."),
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_script_json(self, sample_script):
|
||||
"""Create JSON representation of sample script."""
|
||||
return sample_script.model_dump_json()
|
||||
|
||||
@patch("src.podcast.graph.script_writer_node.get_prompt_template")
|
||||
@patch("src.podcast.graph.script_writer_node.get_llm_by_type")
|
||||
def test_script_writer_with_json_mode_success(
|
||||
self, mock_get_llm, mock_get_template, sample_state, sample_script
|
||||
):
|
||||
"""Test successful script generation using json_mode."""
|
||||
mock_get_template.return_value = "Generate a podcast script."
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_structured_model = MagicMock()
|
||||
mock_model.with_structured_output.return_value = mock_structured_model
|
||||
mock_structured_model.invoke.return_value = sample_script
|
||||
mock_get_llm.return_value = mock_model
|
||||
|
||||
result = script_writer_node(sample_state)
|
||||
|
||||
assert result["script"] == sample_script
|
||||
assert result["audio_chunks"] == []
|
||||
mock_model.with_structured_output.assert_called_once_with(
|
||||
Script, method="json_mode"
|
||||
)
|
||||
|
||||
@patch("src.podcast.graph.script_writer_node.get_prompt_template")
|
||||
@patch("src.podcast.graph.script_writer_node.get_llm_by_type")
|
||||
def test_script_writer_fallback_on_json_object_not_supported(
|
||||
self, mock_get_llm, mock_get_template, sample_state, sample_script_json
|
||||
):
|
||||
"""Test fallback to prompting when model doesn't support json_object."""
|
||||
mock_get_template.return_value = "Generate a podcast script."
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_structured_model = MagicMock()
|
||||
mock_model.with_structured_output.return_value = mock_structured_model
|
||||
|
||||
# Simulate json_object not supported error
|
||||
mock_structured_model.invoke.side_effect = openai.BadRequestError(
|
||||
message="json_object is not supported by this model",
|
||||
response=MagicMock(status_code=400),
|
||||
body={"error": {"message": "json_object is not supported"}},
|
||||
)
|
||||
|
||||
# Mock the fallback response
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = sample_script_json
|
||||
mock_model.invoke.return_value = mock_response
|
||||
|
||||
mock_get_llm.return_value = mock_model
|
||||
|
||||
result = script_writer_node(sample_state)
|
||||
|
||||
assert result["script"].locale == "en"
|
||||
assert len(result["script"].lines) == 2
|
||||
assert result["audio_chunks"] == []
|
||||
# Verify fallback was used
|
||||
mock_model.invoke.assert_called_once()
|
||||
|
||||
@patch("src.podcast.graph.script_writer_node.get_prompt_template")
|
||||
@patch("src.podcast.graph.script_writer_node.get_llm_by_type")
|
||||
def test_script_writer_reraises_other_bad_request_errors(
|
||||
self, mock_get_llm, mock_get_template, sample_state
|
||||
):
|
||||
"""Test that other BadRequestError types are re-raised."""
|
||||
mock_get_template.return_value = "Generate a podcast script."
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_structured_model = MagicMock()
|
||||
mock_model.with_structured_output.return_value = mock_structured_model
|
||||
|
||||
# Simulate a different BadRequestError (not json_object related)
|
||||
mock_structured_model.invoke.side_effect = openai.BadRequestError(
|
||||
message="Invalid model parameter",
|
||||
response=MagicMock(status_code=400),
|
||||
body={"error": {"message": "Invalid model parameter"}},
|
||||
)
|
||||
|
||||
mock_get_llm.return_value = mock_model
|
||||
|
||||
with pytest.raises(openai.BadRequestError) as exc_info:
|
||||
script_writer_node(sample_state)
|
||||
|
||||
assert "Invalid model parameter" in str(exc_info.value)
|
||||
|
||||
@patch("src.podcast.graph.script_writer_node.get_prompt_template")
|
||||
@patch("src.podcast.graph.script_writer_node.get_llm_by_type")
|
||||
def test_script_writer_fallback_with_markdown_wrapped_json(
|
||||
self, mock_get_llm, mock_get_template, sample_state
|
||||
):
|
||||
"""Test fallback handles JSON wrapped in markdown code blocks."""
|
||||
mock_get_template.return_value = "Generate a podcast script."
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_structured_model = MagicMock()
|
||||
mock_model.with_structured_output.return_value = mock_structured_model
|
||||
|
||||
mock_structured_model.invoke.side_effect = openai.BadRequestError(
|
||||
message="json_object is not supported",
|
||||
response=MagicMock(status_code=400),
|
||||
body={},
|
||||
)
|
||||
|
||||
# Mock response with markdown-wrapped JSON (common LLM output)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = """```json
|
||||
{
|
||||
"locale": "zh",
|
||||
"lines": [
|
||||
{"speaker": "male", "paragraph": "欢迎收听播客。"}
|
||||
]
|
||||
}
|
||||
```"""
|
||||
mock_model.invoke.return_value = mock_response
|
||||
|
||||
mock_get_llm.return_value = mock_model
|
||||
|
||||
result = script_writer_node(sample_state)
|
||||
|
||||
assert result["script"].locale == "zh"
|
||||
assert len(result["script"].lines) == 1
|
||||
assert result["script"].lines[0].speaker == "male"
|
||||
|
||||
@patch("src.podcast.graph.script_writer_node.get_prompt_template")
|
||||
@patch("src.podcast.graph.script_writer_node.get_llm_by_type")
|
||||
def test_script_writer_fallback_raises_on_invalid_json(
|
||||
self, mock_get_llm, mock_get_template, sample_state
|
||||
):
|
||||
"""Test that fallback raises JSONDecodeError when response is not valid JSON."""
|
||||
mock_get_template.return_value = "Generate a podcast script."
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_structured_model = MagicMock()
|
||||
mock_model.with_structured_output.return_value = mock_structured_model
|
||||
|
||||
mock_structured_model.invoke.side_effect = openai.BadRequestError(
|
||||
message="json_object is not supported",
|
||||
response=MagicMock(status_code=400),
|
||||
body={},
|
||||
)
|
||||
|
||||
# Mock response with completely invalid JSON
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "This is not JSON at all, just plain text response."
|
||||
mock_model.invoke.return_value = mock_response
|
||||
|
||||
mock_get_llm.return_value = mock_model
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
script_writer_node(sample_state)
|
||||
|
||||
@patch("src.podcast.graph.script_writer_node.get_prompt_template")
|
||||
@patch("src.podcast.graph.script_writer_node.get_llm_by_type")
|
||||
def test_script_writer_fallback_raises_on_invalid_schema(
|
||||
self, mock_get_llm, mock_get_template, sample_state
|
||||
):
|
||||
"""Test that fallback raises ValidationError when JSON doesn't match Script schema."""
|
||||
mock_get_template.return_value = "Generate a podcast script."
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_structured_model = MagicMock()
|
||||
mock_model.with_structured_output.return_value = mock_structured_model
|
||||
|
||||
mock_structured_model.invoke.side_effect = openai.BadRequestError(
|
||||
message="json_object is not supported",
|
||||
response=MagicMock(status_code=400),
|
||||
body={},
|
||||
)
|
||||
|
||||
# Mock response with valid JSON but invalid schema (missing required fields, wrong types)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '{"locale": "invalid_locale", "lines": "not_a_list"}'
|
||||
mock_model.invoke.return_value = mock_response
|
||||
|
||||
mock_get_llm.return_value = mock_model
|
||||
|
||||
# Pydantic ValidationError is raised when schema validation fails
|
||||
from pydantic import ValidationError
|
||||
with pytest.raises(ValidationError):
|
||||
script_writer_node(sample_state)
|
||||
@@ -1,2 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
@@ -1,2 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
@@ -1,156 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.prompt_enhancer.graph.builder import build_graph
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
|
||||
|
||||
class TestBuildGraph:
|
||||
"""Test cases for build_graph function."""
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_structure(self, mock_state_graph):
|
||||
"""Test that build_graph creates the correct graph structure."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
result = build_graph()
|
||||
|
||||
# Verify StateGraph was created with correct state type
|
||||
mock_state_graph.assert_called_once_with(PromptEnhancerState)
|
||||
|
||||
# Verify entry point was set
|
||||
mock_builder.set_entry_point.assert_called_once_with("enhancer")
|
||||
|
||||
# Verify finish point was set
|
||||
mock_builder.set_finish_point.assert_called_once_with("enhancer")
|
||||
|
||||
# Verify graph was compiled
|
||||
mock_builder.compile.assert_called_once()
|
||||
|
||||
# Verify return value
|
||||
assert result == mock_compiled_graph
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
@patch("src.prompt_enhancer.graph.builder.prompt_enhancer_node")
|
||||
def test_build_graph_node_function(self, mock_enhancer_node, mock_state_graph):
|
||||
"""Test that the correct node function is added to the graph."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify the correct node function was added
|
||||
mock_builder.add_node.assert_called_once_with("enhancer", mock_enhancer_node)
|
||||
|
||||
def test_build_graph_returns_compiled_graph(self):
|
||||
"""Test that build_graph returns a compiled graph object."""
|
||||
with patch("src.prompt_enhancer.graph.builder.StateGraph") as mock_state_graph:
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
result = build_graph()
|
||||
|
||||
assert result is mock_compiled_graph
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_call_sequence(self, mock_state_graph):
|
||||
"""Test that build_graph calls methods in the correct sequence."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
# Track call order
|
||||
call_order = []
|
||||
|
||||
def track_add_node(*args, **kwargs):
|
||||
call_order.append("add_node")
|
||||
|
||||
def track_set_entry_point(*args, **kwargs):
|
||||
call_order.append("set_entry_point")
|
||||
|
||||
def track_set_finish_point(*args, **kwargs):
|
||||
call_order.append("set_finish_point")
|
||||
|
||||
def track_compile(*args, **kwargs):
|
||||
call_order.append("compile")
|
||||
return mock_compiled_graph
|
||||
|
||||
mock_builder.add_node.side_effect = track_add_node
|
||||
mock_builder.set_entry_point.side_effect = track_set_entry_point
|
||||
mock_builder.set_finish_point.side_effect = track_set_finish_point
|
||||
mock_builder.compile.side_effect = track_compile
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify the correct call sequence
|
||||
expected_order = ["add_node", "set_entry_point", "set_finish_point", "compile"]
|
||||
assert call_order == expected_order
|
||||
|
||||
def test_build_graph_integration(self):
|
||||
"""Integration test to verify the graph can be built without mocking."""
|
||||
# This test verifies that all imports and dependencies are correct
|
||||
try:
|
||||
graph = build_graph()
|
||||
assert graph is not None
|
||||
# The graph should be a compiled LangGraph object
|
||||
assert hasattr(graph, "invoke") or hasattr(graph, "stream")
|
||||
except ImportError as e:
|
||||
pytest.skip(f"Skipping integration test due to missing dependencies: {e}")
|
||||
except Exception as e:
|
||||
# If there are configuration issues (like missing LLM config),
|
||||
# we still consider the test successful if the graph structure is built
|
||||
if "LLM" in str(e) or "configuration" in str(e).lower():
|
||||
pytest.skip(
|
||||
f"Skipping integration test due to configuration issues: {e}"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_single_node_workflow(self, mock_state_graph):
|
||||
"""Test that the graph is configured as a single-node workflow."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify only one node is added
|
||||
assert mock_builder.add_node.call_count == 1
|
||||
|
||||
# Verify entry and finish points are the same node
|
||||
mock_builder.set_entry_point.assert_called_once_with("enhancer")
|
||||
mock_builder.set_finish_point.assert_called_once_with("enhancer")
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_state_type(self, mock_state_graph):
|
||||
"""Test that the graph is initialized with the correct state type."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify StateGraph was initialized with PromptEnhancerState
|
||||
args, kwargs = mock_state_graph.call_args
|
||||
assert args[0] == PromptEnhancerState
|
||||
@@ -1,526 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.prompt_enhancer.graph.enhancer_node import prompt_enhancer_node
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
"""Mock LLM that returns a test response."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(
|
||||
content="""Thoughts: LLM thinks a lot
|
||||
<enhanced_prompt>
|
||||
Enhanced test prompt
|
||||
</enhanced_prompt>
|
||||
"""
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_xml_with_whitespace():
|
||||
"""Mock LLM that returns XML response with extra whitespace."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(
|
||||
content="""
|
||||
Some thoughts here...
|
||||
|
||||
<enhanced_prompt>
|
||||
|
||||
Enhanced prompt with whitespace
|
||||
|
||||
</enhanced_prompt>
|
||||
|
||||
Additional content after XML
|
||||
"""
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_xml_multiline():
|
||||
"""Mock LLM that returns XML response with multiline content."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(
|
||||
content="""
|
||||
<enhanced_prompt>
|
||||
This is a multiline enhanced prompt
|
||||
that spans multiple lines
|
||||
and includes various formatting.
|
||||
|
||||
It should preserve the structure.
|
||||
</enhanced_prompt>
|
||||
"""
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_no_xml():
|
||||
"""Mock LLM that returns response without XML tags."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(
|
||||
content="Enhanced Prompt: This is an enhanced prompt without XML tags"
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_malformed_xml():
|
||||
"""Mock LLM that returns response with malformed XML."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(
|
||||
content="""
|
||||
<enhanced_prompt>
|
||||
This XML tag is not properly closed
|
||||
<enhanced_prompt>
|
||||
"""
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messages():
|
||||
"""Mock messages returned by apply_prompt_template."""
|
||||
return [
|
||||
SystemMessage(content="System prompt template"),
|
||||
HumanMessage(content="Test human message"),
|
||||
]
|
||||
|
||||
|
||||
class TestPromptEnhancerNode:
|
||||
"""Test cases for prompt_enhancer_node function."""
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_basic_prompt_enhancement(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test basic prompt enhancement without context or report style."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(prompt="Write about AI")
|
||||
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Verify LLM was called
|
||||
mock_get_llm.assert_called_once_with("basic")
|
||||
mock_llm.invoke.assert_called_once_with(mock_messages)
|
||||
|
||||
# Verify apply_prompt_template was called correctly
|
||||
mock_apply_template.assert_called_once()
|
||||
call_args = mock_apply_template.call_args
|
||||
assert call_args[0][0] == "prompt_enhancer/prompt_enhancer"
|
||||
assert "messages" in call_args[0][1]
|
||||
assert "report_style" in call_args[0][1]
|
||||
|
||||
# Verify result
|
||||
assert result == {"output": "Enhanced test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prompt_enhancement_with_report_style(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test prompt enhancement with report style."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(
|
||||
prompt="Write about AI", report_style=ReportStyle.ACADEMIC
|
||||
)
|
||||
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Verify apply_prompt_template was called with report_style
|
||||
mock_apply_template.assert_called_once()
|
||||
call_args = mock_apply_template.call_args
|
||||
assert call_args[0][0] == "prompt_enhancer/prompt_enhancer"
|
||||
assert call_args[0][1]["report_style"] == ReportStyle.ACADEMIC
|
||||
|
||||
# Verify result
|
||||
assert result == {"output": "Enhanced test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prompt_enhancement_with_context(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test prompt enhancement with additional context."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(
|
||||
prompt="Write about AI", context="Focus on machine learning applications"
|
||||
)
|
||||
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Verify apply_prompt_template was called
|
||||
mock_apply_template.assert_called_once()
|
||||
call_args = mock_apply_template.call_args
|
||||
|
||||
# Check that the context was included in the human message
|
||||
messages_arg = call_args[0][1]["messages"]
|
||||
assert len(messages_arg) == 1
|
||||
human_message = messages_arg[0]
|
||||
assert isinstance(human_message, HumanMessage)
|
||||
assert "Focus on machine learning applications" in human_message.content
|
||||
|
||||
assert result == {"output": "Enhanced test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_error_handling(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test error handling when LLM call fails."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Mock LLM to raise an exception
|
||||
mock_llm.invoke.side_effect = Exception("LLM error")
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Should return original prompt on error
|
||||
assert result == {"output": "Test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_template_error_handling(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test error handling when template application fails."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# Mock apply_prompt_template to raise an exception
|
||||
mock_apply_template.side_effect = Exception("Template error")
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Should return original prompt on error
|
||||
assert result == {"output": "Test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prefix_removal(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test that common prefixes are removed from LLM response."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Test different prefixes that should be removed
|
||||
test_cases = [
|
||||
"Enhanced Prompt: This is the enhanced prompt",
|
||||
"Enhanced prompt: This is the enhanced prompt",
|
||||
"Here's the enhanced prompt: This is the enhanced prompt",
|
||||
"Here is the enhanced prompt: This is the enhanced prompt",
|
||||
"**Enhanced Prompt**: This is the enhanced prompt",
|
||||
"**Enhanced prompt**: This is the enhanced prompt",
|
||||
]
|
||||
|
||||
for response_with_prefix in test_cases:
|
||||
mock_llm.invoke.return_value = MagicMock(content=response_with_prefix)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "This is the enhanced prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_whitespace_handling(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test that whitespace is properly stripped from LLM response."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Mock LLM response with extra whitespace
|
||||
mock_llm.invoke.return_value = MagicMock(
|
||||
content=" \n\n Enhanced prompt \n\n "
|
||||
)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "Enhanced prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_xml_with_whitespace_handling(
|
||||
self,
|
||||
mock_get_llm,
|
||||
mock_apply_template,
|
||||
mock_llm_xml_with_whitespace,
|
||||
mock_messages,
|
||||
):
|
||||
"""Test XML extraction with extra whitespace inside tags."""
|
||||
mock_get_llm.return_value = mock_llm_xml_with_whitespace
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "Enhanced prompt with whitespace"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_xml_multiline_content(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm_xml_multiline, mock_messages
|
||||
):
|
||||
"""Test XML extraction with multiline content."""
|
||||
mock_get_llm.return_value = mock_llm_xml_multiline
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
expected_output = """This is a multiline enhanced prompt
|
||||
that spans multiple lines
|
||||
and includes various formatting.
|
||||
|
||||
It should preserve the structure."""
|
||||
assert result == {"output": expected_output}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_fallback_to_prefix_removal(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm_no_xml, mock_messages
|
||||
):
|
||||
"""Test fallback to prefix removal when no XML tags are found."""
|
||||
mock_get_llm.return_value = mock_llm_no_xml
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "This is an enhanced prompt without XML tags"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_malformed_xml_fallback(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm_malformed_xml, mock_messages
|
||||
):
|
||||
"""Test handling of malformed XML tags."""
|
||||
mock_get_llm.return_value = mock_llm_malformed_xml
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Should fall back to using the entire content since XML is malformed
|
||||
expected_content = """<enhanced_prompt>
|
||||
This XML tag is not properly closed
|
||||
<enhanced_prompt>"""
|
||||
assert result == {"output": expected_content}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_case_sensitive_prefix_removal(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test that prefix removal is case-sensitive."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Test case variations that should NOT be removed
|
||||
test_cases = [
|
||||
"ENHANCED PROMPT: This should not be removed",
|
||||
"enhanced prompt: This should not be removed",
|
||||
"Enhanced Prompt This should not be removed", # Missing colon
|
||||
"Enhanced Prompt :: This should not be removed", # Double colon
|
||||
]
|
||||
|
||||
for response_content in test_cases:
|
||||
mock_llm.invoke.return_value = MagicMock(content=response_content)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Should return the full content since prefix doesn't match exactly
|
||||
assert result == {"output": response_content}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prefix_with_extra_whitespace(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test prefix removal with extra whitespace after colon."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
test_cases = [
|
||||
("Enhanced Prompt: This has extra spaces", "This has extra spaces"),
|
||||
("Enhanced prompt:\t\tThis has tabs", "This has tabs"),
|
||||
("Here's the enhanced prompt:\n\nThis has newlines", "This has newlines"),
|
||||
]
|
||||
|
||||
for response_content, expected_output in test_cases:
|
||||
mock_llm.invoke.return_value = MagicMock(content=response_content)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": expected_output}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_xml_with_special_characters(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test XML extraction with special characters and symbols."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
special_content = """<enhanced_prompt>
|
||||
Enhanced prompt with special chars: @#$%^&*()
|
||||
Unicode: 🚀 ✨ 💡
|
||||
Quotes: "double" and 'single'
|
||||
Backslashes: \\n \\t \\r
|
||||
</enhanced_prompt>"""
|
||||
|
||||
mock_llm.invoke.return_value = MagicMock(content=special_content)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
expected_output = """Enhanced prompt with special chars: @#$%^&*()
|
||||
Unicode: 🚀 ✨ 💡
|
||||
Quotes: "double" and 'single'
|
||||
Backslashes: \\n \\t \\r"""
|
||||
assert result == {"output": expected_output}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_very_long_response(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test handling of very long LLM responses."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Create a very long response
|
||||
long_content = "This is a very long enhanced prompt. " * 100
|
||||
xml_response = f"<enhanced_prompt>\n{long_content}\n</enhanced_prompt>"
|
||||
|
||||
mock_llm.invoke.return_value = MagicMock(content=xml_response)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": long_content.strip()}
|
||||
assert len(result["output"]) > 1000 # Verify it's actually long
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_empty_response_content(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test handling of empty response content."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
mock_llm.invoke.return_value = MagicMock(content="")
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": ""}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_only_whitespace_response(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test handling of response with only whitespace."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
mock_llm.invoke.return_value = MagicMock(content=" \n\n\t\t ")
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": ""}
|
||||
@@ -1,107 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_creation():
|
||||
"""Test that PromptEnhancerState can be created with required fields."""
|
||||
state = PromptEnhancerState(
|
||||
prompt="Test prompt", context=None, report_style=None, output=None
|
||||
)
|
||||
|
||||
assert state["prompt"] == "Test prompt"
|
||||
assert state["context"] is None
|
||||
assert state["report_style"] is None
|
||||
assert state["output"] is None
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_with_all_fields():
|
||||
"""Test PromptEnhancerState with all fields populated."""
|
||||
state = PromptEnhancerState(
|
||||
prompt="Write about AI",
|
||||
context="Additional context about AI research",
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
output="Enhanced prompt about AI research",
|
||||
)
|
||||
|
||||
assert state["prompt"] == "Write about AI"
|
||||
assert state["context"] == "Additional context about AI research"
|
||||
assert state["report_style"] == ReportStyle.ACADEMIC
|
||||
assert state["output"] == "Enhanced prompt about AI research"
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_minimal():
|
||||
"""Test PromptEnhancerState with only required prompt field."""
|
||||
state = PromptEnhancerState(prompt="Minimal prompt")
|
||||
|
||||
assert state["prompt"] == "Minimal prompt"
|
||||
# Optional fields should not be present if not specified
|
||||
assert "context" not in state
|
||||
assert "report_style" not in state
|
||||
assert "output" not in state
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_with_different_report_styles():
|
||||
"""Test PromptEnhancerState with different ReportStyle values."""
|
||||
styles = [
|
||||
ReportStyle.ACADEMIC,
|
||||
ReportStyle.POPULAR_SCIENCE,
|
||||
ReportStyle.NEWS,
|
||||
ReportStyle.SOCIAL_MEDIA,
|
||||
]
|
||||
|
||||
for style in styles:
|
||||
state = PromptEnhancerState(prompt="Test prompt", report_style=style)
|
||||
assert state["report_style"] == style
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_update():
|
||||
"""Test updating PromptEnhancerState fields."""
|
||||
state = PromptEnhancerState(prompt="Original prompt")
|
||||
|
||||
# Update with new fields
|
||||
state.update(
|
||||
{
|
||||
"context": "New context",
|
||||
"report_style": ReportStyle.NEWS,
|
||||
"output": "Enhanced output",
|
||||
}
|
||||
)
|
||||
|
||||
assert state["prompt"] == "Original prompt"
|
||||
assert state["context"] == "New context"
|
||||
assert state["report_style"] == ReportStyle.NEWS
|
||||
assert state["output"] == "Enhanced output"
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_get_method():
|
||||
"""Test using get() method on PromptEnhancerState."""
|
||||
state = PromptEnhancerState(prompt="Test prompt", report_style=ReportStyle.ACADEMIC)
|
||||
|
||||
# Test get with existing keys
|
||||
assert state.get("prompt") == "Test prompt"
|
||||
assert state.get("report_style") == ReportStyle.ACADEMIC
|
||||
|
||||
# Test get with non-existing keys
|
||||
assert state.get("context") is None
|
||||
assert state.get("output") is None
|
||||
assert state.get("nonexistent", "default") == "default"
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_type_annotations():
|
||||
"""Test that the state accepts correct types."""
|
||||
# This test ensures the TypedDict structure is working correctly
|
||||
state = PromptEnhancerState(
|
||||
prompt="Test prompt",
|
||||
context="Test context",
|
||||
report_style=ReportStyle.POPULAR_SCIENCE,
|
||||
output="Test output",
|
||||
)
|
||||
|
||||
# Verify types
|
||||
assert isinstance(state["prompt"], str)
|
||||
assert isinstance(state["context"], str)
|
||||
assert isinstance(state["report_style"], ReportStyle)
|
||||
assert isinstance(state["output"], str)
|
||||
@@ -1,154 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.rag.dify import DifyProvider, parse_uri
|
||||
|
||||
|
||||
# Dummy classes to mock dependencies
|
||||
class DummyResource:
|
||||
def __init__(self, uri, title="", description=""):
|
||||
self.uri = uri
|
||||
self.title = title
|
||||
self.description = description
|
||||
|
||||
|
||||
class DummyChunk:
|
||||
def __init__(self, content, similarity):
|
||||
self.content = content
|
||||
self.similarity = similarity
|
||||
|
||||
|
||||
class DummyDocument:
|
||||
def __init__(self, id, title, chunks=None):
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.chunks = chunks or []
|
||||
|
||||
|
||||
# Patch imports in dify.py to use dummy classes
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_imports(monkeypatch):
|
||||
import src.rag.dify as dify
|
||||
|
||||
dify.Resource = DummyResource
|
||||
dify.Chunk = DummyChunk
|
||||
dify.Document = DummyDocument
|
||||
yield
|
||||
|
||||
|
||||
def test_parse_uri_valid():
|
||||
uri = "rag://dataset/123#abc"
|
||||
dataset_id, document_id = parse_uri(uri)
|
||||
assert dataset_id == "123"
|
||||
assert document_id == "abc"
|
||||
|
||||
|
||||
def test_parse_uri_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
parse_uri("http://dataset/123#abc")
|
||||
|
||||
|
||||
def test_init_env_vars(monkeypatch):
|
||||
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||
provider = DifyProvider()
|
||||
assert provider.api_url == "http://api"
|
||||
assert provider.api_key == "key"
|
||||
|
||||
|
||||
def test_init_missing_env(monkeypatch):
|
||||
monkeypatch.delenv("DIFY_API_URL", raising=False)
|
||||
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||
with pytest.raises(ValueError):
|
||||
DifyProvider()
|
||||
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||
monkeypatch.delenv("DIFY_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError):
|
||||
DifyProvider()
|
||||
|
||||
|
||||
@patch("src.rag.dify.requests.post")
|
||||
def test_query_relevant_documents_success(mock_post, monkeypatch):
|
||||
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||
provider = DifyProvider()
|
||||
resource = DummyResource("rag://dataset/123#doc456")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"records": [
|
||||
{
|
||||
"segment": {
|
||||
"content": "chunk text",
|
||||
"document": {
|
||||
"id": "doc456",
|
||||
"name": "Doc Title",
|
||||
},
|
||||
},
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
docs = provider.query_relevant_documents("query", [resource])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "doc456"
|
||||
assert docs[0].title == "Doc Title"
|
||||
assert len(docs[0].chunks) == 1
|
||||
assert docs[0].chunks[0].content == "chunk text"
|
||||
assert docs[0].chunks[0].similarity == 0.9
|
||||
|
||||
|
||||
@patch("src.rag.dify.requests.post")
|
||||
def test_query_relevant_documents_error(mock_post, monkeypatch):
|
||||
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||
provider = DifyProvider()
|
||||
resource = DummyResource("rag://dataset/123#doc456")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.text = "error"
|
||||
mock_post.return_value = mock_response
|
||||
with pytest.raises(Exception):
|
||||
provider.query_relevant_documents("query", [resource])
|
||||
|
||||
|
||||
@patch("src.rag.dify.requests.get")
|
||||
def test_list_resources_success(mock_get, monkeypatch):
|
||||
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||
provider = DifyProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [
|
||||
{"id": "123", "name": "Dataset1", "description": "desc1"},
|
||||
{"id": "456", "name": "Dataset2", "description": "desc2"},
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
resources = provider.list_resources()
|
||||
assert len(resources) == 2
|
||||
assert resources[0].uri == "rag://dataset/123"
|
||||
assert resources[0].title == "Dataset1"
|
||||
assert resources[0].description == "desc1"
|
||||
assert resources[1].uri == "rag://dataset/456"
|
||||
assert resources[1].title == "Dataset2"
|
||||
assert resources[1].description == "desc2"
|
||||
|
||||
|
||||
@patch("src.rag.dify.requests.get")
|
||||
def test_list_resources_error(mock_get, monkeypatch):
|
||||
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||
provider = DifyProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "fail"
|
||||
mock_get.return_value = mock_response
|
||||
with pytest.raises(Exception):
|
||||
provider.list_resources()
|
||||
@@ -1,930 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Tests for Milvus RAG provider.
|
||||
|
||||
IMPORTANT NOTE: This test file creates temporary directories for testing examples
|
||||
functionality. All temporary directories are automatically cleaned up using pytest
|
||||
fixtures. When adding new tests that create temporary directories:
|
||||
|
||||
1. Use the provided fixtures (temp_examples_dir, temp_error_examples_dir, etc.)
|
||||
2. Never create temporary directories without automatic cleanup
|
||||
3. Follow the pattern: fixture -> use -> automatic cleanup
|
||||
4. If you need a new directory pattern, create a corresponding fixture
|
||||
|
||||
This ensures tests don't leave behind temporary files that clutter the workspace.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
import src.rag.milvus as milvus_mod
|
||||
from src.rag.milvus import MilvusProvider
|
||||
from src.rag.retriever import Resource
|
||||
|
||||
|
||||
class DummyEmbedding:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def embed_query(self, text: str):
|
||||
return [0.1, 0.2, 0.3]
|
||||
|
||||
def embed_documents(self, texts):
|
||||
return [[0.1, 0.2, 0.3] for _ in texts]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_embeddings(monkeypatch):
|
||||
# Prevent network / external API usage during __init__
|
||||
monkeypatch.setenv("MILVUS_EMBEDDING_PROVIDER", "openai")
|
||||
monkeypatch.setenv("MILVUS_EMBEDDING_MODEL", "text-embedding-ada-002")
|
||||
monkeypatch.setenv("MILVUS_COLLECTION", "documents")
|
||||
monkeypatch.setenv("MILVUS_URI", "./milvus_demo.db") # default lite
|
||||
monkeypatch.setattr(milvus_mod, "OpenAIEmbeddings", DummyEmbedding)
|
||||
monkeypatch.setattr(milvus_mod, "DashscopeEmbeddings", DummyEmbedding)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_root():
|
||||
# Mirror logic from implementation: current_file.parent.parent.parent
|
||||
return Path(milvus_mod.__file__).parent.parent.parent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_examples_dir(project_root):
|
||||
"""Create a temporary examples directory with automatic cleanup."""
|
||||
# Create a unique temporary directory name
|
||||
temp_dir_name = f"examples_test_{uuid4().hex}"
|
||||
temp_dir_path = project_root / temp_dir_name
|
||||
|
||||
# Create the directory
|
||||
temp_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
yield temp_dir_path
|
||||
|
||||
# Cleanup: remove the directory and all its contents
|
||||
if temp_dir_path.exists():
|
||||
shutil.rmtree(temp_dir_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_error_examples_dir(project_root):
|
||||
"""Create a temporary error examples directory with automatic cleanup."""
|
||||
# Create a unique temporary directory name for error tests
|
||||
temp_dir_name = f"examples_error_{uuid4().hex}"
|
||||
temp_dir_path = project_root / temp_dir_name
|
||||
|
||||
# Create the directory
|
||||
temp_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
yield temp_dir_path
|
||||
|
||||
# Cleanup: remove the directory and all its contents
|
||||
if temp_dir_path.exists():
|
||||
shutil.rmtree(temp_dir_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_load_skip_examples_dir(project_root):
|
||||
"""Create a temporary examples directory for load_skip tests with automatic cleanup."""
|
||||
# Use the expected directory name for this test
|
||||
temp_dir_name = "examples_test_load_skip"
|
||||
temp_dir_path = project_root / temp_dir_name
|
||||
|
||||
# Create the directory if it doesn't exist
|
||||
temp_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
yield temp_dir_path
|
||||
|
||||
# Cleanup: remove the directory and all its contents
|
||||
if temp_dir_path.exists():
|
||||
shutil.rmtree(temp_dir_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_single_chunk_examples_dir(project_root):
|
||||
"""Create a temporary examples directory for single_chunk tests with automatic cleanup."""
|
||||
# Use the expected directory name for this test
|
||||
temp_dir_name = "examples_test_single_chunk"
|
||||
temp_dir_path = project_root / temp_dir_name
|
||||
|
||||
# Create the directory if it doesn't exist
|
||||
temp_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
yield temp_dir_path
|
||||
|
||||
# Cleanup: remove the directory and all its contents
|
||||
if temp_dir_path.exists():
|
||||
shutil.rmtree(temp_dir_path)
|
||||
|
||||
|
||||
def _patch_init(monkeypatch):
|
||||
"""Patch retriever initialization to use dummy embedding model."""
|
||||
monkeypatch.setattr(
|
||||
MilvusProvider,
|
||||
"_init_embedding_model",
|
||||
lambda self: setattr(self, "embedding_model", DummyEmbedding()),
|
||||
)
|
||||
|
||||
|
||||
def test_list_local_markdown_resources_missing_dir(project_root):
|
||||
retriever = MilvusProvider()
|
||||
# Point to a non-existent examples dir
|
||||
retriever.examples_dir = f"missing_examples_{uuid4().hex}"
|
||||
resources = retriever._list_local_markdown_resources()
|
||||
assert resources == []
|
||||
|
||||
|
||||
def test_list_local_markdown_resources_populated(temp_examples_dir):
|
||||
retriever = MilvusProvider()
|
||||
# Use the name of the temp directory for examples_dir
|
||||
retriever.examples_dir = temp_examples_dir.name
|
||||
|
||||
# File with heading
|
||||
(temp_examples_dir / "file1.md").write_text(
|
||||
"# Title One\n\nContent body.", encoding="utf-8"
|
||||
)
|
||||
# File without heading -> fallback title
|
||||
(temp_examples_dir / "file_two.md").write_text("No heading here.", encoding="utf-8")
|
||||
# Non-markdown file should be ignored
|
||||
(temp_examples_dir / "ignore.txt").write_text(
|
||||
"Should not be picked up.", encoding="utf-8"
|
||||
)
|
||||
|
||||
resources = retriever._list_local_markdown_resources()
|
||||
# Order not guaranteed; sort by uri for assertions
|
||||
resources.sort(key=lambda r: r.uri)
|
||||
|
||||
# Expect two resources
|
||||
assert len(resources) == 2
|
||||
uris = {r.uri for r in resources}
|
||||
assert uris == {
|
||||
f"milvus://{retriever.collection_name}/file1.md",
|
||||
f"milvus://{retriever.collection_name}/file_two.md",
|
||||
}
|
||||
|
||||
res_map = {r.uri: r for r in resources}
|
||||
r1 = res_map[f"milvus://{retriever.collection_name}/file1.md"]
|
||||
assert isinstance(r1, Resource)
|
||||
assert r1.title == "Title One"
|
||||
assert r1.description == "Local markdown example (not yet ingested)"
|
||||
|
||||
r2 = res_map[f"milvus://{retriever.collection_name}/file_two.md"]
|
||||
# Fallback logic: filename -> "file_two" -> "file two" -> title case -> "File Two"
|
||||
assert r2.title == "File Two"
|
||||
assert r2.description == "Local markdown example (not yet ingested)"
|
||||
|
||||
|
||||
def test_list_local_markdown_resources_read_error(monkeypatch, temp_error_examples_dir):
|
||||
retriever = MilvusProvider()
|
||||
# Use the name of the temp directory for examples_dir
|
||||
retriever.examples_dir = temp_error_examples_dir.name
|
||||
|
||||
bad_file = temp_error_examples_dir / "bad.md"
|
||||
good_file = temp_error_examples_dir / "good.md"
|
||||
good_file.write_text("# Good Title\n\nBody.", encoding="utf-8")
|
||||
bad_file.write_text("Broken", encoding="utf-8")
|
||||
|
||||
# Patch Path.read_text to raise for bad.md only
|
||||
original_read_text = Path.read_text
|
||||
|
||||
def fake_read_text(self, *args, **kwargs):
|
||||
if self == bad_file:
|
||||
raise OSError("Cannot read file")
|
||||
return original_read_text(self, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(Path, "read_text", fake_read_text)
|
||||
|
||||
resources = retriever._list_local_markdown_resources()
|
||||
# Only good.md should appear
|
||||
assert len(resources) == 1
|
||||
r = resources[0]
|
||||
assert r.title == "Good Title"
|
||||
assert r.uri == f"milvus://{retriever.collection_name}/good.md"
|
||||
|
||||
|
||||
def test_create_collection_schema_fields(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
schema = retriever._create_collection_schema()
|
||||
field_names = {f.name for f in schema.fields}
|
||||
# Core fields must be present
|
||||
assert {
|
||||
retriever.id_field,
|
||||
retriever.vector_field,
|
||||
retriever.content_field,
|
||||
} <= field_names
|
||||
# Dynamic field enabled for extra metadata
|
||||
assert schema.enable_dynamic_field is True
|
||||
|
||||
|
||||
def test_generate_doc_id_stable(monkeypatch, tmp_path):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
test_file = tmp_path / "example.md"
|
||||
test_file.write_text("# Title\nBody", encoding="utf-8")
|
||||
doc_id1 = retriever._generate_doc_id(test_file)
|
||||
doc_id2 = retriever._generate_doc_id(test_file)
|
||||
assert doc_id1 == doc_id2 # deterministic given unchanged file metadata
|
||||
|
||||
|
||||
def test_extract_title_from_markdown(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
heading = retriever._extract_title_from_markdown("# Heading\nBody", "ignored.md")
|
||||
assert heading == "Heading"
|
||||
fallback = retriever._extract_title_from_markdown("Body only", "my_file_name.md")
|
||||
assert fallback == "My File Name"
|
||||
|
||||
|
||||
def test_split_content_chunking(monkeypatch):
|
||||
monkeypatch.setenv("MILVUS_CHUNK_SIZE", "40") # small to force split
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
long_content = (
|
||||
"Para1 text here.\n\nPara2 second block.\n\nPara3 final." # 3 paragraphs
|
||||
)
|
||||
chunks = retriever._split_content(long_content)
|
||||
assert len(chunks) >= 2 # forced split
|
||||
assert all(chunks) # no empty chunks
|
||||
|
||||
|
||||
def test_get_embedding_invalid_inputs(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
# Non-string value
|
||||
with pytest.raises(RuntimeError):
|
||||
retriever._get_embedding(123) # type: ignore[arg-type]
|
||||
# Whitespace only
|
||||
with pytest.raises(RuntimeError):
|
||||
retriever._get_embedding(" ")
|
||||
|
||||
|
||||
def test_list_resources_remote_success_and_dedup(monkeypatch):
|
||||
monkeypatch.setenv("MILVUS_URI", "http://remote")
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
|
||||
class DocObj:
|
||||
def __init__(self, content: str, meta: dict):
|
||||
self.page_content = content
|
||||
self.metadata = meta
|
||||
|
||||
calls = {"similarity_search": 0}
|
||||
|
||||
class RemoteClient:
|
||||
def similarity_search(self, query, k, expr): # noqa: D401
|
||||
calls["similarity_search"] += 1
|
||||
# Two docs with identical id to test dedup
|
||||
meta1 = {
|
||||
retriever.id_field: "d1",
|
||||
retriever.title_field: "T1",
|
||||
retriever.url_field: "u1",
|
||||
}
|
||||
meta2 = {
|
||||
retriever.id_field: "d1",
|
||||
retriever.title_field: "T1_dup",
|
||||
retriever.url_field: "u1",
|
||||
}
|
||||
return [DocObj("c1", meta1), DocObj("c1_dup", meta2)]
|
||||
|
||||
retriever.client = RemoteClient()
|
||||
resources = retriever.list_resources("query text")
|
||||
assert len(resources) == 1 # dedup applied
|
||||
assert resources[0].title.startswith("T1")
|
||||
assert calls["similarity_search"] == 1
|
||||
|
||||
|
||||
def test_list_resources_lite_success(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
|
||||
class DummyMilvusLite:
|
||||
def query(self, collection_name, filter, output_fields, limit): # noqa: D401
|
||||
return [
|
||||
{
|
||||
retriever.id_field: "idA",
|
||||
retriever.title_field: "Alpha",
|
||||
retriever.url_field: "u://a",
|
||||
},
|
||||
{
|
||||
retriever.id_field: "idB",
|
||||
retriever.title_field: "Beta",
|
||||
retriever.url_field: "u://b",
|
||||
},
|
||||
]
|
||||
|
||||
retriever.client = DummyMilvusLite()
|
||||
resources = retriever.list_resources()
|
||||
assert {r.title for r in resources} == {"Alpha", "Beta"}
|
||||
|
||||
|
||||
def test_query_relevant_documents_lite_success(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
|
||||
# Provide deterministic embedding output
|
||||
retriever.embedding_model.embed_query = lambda text: [0.1, 0.2, 0.3] # type: ignore
|
||||
|
||||
class DummyMilvusLite:
|
||||
def search(
|
||||
self, collection_name, data, anns_field, param, limit, output_fields
|
||||
): # noqa: D401
|
||||
# Simulate two result entries
|
||||
return [
|
||||
[
|
||||
{
|
||||
"entity": {
|
||||
retriever.id_field: "d1",
|
||||
retriever.content_field: "c1",
|
||||
retriever.title_field: "T1",
|
||||
retriever.url_field: "u1",
|
||||
},
|
||||
"distance": 0.9,
|
||||
},
|
||||
{
|
||||
"entity": {
|
||||
retriever.id_field: "d2",
|
||||
retriever.content_field: "c2",
|
||||
retriever.title_field: "T2",
|
||||
retriever.url_field: "u2",
|
||||
},
|
||||
"distance": 0.8,
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
retriever.client = DummyMilvusLite()
|
||||
# Filter for only d2 via resource list
|
||||
docs = retriever.query_relevant_documents(
|
||||
"question", resources=[Resource(uri="milvus://d2", title="", description="")]
|
||||
)
|
||||
assert len(docs) == 1 and docs[0].id == "d2" and docs[0].chunks[0].similarity == 0.8
|
||||
|
||||
|
||||
def test_query_relevant_documents_remote_success(monkeypatch):
|
||||
monkeypatch.setenv("MILVUS_URI", "http://remote")
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
retriever.embedding_model.embed_query = lambda text: [0.1, 0.2, 0.3] # type: ignore
|
||||
|
||||
class DocObj:
|
||||
def __init__(self, content: str, meta: dict): # noqa: D401
|
||||
self.page_content = content
|
||||
self.metadata = meta
|
||||
|
||||
class RemoteClient:
|
||||
def similarity_search_with_score(self, query, k): # noqa: D401
|
||||
return [
|
||||
(
|
||||
DocObj(
|
||||
"c1",
|
||||
{
|
||||
retriever.id_field: "d1",
|
||||
retriever.title_field: "T1",
|
||||
retriever.url_field: "u1",
|
||||
},
|
||||
),
|
||||
0.7,
|
||||
),
|
||||
(
|
||||
DocObj(
|
||||
"c2",
|
||||
{
|
||||
retriever.id_field: "d2",
|
||||
retriever.title_field: "T2",
|
||||
retriever.url_field: "u2",
|
||||
},
|
||||
),
|
||||
0.6,
|
||||
),
|
||||
]
|
||||
|
||||
retriever.client = RemoteClient()
|
||||
# Filter to only d1
|
||||
docs = retriever.query_relevant_documents(
|
||||
"q", resources=[Resource(uri="milvus://d1", title="", description="")]
|
||||
)
|
||||
assert len(docs) == 1 and docs[0].id == "d1" and docs[0].chunks[0].similarity == 0.7
|
||||
|
||||
|
||||
def test_get_embedding_dimension_explicit(monkeypatch):
|
||||
monkeypatch.setenv("MILVUS_EMBEDDING_DIM", "777")
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
assert retriever.embedding_dim == 777
|
||||
|
||||
|
||||
def test_get_embedding_dimension_unknown_model(monkeypatch):
|
||||
monkeypatch.delenv("MILVUS_EMBEDDING_DIM", raising=False)
|
||||
monkeypatch.setenv("MILVUS_EMBEDDING_MODEL", "unknown-model-x")
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
# falls back to default 1536
|
||||
assert retriever.embedding_dim == 1536
|
||||
|
||||
|
||||
def test_is_milvus_lite_variants(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
monkeypatch.setenv("MILVUS_URI", "mydb.db")
|
||||
assert MilvusProvider()._is_milvus_lite() is True
|
||||
monkeypatch.setenv("MILVUS_URI", "relative_path_store")
|
||||
assert MilvusProvider()._is_milvus_lite() is True
|
||||
monkeypatch.setenv("MILVUS_URI", "http://host:19530")
|
||||
assert MilvusProvider()._is_milvus_lite() is False
|
||||
|
||||
|
||||
def test_create_collection_lite(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
created: dict = {}
|
||||
|
||||
class DummyMilvusLite:
|
||||
def list_collections(self): # noqa: D401
|
||||
return [] # empty triggers creation
|
||||
|
||||
def create_collection(self, collection_name, schema, index_params): # noqa: D401
|
||||
created["name"] = collection_name
|
||||
created["schema"] = schema
|
||||
created["index"] = index_params
|
||||
|
||||
retriever.client = DummyMilvusLite()
|
||||
retriever._ensure_collection_exists()
|
||||
assert created["name"] == retriever.collection_name
|
||||
|
||||
|
||||
def test_ensure_collection_exists_remote(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
monkeypatch.setenv("MILVUS_URI", "http://remote:19530")
|
||||
retriever = MilvusProvider()
|
||||
# remote path, nothing thrown
|
||||
retriever.client = SimpleNamespace()
|
||||
retriever._ensure_collection_exists()
|
||||
|
||||
|
||||
def test_get_existing_document_ids_lite(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
|
||||
class DummyMilvusLite:
|
||||
def query(self, collection_name, filter, output_fields, limit): # noqa: D401
|
||||
return [
|
||||
{retriever.id_field: "a"},
|
||||
{retriever.id_field: "b"},
|
||||
{"other": "ignored"},
|
||||
]
|
||||
|
||||
retriever.client = DummyMilvusLite()
|
||||
assert retriever._get_existing_document_ids() == {"a", "b"}
|
||||
|
||||
|
||||
def test_get_existing_document_ids_remote(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
monkeypatch.setenv("MILVUS_URI", "http://x")
|
||||
retriever = MilvusProvider()
|
||||
retriever.client = object()
|
||||
assert retriever._get_existing_document_ids() == set()
|
||||
|
||||
|
||||
def test_insert_document_chunk_lite_and_error(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
|
||||
captured = {}
|
||||
|
||||
class DummyMilvusLite:
|
||||
def insert(self, collection_name, data): # noqa: D401
|
||||
captured["data"] = data
|
||||
|
||||
retriever.client = DummyMilvusLite()
|
||||
retriever._insert_document_chunk(
|
||||
doc_id="id1", content="hello", title="T", url="u", metadata={"m": 1}
|
||||
)
|
||||
assert captured["data"][0][retriever.id_field] == "id1"
|
||||
|
||||
# error path: patch embedding to raise
|
||||
def bad_embed(text): # noqa: D401
|
||||
raise RuntimeError("boom")
|
||||
|
||||
retriever.embedding_model.embed_query = bad_embed # type: ignore[attr-defined]
|
||||
with pytest.raises(RuntimeError):
|
||||
retriever._insert_document_chunk(
|
||||
doc_id="id2", content="err", title="T", url="u", metadata={}
|
||||
)
|
||||
|
||||
|
||||
def test_insert_document_chunk_remote(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
monkeypatch.setenv("MILVUS_URI", "http://remote")
|
||||
retriever = MilvusProvider()
|
||||
added = {}
|
||||
|
||||
class RemoteClient:
|
||||
def add_texts(self, texts, metadatas): # noqa: D401
|
||||
added["texts"] = texts
|
||||
added["meta"] = metadatas
|
||||
|
||||
retriever.client = RemoteClient()
|
||||
retriever._insert_document_chunk(
|
||||
doc_id="idx", content="ct", title="Title", url="urlx", metadata={"k": 2}
|
||||
)
|
||||
assert added["meta"][0][retriever.id_field] == "idx"
|
||||
|
||||
|
||||
def test_connect_lite_and_error(monkeypatch):
|
||||
# patch MilvusClient to a dummy
|
||||
class FakeMilvusClient:
|
||||
def __init__(self, uri): # noqa: D401
|
||||
self.uri = uri
|
||||
|
||||
def list_collections(self): # noqa: D401
|
||||
return []
|
||||
|
||||
def create_collection(self, **kwargs): # noqa: D401
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(milvus_mod, "MilvusClient", FakeMilvusClient)
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
retriever._connect()
|
||||
assert isinstance(retriever.client, FakeMilvusClient)
|
||||
|
||||
# error path: patch MilvusClient to raise
|
||||
class BadClient:
|
||||
def __init__(self, uri): # noqa: D401
|
||||
raise RuntimeError("fail connect")
|
||||
|
||||
monkeypatch.setattr(milvus_mod, "MilvusClient", BadClient)
|
||||
retriever2 = MilvusProvider()
|
||||
with pytest.raises(ConnectionError):
|
||||
retriever2._connect()
|
||||
|
||||
|
||||
def test_connect_remote(monkeypatch):
|
||||
monkeypatch.setenv("MILVUS_URI", "http://remote")
|
||||
_patch_init(monkeypatch)
|
||||
created = {}
|
||||
|
||||
class FakeLangchainMilvus:
|
||||
def __init__(self, **kwargs): # noqa: D401
|
||||
created.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(milvus_mod, "LangchainMilvus", FakeLangchainMilvus)
|
||||
retriever = MilvusProvider()
|
||||
retriever._connect()
|
||||
assert created["collection_name"] == retriever.collection_name
|
||||
|
||||
|
||||
def test_list_resources_remote_failure(monkeypatch):
|
||||
monkeypatch.setenv("MILVUS_URI", "http://remote")
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
|
||||
# Provide minimal working local examples dir (none -> returns [])
|
||||
monkeypatch.setattr(retriever, "_list_local_markdown_resources", lambda: [])
|
||||
|
||||
# patch client to raise inside similarity_search to trigger fallback path
|
||||
class BadClient:
|
||||
def similarity_search(self, *args, **kwargs): # noqa: D401
|
||||
raise RuntimeError("fail")
|
||||
|
||||
retriever.client = BadClient()
|
||||
# Should fallback to [] without raising
|
||||
assert retriever.list_resources() == []
|
||||
|
||||
|
||||
def test_list_local_markdown_resources_empty(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
monkeypatch.setenv("MILVUS_EXAMPLES_DIR", "nonexistent_dir")
|
||||
retriever.examples_dir = "nonexistent_dir"
|
||||
assert retriever._list_local_markdown_resources() == []
|
||||
|
||||
|
||||
def test_query_relevant_documents_error(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
retriever.embedding_model.embed_query = lambda text: ( # type: ignore
|
||||
_ for _ in ()
|
||||
).throw(RuntimeError("embed fail"))
|
||||
with pytest.raises(RuntimeError):
|
||||
retriever.query_relevant_documents("q")
|
||||
|
||||
|
||||
def test_create_collection_when_client_exists(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
retriever.client = SimpleNamespace(closed=False)
|
||||
# remote vs lite path difference handled by _is_milvus_lite
|
||||
retriever.create_collection() # should no-op gracefully
|
||||
|
||||
|
||||
def test_load_examples_force_reload(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
retriever.client = SimpleNamespace()
|
||||
called = {"clear": 0, "load": 0}
|
||||
monkeypatch.setattr(
|
||||
retriever, "_clear_example_documents", lambda: called.__setitem__("clear", 1)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
retriever, "_load_example_files", lambda: called.__setitem__("load", 1)
|
||||
)
|
||||
retriever.load_examples(force_reload=True)
|
||||
assert called == {"clear": 1, "load": 1}
|
||||
|
||||
|
||||
def test_clear_example_documents_remote(monkeypatch):
|
||||
monkeypatch.setenv("MILVUS_URI", "http://remote")
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
retriever.client = SimpleNamespace()
|
||||
# Should just log and not raise
|
||||
retriever._clear_example_documents()
|
||||
|
||||
|
||||
def test_clear_example_documents_lite(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
deleted = {}
|
||||
|
||||
class DummyMilvusLite:
|
||||
def query(self, **kwargs): # noqa: D401
|
||||
return [
|
||||
{retriever.id_field: "ex1"},
|
||||
{retriever.id_field: "ex2"},
|
||||
]
|
||||
|
||||
def delete(self, collection_name, ids): # noqa: D401
|
||||
deleted["ids"] = ids
|
||||
|
||||
retriever.client = DummyMilvusLite()
|
||||
retriever._clear_example_documents()
|
||||
assert deleted["ids"] == ["ex1", "ex2"]
|
||||
|
||||
|
||||
def test_get_loaded_examples_lite_and_error(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
|
||||
class DummyMilvusLite:
|
||||
def query(self, **kwargs): # noqa: D401
|
||||
return [
|
||||
{
|
||||
retriever.id_field: "id1",
|
||||
retriever.title_field: "T1",
|
||||
retriever.url_field: "u1",
|
||||
"file": "f1",
|
||||
}
|
||||
]
|
||||
|
||||
retriever.client = DummyMilvusLite()
|
||||
loaded = retriever.get_loaded_examples()
|
||||
assert loaded[0]["id"] == "id1"
|
||||
|
||||
# error path
|
||||
class BadClient:
|
||||
def query(self, **kwargs): # noqa: D401
|
||||
raise RuntimeError("fail")
|
||||
|
||||
retriever.client = BadClient()
|
||||
assert retriever.get_loaded_examples() == []
|
||||
|
||||
|
||||
def test_get_loaded_examples_remote(monkeypatch):
|
||||
monkeypatch.setenv("MILVUS_URI", "http://remote")
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
retriever.client = SimpleNamespace()
|
||||
assert retriever.get_loaded_examples() == []
|
||||
|
||||
|
||||
def test_close_lite_and_remote(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
closed = {"c": 0}
|
||||
|
||||
class DummyMilvusLite:
|
||||
def close(self): # noqa: D401
|
||||
closed["c"] += 1
|
||||
|
||||
def list_collections(self): # noqa: D401
|
||||
return []
|
||||
|
||||
def create_collection(self, **kwargs): # noqa: D401
|
||||
pass
|
||||
|
||||
retriever.client = DummyMilvusLite()
|
||||
retriever.close()
|
||||
assert closed["c"] == 1
|
||||
|
||||
# remote path: no close attr usage expected
|
||||
monkeypatch.setenv("MILVUS_URI", "http://remote")
|
||||
retriever2 = MilvusProvider()
|
||||
retriever2.client = SimpleNamespace()
|
||||
retriever2.close() # should not raise
|
||||
|
||||
|
||||
def test_get_embedding_invalid_output(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
retriever = MilvusProvider()
|
||||
# patch embedding model to return invalid output (empty list)
|
||||
retriever.embedding_model.embed_query = lambda text: [] # type: ignore
|
||||
with pytest.raises(RuntimeError):
|
||||
retriever._get_embedding("text")
|
||||
|
||||
|
||||
def test_dashscope_embeddings_empty_inputs_short_circuit(monkeypatch):
|
||||
# Use real class but swap _client to ensure create is never called
|
||||
emb = milvus_mod.DashscopeEmbeddings(model="m")
|
||||
|
||||
class FailingClient:
|
||||
class _Emb:
|
||||
def create(self, *a, **k):
|
||||
raise AssertionError("Should not be called for empty input")
|
||||
|
||||
embeddings = _Emb()
|
||||
|
||||
emb._client = FailingClient() # type: ignore
|
||||
assert emb.embed_documents([]) == []
|
||||
|
||||
|
||||
# Tests for _init_embedding_model provider selection logic
|
||||
def test_init_embedding_model_openai(monkeypatch):
|
||||
monkeypatch.setenv("MILVUS_EMBEDDING_PROVIDER", "openai")
|
||||
monkeypatch.setenv("MILVUS_EMBEDDING_MODEL", "text-embedding-ada-002")
|
||||
captured = {}
|
||||
|
||||
class CapturingOpenAI:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(milvus_mod, "OpenAIEmbeddings", CapturingOpenAI)
|
||||
prov = MilvusProvider()
|
||||
assert isinstance(prov.embedding_model, CapturingOpenAI)
|
||||
# kwargs forwarded
|
||||
assert captured["model"] == "text-embedding-ada-002"
|
||||
assert captured["encoding_format"] == "float"
|
||||
assert captured["dimensions"] == prov.embedding_dim
|
||||
|
||||
|
||||
def test_init_embedding_model_dashscope(monkeypatch):
|
||||
monkeypatch.setenv("MILVUS_EMBEDDING_PROVIDER", "dashscope")
|
||||
monkeypatch.setenv("MILVUS_EMBEDDING_MODEL", "text-embedding-ada-002")
|
||||
captured = {}
|
||||
|
||||
class CapturingDashscope:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(milvus_mod, "DashscopeEmbeddings", CapturingDashscope)
|
||||
prov = MilvusProvider()
|
||||
assert isinstance(prov.embedding_model, CapturingDashscope)
|
||||
assert captured["model"] == "text-embedding-ada-002"
|
||||
assert captured["encoding_format"] == "float"
|
||||
assert captured["dimensions"] == prov.embedding_dim
|
||||
|
||||
|
||||
def test_init_embedding_model_invalid_provider(monkeypatch):
|
||||
monkeypatch.setenv("MILVUS_EMBEDDING_PROVIDER", "not_a_provider")
|
||||
with pytest.raises(ValueError):
|
||||
MilvusProvider()
|
||||
|
||||
|
||||
def test_load_example_files_directory_missing(monkeypatch):
|
||||
_patch_init(monkeypatch)
|
||||
missing_dir = "examples_dir_does_not_exist_xyz"
|
||||
monkeypatch.setenv("MILVUS_EXAMPLES_DIR", missing_dir)
|
||||
retriever = MilvusProvider()
|
||||
retriever.examples_dir = missing_dir
|
||||
called = {"insert": 0}
|
||||
monkeypatch.setattr(
|
||||
retriever,
|
||||
"_insert_document_chunk",
|
||||
lambda **kwargs: (_ for _ in ()).throw(AssertionError("should not insert")),
|
||||
)
|
||||
retriever._load_example_files()
|
||||
assert called["insert"] == 0 # sanity (no insertion attempted)
|
||||
|
||||
|
||||
def test_load_example_files_loads_and_skips_existing(
|
||||
monkeypatch, temp_load_skip_examples_dir
|
||||
):
|
||||
_patch_init(monkeypatch)
|
||||
examples_dir_name = temp_load_skip_examples_dir.name
|
||||
|
||||
file1 = temp_load_skip_examples_dir / "file1.md"
|
||||
file2 = temp_load_skip_examples_dir / "file2.md"
|
||||
file1.write_text("# Title One\nContent A", encoding="utf-8")
|
||||
file2.write_text("# Title Two\nContent B", encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("MILVUS_EXAMPLES_DIR", examples_dir_name)
|
||||
retriever = MilvusProvider()
|
||||
retriever.examples_dir = examples_dir_name
|
||||
|
||||
# Compute doc ids using real method
|
||||
doc_id_file1 = retriever._generate_doc_id(file1)
|
||||
doc_id_file2 = retriever._generate_doc_id(file2)
|
||||
|
||||
# Existing docs contains file1 so it is skipped
|
||||
monkeypatch.setattr(retriever, "_get_existing_document_ids", lambda: {doc_id_file1})
|
||||
# Force two chunks for any file to test suffix logic
|
||||
monkeypatch.setattr(retriever, "_split_content", lambda content: ["part1", "part2"])
|
||||
|
||||
calls = []
|
||||
|
||||
def record_insert(doc_id, content, title, url, metadata):
|
||||
calls.append(
|
||||
{
|
||||
"doc_id": doc_id,
|
||||
"content": content,
|
||||
"title": title,
|
||||
"url": url,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(retriever, "_insert_document_chunk", record_insert)
|
||||
|
||||
retriever._load_example_files()
|
||||
|
||||
# Only file2 processed -> two chunk inserts
|
||||
assert len(calls) == 2
|
||||
expected_ids = {f"{doc_id_file2}_chunk_0", f"{doc_id_file2}_chunk_1"}
|
||||
assert {c["doc_id"] for c in calls} == expected_ids
|
||||
assert all(c["metadata"]["file"] == "file2.md" for c in calls)
|
||||
assert all(c["metadata"]["source"] == "examples" for c in calls)
|
||||
assert all(c["title"] == "Title Two" for c in calls)
|
||||
|
||||
|
||||
def test_load_example_files_single_chunk_no_suffix(
|
||||
monkeypatch, temp_single_chunk_examples_dir
|
||||
):
|
||||
_patch_init(monkeypatch)
|
||||
examples_dir_name = temp_single_chunk_examples_dir.name
|
||||
|
||||
file_single = temp_single_chunk_examples_dir / "single.md"
|
||||
file_single.write_text(
|
||||
"# Single Title\nOnly one small paragraph.", encoding="utf-8"
|
||||
)
|
||||
|
||||
monkeypatch.setenv("MILVUS_EXAMPLES_DIR", examples_dir_name)
|
||||
retriever = MilvusProvider()
|
||||
retriever.examples_dir = examples_dir_name
|
||||
|
||||
base_doc_id = retriever._generate_doc_id(file_single)
|
||||
|
||||
monkeypatch.setattr(retriever, "_get_existing_document_ids", lambda: set())
|
||||
monkeypatch.setattr(retriever, "_split_content", lambda content: ["onlychunk"])
|
||||
|
||||
captured = {}
|
||||
|
||||
def capture(doc_id, content, title, url, metadata):
|
||||
captured["doc_id"] = doc_id
|
||||
captured["title"] = title
|
||||
captured["metadata"] = metadata
|
||||
|
||||
monkeypatch.setattr(retriever, "_insert_document_chunk", capture)
|
||||
|
||||
retriever._load_example_files()
|
||||
|
||||
assert captured["doc_id"] == base_doc_id # no _chunk_ suffix
|
||||
assert captured["title"] == "Single Title"
|
||||
assert captured["metadata"]["file"] == "single.md"
|
||||
assert captured["metadata"]["source"] == "examples"
|
||||
|
||||
|
||||
# Clean up test database file after tests
|
||||
import atexit
|
||||
|
||||
|
||||
def cleanup_test_database():
|
||||
"""Clean up milvus_demo.db file created during testing."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Skip cleanup if disabled
|
||||
if os.getenv("DISABLE_TEST_CLEANUP", "false").lower() == "true":
|
||||
return
|
||||
|
||||
db_file = Path.cwd() / "milvus_demo.db"
|
||||
if db_file.exists():
|
||||
try:
|
||||
db_file.unlink()
|
||||
print("🧹 Cleaned up milvus_demo.db")
|
||||
except Exception:
|
||||
pass # Silently ignore cleanup errors
|
||||
|
||||
|
||||
# Register cleanup to run when Python exits
|
||||
atexit.register(cleanup_test_database)
|
||||
@@ -1,333 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
import src.rag.qdrant as qdrant_mod
|
||||
from src.rag.qdrant import QdrantProvider
|
||||
|
||||
|
||||
class DummyEmbedding:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def embed_query(self, text: str):
|
||||
return [0.1] * 1536
|
||||
|
||||
def embed_documents(self, texts):
|
||||
return [[0.1] * 1536 for _ in texts]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_embeddings(monkeypatch):
|
||||
monkeypatch.setenv("QDRANT_EMBEDDING_PROVIDER", "openai")
|
||||
monkeypatch.setenv("QDRANT_EMBEDDING_MODEL", "text-embedding-ada-002")
|
||||
monkeypatch.setenv("QDRANT_COLLECTION", "documents")
|
||||
monkeypatch.setenv("QDRANT_LOCATION", ":memory:")
|
||||
monkeypatch.setattr(qdrant_mod, "OpenAIEmbeddings", DummyEmbedding)
|
||||
monkeypatch.setattr(qdrant_mod, "DashscopeEmbeddings", DummyEmbedding)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_root():
|
||||
return Path(qdrant_mod.__file__).parent.parent.parent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_examples_dir(project_root):
|
||||
temp_dir_name = f"examples_test_{uuid4().hex}"
|
||||
temp_dir_path = project_root / temp_dir_name
|
||||
temp_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
yield temp_dir_path
|
||||
if temp_dir_path.exists():
|
||||
shutil.rmtree(temp_dir_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_error_examples_dir(project_root):
|
||||
temp_dir_name = f"examples_error_{uuid4().hex}"
|
||||
temp_dir_path = project_root / temp_dir_name
|
||||
temp_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
yield temp_dir_path
|
||||
if temp_dir_path.exists():
|
||||
shutil.rmtree(temp_dir_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_load_skip_examples_dir(project_root):
|
||||
temp_dir_name = f"examples_load_skip_{uuid4().hex}"
|
||||
temp_dir_path = project_root / temp_dir_name
|
||||
temp_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
yield temp_dir_path
|
||||
if temp_dir_path.exists():
|
||||
shutil.rmtree(temp_dir_path)
|
||||
|
||||
|
||||
def test_init_openai_provider(monkeypatch):
|
||||
monkeypatch.setenv("QDRANT_EMBEDDING_PROVIDER", "openai")
|
||||
provider = QdrantProvider()
|
||||
assert provider.embedding_provider == "openai"
|
||||
assert isinstance(provider.embedding_model, DummyEmbedding)
|
||||
|
||||
|
||||
def test_init_dashscope_provider(monkeypatch):
|
||||
monkeypatch.setenv("QDRANT_EMBEDDING_PROVIDER", "dashscope")
|
||||
provider = QdrantProvider()
|
||||
assert provider.embedding_provider == "dashscope"
|
||||
assert isinstance(provider.embedding_model, DummyEmbedding)
|
||||
|
||||
|
||||
def test_init_invalid_provider(monkeypatch):
|
||||
monkeypatch.setenv("QDRANT_EMBEDDING_PROVIDER", "invalid_provider")
|
||||
with pytest.raises(ValueError, match="Unsupported embedding provider"):
|
||||
QdrantProvider()
|
||||
|
||||
|
||||
def test_get_embedding_dimension_explicit(monkeypatch):
|
||||
monkeypatch.setenv("QDRANT_EMBEDDING_DIM", "2048")
|
||||
provider = QdrantProvider()
|
||||
assert provider.embedding_dim == 2048
|
||||
|
||||
|
||||
def test_get_embedding_dimension_default(monkeypatch):
|
||||
monkeypatch.delenv("QDRANT_EMBEDDING_DIM", raising=False)
|
||||
monkeypatch.setenv("QDRANT_EMBEDDING_MODEL", "text-embedding-ada-002")
|
||||
provider = QdrantProvider()
|
||||
assert provider.embedding_dim == 1536
|
||||
|
||||
|
||||
def test_get_embedding_dimension_unknown_model(monkeypatch):
|
||||
monkeypatch.delenv("QDRANT_EMBEDDING_DIM", raising=False)
|
||||
monkeypatch.setenv("QDRANT_EMBEDDING_MODEL", "unknown-model")
|
||||
provider = QdrantProvider()
|
||||
assert provider.embedding_dim == 1536
|
||||
|
||||
|
||||
def test_connect_memory_mode(monkeypatch):
|
||||
monkeypatch.setenv("QDRANT_LOCATION", ":memory:")
|
||||
provider = QdrantProvider()
|
||||
provider._connect()
|
||||
assert provider.client is not None
|
||||
|
||||
|
||||
def test_create_collection(monkeypatch):
|
||||
provider = QdrantProvider()
|
||||
provider.create_collection()
|
||||
assert provider.client is not None
|
||||
|
||||
|
||||
def test_extract_title_from_markdown():
|
||||
provider = QdrantProvider()
|
||||
content = "# Test Title\n\nSome content"
|
||||
title = provider._extract_title_from_markdown(content, "test.md")
|
||||
assert title == "Test Title"
|
||||
|
||||
|
||||
def test_extract_title_fallback():
|
||||
provider = QdrantProvider()
|
||||
content = "No title here"
|
||||
title = provider._extract_title_from_markdown(content, "test_file.md")
|
||||
assert title == "Test File"
|
||||
|
||||
|
||||
def test_split_content_short():
|
||||
provider = QdrantProvider()
|
||||
content = "Short content"
|
||||
chunks = provider._split_content(content)
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == content
|
||||
|
||||
|
||||
def test_split_content_long(monkeypatch):
|
||||
monkeypatch.setenv("QDRANT_CHUNK_SIZE", "20")
|
||||
provider = QdrantProvider()
|
||||
content = "Paragraph one here\n\nParagraph two here\n\nParagraph three here\n\nParagraph four here"
|
||||
chunks = provider._split_content(content)
|
||||
assert len(chunks) > 1
|
||||
|
||||
|
||||
def test_string_to_uuid():
|
||||
provider = QdrantProvider()
|
||||
uuid1 = provider._string_to_uuid("test")
|
||||
uuid2 = provider._string_to_uuid("test")
|
||||
assert uuid1 == uuid2
|
||||
|
||||
|
||||
def test_get_embedding():
|
||||
provider = QdrantProvider()
|
||||
embedding = provider._get_embedding("test text")
|
||||
assert len(embedding) == 1536
|
||||
assert all(isinstance(x, float) for x in embedding)
|
||||
|
||||
|
||||
def test_load_examples_no_directory(monkeypatch, project_root):
|
||||
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", "nonexistent_dir")
|
||||
provider = QdrantProvider()
|
||||
provider.load_examples()
|
||||
|
||||
|
||||
def test_load_examples_empty_directory(monkeypatch, temp_examples_dir):
|
||||
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
|
||||
provider = QdrantProvider()
|
||||
provider.load_examples()
|
||||
|
||||
|
||||
def test_load_examples_with_files(monkeypatch, temp_examples_dir):
|
||||
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
|
||||
|
||||
md_file = temp_examples_dir / "test.md"
|
||||
md_file.write_text("# Test\n\nContent", encoding="utf-8")
|
||||
|
||||
provider = QdrantProvider()
|
||||
provider.load_examples()
|
||||
|
||||
loaded = provider.get_loaded_examples()
|
||||
assert len(loaded) == 1
|
||||
assert loaded[0]["title"] == "Test"
|
||||
|
||||
|
||||
def test_load_examples_skip_existing(monkeypatch, temp_load_skip_examples_dir):
|
||||
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_load_skip_examples_dir.name)
|
||||
|
||||
md_file = temp_load_skip_examples_dir / "test.md"
|
||||
md_file.write_text("# Test\n\nContent", encoding="utf-8")
|
||||
|
||||
provider = QdrantProvider()
|
||||
provider.load_examples()
|
||||
provider.load_examples()
|
||||
|
||||
loaded = provider.get_loaded_examples()
|
||||
assert len(loaded) == 1
|
||||
|
||||
|
||||
def test_load_examples_force_reload(monkeypatch, temp_examples_dir):
|
||||
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
|
||||
|
||||
md_file = temp_examples_dir / "test.md"
|
||||
md_file.write_text("# Test\n\nContent", encoding="utf-8")
|
||||
|
||||
provider = QdrantProvider()
|
||||
provider.load_examples()
|
||||
provider.load_examples(force_reload=True)
|
||||
|
||||
loaded = provider.get_loaded_examples()
|
||||
assert len(loaded) == 1
|
||||
|
||||
|
||||
def test_load_examples_error_handling(monkeypatch, temp_error_examples_dir):
|
||||
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_error_examples_dir.name)
|
||||
|
||||
good_file = temp_error_examples_dir / "good.md"
|
||||
good_file.write_text("# Good\n\nContent", encoding="utf-8")
|
||||
|
||||
bad_file = temp_error_examples_dir / "bad.md"
|
||||
bad_file.write_text("# Bad\n\n", encoding="utf-8")
|
||||
|
||||
provider = QdrantProvider()
|
||||
provider.load_examples()
|
||||
|
||||
loaded = provider.get_loaded_examples()
|
||||
assert len(loaded) >= 1
|
||||
|
||||
|
||||
def test_list_resources_no_query(monkeypatch, temp_examples_dir):
|
||||
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
|
||||
|
||||
md_file = temp_examples_dir / "test.md"
|
||||
md_file.write_text("# Test\n\nContent", encoding="utf-8")
|
||||
|
||||
provider = QdrantProvider()
|
||||
provider.load_examples()
|
||||
|
||||
resources = provider.list_resources()
|
||||
assert len(resources) >= 1
|
||||
|
||||
|
||||
def test_list_resources_with_query(monkeypatch, temp_examples_dir):
|
||||
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
|
||||
|
||||
md_file = temp_examples_dir / "test.md"
|
||||
md_file.write_text("# Test\n\nContent", encoding="utf-8")
|
||||
|
||||
provider = QdrantProvider()
|
||||
provider.load_examples()
|
||||
|
||||
resources = provider.list_resources(query="test")
|
||||
assert isinstance(resources, list)
|
||||
|
||||
|
||||
def test_query_relevant_documents(monkeypatch, temp_examples_dir):
|
||||
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
|
||||
|
||||
md_file = temp_examples_dir / "test.md"
|
||||
md_file.write_text("# Test\n\nContent about testing", encoding="utf-8")
|
||||
|
||||
provider = QdrantProvider()
|
||||
provider.load_examples()
|
||||
|
||||
documents = provider.query_relevant_documents("testing")
|
||||
assert isinstance(documents, list)
|
||||
|
||||
|
||||
def test_query_relevant_documents_with_resources(monkeypatch, temp_examples_dir):
|
||||
monkeypatch.setenv("QDRANT_EXAMPLES_DIR", temp_examples_dir.name)
|
||||
|
||||
md_file = temp_examples_dir / "test.md"
|
||||
md_file.write_text("# Test\n\nContent", encoding="utf-8")
|
||||
|
||||
provider = QdrantProvider()
|
||||
provider.load_examples()
|
||||
|
||||
resources = provider.list_resources()
|
||||
documents = provider.query_relevant_documents("test", resources=resources)
|
||||
assert isinstance(documents, list)
|
||||
|
||||
|
||||
def test_close():
|
||||
provider = QdrantProvider()
|
||||
provider._connect()
|
||||
provider.close()
|
||||
assert provider.client is None
|
||||
|
||||
|
||||
def test_del():
|
||||
provider = QdrantProvider()
|
||||
provider._connect()
|
||||
del provider
|
||||
|
||||
|
||||
def test_top_k_configuration(monkeypatch):
|
||||
monkeypatch.setenv("QDRANT_TOP_K", "20")
|
||||
provider = QdrantProvider()
|
||||
assert provider.top_k == 20
|
||||
|
||||
|
||||
def test_top_k_invalid(monkeypatch):
|
||||
monkeypatch.setenv("QDRANT_TOP_K", "invalid")
|
||||
provider = QdrantProvider()
|
||||
assert provider.top_k == 10
|
||||
|
||||
|
||||
def test_chunk_size_configuration(monkeypatch):
|
||||
monkeypatch.setenv("QDRANT_CHUNK_SIZE", "5000")
|
||||
provider = QdrantProvider()
|
||||
assert provider.chunk_size == 5000
|
||||
|
||||
|
||||
def test_collection_name_configuration(monkeypatch):
|
||||
monkeypatch.setenv("QDRANT_COLLECTION", "custom_collection")
|
||||
provider = QdrantProvider()
|
||||
assert provider.collection_name == "custom_collection"
|
||||
|
||||
|
||||
def test_auto_load_examples_configuration(monkeypatch):
|
||||
monkeypatch.setenv("QDRANT_AUTO_LOAD_EXAMPLES", "false")
|
||||
provider = QdrantProvider()
|
||||
assert provider.auto_load_examples is False
|
||||
@@ -1,165 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.rag.ragflow import RAGFlowProvider, parse_uri
|
||||
|
||||
|
||||
# Dummy classes to mock dependencies
|
||||
class DummyResource:
|
||||
def __init__(self, uri, title="", description=""):
|
||||
self.uri = uri
|
||||
self.title = title
|
||||
self.description = description
|
||||
|
||||
|
||||
class DummyChunk:
|
||||
def __init__(self, content, similarity):
|
||||
self.content = content
|
||||
self.similarity = similarity
|
||||
|
||||
|
||||
class DummyDocument:
|
||||
def __init__(self, id, title, chunks=None):
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.chunks = chunks or []
|
||||
|
||||
|
||||
# Patch imports in ragflow.py to use dummy classes
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_imports(monkeypatch):
|
||||
import src.rag.ragflow as ragflow
|
||||
|
||||
ragflow.Resource = DummyResource
|
||||
ragflow.Chunk = DummyChunk
|
||||
ragflow.Document = DummyDocument
|
||||
yield
|
||||
|
||||
|
||||
def test_parse_uri_valid():
|
||||
uri = "rag://dataset/123#abc"
|
||||
dataset_id, document_id = parse_uri(uri)
|
||||
assert dataset_id == "123"
|
||||
assert document_id == "abc"
|
||||
|
||||
|
||||
def test_parse_uri_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
parse_uri("http://dataset/123#abc")
|
||||
|
||||
|
||||
def test_init_env_vars(monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
monkeypatch.delenv("RAGFLOW_PAGE_SIZE", raising=False)
|
||||
provider = RAGFlowProvider()
|
||||
assert provider.api_url == "http://api"
|
||||
assert provider.api_key == "key"
|
||||
assert provider.page_size == 10
|
||||
|
||||
|
||||
def test_init_page_size(monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
monkeypatch.setenv("RAGFLOW_PAGE_SIZE", "5")
|
||||
provider = RAGFlowProvider()
|
||||
assert provider.page_size == 5
|
||||
|
||||
|
||||
def test_init_cross_language(monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
monkeypatch.setenv("RAGFLOW_CROSS_LANGUAGES", "lang1,lang2")
|
||||
provider = RAGFlowProvider()
|
||||
assert provider.cross_languages == ["lang1", "lang2"]
|
||||
|
||||
|
||||
def test_init_missing_env(monkeypatch):
|
||||
monkeypatch.delenv("RAGFLOW_API_URL", raising=False)
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
with pytest.raises(ValueError):
|
||||
RAGFlowProvider()
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.delenv("RAGFLOW_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError):
|
||||
RAGFlowProvider()
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.post")
|
||||
def test_query_relevant_documents_success(mock_post, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
resource = DummyResource("rag://dataset/123#doc456")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": {
|
||||
"doc_aggs": [{"doc_id": "doc456", "doc_name": "Doc Title"}],
|
||||
"chunks": [
|
||||
{"document_id": "doc456", "content": "chunk text", "similarity": 0.9}
|
||||
],
|
||||
}
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
docs = provider.query_relevant_documents("query", [resource])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "doc456"
|
||||
assert docs[0].title == "Doc Title"
|
||||
assert len(docs[0].chunks) == 1
|
||||
assert docs[0].chunks[0].content == "chunk text"
|
||||
assert docs[0].chunks[0].similarity == 0.9
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.post")
|
||||
def test_query_relevant_documents_error(mock_post, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.text = "error"
|
||||
mock_post.return_value = mock_response
|
||||
with pytest.raises(Exception):
|
||||
provider.query_relevant_documents("query", [])
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.get")
|
||||
def test_list_resources_success(mock_get, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [
|
||||
{"id": "123", "name": "Dataset1", "description": "desc1"},
|
||||
{"id": "456", "name": "Dataset2", "description": "desc2"},
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
resources = provider.list_resources()
|
||||
assert len(resources) == 2
|
||||
assert resources[0].uri == "rag://dataset/123"
|
||||
assert resources[0].title == "Dataset1"
|
||||
assert resources[0].description == "desc1"
|
||||
assert resources[1].uri == "rag://dataset/456"
|
||||
assert resources[1].title == "Dataset2"
|
||||
assert resources[1].description == "desc2"
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.get")
|
||||
def test_list_resources_error(mock_get, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "fail"
|
||||
mock_get.return_value = mock_response
|
||||
with pytest.raises(Exception):
|
||||
provider.list_resources()
|
||||
@@ -1,114 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
|
||||
|
||||
def test_chunk_init():
|
||||
chunk = Chunk(content="test content", similarity=0.9)
|
||||
assert chunk.content == "test content"
|
||||
assert chunk.similarity == 0.9
|
||||
|
||||
|
||||
def test_document_init_and_to_dict():
|
||||
chunk1 = Chunk(content="chunk1", similarity=0.8)
|
||||
chunk2 = Chunk(content="chunk2", similarity=0.7)
|
||||
doc = Document(
|
||||
id="doc1", url="http://example.com", title="Title", chunks=[chunk1, chunk2]
|
||||
)
|
||||
assert doc.id == "doc1"
|
||||
assert doc.url == "http://example.com"
|
||||
assert doc.title == "Title"
|
||||
assert doc.chunks == [chunk1, chunk2]
|
||||
d = doc.to_dict()
|
||||
assert d["id"] == "doc1"
|
||||
assert d["content"] == "chunk1\n\nchunk2"
|
||||
assert d["url"] == "http://example.com"
|
||||
assert d["title"] == "Title"
|
||||
|
||||
|
||||
def test_document_to_dict_optional_fields():
|
||||
chunk = Chunk(content="only chunk", similarity=1.0)
|
||||
doc = Document(id="doc2", chunks=[chunk])
|
||||
d = doc.to_dict()
|
||||
assert d["id"] == "doc2"
|
||||
assert d["content"] == "only chunk"
|
||||
assert "url" not in d
|
||||
assert "title" not in d
|
||||
|
||||
|
||||
def test_resource_model():
|
||||
resource = Resource(uri="uri1", title="Resource Title")
|
||||
assert resource.uri == "uri1"
|
||||
assert resource.title == "Resource Title"
|
||||
assert resource.description == ""
|
||||
|
||||
|
||||
def test_resource_model_with_description():
|
||||
resource = Resource(uri="uri2", title="Resource2", description="desc")
|
||||
assert resource.description == "desc"
|
||||
|
||||
|
||||
def test_retriever_abstract_methods():
|
||||
class DummyRetriever(Retriever):
|
||||
def list_resources(self, query=None):
|
||||
return [Resource(uri="uri", title="title")]
|
||||
|
||||
async def list_resources_async(self, query=None):
|
||||
return [Resource(uri="uri", title="title")]
|
||||
|
||||
def query_relevant_documents(self, query, resources=[]):
|
||||
return [Document(id="id", chunks=[])]
|
||||
|
||||
async def query_relevant_documents_async(self, query, resources=[]):
|
||||
return [Document(id="id", chunks=[])]
|
||||
|
||||
retriever = DummyRetriever()
|
||||
# Test synchronous methods
|
||||
resources = retriever.list_resources()
|
||||
assert isinstance(resources, list)
|
||||
assert isinstance(resources[0], Resource)
|
||||
assert resources[0].uri == "uri"
|
||||
|
||||
docs = retriever.query_relevant_documents("query", resources)
|
||||
assert isinstance(docs, list)
|
||||
assert isinstance(docs[0], Document)
|
||||
assert docs[0].id == "id"
|
||||
|
||||
|
||||
def test_retriever_cannot_instantiate():
|
||||
with pytest.raises(TypeError):
|
||||
Retriever()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retriever_async_methods():
|
||||
"""Test that async methods work correctly in DummyRetriever."""
|
||||
class DummyRetriever(Retriever):
|
||||
def list_resources(self, query=None):
|
||||
return [Resource(uri="uri", title="title")]
|
||||
|
||||
async def list_resources_async(self, query=None):
|
||||
return [Resource(uri="uri_async", title="title_async")]
|
||||
|
||||
def query_relevant_documents(self, query, resources=[]):
|
||||
return [Document(id="id", chunks=[])]
|
||||
|
||||
async def query_relevant_documents_async(self, query, resources=[]):
|
||||
return [Document(id="id_async", chunks=[])]
|
||||
|
||||
retriever = DummyRetriever()
|
||||
|
||||
# Test async list_resources
|
||||
resources = await retriever.list_resources_async()
|
||||
assert isinstance(resources, list)
|
||||
assert isinstance(resources[0], Resource)
|
||||
assert resources[0].uri == "uri_async"
|
||||
|
||||
# Test async query_relevant_documents
|
||||
docs = await retriever.query_relevant_documents_async("query", resources)
|
||||
assert isinstance(docs, list)
|
||||
assert isinstance(docs[0], Document)
|
||||
assert docs[0].id == "id_async"
|
||||
@@ -1,540 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider, parse_uri
|
||||
|
||||
|
||||
# Dummy classes to mock dependencies
|
||||
class MockResource:
|
||||
def __init__(self, uri, title="", description=""):
|
||||
self.uri = uri
|
||||
self.title = title
|
||||
self.description = description
|
||||
|
||||
|
||||
class MockChunk:
|
||||
def __init__(self, content, similarity):
|
||||
self.content = content
|
||||
self.similarity = similarity
|
||||
|
||||
|
||||
class MockDocument:
|
||||
def __init__(self, id, title, chunks=None):
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.chunks = chunks or []
|
||||
|
||||
|
||||
# Patch the imports to use mock classes
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_imports():
|
||||
with (
|
||||
patch("src.rag.vikingdb_knowledge_base.Resource", MockResource),
|
||||
patch("src.rag.vikingdb_knowledge_base.Chunk", MockChunk),
|
||||
patch("src.rag.vikingdb_knowledge_base.Document", MockDocument),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env_vars():
|
||||
"""Fixture to set up environment variables"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE": "10",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_REGION": "cn-north-1",
|
||||
},
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestParseUri:
|
||||
def test_parse_uri_valid_with_fragment(self):
|
||||
"""Test parsing valid URI with fragment"""
|
||||
uri = "rag://dataset/123#doc456"
|
||||
resource_id, document_id = parse_uri(uri)
|
||||
assert resource_id == "123"
|
||||
assert document_id == "doc456"
|
||||
|
||||
def test_parse_uri_valid_without_fragment(self):
|
||||
"""Test parsing valid URI without fragment"""
|
||||
uri = "rag://dataset/123"
|
||||
resource_id, document_id = parse_uri(uri)
|
||||
assert resource_id == "123"
|
||||
assert document_id == ""
|
||||
|
||||
def test_parse_uri_invalid_scheme(self):
|
||||
"""Test parsing URI with invalid scheme"""
|
||||
with pytest.raises(ValueError, match="Invalid URI"):
|
||||
parse_uri("http://dataset/123#abc")
|
||||
|
||||
def test_parse_uri_malformed(self):
|
||||
"""Test parsing malformed URI"""
|
||||
with pytest.raises(ValueError, match="Invalid URI"):
|
||||
parse_uri("invalid_uri")
|
||||
|
||||
|
||||
class TestVikingDBKnowledgeBaseProviderInit:
|
||||
def test_init_success_with_all_env_vars(self, env_vars):
|
||||
"""Test successful initialization with all environment variables"""
|
||||
provider = VikingDBKnowledgeBaseProvider()
|
||||
assert provider.api_url == "api-test.example.com"
|
||||
assert provider.api_ak == "test_ak"
|
||||
assert provider.api_sk == "test_sk"
|
||||
assert provider.retrieval_size == 10
|
||||
assert provider.region == "cn-north-1"
|
||||
assert provider.service == "air"
|
||||
|
||||
def test_init_success_without_retrieval_size(self):
|
||||
"""Test initialization without VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE (should use default)"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
provider = VikingDBKnowledgeBaseProvider()
|
||||
assert provider.retrieval_size == 10
|
||||
|
||||
def test_init_custom_retrieval_size(self):
|
||||
"""Test initialization with custom retrieval size"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE": "5",
|
||||
},
|
||||
):
|
||||
provider = VikingDBKnowledgeBaseProvider()
|
||||
assert provider.retrieval_size == 5
|
||||
|
||||
def test_init_custom_region(self):
|
||||
"""Test initialization with custom region"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_REGION": "us-east-1",
|
||||
},
|
||||
):
|
||||
provider = VikingDBKnowledgeBaseProvider()
|
||||
assert provider.region == "us-east-1"
|
||||
|
||||
def test_init_missing_api_url(self):
|
||||
"""Test initialization fails when API URL is missing"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="VIKINGDB_KNOWLEDGE_BASE_API_URL is not set"
|
||||
):
|
||||
VikingDBKnowledgeBaseProvider()
|
||||
|
||||
def test_init_missing_api_ak(self):
|
||||
"""Test initialization fails when API AK is missing"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="VIKINGDB_KNOWLEDGE_BASE_API_AK is not set"
|
||||
):
|
||||
VikingDBKnowledgeBaseProvider()
|
||||
|
||||
def test_init_missing_api_sk(self):
|
||||
"""Test initialization fails when API SK is missing"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="VIKINGDB_KNOWLEDGE_BASE_API_SK is not set"
|
||||
):
|
||||
VikingDBKnowledgeBaseProvider()
|
||||
|
||||
|
||||
class TestVikingDBKnowledgeBaseProviderSignature:
|
||||
@pytest.fixture
|
||||
def provider(self, env_vars):
|
||||
return VikingDBKnowledgeBaseProvider()
|
||||
|
||||
def test_hmac_sha256(self, provider):
|
||||
"""Test HMAC SHA256 calculation"""
|
||||
key = b"test_key"
|
||||
content = "test_content"
|
||||
result = provider._hmac_sha256(key, content)
|
||||
expected = hmac.new(key, content.encode("utf-8"), hashlib.sha256).digest()
|
||||
assert result == expected
|
||||
|
||||
def test_hash_sha256(self, provider):
|
||||
"""Test SHA256 hash calculation"""
|
||||
data = b"test_data"
|
||||
result = provider._hash_sha256(data)
|
||||
expected = hashlib.sha256(data).digest()
|
||||
assert result == expected
|
||||
|
||||
def test_get_signed_key(self, provider):
|
||||
"""Test signed key generation"""
|
||||
secret_key = "test_secret"
|
||||
date = "20250722"
|
||||
region = "cn-north-1"
|
||||
service = "air"
|
||||
|
||||
result = provider._get_signed_key(secret_key, date, region, service)
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) == 32 # SHA256 digest is 32 bytes
|
||||
|
||||
def test_create_canonical_request(self, provider):
|
||||
"""Test canonical request creation"""
|
||||
method = "POST"
|
||||
path = "/api/test"
|
||||
query_params = {"param1": "value1", "param2": "value2"}
|
||||
headers = {"Content-Type": "application/json", "Host": "example.com"}
|
||||
payload = b'{"test": "data"}'
|
||||
|
||||
canonical_request, signed_headers = provider._create_canonical_request(
|
||||
method, path, query_params, headers, payload
|
||||
)
|
||||
|
||||
assert "POST" in canonical_request
|
||||
assert "/api/test" in canonical_request
|
||||
assert "param1=value1¶m2=value2" in canonical_request
|
||||
assert "content-type:application/json" in canonical_request
|
||||
assert "host:example.com" in canonical_request
|
||||
assert signed_headers == "content-type;host"
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.datetime")
|
||||
def test_create_signature(self, mock_datetime, provider):
|
||||
"""Test signature creation"""
|
||||
# Mock datetime
|
||||
mock_now = datetime(2025, 7, 22, 10, 30, 45)
|
||||
mock_datetime.utcnow.return_value = mock_now
|
||||
|
||||
method = "POST"
|
||||
path = "/api/test"
|
||||
query_params = {}
|
||||
headers = {}
|
||||
payload = b'{"test": "data"}'
|
||||
|
||||
result = provider._create_signature(
|
||||
method, path, query_params, headers, payload
|
||||
)
|
||||
|
||||
assert "X-Date" in result
|
||||
assert "Host" in result
|
||||
assert "X-Content-Sha256" in result
|
||||
assert "Content-Type" in result
|
||||
assert "Authorization" in result
|
||||
assert "HMAC-SHA256" in result["Authorization"]
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.requests.request")
|
||||
def test_make_signed_request_success(self, mock_request, provider):
|
||||
"""Test successful signed request"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"code": 0, "data": {}}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
result = provider._make_signed_request(
|
||||
"POST", "/api/test", data={"test": "data"}
|
||||
)
|
||||
|
||||
assert result == mock_response
|
||||
mock_request.assert_called_once()
|
||||
|
||||
# Verify the call arguments
|
||||
call_args = mock_request.call_args
|
||||
assert call_args[1]["method"] == "POST"
|
||||
assert call_args[1]["url"] == f"https://{provider.api_url}/api/test"
|
||||
assert call_args[1]["timeout"] == 30
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.requests.request")
|
||||
def test_make_signed_request_with_exception(self, mock_request, provider):
|
||||
"""Test signed request with exception"""
|
||||
mock_request.side_effect = Exception("Network error")
|
||||
|
||||
with pytest.raises(ValueError, match="Request failed: Network error"):
|
||||
provider._make_signed_request("GET", "/api/test")
|
||||
|
||||
|
||||
class TestVikingDBKnowledgeBaseProviderQueryRelevantDocuments:
|
||||
@pytest.fixture
|
||||
def provider(self, env_vars):
|
||||
return VikingDBKnowledgeBaseProvider()
|
||||
|
||||
def test_query_relevant_documents_empty_resources(self, provider):
|
||||
"""Test querying with empty resources list"""
|
||||
result = provider.query_relevant_documents("test query", [])
|
||||
assert result == []
|
||||
|
||||
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
|
||||
def test_query_relevant_documents_success(self, mock_request, provider):
|
||||
"""Test successful document query"""
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"result_list": [
|
||||
{
|
||||
"doc_info": {
|
||||
"doc_id": "doc123",
|
||||
"doc_name": "Test Document",
|
||||
},
|
||||
"content": "Test content",
|
||||
"score": 0.95,
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
resources = [MockResource("rag://dataset/123")]
|
||||
result = provider.query_relevant_documents("test query", resources)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == "doc123"
|
||||
assert result[0].title == "Test Document"
|
||||
assert len(result[0].chunks) == 1
|
||||
assert result[0].chunks[0].content == "Test content"
|
||||
assert result[0].chunks[0].similarity == 0.95
|
||||
|
||||
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
|
||||
def test_query_relevant_documents_with_document_filter(
|
||||
self, mock_request, provider
|
||||
):
|
||||
"""Test document query with document ID filter"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"code": 0, "data": {"result_list": []}}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
resources = [MockResource("rag://dataset/123#doc456")]
|
||||
provider.query_relevant_documents("test query", resources)
|
||||
|
||||
# Verify that query_param with doc_filter was included in the request
|
||||
call_args = mock_request.call_args
|
||||
request_data = call_args[1]["data"]
|
||||
assert "query_param" in request_data
|
||||
assert "doc_filter" in request_data["query_param"]
|
||||
|
||||
doc_filter = request_data["query_param"]["doc_filter"]
|
||||
assert doc_filter["op"] == "must"
|
||||
assert doc_filter["field"] == "doc_id"
|
||||
assert doc_filter["conds"] == ["doc456"]
|
||||
|
||||
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
|
||||
def test_query_relevant_documents_api_error(self, mock_request, provider):
|
||||
"""Test handling of API error response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"code": 1, "message": "API Error"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
resources = [MockResource("rag://dataset/123")]
|
||||
with pytest.raises(
|
||||
ValueError, match="Failed to query documents from resource: API Error"
|
||||
):
|
||||
provider.query_relevant_documents("test query", resources)
|
||||
|
||||
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
|
||||
def test_query_relevant_documents_json_decode_error(self, mock_request, provider):
|
||||
"""Test handling of JSON decode error"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
resources = [MockResource("rag://dataset/123")]
|
||||
with pytest.raises(ValueError, match="Failed to parse JSON response"):
|
||||
provider.query_relevant_documents("test query", resources)
|
||||
|
||||
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
|
||||
def test_query_relevant_documents_multiple_resources(self, mock_request, provider):
|
||||
"""Test querying multiple resources and merging results"""
|
||||
# Mock responses for different resources
|
||||
responses = [
|
||||
{
|
||||
"code": 0,
|
||||
"data": {
|
||||
"result_list": [
|
||||
{
|
||||
"doc_info": {
|
||||
"doc_id": "doc1",
|
||||
"doc_name": "Document 1",
|
||||
},
|
||||
"content": "Content 1",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
{
|
||||
"code": 0,
|
||||
"data": {
|
||||
"result_list": [
|
||||
{
|
||||
"doc_info": {
|
||||
"doc_id": "doc1",
|
||||
"doc_name": "Document 1",
|
||||
},
|
||||
"content": "Content 2",
|
||||
"score": 0.8,
|
||||
},
|
||||
{
|
||||
"doc_info": {
|
||||
"doc_id": "doc2",
|
||||
"doc_name": "Document 2",
|
||||
},
|
||||
"content": "Content 3",
|
||||
"score": 0.7,
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
mock_responses = [MagicMock() for _ in responses]
|
||||
for i, resp in enumerate(responses):
|
||||
mock_responses[i].json.return_value = resp
|
||||
mock_request.side_effect = mock_responses
|
||||
|
||||
resources = [
|
||||
MockResource("rag://dataset/123"),
|
||||
MockResource("rag://dataset/456"),
|
||||
]
|
||||
result = provider.query_relevant_documents("test query", resources)
|
||||
|
||||
# Should have 2 documents: doc1 (with 2 chunks) and doc2 (with 1 chunk)
|
||||
assert len(result) == 2
|
||||
doc1 = next(doc for doc in result if doc.id == "doc1")
|
||||
doc2 = next(doc for doc in result if doc.id == "doc2")
|
||||
assert len(doc1.chunks) == 2
|
||||
assert len(doc2.chunks) == 1
|
||||
|
||||
|
||||
class TestVikingDBKnowledgeBaseProviderListResources:
|
||||
@pytest.fixture
|
||||
def provider(self, env_vars):
|
||||
return VikingDBKnowledgeBaseProvider()
|
||||
|
||||
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
|
||||
def test_list_resources_success(self, mock_request, provider):
|
||||
"""Test successful resource listing"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"collection_list": [
|
||||
{
|
||||
"resource_id": "123",
|
||||
"collection_name": "Dataset 1",
|
||||
"description": "Description 1",
|
||||
},
|
||||
{
|
||||
"resource_id": "456",
|
||||
"collection_name": "Dataset 2",
|
||||
"description": "Description 2",
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
result = provider.list_resources()
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].uri == "rag://dataset/123"
|
||||
assert result[0].title == "Dataset 1"
|
||||
assert result[0].description == "Description 1"
|
||||
assert result[1].uri == "rag://dataset/456"
|
||||
assert result[1].title == "Dataset 2"
|
||||
assert result[1].description == "Description 2"
|
||||
|
||||
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
|
||||
def test_list_resources_with_query_filter(self, mock_request, provider):
|
||||
"""Test resource listing with query filter"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"collection_list": [
|
||||
{
|
||||
"resource_id": "123",
|
||||
"collection_name": "Test Dataset",
|
||||
"description": "Description",
|
||||
},
|
||||
{
|
||||
"resource_id": "456",
|
||||
"collection_name": "Other Dataset",
|
||||
"description": "Description",
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
result = provider.list_resources("test")
|
||||
|
||||
# Should only return the dataset with "test" in the name
|
||||
assert len(result) == 1
|
||||
assert result[0].title == "Test Dataset"
|
||||
|
||||
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
|
||||
def test_list_resources_api_error(self, mock_request, provider):
|
||||
"""Test handling of API error in list_resources"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"code": 1, "message": "API Error"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception, match="Failed to list resources: API Error"):
|
||||
provider.list_resources()
|
||||
|
||||
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
|
||||
def test_list_resources_json_decode_error(self, mock_request, provider):
|
||||
"""Test handling of JSON decode error in list_resources"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to parse JSON response"):
|
||||
provider.list_resources()
|
||||
|
||||
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
|
||||
def test_list_resources_empty_response(self, mock_request, provider):
|
||||
"""Test handling of empty response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"code": 0, "data": {"collection_list": []}}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
result = provider.list_resources()
|
||||
assert result == []
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,168 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from pydantic import ValidationError
|
||||
|
||||
import src.server.mcp_utils as mcp_utils # Assuming mcp_utils is the module to test
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.rag.retriever import Resource
|
||||
from src.server.chat_request import (
|
||||
ChatMessage,
|
||||
ChatRequest,
|
||||
ContentItem,
|
||||
EnhancePromptRequest,
|
||||
GeneratePodcastRequest,
|
||||
GeneratePPTRequest,
|
||||
GenerateProseRequest,
|
||||
TTSRequest,
|
||||
)
|
||||
|
||||
|
||||
def test_content_item_text_and_image():
|
||||
item_text = ContentItem(type="text", text="hello")
|
||||
assert item_text.type == "text"
|
||||
assert item_text.text == "hello"
|
||||
assert item_text.image_url is None
|
||||
|
||||
item_image = ContentItem(type="image", image_url="http://img.com/1.png")
|
||||
assert item_image.type == "image"
|
||||
assert item_image.text is None
|
||||
assert item_image.image_url == "http://img.com/1.png"
|
||||
|
||||
|
||||
def test_chat_message_with_string_content():
|
||||
msg = ChatMessage(role="user", content="Hello!")
|
||||
assert msg.role == "user"
|
||||
assert msg.content == "Hello!"
|
||||
|
||||
|
||||
def test_chat_message_with_content_items():
|
||||
items = [ContentItem(type="text", text="hi")]
|
||||
msg = ChatMessage(role="assistant", content=items)
|
||||
assert msg.role == "assistant"
|
||||
assert isinstance(msg.content, list)
|
||||
assert msg.content[0].type == "text"
|
||||
|
||||
|
||||
def test_chat_request_defaults():
|
||||
req = ChatRequest()
|
||||
assert req.messages == []
|
||||
assert req.resources == []
|
||||
assert req.debug is False
|
||||
assert req.thread_id == "__default__"
|
||||
assert req.max_plan_iterations == 1
|
||||
assert req.max_step_num == 3
|
||||
assert req.max_search_results == 3
|
||||
assert req.auto_accepted_plan is False
|
||||
assert req.interrupt_feedback is None
|
||||
assert req.mcp_settings is None
|
||||
assert req.enable_background_investigation is True
|
||||
assert req.report_style == ReportStyle.ACADEMIC
|
||||
|
||||
|
||||
def test_chat_request_with_values():
|
||||
resource = Resource(
|
||||
name="test", type="doc", uri="some-uri-value", title="some-title-value"
|
||||
)
|
||||
msg = ChatMessage(role="user", content="hi")
|
||||
req = ChatRequest(
|
||||
messages=[msg],
|
||||
resources=[resource],
|
||||
debug=True,
|
||||
thread_id="tid",
|
||||
max_plan_iterations=2,
|
||||
max_step_num=5,
|
||||
max_search_results=10,
|
||||
auto_accepted_plan=True,
|
||||
interrupt_feedback="stop",
|
||||
mcp_settings={"foo": "bar"},
|
||||
enable_background_investigation=False,
|
||||
report_style="academic",
|
||||
)
|
||||
assert req.messages[0].role == "user"
|
||||
assert req.debug is True
|
||||
assert req.thread_id == "tid"
|
||||
assert req.max_plan_iterations == 2
|
||||
assert req.max_step_num == 5
|
||||
assert req.max_search_results == 10
|
||||
assert req.auto_accepted_plan is True
|
||||
assert req.interrupt_feedback == "stop"
|
||||
assert req.mcp_settings == {"foo": "bar"}
|
||||
assert req.enable_background_investigation is False
|
||||
assert req.report_style == ReportStyle.ACADEMIC
|
||||
|
||||
|
||||
def test_tts_request_defaults():
|
||||
req = TTSRequest(text="hello")
|
||||
assert req.text == "hello"
|
||||
assert req.voice_type == "BV700_V2_streaming"
|
||||
assert req.encoding == "mp3"
|
||||
assert req.speed_ratio == 1.0
|
||||
assert req.volume_ratio == 1.0
|
||||
assert req.pitch_ratio == 1.0
|
||||
assert req.text_type == "plain"
|
||||
assert req.with_frontend == 1
|
||||
assert req.frontend_type == "unitTson"
|
||||
|
||||
|
||||
def test_generate_podcast_request():
|
||||
req = GeneratePodcastRequest(content="Podcast content")
|
||||
assert req.content == "Podcast content"
|
||||
|
||||
|
||||
def test_generate_ppt_request():
|
||||
req = GeneratePPTRequest(content="PPT content")
|
||||
assert req.content == "PPT content"
|
||||
|
||||
|
||||
def test_generate_prose_request():
|
||||
req = GenerateProseRequest(prompt="Write a poem", option="poet", command="rhyme")
|
||||
assert req.prompt == "Write a poem"
|
||||
assert req.option == "poet"
|
||||
assert req.command == "rhyme"
|
||||
|
||||
req2 = GenerateProseRequest(prompt="Write", option="short")
|
||||
assert req2.command == ""
|
||||
|
||||
|
||||
def test_enhance_prompt_request_defaults():
|
||||
req = EnhancePromptRequest(prompt="Improve this")
|
||||
assert req.prompt == "Improve this"
|
||||
assert req.context == ""
|
||||
assert req.report_style == "academic"
|
||||
|
||||
|
||||
def test_content_item_validation_error():
|
||||
with pytest.raises(ValidationError):
|
||||
ContentItem() # missing required 'type'
|
||||
|
||||
|
||||
def test_chat_message_validation_error():
|
||||
with pytest.raises(ValidationError):
|
||||
ChatMessage(role="user") # missing content
|
||||
|
||||
|
||||
def test_tts_request_validation_error():
|
||||
with pytest.raises(ValidationError):
|
||||
TTSRequest() # missing required 'text'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock)
|
||||
@patch("src.server.mcp_utils.StdioServerParameters")
|
||||
@patch("src.server.mcp_utils.stdio_client")
|
||||
async def test_load_mcp_tools_exception_handling(
|
||||
mock_stdio_client, mock_StdioServerParameters, mock_get_tools
|
||||
): # Changed to async def
|
||||
mock_get_tools.side_effect = Exception("unexpected error")
|
||||
mock_StdioServerParameters.return_value = MagicMock()
|
||||
mock_stdio_client.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mcp_utils.load_mcp_tools(server_type="stdio", command="node") # Use await
|
||||
assert exc.value.status_code == 500
|
||||
assert "unexpected error" in exc.value.detail
|
||||
@@ -1,77 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
|
||||
|
||||
|
||||
def test_mcp_server_metadata_request_required_fields():
|
||||
# 'transport' is required
|
||||
req = MCPServerMetadataRequest(transport="stdio")
|
||||
assert req.transport == "stdio"
|
||||
assert req.command is None
|
||||
assert req.args is None
|
||||
assert req.url is None
|
||||
assert req.env is None
|
||||
assert req.timeout_seconds is None
|
||||
assert req.sse_read_timeout is None
|
||||
|
||||
|
||||
def test_mcp_server_metadata_request_optional_fields():
|
||||
req = MCPServerMetadataRequest(
|
||||
transport="sse",
|
||||
command="run",
|
||||
args=["--foo", "bar"],
|
||||
url="http://localhost:8080",
|
||||
env={"FOO": "BAR"},
|
||||
timeout_seconds=30,
|
||||
sse_read_timeout=15,
|
||||
)
|
||||
assert req.transport == "sse"
|
||||
assert req.command == "run"
|
||||
assert req.args == ["--foo", "bar"]
|
||||
assert req.url == "http://localhost:8080"
|
||||
assert req.env == {"FOO": "BAR"}
|
||||
assert req.timeout_seconds == 30
|
||||
assert req.sse_read_timeout == 15
|
||||
|
||||
|
||||
def test_mcp_server_metadata_request_missing_transport():
|
||||
with pytest.raises(ValidationError):
|
||||
MCPServerMetadataRequest()
|
||||
|
||||
|
||||
def test_mcp_server_metadata_response_required_fields():
|
||||
resp = MCPServerMetadataResponse(transport="stdio")
|
||||
assert resp.transport == "stdio"
|
||||
assert resp.command is None
|
||||
assert resp.args is None
|
||||
assert resp.url is None
|
||||
assert resp.env is None
|
||||
assert resp.tools == []
|
||||
|
||||
|
||||
def test_mcp_server_metadata_response_optional_fields():
|
||||
resp = MCPServerMetadataResponse(
|
||||
transport="sse",
|
||||
command="run",
|
||||
args=["--foo", "bar"],
|
||||
url="http://localhost:8080",
|
||||
env={"FOO": "BAR"},
|
||||
tools=["tool1", "tool2"],
|
||||
)
|
||||
assert resp.transport == "sse"
|
||||
assert resp.command == "run"
|
||||
assert resp.args == ["--foo", "bar"]
|
||||
assert resp.url == "http://localhost:8080"
|
||||
assert resp.env == {"FOO": "BAR"}
|
||||
assert resp.tools == ["tool1", "tool2"]
|
||||
|
||||
|
||||
def test_mcp_server_metadata_response_tools_default_factory():
|
||||
resp1 = MCPServerMetadataResponse(transport="stdio")
|
||||
resp2 = MCPServerMetadataResponse(transport="stdio")
|
||||
resp1.tools.append("toolA")
|
||||
assert resp2.tools == [] # Should not share list between instances
|
||||
@@ -1,185 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
import src.server.mcp_utils as mcp_utils
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.mcp_utils.ClientSession")
|
||||
async def test__get_tools_from_client_session_success(mock_ClientSession):
|
||||
mock_read = AsyncMock()
|
||||
mock_write = AsyncMock()
|
||||
mock_callback = AsyncMock()
|
||||
mock_context_manager = AsyncMock()
|
||||
mock_context_manager.__aenter__.return_value = (
|
||||
mock_read,
|
||||
mock_write,
|
||||
mock_callback,
|
||||
)
|
||||
mock_context_manager.__aexit__.return_value = None
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__.return_value = mock_session
|
||||
mock_session.__aexit__.return_value = None
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_tools_obj = MagicMock()
|
||||
mock_tools_obj.tools = ["tool1", "tool2"]
|
||||
mock_session.list_tools = AsyncMock(return_value=mock_tools_obj)
|
||||
mock_ClientSession.return_value = mock_session
|
||||
|
||||
result = await mcp_utils._get_tools_from_client_session(
|
||||
mock_context_manager, timeout_seconds=5
|
||||
)
|
||||
assert result == ["tool1", "tool2"]
|
||||
mock_session.initialize.assert_awaited_once()
|
||||
mock_session.list_tools.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock)
|
||||
@patch("src.server.mcp_utils.StdioServerParameters")
|
||||
@patch("src.server.mcp_utils.stdio_client")
|
||||
async def test_load_mcp_tools_stdio_success(
|
||||
mock_stdio_client, mock_StdioServerParameters, mock_get_tools
|
||||
):
|
||||
mock_get_tools.return_value = ["toolA"]
|
||||
params = MagicMock()
|
||||
mock_StdioServerParameters.return_value = params
|
||||
mock_client = MagicMock()
|
||||
mock_stdio_client.return_value = mock_client
|
||||
|
||||
result = await mcp_utils.load_mcp_tools(
|
||||
server_type="stdio",
|
||||
command="node",
|
||||
args=["server.js"],
|
||||
env={"API_KEY": "test123"},
|
||||
timeout_seconds=3,
|
||||
)
|
||||
assert result == ["toolA"]
|
||||
mock_StdioServerParameters.assert_called_once_with(
|
||||
command="node", args=["server.js"], env={"API_KEY": "test123"}
|
||||
)
|
||||
mock_stdio_client.assert_called_once_with(params)
|
||||
mock_get_tools.assert_awaited_once_with(mock_client, 3)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_mcp_tools_stdio_missing_command():
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mcp_utils.load_mcp_tools(server_type="stdio")
|
||||
assert exc.value.status_code == 400
|
||||
assert "Command is required" in exc.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock)
|
||||
@patch("src.server.mcp_utils.sse_client")
|
||||
async def test_load_mcp_tools_sse_success(mock_sse_client, mock_get_tools):
|
||||
mock_get_tools.return_value = ["toolB"]
|
||||
mock_client = MagicMock()
|
||||
mock_sse_client.return_value = mock_client
|
||||
|
||||
result = await mcp_utils.load_mcp_tools(
|
||||
server_type="sse",
|
||||
url="http://localhost:1234",
|
||||
headers={"Authorization": "Bearer 1234567890"},
|
||||
timeout_seconds=7,
|
||||
)
|
||||
assert result == ["toolB"]
|
||||
# When sse_read_timeout is None, it should not be passed
|
||||
mock_sse_client.assert_called_once_with(
|
||||
url="http://localhost:1234",
|
||||
headers={"Authorization": "Bearer 1234567890"},
|
||||
timeout=7,
|
||||
)
|
||||
mock_get_tools.assert_awaited_once_with(mock_client, 7)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock)
|
||||
@patch("src.server.mcp_utils.sse_client")
|
||||
async def test_load_mcp_tools_sse_with_sse_read_timeout(mock_sse_client, mock_get_tools):
|
||||
"""Test that sse_read_timeout parameter is used when provided."""
|
||||
mock_get_tools.return_value = ["toolC"]
|
||||
mock_client = MagicMock()
|
||||
mock_sse_client.return_value = mock_client
|
||||
|
||||
result = await mcp_utils.load_mcp_tools(
|
||||
server_type="sse",
|
||||
url="http://localhost:1234",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
timeout_seconds=10,
|
||||
sse_read_timeout=5,
|
||||
)
|
||||
assert result == ["toolC"]
|
||||
# Both timeout_seconds and sse_read_timeout should be passed
|
||||
mock_sse_client.assert_called_once_with(
|
||||
url="http://localhost:1234",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
timeout=10,
|
||||
sse_read_timeout=5,
|
||||
)
|
||||
# But timeout_seconds should be used for the session timeout
|
||||
mock_get_tools.assert_awaited_once_with(mock_client, 10)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock)
|
||||
@patch("src.server.mcp_utils.sse_client")
|
||||
async def test_load_mcp_tools_sse_without_sse_read_timeout(mock_sse_client, mock_get_tools):
|
||||
"""Test that timeout_seconds is used when sse_read_timeout is not provided."""
|
||||
mock_get_tools.return_value = ["toolD"]
|
||||
mock_client = MagicMock()
|
||||
mock_sse_client.return_value = mock_client
|
||||
|
||||
result = await mcp_utils.load_mcp_tools(
|
||||
server_type="sse",
|
||||
url="http://localhost:1234",
|
||||
timeout_seconds=20,
|
||||
)
|
||||
assert result == ["toolD"]
|
||||
# When sse_read_timeout is not provided, it should not be passed
|
||||
mock_sse_client.assert_called_once_with(
|
||||
url="http://localhost:1234",
|
||||
headers=None,
|
||||
timeout=20,
|
||||
)
|
||||
mock_get_tools.assert_awaited_once_with(mock_client, 20)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_mcp_tools_sse_missing_url():
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mcp_utils.load_mcp_tools(server_type="sse")
|
||||
assert exc.value.status_code == 400
|
||||
assert "URL is required" in exc.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_mcp_tools_unsupported_type():
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mcp_utils.load_mcp_tools(server_type="unknown")
|
||||
assert exc.value.status_code == 400
|
||||
assert "Invalid transport type" in exc.value.detail or "Unsupported server type" in exc.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock)
|
||||
@patch("src.server.mcp_utils.StdioServerParameters")
|
||||
@patch("src.server.mcp_utils.stdio_client")
|
||||
async def test_load_mcp_tools_exception_handling(
|
||||
mock_stdio_client, mock_StdioServerParameters, mock_get_tools
|
||||
):
|
||||
mock_get_tools.side_effect = Exception("unexpected error")
|
||||
mock_StdioServerParameters.return_value = MagicMock()
|
||||
mock_stdio_client.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mcp_utils.load_mcp_tools(server_type="stdio", command="node")
|
||||
assert exc.value.status_code == 500
|
||||
assert "unexpected error" in exc.value.detail
|
||||
@@ -1,450 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for MCP server configuration validators.
|
||||
|
||||
Tests cover:
|
||||
- Command validation (allowlist)
|
||||
- Argument validation (path traversal, command injection)
|
||||
- Environment variable validation
|
||||
- URL validation
|
||||
- Header validation
|
||||
- Full config validation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.mcp_validators import (
|
||||
ALLOWED_COMMANDS,
|
||||
MCPValidationError,
|
||||
validate_args_for_local_file_access,
|
||||
validate_command,
|
||||
validate_command_injection,
|
||||
validate_environment_variables,
|
||||
validate_headers,
|
||||
validate_mcp_server_config,
|
||||
validate_url,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateCommand:
|
||||
"""Tests for validate_command function."""
|
||||
|
||||
def test_allowed_commands(self):
|
||||
"""Test that all allowed commands pass validation."""
|
||||
for cmd in ALLOWED_COMMANDS:
|
||||
validate_command(cmd) # Should not raise
|
||||
|
||||
def test_allowed_command_with_path(self):
|
||||
"""Test that commands with paths are validated by base name."""
|
||||
validate_command("/usr/bin/python3")
|
||||
validate_command("/usr/local/bin/node")
|
||||
validate_command("C:\\Python\\python.exe")
|
||||
|
||||
def test_disallowed_command(self):
|
||||
"""Test that disallowed commands raise an error."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_command("bash")
|
||||
assert "not allowed" in exc_info.value.message
|
||||
assert exc_info.value.field == "command"
|
||||
|
||||
def test_disallowed_dangerous_commands(self):
|
||||
"""Test that dangerous commands are rejected."""
|
||||
dangerous_commands = ["rm", "sudo", "chmod", "chown", "curl", "wget", "sh"]
|
||||
for cmd in dangerous_commands:
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command(cmd)
|
||||
|
||||
def test_empty_command(self):
|
||||
"""Test that empty command raises an error."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command("")
|
||||
|
||||
def test_none_command(self):
|
||||
"""Test that None command raises an error."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command(None)
|
||||
|
||||
|
||||
class TestValidateArgsForLocalFileAccess:
|
||||
"""Tests for validate_args_for_local_file_access function."""
|
||||
|
||||
def test_safe_args(self):
|
||||
"""Test that safe arguments pass validation."""
|
||||
safe_args = [
|
||||
["--help"],
|
||||
["-v", "--verbose"],
|
||||
["package-name"],
|
||||
["--config", "config.json"],
|
||||
["run", "script.py"],
|
||||
]
|
||||
for args in safe_args:
|
||||
validate_args_for_local_file_access(args) # Should not raise
|
||||
|
||||
def test_directory_traversal(self):
|
||||
"""Test that directory traversal patterns are rejected."""
|
||||
traversal_patterns = [
|
||||
["../etc/passwd"],
|
||||
["..\\windows\\system32"],
|
||||
["../../secret"],
|
||||
["foo/../bar/../../../etc/passwd"],
|
||||
["foo/.."], # ".." at end after path separator
|
||||
["bar\\.."], # ".." at end after Windows path separator
|
||||
["path/to/foo/.."], # Longer path ending with ".."
|
||||
]
|
||||
for args in traversal_patterns:
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_args_for_local_file_access(args)
|
||||
assert "traversal" in exc_info.value.message.lower()
|
||||
|
||||
def test_absolute_path_with_dangerous_extension(self):
|
||||
"""Test that absolute paths with dangerous extensions are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_args_for_local_file_access(["/etc/passwd.sh"])
|
||||
|
||||
def test_windows_absolute_path(self):
|
||||
"""Test that Windows absolute paths are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_args_for_local_file_access(["C:\\Windows\\system32"])
|
||||
|
||||
def test_home_directory_reference(self):
|
||||
"""Test that home directory references are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_args_for_local_file_access(["~/secrets"])
|
||||
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_args_for_local_file_access(["~\\secrets"])
|
||||
|
||||
def test_null_byte(self):
|
||||
"""Test that null bytes in arguments are rejected."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_args_for_local_file_access(["file\x00.txt"])
|
||||
assert "null byte" in exc_info.value.message.lower()
|
||||
|
||||
def test_excessively_long_argument(self):
|
||||
"""Test that excessively long arguments are rejected."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_args_for_local_file_access(["a" * 1001])
|
||||
assert "maximum length" in exc_info.value.message.lower()
|
||||
|
||||
def test_dangerous_extensions(self):
|
||||
"""Test that dangerous file extensions are rejected."""
|
||||
dangerous_files = [
|
||||
["script.sh"],
|
||||
["binary.exe"],
|
||||
["library.dll"],
|
||||
["secret.env"],
|
||||
["key.pem"],
|
||||
]
|
||||
for args in dangerous_files:
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_args_for_local_file_access(args)
|
||||
assert "dangerous file type" in exc_info.value.message.lower()
|
||||
|
||||
def test_empty_args(self):
|
||||
"""Test that empty args list passes validation."""
|
||||
validate_args_for_local_file_access([])
|
||||
validate_args_for_local_file_access(None)
|
||||
|
||||
|
||||
class TestValidateCommandInjection:
|
||||
"""Tests for validate_command_injection function."""
|
||||
|
||||
def test_safe_args(self):
|
||||
"""Test that safe arguments pass validation."""
|
||||
safe_args = [
|
||||
["--help"],
|
||||
["package-name"],
|
||||
["@scope/package"],
|
||||
["file.json"],
|
||||
]
|
||||
for args in safe_args:
|
||||
validate_command_injection(args) # Should not raise
|
||||
|
||||
def test_shell_metacharacters(self):
|
||||
"""Test that shell metacharacters are rejected."""
|
||||
metachar_args = [
|
||||
["foo; rm -rf /"],
|
||||
["foo & bar"],
|
||||
["foo | cat /etc/passwd"],
|
||||
["$(whoami)"],
|
||||
["`id`"],
|
||||
["foo > /etc/passwd"],
|
||||
["foo < /etc/passwd"],
|
||||
["${PATH}"],
|
||||
]
|
||||
for args in metachar_args:
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_command_injection(args)
|
||||
assert "args" == exc_info.value.field
|
||||
|
||||
def test_command_chaining(self):
|
||||
"""Test that command chaining patterns are rejected."""
|
||||
chaining_args = [
|
||||
["foo && bar"],
|
||||
["foo || bar"],
|
||||
["foo;; bar"],
|
||||
["foo >> output"],
|
||||
["foo << input"],
|
||||
]
|
||||
for args in chaining_args:
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command_injection(args)
|
||||
|
||||
def test_backtick_injection(self):
|
||||
"""Test that backtick command substitution is rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command_injection(["`whoami`"])
|
||||
|
||||
def test_process_substitution(self):
|
||||
"""Test that process substitution is rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command_injection(["<(cat /etc/passwd)"])
|
||||
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command_injection([">(tee /tmp/out)"])
|
||||
|
||||
|
||||
class TestValidateEnvironmentVariables:
|
||||
"""Tests for validate_environment_variables function."""
|
||||
|
||||
def test_safe_env_vars(self):
|
||||
"""Test that safe environment variables pass validation."""
|
||||
safe_env = {
|
||||
"API_KEY": "secret123",
|
||||
"DEBUG": "true",
|
||||
"MY_VARIABLE": "value",
|
||||
}
|
||||
validate_environment_variables(safe_env) # Should not raise
|
||||
|
||||
def test_dangerous_env_vars(self):
|
||||
"""Test that dangerous environment variables are rejected."""
|
||||
dangerous_vars = [
|
||||
{"PATH": "/malicious/path"},
|
||||
{"LD_LIBRARY_PATH": "/malicious/lib"},
|
||||
{"DYLD_LIBRARY_PATH": "/malicious/lib"},
|
||||
{"LD_PRELOAD": "/malicious/lib.so"},
|
||||
{"PYTHONPATH": "/malicious/python"},
|
||||
{"NODE_PATH": "/malicious/node"},
|
||||
]
|
||||
for env in dangerous_vars:
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_environment_variables(env)
|
||||
assert "not allowed" in exc_info.value.message.lower()
|
||||
|
||||
def test_null_byte_in_value(self):
|
||||
"""Test that null bytes in values are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_environment_variables({"KEY": "value\x00malicious"})
|
||||
|
||||
def test_empty_env(self):
|
||||
"""Test that empty env dict passes validation."""
|
||||
validate_environment_variables({})
|
||||
validate_environment_variables(None)
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
"""Tests for validate_url function."""
|
||||
|
||||
def test_valid_urls(self):
|
||||
"""Test that valid URLs pass validation."""
|
||||
valid_urls = [
|
||||
"http://localhost:3000",
|
||||
"https://api.example.com",
|
||||
"http://192.168.1.1:8080/api",
|
||||
"https://mcp.example.com/sse",
|
||||
]
|
||||
for url in valid_urls:
|
||||
validate_url(url) # Should not raise
|
||||
|
||||
def test_invalid_scheme(self):
|
||||
"""Test that invalid URL schemes are rejected."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_url("ftp://example.com")
|
||||
assert "scheme" in exc_info.value.message.lower()
|
||||
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_url("file:///etc/passwd")
|
||||
|
||||
def test_credentials_in_url(self):
|
||||
"""Test that URLs with credentials are rejected."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_url("https://user:pass@example.com")
|
||||
assert "credentials" in exc_info.value.message.lower()
|
||||
|
||||
def test_null_byte_in_url(self):
|
||||
"""Test that null bytes in URL are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_url("https://example.com\x00/malicious")
|
||||
|
||||
def test_empty_url(self):
|
||||
"""Test that empty URL raises an error."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_url("")
|
||||
|
||||
def test_no_host(self):
|
||||
"""Test that URL without host raises an error."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_url("http:///path")
|
||||
|
||||
|
||||
class TestValidateHeaders:
|
||||
"""Tests for validate_headers function."""
|
||||
|
||||
def test_valid_headers(self):
|
||||
"""Test that valid headers pass validation."""
|
||||
valid_headers = {
|
||||
"Authorization": "Bearer token123",
|
||||
"Content-Type": "application/json",
|
||||
"X-Custom-Header": "value",
|
||||
}
|
||||
validate_headers(valid_headers) # Should not raise
|
||||
|
||||
def test_newline_in_header_name(self):
|
||||
"""Test that newlines in header names are rejected (HTTP header injection)."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_headers({"X-Bad\nHeader": "value"})
|
||||
assert "newline" in exc_info.value.message.lower()
|
||||
|
||||
def test_newline_in_header_value(self):
|
||||
"""Test that newlines in header values are rejected (HTTP header injection)."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_headers({"X-Header": "value\r\nX-Injected: malicious"})
|
||||
|
||||
def test_null_byte_in_header(self):
|
||||
"""Test that null bytes in headers are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_headers({"X-Header": "value\x00"})
|
||||
|
||||
def test_empty_headers(self):
|
||||
"""Test that empty headers dict passes validation."""
|
||||
validate_headers({})
|
||||
validate_headers(None)
|
||||
|
||||
|
||||
class TestValidateMCPServerConfig:
|
||||
"""Tests for the main validate_mcp_server_config function."""
|
||||
|
||||
def test_valid_stdio_config(self):
|
||||
"""Test valid stdio configuration."""
|
||||
validate_mcp_server_config(
|
||||
transport="stdio",
|
||||
command="npx",
|
||||
args=["@modelcontextprotocol/server-filesystem"],
|
||||
env={"API_KEY": "secret"},
|
||||
) # Should not raise
|
||||
|
||||
def test_valid_sse_config(self):
|
||||
"""Test valid SSE configuration."""
|
||||
validate_mcp_server_config(
|
||||
transport="sse",
|
||||
url="https://api.example.com/sse",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
) # Should not raise
|
||||
|
||||
def test_valid_http_config(self):
|
||||
"""Test valid streamable_http configuration."""
|
||||
validate_mcp_server_config(
|
||||
transport="streamable_http",
|
||||
url="https://api.example.com/mcp",
|
||||
) # Should not raise
|
||||
|
||||
def test_invalid_transport(self):
|
||||
"""Test that invalid transport type raises an error."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_mcp_server_config(transport="invalid")
|
||||
assert "Invalid transport type" in exc_info.value.message
|
||||
|
||||
def test_combined_validation_errors(self):
|
||||
"""Test that multiple validation errors are combined."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_mcp_server_config(
|
||||
transport="stdio",
|
||||
command="bash", # Not allowed
|
||||
args=["../etc/passwd"], # Path traversal
|
||||
env={"PATH": "/malicious"}, # Dangerous env var
|
||||
)
|
||||
# All errors should be combined
|
||||
assert "not allowed" in exc_info.value.message
|
||||
assert "traversal" in exc_info.value.message.lower()
|
||||
|
||||
def test_non_strict_mode(self):
|
||||
"""Test that non-strict mode logs warnings instead of raising."""
|
||||
# Should not raise, but would log warnings
|
||||
validate_mcp_server_config(
|
||||
transport="stdio",
|
||||
command="bash",
|
||||
strict=False,
|
||||
)
|
||||
|
||||
def test_stdio_with_dangerous_args(self):
|
||||
"""Test stdio config with command injection attempt."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_mcp_server_config(
|
||||
transport="stdio",
|
||||
command="node",
|
||||
args=["script.js; rm -rf /"],
|
||||
)
|
||||
|
||||
def test_sse_with_invalid_url(self):
|
||||
"""Test SSE config with invalid URL."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_mcp_server_config(
|
||||
transport="sse",
|
||||
url="ftp://example.com",
|
||||
)
|
||||
|
||||
|
||||
class TestMCPServerMetadataRequest:
|
||||
"""Tests for Pydantic model validation."""
|
||||
|
||||
def test_valid_request(self):
|
||||
"""Test that valid request passes validation."""
|
||||
from src.server.mcp_request import MCPServerMetadataRequest
|
||||
|
||||
request = MCPServerMetadataRequest(
|
||||
transport="stdio",
|
||||
command="npx",
|
||||
args=["@modelcontextprotocol/server-filesystem"],
|
||||
)
|
||||
assert request.transport == "stdio"
|
||||
assert request.command == "npx"
|
||||
|
||||
def test_invalid_command_raises_validation_error(self):
|
||||
"""Test that invalid command raises Pydantic ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.mcp_request import MCPServerMetadataRequest
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
MCPServerMetadataRequest(
|
||||
transport="stdio",
|
||||
command="bash",
|
||||
)
|
||||
assert "not allowed" in str(exc_info.value).lower()
|
||||
|
||||
def test_command_injection_raises_validation_error(self):
|
||||
"""Test that command injection raises Pydantic ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.mcp_request import MCPServerMetadataRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
MCPServerMetadataRequest(
|
||||
transport="stdio",
|
||||
command="node",
|
||||
args=["script.js; rm -rf /"],
|
||||
)
|
||||
|
||||
def test_invalid_url_raises_validation_error(self):
|
||||
"""Test that invalid URL raises Pydantic ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.mcp_request import MCPServerMetadataRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
MCPServerMetadataRequest(
|
||||
transport="sse",
|
||||
url="ftp://example.com",
|
||||
)
|
||||
@@ -1,317 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for tool call chunk processing.
|
||||
|
||||
Tests for the fix of issue #523: Tool name concatenation in consecutive tool calls.
|
||||
This ensures that tool call chunks are properly segregated by index to prevent
|
||||
tool names from being concatenated when multiple tool calls happen in sequence.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Import the functions to test
|
||||
# Note: We need to import from the app module
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add src directory to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../"))
|
||||
|
||||
from src.server.app import _process_tool_call_chunks, _validate_tool_call_chunks
|
||||
|
||||
|
||||
class TestProcessToolCallChunks:
|
||||
"""Test cases for _process_tool_call_chunks function."""
|
||||
|
||||
def test_empty_tool_call_chunks(self):
|
||||
"""Test processing empty tool call chunks."""
|
||||
result = _process_tool_call_chunks([])
|
||||
assert result == []
|
||||
|
||||
def test_single_tool_call_single_chunk(self):
|
||||
"""Test processing a single tool call with a single chunk."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{"query": "test"}', "id": "call_1", "index": 0}
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "web_search"
|
||||
assert result[0]["id"] == "call_1"
|
||||
assert result[0]["index"] == 0
|
||||
assert '"query": "test"' in result[0]["args"]
|
||||
|
||||
def test_consecutive_tool_calls_different_indices(self):
|
||||
"""Test that consecutive tool calls with different indices are not concatenated."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{"query": "test"}', "id": "call_1", "index": 0},
|
||||
{"name": "web_search", "args": '{"query": "test2"}', "id": "call_2", "index": 1},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "web_search"
|
||||
assert result[0]["id"] == "call_1"
|
||||
assert result[0]["index"] == 0
|
||||
assert result[1]["name"] == "web_search"
|
||||
assert result[1]["id"] == "call_2"
|
||||
assert result[1]["index"] == 1
|
||||
# Verify names are NOT concatenated
|
||||
assert result[0]["name"] != "web_searchweb_search"
|
||||
assert result[1]["name"] != "web_searchweb_search"
|
||||
|
||||
def test_different_tools_different_indices(self):
|
||||
"""Test consecutive calls to different tools."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{"query": "test"}', "id": "call_1", "index": 0},
|
||||
{"name": "crawl_tool", "args": '{"url": "http://example.com"}', "id": "call_2", "index": 1},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "web_search"
|
||||
assert result[1]["name"] == "crawl_tool"
|
||||
# Verify names are NOT concatenated (the issue bug scenario)
|
||||
assert "web_searchcrawl_tool" not in result[0]["name"]
|
||||
assert "web_searchcrawl_tool" not in result[1]["name"]
|
||||
|
||||
def test_streaming_chunks_same_index(self):
|
||||
"""Test streaming chunks with same index are properly accumulated."""
|
||||
chunks = [
|
||||
{"name": "web_", "args": '{"query"', "id": "call_1", "index": 0},
|
||||
{"name": "search", "args": ': "test"}', "id": "call_1", "index": 0},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
# Name should NOT be concatenated when it's the same tool
|
||||
assert result[0]["name"] in ["web_", "search", "web_search"]
|
||||
assert result[0]["id"] == "call_1"
|
||||
# Args should be accumulated
|
||||
assert "query" in result[0]["args"]
|
||||
assert "test" in result[0]["args"]
|
||||
|
||||
def test_tool_call_index_collision_warning(self):
|
||||
"""Test that index collision with different names generates warning."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{}', "id": "call_1", "index": 0},
|
||||
{"name": "crawl_tool", "args": '{}', "id": "call_2", "index": 0},
|
||||
]
|
||||
|
||||
# This should trigger a warning
|
||||
with patch('src.server.app.logger') as mock_logger:
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
# Verify warning was logged
|
||||
mock_logger.warning.assert_called()
|
||||
call_args = mock_logger.warning.call_args[0][0]
|
||||
assert "Tool name mismatch detected" in call_args
|
||||
assert "web_search" in call_args
|
||||
assert "crawl_tool" in call_args
|
||||
|
||||
def test_chunks_without_explicit_index(self):
|
||||
"""Test handling chunks without explicit index (edge case)."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{}', "id": "call_1"} # No index
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "web_search"
|
||||
|
||||
def test_chunk_sorting_by_index(self):
|
||||
"""Test that chunks are sorted by index in proper order."""
|
||||
chunks = [
|
||||
{"name": "tool_3", "args": '{}', "id": "call_3", "index": 2},
|
||||
{"name": "tool_1", "args": '{}', "id": "call_1", "index": 0},
|
||||
{"name": "tool_2", "args": '{}', "id": "call_2", "index": 1},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0]["index"] == 0
|
||||
assert result[1]["index"] == 1
|
||||
assert result[2]["index"] == 2
|
||||
|
||||
def test_args_accumulation(self):
|
||||
"""Test that arguments are properly accumulated for same index."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{"q', "id": "call_1", "index": 0},
|
||||
{"name": "web_search", "args": 'uery": "test"}', "id": "call_1", "index": 0},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 1
|
||||
# Sanitize removes json encoding, so just check it's accumulated
|
||||
assert len(result[0]["args"]) > 0
|
||||
|
||||
def test_preserve_tool_id(self):
|
||||
"""Test that tool IDs are preserved correctly."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{}', "id": "call_abc123", "index": 0},
|
||||
{"name": "web_search", "args": '{}', "id": "call_xyz789", "index": 1},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert result[0]["id"] == "call_abc123"
|
||||
assert result[1]["id"] == "call_xyz789"
|
||||
|
||||
def test_multiple_indices_detected(self):
|
||||
"""Test that multiple indices are properly detected and logged."""
|
||||
chunks = [
|
||||
{"name": "tool_a", "args": '{}', "id": "call_1", "index": 0},
|
||||
{"name": "tool_b", "args": '{}', "id": "call_2", "index": 1},
|
||||
{"name": "tool_c", "args": '{}', "id": "call_3", "index": 2},
|
||||
]
|
||||
|
||||
with patch('src.server.app.logger') as mock_logger:
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
# Should have debug logs for multiple indices
|
||||
debug_calls = [call[0][0] for call in mock_logger.debug.call_args_list]
|
||||
# Check if any debug call mentions multiple indices
|
||||
multiple_indices_mentioned = any(
|
||||
"Multiple indices" in call for call in debug_calls
|
||||
)
|
||||
assert multiple_indices_mentioned or len(result) == 3
|
||||
|
||||
|
||||
class TestValidateToolCallChunks:
|
||||
"""Test cases for _validate_tool_call_chunks function."""
|
||||
|
||||
def test_validate_empty_chunks(self):
|
||||
"""Test validation of empty chunks."""
|
||||
# Should not raise any exception
|
||||
_validate_tool_call_chunks([])
|
||||
|
||||
def test_validate_logs_chunk_info(self):
|
||||
"""Test that validation logs chunk information."""
|
||||
chunks = [
|
||||
{"name": "web_search", "args": '{}', "id": "call_1", "index": 0},
|
||||
]
|
||||
|
||||
with patch('src.server.app.logger') as mock_logger:
|
||||
_validate_tool_call_chunks(chunks)
|
||||
|
||||
# Should have logged debug info
|
||||
assert mock_logger.debug.called
|
||||
|
||||
def test_validate_detects_multiple_indices(self):
|
||||
"""Test that validation detects multiple indices."""
|
||||
chunks = [
|
||||
{"name": "tool_1", "args": '{}', "id": "call_1", "index": 0},
|
||||
{"name": "tool_2", "args": '{}', "id": "call_2", "index": 1},
|
||||
]
|
||||
|
||||
with patch('src.server.app.logger') as mock_logger:
|
||||
_validate_tool_call_chunks(chunks)
|
||||
|
||||
# Should have logged about multiple indices
|
||||
debug_calls = [call[0][0] for call in mock_logger.debug.call_args_list]
|
||||
multiple_indices_mentioned = any(
|
||||
"Multiple indices" in call for call in debug_calls
|
||||
)
|
||||
assert multiple_indices_mentioned
|
||||
|
||||
|
||||
class TestRealWorldScenarios:
|
||||
"""Test cases for real-world scenarios from issue #523."""
|
||||
|
||||
def test_issue_523_scenario_consecutive_web_search(self):
|
||||
"""
|
||||
Replicate issue #523: Consecutive web_search calls.
|
||||
Previously would result in "web_searchweb_search" error.
|
||||
"""
|
||||
# Simulate streaming chunks from two consecutive web_search calls
|
||||
chunks = [
|
||||
# First web_search call (index 0)
|
||||
{"name": "web_", "args": '{"query', "id": "call_1", "index": 0},
|
||||
{"name": "search", "args": '": "first query"}', "id": "call_1", "index": 0},
|
||||
# Second web_search call (index 1)
|
||||
{"name": "web_", "args": '{"query', "id": "call_2", "index": 1},
|
||||
{"name": "search", "args": '": "second query"}', "id": "call_2", "index": 1},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
# Should have 2 tool calls, not concatenated names
|
||||
assert len(result) >= 1 # At minimum should process without error
|
||||
|
||||
# Extract tool names from result
|
||||
tool_names = [chunk.get("name") for chunk in result]
|
||||
|
||||
# Verify "web_searchweb_search" error doesn't occur
|
||||
assert "web_searchweb_search" not in tool_names
|
||||
|
||||
# Both calls should have web_search (or parts of it)
|
||||
concatenated_names = "".join(tool_names)
|
||||
assert "web_search" in concatenated_names or "web_" in concatenated_names
|
||||
|
||||
def test_mixed_tools_consecutive_calls(self):
|
||||
"""Test realistic scenario with mixed tools in sequence."""
|
||||
chunks = [
|
||||
# web_search call
|
||||
{"name": "web_search", "args": '{"query": "python"}', "id": "1", "index": 0},
|
||||
# crawl_tool call
|
||||
{"name": "crawl_tool", "args": '{"url": "http://example.com"}', "id": "2", "index": 1},
|
||||
# Another web_search
|
||||
{"name": "web_search", "args": '{"query": "rust"}', "id": "3", "index": 2},
|
||||
]
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
assert len(result) == 3
|
||||
tool_names = [chunk.get("name") for chunk in result]
|
||||
|
||||
# No concatenation should occur
|
||||
assert "web_searchcrawl_tool" not in tool_names
|
||||
assert "crawl_toolweb_search" not in tool_names
|
||||
|
||||
def test_long_sequence_tool_calls(self):
|
||||
"""Test a long sequence of tool calls."""
|
||||
chunks = []
|
||||
for i in range(10):
|
||||
tool_name = "web_search" if i % 2 == 0 else "crawl_tool"
|
||||
chunks.append({
|
||||
"name": tool_name,
|
||||
"args": '{"query": "test"}' if tool_name == "web_search" else '{"url": "http://example.com"}',
|
||||
"id": f"call_{i}",
|
||||
"index": i
|
||||
})
|
||||
|
||||
result = _process_tool_call_chunks(chunks)
|
||||
|
||||
# Should process all 10 tool calls
|
||||
assert len(result) == 10
|
||||
|
||||
# Verify each tool call has correct name (not concatenated with other tool names)
|
||||
for i, chunk in enumerate(result):
|
||||
expected_name = "web_search" if i % 2 == 0 else "crawl_tool"
|
||||
actual_name = chunk.get("name", "")
|
||||
|
||||
# The actual name should be the expected name, not concatenated
|
||||
assert actual_name == expected_name, (
|
||||
f"Tool call {i} has name '{actual_name}', expected '{expected_name}'. "
|
||||
f"This indicates concatenation with adjacent tool call."
|
||||
)
|
||||
|
||||
# Verify IDs are correct
|
||||
assert chunk.get("id") == f"call_{i}"
|
||||
assert chunk.get("index") == i
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@@ -1,216 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from src.tools.crawl import crawl_tool, is_pdf_url
|
||||
|
||||
|
||||
class TestCrawlTool:
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
def test_crawl_tool_success(self, mock_crawler_class):
|
||||
# Arrange
|
||||
mock_crawler = Mock()
|
||||
mock_article = Mock()
|
||||
mock_article.to_markdown.return_value = (
|
||||
"# Test Article\nThis is test content." * 100
|
||||
)
|
||||
mock_crawler.crawl.return_value = mock_article
|
||||
mock_crawler_class.return_value = mock_crawler
|
||||
|
||||
url = "https://example.com"
|
||||
|
||||
# Act
|
||||
result = crawl_tool.invoke({"url": url})
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
result_dict = json.loads(result)
|
||||
assert result_dict["url"] == url
|
||||
assert "crawled_content" in result_dict
|
||||
assert len(result_dict["crawled_content"]) <= 1000
|
||||
mock_crawler_class.assert_called_once()
|
||||
mock_crawler.crawl.assert_called_once_with(url)
|
||||
mock_article.to_markdown.assert_called_once()
|
||||
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
def test_crawl_tool_short_content(self, mock_crawler_class):
|
||||
# Arrange
|
||||
mock_crawler = Mock()
|
||||
mock_article = Mock()
|
||||
short_content = "Short content"
|
||||
mock_article.to_markdown.return_value = short_content
|
||||
mock_crawler.crawl.return_value = mock_article
|
||||
mock_crawler_class.return_value = mock_crawler
|
||||
|
||||
url = "https://example.com"
|
||||
|
||||
# Act
|
||||
result = crawl_tool.invoke({"url": url})
|
||||
|
||||
# Assert
|
||||
result_dict = json.loads(result)
|
||||
assert result_dict["crawled_content"] == short_content
|
||||
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
@patch("src.tools.crawl.logger")
|
||||
def test_crawl_tool_crawler_exception(self, mock_logger, mock_crawler_class):
|
||||
# Arrange
|
||||
mock_crawler = Mock()
|
||||
mock_crawler.crawl.side_effect = Exception("Network error")
|
||||
mock_crawler_class.return_value = mock_crawler
|
||||
|
||||
url = "https://example.com"
|
||||
|
||||
# Act
|
||||
result = crawl_tool.invoke({"url": url})
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert "Failed to crawl" in result
|
||||
assert "Network error" in result
|
||||
mock_logger.error.assert_called_once()
|
||||
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
@patch("src.tools.crawl.logger")
|
||||
def test_crawl_tool_crawler_instantiation_exception(
|
||||
self, mock_logger, mock_crawler_class
|
||||
):
|
||||
# Arrange
|
||||
mock_crawler_class.side_effect = Exception("Crawler init error")
|
||||
|
||||
url = "https://example.com"
|
||||
|
||||
# Act
|
||||
result = crawl_tool.invoke({"url": url})
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert "Failed to crawl" in result
|
||||
assert "Crawler init error" in result
|
||||
mock_logger.error.assert_called_once()
|
||||
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
@patch("src.tools.crawl.logger")
|
||||
def test_crawl_tool_markdown_conversion_exception(
|
||||
self, mock_logger, mock_crawler_class
|
||||
):
|
||||
# Arrange
|
||||
mock_crawler = Mock()
|
||||
mock_article = Mock()
|
||||
mock_article.to_markdown.side_effect = Exception("Markdown conversion error")
|
||||
mock_crawler.crawl.return_value = mock_article
|
||||
mock_crawler_class.return_value = mock_crawler
|
||||
|
||||
url = "https://example.com"
|
||||
|
||||
# Act
|
||||
result = crawl_tool.invoke({"url": url})
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert "Failed to crawl" in result
|
||||
assert "Markdown conversion error" in result
|
||||
mock_logger.error.assert_called_once()
|
||||
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
def test_crawl_tool_with_none_content(self, mock_crawler_class):
|
||||
# Arrange
|
||||
mock_crawler = Mock()
|
||||
mock_article = Mock()
|
||||
mock_article.to_markdown.return_value = "# Article\n\n*No content available*\n"
|
||||
mock_crawler.crawl.return_value = mock_article
|
||||
mock_crawler_class.return_value = mock_crawler
|
||||
|
||||
url = "https://example.com"
|
||||
|
||||
# Act
|
||||
result = crawl_tool.invoke({"url": url})
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
result_dict = json.loads(result)
|
||||
assert result_dict["url"] == url
|
||||
assert "crawled_content" in result_dict
|
||||
assert "No content available" in result_dict["crawled_content"]
|
||||
|
||||
|
||||
class TestPDFHandling:
|
||||
"""Test PDF URL detection and handling for issue #701."""
|
||||
|
||||
def test_is_pdf_url_with_pdf_urls(self):
|
||||
"""Test that PDF URLs are correctly identified."""
|
||||
test_cases = [
|
||||
("https://example.com/document.pdf", True),
|
||||
("https://example.com/file.PDF", True), # Case insensitive
|
||||
("https://example.com/path/to/report.pdf", True),
|
||||
("https://pdf.dfcfw.com/pdf/H3_AP202503071644153386_1.pdf", True), # URL from issue
|
||||
("http://site.com/path/document.pdf?param=value", True), # With query params
|
||||
]
|
||||
|
||||
for url, expected in test_cases:
|
||||
assert is_pdf_url(url) == expected, f"Failed for URL: {url}"
|
||||
|
||||
def test_is_pdf_url_with_non_pdf_urls(self):
|
||||
"""Test that non-PDF URLs are correctly identified."""
|
||||
test_cases = [
|
||||
("https://example.com/page.html", False),
|
||||
("https://example.com/article.php", False),
|
||||
("https://example.com/", False),
|
||||
("https://example.com/document.pdfx", False), # Not exactly .pdf
|
||||
("https://example.com/document.doc", False),
|
||||
("https://example.com/document.txt", False),
|
||||
("https://example.com?file=document.pdf", False), # Query param, not path
|
||||
("", False), # Empty string
|
||||
(None, False), # None value
|
||||
]
|
||||
|
||||
for url, expected in test_cases:
|
||||
assert is_pdf_url(url) == expected, f"Failed for URL: {url}"
|
||||
|
||||
def test_crawl_tool_with_pdf_url(self):
|
||||
"""Test that PDF URLs return the expected error structure."""
|
||||
pdf_url = "https://example.com/document.pdf"
|
||||
|
||||
# Act
|
||||
result = crawl_tool.invoke({"url": pdf_url})
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
result_dict = json.loads(result)
|
||||
|
||||
# Check structure of PDF error response
|
||||
assert result_dict["url"] == pdf_url
|
||||
assert "error" in result_dict
|
||||
assert result_dict["crawled_content"] is None
|
||||
assert result_dict["is_pdf"] is True
|
||||
assert "PDF files cannot be crawled directly" in result_dict["error"]
|
||||
|
||||
def test_crawl_tool_with_issue_pdf_url(self):
|
||||
"""Test with the exact PDF URL from issue #701."""
|
||||
issue_pdf_url = "https://pdf.dfcfw.com/pdf/H3_AP202503071644153386_1.pdf"
|
||||
|
||||
# Act
|
||||
result = crawl_tool.invoke({"url": issue_pdf_url})
|
||||
|
||||
# Assert
|
||||
result_dict = json.loads(result)
|
||||
assert result_dict["url"] == issue_pdf_url
|
||||
assert result_dict["is_pdf"] is True
|
||||
assert "cannot be crawled directly" in result_dict["error"]
|
||||
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
@patch("src.tools.crawl.logger")
|
||||
def test_crawl_tool_skips_crawler_for_pdfs(self, mock_logger, mock_crawler_class):
|
||||
"""Test that the crawler is not instantiated for PDF URLs."""
|
||||
pdf_url = "https://example.com/document.pdf"
|
||||
|
||||
# Act
|
||||
result = crawl_tool.invoke({"url": pdf_url})
|
||||
|
||||
# Assert
|
||||
# Crawler should not be instantiated for PDF URLs
|
||||
mock_crawler_class.assert_not_called()
|
||||
mock_logger.info.assert_called_once_with(f"PDF URL detected, skipping crawling: {pdf_url}")
|
||||
|
||||
# Should return proper PDF error structure
|
||||
result_dict = json.loads(result)
|
||||
assert result_dict["is_pdf"] is True
|
||||
@@ -1,119 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
from src.tools.decorators import create_logged_tool
|
||||
|
||||
|
||||
class MockBaseTool:
|
||||
"""Mock base tool class for testing."""
|
||||
|
||||
def _run(self, *args, **kwargs):
|
||||
return "base_result"
|
||||
|
||||
|
||||
class TestLoggedToolMixin:
|
||||
def test_run_calls_log_operation(self):
|
||||
"""Test that _run calls _log_operation with correct parameters."""
|
||||
# Create a logged tool instance
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
# Mock the _log_operation method
|
||||
tool._log_operation = Mock()
|
||||
|
||||
# Call _run with test parameters
|
||||
args = ("arg1", "arg2")
|
||||
kwargs = {"key1": "value1", "key2": "value2"}
|
||||
tool._run(*args, **kwargs)
|
||||
|
||||
# Verify _log_operation was called with correct parameters
|
||||
tool._log_operation.assert_called_once_with("_run", *args, **kwargs)
|
||||
|
||||
def test_run_calls_super_run(self):
|
||||
"""Test that _run calls the parent class _run method."""
|
||||
# Create a logged tool instance
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
# Mock the parent _run method
|
||||
with patch.object(
|
||||
MockBaseTool, "_run", return_value="mocked_result"
|
||||
) as mock_super_run:
|
||||
args = ("arg1", "arg2")
|
||||
kwargs = {"key1": "value1"}
|
||||
result = tool._run(*args, **kwargs)
|
||||
|
||||
# Verify super()._run was called with correct parameters
|
||||
mock_super_run.assert_called_once_with(*args, **kwargs)
|
||||
# Verify the result is returned
|
||||
assert result == "mocked_result"
|
||||
|
||||
def test_run_logs_result(self):
|
||||
"""Test that _run logs the result with debug level."""
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
with patch("src.tools.decorators.logger.debug") as mock_debug:
|
||||
tool._run("test_arg")
|
||||
|
||||
# Verify debug log was called with correct message
|
||||
mock_debug.assert_has_calls(
|
||||
[
|
||||
call("Tool MockBaseTool._run called with parameters: test_arg"),
|
||||
call("Tool MockBaseTool returned: base_result"),
|
||||
]
|
||||
)
|
||||
|
||||
def test_run_returns_super_result(self):
|
||||
"""Test that _run returns the result from parent class."""
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
result = tool._run()
|
||||
assert result == "base_result"
|
||||
|
||||
def test_run_with_no_args(self):
|
||||
"""Test _run method with no arguments."""
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
with patch("src.tools.decorators.logger.debug") as mock_debug:
|
||||
tool._log_operation = Mock()
|
||||
|
||||
result = tool._run()
|
||||
|
||||
# Verify _log_operation called with no args
|
||||
tool._log_operation.assert_called_once_with("_run")
|
||||
# Verify result logging
|
||||
mock_debug.assert_called_once()
|
||||
assert result == "base_result"
|
||||
|
||||
def test_run_with_mixed_args_kwargs(self):
|
||||
"""Test _run method with both positional and keyword arguments."""
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
tool._log_operation = Mock()
|
||||
|
||||
args = ("pos1", "pos2")
|
||||
kwargs = {"kw1": "val1", "kw2": "val2"}
|
||||
result = tool._run(*args, **kwargs)
|
||||
|
||||
# Verify all arguments passed correctly
|
||||
tool._log_operation.assert_called_once_with("_run", *args, **kwargs)
|
||||
assert result == "base_result"
|
||||
|
||||
def test_run_class_name_replacement(self):
|
||||
"""Test that class name 'Logged' prefix is correctly removed in logging."""
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
with patch("src.tools.decorators.logger.debug") as mock_debug:
|
||||
tool._run()
|
||||
|
||||
# Verify the logged class name has 'Logged' prefix removed
|
||||
call_args = mock_debug.call_args[0][0]
|
||||
assert "Tool MockBaseTool returned:" in call_args
|
||||
assert "LoggedMockBaseTool" not in call_args
|
||||
@@ -1,218 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from src.tools.infoquest_search.infoquest_search_api import InfoQuestAPIWrapper
|
||||
|
||||
class TestInfoQuestAPIWrapper:
|
||||
@pytest.fixture
|
||||
def wrapper(self):
|
||||
# Create a wrapper instance with mock API key
|
||||
return InfoQuestAPIWrapper(infoquest_api_key="dummy-key")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response_data(self):
|
||||
# Mock search result data
|
||||
return {
|
||||
"search_result": {
|
||||
"results": [
|
||||
{
|
||||
"content": {
|
||||
"results": {
|
||||
"organic": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"desc": "Test description"
|
||||
}
|
||||
],
|
||||
"top_stories": {
|
||||
"items": [
|
||||
{
|
||||
"time_frame": "2 days ago",
|
||||
"title": "Test News",
|
||||
"url": "https://example.com/news",
|
||||
"source": "Test Source"
|
||||
}
|
||||
]
|
||||
},
|
||||
"images": {
|
||||
"items": [
|
||||
{
|
||||
"url": "https://example.com/image.jpg",
|
||||
"alt": "Test image description"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@patch("src.tools.infoquest_search.infoquest_search_api.requests.post")
|
||||
def test_raw_results_success(self, mock_post, wrapper, mock_response_data):
|
||||
# Test successful synchronous search results
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = mock_response_data
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = wrapper.raw_results("test query", time_range=0, site="")
|
||||
|
||||
assert result == mock_response_data["search_result"]
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
assert "json" in call_args.kwargs
|
||||
assert call_args.kwargs["json"]["query"] == "test query"
|
||||
assert "time_range" not in call_args.kwargs["json"]
|
||||
assert "site" not in call_args.kwargs["json"]
|
||||
|
||||
@patch("src.tools.infoquest_search.infoquest_search_api.requests.post")
|
||||
def test_raw_results_with_time_range_and_site(self, mock_post, wrapper, mock_response_data):
|
||||
# Test search with time range and site filtering
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = mock_response_data
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = wrapper.raw_results("test query", time_range=30, site="example.com")
|
||||
|
||||
assert result == mock_response_data["search_result"]
|
||||
call_args = mock_post.call_args
|
||||
params = call_args.kwargs["json"]
|
||||
assert params["time_range"] == 30
|
||||
assert params["site"] == "example.com"
|
||||
|
||||
@patch("src.tools.infoquest_search.infoquest_search_api.requests.post")
|
||||
def test_raw_results_http_error(self, mock_post, wrapper):
|
||||
# Test HTTP error handling
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("API Error")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(requests.HTTPError):
|
||||
wrapper.raw_results("test query", time_range=0, site="")
|
||||
|
||||
# Check if pytest-asyncio is available, otherwise mark for conditional skipping
|
||||
try:
|
||||
import pytest_asyncio
|
||||
_asyncio_available = True
|
||||
except ImportError:
|
||||
_asyncio_available = False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_results_async_success(self, wrapper, mock_response_data):
|
||||
# Skip only if pytest-asyncio is not installed
|
||||
if not self._asyncio_available:
|
||||
pytest.skip("pytest-asyncio is not installed")
|
||||
|
||||
with patch('json.loads', return_value=mock_response_data):
|
||||
original_method = InfoQuestAPIWrapper.raw_results_async
|
||||
|
||||
async def mock_raw_results_async(self, query, time_range=0, site="", output_format="json"):
|
||||
return mock_response_data["search_result"]
|
||||
|
||||
InfoQuestAPIWrapper.raw_results_async = mock_raw_results_async
|
||||
|
||||
try:
|
||||
result = await wrapper.raw_results_async("test query", time_range=0, site="")
|
||||
assert result == mock_response_data["search_result"]
|
||||
finally:
|
||||
InfoQuestAPIWrapper.raw_results_async = original_method
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_results_async_error(self, wrapper):
|
||||
if not self._asyncio_available:
|
||||
pytest.skip("pytest-asyncio is not installed")
|
||||
|
||||
original_method = InfoQuestAPIWrapper.raw_results_async
|
||||
|
||||
async def mock_raw_results_async_error(self, query, time_range=0, site="", output_format="json"):
|
||||
raise Exception("Error 400: Bad Request")
|
||||
|
||||
InfoQuestAPIWrapper.raw_results_async = mock_raw_results_async_error
|
||||
|
||||
try:
|
||||
with pytest.raises(Exception, match="Error 400: Bad Request"):
|
||||
await wrapper.raw_results_async("test query", time_range=0, site="")
|
||||
finally:
|
||||
InfoQuestAPIWrapper.raw_results_async = original_method
|
||||
|
||||
def test_clean_results_with_images(self, wrapper, mock_response_data):
|
||||
# Test result cleaning functionality
|
||||
raw_results = mock_response_data["search_result"]["results"]
|
||||
cleaned_results = wrapper.clean_results_with_images(raw_results)
|
||||
|
||||
assert len(cleaned_results) == 3
|
||||
|
||||
# Test page result
|
||||
page_result = cleaned_results[0]
|
||||
assert page_result["type"] == "page"
|
||||
assert page_result["title"] == "Test Title"
|
||||
assert page_result["url"] == "https://example.com"
|
||||
assert page_result["desc"] == "Test description"
|
||||
|
||||
# Test news result
|
||||
news_result = cleaned_results[1]
|
||||
assert news_result["type"] == "news"
|
||||
assert news_result["time_frame"] == "2 days ago"
|
||||
assert news_result["title"] == "Test News"
|
||||
assert news_result["url"] == "https://example.com/news"
|
||||
assert news_result["source"] == "Test Source"
|
||||
|
||||
# Test image result
|
||||
image_result = cleaned_results[2]
|
||||
assert image_result["type"] == "image_url"
|
||||
assert image_result["image_url"] == "https://example.com/image.jpg"
|
||||
assert image_result["image_description"] == "Test image description"
|
||||
|
||||
def test_clean_results_empty_categories(self, wrapper):
|
||||
# Test result cleaning with empty categories
|
||||
data = [
|
||||
{
|
||||
"content": {
|
||||
"results": {
|
||||
"organic": [],
|
||||
"top_stories": {"items": []},
|
||||
"images": {"items": []}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
result = wrapper.clean_results_with_images(data)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_clean_results_url_deduplication(self, wrapper):
|
||||
# Test URL deduplication functionality
|
||||
data = [
|
||||
{
|
||||
"content": {
|
||||
"results": {
|
||||
"organic": [
|
||||
{
|
||||
"title": "Test Title 1",
|
||||
"url": "https://example.com",
|
||||
"desc": "Description 1"
|
||||
},
|
||||
{
|
||||
"title": "Test Title 2",
|
||||
"url": "https://example.com",
|
||||
"desc": "Description 2"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
result = wrapper.clean_results_with_images(data)
|
||||
assert len(result) == 1
|
||||
assert result[0]["title"] == "Test Title 1"
|
||||
@@ -1,226 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
|
||||
|
||||
class TestInfoQuestSearchResults:
|
||||
@pytest.fixture
|
||||
def search_tool(self):
|
||||
"""Create a mock InfoQuestSearchResults instance."""
|
||||
mock_tool = Mock()
|
||||
|
||||
mock_tool.time_range = 30
|
||||
mock_tool.site = "example.com"
|
||||
|
||||
def mock_run(query, **kwargs):
|
||||
sample_cleaned_results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"desc": "Test description"
|
||||
}
|
||||
]
|
||||
sample_raw_results = {
|
||||
"results": [
|
||||
{
|
||||
"content": {
|
||||
"results": {
|
||||
"organic": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"desc": "Test description"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
return json.dumps(sample_cleaned_results, ensure_ascii=False), sample_raw_results
|
||||
|
||||
async def mock_arun(query, **kwargs):
|
||||
return mock_run(query, **kwargs)
|
||||
|
||||
mock_tool._run = mock_run
|
||||
mock_tool._arun = mock_arun
|
||||
|
||||
return mock_tool
|
||||
|
||||
@pytest.fixture
|
||||
def sample_raw_results(self):
|
||||
"""Sample raw results from InfoQuest API."""
|
||||
return {
|
||||
"results": [
|
||||
{
|
||||
"content": {
|
||||
"results": {
|
||||
"organic": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"desc": "Test description"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_cleaned_results(self):
|
||||
"""Sample cleaned results."""
|
||||
return [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"desc": "Test description"
|
||||
}
|
||||
]
|
||||
|
||||
def test_init_default_values(self):
|
||||
"""Test initialization with default values using patch."""
|
||||
with patch('src.tools.infoquest_search.infoquest_search_results.InfoQuestAPIWrapper') as mock_wrapper_class:
|
||||
mock_instance = Mock()
|
||||
mock_wrapper_class.return_value = mock_instance
|
||||
|
||||
from src.tools.infoquest_search.infoquest_search_results import InfoQuestSearchResults
|
||||
|
||||
with patch.object(InfoQuestSearchResults, '__init__', return_value=None) as mock_init:
|
||||
InfoQuestSearchResults(infoquest_api_key="dummy-key")
|
||||
|
||||
mock_init.assert_called_once()
|
||||
|
||||
def test_init_custom_values(self):
|
||||
"""Test initialization with custom values using patch."""
|
||||
with patch('src.tools.infoquest_search.infoquest_search_results.InfoQuestAPIWrapper') as mock_wrapper_class:
|
||||
mock_instance = Mock()
|
||||
mock_wrapper_class.return_value = mock_instance
|
||||
|
||||
from src.tools.infoquest_search.infoquest_search_results import InfoQuestSearchResults
|
||||
|
||||
with patch.object(InfoQuestSearchResults, '__init__', return_value=None) as mock_init:
|
||||
InfoQuestSearchResults(
|
||||
time_range=10,
|
||||
site="test.com",
|
||||
infoquest_api_key="dummy-key"
|
||||
)
|
||||
|
||||
mock_init.assert_called_once()
|
||||
|
||||
def test_run_success(
|
||||
self,
|
||||
search_tool,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test successful synchronous run."""
|
||||
result, raw = search_tool._run("test query")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert isinstance(raw, dict)
|
||||
assert "results" in raw
|
||||
|
||||
result_data = json.loads(result)
|
||||
assert isinstance(result_data, list)
|
||||
assert len(result_data) > 0
|
||||
|
||||
def test_run_exception(self, search_tool):
|
||||
"""Test synchronous run with exception."""
|
||||
original_run = search_tool._run
|
||||
|
||||
def mock_run_with_error(query, **kwargs):
|
||||
return json.dumps({"error": "API Error"}, ensure_ascii=False), {}
|
||||
|
||||
try:
|
||||
search_tool._run = mock_run_with_error
|
||||
result, raw = search_tool._run("test query")
|
||||
|
||||
result_dict = json.loads(result)
|
||||
assert "error" in result_dict
|
||||
assert "API Error" in result_dict["error"]
|
||||
assert raw == {}
|
||||
finally:
|
||||
search_tool._run = original_run
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_success(
|
||||
self,
|
||||
search_tool,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test successful asynchronous run."""
|
||||
result, raw = await search_tool._arun("test query")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert isinstance(raw, dict)
|
||||
assert "results" in raw
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_exception(self, search_tool):
|
||||
"""Test asynchronous run with exception."""
|
||||
original_arun = search_tool._arun
|
||||
|
||||
async def mock_arun_with_error(query, **kwargs):
|
||||
return json.dumps({"error": "Async API Error"}, ensure_ascii=False), {}
|
||||
|
||||
try:
|
||||
search_tool._arun = mock_arun_with_error
|
||||
result, raw = await search_tool._arun("test query")
|
||||
|
||||
result_dict = json.loads(result)
|
||||
assert "error" in result_dict
|
||||
assert "Async API Error" in result_dict["error"]
|
||||
assert raw == {}
|
||||
finally:
|
||||
search_tool._arun = original_arun
|
||||
|
||||
def test_run_with_run_manager(
|
||||
self,
|
||||
search_tool,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test run with callback manager."""
|
||||
mock_run_manager = Mock()
|
||||
result, raw = search_tool._run("test query", run_manager=mock_run_manager)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert isinstance(raw, dict)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_with_run_manager(
|
||||
self,
|
||||
search_tool,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test async run with callback manager."""
|
||||
mock_run_manager = Mock()
|
||||
result, raw = await search_tool._arun("test query", run_manager=mock_run_manager)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert isinstance(raw, dict)
|
||||
|
||||
def test_api_wrapper_initialization_with_key(self):
|
||||
"""Test API wrapper initialization with key."""
|
||||
with patch('src.tools.infoquest_search.infoquest_search_results.InfoQuestAPIWrapper') as mock_wrapper_class:
|
||||
mock_instance = Mock()
|
||||
mock_wrapper_class.return_value = mock_instance
|
||||
|
||||
from src.tools.infoquest_search.infoquest_search_results import InfoQuestSearchResults
|
||||
|
||||
with patch.object(InfoQuestSearchResults, '__init__', return_value=None) as mock_init:
|
||||
InfoQuestSearchResults(infoquest_api_key="test-key")
|
||||
|
||||
mock_init.assert_called_once()
|
||||
@@ -1,222 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.tools.python_repl import python_repl_tool
|
||||
|
||||
|
||||
class TestPythonReplTool:
|
||||
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_successful_code_execution(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = "print('Hello, World!')"
|
||||
expected_output = "Hello, World!\n"
|
||||
mock_repl.run.return_value = expected_output
|
||||
|
||||
# Act
|
||||
result = python_repl_tool.invoke({"code": code})
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
mock_logger.info.assert_called_with("Code execution successful")
|
||||
assert "Successfully executed:" in result
|
||||
assert code in result
|
||||
assert expected_output in result
|
||||
|
||||
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_invalid_input_type(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
invalid_code = 123
|
||||
|
||||
# Act & Assert - expect ValidationError when passing invalid input
|
||||
with pytest.raises(Exception): # Could be ValidationError or similar
|
||||
python_repl_tool.invoke({"code": invalid_code})
|
||||
|
||||
mock_repl.run.assert_not_called()
|
||||
|
||||
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_code_execution_with_error_in_result(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = "invalid_function()"
|
||||
error_result = "NameError: name 'invalid_function' is not defined"
|
||||
mock_repl.run.return_value = error_result
|
||||
|
||||
# Act
|
||||
result = python_repl_tool.invoke({"code": code})
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
mock_logger.error.assert_called_with(error_result)
|
||||
assert "Error executing code:" in result
|
||||
assert code in result
|
||||
assert error_result in result
|
||||
|
||||
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_code_execution_with_exception_in_result(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = "1/0"
|
||||
exception_result = "ZeroDivisionError: division by zero"
|
||||
mock_repl.run.return_value = exception_result
|
||||
|
||||
# Act
|
||||
result = python_repl_tool.invoke({"code": code})
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
mock_logger.error.assert_called_with(exception_result)
|
||||
assert "Error executing code:" in result
|
||||
assert code in result
|
||||
assert exception_result in result
|
||||
|
||||
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_code_execution_raises_exception(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = "print('test')"
|
||||
exception = RuntimeError("REPL failed")
|
||||
mock_repl.run.side_effect = exception
|
||||
|
||||
# Act
|
||||
result = python_repl_tool.invoke({"code": code})
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
mock_logger.error.assert_called_with(repr(exception))
|
||||
assert "Error executing code:" in result
|
||||
assert code in result
|
||||
assert repr(exception) in result
|
||||
|
||||
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_successful_execution_with_calculation(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = "result = 2 + 3\nprint(result)"
|
||||
expected_output = "5\n"
|
||||
mock_repl.run.return_value = expected_output
|
||||
|
||||
# Act
|
||||
result = python_repl_tool.invoke({"code": code})
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
mock_logger.info.assert_any_call("Executing Python code")
|
||||
mock_logger.info.assert_any_call("Code execution successful")
|
||||
assert "Successfully executed:" in result
|
||||
assert code in result
|
||||
assert expected_output in result
|
||||
|
||||
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_empty_string_code(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = ""
|
||||
mock_repl.run.return_value = ""
|
||||
|
||||
# Act
|
||||
result = python_repl_tool.invoke({"code": code})
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
mock_logger.info.assert_called_with("Code execution successful")
|
||||
assert "Successfully executed:" in result
|
||||
|
||||
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "true"})
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_logging_calls(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = "x = 1"
|
||||
mock_repl.run.return_value = ""
|
||||
|
||||
# Act
|
||||
python_repl_tool.invoke({"code": code})
|
||||
|
||||
# Assert
|
||||
mock_logger.info.assert_any_call("Executing Python code")
|
||||
mock_logger.info.assert_any_call("Code execution successful")
|
||||
|
||||
# New tests for configuration behavior
|
||||
@patch.dict(os.environ, {"ENABLE_PYTHON_REPL": "false"})
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_tool_disabled(self, mock_logger):
|
||||
# Arrange
|
||||
code = "print('test')"
|
||||
|
||||
# Act
|
||||
result = python_repl_tool.invoke({"code": code})
|
||||
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Python REPL tool is disabled. Please enable it in environment configuration."
|
||||
)
|
||||
assert "Tool disabled:" in result
|
||||
assert "Python REPL tool is disabled" in result
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_tool_disabled_by_default(self, mock_logger):
|
||||
# Arrange - remove any existing ENABLE_PYTHON_REPL variable
|
||||
if "ENABLE_PYTHON_REPL" in os.environ:
|
||||
del os.environ["ENABLE_PYTHON_REPL"]
|
||||
code = "print('test')"
|
||||
|
||||
# Act
|
||||
result = python_repl_tool.invoke({"code": code})
|
||||
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Python REPL tool is disabled. Please enable it in environment configuration."
|
||||
)
|
||||
assert "Tool disabled:" in result
|
||||
|
||||
@pytest.mark.parametrize("env_value", ["true", "True", "TRUE", "1", "yes", "on"])
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_tool_enabled_with_various_truthy_values(
|
||||
self, mock_logger, mock_repl, env_value
|
||||
):
|
||||
# Arrange
|
||||
with patch.dict(os.environ, {"ENABLE_PYTHON_REPL": env_value}):
|
||||
code = "print('enabled')"
|
||||
expected_output = "enabled\n"
|
||||
mock_repl.run.return_value = expected_output
|
||||
|
||||
# Act
|
||||
result = python_repl_tool.invoke({"code": code})
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
assert "Successfully executed:" in result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_value", ["false", "False", "FALSE", "0", "no", "off", ""]
|
||||
)
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_tool_disabled_with_various_falsy_values(self, mock_logger, env_value):
|
||||
# Arrange
|
||||
with patch.dict(os.environ, {"ENABLE_PYTHON_REPL": env_value}):
|
||||
code = "print('disabled')"
|
||||
|
||||
# Act
|
||||
result = python_repl_tool.invoke({"code": code})
|
||||
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Python REPL tool is disabled. Please enable it in environment configuration."
|
||||
)
|
||||
assert "Tool disabled:" in result
|
||||
@@ -1,291 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.config import SearchEngine
|
||||
from src.tools.search import get_web_search_tool
|
||||
|
||||
|
||||
class TestGetWebSearchTool:
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
def test_get_web_search_tool_tavily(self):
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.max_results == 5
|
||||
assert tool.include_raw_content is False
|
||||
assert tool.include_images is True
|
||||
assert tool.include_image_descriptions is True
|
||||
assert tool.include_answer is False
|
||||
assert tool.search_depth == "advanced"
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.DUCKDUCKGO.value)
|
||||
def test_get_web_search_tool_duckduckgo(self):
|
||||
tool = get_web_search_tool(max_search_results=3)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.max_results == 3
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.BRAVE_SEARCH.value)
|
||||
@patch.dict(os.environ, {"BRAVE_SEARCH_API_KEY": "test_api_key"})
|
||||
def test_get_web_search_tool_brave(self):
|
||||
tool = get_web_search_tool(max_search_results=4)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.search_wrapper.api_key.get_secret_value() == "test_api_key"
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.ARXIV.value)
|
||||
def test_get_web_search_tool_arxiv(self):
|
||||
tool = get_web_search_tool(max_search_results=2)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.api_wrapper.top_k_results == 2
|
||||
assert tool.api_wrapper.load_max_docs == 2
|
||||
assert tool.api_wrapper.load_all_available_meta is True
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", "unsupported_engine")
|
||||
def test_get_web_search_tool_unsupported_engine(self):
|
||||
with pytest.raises(
|
||||
ValueError, match="Unsupported search engine: unsupported_engine"
|
||||
):
|
||||
get_web_search_tool(max_search_results=1)
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.BRAVE_SEARCH.value)
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_get_web_search_tool_brave_no_api_key(self):
|
||||
tool = get_web_search_tool(max_search_results=1)
|
||||
assert tool.search_wrapper.api_key.get_secret_value() == ""
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.SERPER.value)
|
||||
@patch.dict(os.environ, {"SERPER_API_KEY": "test_serper_key"})
|
||||
def test_get_web_search_tool_serper(self):
|
||||
tool = get_web_search_tool(max_search_results=6)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.api_wrapper.k == 6
|
||||
assert tool.api_wrapper.serper_api_key == "test_serper_key"
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.SERPER.value)
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_get_web_search_tool_serper_no_api_key(self):
|
||||
with pytest.raises(ValidationError):
|
||||
get_web_search_tool(max_search_results=1)
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_get_web_search_tool_tavily_with_custom_config(self, mock_config):
|
||||
"""Test Tavily tool with custom configuration values."""
|
||||
mock_config.return_value = {
|
||||
"SEARCH_ENGINE": {
|
||||
"include_answer": True,
|
||||
"search_depth": "basic",
|
||||
"include_raw_content": True,
|
||||
"include_images": False,
|
||||
"include_image_descriptions": True,
|
||||
"include_domains": ["example.com"],
|
||||
"exclude_domains": ["spam.com"],
|
||||
}
|
||||
}
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.max_results == 5
|
||||
assert tool.include_answer is True
|
||||
assert tool.search_depth == "basic"
|
||||
assert tool.include_raw_content is True
|
||||
assert tool.include_images is False
|
||||
# include_image_descriptions should be False because include_images is False
|
||||
assert tool.include_image_descriptions is False
|
||||
assert tool.include_domains == ["example.com"]
|
||||
assert tool.exclude_domains == ["spam.com"]
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_get_web_search_tool_tavily_with_empty_config(self, mock_config):
|
||||
"""Test Tavily tool uses defaults when config is empty."""
|
||||
mock_config.return_value = {"SEARCH_ENGINE": {}}
|
||||
tool = get_web_search_tool(max_search_results=10)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.max_results == 10
|
||||
assert tool.include_answer is False
|
||||
assert tool.search_depth == "advanced"
|
||||
assert tool.include_raw_content is False
|
||||
assert tool.include_images is True
|
||||
assert tool.include_image_descriptions is True
|
||||
assert tool.include_domains == []
|
||||
assert tool.exclude_domains == []
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_get_web_search_tool_tavily_image_descriptions_disabled_when_images_disabled(
|
||||
self, mock_config
|
||||
):
|
||||
"""Test that include_image_descriptions is False when include_images is False."""
|
||||
mock_config.return_value = {
|
||||
"SEARCH_ENGINE": {
|
||||
"include_images": False,
|
||||
"include_image_descriptions": True, # This should be ignored
|
||||
}
|
||||
}
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.include_images is False
|
||||
assert tool.include_image_descriptions is False
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_get_web_search_tool_tavily_partial_config(self, mock_config):
|
||||
"""Test Tavily tool with partial configuration."""
|
||||
mock_config.return_value = {
|
||||
"SEARCH_ENGINE": {
|
||||
"include_answer": True,
|
||||
"include_domains": ["trusted.com"],
|
||||
}
|
||||
}
|
||||
tool = get_web_search_tool(max_search_results=3)
|
||||
assert tool.include_answer is True
|
||||
assert tool.search_depth == "advanced" # default
|
||||
assert tool.include_raw_content is False # default
|
||||
assert tool.include_domains == ["trusted.com"]
|
||||
assert tool.exclude_domains == [] # default
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_get_web_search_tool_tavily_with_no_config_file(self, mock_config):
|
||||
"""Test Tavily tool when config file doesn't exist."""
|
||||
mock_config.return_value = {}
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.max_results == 5
|
||||
assert tool.include_answer is False
|
||||
assert tool.search_depth == "advanced"
|
||||
assert tool.include_raw_content is False
|
||||
assert tool.include_images is True
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_get_web_search_tool_tavily_multiple_domains(self, mock_config):
|
||||
"""Test Tavily tool with multiple domains in include/exclude lists."""
|
||||
mock_config.return_value = {
|
||||
"SEARCH_ENGINE": {
|
||||
"include_domains": ["example.com", "trusted.com", "gov.cn"],
|
||||
"exclude_domains": ["spam.com", "scam.org"],
|
||||
}
|
||||
}
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.include_domains == ["example.com", "trusted.com", "gov.cn"]
|
||||
assert tool.exclude_domains == ["spam.com", "scam.org"]
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_tavily_with_no_search_engine_section(self, mock_config):
|
||||
"""Test Tavily tool when SEARCH_ENGINE section doesn't exist in config."""
|
||||
mock_config.return_value = {"OTHER_CONFIG": {}}
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.max_results == 5
|
||||
assert tool.include_answer is False
|
||||
assert tool.search_depth == "advanced"
|
||||
assert tool.include_raw_content is False
|
||||
assert tool.include_images is True
|
||||
assert tool.include_domains == []
|
||||
assert tool.exclude_domains == []
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_tavily_with_completely_empty_config(self, mock_config):
|
||||
"""Test Tavily tool with completely empty config."""
|
||||
mock_config.return_value = {}
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.max_results == 5
|
||||
assert tool.include_answer is False
|
||||
assert tool.search_depth == "advanced"
|
||||
assert tool.include_raw_content is False
|
||||
assert tool.include_images is True
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_tavily_with_only_include_answer_param(self, mock_config):
|
||||
"""Test Tavily tool with only include_answer parameter specified."""
|
||||
mock_config.return_value = {"SEARCH_ENGINE": {"include_answer": True}}
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.include_answer is True
|
||||
assert tool.search_depth == "advanced"
|
||||
assert tool.include_raw_content is False
|
||||
assert tool.include_images is True
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_tavily_with_only_search_depth_param(self, mock_config):
|
||||
"""Test Tavily tool with only search_depth parameter specified."""
|
||||
mock_config.return_value = {"SEARCH_ENGINE": {"search_depth": "basic"}}
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.search_depth == "basic"
|
||||
assert tool.include_answer is False
|
||||
assert tool.include_raw_content is False
|
||||
assert tool.include_images is True
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_tavily_with_only_include_domains_param(self, mock_config):
|
||||
"""Test Tavily tool with only include_domains parameter specified."""
|
||||
mock_config.return_value = {
|
||||
"SEARCH_ENGINE": {"include_domains": ["example.com"]}
|
||||
}
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.include_domains == ["example.com"]
|
||||
assert tool.exclude_domains == []
|
||||
assert tool.include_answer is False
|
||||
assert tool.search_depth == "advanced"
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_tavily_with_explicit_false_boolean_values(self, mock_config):
|
||||
"""Test that explicitly False boolean values are respected (not treated as missing)."""
|
||||
mock_config.return_value = {
|
||||
"SEARCH_ENGINE": {
|
||||
"include_answer": False,
|
||||
"include_raw_content": False,
|
||||
"include_images": False,
|
||||
}
|
||||
}
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.include_answer is False
|
||||
assert tool.include_raw_content is False
|
||||
assert tool.include_images is False
|
||||
assert tool.include_image_descriptions is False
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_tavily_with_empty_domain_lists(self, mock_config):
|
||||
"""Test that empty domain lists are treated as optional."""
|
||||
mock_config.return_value = {
|
||||
"SEARCH_ENGINE": {
|
||||
"include_domains": [],
|
||||
"exclude_domains": [],
|
||||
}
|
||||
}
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.include_domains == []
|
||||
assert tool.exclude_domains == []
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
@patch("src.tools.search.load_yaml_config")
|
||||
def test_tavily_all_parameters_optional_mix(self, mock_config):
|
||||
"""Test that any combination of optional parameters works."""
|
||||
mock_config.return_value = {
|
||||
"SEARCH_ENGINE": {
|
||||
"include_answer": True,
|
||||
"include_images": False,
|
||||
# Deliberately omit search_depth, include_raw_content, domains
|
||||
}
|
||||
}
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.include_answer is True
|
||||
assert tool.include_images is False
|
||||
assert (
|
||||
tool.include_image_descriptions is False
|
||||
) # should be False since include_images is False
|
||||
assert tool.search_depth == "advanced" # default
|
||||
assert tool.include_raw_content is False # default
|
||||
assert tool.include_domains == [] # default
|
||||
assert tool.exclude_domains == [] # default
|
||||
@@ -1,263 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from src.tools.search_postprocessor import SearchResultPostProcessor
|
||||
|
||||
|
||||
class TestSearchResultPostProcessor:
|
||||
"""Test cases for SearchResultPostProcessor"""
|
||||
|
||||
@pytest.fixture
|
||||
def post_processor(self):
|
||||
"""Create a SearchResultPostProcessor instance for testing"""
|
||||
return SearchResultPostProcessor(
|
||||
min_score_threshold=0.5, max_content_length_per_page=100
|
||||
)
|
||||
|
||||
def test_process_results_empty_input(self, post_processor):
|
||||
"""Test processing empty results"""
|
||||
results = []
|
||||
processed = post_processor.process_results(results)
|
||||
assert processed == []
|
||||
|
||||
def test_process_results_with_valid_page_results(self, post_processor):
|
||||
"""Test processing valid page results"""
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Test Page",
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["title"] == "Test Page"
|
||||
assert processed[0]["url"] == "https://example.com"
|
||||
assert processed[0]["content"] == "Test content"
|
||||
assert processed[0]["score"] == 0.8
|
||||
|
||||
def test_process_results_filter_low_score(self, post_processor):
|
||||
"""Test filtering out low score results"""
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Low Score Page",
|
||||
"url": "https://example.com/low",
|
||||
"content": "Low score content",
|
||||
"score": 0.3, # Below threshold of 0.5
|
||||
},
|
||||
{
|
||||
"type": "page",
|
||||
"title": "High Score Page",
|
||||
"url": "https://example.com/high",
|
||||
"content": "High score content",
|
||||
"score": 0.9,
|
||||
},
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["title"] == "High Score Page"
|
||||
|
||||
def test_process_results_remove_duplicates(self, post_processor):
|
||||
"""Test removing duplicate URLs"""
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Page 1",
|
||||
"url": "https://example.com",
|
||||
"content": "Content 1",
|
||||
"score": 0.8,
|
||||
},
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Page 2",
|
||||
"url": "https://example.com", # Duplicate URL
|
||||
"content": "Content 2",
|
||||
"score": 0.7,
|
||||
},
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["title"] == "Page 1" # First one should be kept
|
||||
|
||||
def test_process_results_sort_by_score(self, post_processor):
|
||||
"""Test sorting results by score in descending order"""
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Low Score",
|
||||
"url": "https://example.com/low",
|
||||
"content": "Low score content",
|
||||
"score": 0.3,
|
||||
},
|
||||
{
|
||||
"type": "page",
|
||||
"title": "High Score",
|
||||
"url": "https://example.com/high",
|
||||
"content": "High score content",
|
||||
"score": 0.9,
|
||||
},
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Medium Score",
|
||||
"url": "https://example.com/medium",
|
||||
"content": "Medium score content",
|
||||
"score": 0.6,
|
||||
},
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 2 # Low score filtered out
|
||||
# Should be sorted by score descending
|
||||
assert processed[0]["title"] == "High Score"
|
||||
assert processed[1]["title"] == "Medium Score"
|
||||
|
||||
def test_process_results_truncate_long_content(self, post_processor):
|
||||
"""Test truncating long content"""
|
||||
long_content = "A" * 150 # Longer than max_content_length of 100
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Long Content Page",
|
||||
"url": "https://example.com",
|
||||
"content": long_content,
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert len(processed[0]["content"]) == 103 # 100 + "..."
|
||||
assert processed[0]["content"].endswith("...")
|
||||
|
||||
def test_process_results_remove_base64_images(self, post_processor):
|
||||
"""Test removing base64 images from content"""
|
||||
content_with_base64 = (
|
||||
"Content with image "
|
||||
+ "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
|
||||
)
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Page with Base64",
|
||||
"url": "https://example.com",
|
||||
"content": content_with_base64,
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["content"] == "Content with image "
|
||||
|
||||
def test_process_results_with_image_type(self, post_processor):
|
||||
"""Test processing image type results"""
|
||||
results = [
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "https://example.com/image.jpg",
|
||||
"image_description": "Test image",
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["type"] == "image"
|
||||
assert processed[0]["image_url"] == "https://example.com/image.jpg"
|
||||
assert processed[0]["image_description"] == "Test image"
|
||||
|
||||
def test_process_results_filter_base64_image_urls(self, post_processor):
|
||||
"""Test filtering out image results with base64 URLs"""
|
||||
results = [
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==",
|
||||
"image_description": "Base64 image",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "https://example.com/image.jpg",
|
||||
"image_description": "Regular image",
|
||||
},
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["image_url"] == "https://example.com/image.jpg"
|
||||
|
||||
def test_process_results_truncate_long_image_description(self, post_processor):
|
||||
"""Test truncating long image descriptions"""
|
||||
long_description = "A" * 150 # Longer than max_content_length of 100
|
||||
results = [
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "https://example.com/image.jpg",
|
||||
"image_description": long_description,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert len(processed[0]["image_description"]) == 103 # 100 + "..."
|
||||
assert processed[0]["image_description"].endswith("...")
|
||||
|
||||
def test_process_results_other_types_passthrough(self, post_processor):
|
||||
"""Test that other result types pass through unchanged"""
|
||||
results = [
|
||||
{
|
||||
"type": "video",
|
||||
"title": "Test Video",
|
||||
"url": "https://example.com/video.mp4",
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["type"] == "video"
|
||||
assert processed[0]["title"] == "Test Video"
|
||||
|
||||
def test_process_results_truncate_long_content_with_no_config(self):
|
||||
"""Test truncating long content"""
|
||||
post_processor = SearchResultPostProcessor(None, None)
|
||||
long_content = "A" * 150 # Longer than max_content_length of 100
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Long Content Page",
|
||||
"url": "https://example.com",
|
||||
"content": long_content,
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert len(processed[0]["content"]) == len("A" * 150)
|
||||
|
||||
def test_process_results_truncate_long_content_with_max_content_length_config(self):
|
||||
"""Test truncating long content"""
|
||||
post_processor = SearchResultPostProcessor(None, 100)
|
||||
long_content = "A" * 150 # Longer than max_content_length of 100
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Long Content Page",
|
||||
"url": "https://example.com",
|
||||
"content": long_content,
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert len(processed[0]["content"]) == 103
|
||||
assert processed[0]["content"].endswith("...")
|
||||
|
||||
def test_process_results_truncate_long_content_with_min_score_config(self):
|
||||
"""Test truncating long content"""
|
||||
post_processor = SearchResultPostProcessor(0.8, None)
|
||||
long_content = "A" * 150 # Longer than max_content_length of 100
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Long Content Page",
|
||||
"url": "https://example.com",
|
||||
"content": long_content,
|
||||
"score": 0.3,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 0
|
||||
@@ -1,207 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from src.tools.tavily_search.tavily_search_api_wrapper import (
|
||||
EnhancedTavilySearchAPIWrapper,
|
||||
)
|
||||
|
||||
|
||||
class TestEnhancedTavilySearchAPIWrapper:
|
||||
@pytest.fixture
|
||||
def wrapper(self):
|
||||
with patch(
|
||||
"src.tools.tavily_search.tavily_search_api_wrapper.OriginalTavilySearchAPIWrapper"
|
||||
):
|
||||
wrapper = EnhancedTavilySearchAPIWrapper(tavily_api_key="dummy-key")
|
||||
# The parent class is mocked, so initialization won't fail
|
||||
return wrapper
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response_data(self):
|
||||
return {
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
"score": 0.9,
|
||||
"raw_content": "Raw test content",
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
{
|
||||
"url": "https://example.com/image.jpg",
|
||||
"description": "Test image description",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post")
|
||||
def test_raw_results_success(self, mock_post, wrapper, mock_response_data):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = mock_response_data
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = wrapper.raw_results("test query", max_results=10)
|
||||
|
||||
assert result == mock_response_data
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
assert "json" in call_args.kwargs
|
||||
assert call_args.kwargs["json"]["query"] == "test query"
|
||||
assert call_args.kwargs["json"]["max_results"] == 10
|
||||
|
||||
@patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post")
|
||||
def test_raw_results_with_all_parameters(
|
||||
self, mock_post, wrapper, mock_response_data
|
||||
):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = mock_response_data
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = wrapper.raw_results(
|
||||
"test query",
|
||||
max_results=3,
|
||||
search_depth="basic",
|
||||
include_domains=["example.com"],
|
||||
exclude_domains=["spam.com"],
|
||||
include_answer=True,
|
||||
include_raw_content=True,
|
||||
include_images=True,
|
||||
include_image_descriptions=True,
|
||||
)
|
||||
|
||||
assert result == mock_response_data
|
||||
call_args = mock_post.call_args
|
||||
params = call_args.kwargs["json"]
|
||||
assert params["include_domains"] == ["example.com"]
|
||||
assert params["exclude_domains"] == ["spam.com"]
|
||||
assert params["include_answer"] is True
|
||||
assert params["include_raw_content"] is True
|
||||
|
||||
@patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post")
|
||||
def test_raw_results_http_error(self, mock_post, wrapper):
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("API Error")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(requests.HTTPError):
|
||||
wrapper.raw_results("test query")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_results_async_success(self, wrapper, mock_response_data):
|
||||
# Create a mock that acts as both the response and its context manager
|
||||
mock_response_cm = AsyncMock()
|
||||
mock_response_cm.__aenter__ = AsyncMock(return_value=mock_response_cm)
|
||||
mock_response_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_response_cm.status = 200
|
||||
mock_response_cm.text = AsyncMock(return_value=json.dumps(mock_response_data))
|
||||
|
||||
# Create mock session that returns the context manager
|
||||
mock_session = AsyncMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=mock_response_cm
|
||||
) # Use MagicMock, not AsyncMock
|
||||
|
||||
# Create mock session class
|
||||
mock_session_cm = AsyncMock()
|
||||
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"src.tools.tavily_search.tavily_search_api_wrapper.aiohttp.ClientSession",
|
||||
return_value=mock_session_cm,
|
||||
):
|
||||
result = await wrapper.raw_results_async("test query")
|
||||
|
||||
assert result == mock_response_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_results_async_error(self, wrapper):
|
||||
# Create a mock that acts as both the response and its context manager
|
||||
mock_response_cm = AsyncMock()
|
||||
mock_response_cm.__aenter__ = AsyncMock(return_value=mock_response_cm)
|
||||
mock_response_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_response_cm.status = 400
|
||||
mock_response_cm.reason = "Bad Request"
|
||||
|
||||
# Create mock session that returns the context manager
|
||||
mock_session = AsyncMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=mock_response_cm
|
||||
) # Use MagicMock, not AsyncMock
|
||||
|
||||
# Create mock session class
|
||||
mock_session_cm = AsyncMock()
|
||||
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"src.tools.tavily_search.tavily_search_api_wrapper.aiohttp.ClientSession",
|
||||
return_value=mock_session_cm,
|
||||
):
|
||||
with pytest.raises(Exception, match="Error 400: Bad Request"):
|
||||
await wrapper.raw_results_async("test query")
|
||||
|
||||
def test_clean_results_with_images(self, wrapper, mock_response_data):
|
||||
result = wrapper.clean_results_with_images(mock_response_data)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
# Test page result
|
||||
page_result = result[0]
|
||||
assert page_result["type"] == "page"
|
||||
assert page_result["title"] == "Test Title"
|
||||
assert page_result["url"] == "https://example.com"
|
||||
assert page_result["content"] == "Test content"
|
||||
assert page_result["score"] == 0.9
|
||||
assert page_result["raw_content"] == "Raw test content"
|
||||
|
||||
# Test image result
|
||||
image_result = result[1]
|
||||
assert image_result["type"] == "image_url"
|
||||
assert image_result["image_url"] == {"url": "https://example.com/image.jpg"}
|
||||
assert image_result["image_description"] == "Test image description"
|
||||
|
||||
def test_clean_results_without_raw_content(self, wrapper):
|
||||
data = {
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
"score": 0.9,
|
||||
}
|
||||
],
|
||||
"images": [],
|
||||
}
|
||||
|
||||
result = wrapper.clean_results_with_images(data)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "raw_content" not in result[0]
|
||||
|
||||
def test_clean_results_empty_images(self, wrapper):
|
||||
data = {
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
"score": 0.9,
|
||||
}
|
||||
],
|
||||
"images": [],
|
||||
}
|
||||
|
||||
result = wrapper.clean_results_with_images(data)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "page"
|
||||
@@ -1,206 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.tools.tavily_search.tavily_search_api_wrapper import (
|
||||
EnhancedTavilySearchAPIWrapper,
|
||||
)
|
||||
from src.tools.tavily_search.tavily_search_results_with_images import (
|
||||
TavilySearchWithImages,
|
||||
)
|
||||
|
||||
|
||||
class TestTavilySearchWithImages:
|
||||
@pytest.fixture
|
||||
def mock_api_wrapper(self):
|
||||
"""Create a mock API wrapper."""
|
||||
wrapper = Mock(spec=EnhancedTavilySearchAPIWrapper)
|
||||
return wrapper
|
||||
|
||||
@pytest.fixture
|
||||
def search_tool(self, mock_api_wrapper):
|
||||
"""Create a TavilySearchWithImages instance with mocked dependencies."""
|
||||
tool = TavilySearchWithImages(
|
||||
max_results=5,
|
||||
include_answer=True,
|
||||
include_raw_content=True,
|
||||
include_images=True,
|
||||
include_image_descriptions=True,
|
||||
)
|
||||
tool.api_wrapper = mock_api_wrapper
|
||||
return tool
|
||||
|
||||
@pytest.fixture
|
||||
def sample_raw_results(self):
|
||||
"""Sample raw results from Tavily API."""
|
||||
return {
|
||||
"query": "test query",
|
||||
"answer": "Test answer",
|
||||
"images": ["https://example.com/image1.jpg"],
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
"score": 0.95,
|
||||
"raw_content": "Raw test content",
|
||||
}
|
||||
],
|
||||
"response_time": 1.5,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_cleaned_results(self):
|
||||
"""Sample cleaned results."""
|
||||
return [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
}
|
||||
]
|
||||
|
||||
def test_init_default_values(self):
|
||||
"""Test initialization with default values."""
|
||||
tool = TavilySearchWithImages()
|
||||
assert tool.include_image_descriptions is False
|
||||
assert isinstance(tool.api_wrapper, EnhancedTavilySearchAPIWrapper)
|
||||
|
||||
def test_init_custom_values(self):
|
||||
"""Test initialization with custom values."""
|
||||
tool = TavilySearchWithImages(max_results=10, include_image_descriptions=True)
|
||||
assert tool.max_results == 10
|
||||
assert tool.include_image_descriptions is True
|
||||
|
||||
def test_run_success(
|
||||
self,
|
||||
search_tool,
|
||||
mock_api_wrapper,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test successful synchronous run."""
|
||||
mock_api_wrapper.raw_results.return_value = sample_raw_results
|
||||
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
|
||||
|
||||
result, raw = search_tool._run("test query")
|
||||
|
||||
assert result == json.dumps(sample_cleaned_results, ensure_ascii=False)
|
||||
assert raw == sample_raw_results
|
||||
|
||||
mock_api_wrapper.raw_results.assert_called_once_with(
|
||||
"test query",
|
||||
search_tool.max_results,
|
||||
search_tool.search_depth,
|
||||
search_tool.include_domains,
|
||||
search_tool.exclude_domains,
|
||||
search_tool.include_answer,
|
||||
search_tool.include_raw_content,
|
||||
search_tool.include_images,
|
||||
search_tool.include_image_descriptions,
|
||||
)
|
||||
|
||||
mock_api_wrapper.clean_results_with_images.assert_called_once_with(
|
||||
sample_raw_results
|
||||
)
|
||||
|
||||
def test_run_exception(self, search_tool, mock_api_wrapper):
|
||||
"""Test synchronous run with exception."""
|
||||
mock_api_wrapper.raw_results.side_effect = Exception("API Error")
|
||||
|
||||
result, raw = search_tool._run("test query")
|
||||
|
||||
result_dict = json.loads(result)
|
||||
assert "error" in result_dict
|
||||
assert "API Error" in result_dict["error"]
|
||||
assert raw == {}
|
||||
mock_api_wrapper.clean_results_with_images.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_success(
|
||||
self,
|
||||
search_tool,
|
||||
mock_api_wrapper,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test successful asynchronous run."""
|
||||
mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results)
|
||||
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
|
||||
|
||||
result, raw = await search_tool._arun("test query")
|
||||
|
||||
assert result == json.dumps(sample_cleaned_results, ensure_ascii=False)
|
||||
assert raw == sample_raw_results
|
||||
|
||||
mock_api_wrapper.raw_results_async.assert_called_once_with(
|
||||
"test query",
|
||||
search_tool.max_results,
|
||||
search_tool.search_depth,
|
||||
search_tool.include_domains,
|
||||
search_tool.exclude_domains,
|
||||
search_tool.include_answer,
|
||||
search_tool.include_raw_content,
|
||||
search_tool.include_images,
|
||||
search_tool.include_image_descriptions,
|
||||
)
|
||||
|
||||
mock_api_wrapper.clean_results_with_images.assert_called_once_with(
|
||||
sample_raw_results
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_exception(self, search_tool, mock_api_wrapper):
|
||||
"""Test asynchronous run with exception."""
|
||||
mock_api_wrapper.raw_results_async = AsyncMock(
|
||||
side_effect=Exception("Async API Error")
|
||||
)
|
||||
|
||||
result, raw = await search_tool._arun("test query")
|
||||
|
||||
result_dict = json.loads(result)
|
||||
assert "error" in result_dict
|
||||
assert "Async API Error" in result_dict["error"]
|
||||
assert raw == {}
|
||||
mock_api_wrapper.clean_results_with_images.assert_not_called()
|
||||
|
||||
def test_run_with_run_manager(
|
||||
self,
|
||||
search_tool,
|
||||
mock_api_wrapper,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test run with callback manager."""
|
||||
mock_run_manager = Mock()
|
||||
mock_api_wrapper.raw_results.return_value = sample_raw_results
|
||||
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
|
||||
|
||||
result, raw = search_tool._run("test query", run_manager=mock_run_manager)
|
||||
|
||||
assert result == json.dumps(sample_cleaned_results, ensure_ascii=False)
|
||||
assert raw == sample_raw_results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_with_run_manager(
|
||||
self,
|
||||
search_tool,
|
||||
mock_api_wrapper,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test async run with callback manager."""
|
||||
mock_run_manager = Mock()
|
||||
mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results)
|
||||
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
|
||||
|
||||
result, raw = await search_tool._arun(
|
||||
"test query", run_manager=mock_run_manager
|
||||
)
|
||||
|
||||
assert result == json.dumps(sample_cleaned_results, ensure_ascii=False)
|
||||
assert raw == sample_raw_results
|
||||
@@ -1,126 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
|
||||
from src.rag import Chunk, Document, Resource, Retriever
|
||||
from src.tools.retriever import RetrieverInput, RetrieverTool, get_retriever_tool
|
||||
|
||||
|
||||
def test_retriever_input_model():
|
||||
input_data = RetrieverInput(keywords="test keywords")
|
||||
assert input_data.keywords == "test keywords"
|
||||
|
||||
|
||||
def test_retriever_tool_init():
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
||||
|
||||
assert tool.name == "local_search_tool"
|
||||
assert "retrieving information" in tool.description
|
||||
assert tool.args_schema == RetrieverInput
|
||||
assert tool.retriever == mock_retriever
|
||||
assert tool.resources == resources
|
||||
|
||||
|
||||
def test_retriever_tool_run_with_results():
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
chunk = Chunk(content="test content", similarity=0.9)
|
||||
doc = Document(id="doc1", chunks=[chunk])
|
||||
mock_retriever.query_relevant_documents.return_value = [doc]
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
||||
|
||||
result = tool._run("test keywords")
|
||||
|
||||
mock_retriever.query_relevant_documents.assert_called_once_with(
|
||||
"test keywords", resources
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert result[0] == doc.to_dict()
|
||||
|
||||
|
||||
def test_retriever_tool_run_no_results():
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
mock_retriever.query_relevant_documents.return_value = []
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
||||
|
||||
result = tool._run("test keywords")
|
||||
|
||||
assert result == "No results found from the local knowledge base."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retriever_tool_arun():
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
chunk = Chunk(content="async content", similarity=0.8)
|
||||
doc = Document(id="doc2", chunks=[chunk])
|
||||
|
||||
# Mock the async method
|
||||
async def mock_async_query(*args, **kwargs):
|
||||
return [doc]
|
||||
|
||||
mock_retriever.query_relevant_documents_async = mock_async_query
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
||||
|
||||
mock_run_manager = Mock(spec=AsyncCallbackManagerForToolRun)
|
||||
|
||||
result = await tool._arun("async keywords", mock_run_manager)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert result[0] == doc.to_dict()
|
||||
|
||||
|
||||
@patch("src.tools.retriever.build_retriever")
|
||||
def test_get_retriever_tool_success(mock_build_retriever):
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
mock_build_retriever.return_value = mock_retriever
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = get_retriever_tool(resources)
|
||||
|
||||
assert isinstance(tool, RetrieverTool)
|
||||
assert tool.retriever == mock_retriever
|
||||
assert tool.resources == resources
|
||||
|
||||
|
||||
def test_get_retriever_tool_empty_resources():
|
||||
result = get_retriever_tool([])
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("src.tools.retriever.build_retriever")
|
||||
def test_get_retriever_tool_no_retriever(mock_build_retriever):
|
||||
mock_build_retriever.return_value = None
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
result = get_retriever_tool(resources)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_retriever_tool_run_with_callback_manager():
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
mock_retriever.query_relevant_documents.return_value = []
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
||||
|
||||
mock_callback_manager = Mock(spec=CallbackManagerForToolRun)
|
||||
result = tool._run("test keywords", mock_callback_manager)
|
||||
|
||||
assert result == "No results found from the local knowledge base."
|
||||
@@ -1,235 +0,0 @@
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
|
||||
from src.utils.context_manager import ContextManager
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Test cases for ContextManager"""
|
||||
|
||||
def test_count_tokens_with_empty_messages(self):
|
||||
"""Test counting tokens with empty message list"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = []
|
||||
token_count = context_manager.count_tokens(messages)
|
||||
assert token_count == 0
|
||||
|
||||
def test_count_tokens_with_system_message(self):
|
||||
"""Test counting tokens with system message"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [SystemMessage(content="You are a helpful assistant.")]
|
||||
token_count = context_manager.count_tokens(messages)
|
||||
# System message has 28 characters, should be around 8 tokens (28/4 * 1.1)
|
||||
assert token_count > 7
|
||||
|
||||
def test_count_tokens_with_human_message(self):
|
||||
"""Test counting tokens with human message"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [HumanMessage(content="你好,这是一个测试消息。")]
|
||||
token_count = context_manager.count_tokens(messages)
|
||||
assert token_count > 12
|
||||
|
||||
def test_count_tokens_with_ai_message(self):
|
||||
"""Test counting tokens with AI message"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [AIMessage(content="I'm doing well, thank you for asking!")]
|
||||
token_count = context_manager.count_tokens(messages)
|
||||
assert token_count >= 10
|
||||
|
||||
def test_count_tokens_with_tool_message(self):
|
||||
"""Test counting tokens with tool message"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [
|
||||
ToolMessage(content="Tool execution result data here", tool_call_id="test")
|
||||
]
|
||||
token_count = context_manager.count_tokens(messages)
|
||||
# Tool message has about 32 characters, should be around 10 tokens (32/4 * 1.3)
|
||||
assert token_count > 0
|
||||
|
||||
def test_count_tokens_with_multiple_messages(self):
|
||||
"""Test counting tokens with multiple messages"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello, how are you?"),
|
||||
AIMessage(content="I'm doing well, thank you for asking!"),
|
||||
]
|
||||
token_count = context_manager.count_tokens(messages)
|
||||
# Should be sum of all individual message tokens
|
||||
assert token_count > 0
|
||||
|
||||
def test_is_over_limit_when_under_limit(self):
|
||||
"""Test is_over_limit when messages are under token limit"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
short_messages = [HumanMessage(content="Short message")]
|
||||
is_over = context_manager.is_over_limit(short_messages)
|
||||
assert is_over is False
|
||||
|
||||
def test_is_over_limit_when_over_limit(self):
|
||||
"""Test is_over_limit when messages exceed token limit"""
|
||||
# Create a context manager with a very low limit
|
||||
low_limit_cm = ContextManager(token_limit=5)
|
||||
long_messages = [
|
||||
HumanMessage(
|
||||
content="This is a very long message that should exceed the limit"
|
||||
)
|
||||
]
|
||||
is_over = low_limit_cm.is_over_limit(long_messages)
|
||||
assert is_over is True
|
||||
|
||||
def test_compress_messages_when_not_over_limit(self):
|
||||
"""Test compress_messages when messages are not over limit"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [HumanMessage(content="Short message")]
|
||||
compressed = context_manager.compress_messages({"messages": messages})
|
||||
# Should return the same messages when not over limit
|
||||
assert len(compressed["messages"]) == len(messages)
|
||||
|
||||
def test_compress_messages_with_tool_message(self):
|
||||
"""Test compress_messages preserves system message and compresses raw_content"""
|
||||
# Create a context manager with limited token capacity
|
||||
limited_cm = ContextManager(token_limit=200)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
ToolMessage(
|
||||
name="web_search",
|
||||
content='[{"title": "Test Result", "url": "https://example.com", "raw_content": "' + ("This is a test content that should be compressed if it exceeds 1024 characters. " * 2000) + '"}]',
|
||||
tool_call_id="test_search",
|
||||
)
|
||||
]
|
||||
|
||||
compressed = limited_cm.compress_messages({"messages": messages})
|
||||
# Should preserve system message and some recent messages
|
||||
assert len(compressed["messages"]) == 4
|
||||
|
||||
# Verify raw_content was compressed to 1024 characters
|
||||
import json
|
||||
for msg in compressed["messages"]:
|
||||
if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "web_search":
|
||||
content_data = json.loads(msg.content)
|
||||
if isinstance(content_data, list):
|
||||
for item in content_data:
|
||||
if isinstance(item, dict) and "raw_content" in item:
|
||||
assert len(item["raw_content"]) == 1024
|
||||
|
||||
def test_compress_messages_with_preserve_prefix_message(self):
|
||||
"""Test compress_messages when no system message is present"""
|
||||
# Create a context manager with limited token capacity
|
||||
limited_cm = ContextManager(token_limit=100, preserve_prefix_message_count=2)
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
HumanMessage(
|
||||
content="Can you tell me a very long story that would exceed token limits? "
|
||||
* 10
|
||||
),
|
||||
]
|
||||
|
||||
compressed = limited_cm.compress_messages({"messages": messages})
|
||||
# Should keep only the most recent messages that fit
|
||||
assert len(compressed["messages"]) == 3
|
||||
|
||||
def test_compress_messages_without_config(self):
|
||||
"""Test compress_messages preserves system message"""
|
||||
# Create a context manager with limited token capacity
|
||||
limited_cm = ContextManager(None)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
HumanMessage(
|
||||
content="Can you tell me a very long story that would exceed token limits? "
|
||||
* 100
|
||||
),
|
||||
]
|
||||
|
||||
compressed = limited_cm.compress_messages({"messages": messages})
|
||||
# return the original messages
|
||||
assert len(compressed["messages"]) == 4
|
||||
|
||||
def test_count_message_tokens_with_additional_kwargs(self):
|
||||
"""Test counting tokens for messages with additional kwargs"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
message = ToolMessage(
|
||||
content="Tool result",
|
||||
tool_call_id="test",
|
||||
additional_kwargs={"tool_calls": [{"name": "test_function"}]},
|
||||
)
|
||||
token_count = context_manager._count_message_tokens(message)
|
||||
assert token_count > 0
|
||||
|
||||
def test_count_message_tokens_minimum_one_token(self):
|
||||
"""Test that message token count is at least 1"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
message = HumanMessage(content="") # Empty content
|
||||
token_count = context_manager._count_message_tokens(message)
|
||||
assert token_count == 1 # Should be at least 1
|
||||
|
||||
def test_count_text_tokens_english_only(self):
|
||||
"""Test counting tokens for English text"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
# 16 English characters should result in 4 tokens (16/4)
|
||||
text = "This is a test."
|
||||
token_count = context_manager._count_text_tokens(text)
|
||||
assert token_count > 0
|
||||
|
||||
def test_count_text_tokens_chinese_only(self):
|
||||
"""Test counting tokens for Chinese text"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
# 8 Chinese characters should result in 8 tokens (1:1 ratio)
|
||||
text = "这是一个测试文本"
|
||||
token_count = context_manager._count_text_tokens(text)
|
||||
assert token_count == 8
|
||||
|
||||
def test_count_text_tokens_mixed_content(self):
|
||||
"""Test counting tokens for mixed English and Chinese text"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
text = "Hello world 这是一些中文"
|
||||
token_count = context_manager._count_text_tokens(text)
|
||||
assert token_count > 6
|
||||
|
||||
def test_compress_messages_with_runtime_when_not_over_limit(self):
|
||||
"""compress_messages accepts runtime param when under limit"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [HumanMessage(content="Short message"), AIMessage(content="OK")]
|
||||
compressed = context_manager.compress_messages({"messages": messages}, runtime=object())
|
||||
assert isinstance(compressed, dict)
|
||||
assert "messages" in compressed
|
||||
assert len(compressed["messages"]) == len(messages)
|
||||
|
||||
def test_compress_messages_with_runtime_when_over_limit(self):
|
||||
"""compress_messages accepts runtime param and still compresses"""
|
||||
limited_cm = ContextManager(token_limit=200)
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
HumanMessage(
|
||||
content="Can you tell me a very long story that would exceed token limits? " * 100
|
||||
),
|
||||
ToolMessage(
|
||||
name="web_search",
|
||||
content='[{"title": "Test Result", "url": "https://example.com", "raw_content": "' + ("This is a test content that should be compressed if it exceeds 1024 characters. " * 2000) + '"}]',
|
||||
tool_call_id="test_search",
|
||||
)
|
||||
]
|
||||
compressed = limited_cm.compress_messages({"messages": messages}, runtime=object())
|
||||
assert isinstance(compressed, dict)
|
||||
assert "messages" in compressed
|
||||
# Should preserve only what fits; with this setup we expect heavy compression
|
||||
assert len(compressed["messages"]) == 5
|
||||
|
||||
# Verify raw_content was compressed to 1024 characters
|
||||
import json
|
||||
for msg in compressed["messages"]:
|
||||
if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "web_search":
|
||||
content_data = json.loads(msg.content)
|
||||
if isinstance(content_data, list):
|
||||
for item in content_data:
|
||||
if isinstance(item, dict) and "raw_content" in item:
|
||||
assert len(item["raw_content"]) == 1024
|
||||
@@ -1,581 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
|
||||
from src.utils.json_utils import (
|
||||
_extract_json_from_content,
|
||||
repair_json_output,
|
||||
sanitize_args,
|
||||
sanitize_tool_response,
|
||||
)
|
||||
|
||||
|
||||
class TestRepairJsonOutput:
|
||||
def test_valid_json_object(self):
|
||||
"""Test with valid JSON object"""
|
||||
content = '{"key": "value", "number": 123}'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value", "number": 123}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_valid_json_array(self):
|
||||
"""Test with valid JSON array"""
|
||||
content = '[1, 2, 3, "test"]'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps([1, 2, 3, "test"], ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_with_code_block_json(self):
|
||||
"""Test JSON wrapped in ```json code block"""
|
||||
content = '```json\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_with_code_block_ts(self):
|
||||
"""Test JSON wrapped in ```ts code block"""
|
||||
content = '```ts\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_with_code_block_uppercase_json(self):
|
||||
"""Test JSON wrapped in ```JSON (uppercase) code block"""
|
||||
content = '```JSON\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_with_code_block_uppercase_ts(self):
|
||||
"""Test JSON wrapped in ```TS (uppercase) code block"""
|
||||
content = '```TS\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_with_code_block_mixed_case_json(self):
|
||||
"""Test JSON wrapped in ```Json (mixed case) code block"""
|
||||
content = '```Json\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_with_code_block_uppercase_ts_with_prefix(self):
|
||||
"""Test JSON wrapped in ```TS code block with prefix text"""
|
||||
content = 'some prefix ```TS\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_with_code_block_uppercase_json_with_prefix(self):
|
||||
"""Test JSON wrapped in ```JSON code block with prefix text - case sensitive fix"""
|
||||
# This tests the fix for case-insensitive guard when fence is not at start
|
||||
content = 'prefix ```JSON\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_with_plain_code_block_uppercase(self):
|
||||
"""Test JSON wrapped in plain ``` code block (case insensitive)"""
|
||||
content = '```\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_malformed_json_repair(self):
|
||||
"""Test with malformed JSON that can be repaired"""
|
||||
content = '{"key": "value", "incomplete":'
|
||||
result = repair_json_output(content)
|
||||
# Should return repaired JSON
|
||||
assert result.startswith('{"key": "value"')
|
||||
|
||||
def test_non_json_content(self):
|
||||
"""Test with non-JSON content"""
|
||||
content = "This is just plain text"
|
||||
result = repair_json_output(content)
|
||||
assert result == content
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Test with empty string"""
|
||||
content = ""
|
||||
result = repair_json_output(content)
|
||||
assert result == ""
|
||||
|
||||
def test_whitespace_only(self):
|
||||
"""Test with whitespace only"""
|
||||
content = " \n\t "
|
||||
result = repair_json_output(content)
|
||||
assert result == ""
|
||||
|
||||
def test_json_with_unicode(self):
|
||||
"""Test JSON with unicode characters"""
|
||||
content = '{"name": "测试", "emoji": "🎯"}'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"name": "测试", "emoji": "🎯"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_code_block_without_closing(self):
|
||||
"""Test JSON code block without closing```"""
|
||||
content = '```json\n{"key": "value"}'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_repair_broken_json(self):
|
||||
"""Test exception handling when JSON repair fails"""
|
||||
content = '{"this": "is", "completely": broken and unparseable'
|
||||
expect = '{"this": "is", "completely": "broken and unparseable"}'
|
||||
result = repair_json_output(content)
|
||||
assert result == expect
|
||||
|
||||
def test_nested_json_object(self):
|
||||
"""Test with nested JSON object"""
|
||||
content = '{"outer": {"inner": {"deep": "value"}}}'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps(
|
||||
{"outer": {"inner": {"deep": "value"}}}, ensure_ascii=False
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_json_array_with_objects(self):
|
||||
"""Test JSON array containing objects"""
|
||||
content = '[{"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}]'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps(
|
||||
[{"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}], ensure_ascii=False
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_content_with_json_in_middle(self):
|
||||
"""Test content that contains ```json in the middle"""
|
||||
content = 'Some text before ```json {"key": "value"} and after'
|
||||
result = repair_json_output(content)
|
||||
# Should attempt to process as JSON since it contains ```json
|
||||
assert isinstance(result, str)
|
||||
assert result == '{"key": "value"}'
|
||||
|
||||
|
||||
class TestExtractJsonFromContent:
|
||||
def test_json_with_extra_tokens_after_closing_brace(self):
|
||||
"""Test extracting JSON with extra tokens after closing brace"""
|
||||
content = '{"key": "value"} extra tokens here'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '{"key": "value"}'
|
||||
|
||||
def test_json_with_extra_tokens_after_closing_bracket(self):
|
||||
"""Test extracting JSON array with extra tokens"""
|
||||
content = '[1, 2, 3] garbage data'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '[1, 2, 3]'
|
||||
|
||||
def test_nested_json_with_extra_tokens(self):
|
||||
"""Test nested JSON with extra tokens"""
|
||||
content = '{"nested": {"inner": [1, 2, 3]}} invalid text'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '{"nested": {"inner": [1, 2, 3]}}'
|
||||
|
||||
def test_json_with_string_containing_braces(self):
|
||||
"""Test JSON with strings containing braces"""
|
||||
content = '{"text": "this has {braces} in it"} extra'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '{"text": "this has {braces} in it"}'
|
||||
|
||||
def test_json_with_escaped_quotes(self):
|
||||
"""Test JSON with escaped quotes in strings"""
|
||||
content = '{"text": "quote \\"here\\""} junk'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '{"text": "quote \\"here\\""}'
|
||||
|
||||
def test_clean_json_no_extra_tokens(self):
|
||||
"""Test clean JSON without extra tokens"""
|
||||
content = '{"key": "value"}'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '{"key": "value"}'
|
||||
|
||||
def test_empty_object(self):
|
||||
"""Test empty object"""
|
||||
content = '{} extra'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '{}'
|
||||
|
||||
def test_empty_array(self):
|
||||
"""Test empty array"""
|
||||
content = '[] more stuff'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '[]'
|
||||
|
||||
def test_extra_closing_brace_no_opening(self):
|
||||
"""Test that extra closing brace without opening is not marked as valid end"""
|
||||
content = '} garbage data'
|
||||
result = _extract_json_from_content(content)
|
||||
# Should return original content since no opening brace was seen
|
||||
assert result == content
|
||||
|
||||
def test_extra_closing_bracket_no_opening(self):
|
||||
"""Test that extra closing bracket without opening is not marked as valid end"""
|
||||
content = '] garbage data'
|
||||
result = _extract_json_from_content(content)
|
||||
# Should return original content since no opening bracket was seen
|
||||
assert result == content
|
||||
|
||||
|
||||
class TestSanitizeToolResponse:
|
||||
def test_basic_sanitization(self):
|
||||
"""Test basic tool response sanitization"""
|
||||
content = "normal response"
|
||||
result = sanitize_tool_response(content)
|
||||
assert result == "normal response"
|
||||
|
||||
def test_json_with_extra_tokens(self):
|
||||
"""Test sanitizing JSON with extra tokens"""
|
||||
content = '{"data": "value"} some garbage'
|
||||
result = sanitize_tool_response(content)
|
||||
assert result == '{"data": "value"}'
|
||||
|
||||
def test_very_long_response_truncation(self):
|
||||
"""Test truncation of very long responses"""
|
||||
long_content = "a" * 60000 # Exceeds default max of 50000
|
||||
result = sanitize_tool_response(long_content)
|
||||
assert len(result) <= 50003 # 50000 + "..."
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_custom_max_length(self):
|
||||
"""Test custom maximum length"""
|
||||
long_content = "a" * 1000
|
||||
result = sanitize_tool_response(long_content, max_length=100)
|
||||
assert len(result) <= 103 # 100 + "..."
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_control_character_removal(self):
|
||||
"""Test removal of control characters"""
|
||||
content = "text with \x00 null \x01 chars"
|
||||
result = sanitize_tool_response(content)
|
||||
assert "\x00" not in result
|
||||
assert "\x01" not in result
|
||||
|
||||
def test_none_content(self):
|
||||
"""Test handling of None content"""
|
||||
result = sanitize_tool_response("")
|
||||
assert result == ""
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test whitespace handling"""
|
||||
content = " text with spaces "
|
||||
result = sanitize_tool_response(content)
|
||||
assert result == "text with spaces"
|
||||
|
||||
def test_json_array_with_extra_tokens(self):
|
||||
"""Test JSON array with extra tokens"""
|
||||
content = '[{"id": 1}, {"id": 2}] invalid stuff'
|
||||
result = sanitize_tool_response(content)
|
||||
assert result == '[{"id": 1}, {"id": 2}]'
|
||||
|
||||
|
||||
class TestSanitizeArgs:
|
||||
def test_sanitize_special_characters(self):
|
||||
"""Test sanitization of special characters"""
|
||||
args = '{"key": "value", "array": [1, 2, 3]}'
|
||||
result = sanitize_args(args)
|
||||
assert result == '{"key": "value", "array": [1, 2, 3]}'
|
||||
|
||||
def test_sanitize_square_brackets(self):
|
||||
"""Test sanitization of square brackets"""
|
||||
args = '[1, 2, 3]'
|
||||
result = sanitize_args(args)
|
||||
assert result == '[1, 2, 3]'
|
||||
|
||||
def test_sanitize_curly_braces(self):
|
||||
"""Test sanitization of curly braces"""
|
||||
args = '{key: value}'
|
||||
result = sanitize_args(args)
|
||||
assert result == '{key: value}'
|
||||
|
||||
def test_sanitize_mixed_brackets(self):
|
||||
"""Test sanitization of mixed bracket types"""
|
||||
args = '{[test]}'
|
||||
result = sanitize_args(args)
|
||||
assert result == '{[test]}'
|
||||
|
||||
def test_sanitize_non_string_input(self):
|
||||
"""Test sanitization of non-string input returns empty string"""
|
||||
assert sanitize_args(None) == ""
|
||||
assert sanitize_args(123) == ""
|
||||
assert sanitize_args([1, 2, 3]) == ""
|
||||
assert sanitize_args({"key": "value"}) == ""
|
||||
|
||||
def test_sanitize_empty_string(self):
|
||||
"""Test sanitization of empty string"""
|
||||
result = sanitize_args("")
|
||||
assert result == ""
|
||||
|
||||
def test_sanitize_plain_text(self):
|
||||
"""Test sanitization of plain text without special characters"""
|
||||
args = "plain text without brackets or braces"
|
||||
result = sanitize_args(args)
|
||||
assert result == "plain text without brackets or braces"
|
||||
|
||||
def test_sanitize_nested_structures(self):
|
||||
"""Test sanitization of deeply nested structures"""
|
||||
args = '{"outer": {"inner": [1, [2, 3]]}}'
|
||||
result = sanitize_args(args)
|
||||
assert result == '{"outer": {"inner": [1, [2, 3]]}}'
|
||||
|
||||
|
||||
class TestRepairJsonOutputEdgeCases:
|
||||
def test_code_block_with_leading_spaces(self):
|
||||
"""Test code block with leading spaces"""
|
||||
content = ' ```json\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_code_block_with_tabs(self):
|
||||
"""Test code block with tabs"""
|
||||
content = '\t```json\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_code_block_with_multiple_newlines(self):
|
||||
"""Test code block with multiple newlines after opening fence"""
|
||||
content = '```json\n\n\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_code_block_with_spaces_before_closing(self):
|
||||
"""Test code block with spaces before closing fence"""
|
||||
content = '```json\n{"key": "value"}\n ```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_with_newlines_in_values(self):
|
||||
"""Test JSON with newlines in string values"""
|
||||
content = '{"text": "line1\\nline2\\nline3"}'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"text": "line1\nline2\nline3"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_with_special_unicode(self):
|
||||
"""Test JSON with special unicode characters"""
|
||||
content = '{"emoji": "🔥💯", "chinese": "中文测试", "math": "∑∫"}'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"emoji": "🔥💯", "chinese": "中文测试", "math": "∑∫"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_boolean_values(self):
|
||||
"""Test JSON with boolean values"""
|
||||
content = '{"active": true, "disabled": false, "nullable": null}'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"active": True, "disabled": False, "nullable": None}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_json_numeric_values(self):
|
||||
"""Test JSON with various numeric values"""
|
||||
content = '{"int": 42, "float": 3.14159, "negative": -123, "scientific": 1.23e10}'
|
||||
result = repair_json_output(content)
|
||||
parsed = json.loads(result)
|
||||
assert parsed["int"] == 42
|
||||
assert parsed["float"] == 3.14159
|
||||
assert parsed["negative"] == -123
|
||||
|
||||
def test_plain_code_block_marker(self):
|
||||
"""Test plain ``` code block without language specifier"""
|
||||
content = '```\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_multiple_json_objects_takes_first_complete(self):
|
||||
"""Test that multiple JSON objects are properly extracted"""
|
||||
content = '{"first": "object"} {"second": "object"}'
|
||||
result = repair_json_output(content)
|
||||
# json_repair will combine multiple objects into an array
|
||||
expected = json.dumps([{"first": "object"}, {"second": "object"}], ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_chinese_json_with_code_block(self):
|
||||
"""Test JSON with Chinese content wrapped in markdown code block"""
|
||||
content = '''```json
|
||||
{
|
||||
"locale": "en-US",
|
||||
"has_enough_context": true,
|
||||
"thought": "测试中文内容",
|
||||
"title": "地月距离小报告",
|
||||
"steps": []
|
||||
}
|
||||
```'''
|
||||
result = repair_json_output(content)
|
||||
parsed = json.loads(result)
|
||||
assert parsed["locale"] == "en-US"
|
||||
assert parsed["title"] == "地月距离小报告"
|
||||
assert parsed["thought"] == "测试中文内容"
|
||||
assert isinstance(parsed["steps"], list)
|
||||
|
||||
def test_code_block_uppercase_json_with_leading_spaces(self):
|
||||
"""Test uppercase JSON code block with leading spaces"""
|
||||
content = ' ```JSON\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_code_block_uppercase_json_with_tabs(self):
|
||||
"""Test uppercase JSON code block with tabs"""
|
||||
content = '\t```JSON\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_code_block_mixed_case_with_multiple_newlines(self):
|
||||
"""Test mixed case code block with multiple newlines"""
|
||||
content = '```JsOn\n\n\n{"key": "value"}\n```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_code_block_uppercase_with_spaces_before_closing(self):
|
||||
"""Test uppercase code block with spaces before closing fence"""
|
||||
content = '```TYPESCRIPT\n{"key": "value"}\n ```'
|
||||
result = repair_json_output(content)
|
||||
expected = json.dumps({"key": "value"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_code_block_case_insensitive_various_languages(self):
|
||||
"""Test code blocks with various language specifiers in different cases"""
|
||||
test_cases = [
|
||||
('```Python\n{"key": "value"}\n```', '{"key": "value"}'),
|
||||
('```PYTHON\n{"key": "value"}\n```', '{"key": "value"}'),
|
||||
('```pYtHoN\n{"key": "value"}\n```', '{"key": "value"}'),
|
||||
('```sql\n{"key": "value"}\n```', '{"key": "value"}'),
|
||||
('```SQL\n{"key": "value"}\n```', '{"key": "value"}'),
|
||||
]
|
||||
for content, expected_json_str in test_cases:
|
||||
result = repair_json_output(content)
|
||||
# Verify it's valid JSON
|
||||
parsed = json.loads(result)
|
||||
assert parsed["key"] == "value"
|
||||
|
||||
|
||||
class TestExtractJsonFromContentEdgeCases:
|
||||
def test_deeply_nested_json(self):
|
||||
"""Test extraction of deeply nested JSON"""
|
||||
content = '{"l1": {"l2": {"l3": {"l4": {"l5": "deep"}}}}} garbage'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '{"l1": {"l2": {"l3": {"l4": {"l5": "deep"}}}}}'
|
||||
|
||||
def test_json_array_of_arrays(self):
|
||||
"""Test extraction of nested arrays"""
|
||||
content = '[[1, 2], [3, 4], [5, 6]] extra'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '[[1, 2], [3, 4], [5, 6]]'
|
||||
|
||||
def test_json_with_backslashes_in_string(self):
|
||||
"""Test JSON with backslashes in string values"""
|
||||
content = r'{"path": "C:\\Users\\test\\file.txt"} garbage'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == r'{"path": "C:\\Users\\test\\file.txt"}'
|
||||
|
||||
def test_json_with_forward_slashes(self):
|
||||
"""Test JSON with forward slashes in string values"""
|
||||
content = '{"url": "https://example.com/path/to/resource"} extra'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '{"url": "https://example.com/path/to/resource"}'
|
||||
|
||||
def test_mixed_object_and_array(self):
|
||||
"""Test JSON with mixed objects and arrays"""
|
||||
content = '{"items": [{"id": 1}, {"id": 2}], "count": 2} tail'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '{"items": [{"id": 1}, {"id": 2}], "count": 2}'
|
||||
|
||||
def test_json_with_unicode_escape_sequences(self):
|
||||
"""Test JSON with unicode escape sequences"""
|
||||
content = r'{"text": "\u4E2D\u6587"} junk'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == r'{"text": "\u4E2D\u6587"}'
|
||||
|
||||
def test_no_json_structure(self):
|
||||
"""Test content without JSON structure"""
|
||||
content = 'just plain text without brackets'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == content
|
||||
|
||||
def test_unbalanced_braces_in_middle(self):
|
||||
"""Test content with unbalanced braces doesn't extract invalid JSON"""
|
||||
content = '{"incomplete": {"nested": } text'
|
||||
result = _extract_json_from_content(content)
|
||||
# Should not mark as valid end since braces are unbalanced
|
||||
assert result == content
|
||||
|
||||
def test_json_with_comma_separated_values(self):
|
||||
"""Test JSON object with multiple comma-separated values"""
|
||||
content = '{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} more text'
|
||||
result = _extract_json_from_content(content)
|
||||
assert result == '{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}'
|
||||
|
||||
|
||||
class TestSanitizeToolResponseEdgeCases:
|
||||
def test_json_object_with_extra_tokens(self):
|
||||
"""Test sanitizing JSON object with trailing tokens"""
|
||||
content = '{"status": "success", "data": {"id": 123}} trailing garbage'
|
||||
result = sanitize_tool_response(content)
|
||||
assert result == '{"status": "success", "data": {"id": 123}}'
|
||||
|
||||
def test_truncation_at_exact_boundary(self):
|
||||
"""Test truncation behavior at exact max_length boundary"""
|
||||
content = "x" * 50000
|
||||
result = sanitize_tool_response(content, max_length=50000)
|
||||
assert len(result) == 50000
|
||||
assert not result.endswith("...")
|
||||
|
||||
def test_truncation_one_over_boundary(self):
|
||||
"""Test truncation when content is one char over limit"""
|
||||
content = "x" * 50001
|
||||
result = sanitize_tool_response(content, max_length=50000)
|
||||
assert len(result) <= 50003
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_multiple_control_characters(self):
|
||||
"""Test removal of multiple types of control characters"""
|
||||
content = "text\x00with\x01various\x02control\x1Fchars\x7F"
|
||||
result = sanitize_tool_response(content)
|
||||
# All control characters should be removed
|
||||
assert "\x00" not in result
|
||||
assert "\x01" not in result
|
||||
assert "\x02" not in result
|
||||
assert "\x1F" not in result
|
||||
assert "\x7F" not in result
|
||||
assert "textwithvariouscontrolchars" == result
|
||||
|
||||
def test_newline_and_tab_preservation(self):
|
||||
"""Test that newlines and tabs are preserved (they are valid)"""
|
||||
content = "line1\nline2\tindented"
|
||||
result = sanitize_tool_response(content)
|
||||
assert "\n" in result
|
||||
assert "\t" in result
|
||||
assert result == "line1\nline2\tindented"
|
||||
|
||||
def test_non_json_content_unchanged(self):
|
||||
"""Test that non-JSON content is not modified"""
|
||||
content = "This is plain text without any JSON structure"
|
||||
result = sanitize_tool_response(content)
|
||||
assert result == content
|
||||
|
||||
def test_json_array_at_start(self):
|
||||
"""Test extraction of JSON array at start of content"""
|
||||
content = '[1, 2, 3, 4, 5] followed by text'
|
||||
result = sanitize_tool_response(content)
|
||||
assert result == '[1, 2, 3, 4, 5]'
|
||||
|
||||
def test_empty_json_structures_preserved(self):
|
||||
"""Test that empty JSON structures are preserved"""
|
||||
content = '{"empty_obj": {}, "empty_arr": []} extra'
|
||||
result = sanitize_tool_response(content)
|
||||
assert result == '{"empty_obj": {}, "empty_arr": []}'
|
||||
|
||||
def test_whitespace_variations(self):
|
||||
"""Test handling of various whitespace patterns"""
|
||||
content = " \n\t content with spaces \t\n "
|
||||
result = sanitize_tool_response(content)
|
||||
assert result == "content with spaces"
|
||||
@@ -1,268 +0,0 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for log sanitization utilities.
|
||||
|
||||
This test file verifies that the log sanitizer properly prevents log injection attacks
|
||||
by escaping dangerous characters in user-controlled input before logging.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.utils.log_sanitizer import (
|
||||
create_safe_log_message,
|
||||
sanitize_agent_name,
|
||||
sanitize_feedback,
|
||||
sanitize_log_input,
|
||||
sanitize_thread_id,
|
||||
sanitize_tool_name,
|
||||
sanitize_user_content,
|
||||
)
|
||||
|
||||
|
||||
class TestSanitizeLogInput:
|
||||
"""Test the main sanitize_log_input function."""
|
||||
|
||||
def test_sanitize_normal_text(self):
|
||||
"""Test that normal text is preserved."""
|
||||
text = "normal text"
|
||||
result = sanitize_log_input(text)
|
||||
assert result == "normal text"
|
||||
|
||||
def test_sanitize_newline_injection(self):
|
||||
"""Test prevention of newline injection attack."""
|
||||
malicious = "abc\n[INFO] Forged log entry"
|
||||
result = sanitize_log_input(malicious)
|
||||
assert "\n" not in result
|
||||
assert "[INFO]" in result # The attack text is preserved but escaped
|
||||
assert "\\n" in result # Newline is escaped
|
||||
|
||||
def test_sanitize_carriage_return(self):
|
||||
"""Test prevention of carriage return injection."""
|
||||
malicious = "text\r[WARN] Forged entry"
|
||||
result = sanitize_log_input(malicious)
|
||||
assert "\r" not in result
|
||||
assert "\\r" in result
|
||||
|
||||
def test_sanitize_tab_character(self):
|
||||
"""Test prevention of tab character injection."""
|
||||
malicious = "text\t[ERROR] Forged"
|
||||
result = sanitize_log_input(malicious)
|
||||
assert "\t" not in result
|
||||
assert "\\t" in result
|
||||
|
||||
def test_sanitize_null_character(self):
|
||||
"""Test prevention of null character injection."""
|
||||
malicious = "text\x00[CRITICAL]"
|
||||
result = sanitize_log_input(malicious)
|
||||
assert "\x00" not in result
|
||||
|
||||
def test_sanitize_backslash(self):
|
||||
"""Test that backslashes are properly escaped."""
|
||||
text = "path\\to\\file"
|
||||
result = sanitize_log_input(text)
|
||||
assert result == "path\\\\to\\\\file"
|
||||
|
||||
def test_sanitize_escape_character(self):
|
||||
"""Test prevention of ANSI escape sequence injection."""
|
||||
malicious = "text\x1b[31mRED TEXT\x1b[0m"
|
||||
result = sanitize_log_input(malicious)
|
||||
assert "\x1b" not in result
|
||||
assert "\\x1b" in result
|
||||
|
||||
def test_sanitize_max_length_truncation(self):
|
||||
"""Test that long strings are truncated."""
|
||||
long_text = "a" * 1000
|
||||
result = sanitize_log_input(long_text, max_length=100)
|
||||
assert len(result) <= 100
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_sanitize_none_value(self):
|
||||
"""Test that None is handled properly."""
|
||||
result = sanitize_log_input(None)
|
||||
assert result == "None"
|
||||
|
||||
def test_sanitize_numeric_value(self):
|
||||
"""Test that numeric values are converted to strings."""
|
||||
result = sanitize_log_input(12345)
|
||||
assert result == "12345"
|
||||
|
||||
def test_sanitize_complex_injection_attack(self):
|
||||
"""Test complex multi-character injection attack."""
|
||||
malicious = 'thread-123\n[WARNING] Unauthorized\r[ERROR] System failure\t[CRITICAL] Shutdown'
|
||||
result = sanitize_log_input(malicious)
|
||||
# All dangerous characters should be escaped
|
||||
assert "\n" not in result
|
||||
assert "\r" not in result
|
||||
assert "\t" not in result
|
||||
# But the text should still be there (escaped)
|
||||
assert "WARNING" in result
|
||||
assert "ERROR" in result
|
||||
|
||||
|
||||
class TestSanitizeThreadId:
|
||||
"""Test sanitization of thread IDs."""
|
||||
|
||||
def test_thread_id_normal(self):
|
||||
"""Test normal thread ID."""
|
||||
thread_id = "thread-123-abc"
|
||||
result = sanitize_thread_id(thread_id)
|
||||
assert result == "thread-123-abc"
|
||||
|
||||
def test_thread_id_with_newline(self):
|
||||
"""Test thread ID with newline injection."""
|
||||
malicious = "thread-1\n[INFO] Forged"
|
||||
result = sanitize_thread_id(malicious)
|
||||
assert "\n" not in result
|
||||
assert "\\n" in result
|
||||
|
||||
def test_thread_id_max_length(self):
|
||||
"""Test that thread ID truncation respects max length."""
|
||||
long_id = "x" * 200
|
||||
result = sanitize_thread_id(long_id)
|
||||
assert len(result) <= 100
|
||||
|
||||
|
||||
class TestSanitizeUserContent:
|
||||
"""Test sanitization of user-provided message content."""
|
||||
|
||||
def test_user_content_normal(self):
|
||||
"""Test normal user content."""
|
||||
content = "What is the weather today?"
|
||||
result = sanitize_user_content(content)
|
||||
assert result == "What is the weather today?"
|
||||
|
||||
def test_user_content_with_newline(self):
|
||||
"""Test user content with newline."""
|
||||
malicious = "My question\n[ADMIN] Delete user"
|
||||
result = sanitize_user_content(malicious)
|
||||
assert "\n" not in result
|
||||
assert "\\n" in result
|
||||
|
||||
def test_user_content_max_length(self):
|
||||
"""Test that user content is truncated more aggressively."""
|
||||
long_content = "x" * 500
|
||||
result = sanitize_user_content(long_content)
|
||||
assert len(result) <= 200
|
||||
|
||||
|
||||
class TestSanitizeToolName:
|
||||
"""Test sanitization of tool names."""
|
||||
|
||||
def test_tool_name_normal(self):
|
||||
"""Test normal tool name."""
|
||||
tool = "web_search"
|
||||
result = sanitize_tool_name(tool)
|
||||
assert result == "web_search"
|
||||
|
||||
def test_tool_name_injection(self):
|
||||
"""Test tool name with injection attempt."""
|
||||
malicious = "search\n[WARN] Forged"
|
||||
result = sanitize_tool_name(malicious)
|
||||
assert "\n" not in result
|
||||
|
||||
|
||||
class TestSanitizeFeedback:
|
||||
"""Test sanitization of user feedback."""
|
||||
|
||||
def test_feedback_normal(self):
|
||||
"""Test normal feedback."""
|
||||
feedback = "[accepted]"
|
||||
result = sanitize_feedback(feedback)
|
||||
assert result == "[accepted]"
|
||||
|
||||
def test_feedback_injection(self):
|
||||
"""Test feedback with injection attempt."""
|
||||
malicious = "[approved]\n[CRITICAL] System down"
|
||||
result = sanitize_feedback(malicious)
|
||||
assert "\n" not in result
|
||||
assert "\\n" in result
|
||||
|
||||
def test_feedback_max_length(self):
|
||||
"""Test that feedback is truncated."""
|
||||
long_feedback = "x" * 500
|
||||
result = sanitize_feedback(long_feedback)
|
||||
assert len(result) <= 150
|
||||
|
||||
|
||||
class TestCreateSafeLogMessage:
|
||||
"""Test the create_safe_log_message helper function."""
|
||||
|
||||
def test_safe_message_normal(self):
|
||||
"""Test normal message creation."""
|
||||
msg = create_safe_log_message(
|
||||
"[{thread_id}] Processing {tool_name}",
|
||||
thread_id="thread-1",
|
||||
tool_name="search",
|
||||
)
|
||||
assert "[thread-1] Processing search" == msg
|
||||
|
||||
def test_safe_message_with_injection(self):
|
||||
"""Test message creation with injected values."""
|
||||
msg = create_safe_log_message(
|
||||
"[{thread_id}] Tool: {tool_name}",
|
||||
thread_id="id\n[INFO] Forged",
|
||||
tool_name="search\r[ERROR]",
|
||||
)
|
||||
# The dangerous characters should be escaped
|
||||
assert "\n" not in msg
|
||||
assert "\r" not in msg
|
||||
assert "\\n" in msg
|
||||
assert "\\r" in msg
|
||||
|
||||
def test_safe_message_multiple_values(self):
|
||||
"""Test message with multiple values."""
|
||||
msg = create_safe_log_message(
|
||||
"[{id}] User: {user} Tool: {tool}",
|
||||
id="123",
|
||||
user="admin\t[WARN]",
|
||||
tool="delete\x1b[31m",
|
||||
)
|
||||
assert "\t" not in msg
|
||||
assert "\x1b" not in msg
|
||||
|
||||
|
||||
class TestLogInjectionAttackPrevention:
|
||||
"""Integration tests for log injection prevention."""
|
||||
|
||||
def test_classic_log_injection_newline(self):
|
||||
"""Test the classic log injection attack using newlines."""
|
||||
attacker_input = 'abc\n[WARNING] Unauthorized access detected'
|
||||
result = sanitize_log_input(attacker_input)
|
||||
# The output should not contain an actual newline that would create a new log entry
|
||||
assert result.count("\n") == 0
|
||||
# But the escaped version should be in there
|
||||
assert "\\n" in result
|
||||
|
||||
def test_carriage_return_log_injection(self):
|
||||
"""Test log injection via carriage return."""
|
||||
attacker_input = "request_id\r\n[ERROR] CRITICAL FAILURE"
|
||||
result = sanitize_log_input(attacker_input)
|
||||
assert "\r" not in result
|
||||
assert "\n" not in result
|
||||
|
||||
def test_html_injection_prevention(self):
|
||||
"""Test prevention of HTML injection in logs."""
|
||||
# While HTML tags themselves aren't dangerous in log files,
|
||||
# escaping control characters helps prevent parsing attacks
|
||||
malicious_html = "user\x1b[32m<script>alert('xss')</script>"
|
||||
result = sanitize_log_input(malicious_html)
|
||||
assert "\x1b" not in result
|
||||
# HTML is preserved but with escaped control chars
|
||||
assert "<script>" in result
|
||||
|
||||
def test_multiple_injection_techniques(self):
|
||||
"""Test prevention of multiple injection techniques combined."""
|
||||
attack = 'id_1\n\r\t[CRITICAL]\x1b[31m RED TEXT'
|
||||
result = sanitize_log_input(attack)
|
||||
# No actual control characters should exist
|
||||
assert "\n" not in result
|
||||
assert "\r" not in result
|
||||
assert "\t" not in result
|
||||
assert "\x1b" not in result
|
||||
# But escaped versions should exist
|
||||
assert "\\n" in result
|
||||
assert "\\r" in result
|
||||
assert "\\t" in result
|
||||
assert "\\x1b" in result
|
||||
Reference in New Issue
Block a user