#!/usr/bin/env python3 """ 性能基准测试脚本 对关键 API 端点进行性能测试,测量响应时间、吞吐量等指标。 """ import asyncio import time import statistics from typing import List, Dict, Any import httpx import json from datetime import datetime class BenchmarkResult: """基准测试结果""" def __init__(self, endpoint: str, method: str): self.endpoint = endpoint self.method = method self.response_times: List[float] = [] self.status_codes: List[int] = [] self.errors: List[str] = [] def add_result(self, response_time: float, status_code: int, error: str = None): """添加测试结果""" self.response_times.append(response_time) self.status_codes.append(status_code) if error: self.errors.append(error) def get_statistics(self) -> Dict[str, Any]: """计算统计数据""" if not self.response_times: return { "endpoint": self.endpoint, "method": self.method, "error": "No data collected" } return { "endpoint": self.endpoint, "method": self.method, "requests": len(self.response_times), "success_rate": (len([c for c in self.status_codes if 200 <= c < 300]) / len(self.status_codes)) * 100, "response_time": { "min": min(self.response_times) * 1000, # 转换为毫秒 "max": max(self.response_times) * 1000, "mean": statistics.mean(self.response_times) * 1000, "median": statistics.median(self.response_times) * 1000, "p95": statistics.quantiles(self.response_times, n=20)[18] * 1000 if len(self.response_times) > 1 else self.response_times[0] * 1000, "p99": statistics.quantiles(self.response_times, n=100)[98] * 1000 if len(self.response_times) > 1 else self.response_times[0] * 1000, }, "errors": len(self.errors), "error_messages": self.errors[:5] if self.errors else [] # 只显示前5个错误 } class APIBenchmark: """API 性能基准测试""" def __init__(self, base_url: str = "http://localhost:8000"): self.base_url = base_url self.results: List[BenchmarkResult] = [] async def benchmark_endpoint( self, endpoint: str, method: str = "GET", data: Dict = None, iterations: int = 100, concurrent: int = 10 ) -> BenchmarkResult: """对单个端点进行基准测试""" result = BenchmarkResult(endpoint, method) async with httpx.AsyncClient(timeout=30.0) as client: # 分批并发请求 for batch in range(0, iterations, concurrent): batch_size = min(concurrent, iterations - batch) tasks = [] for _ in range(batch_size): tasks.append(self._make_request(client, endpoint, method, data, result)) await asyncio.gather(*tasks) self.results.append(result) return result async def _make_request( self, client: httpx.AsyncClient, endpoint: str, method: str, data: Dict, result: BenchmarkResult ): """执行单个请求""" url = f"{self.base_url}{endpoint}" start_time = time.time() try: if method == "GET": response = await client.get(url) elif method == "POST": response = await client.post(url, json=data) elif method == "PUT": response = await client.put(url, json=data) elif method == "DELETE": response = await client.delete(url) else: raise ValueError(f"Unsupported method: {method}") elapsed = time.time() - start_time result.add_result(elapsed, response.status_code) except Exception as e: elapsed = time.time() - start_time result.add_result(elapsed, 0, str(e)) def print_results(self): """打印测试结果""" print("\n" + "=" * 80) print("性能基准测试结果") print("=" * 80) print(f"测试时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"基础 URL: {self.base_url}") print("=" * 80) for result in self.results: stats = result.get_statistics() if "error" in stats: print(f"\n❌ {stats['method']} {stats['endpoint']}") print(f" 错误: {stats['error']}") continue print(f"\n✓ {stats['method']} {stats['endpoint']}") print(f" 请求数: {stats['requests']}") print(f" 成功率: {stats['success_rate']:.2f}%") print(f" 响应时间 (ms):") print(f" 最小值: {stats['response_time']['min']:.2f}") print(f" 最大值: {stats['response_time']['max']:.2f}") print(f" 平均值: {stats['response_time']['mean']:.2f}") print(f" 中位数: {stats['response_time']['median']:.2f}") print(f" P95: {stats['response_time']['p95']:.2f}") print(f" P99: {stats['response_time']['p99']:.2f}") if stats['errors'] > 0: print(f" ⚠️ 错误数: {stats['errors']}") if stats['error_messages']: print(f" 错误示例: {stats['error_messages'][0]}") print("\n" + "=" * 80) def save_results(self, filename: str = "benchmark_results.json"): """保存结果到文件""" results_data = { "timestamp": datetime.now().isoformat(), "base_url": self.base_url, "results": [r.get_statistics() for r in self.results] } with open(filename, 'w', encoding='utf-8') as f: json.dump(results_data, f, indent=2, ensure_ascii=False) print(f"\n结果已保存到: {filename}") async def run_benchmarks(): """运行所有基准测试""" benchmark = APIBenchmark() print("开始性能基准测试...") print("注意: 确保应用正在运行 (python main.py)") # 等待用户确认 try: input("\n按 Enter 继续,或 Ctrl+C 取消...") except KeyboardInterrupt: print("\n测试已取消") return # 测试健康检查端点 print("\n测试 1/5: 健康检查端点...") await benchmark.benchmark_endpoint( "/health", method="GET", iterations=100, concurrent=10 ) # 测试指标端点 print("测试 2/5: 指标端点...") await benchmark.benchmark_endpoint( "/metrics", method="GET", iterations=100, concurrent=10 ) # 测试文档创建(需要有效的数据) print("测试 3/5: 创建文档端点...") document_data = { "content": "This is a test document for benchmarking", "metadata": { "source": "benchmark", "type": "test" } } await benchmark.benchmark_endpoint( "/api/v1/documents/", method="POST", data=document_data, iterations=50, # 减少迭代次数,因为会创建实际数据 concurrent=5 ) # 测试文档搜索 print("测试 4/5: 搜索文档端点...") search_data = { "query_text": "test document", "top_k": 10 } await benchmark.benchmark_endpoint( "/api/v1/documents/search", method="POST", data=search_data, iterations=100, concurrent=10 ) # 测试知识库列表(假设端点存在) print("测试 5/5: 知识库端点...") await benchmark.benchmark_endpoint( "/api/v1/knowledge-bases/", method="GET", iterations=100, concurrent=10 ) # 打印和保存结果 benchmark.print_results() benchmark.save_results() # 性能评估 print("\n" + "=" * 80) print("性能评估") print("=" * 80) for result in benchmark.results: stats = result.get_statistics() if "error" in stats: continue mean_time = stats['response_time']['mean'] endpoint = stats['endpoint'] # 根据端点类型设置不同的阈值 if endpoint == "/health" or endpoint == "/metrics": threshold = 50 # 50ms target = "< 50ms" elif "search" in endpoint: threshold = 500 # 500ms target = "< 500ms" else: threshold = 200 # 200ms target = "< 200ms" status = "✓ 通过" if mean_time < threshold else "⚠️ 需要优化" print(f"{stats['method']} {endpoint}") print(f" 平均响应时间: {mean_time:.2f}ms (目标: {target}) - {status}") print("=" * 80) def main(): """主函数""" try: asyncio.run(run_benchmarks()) except KeyboardInterrupt: print("\n\n测试被用户中断") except Exception as e: print(f"\n\n错误: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()