test_mcp_hybrid_search.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #!/usr/bin/env python3
  2. """
  3. 测试MCP服务的混合检索功能
  4. """
  5. import sys
  6. import os
  7. import json
  8. import unittest
  9. import requests
  10. from typing import Dict, Any
  11. # 添加项目根目录到Python路径
  12. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  13. class TestMCPHybridSearch(unittest.TestCase):
  14. """测试MCP服务的混合检索功能"""
  15. def setUp(self):
  16. """设置测试环境"""
  17. # MCP服务的基础URL
  18. self.base_url = "http://localhost:18000"
  19. # 测试数据
  20. # 测试图片
  21. self.test_image = r"D:\project\work\ragflow_plugs\book\output\temp\2.png"
  22. # 测试文本查询
  23. self.test_text_query = "卡梅拉"
  24. def test_hybrid_search(self):
  25. """测试混合检索API"""
  26. print("测试混合检索API...")
  27. # 构建请求数据
  28. data = {
  29. "text_query": self.test_text_query,
  30. "image": self.test_image,
  31. "topn": 2
  32. }
  33. # 发送请求
  34. response = requests.post(
  35. f"{self.base_url}/tools/hybrid_search",
  36. json=data
  37. )
  38. # 验证响应
  39. self.assertEqual(response.status_code, 200, f"请求失败: {response.text}")
  40. result = response.json()
  41. self.assertTrue(result["success"], f"API调用失败: {result.get('message', '未知错误')}")
  42. self.assertIn("output", result, "响应中缺少hits字段")
  43. self.assertIn("total", result, "响应中缺少total字段")
  44. self.assertIsInstance(result["output"], list, "hits字段应该是一个列表")
  45. self.assertIsInstance(result["total"], int, "total字段应该是一个整数")
  46. print(f"✓ 混合检索API测试通过,命中数量: {result['total']}")
  47. if __name__ == "__main__":
  48. """运行测试"""
  49. print("开始测试MCP服务的混合检索功能...\n")
  50. # 创建测试套件
  51. suite = unittest.TestLoader().loadTestsFromTestCase(TestMCPHybridSearch)
  52. # 运行测试
  53. runner = unittest.TextTestRunner(verbosity=2)
  54. result = runner.run(suite)
  55. print(f"\n测试完成,共运行 {result.testsRun} 个测试,成功 {result.testsRun - len(result.failures) - len(result.errors)} 个,失败 {len(result.failures)} 个,错误 {len(result.errors)} 个")