update_imports.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. #!/usr/bin/env python3
  2. """
  3. 导入更新脚本
  4. 此脚本帮助识别和更新旧模块的导入引用。
  5. """
  6. import os
  7. import re
  8. from pathlib import Path
  9. from typing import List, Tuple
  10. # 导入映射:旧模块 -> 新模块
  11. IMPORT_MAPPINGS = {
  12. # 向量搜索相关
  13. 'src.api.sdk.search_infinity': 'src.application.vector_search.handlers',
  14. 'src.api.db.services.vector_search_service': 'src.infrastructure.vector_db',
  15. # 数据库相关
  16. 'src.api.db.repositories': 'src.infrastructure.database.repositories',
  17. 'src.api.db.models': 'src.presentation.schemas',
  18. # 文档解析相关
  19. 'src.api.dataset': 'src.application.document_parsing',
  20. # 其他
  21. 'src.api.sdk': 'src.application',
  22. }
  23. def find_python_files(root_dir: str, exclude_dirs: List[str] = None) -> List[Path]:
  24. """
  25. 查找所有 Python 文件
  26. Args:
  27. root_dir: 根目录
  28. exclude_dirs: 要排除的目录列表
  29. Returns:
  30. Python 文件路径列表
  31. """
  32. if exclude_dirs is None:
  33. exclude_dirs = ['.git', '__pycache__', '.pytest_cache', 'venv', '.venv', 'node_modules']
  34. python_files = []
  35. root_path = Path(root_dir)
  36. for file_path in root_path.rglob('*.py'):
  37. # 检查是否在排除目录中
  38. if any(excluded in file_path.parts for excluded in exclude_dirs):
  39. continue
  40. python_files.append(file_path)
  41. return python_files
  42. def find_old_imports(file_path: Path) -> List[Tuple[int, str, str]]:
  43. """
  44. 在文件中查找旧的导入语句
  45. Args:
  46. file_path: 文件路径
  47. Returns:
  48. (行号, 旧导入, 建议的新导入) 的列表
  49. """
  50. old_imports = []
  51. try:
  52. with open(file_path, 'r', encoding='utf-8') as f:
  53. lines = f.readlines()
  54. for line_num, line in enumerate(lines, 1):
  55. line = line.strip()
  56. # 跳过注释和空行
  57. if not line or line.startswith('#'):
  58. continue
  59. # 检查 import 语句
  60. for old_module, new_module in IMPORT_MAPPINGS.items():
  61. if old_module in line and ('import' in line or 'from' in line):
  62. suggested = line.replace(old_module, new_module)
  63. old_imports.append((line_num, line, suggested))
  64. except Exception as e:
  65. print(f"Error reading {file_path}: {e}")
  66. return old_imports
  67. def scan_project(root_dir: str = '.') -> dict:
  68. """
  69. 扫描整个项目,查找需要更新的导入
  70. Args:
  71. root_dir: 项目根目录
  72. Returns:
  73. 文件路径到旧导入列表的映射
  74. """
  75. print(f"Scanning project in {root_dir}...")
  76. python_files = find_python_files(root_dir)
  77. print(f"Found {len(python_files)} Python files")
  78. results = {}
  79. for file_path in python_files:
  80. old_imports = find_old_imports(file_path)
  81. if old_imports:
  82. results[str(file_path)] = old_imports
  83. return results
  84. def generate_report(results: dict) -> str:
  85. """
  86. 生成扫描报告
  87. Args:
  88. results: 扫描结果
  89. Returns:
  90. 报告文本
  91. """
  92. if not results:
  93. return "✅ No old imports found!"
  94. report_lines = [
  95. "# Import Update Report",
  96. "",
  97. f"Found {len(results)} files with old imports:",
  98. ""
  99. ]
  100. for file_path, old_imports in results.items():
  101. report_lines.append(f"## {file_path}")
  102. report_lines.append("")
  103. for line_num, old_line, suggested in old_imports:
  104. report_lines.append(f"Line {line_num}:")
  105. report_lines.append(f" Old: {old_line}")
  106. report_lines.append(f" New: {suggested}")
  107. report_lines.append("")
  108. report_lines.append("## Summary")
  109. report_lines.append("")
  110. report_lines.append(f"Total files: {len(results)}")
  111. report_lines.append(f"Total imports to update: {sum(len(imports) for imports in results.values())}")
  112. return "\n".join(report_lines)
  113. def main():
  114. """主函数"""
  115. print("=" * 60)
  116. print("Import Update Scanner")
  117. print("=" * 60)
  118. print()
  119. # 扫描项目
  120. results = scan_project()
  121. # 生成报告
  122. report = generate_report(results)
  123. # 打印报告
  124. print()
  125. print(report)
  126. # 保存报告到文件
  127. report_file = Path('import_update_report.md')
  128. with open(report_file, 'w', encoding='utf-8') as f:
  129. f.write(report)
  130. print()
  131. print(f"Report saved to: {report_file}")
  132. if results:
  133. print()
  134. print("⚠️ Please review and update the imports manually.")
  135. print(" The old modules are deprecated and will be removed in version 2.0.0")
  136. else:
  137. print()
  138. print("✅ All imports are up to date!")
  139. if __name__ == '__main__':
  140. main()