""" 混合搜索服务单元测试 测试 HybridSearchService 的核心功能,包括: - 分数合并算法 - 结果排序 - 归一化 - 过滤和去重 """ import pytest from src.domain.shared.value_objects import EntityId, Timestamp from src.domain.vector_search.entities import Document, SearchResult from src.domain.vector_search.exceptions import VectorSearchException from src.domain.vector_search.services import HybridSearchService from src.domain.vector_search.value_objects import Vector class TestHybridSearchServiceInitialization: """测试混合搜索服务的初始化""" def test_default_initialization(self): """测试使用默认权重初始化""" service = HybridSearchService() assert service.vector_weight == 0.7 assert service.text_weight == 0.3 def test_custom_initialization(self): """测试使用自定义权重初始化""" service = HybridSearchService(vector_weight=0.6, text_weight=0.4) assert service.vector_weight == 0.6 assert service.text_weight == 0.4 def test_initialization_with_negative_vector_weight(self): """测试使用负向量权重初始化应该失败""" with pytest.raises(ValueError, match="vector_weight cannot be negative"): HybridSearchService(vector_weight=-0.1, text_weight=0.3) def test_initialization_with_negative_text_weight(self): """测试使用负全文权重初始化应该失败""" with pytest.raises(ValueError, match="text_weight cannot be negative"): HybridSearchService(vector_weight=0.7, text_weight=-0.1) def test_initialization_with_both_weights_zero(self): """测试两个权重都为零应该失败""" with pytest.raises(ValueError, match="At least one weight must be greater than zero"): HybridSearchService(vector_weight=0.0, text_weight=0.0) def test_initialization_with_non_numeric_vector_weight(self): """测试使用非数值向量权重应该失败""" with pytest.raises(ValueError, match="vector_weight must be a number"): HybridSearchService(vector_weight="0.7", text_weight=0.3) def test_initialization_with_non_numeric_text_weight(self): """测试使用非数值全文权重应该失败""" with pytest.raises(ValueError, match="text_weight must be a number"): HybridSearchService(vector_weight=0.7, text_weight="0.3") def test_get_weights(self): """测试获取权重配置""" service = HybridSearchService(vector_weight=0.6, text_weight=0.4) vector_w, text_w = service.get_weights() assert vector_w == 0.6 assert text_w == 0.4 class TestHybridSearchServiceCombineScores: """测试混合搜索服务的分数合并功能""" @pytest.fixture def service(self): """创建混合搜索服务实例""" return HybridSearchService(vector_weight=0.7, text_weight=0.3) @pytest.fixture def sample_documents(self): """创建示例文档""" doc1 = Document( id=EntityId("doc_1"), content="Machine learning is a subset of AI", embedding=Vector([1.0, 2.0, 3.0]), metadata={"category": "AI"}, created_at=Timestamp.now(), updated_at=Timestamp.now() ) doc2 = Document( id=EntityId("doc_2"), content="Deep learning uses neural networks", embedding=Vector([2.0, 3.0, 4.0]), metadata={"category": "AI"}, created_at=Timestamp.now(), updated_at=Timestamp.now() ) doc3 = Document( id=EntityId("doc_3"), content="Natural language processing", embedding=Vector([3.0, 4.0, 5.0]), metadata={"category": "NLP"}, created_at=Timestamp.now(), updated_at=Timestamp.now() ) return doc1, doc2, doc3 def test_combine_scores_with_overlapping_results(self, service, sample_documents): """测试合并有重叠文档的搜索结果""" doc1, doc2, doc3 = sample_documents # 向量搜索结果 vector_results = [ SearchResult(document=doc1, score=0.9, rank=0), SearchResult(document=doc2, score=0.8, rank=1) ] # 全文搜索结果 text_results = [ SearchResult(document=doc1, score=0.7, rank=0), SearchResult(document=doc3, score=0.6, rank=1) ] # 合并结果 combined = service.combine_scores(vector_results, text_results) # 验证结果数量 assert len(combined) == 3 # 验证分数计算 # doc1: 0.9 * 0.7 + 0.7 * 0.3 = 0.63 + 0.21 = 0.84 # doc2: 0.8 * 0.7 = 0.56 # doc3: 0.6 * 0.3 = 0.18 assert combined[0].document.id == EntityId("doc_1") assert abs(combined[0].score - 0.84) < 0.001 assert combined[0].rank == 0 assert combined[1].document.id == EntityId("doc_2") assert abs(combined[1].score - 0.56) < 0.001 assert combined[1].rank == 1 assert combined[2].document.id == EntityId("doc_3") assert abs(combined[2].score - 0.18) < 0.001 assert combined[2].rank == 2 def test_combine_scores_with_no_overlap(self, service, sample_documents): """测试合并没有重叠文档的搜索结果""" doc1, doc2, doc3 = sample_documents vector_results = [ SearchResult(document=doc1, score=0.9, rank=0) ] text_results = [ SearchResult(document=doc2, score=0.8, rank=0) ] combined = service.combine_scores(vector_results, text_results) assert len(combined) == 2 assert combined[0].document.id == EntityId("doc_1") assert abs(combined[0].score - 0.63) < 0.001 # 0.9 * 0.7 assert combined[1].document.id == EntityId("doc_2") assert abs(combined[1].score - 0.24) < 0.001 # 0.8 * 0.3 def test_combine_scores_with_empty_vector_results(self, service, sample_documents): """测试向量搜索结果为空的情况""" doc1, doc2, doc3 = sample_documents vector_results = [] text_results = [ SearchResult(document=doc1, score=0.8, rank=0) ] combined = service.combine_scores(vector_results, text_results) assert len(combined) == 1 assert combined[0].document.id == EntityId("doc_1") assert abs(combined[0].score - 0.24) < 0.001 # 0.8 * 0.3 def test_combine_scores_with_empty_text_results(self, service, sample_documents): """测试全文搜索结果为空的情况""" doc1, doc2, doc3 = sample_documents vector_results = [ SearchResult(document=doc1, score=0.9, rank=0) ] text_results = [] combined = service.combine_scores(vector_results, text_results) assert len(combined) == 1 assert combined[0].document.id == EntityId("doc_1") assert abs(combined[0].score - 0.63) < 0.001 # 0.9 * 0.7 def test_combine_scores_with_both_empty(self, service): """测试两个搜索结果都为空的情况""" combined = service.combine_scores([], []) assert len(combined) == 0 def test_combine_scores_with_invalid_vector_results(self, service): """测试无效的向量搜索结果""" with pytest.raises(ValueError, match="vector_results must be a list"): service.combine_scores("not a list", []) def test_combine_scores_with_invalid_text_results(self, service): """测试无效的全文搜索结果""" with pytest.raises(ValueError, match="text_results must be a list"): service.combine_scores([], "not a list") def test_combine_scores_with_non_search_result_items(self, service, sample_documents): """测试列表中包含非 SearchResult 对象""" doc1, doc2, doc3 = sample_documents vector_results = [ SearchResult(document=doc1, score=0.9, rank=0), "not a search result" ] with pytest.raises(ValueError, match="must be a SearchResult instance"): service.combine_scores(vector_results, []) class TestHybridSearchServiceNormalizeScores: """测试混合搜索服务的分数归一化功能""" @pytest.fixture def service(self): """创建混合搜索服务实例""" return HybridSearchService() @pytest.fixture def sample_documents(self): """创建示例文档""" doc1 = Document( id=EntityId("doc_1"), content="Test 1", embedding=None, metadata={}, created_at=Timestamp.now(), updated_at=Timestamp.now() ) doc2 = Document( id=EntityId("doc_2"), content="Test 2", embedding=None, metadata={}, created_at=Timestamp.now(), updated_at=Timestamp.now() ) doc3 = Document( id=EntityId("doc_3"), content="Test 3", embedding=None, metadata={}, created_at=Timestamp.now(), updated_at=Timestamp.now() ) return doc1, doc2, doc3 def test_normalize_scores_basic(self, service, sample_documents): """测试基本的分数归一化""" doc1, doc2, doc3 = sample_documents results = [ SearchResult(document=doc1, score=10.0, rank=0), SearchResult(document=doc2, score=5.0, rank=1), SearchResult(document=doc3, score=0.0, rank=2) ] normalized = service.normalize_scores(results, min_score=0.0, max_score=1.0) assert len(normalized) == 3 assert abs(normalized[0].score - 1.0) < 0.001 assert abs(normalized[1].score - 0.5) < 0.001 assert abs(normalized[2].score - 0.0) < 0.001 def test_normalize_scores_custom_range(self, service, sample_documents): """测试自定义范围的归一化""" doc1, doc2, doc3 = sample_documents results = [ SearchResult(document=doc1, score=100.0, rank=0), SearchResult(document=doc2, score=50.0, rank=1) ] normalized = service.normalize_scores(results, min_score=0.5, max_score=1.0) assert abs(normalized[0].score - 1.0) < 0.001 assert abs(normalized[1].score - 0.5) < 0.001 def test_normalize_scores_single_result(self, service, sample_documents): """测试单个结果的归一化""" doc1, doc2, doc3 = sample_documents results = [ SearchResult(document=doc1, score=5.0, rank=0) ] normalized = service.normalize_scores(results, min_score=0.0, max_score=1.0) assert len(normalized) == 1 assert normalized[0].score == 1.0 # 单个结果应该得到最大分数 def test_normalize_scores_all_same(self, service, sample_documents): """测试所有分数相同的情况""" doc1, doc2, doc3 = sample_documents results = [ SearchResult(document=doc1, score=5.0, rank=0), SearchResult(document=doc2, score=5.0, rank=1), SearchResult(document=doc3, score=5.0, rank=2) ] normalized = service.normalize_scores(results, min_score=0.0, max_score=1.0) # 所有分数相同时,应该得到平均分数 for result in normalized: assert abs(result.score - 0.5) < 0.001 def test_normalize_scores_empty_list(self, service): """测试空列表的归一化""" normalized = service.normalize_scores([], min_score=0.0, max_score=1.0) assert len(normalized) == 0 def test_normalize_scores_invalid_range(self, service, sample_documents): """测试无效的归一化范围""" doc1, doc2, doc3 = sample_documents results = [ SearchResult(document=doc1, score=5.0, rank=0) ] with pytest.raises(ValueError, match="min_score must be less than max_score"): service.normalize_scores(results, min_score=1.0, max_score=0.0) class TestHybridSearchServiceFilterByScoreThreshold: """测试混合搜索服务的分数阈值过滤功能""" @pytest.fixture def service(self): """创建混合搜索服务实例""" return HybridSearchService() @pytest.fixture def sample_documents(self): """创建示例文档""" doc1 = Document( id=EntityId("doc_1"), content="Test 1", embedding=None, metadata={}, created_at=Timestamp.now(), updated_at=Timestamp.now() ) doc2 = Document( id=EntityId("doc_2"), content="Test 2", embedding=None, metadata={}, created_at=Timestamp.now(), updated_at=Timestamp.now() ) doc3 = Document( id=EntityId("doc_3"), content="Test 3", embedding=None, metadata={}, created_at=Timestamp.now(), updated_at=Timestamp.now() ) return doc1, doc2, doc3 def test_filter_by_score_threshold_basic(self, service, sample_documents): """测试基本的分数过滤""" doc1, doc2, doc3 = sample_documents results = [ SearchResult(document=doc1, score=0.9, rank=0), SearchResult(document=doc2, score=0.7, rank=1), SearchResult(document=doc3, score=0.5, rank=2) ] filtered = service.filter_by_score_threshold(results, threshold=0.6) assert len(filtered) == 2 assert filtered[0].document.id == EntityId("doc_1") assert filtered[0].rank == 0 assert filtered[1].document.id == EntityId("doc_2") assert filtered[1].rank == 1 def test_filter_by_score_threshold_all_pass(self, service, sample_documents): """测试所有结果都通过阈值""" doc1, doc2, doc3 = sample_documents results = [ SearchResult(document=doc1, score=0.9, rank=0), SearchResult(document=doc2, score=0.8, rank=1) ] filtered = service.filter_by_score_threshold(results, threshold=0.5) assert len(filtered) == 2 def test_filter_by_score_threshold_none_pass(self, service, sample_documents): """测试没有结果通过阈值""" doc1, doc2, doc3 = sample_documents results = [ SearchResult(document=doc1, score=0.4, rank=0), SearchResult(document=doc2, score=0.3, rank=1) ] filtered = service.filter_by_score_threshold(results, threshold=0.5) assert len(filtered) == 0 def test_filter_by_score_threshold_exact_match(self, service, sample_documents): """测试分数正好等于阈值的情况""" doc1, doc2, doc3 = sample_documents results = [ SearchResult(document=doc1, score=0.6, rank=0) ] filtered = service.filter_by_score_threshold(results, threshold=0.6) assert len(filtered) == 1 class TestHybridSearchServiceDeduplicateResults: """测试混合搜索服务的去重功能""" @pytest.fixture def service(self): """创建混合搜索服务实例""" return HybridSearchService() @pytest.fixture def sample_documents(self): """创建示例文档""" doc1 = Document( id=EntityId("doc_1"), content="Test 1", embedding=None, metadata={}, created_at=Timestamp.now(), updated_at=Timestamp.now() ) doc2 = Document( id=EntityId("doc_2"), content="Test 2", embedding=None, metadata={}, created_at=Timestamp.now(), updated_at=Timestamp.now() ) return doc1, doc2 def test_deduplicate_results_keep_highest(self, service, sample_documents): """测试保留最高分数的去重""" doc1, doc2 = sample_documents results = [ SearchResult(document=doc1, score=0.9, rank=0), SearchResult(document=doc2, score=0.8, rank=1), SearchResult(document=doc1, score=0.7, rank=2) # 重复 ] deduped = service.deduplicate_results(results, keep="highest") assert len(deduped) == 2 assert deduped[0].document.id == EntityId("doc_1") assert deduped[0].score == 0.9 # 保留最高分 assert deduped[1].document.id == EntityId("doc_2") def test_deduplicate_results_keep_lowest(self, service, sample_documents): """测试保留最低分数的去重""" doc1, doc2 = sample_documents results = [ SearchResult(document=doc1, score=0.9, rank=0), SearchResult(document=doc2, score=0.8, rank=1), SearchResult(document=doc1, score=0.7, rank=2) # 重复 ] deduped = service.deduplicate_results(results, keep="lowest") assert len(deduped) == 2 assert deduped[0].document.id == EntityId("doc_1") assert deduped[0].score == 0.7 # 保留最低分 assert deduped[1].document.id == EntityId("doc_2") def test_deduplicate_results_keep_first(self, service, sample_documents): """测试保留第一次出现的去重""" doc1, doc2 = sample_documents results = [ SearchResult(document=doc1, score=0.9, rank=0), SearchResult(document=doc2, score=0.8, rank=1), SearchResult(document=doc1, score=0.7, rank=2) # 重复 ] deduped = service.deduplicate_results(results, keep="first") assert len(deduped) == 2 assert deduped[0].document.id == EntityId("doc_1") assert deduped[0].score == 0.9 # 保留第一次出现的 assert deduped[1].document.id == EntityId("doc_2") def test_deduplicate_results_no_duplicates(self, service, sample_documents): """测试没有重复的情况""" doc1, doc2 = sample_documents results = [ SearchResult(document=doc1, score=0.9, rank=0), SearchResult(document=doc2, score=0.8, rank=1) ] deduped = service.deduplicate_results(results, keep="highest") assert len(deduped) == 2 def test_deduplicate_results_invalid_keep_strategy(self, service, sample_documents): """测试无效的保留策略""" doc1, doc2 = sample_documents results = [ SearchResult(document=doc1, score=0.9, rank=0) ] with pytest.raises(ValueError, match="keep must be one of"): service.deduplicate_results(results, keep="invalid") class TestHybridSearchServiceStringRepresentation: """测试混合搜索服务的字符串表示""" def test_str_representation(self): """测试 __str__ 方法""" service = HybridSearchService(vector_weight=0.6, text_weight=0.4) str_repr = str(service) assert "HybridSearchService" in str_repr assert "0.6" in str_repr assert "0.4" in str_repr def test_repr_representation(self): """测试 __repr__ 方法""" service = HybridSearchService(vector_weight=0.6, text_weight=0.4) repr_str = repr(service) assert "HybridSearchService" in repr_str assert "0.6" in repr_str assert "0.4" in repr_str