#!/usr/bin/env python3 """ 导入更新脚本 此脚本帮助识别和更新旧模块的导入引用。 """ import os import re from pathlib import Path from typing import List, Tuple # 导入映射:旧模块 -> 新模块 IMPORT_MAPPINGS = { # 向量搜索相关 'src.api.sdk.search_infinity': 'src.application.vector_search.handlers', 'src.api.db.services.vector_search_service': 'src.infrastructure.vector_db', # 数据库相关 'src.api.db.repositories': 'src.infrastructure.database.repositories', 'src.api.db.models': 'src.presentation.schemas', # 文档解析相关 'src.api.dataset': 'src.application.document_parsing', # 其他 'src.api.sdk': 'src.application', } def find_python_files(root_dir: str, exclude_dirs: List[str] = None) -> List[Path]: """ 查找所有 Python 文件 Args: root_dir: 根目录 exclude_dirs: 要排除的目录列表 Returns: Python 文件路径列表 """ if exclude_dirs is None: exclude_dirs = ['.git', '__pycache__', '.pytest_cache', 'venv', '.venv', 'node_modules'] python_files = [] root_path = Path(root_dir) for file_path in root_path.rglob('*.py'): # 检查是否在排除目录中 if any(excluded in file_path.parts for excluded in exclude_dirs): continue python_files.append(file_path) return python_files def find_old_imports(file_path: Path) -> List[Tuple[int, str, str]]: """ 在文件中查找旧的导入语句 Args: file_path: 文件路径 Returns: (行号, 旧导入, 建议的新导入) 的列表 """ old_imports = [] try: with open(file_path, 'r', encoding='utf-8') as f: lines = f.readlines() for line_num, line in enumerate(lines, 1): line = line.strip() # 跳过注释和空行 if not line or line.startswith('#'): continue # 检查 import 语句 for old_module, new_module in IMPORT_MAPPINGS.items(): if old_module in line and ('import' in line or 'from' in line): suggested = line.replace(old_module, new_module) old_imports.append((line_num, line, suggested)) except Exception as e: print(f"Error reading {file_path}: {e}") return old_imports def scan_project(root_dir: str = '.') -> dict: """ 扫描整个项目,查找需要更新的导入 Args: root_dir: 项目根目录 Returns: 文件路径到旧导入列表的映射 """ print(f"Scanning project in {root_dir}...") python_files = find_python_files(root_dir) print(f"Found {len(python_files)} Python files") results = {} for file_path in python_files: old_imports = find_old_imports(file_path) if old_imports: results[str(file_path)] = old_imports return results def generate_report(results: dict) -> str: """ 生成扫描报告 Args: results: 扫描结果 Returns: 报告文本 """ if not results: return "✅ No old imports found!" report_lines = [ "# Import Update Report", "", f"Found {len(results)} files with old imports:", "" ] for file_path, old_imports in results.items(): report_lines.append(f"## {file_path}") report_lines.append("") for line_num, old_line, suggested in old_imports: report_lines.append(f"Line {line_num}:") report_lines.append(f" Old: {old_line}") report_lines.append(f" New: {suggested}") report_lines.append("") report_lines.append("## Summary") report_lines.append("") report_lines.append(f"Total files: {len(results)}") report_lines.append(f"Total imports to update: {sum(len(imports) for imports in results.values())}") return "\n".join(report_lines) def main(): """主函数""" print("=" * 60) print("Import Update Scanner") print("=" * 60) print() # 扫描项目 results = scan_project() # 生成报告 report = generate_report(results) # 打印报告 print() print(report) # 保存报告到文件 report_file = Path('import_update_report.md') with open(report_file, 'w', encoding='utf-8') as f: f.write(report) print() print(f"Report saved to: {report_file}") if results: print() print("⚠️ Please review and update the imports manually.") print(" The old modules are deprecated and will be removed in version 2.0.0") else: print() print("✅ All imports are up to date!") if __name__ == '__main__': main()