| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- #!/usr/bin/env python3
- """
- 性能基准测试脚本
- 对关键 API 端点进行性能测试,测量响应时间、吞吐量等指标。
- """
- import asyncio
- import time
- import statistics
- from typing import List, Dict, Any
- import httpx
- import json
- from datetime import datetime
- class BenchmarkResult:
- """基准测试结果"""
-
- def __init__(self, endpoint: str, method: str):
- self.endpoint = endpoint
- self.method = method
- self.response_times: List[float] = []
- self.status_codes: List[int] = []
- self.errors: List[str] = []
-
- def add_result(self, response_time: float, status_code: int, error: str = None):
- """添加测试结果"""
- self.response_times.append(response_time)
- self.status_codes.append(status_code)
- if error:
- self.errors.append(error)
-
- def get_statistics(self) -> Dict[str, Any]:
- """计算统计数据"""
- if not self.response_times:
- return {
- "endpoint": self.endpoint,
- "method": self.method,
- "error": "No data collected"
- }
-
- return {
- "endpoint": self.endpoint,
- "method": self.method,
- "requests": len(self.response_times),
- "success_rate": (len([c for c in self.status_codes if 200 <= c < 300]) / len(self.status_codes)) * 100,
- "response_time": {
- "min": min(self.response_times) * 1000, # 转换为毫秒
- "max": max(self.response_times) * 1000,
- "mean": statistics.mean(self.response_times) * 1000,
- "median": statistics.median(self.response_times) * 1000,
- "p95": statistics.quantiles(self.response_times, n=20)[18] * 1000 if len(self.response_times) > 1 else self.response_times[0] * 1000,
- "p99": statistics.quantiles(self.response_times, n=100)[98] * 1000 if len(self.response_times) > 1 else self.response_times[0] * 1000,
- },
- "errors": len(self.errors),
- "error_messages": self.errors[:5] if self.errors else [] # 只显示前5个错误
- }
- class APIBenchmark:
- """API 性能基准测试"""
-
- def __init__(self, base_url: str = "http://localhost:8000"):
- self.base_url = base_url
- self.results: List[BenchmarkResult] = []
-
- async def benchmark_endpoint(
- self,
- endpoint: str,
- method: str = "GET",
- data: Dict = None,
- iterations: int = 100,
- concurrent: int = 10
- ) -> BenchmarkResult:
- """对单个端点进行基准测试"""
- result = BenchmarkResult(endpoint, method)
-
- async with httpx.AsyncClient(timeout=30.0) as client:
- # 分批并发请求
- for batch in range(0, iterations, concurrent):
- batch_size = min(concurrent, iterations - batch)
- tasks = []
-
- for _ in range(batch_size):
- tasks.append(self._make_request(client, endpoint, method, data, result))
-
- await asyncio.gather(*tasks)
-
- self.results.append(result)
- return result
-
- async def _make_request(
- self,
- client: httpx.AsyncClient,
- endpoint: str,
- method: str,
- data: Dict,
- result: BenchmarkResult
- ):
- """执行单个请求"""
- url = f"{self.base_url}{endpoint}"
- start_time = time.time()
-
- try:
- if method == "GET":
- response = await client.get(url)
- elif method == "POST":
- response = await client.post(url, json=data)
- elif method == "PUT":
- response = await client.put(url, json=data)
- elif method == "DELETE":
- response = await client.delete(url)
- else:
- raise ValueError(f"Unsupported method: {method}")
-
- elapsed = time.time() - start_time
- result.add_result(elapsed, response.status_code)
-
- except Exception as e:
- elapsed = time.time() - start_time
- result.add_result(elapsed, 0, str(e))
-
- def print_results(self):
- """打印测试结果"""
- print("\n" + "=" * 80)
- print("性能基准测试结果")
- print("=" * 80)
- print(f"测试时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
- print(f"基础 URL: {self.base_url}")
- print("=" * 80)
-
- for result in self.results:
- stats = result.get_statistics()
-
- if "error" in stats:
- print(f"\n❌ {stats['method']} {stats['endpoint']}")
- print(f" 错误: {stats['error']}")
- continue
-
- print(f"\n✓ {stats['method']} {stats['endpoint']}")
- print(f" 请求数: {stats['requests']}")
- print(f" 成功率: {stats['success_rate']:.2f}%")
- print(f" 响应时间 (ms):")
- print(f" 最小值: {stats['response_time']['min']:.2f}")
- print(f" 最大值: {stats['response_time']['max']:.2f}")
- print(f" 平均值: {stats['response_time']['mean']:.2f}")
- print(f" 中位数: {stats['response_time']['median']:.2f}")
- print(f" P95: {stats['response_time']['p95']:.2f}")
- print(f" P99: {stats['response_time']['p99']:.2f}")
-
- if stats['errors'] > 0:
- print(f" ⚠️ 错误数: {stats['errors']}")
- if stats['error_messages']:
- print(f" 错误示例: {stats['error_messages'][0]}")
-
- print("\n" + "=" * 80)
-
- def save_results(self, filename: str = "benchmark_results.json"):
- """保存结果到文件"""
- results_data = {
- "timestamp": datetime.now().isoformat(),
- "base_url": self.base_url,
- "results": [r.get_statistics() for r in self.results]
- }
-
- with open(filename, 'w', encoding='utf-8') as f:
- json.dump(results_data, f, indent=2, ensure_ascii=False)
-
- print(f"\n结果已保存到: {filename}")
- async def run_benchmarks():
- """运行所有基准测试"""
- benchmark = APIBenchmark()
-
- print("开始性能基准测试...")
- print("注意: 确保应用正在运行 (python main.py)")
-
- # 等待用户确认
- try:
- input("\n按 Enter 继续,或 Ctrl+C 取消...")
- except KeyboardInterrupt:
- print("\n测试已取消")
- return
-
- # 测试健康检查端点
- print("\n测试 1/5: 健康检查端点...")
- await benchmark.benchmark_endpoint(
- "/health",
- method="GET",
- iterations=100,
- concurrent=10
- )
-
- # 测试指标端点
- print("测试 2/5: 指标端点...")
- await benchmark.benchmark_endpoint(
- "/metrics",
- method="GET",
- iterations=100,
- concurrent=10
- )
-
- # 测试文档创建(需要有效的数据)
- print("测试 3/5: 创建文档端点...")
- document_data = {
- "content": "This is a test document for benchmarking",
- "metadata": {
- "source": "benchmark",
- "type": "test"
- }
- }
- await benchmark.benchmark_endpoint(
- "/api/v1/documents/",
- method="POST",
- data=document_data,
- iterations=50, # 减少迭代次数,因为会创建实际数据
- concurrent=5
- )
-
- # 测试文档搜索
- print("测试 4/5: 搜索文档端点...")
- search_data = {
- "query_text": "test document",
- "top_k": 10
- }
- await benchmark.benchmark_endpoint(
- "/api/v1/documents/search",
- method="POST",
- data=search_data,
- iterations=100,
- concurrent=10
- )
-
- # 测试知识库列表(假设端点存在)
- print("测试 5/5: 知识库端点...")
- await benchmark.benchmark_endpoint(
- "/api/v1/knowledge-bases/",
- method="GET",
- iterations=100,
- concurrent=10
- )
-
- # 打印和保存结果
- benchmark.print_results()
- benchmark.save_results()
-
- # 性能评估
- print("\n" + "=" * 80)
- print("性能评估")
- print("=" * 80)
-
- for result in benchmark.results:
- stats = result.get_statistics()
- if "error" in stats:
- continue
-
- mean_time = stats['response_time']['mean']
- endpoint = stats['endpoint']
-
- # 根据端点类型设置不同的阈值
- if endpoint == "/health" or endpoint == "/metrics":
- threshold = 50 # 50ms
- target = "< 50ms"
- elif "search" in endpoint:
- threshold = 500 # 500ms
- target = "< 500ms"
- else:
- threshold = 200 # 200ms
- target = "< 200ms"
-
- status = "✓ 通过" if mean_time < threshold else "⚠️ 需要优化"
- print(f"{stats['method']} {endpoint}")
- print(f" 平均响应时间: {mean_time:.2f}ms (目标: {target}) - {status}")
-
- print("=" * 80)
- def main():
- """主函数"""
- try:
- asyncio.run(run_benchmarks())
- except KeyboardInterrupt:
- print("\n\n测试被用户中断")
- except Exception as e:
- print(f"\n\n错误: {e}")
- import traceback
- traceback.print_exc()
- if __name__ == "__main__":
- main()
|