Compare commits

...

3 Commits

Author SHA1 Message Date
Henry Li 03df43feb1 docs: add VolcEngine introduction. 2025-06-12 13:28:41 +08:00
Willem Jiang ee1af78767 test: added unit tests for rag (#298)
* test: added unit tests for rag

* reformate the code
2025-06-11 19:46:08 +08:00
Willem Jiang 2554e4ba63 test: add unit tests of llms (#299) 2025-06-11 19:46:01 +08:00
6 changed files with 330 additions and 15 deletions
+7
View File
@@ -31,6 +31,13 @@ https://github.com/user-attachments/assets/f3786598-1f2a-4d07-919e-8b99dfa1de3e
- [如何装饰租赁公寓?](https://deerflow.tech/chat?replay=rental-apartment-decoration)
- [访问我们的官方网站探索更多回放示例。](https://deerflow.tech/#case-studies)
### 火山引擎
目前,DeerFlow 已正式入驻[火山引擎的 FaaS 应用中心](https://console.volcengine.com/vefaas/region:vefaas+cn-beijing/market),用户可通过体验链接进行在线体验,直观感受其强大功能与便捷操作;同时,为满足不同用户的部署需求,DeerFlow 支持基于火山引擎一键部署,点击部署链接即可快速完成部署流程,开启高效研究之旅。[快来看看吧](https://console.volcengine.com/vefaas/region:vefaas+cn-beijing/market)~
<img width="1800" alt="截屏2025-06-12 13 25 12" src="https://github.com/user-attachments/assets/73c15966-6b79-4dc0-8803-efdaf7c4015e" />
---
## 📑 目录
-6
View File
@@ -70,9 +70,3 @@ def get_llm_by_type(
# In the future, we will use reasoning_llm and vl_llm for different purposes
# reasoning_llm = get_llm_by_type("reasoning")
# vl_llm = get_llm_by_type("vision")
if __name__ == "__main__":
# Initialize LLMs for different purposes - now these will be cached
basic_llm = get_llm_by_type("basic")
print(basic_llm.invoke("Hello"))
-9
View File
@@ -122,12 +122,3 @@ def parse_uri(uri: str) -> tuple[str, str]:
if parsed.scheme != "rag":
raise ValueError(f"Invalid URI: {uri}")
return parsed.path.split("/")[1], parsed.fragment
if __name__ == "__main__":
uri = "rag://dataset/123#abc"
parsed = urlparse(uri)
print(parsed.scheme)
print(parsed.netloc)
print(parsed.path)
print(parsed.fragment)
+70
View File
@@ -0,0 +1,70 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
import types
import pytest
from src.llms import llm
class DummyChatOpenAI:
def __init__(self, **kwargs):
self.kwargs = kwargs
def invoke(self, msg):
return f"Echo: {msg}"
@pytest.fixture(autouse=True)
def patch_chat_openai(monkeypatch):
monkeypatch.setattr(llm, "ChatOpenAI", DummyChatOpenAI)
@pytest.fixture
def dummy_conf():
return {
"BASIC_MODEL": {"api_key": "test_key", "base_url": "http://test"},
"REASONING_MODEL": {"api_key": "reason_key"},
"VISION_MODEL": {"api_key": "vision_key"},
}
def test_get_env_llm_conf(monkeypatch):
monkeypatch.setenv("BASIC_MODEL__API_KEY", "env_key")
monkeypatch.setenv("BASIC_MODEL__BASE_URL", "http://env")
conf = llm._get_env_llm_conf("basic")
assert conf["api_key"] == "env_key"
assert conf["base_url"] == "http://env"
def test_create_llm_use_conf_merges_env(monkeypatch, dummy_conf):
monkeypatch.setenv("BASIC_MODEL__API_KEY", "env_key")
result = llm._create_llm_use_conf("basic", dummy_conf)
assert isinstance(result, DummyChatOpenAI)
assert result.kwargs["api_key"] == "env_key"
assert result.kwargs["base_url"] == "http://test"
def test_create_llm_use_conf_invalid_type(dummy_conf):
with pytest.raises(ValueError):
llm._create_llm_use_conf("unknown", dummy_conf)
def test_create_llm_use_conf_empty_conf(monkeypatch):
with pytest.raises(ValueError):
llm._create_llm_use_conf("basic", {})
def test_get_llm_by_type_caches(monkeypatch, dummy_conf):
called = {}
def fake_load_yaml_config(path):
called["called"] = True
return dummy_conf
monkeypatch.setattr(llm, "load_yaml_config", fake_load_yaml_config)
llm._llm_cache.clear()
inst1 = llm.get_llm_by_type("basic")
inst2 = llm.get_llm_by_type("basic")
assert inst1 is inst2
assert called["called"]
+181
View File
@@ -0,0 +1,181 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
import pytest
import requests
from unittest.mock import patch, MagicMock
from src.rag.ragflow import RAGFlowProvider, parse_uri
# Dummy classes to mock dependencies
class DummyResource:
def __init__(self, uri, title="", description=""):
self.uri = uri
self.title = title
self.description = description
class DummyChunk:
def __init__(self, content, similarity):
self.content = content
self.similarity = similarity
class DummyDocument:
def __init__(self, id, title, chunks=None):
self.id = id
self.title = title
self.chunks = chunks or []
# Patch imports in ragflow.py to use dummy classes
@pytest.fixture(autouse=True)
def patch_imports(monkeypatch):
import src.rag.ragflow as ragflow
ragflow.Resource = DummyResource
ragflow.Chunk = DummyChunk
ragflow.Document = DummyDocument
yield
def test_parse_uri_valid():
uri = "rag://dataset/123#abc"
dataset_id, document_id = parse_uri(uri)
assert dataset_id == "123"
assert document_id == "abc"
def test_parse_uri_invalid():
with pytest.raises(ValueError):
parse_uri("http://dataset/123#abc")
def test_init_env_vars(monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
monkeypatch.delenv("RAGFLOW_PAGE_SIZE", raising=False)
provider = RAGFlowProvider()
assert provider.api_url == "http://api"
assert provider.api_key == "key"
assert provider.page_size == 10
def test_init_page_size(monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
monkeypatch.setenv("RAGFLOW_PAGE_SIZE", "5")
provider = RAGFlowProvider()
assert provider.page_size == 5
def test_init_missing_env(monkeypatch):
monkeypatch.delenv("RAGFLOW_API_URL", raising=False)
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
with pytest.raises(ValueError):
RAGFlowProvider()
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.delenv("RAGFLOW_API_KEY", raising=False)
with pytest.raises(ValueError):
RAGFlowProvider()
@patch("src.rag.ragflow.requests.post")
def test_query_relevant_documents_success(mock_post, monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
provider = RAGFlowProvider()
resource = DummyResource("rag://dataset/123#doc456")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": {
"doc_aggs": [{"doc_id": "doc456", "doc_name": "Doc Title"}],
"chunks": [
{"document_id": "doc456", "content": "chunk text", "similarity": 0.9}
],
}
}
mock_post.return_value = mock_response
docs = provider.query_relevant_documents("query", [resource])
assert len(docs) == 1
assert docs[0].id == "doc456"
assert docs[0].title == "Doc Title"
assert len(docs[0].chunks) == 1
assert docs[0].chunks[0].content == "chunk text"
assert docs[0].chunks[0].similarity == 0.9
@patch("src.rag.ragflow.requests.post")
def test_query_relevant_documents_error(mock_post, monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
provider = RAGFlowProvider()
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "error"
mock_post.return_value = mock_response
with pytest.raises(Exception):
provider.query_relevant_documents("query", [])
@patch("src.rag.ragflow.requests.get")
def test_list_resources_success(mock_get, monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
provider = RAGFlowProvider()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "123", "name": "Dataset1", "description": "desc1"},
{"id": "456", "name": "Dataset2", "description": "desc2"},
]
}
mock_get.return_value = mock_response
resources = provider.list_resources()
assert len(resources) == 2
assert resources[0].uri == "rag://dataset/123"
assert resources[0].title == "Dataset1"
assert resources[0].description == "desc1"
assert resources[1].uri == "rag://dataset/456"
assert resources[1].title == "Dataset2"
assert resources[1].description == "desc2"
@patch("src.rag.ragflow.requests.get")
def test_list_resources_success(mock_get, monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
provider = RAGFlowProvider()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "123", "name": "Dataset1", "description": "desc1"},
{"id": "456", "name": "Dataset2", "description": "desc2"},
]
}
mock_get.return_value = mock_response
resources = provider.list_resources()
assert len(resources) == 2
assert resources[0].uri == "rag://dataset/123"
assert resources[0].title == "Dataset1"
assert resources[0].description == "desc1"
assert resources[1].uri == "rag://dataset/456"
assert resources[1].title == "Dataset2"
assert resources[1].description == "desc2"
@patch("src.rag.ragflow.requests.get")
def test_list_resources_error(mock_get, monkeypatch):
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
provider = RAGFlowProvider()
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "fail"
mock_get.return_value = mock_response
with pytest.raises(Exception):
provider.list_resources()
+72
View File
@@ -0,0 +1,72 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
from src.rag.retriever import Chunk, Document, Resource, Retriever
def test_chunk_init():
chunk = Chunk(content="test content", similarity=0.9)
assert chunk.content == "test content"
assert chunk.similarity == 0.9
def test_document_init_and_to_dict():
chunk1 = Chunk(content="chunk1", similarity=0.8)
chunk2 = Chunk(content="chunk2", similarity=0.7)
doc = Document(
id="doc1", url="http://example.com", title="Title", chunks=[chunk1, chunk2]
)
assert doc.id == "doc1"
assert doc.url == "http://example.com"
assert doc.title == "Title"
assert doc.chunks == [chunk1, chunk2]
d = doc.to_dict()
assert d["id"] == "doc1"
assert d["content"] == "chunk1\n\nchunk2"
assert d["url"] == "http://example.com"
assert d["title"] == "Title"
def test_document_to_dict_optional_fields():
chunk = Chunk(content="only chunk", similarity=1.0)
doc = Document(id="doc2", chunks=[chunk])
d = doc.to_dict()
assert d["id"] == "doc2"
assert d["content"] == "only chunk"
assert "url" not in d
assert "title" not in d
def test_resource_model():
resource = Resource(uri="uri1", title="Resource Title")
assert resource.uri == "uri1"
assert resource.title == "Resource Title"
assert resource.description == ""
def test_resource_model_with_description():
resource = Resource(uri="uri2", title="Resource2", description="desc")
assert resource.description == "desc"
def test_retriever_abstract_methods():
class DummyRetriever(Retriever):
def list_resources(self, query=None):
return [Resource(uri="uri", title="title")]
def query_relevant_documents(self, query, resources=[]):
return [Document(id="id", chunks=[])]
retriever = DummyRetriever()
resources = retriever.list_resources()
assert isinstance(resources, list)
assert isinstance(resources[0], Resource)
docs = retriever.query_relevant_documents("query", resources)
assert isinstance(docs, list)
assert isinstance(docs[0], Document)
def test_retriever_cannot_instantiate():
with pytest.raises(TypeError):
Retriever()