| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551 |
- """
- 混合搜索服务单元测试
- 测试 HybridSearchService 的核心功能,包括:
- - 分数合并算法
- - 结果排序
- - 归一化
- - 过滤和去重
- """
- import pytest
- from src.domain.shared.value_objects import EntityId, Timestamp
- from src.domain.vector_search.entities import Document, SearchResult
- from src.domain.vector_search.exceptions import VectorSearchException
- from src.domain.vector_search.services import HybridSearchService
- from src.domain.vector_search.value_objects import Vector
- class TestHybridSearchServiceInitialization:
- """测试混合搜索服务的初始化"""
-
- def test_default_initialization(self):
- """测试使用默认权重初始化"""
- service = HybridSearchService()
- assert service.vector_weight == 0.7
- assert service.text_weight == 0.3
-
- def test_custom_initialization(self):
- """测试使用自定义权重初始化"""
- service = HybridSearchService(vector_weight=0.6, text_weight=0.4)
- assert service.vector_weight == 0.6
- assert service.text_weight == 0.4
-
- def test_initialization_with_negative_vector_weight(self):
- """测试使用负向量权重初始化应该失败"""
- with pytest.raises(ValueError, match="vector_weight cannot be negative"):
- HybridSearchService(vector_weight=-0.1, text_weight=0.3)
-
- def test_initialization_with_negative_text_weight(self):
- """测试使用负全文权重初始化应该失败"""
- with pytest.raises(ValueError, match="text_weight cannot be negative"):
- HybridSearchService(vector_weight=0.7, text_weight=-0.1)
-
- def test_initialization_with_both_weights_zero(self):
- """测试两个权重都为零应该失败"""
- with pytest.raises(ValueError, match="At least one weight must be greater than zero"):
- HybridSearchService(vector_weight=0.0, text_weight=0.0)
-
- def test_initialization_with_non_numeric_vector_weight(self):
- """测试使用非数值向量权重应该失败"""
- with pytest.raises(ValueError, match="vector_weight must be a number"):
- HybridSearchService(vector_weight="0.7", text_weight=0.3)
-
- def test_initialization_with_non_numeric_text_weight(self):
- """测试使用非数值全文权重应该失败"""
- with pytest.raises(ValueError, match="text_weight must be a number"):
- HybridSearchService(vector_weight=0.7, text_weight="0.3")
-
- def test_get_weights(self):
- """测试获取权重配置"""
- service = HybridSearchService(vector_weight=0.6, text_weight=0.4)
- vector_w, text_w = service.get_weights()
- assert vector_w == 0.6
- assert text_w == 0.4
- class TestHybridSearchServiceCombineScores:
- """测试混合搜索服务的分数合并功能"""
-
- @pytest.fixture
- def service(self):
- """创建混合搜索服务实例"""
- return HybridSearchService(vector_weight=0.7, text_weight=0.3)
-
- @pytest.fixture
- def sample_documents(self):
- """创建示例文档"""
- doc1 = Document(
- id=EntityId("doc_1"),
- content="Machine learning is a subset of AI",
- embedding=Vector([1.0, 2.0, 3.0]),
- metadata={"category": "AI"},
- created_at=Timestamp.now(),
- updated_at=Timestamp.now()
- )
- doc2 = Document(
- id=EntityId("doc_2"),
- content="Deep learning uses neural networks",
- embedding=Vector([2.0, 3.0, 4.0]),
- metadata={"category": "AI"},
- created_at=Timestamp.now(),
- updated_at=Timestamp.now()
- )
- doc3 = Document(
- id=EntityId("doc_3"),
- content="Natural language processing",
- embedding=Vector([3.0, 4.0, 5.0]),
- metadata={"category": "NLP"},
- created_at=Timestamp.now(),
- updated_at=Timestamp.now()
- )
- return doc1, doc2, doc3
-
- def test_combine_scores_with_overlapping_results(self, service, sample_documents):
- """测试合并有重叠文档的搜索结果"""
- doc1, doc2, doc3 = sample_documents
-
- # 向量搜索结果
- vector_results = [
- SearchResult(document=doc1, score=0.9, rank=0),
- SearchResult(document=doc2, score=0.8, rank=1)
- ]
-
- # 全文搜索结果
- text_results = [
- SearchResult(document=doc1, score=0.7, rank=0),
- SearchResult(document=doc3, score=0.6, rank=1)
- ]
-
- # 合并结果
- combined = service.combine_scores(vector_results, text_results)
-
- # 验证结果数量
- assert len(combined) == 3
-
- # 验证分数计算
- # doc1: 0.9 * 0.7 + 0.7 * 0.3 = 0.63 + 0.21 = 0.84
- # doc2: 0.8 * 0.7 = 0.56
- # doc3: 0.6 * 0.3 = 0.18
- assert combined[0].document.id == EntityId("doc_1")
- assert abs(combined[0].score - 0.84) < 0.001
- assert combined[0].rank == 0
-
- assert combined[1].document.id == EntityId("doc_2")
- assert abs(combined[1].score - 0.56) < 0.001
- assert combined[1].rank == 1
-
- assert combined[2].document.id == EntityId("doc_3")
- assert abs(combined[2].score - 0.18) < 0.001
- assert combined[2].rank == 2
-
- def test_combine_scores_with_no_overlap(self, service, sample_documents):
- """测试合并没有重叠文档的搜索结果"""
- doc1, doc2, doc3 = sample_documents
-
- vector_results = [
- SearchResult(document=doc1, score=0.9, rank=0)
- ]
-
- text_results = [
- SearchResult(document=doc2, score=0.8, rank=0)
- ]
-
- combined = service.combine_scores(vector_results, text_results)
-
- assert len(combined) == 2
- assert combined[0].document.id == EntityId("doc_1")
- assert abs(combined[0].score - 0.63) < 0.001 # 0.9 * 0.7
- assert combined[1].document.id == EntityId("doc_2")
- assert abs(combined[1].score - 0.24) < 0.001 # 0.8 * 0.3
-
- def test_combine_scores_with_empty_vector_results(self, service, sample_documents):
- """测试向量搜索结果为空的情况"""
- doc1, doc2, doc3 = sample_documents
-
- vector_results = []
- text_results = [
- SearchResult(document=doc1, score=0.8, rank=0)
- ]
-
- combined = service.combine_scores(vector_results, text_results)
-
- assert len(combined) == 1
- assert combined[0].document.id == EntityId("doc_1")
- assert abs(combined[0].score - 0.24) < 0.001 # 0.8 * 0.3
-
- def test_combine_scores_with_empty_text_results(self, service, sample_documents):
- """测试全文搜索结果为空的情况"""
- doc1, doc2, doc3 = sample_documents
-
- vector_results = [
- SearchResult(document=doc1, score=0.9, rank=0)
- ]
- text_results = []
-
- combined = service.combine_scores(vector_results, text_results)
-
- assert len(combined) == 1
- assert combined[0].document.id == EntityId("doc_1")
- assert abs(combined[0].score - 0.63) < 0.001 # 0.9 * 0.7
-
- def test_combine_scores_with_both_empty(self, service):
- """测试两个搜索结果都为空的情况"""
- combined = service.combine_scores([], [])
- assert len(combined) == 0
-
- def test_combine_scores_with_invalid_vector_results(self, service):
- """测试无效的向量搜索结果"""
- with pytest.raises(ValueError, match="vector_results must be a list"):
- service.combine_scores("not a list", [])
-
- def test_combine_scores_with_invalid_text_results(self, service):
- """测试无效的全文搜索结果"""
- with pytest.raises(ValueError, match="text_results must be a list"):
- service.combine_scores([], "not a list")
-
- def test_combine_scores_with_non_search_result_items(self, service, sample_documents):
- """测试列表中包含非 SearchResult 对象"""
- doc1, doc2, doc3 = sample_documents
-
- vector_results = [
- SearchResult(document=doc1, score=0.9, rank=0),
- "not a search result"
- ]
-
- with pytest.raises(ValueError, match="must be a SearchResult instance"):
- service.combine_scores(vector_results, [])
- class TestHybridSearchServiceNormalizeScores:
- """测试混合搜索服务的分数归一化功能"""
-
- @pytest.fixture
- def service(self):
- """创建混合搜索服务实例"""
- return HybridSearchService()
-
- @pytest.fixture
- def sample_documents(self):
- """创建示例文档"""
- doc1 = Document(
- id=EntityId("doc_1"),
- content="Test 1",
- embedding=None,
- metadata={},
- created_at=Timestamp.now(),
- updated_at=Timestamp.now()
- )
- doc2 = Document(
- id=EntityId("doc_2"),
- content="Test 2",
- embedding=None,
- metadata={},
- created_at=Timestamp.now(),
- updated_at=Timestamp.now()
- )
- doc3 = Document(
- id=EntityId("doc_3"),
- content="Test 3",
- embedding=None,
- metadata={},
- created_at=Timestamp.now(),
- updated_at=Timestamp.now()
- )
- return doc1, doc2, doc3
-
- def test_normalize_scores_basic(self, service, sample_documents):
- """测试基本的分数归一化"""
- doc1, doc2, doc3 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=10.0, rank=0),
- SearchResult(document=doc2, score=5.0, rank=1),
- SearchResult(document=doc3, score=0.0, rank=2)
- ]
-
- normalized = service.normalize_scores(results, min_score=0.0, max_score=1.0)
-
- assert len(normalized) == 3
- assert abs(normalized[0].score - 1.0) < 0.001
- assert abs(normalized[1].score - 0.5) < 0.001
- assert abs(normalized[2].score - 0.0) < 0.001
-
- def test_normalize_scores_custom_range(self, service, sample_documents):
- """测试自定义范围的归一化"""
- doc1, doc2, doc3 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=100.0, rank=0),
- SearchResult(document=doc2, score=50.0, rank=1)
- ]
-
- normalized = service.normalize_scores(results, min_score=0.5, max_score=1.0)
-
- assert abs(normalized[0].score - 1.0) < 0.001
- assert abs(normalized[1].score - 0.5) < 0.001
-
- def test_normalize_scores_single_result(self, service, sample_documents):
- """测试单个结果的归一化"""
- doc1, doc2, doc3 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=5.0, rank=0)
- ]
-
- normalized = service.normalize_scores(results, min_score=0.0, max_score=1.0)
-
- assert len(normalized) == 1
- assert normalized[0].score == 1.0 # 单个结果应该得到最大分数
-
- def test_normalize_scores_all_same(self, service, sample_documents):
- """测试所有分数相同的情况"""
- doc1, doc2, doc3 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=5.0, rank=0),
- SearchResult(document=doc2, score=5.0, rank=1),
- SearchResult(document=doc3, score=5.0, rank=2)
- ]
-
- normalized = service.normalize_scores(results, min_score=0.0, max_score=1.0)
-
- # 所有分数相同时,应该得到平均分数
- for result in normalized:
- assert abs(result.score - 0.5) < 0.001
-
- def test_normalize_scores_empty_list(self, service):
- """测试空列表的归一化"""
- normalized = service.normalize_scores([], min_score=0.0, max_score=1.0)
- assert len(normalized) == 0
-
- def test_normalize_scores_invalid_range(self, service, sample_documents):
- """测试无效的归一化范围"""
- doc1, doc2, doc3 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=5.0, rank=0)
- ]
-
- with pytest.raises(ValueError, match="min_score must be less than max_score"):
- service.normalize_scores(results, min_score=1.0, max_score=0.0)
- class TestHybridSearchServiceFilterByScoreThreshold:
- """测试混合搜索服务的分数阈值过滤功能"""
-
- @pytest.fixture
- def service(self):
- """创建混合搜索服务实例"""
- return HybridSearchService()
-
- @pytest.fixture
- def sample_documents(self):
- """创建示例文档"""
- doc1 = Document(
- id=EntityId("doc_1"),
- content="Test 1",
- embedding=None,
- metadata={},
- created_at=Timestamp.now(),
- updated_at=Timestamp.now()
- )
- doc2 = Document(
- id=EntityId("doc_2"),
- content="Test 2",
- embedding=None,
- metadata={},
- created_at=Timestamp.now(),
- updated_at=Timestamp.now()
- )
- doc3 = Document(
- id=EntityId("doc_3"),
- content="Test 3",
- embedding=None,
- metadata={},
- created_at=Timestamp.now(),
- updated_at=Timestamp.now()
- )
- return doc1, doc2, doc3
-
- def test_filter_by_score_threshold_basic(self, service, sample_documents):
- """测试基本的分数过滤"""
- doc1, doc2, doc3 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=0.9, rank=0),
- SearchResult(document=doc2, score=0.7, rank=1),
- SearchResult(document=doc3, score=0.5, rank=2)
- ]
-
- filtered = service.filter_by_score_threshold(results, threshold=0.6)
-
- assert len(filtered) == 2
- assert filtered[0].document.id == EntityId("doc_1")
- assert filtered[0].rank == 0
- assert filtered[1].document.id == EntityId("doc_2")
- assert filtered[1].rank == 1
-
- def test_filter_by_score_threshold_all_pass(self, service, sample_documents):
- """测试所有结果都通过阈值"""
- doc1, doc2, doc3 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=0.9, rank=0),
- SearchResult(document=doc2, score=0.8, rank=1)
- ]
-
- filtered = service.filter_by_score_threshold(results, threshold=0.5)
-
- assert len(filtered) == 2
-
- def test_filter_by_score_threshold_none_pass(self, service, sample_documents):
- """测试没有结果通过阈值"""
- doc1, doc2, doc3 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=0.4, rank=0),
- SearchResult(document=doc2, score=0.3, rank=1)
- ]
-
- filtered = service.filter_by_score_threshold(results, threshold=0.5)
-
- assert len(filtered) == 0
-
- def test_filter_by_score_threshold_exact_match(self, service, sample_documents):
- """测试分数正好等于阈值的情况"""
- doc1, doc2, doc3 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=0.6, rank=0)
- ]
-
- filtered = service.filter_by_score_threshold(results, threshold=0.6)
-
- assert len(filtered) == 1
- class TestHybridSearchServiceDeduplicateResults:
- """测试混合搜索服务的去重功能"""
-
- @pytest.fixture
- def service(self):
- """创建混合搜索服务实例"""
- return HybridSearchService()
-
- @pytest.fixture
- def sample_documents(self):
- """创建示例文档"""
- doc1 = Document(
- id=EntityId("doc_1"),
- content="Test 1",
- embedding=None,
- metadata={},
- created_at=Timestamp.now(),
- updated_at=Timestamp.now()
- )
- doc2 = Document(
- id=EntityId("doc_2"),
- content="Test 2",
- embedding=None,
- metadata={},
- created_at=Timestamp.now(),
- updated_at=Timestamp.now()
- )
- return doc1, doc2
-
- def test_deduplicate_results_keep_highest(self, service, sample_documents):
- """测试保留最高分数的去重"""
- doc1, doc2 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=0.9, rank=0),
- SearchResult(document=doc2, score=0.8, rank=1),
- SearchResult(document=doc1, score=0.7, rank=2) # 重复
- ]
-
- deduped = service.deduplicate_results(results, keep="highest")
-
- assert len(deduped) == 2
- assert deduped[0].document.id == EntityId("doc_1")
- assert deduped[0].score == 0.9 # 保留最高分
- assert deduped[1].document.id == EntityId("doc_2")
-
- def test_deduplicate_results_keep_lowest(self, service, sample_documents):
- """测试保留最低分数的去重"""
- doc1, doc2 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=0.9, rank=0),
- SearchResult(document=doc2, score=0.8, rank=1),
- SearchResult(document=doc1, score=0.7, rank=2) # 重复
- ]
-
- deduped = service.deduplicate_results(results, keep="lowest")
-
- assert len(deduped) == 2
- assert deduped[0].document.id == EntityId("doc_1")
- assert deduped[0].score == 0.7 # 保留最低分
- assert deduped[1].document.id == EntityId("doc_2")
-
- def test_deduplicate_results_keep_first(self, service, sample_documents):
- """测试保留第一次出现的去重"""
- doc1, doc2 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=0.9, rank=0),
- SearchResult(document=doc2, score=0.8, rank=1),
- SearchResult(document=doc1, score=0.7, rank=2) # 重复
- ]
-
- deduped = service.deduplicate_results(results, keep="first")
-
- assert len(deduped) == 2
- assert deduped[0].document.id == EntityId("doc_1")
- assert deduped[0].score == 0.9 # 保留第一次出现的
- assert deduped[1].document.id == EntityId("doc_2")
-
- def test_deduplicate_results_no_duplicates(self, service, sample_documents):
- """测试没有重复的情况"""
- doc1, doc2 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=0.9, rank=0),
- SearchResult(document=doc2, score=0.8, rank=1)
- ]
-
- deduped = service.deduplicate_results(results, keep="highest")
-
- assert len(deduped) == 2
-
- def test_deduplicate_results_invalid_keep_strategy(self, service, sample_documents):
- """测试无效的保留策略"""
- doc1, doc2 = sample_documents
-
- results = [
- SearchResult(document=doc1, score=0.9, rank=0)
- ]
-
- with pytest.raises(ValueError, match="keep must be one of"):
- service.deduplicate_results(results, keep="invalid")
- class TestHybridSearchServiceStringRepresentation:
- """测试混合搜索服务的字符串表示"""
-
- def test_str_representation(self):
- """测试 __str__ 方法"""
- service = HybridSearchService(vector_weight=0.6, text_weight=0.4)
- str_repr = str(service)
- assert "HybridSearchService" in str_repr
- assert "0.6" in str_repr
- assert "0.4" in str_repr
-
- def test_repr_representation(self):
- """测试 __repr__ 方法"""
- service = HybridSearchService(vector_weight=0.6, text_weight=0.4)
- repr_str = repr(service)
- assert "HybridSearchService" in repr_str
- assert "0.6" in repr_str
- assert "0.4" in repr_str
|