benchmark.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. #!/usr/bin/env python3
  2. """
  3. 性能基准测试脚本
  4. 对关键 API 端点进行性能测试,测量响应时间、吞吐量等指标。
  5. """
  6. import asyncio
  7. import time
  8. import statistics
  9. from typing import List, Dict, Any
  10. import httpx
  11. import json
  12. from datetime import datetime
  13. class BenchmarkResult:
  14. """基准测试结果"""
  15. def __init__(self, endpoint: str, method: str):
  16. self.endpoint = endpoint
  17. self.method = method
  18. self.response_times: List[float] = []
  19. self.status_codes: List[int] = []
  20. self.errors: List[str] = []
  21. def add_result(self, response_time: float, status_code: int, error: str = None):
  22. """添加测试结果"""
  23. self.response_times.append(response_time)
  24. self.status_codes.append(status_code)
  25. if error:
  26. self.errors.append(error)
  27. def get_statistics(self) -> Dict[str, Any]:
  28. """计算统计数据"""
  29. if not self.response_times:
  30. return {
  31. "endpoint": self.endpoint,
  32. "method": self.method,
  33. "error": "No data collected"
  34. }
  35. return {
  36. "endpoint": self.endpoint,
  37. "method": self.method,
  38. "requests": len(self.response_times),
  39. "success_rate": (len([c for c in self.status_codes if 200 <= c < 300]) / len(self.status_codes)) * 100,
  40. "response_time": {
  41. "min": min(self.response_times) * 1000, # 转换为毫秒
  42. "max": max(self.response_times) * 1000,
  43. "mean": statistics.mean(self.response_times) * 1000,
  44. "median": statistics.median(self.response_times) * 1000,
  45. "p95": statistics.quantiles(self.response_times, n=20)[18] * 1000 if len(self.response_times) > 1 else self.response_times[0] * 1000,
  46. "p99": statistics.quantiles(self.response_times, n=100)[98] * 1000 if len(self.response_times) > 1 else self.response_times[0] * 1000,
  47. },
  48. "errors": len(self.errors),
  49. "error_messages": self.errors[:5] if self.errors else [] # 只显示前5个错误
  50. }
  51. class APIBenchmark:
  52. """API 性能基准测试"""
  53. def __init__(self, base_url: str = "http://localhost:8000"):
  54. self.base_url = base_url
  55. self.results: List[BenchmarkResult] = []
  56. async def benchmark_endpoint(
  57. self,
  58. endpoint: str,
  59. method: str = "GET",
  60. data: Dict = None,
  61. iterations: int = 100,
  62. concurrent: int = 10
  63. ) -> BenchmarkResult:
  64. """对单个端点进行基准测试"""
  65. result = BenchmarkResult(endpoint, method)
  66. async with httpx.AsyncClient(timeout=30.0) as client:
  67. # 分批并发请求
  68. for batch in range(0, iterations, concurrent):
  69. batch_size = min(concurrent, iterations - batch)
  70. tasks = []
  71. for _ in range(batch_size):
  72. tasks.append(self._make_request(client, endpoint, method, data, result))
  73. await asyncio.gather(*tasks)
  74. self.results.append(result)
  75. return result
  76. async def _make_request(
  77. self,
  78. client: httpx.AsyncClient,
  79. endpoint: str,
  80. method: str,
  81. data: Dict,
  82. result: BenchmarkResult
  83. ):
  84. """执行单个请求"""
  85. url = f"{self.base_url}{endpoint}"
  86. start_time = time.time()
  87. try:
  88. if method == "GET":
  89. response = await client.get(url)
  90. elif method == "POST":
  91. response = await client.post(url, json=data)
  92. elif method == "PUT":
  93. response = await client.put(url, json=data)
  94. elif method == "DELETE":
  95. response = await client.delete(url)
  96. else:
  97. raise ValueError(f"Unsupported method: {method}")
  98. elapsed = time.time() - start_time
  99. result.add_result(elapsed, response.status_code)
  100. except Exception as e:
  101. elapsed = time.time() - start_time
  102. result.add_result(elapsed, 0, str(e))
  103. def print_results(self):
  104. """打印测试结果"""
  105. print("\n" + "=" * 80)
  106. print("性能基准测试结果")
  107. print("=" * 80)
  108. print(f"测试时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
  109. print(f"基础 URL: {self.base_url}")
  110. print("=" * 80)
  111. for result in self.results:
  112. stats = result.get_statistics()
  113. if "error" in stats:
  114. print(f"\n❌ {stats['method']} {stats['endpoint']}")
  115. print(f" 错误: {stats['error']}")
  116. continue
  117. print(f"\n✓ {stats['method']} {stats['endpoint']}")
  118. print(f" 请求数: {stats['requests']}")
  119. print(f" 成功率: {stats['success_rate']:.2f}%")
  120. print(f" 响应时间 (ms):")
  121. print(f" 最小值: {stats['response_time']['min']:.2f}")
  122. print(f" 最大值: {stats['response_time']['max']:.2f}")
  123. print(f" 平均值: {stats['response_time']['mean']:.2f}")
  124. print(f" 中位数: {stats['response_time']['median']:.2f}")
  125. print(f" P95: {stats['response_time']['p95']:.2f}")
  126. print(f" P99: {stats['response_time']['p99']:.2f}")
  127. if stats['errors'] > 0:
  128. print(f" ⚠️ 错误数: {stats['errors']}")
  129. if stats['error_messages']:
  130. print(f" 错误示例: {stats['error_messages'][0]}")
  131. print("\n" + "=" * 80)
  132. def save_results(self, filename: str = "benchmark_results.json"):
  133. """保存结果到文件"""
  134. results_data = {
  135. "timestamp": datetime.now().isoformat(),
  136. "base_url": self.base_url,
  137. "results": [r.get_statistics() for r in self.results]
  138. }
  139. with open(filename, 'w', encoding='utf-8') as f:
  140. json.dump(results_data, f, indent=2, ensure_ascii=False)
  141. print(f"\n结果已保存到: {filename}")
  142. async def run_benchmarks():
  143. """运行所有基准测试"""
  144. benchmark = APIBenchmark()
  145. print("开始性能基准测试...")
  146. print("注意: 确保应用正在运行 (python main.py)")
  147. # 等待用户确认
  148. try:
  149. input("\n按 Enter 继续,或 Ctrl+C 取消...")
  150. except KeyboardInterrupt:
  151. print("\n测试已取消")
  152. return
  153. # 测试健康检查端点
  154. print("\n测试 1/5: 健康检查端点...")
  155. await benchmark.benchmark_endpoint(
  156. "/health",
  157. method="GET",
  158. iterations=100,
  159. concurrent=10
  160. )
  161. # 测试指标端点
  162. print("测试 2/5: 指标端点...")
  163. await benchmark.benchmark_endpoint(
  164. "/metrics",
  165. method="GET",
  166. iterations=100,
  167. concurrent=10
  168. )
  169. # 测试文档创建(需要有效的数据)
  170. print("测试 3/5: 创建文档端点...")
  171. document_data = {
  172. "content": "This is a test document for benchmarking",
  173. "metadata": {
  174. "source": "benchmark",
  175. "type": "test"
  176. }
  177. }
  178. await benchmark.benchmark_endpoint(
  179. "/api/v1/documents/",
  180. method="POST",
  181. data=document_data,
  182. iterations=50, # 减少迭代次数,因为会创建实际数据
  183. concurrent=5
  184. )
  185. # 测试文档搜索
  186. print("测试 4/5: 搜索文档端点...")
  187. search_data = {
  188. "query_text": "test document",
  189. "top_k": 10
  190. }
  191. await benchmark.benchmark_endpoint(
  192. "/api/v1/documents/search",
  193. method="POST",
  194. data=search_data,
  195. iterations=100,
  196. concurrent=10
  197. )
  198. # 测试知识库列表(假设端点存在)
  199. print("测试 5/5: 知识库端点...")
  200. await benchmark.benchmark_endpoint(
  201. "/api/v1/knowledge-bases/",
  202. method="GET",
  203. iterations=100,
  204. concurrent=10
  205. )
  206. # 打印和保存结果
  207. benchmark.print_results()
  208. benchmark.save_results()
  209. # 性能评估
  210. print("\n" + "=" * 80)
  211. print("性能评估")
  212. print("=" * 80)
  213. for result in benchmark.results:
  214. stats = result.get_statistics()
  215. if "error" in stats:
  216. continue
  217. mean_time = stats['response_time']['mean']
  218. endpoint = stats['endpoint']
  219. # 根据端点类型设置不同的阈值
  220. if endpoint == "/health" or endpoint == "/metrics":
  221. threshold = 50 # 50ms
  222. target = "< 50ms"
  223. elif "search" in endpoint:
  224. threshold = 500 # 500ms
  225. target = "< 500ms"
  226. else:
  227. threshold = 200 # 200ms
  228. target = "< 200ms"
  229. status = "✓ 通过" if mean_time < threshold else "⚠️ 需要优化"
  230. print(f"{stats['method']} {endpoint}")
  231. print(f" 平均响应时间: {mean_time:.2f}ms (目标: {target}) - {status}")
  232. print("=" * 80)
  233. def main():
  234. """主函数"""
  235. try:
  236. asyncio.run(run_benchmarks())
  237. except KeyboardInterrupt:
  238. print("\n\n测试被用户中断")
  239. except Exception as e:
  240. print(f"\n\n错误: {e}")
  241. import traceback
  242. traceback.print_exc()
  243. if __name__ == "__main__":
  244. main()