test_services.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. """
  2. 混合搜索服务单元测试
  3. 测试 HybridSearchService 的核心功能,包括:
  4. - 分数合并算法
  5. - 结果排序
  6. - 归一化
  7. - 过滤和去重
  8. """
  9. import pytest
  10. from src.domain.shared.value_objects import EntityId, Timestamp
  11. from src.domain.vector_search.entities import Document, SearchResult
  12. from src.domain.vector_search.exceptions import VectorSearchException
  13. from src.domain.vector_search.services import HybridSearchService
  14. from src.domain.vector_search.value_objects import Vector
  15. class TestHybridSearchServiceInitialization:
  16. """测试混合搜索服务的初始化"""
  17. def test_default_initialization(self):
  18. """测试使用默认权重初始化"""
  19. service = HybridSearchService()
  20. assert service.vector_weight == 0.7
  21. assert service.text_weight == 0.3
  22. def test_custom_initialization(self):
  23. """测试使用自定义权重初始化"""
  24. service = HybridSearchService(vector_weight=0.6, text_weight=0.4)
  25. assert service.vector_weight == 0.6
  26. assert service.text_weight == 0.4
  27. def test_initialization_with_negative_vector_weight(self):
  28. """测试使用负向量权重初始化应该失败"""
  29. with pytest.raises(ValueError, match="vector_weight cannot be negative"):
  30. HybridSearchService(vector_weight=-0.1, text_weight=0.3)
  31. def test_initialization_with_negative_text_weight(self):
  32. """测试使用负全文权重初始化应该失败"""
  33. with pytest.raises(ValueError, match="text_weight cannot be negative"):
  34. HybridSearchService(vector_weight=0.7, text_weight=-0.1)
  35. def test_initialization_with_both_weights_zero(self):
  36. """测试两个权重都为零应该失败"""
  37. with pytest.raises(ValueError, match="At least one weight must be greater than zero"):
  38. HybridSearchService(vector_weight=0.0, text_weight=0.0)
  39. def test_initialization_with_non_numeric_vector_weight(self):
  40. """测试使用非数值向量权重应该失败"""
  41. with pytest.raises(ValueError, match="vector_weight must be a number"):
  42. HybridSearchService(vector_weight="0.7", text_weight=0.3)
  43. def test_initialization_with_non_numeric_text_weight(self):
  44. """测试使用非数值全文权重应该失败"""
  45. with pytest.raises(ValueError, match="text_weight must be a number"):
  46. HybridSearchService(vector_weight=0.7, text_weight="0.3")
  47. def test_get_weights(self):
  48. """测试获取权重配置"""
  49. service = HybridSearchService(vector_weight=0.6, text_weight=0.4)
  50. vector_w, text_w = service.get_weights()
  51. assert vector_w == 0.6
  52. assert text_w == 0.4
  53. class TestHybridSearchServiceCombineScores:
  54. """测试混合搜索服务的分数合并功能"""
  55. @pytest.fixture
  56. def service(self):
  57. """创建混合搜索服务实例"""
  58. return HybridSearchService(vector_weight=0.7, text_weight=0.3)
  59. @pytest.fixture
  60. def sample_documents(self):
  61. """创建示例文档"""
  62. doc1 = Document(
  63. id=EntityId("doc_1"),
  64. content="Machine learning is a subset of AI",
  65. embedding=Vector([1.0, 2.0, 3.0]),
  66. metadata={"category": "AI"},
  67. created_at=Timestamp.now(),
  68. updated_at=Timestamp.now()
  69. )
  70. doc2 = Document(
  71. id=EntityId("doc_2"),
  72. content="Deep learning uses neural networks",
  73. embedding=Vector([2.0, 3.0, 4.0]),
  74. metadata={"category": "AI"},
  75. created_at=Timestamp.now(),
  76. updated_at=Timestamp.now()
  77. )
  78. doc3 = Document(
  79. id=EntityId("doc_3"),
  80. content="Natural language processing",
  81. embedding=Vector([3.0, 4.0, 5.0]),
  82. metadata={"category": "NLP"},
  83. created_at=Timestamp.now(),
  84. updated_at=Timestamp.now()
  85. )
  86. return doc1, doc2, doc3
  87. def test_combine_scores_with_overlapping_results(self, service, sample_documents):
  88. """测试合并有重叠文档的搜索结果"""
  89. doc1, doc2, doc3 = sample_documents
  90. # 向量搜索结果
  91. vector_results = [
  92. SearchResult(document=doc1, score=0.9, rank=0),
  93. SearchResult(document=doc2, score=0.8, rank=1)
  94. ]
  95. # 全文搜索结果
  96. text_results = [
  97. SearchResult(document=doc1, score=0.7, rank=0),
  98. SearchResult(document=doc3, score=0.6, rank=1)
  99. ]
  100. # 合并结果
  101. combined = service.combine_scores(vector_results, text_results)
  102. # 验证结果数量
  103. assert len(combined) == 3
  104. # 验证分数计算
  105. # doc1: 0.9 * 0.7 + 0.7 * 0.3 = 0.63 + 0.21 = 0.84
  106. # doc2: 0.8 * 0.7 = 0.56
  107. # doc3: 0.6 * 0.3 = 0.18
  108. assert combined[0].document.id == EntityId("doc_1")
  109. assert abs(combined[0].score - 0.84) < 0.001
  110. assert combined[0].rank == 0
  111. assert combined[1].document.id == EntityId("doc_2")
  112. assert abs(combined[1].score - 0.56) < 0.001
  113. assert combined[1].rank == 1
  114. assert combined[2].document.id == EntityId("doc_3")
  115. assert abs(combined[2].score - 0.18) < 0.001
  116. assert combined[2].rank == 2
  117. def test_combine_scores_with_no_overlap(self, service, sample_documents):
  118. """测试合并没有重叠文档的搜索结果"""
  119. doc1, doc2, doc3 = sample_documents
  120. vector_results = [
  121. SearchResult(document=doc1, score=0.9, rank=0)
  122. ]
  123. text_results = [
  124. SearchResult(document=doc2, score=0.8, rank=0)
  125. ]
  126. combined = service.combine_scores(vector_results, text_results)
  127. assert len(combined) == 2
  128. assert combined[0].document.id == EntityId("doc_1")
  129. assert abs(combined[0].score - 0.63) < 0.001 # 0.9 * 0.7
  130. assert combined[1].document.id == EntityId("doc_2")
  131. assert abs(combined[1].score - 0.24) < 0.001 # 0.8 * 0.3
  132. def test_combine_scores_with_empty_vector_results(self, service, sample_documents):
  133. """测试向量搜索结果为空的情况"""
  134. doc1, doc2, doc3 = sample_documents
  135. vector_results = []
  136. text_results = [
  137. SearchResult(document=doc1, score=0.8, rank=0)
  138. ]
  139. combined = service.combine_scores(vector_results, text_results)
  140. assert len(combined) == 1
  141. assert combined[0].document.id == EntityId("doc_1")
  142. assert abs(combined[0].score - 0.24) < 0.001 # 0.8 * 0.3
  143. def test_combine_scores_with_empty_text_results(self, service, sample_documents):
  144. """测试全文搜索结果为空的情况"""
  145. doc1, doc2, doc3 = sample_documents
  146. vector_results = [
  147. SearchResult(document=doc1, score=0.9, rank=0)
  148. ]
  149. text_results = []
  150. combined = service.combine_scores(vector_results, text_results)
  151. assert len(combined) == 1
  152. assert combined[0].document.id == EntityId("doc_1")
  153. assert abs(combined[0].score - 0.63) < 0.001 # 0.9 * 0.7
  154. def test_combine_scores_with_both_empty(self, service):
  155. """测试两个搜索结果都为空的情况"""
  156. combined = service.combine_scores([], [])
  157. assert len(combined) == 0
  158. def test_combine_scores_with_invalid_vector_results(self, service):
  159. """测试无效的向量搜索结果"""
  160. with pytest.raises(ValueError, match="vector_results must be a list"):
  161. service.combine_scores("not a list", [])
  162. def test_combine_scores_with_invalid_text_results(self, service):
  163. """测试无效的全文搜索结果"""
  164. with pytest.raises(ValueError, match="text_results must be a list"):
  165. service.combine_scores([], "not a list")
  166. def test_combine_scores_with_non_search_result_items(self, service, sample_documents):
  167. """测试列表中包含非 SearchResult 对象"""
  168. doc1, doc2, doc3 = sample_documents
  169. vector_results = [
  170. SearchResult(document=doc1, score=0.9, rank=0),
  171. "not a search result"
  172. ]
  173. with pytest.raises(ValueError, match="must be a SearchResult instance"):
  174. service.combine_scores(vector_results, [])
  175. class TestHybridSearchServiceNormalizeScores:
  176. """测试混合搜索服务的分数归一化功能"""
  177. @pytest.fixture
  178. def service(self):
  179. """创建混合搜索服务实例"""
  180. return HybridSearchService()
  181. @pytest.fixture
  182. def sample_documents(self):
  183. """创建示例文档"""
  184. doc1 = Document(
  185. id=EntityId("doc_1"),
  186. content="Test 1",
  187. embedding=None,
  188. metadata={},
  189. created_at=Timestamp.now(),
  190. updated_at=Timestamp.now()
  191. )
  192. doc2 = Document(
  193. id=EntityId("doc_2"),
  194. content="Test 2",
  195. embedding=None,
  196. metadata={},
  197. created_at=Timestamp.now(),
  198. updated_at=Timestamp.now()
  199. )
  200. doc3 = Document(
  201. id=EntityId("doc_3"),
  202. content="Test 3",
  203. embedding=None,
  204. metadata={},
  205. created_at=Timestamp.now(),
  206. updated_at=Timestamp.now()
  207. )
  208. return doc1, doc2, doc3
  209. def test_normalize_scores_basic(self, service, sample_documents):
  210. """测试基本的分数归一化"""
  211. doc1, doc2, doc3 = sample_documents
  212. results = [
  213. SearchResult(document=doc1, score=10.0, rank=0),
  214. SearchResult(document=doc2, score=5.0, rank=1),
  215. SearchResult(document=doc3, score=0.0, rank=2)
  216. ]
  217. normalized = service.normalize_scores(results, min_score=0.0, max_score=1.0)
  218. assert len(normalized) == 3
  219. assert abs(normalized[0].score - 1.0) < 0.001
  220. assert abs(normalized[1].score - 0.5) < 0.001
  221. assert abs(normalized[2].score - 0.0) < 0.001
  222. def test_normalize_scores_custom_range(self, service, sample_documents):
  223. """测试自定义范围的归一化"""
  224. doc1, doc2, doc3 = sample_documents
  225. results = [
  226. SearchResult(document=doc1, score=100.0, rank=0),
  227. SearchResult(document=doc2, score=50.0, rank=1)
  228. ]
  229. normalized = service.normalize_scores(results, min_score=0.5, max_score=1.0)
  230. assert abs(normalized[0].score - 1.0) < 0.001
  231. assert abs(normalized[1].score - 0.5) < 0.001
  232. def test_normalize_scores_single_result(self, service, sample_documents):
  233. """测试单个结果的归一化"""
  234. doc1, doc2, doc3 = sample_documents
  235. results = [
  236. SearchResult(document=doc1, score=5.0, rank=0)
  237. ]
  238. normalized = service.normalize_scores(results, min_score=0.0, max_score=1.0)
  239. assert len(normalized) == 1
  240. assert normalized[0].score == 1.0 # 单个结果应该得到最大分数
  241. def test_normalize_scores_all_same(self, service, sample_documents):
  242. """测试所有分数相同的情况"""
  243. doc1, doc2, doc3 = sample_documents
  244. results = [
  245. SearchResult(document=doc1, score=5.0, rank=0),
  246. SearchResult(document=doc2, score=5.0, rank=1),
  247. SearchResult(document=doc3, score=5.0, rank=2)
  248. ]
  249. normalized = service.normalize_scores(results, min_score=0.0, max_score=1.0)
  250. # 所有分数相同时,应该得到平均分数
  251. for result in normalized:
  252. assert abs(result.score - 0.5) < 0.001
  253. def test_normalize_scores_empty_list(self, service):
  254. """测试空列表的归一化"""
  255. normalized = service.normalize_scores([], min_score=0.0, max_score=1.0)
  256. assert len(normalized) == 0
  257. def test_normalize_scores_invalid_range(self, service, sample_documents):
  258. """测试无效的归一化范围"""
  259. doc1, doc2, doc3 = sample_documents
  260. results = [
  261. SearchResult(document=doc1, score=5.0, rank=0)
  262. ]
  263. with pytest.raises(ValueError, match="min_score must be less than max_score"):
  264. service.normalize_scores(results, min_score=1.0, max_score=0.0)
  265. class TestHybridSearchServiceFilterByScoreThreshold:
  266. """测试混合搜索服务的分数阈值过滤功能"""
  267. @pytest.fixture
  268. def service(self):
  269. """创建混合搜索服务实例"""
  270. return HybridSearchService()
  271. @pytest.fixture
  272. def sample_documents(self):
  273. """创建示例文档"""
  274. doc1 = Document(
  275. id=EntityId("doc_1"),
  276. content="Test 1",
  277. embedding=None,
  278. metadata={},
  279. created_at=Timestamp.now(),
  280. updated_at=Timestamp.now()
  281. )
  282. doc2 = Document(
  283. id=EntityId("doc_2"),
  284. content="Test 2",
  285. embedding=None,
  286. metadata={},
  287. created_at=Timestamp.now(),
  288. updated_at=Timestamp.now()
  289. )
  290. doc3 = Document(
  291. id=EntityId("doc_3"),
  292. content="Test 3",
  293. embedding=None,
  294. metadata={},
  295. created_at=Timestamp.now(),
  296. updated_at=Timestamp.now()
  297. )
  298. return doc1, doc2, doc3
  299. def test_filter_by_score_threshold_basic(self, service, sample_documents):
  300. """测试基本的分数过滤"""
  301. doc1, doc2, doc3 = sample_documents
  302. results = [
  303. SearchResult(document=doc1, score=0.9, rank=0),
  304. SearchResult(document=doc2, score=0.7, rank=1),
  305. SearchResult(document=doc3, score=0.5, rank=2)
  306. ]
  307. filtered = service.filter_by_score_threshold(results, threshold=0.6)
  308. assert len(filtered) == 2
  309. assert filtered[0].document.id == EntityId("doc_1")
  310. assert filtered[0].rank == 0
  311. assert filtered[1].document.id == EntityId("doc_2")
  312. assert filtered[1].rank == 1
  313. def test_filter_by_score_threshold_all_pass(self, service, sample_documents):
  314. """测试所有结果都通过阈值"""
  315. doc1, doc2, doc3 = sample_documents
  316. results = [
  317. SearchResult(document=doc1, score=0.9, rank=0),
  318. SearchResult(document=doc2, score=0.8, rank=1)
  319. ]
  320. filtered = service.filter_by_score_threshold(results, threshold=0.5)
  321. assert len(filtered) == 2
  322. def test_filter_by_score_threshold_none_pass(self, service, sample_documents):
  323. """测试没有结果通过阈值"""
  324. doc1, doc2, doc3 = sample_documents
  325. results = [
  326. SearchResult(document=doc1, score=0.4, rank=0),
  327. SearchResult(document=doc2, score=0.3, rank=1)
  328. ]
  329. filtered = service.filter_by_score_threshold(results, threshold=0.5)
  330. assert len(filtered) == 0
  331. def test_filter_by_score_threshold_exact_match(self, service, sample_documents):
  332. """测试分数正好等于阈值的情况"""
  333. doc1, doc2, doc3 = sample_documents
  334. results = [
  335. SearchResult(document=doc1, score=0.6, rank=0)
  336. ]
  337. filtered = service.filter_by_score_threshold(results, threshold=0.6)
  338. assert len(filtered) == 1
  339. class TestHybridSearchServiceDeduplicateResults:
  340. """测试混合搜索服务的去重功能"""
  341. @pytest.fixture
  342. def service(self):
  343. """创建混合搜索服务实例"""
  344. return HybridSearchService()
  345. @pytest.fixture
  346. def sample_documents(self):
  347. """创建示例文档"""
  348. doc1 = Document(
  349. id=EntityId("doc_1"),
  350. content="Test 1",
  351. embedding=None,
  352. metadata={},
  353. created_at=Timestamp.now(),
  354. updated_at=Timestamp.now()
  355. )
  356. doc2 = Document(
  357. id=EntityId("doc_2"),
  358. content="Test 2",
  359. embedding=None,
  360. metadata={},
  361. created_at=Timestamp.now(),
  362. updated_at=Timestamp.now()
  363. )
  364. return doc1, doc2
  365. def test_deduplicate_results_keep_highest(self, service, sample_documents):
  366. """测试保留最高分数的去重"""
  367. doc1, doc2 = sample_documents
  368. results = [
  369. SearchResult(document=doc1, score=0.9, rank=0),
  370. SearchResult(document=doc2, score=0.8, rank=1),
  371. SearchResult(document=doc1, score=0.7, rank=2) # 重复
  372. ]
  373. deduped = service.deduplicate_results(results, keep="highest")
  374. assert len(deduped) == 2
  375. assert deduped[0].document.id == EntityId("doc_1")
  376. assert deduped[0].score == 0.9 # 保留最高分
  377. assert deduped[1].document.id == EntityId("doc_2")
  378. def test_deduplicate_results_keep_lowest(self, service, sample_documents):
  379. """测试保留最低分数的去重"""
  380. doc1, doc2 = sample_documents
  381. results = [
  382. SearchResult(document=doc1, score=0.9, rank=0),
  383. SearchResult(document=doc2, score=0.8, rank=1),
  384. SearchResult(document=doc1, score=0.7, rank=2) # 重复
  385. ]
  386. deduped = service.deduplicate_results(results, keep="lowest")
  387. assert len(deduped) == 2
  388. assert deduped[0].document.id == EntityId("doc_1")
  389. assert deduped[0].score == 0.7 # 保留最低分
  390. assert deduped[1].document.id == EntityId("doc_2")
  391. def test_deduplicate_results_keep_first(self, service, sample_documents):
  392. """测试保留第一次出现的去重"""
  393. doc1, doc2 = sample_documents
  394. results = [
  395. SearchResult(document=doc1, score=0.9, rank=0),
  396. SearchResult(document=doc2, score=0.8, rank=1),
  397. SearchResult(document=doc1, score=0.7, rank=2) # 重复
  398. ]
  399. deduped = service.deduplicate_results(results, keep="first")
  400. assert len(deduped) == 2
  401. assert deduped[0].document.id == EntityId("doc_1")
  402. assert deduped[0].score == 0.9 # 保留第一次出现的
  403. assert deduped[1].document.id == EntityId("doc_2")
  404. def test_deduplicate_results_no_duplicates(self, service, sample_documents):
  405. """测试没有重复的情况"""
  406. doc1, doc2 = sample_documents
  407. results = [
  408. SearchResult(document=doc1, score=0.9, rank=0),
  409. SearchResult(document=doc2, score=0.8, rank=1)
  410. ]
  411. deduped = service.deduplicate_results(results, keep="highest")
  412. assert len(deduped) == 2
  413. def test_deduplicate_results_invalid_keep_strategy(self, service, sample_documents):
  414. """测试无效的保留策略"""
  415. doc1, doc2 = sample_documents
  416. results = [
  417. SearchResult(document=doc1, score=0.9, rank=0)
  418. ]
  419. with pytest.raises(ValueError, match="keep must be one of"):
  420. service.deduplicate_results(results, keep="invalid")
  421. class TestHybridSearchServiceStringRepresentation:
  422. """测试混合搜索服务的字符串表示"""
  423. def test_str_representation(self):
  424. """测试 __str__ 方法"""
  425. service = HybridSearchService(vector_weight=0.6, text_weight=0.4)
  426. str_repr = str(service)
  427. assert "HybridSearchService" in str_repr
  428. assert "0.6" in str_repr
  429. assert "0.4" in str_repr
  430. def test_repr_representation(self):
  431. """测试 __repr__ 方法"""
  432. service = HybridSearchService(vector_weight=0.6, text_weight=0.4)
  433. repr_str = repr(service)
  434. assert "HybridSearchService" in repr_str
  435. assert "0.6" in repr_str
  436. assert "0.4" in repr_str