本设计文档描述了 Python RAG(检索增强生成)系统的重构方案。该系统是一个基于 FastAPI 的生产环境应用,当前存在代码组织混乱、缺少清晰分层架构等问题。本次重构将采用领域驱动设计(DDD)思想和清晰的分层架构,提高系统的可测试性、可维护性和可扩展性。
决策 1:采用四层架构
决策 2:使用 FastAPI 原生依赖注入
Depends 机制足够强大,无需引入额外的 DI 容器决策 3:按领域组织模块,而非按技术层
┌─────────────────────────────────────────────────────────────┐
│ Presentation Layer │
│ (FastAPI Routers, Request/Response Models, Middleware) │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Application Layer │
│ (Use Cases, Application Services, DTOs, Commands) │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Domain Layer │
│ (Entities, Value Objects, Domain Services, Repositories) │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Infrastructure Layer │
│ (Database, Vector DB, External APIs, File System) │
└─────────────────────────────────────────────────────────────┘
Presentation Layer(表现层)
Application Layer(应用层)
Domain Layer(领域层)
Infrastructure Layer(基础设施层)
rag_system/
├── src/
│ ├── domain/ # 领域层
│ │ ├── vector_search/ # 向量搜索领域
│ │ │ ├── entities.py # 实体(Document, Embedding)
│ │ │ ├── value_objects.py # 值对象(Vector, SearchQuery)
│ │ │ ├── repositories.py # 仓储接口
│ │ │ ├── services.py # 领域服务
│ │ │ └── exceptions.py # 领域异常
│ │ ├── document_parsing/ # 文档解析领域
│ │ │ ├── entities.py
│ │ │ ├── value_objects.py
│ │ │ ├── repositories.py
│ │ │ ├── services.py
│ │ │ └── exceptions.py
│ │ ├── knowledge_base/ # 知识库领域
│ │ │ ├── entities.py
│ │ │ ├── value_objects.py
│ │ │ ├── repositories.py
│ │ │ ├── services.py
│ │ │ └── exceptions.py
│ │ └── shared/ # 共享领域概念
│ │ ├── value_objects.py # 共享值对象(ID, Timestamp)
│ │ └── exceptions.py # 基础领域异常
│ │
│ ├── application/ # 应用层
│ │ ├── vector_search/
│ │ │ ├── commands.py # 命令(CreateIndex, DeleteDocument)
│ │ │ ├── queries.py # 查询(SearchDocuments, GetDocument)
│ │ │ ├── handlers.py # 命令/查询处理器
│ │ │ ├── dtos.py # 数据传输对象
│ │ │ └── services.py # 应用服务
│ │ ├── document_parsing/
│ │ │ ├── commands.py
│ │ │ ├── queries.py
│ │ │ ├── handlers.py
│ │ │ ├── dtos.py
│ │ │ └── services.py
│ │ ├── knowledge_base/
│ │ │ ├── commands.py
│ │ │ ├── queries.py
│ │ │ ├── handlers.py
│ │ │ ├── dtos.py
│ │ │ └── services.py
│ │ └── shared/
│ │ ├── interfaces.py # 共享接口
│ │ └── exceptions.py # 应用异常
│ │
│ ├── infrastructure/ # 基础设施层
│ │ ├── vector_db/
│ │ │ ├── infinity.py # Infinity 实现
│ │ │ ├── elasticsearch.py # Elasticsearch 实现
│ │ │ └── base.py # 向量数据库基类
│ │ ├── database/
│ │ │ ├── repositories.py # 仓储实现
│ │ │ ├── models.py # ORM 模型
│ │ │ └── session.py # 数据库会话管理
│ │ ├── external_services/
│ │ │ ├── ragflow.py # Ragflow 集成
│ │ │ ├── dify.py # Dify 集成
│ │ │ └── base.py # 外部服务基类
│ │ ├── file_storage/
│ │ │ ├── local.py # 本地文件存储
│ │ │ ├── s3.py # S3 存储
│ │ │ └── base.py # 文件存储基类
│ │ └── parsers/
│ │ ├── pdf_parser.py # PDF 解析器
│ │ ├── image_parser.py # 图片解析器
│ │ └── base.py # 解析器基类
│ │
│ ├── presentation/ # 表现层
│ │ ├── api/
│ │ │ ├── v1/
│ │ │ │ ├── vector_search.py # 向量搜索路由
│ │ │ │ ├── documents.py # 文档管理路由
│ │ │ │ ├── knowledge_base.py # 知识库路由
│ │ │ │ └── health.py # 健康检查路由
│ │ │ ├── dependencies.py # FastAPI 依赖
│ │ │ ├── middleware.py # 中间件
│ │ │ └── error_handlers.py # 错误处理器
│ │ └── schemas/
│ │ ├── requests.py # 请求模型
│ │ ├── responses.py # 响应模型
│ │ └── common.py # 共享模式
│ │
│ ├── config/ # 配置管理
│ │ ├── settings.py # 配置类
│ │ ├── database.py # 数据库配置
│ │ ├── vector_db.py # 向量数据库配置
│ │ └── logging.py # 日志配置
│ │
│ └── shared/ # 共享工具
│ ├── logging.py # 日志工具
│ ├── exceptions.py # 基础异常
│ └── utils.py # 通用工具函数
│
├── tests/ # 测试
│ ├── unit/ # 单元测试
│ │ ├── domain/
│ │ ├── application/
│ │ └── infrastructure/
│ ├── integration/ # 集成测试
│ │ ├── api/
│ │ └── database/
│ ├── e2e/ # 端到端测试
│ ├── fixtures/ # 测试夹具
│ └── conftest.py # pytest 配置
│
├── docs/ # 文档
│ ├── architecture.md # 架构文档
│ ├── api.md # API 文档
│ ├── deployment.md # 部署文档
│ └── development.md # 开发指南
│
├── scripts/ # 脚本
│ ├── migrate.py # 迁移脚本
│ └── seed.py # 数据种子脚本
│
├── docker/ # Docker 配置
│ ├── Dockerfile
│ └── docker-compose.yml
│
├── .env.example # 环境变量示例
├── pyproject.toml # 项目配置
├── requirements.txt # 依赖
└── README.md # 项目说明
核心实体:
# domain/vector_search/entities.py
from dataclasses import dataclass
from typing import List, Optional
from datetime import datetime
@dataclass
class Document:
"""文档实体"""
id: str
content: str
embedding: Optional['Vector']
metadata: dict
created_at: datetime
updated_at: datetime
def update_content(self, new_content: str) -> None:
"""更新文档内容"""
self.content = new_content
self.updated_at = datetime.now()
# 领域规则:内容更新后需要重新生成嵌入
self.embedding = None
@dataclass
class SearchResult:
"""搜索结果实体"""
document: Document
score: float
rank: int
值对象:
# domain/vector_search/value_objects.py
from dataclasses import dataclass
from typing import List
@dataclass(frozen=True)
class Vector:
"""向量值对象(不可变)"""
dimensions: List[float]
def __post_init__(self):
if not self.dimensions:
raise ValueError("Vector cannot be empty")
if not all(isinstance(d, (int, float)) for d in self.dimensions):
raise ValueError("All dimensions must be numeric")
@property
def dimension_count(self) -> int:
return len(self.dimensions)
@dataclass(frozen=True)
class SearchQuery:
"""搜索查询值对象"""
text: str
top_k: int = 10
filters: dict = None
def __post_init__(self):
if self.top_k <= 0:
raise ValueError("top_k must be positive")
仓储接口:
# domain/vector_search/repositories.py
from abc import ABC, abstractmethod
from typing import List, Optional
class DocumentRepository(ABC):
"""文档仓储接口"""
@abstractmethod
async def save(self, document: Document) -> None:
"""保存文档"""
pass
@abstractmethod
async def find_by_id(self, document_id: str) -> Optional[Document]:
"""根据 ID 查找文档"""
pass
@abstractmethod
async def delete(self, document_id: str) -> None:
"""删除文档"""
pass
@abstractmethod
async def search(self, query: SearchQuery) -> List[SearchResult]:
"""搜索文档"""
pass
领域服务:
# domain/vector_search/services.py
from typing import List
class HybridSearchService:
"""混合搜索领域服务"""
def __init__(self, vector_weight: float = 0.7, text_weight: float = 0.3):
self.vector_weight = vector_weight
self.text_weight = text_weight
def combine_scores(
self,
vector_results: List[SearchResult],
text_results: List[SearchResult]
) -> List[SearchResult]:
"""合并向量搜索和全文搜索结果"""
# 领域逻辑:加权合并两种搜索结果
combined = {}
for result in vector_results:
doc_id = result.document.id
combined[doc_id] = result.score * self.vector_weight
for result in text_results:
doc_id = result.document.id
if doc_id in combined:
combined[doc_id] += result.score * self.text_weight
else:
combined[doc_id] = result.score * self.text_weight
# 按分数排序并返回
sorted_results = sorted(
combined.items(),
key=lambda x: x[1],
reverse=True
)
return self._build_results(sorted_results, vector_results, text_results)
核心实体:
# domain/document_parsing/entities.py
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
class DocumentType(Enum):
PDF = "pdf"
IMAGE = "image"
TEXT = "text"
QA_PAIR = "qa_pair"
@dataclass
class ParsedDocument:
"""解析后的文档实体"""
id: str
original_filename: str
document_type: DocumentType
chunks: List['DocumentChunk']
metadata: dict
def add_chunk(self, chunk: 'DocumentChunk') -> None:
"""添加文档块"""
self.chunks.append(chunk)
def validate(self) -> bool:
"""验证文档完整性"""
return len(self.chunks) > 0 and all(
chunk.content.strip() for chunk in self.chunks
)
@dataclass
class DocumentChunk:
"""文档块实体"""
id: str
content: str
page_number: Optional[int]
position: int
metadata: dict
领域服务:
# domain/document_parsing/services.py
from abc import ABC, abstractmethod
class DocumentParser(ABC):
"""文档解析器接口(领域服务)"""
@abstractmethod
async def parse(self, file_path: str) -> ParsedDocument:
"""解析文档"""
pass
@abstractmethod
def supports(self, document_type: DocumentType) -> bool:
"""检查是否支持该文档类型"""
pass
class ChunkingStrategy(ABC):
"""分块策略接口"""
@abstractmethod
def chunk(self, content: str) -> List[str]:
"""将内容分块"""
pass
核心实体:
# domain/knowledge_base/entities.py
from dataclasses import dataclass
from typing import List, Set
from datetime import datetime
@dataclass
class KnowledgeBase:
"""知识库实体"""
id: str
name: str
description: str
tags: Set[str]
document_ids: List[str]
created_at: datetime
updated_at: datetime
def add_document(self, document_id: str) -> None:
"""添加文档到知识库"""
if document_id not in self.document_ids:
self.document_ids.append(document_id)
self.updated_at = datetime.now()
def remove_document(self, document_id: str) -> None:
"""从知识库移除文档"""
if document_id in self.document_ids:
self.document_ids.remove(document_id)
self.updated_at = datetime.now()
def add_tag(self, tag: str) -> None:
"""添加标签"""
self.tags.add(tag)
self.updated_at = datetime.now()
@dataclass
class PromptDimension:
"""提示词维度实体"""
id: str
name: str
description: str
template: str
variables: List[str]
# application/vector_search/commands.py
from dataclasses import dataclass
from typing import Dict
@dataclass
class CreateDocumentCommand:
"""创建文档命令"""
content: str
metadata: Dict[str, any]
@dataclass
class UpdateDocumentCommand:
"""更新文档命令"""
document_id: str
content: str
metadata: Dict[str, any]
@dataclass
class DeleteDocumentCommand:
"""删除文档命令"""
document_id: str
# application/vector_search/queries.py
@dataclass
class SearchDocumentsQuery:
"""搜索文档查询"""
query_text: str
top_k: int = 10
filters: Dict[str, any] = None
use_hybrid: bool = True
@dataclass
class GetDocumentQuery:
"""获取文档查询"""
document_id: str
命令处理器:
# application/vector_search/handlers.py
from typing import Protocol
class CommandHandler(Protocol):
"""命令处理器协议"""
async def handle(self, command) -> any:
...
class CreateDocumentHandler:
"""创建文档命令处理器"""
def __init__(
self,
document_repository: DocumentRepository,
embedding_service: 'EmbeddingService'
):
self.document_repository = document_repository
self.embedding_service = embedding_service
async def handle(self, command: CreateDocumentCommand) -> str:
"""处理创建文档命令"""
# 1. 生成嵌入向量
embedding = await self.embedding_service.generate(command.content)
# 2. 创建领域实体
document = Document(
id=generate_id(),
content=command.content,
embedding=embedding,
metadata=command.metadata,
created_at=datetime.now(),
updated_at=datetime.now()
)
# 3. 保存到仓储
await self.document_repository.save(document)
return document.id
class SearchDocumentsHandler:
"""搜索文档查询处理器"""
def __init__(
self,
document_repository: DocumentRepository,
hybrid_search_service: HybridSearchService
):
self.document_repository = document_repository
self.hybrid_search_service = hybrid_search_service
async def handle(self, query: SearchDocumentsQuery) -> List['DocumentDTO']:
"""处理搜索文档查询"""
# 构建搜索查询值对象
search_query = SearchQuery(
text=query.query_text,
top_k=query.top_k,
filters=query.filters
)
# 执行搜索
if query.use_hybrid:
results = await self._hybrid_search(search_query)
else:
results = await self.document_repository.search(search_query)
# 转换为 DTO
return [self._to_dto(result) for result in results]
数据传输对象(DTO):
# application/vector_search/dtos.py
from dataclasses import dataclass
from typing import Dict, Optional
@dataclass
class DocumentDTO:
"""文档 DTO"""
id: str
content: str
metadata: Dict[str, any]
score: Optional[float] = None
@classmethod
def from_entity(cls, document: Document, score: Optional[float] = None):
"""从领域实体创建 DTO"""
return cls(
id=document.id,
content=document.content,
metadata=document.metadata,
score=score
)
# infrastructure/vector_db/base.py
from abc import ABC, abstractmethod
from typing import List
class VectorDatabase(ABC):
"""向量数据库抽象基类"""
@abstractmethod
async def connect(self) -> None:
"""连接数据库"""
pass
@abstractmethod
async def create_index(self, index_name: str, dimension: int) -> None:
"""创建索引"""
pass
@abstractmethod
async def insert(self, index_name: str, vectors: List[Vector], metadata: List[dict]) -> None:
"""插入向量"""
pass
@abstractmethod
async def search(self, index_name: str, query_vector: Vector, top_k: int) -> List[dict]:
"""搜索向量"""
pass
# infrastructure/vector_db/infinity.py
class InfinityVectorDB(VectorDatabase):
"""Infinity 向量数据库实现"""
def __init__(self, config: 'InfinityConfig'):
self.config = config
self.client = None
async def connect(self) -> None:
"""连接 Infinity"""
self.client = InfinityClient(
host=self.config.host,
port=self.config.port
)
await self.client.connect()
async def create_index(self, index_name: str, dimension: int) -> None:
"""创建 Infinity 索引"""
await self.client.create_table(
table_name=index_name,
columns={
"id": "varchar",
"vector": f"vector,{dimension},float",
"metadata": "json"
}
)
async def insert(self, index_name: str, vectors: List[Vector], metadata: List[dict]) -> None:
"""插入向量到 Infinity"""
data = [
{
"id": str(uuid.uuid4()),
"vector": vector.dimensions,
"metadata": meta
}
for vector, meta in zip(vectors, metadata)
]
await self.client.insert(index_name, data)
async def search(self, index_name: str, query_vector: Vector, top_k: int) -> List[dict]:
"""在 Infinity 中搜索"""
results = await self.client.search(
table_name=index_name,
query_vector=query_vector.dimensions,
top_k=top_k
)
return results
# infrastructure/vector_db/elasticsearch.py
class ElasticsearchVectorDB(VectorDatabase):
"""Elasticsearch 向量数据库实现"""
def __init__(self, config: 'ElasticsearchConfig'):
self.config = config
self.client = None
async def connect(self) -> None:
"""连接 Elasticsearch"""
self.client = AsyncElasticsearch(
hosts=[f"{self.config.host}:{self.config.port}"]
)
async def create_index(self, index_name: str, dimension: int) -> None:
"""创建 Elasticsearch 索引"""
mapping = {
"mappings": {
"properties": {
"vector": {
"type": "dense_vector",
"dims": dimension
},
"metadata": {
"type": "object"
}
}
}
}
await self.client.indices.create(index=index_name, body=mapping)
# ... 其他方法实现
# infrastructure/database/repositories.py
from domain.vector_search.repositories import DocumentRepository
from domain.vector_search.entities import Document, SearchResult
class SQLDocumentRepository(DocumentRepository):
"""SQL 文档仓储实现"""
def __init__(
self,
session_factory: 'SessionFactory',
vector_db: VectorDatabase
):
self.session_factory = session_factory
self.vector_db = vector_db
async def save(self, document: Document) -> None:
"""保存文档"""
async with self.session_factory() as session:
# 保存到关系数据库
db_document = DocumentModel(
id=document.id,
content=document.content,
metadata=document.metadata,
created_at=document.created_at,
updated_at=document.updated_at
)
session.add(db_document)
await session.commit()
# 保存向量到向量数据库
if document.embedding:
await self.vector_db.insert(
index_name="documents",
vectors=[document.embedding],
metadata=[{"doc_id": document.id}]
)
async def find_by_id(self, document_id: str) -> Optional[Document]:
"""根据 ID 查找文档"""
async with self.session_factory() as session:
result = await session.execute(
select(DocumentModel).where(DocumentModel.id == document_id)
)
db_document = result.scalar_one_or_none()
if not db_document:
return None
return self._to_entity(db_document)
async def search(self, query: SearchQuery) -> List[SearchResult]:
"""搜索文档"""
# 使用向量数据库搜索
vector_results = await self.vector_db.search(
index_name="documents",
query_vector=query.vector,
top_k=query.top_k
)
# 获取完整文档信息
doc_ids = [r["metadata"]["doc_id"] for r in vector_results]
documents = await self._fetch_documents(doc_ids)
# 构建搜索结果
return [
SearchResult(
document=documents[r["metadata"]["doc_id"]],
score=r["score"],
rank=idx
)
for idx, r in enumerate(vector_results)
]
# presentation/api/v1/vector_search.py
from fastapi import APIRouter, Depends, HTTPException
from presentation.schemas.requests import CreateDocumentRequest, SearchRequest
from presentation.schemas.responses import DocumentResponse, SearchResponse
from application.vector_search.handlers import CreateDocumentHandler, SearchDocumentsHandler
from application.vector_search.commands import CreateDocumentCommand
from application.vector_search.queries import SearchDocumentsQuery
router = APIRouter(prefix="/api/v1/documents", tags=["documents"])
@router.post("/", response_model=DocumentResponse, status_code=201)
async def create_document(
request: CreateDocumentRequest,
handler: CreateDocumentHandler = Depends(get_create_document_handler)
):
"""创建文档 API"""
try:
command = CreateDocumentCommand(
content=request.content,
metadata=request.metadata
)
document_id = await handler.handle(command)
return DocumentResponse(id=document_id, message="Document created successfully")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/search", response_model=SearchResponse)
async def search_documents(
request: SearchRequest,
handler: SearchDocumentsHandler = Depends(get_search_documents_handler)
):
"""搜索文档 API"""
try:
query = SearchDocumentsQuery(
query_text=request.query,
top_k=request.top_k,
filters=request.filters,
use_hybrid=request.use_hybrid
)
results = await handler.handle(query)
return SearchResponse(results=results, total=len(results))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# presentation/api/dependencies.py
from functools import lru_cache
from infrastructure.database.repositories import SQLDocumentRepository
from infrastructure.vector_db.infinity import InfinityVectorDB
from application.vector_search.handlers import CreateDocumentHandler, SearchDocumentsHandler
from config.settings import get_settings
@lru_cache()
def get_vector_db() -> VectorDatabase:
"""获取向量数据库实例"""
settings = get_settings()
if settings.vector_db_type == "infinity":
return InfinityVectorDB(settings.infinity_config)
elif settings.vector_db_type == "elasticsearch":
return ElasticsearchVectorDB(settings.elasticsearch_config)
else:
raise ValueError(f"Unsupported vector DB: {settings.vector_db_type}")
def get_document_repository(
vector_db: VectorDatabase = Depends(get_vector_db)
) -> DocumentRepository:
"""获取文档仓储"""
return SQLDocumentRepository(
session_factory=get_session_factory(),
vector_db=vector_db
)
def get_create_document_handler(
repository: DocumentRepository = Depends(get_document_repository)
) -> CreateDocumentHandler:
"""获取创建文档处理器"""
return CreateDocumentHandler(
document_repository=repository,
embedding_service=get_embedding_service()
)
def get_search_documents_handler(
repository: DocumentRepository = Depends(get_document_repository)
) -> SearchDocumentsHandler:
"""获取搜索文档处理器"""
return SearchDocumentsHandler(
document_repository=repository,
hybrid_search_service=HybridSearchService()
)
# presentation/schemas/requests.py
from pydantic import BaseModel, Field
from typing import Dict, Optional
class CreateDocumentRequest(BaseModel):
"""创建文档请求"""
content: str = Field(..., min_length=1, description="文档内容")
metadata: Dict[str, any] = Field(default_factory=dict, description="文档元数据")
class SearchRequest(BaseModel):
"""搜索请求"""
query: str = Field(..., min_length=1, description="搜索查询")
top_k: int = Field(default=10, ge=1, le=100, description="返回结果数量")
filters: Optional[Dict[str, any]] = Field(default=None, description="过滤条件")
use_hybrid: bool = Field(default=True, description="是否使用混合搜索")
# presentation/schemas/responses.py
class DocumentResponse(BaseModel):
"""文档响应"""
id: str
message: str
class SearchResultItem(BaseModel):
"""搜索结果项"""
id: str
content: str
metadata: Dict[str, any]
score: float
class SearchResponse(BaseModel):
"""搜索响应"""
results: List[SearchResultItem]
total: int
# presentation/api/error_handlers.py
from fastapi import Request, status
from fastapi.responses import JSONResponse
from domain.shared.exceptions import DomainException
from application.shared.exceptions import ApplicationException
async def domain_exception_handler(request: Request, exc: DomainException):
"""领域异常处理器"""
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
"error": "Domain Error",
"message": str(exc),
"type": exc.__class__.__name__
}
)
async def application_exception_handler(request: Request, exc: ApplicationException):
"""应用异常处理器"""
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"error": "Application Error",
"message": str(exc),
"type": exc.__class__.__name__
}
)
async def generic_exception_handler(request: Request, exc: Exception):
"""通用异常处理器"""
# 记录错误日志
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"error": "Internal Server Error",
"message": "An unexpected error occurred"
}
)
# config/settings.py
from pydantic_settings import BaseSettings
from functools import lru_cache
class DatabaseSettings(BaseSettings):
"""数据库配置"""
host: str = "localhost"
port: int = 5432
database: str = "rag_system"
username: str
password: str
class Config:
env_prefix = "DB_"
class InfinitySettings(BaseSettings):
"""Infinity 配置"""
host: str = "localhost"
port: int = 23817
class Config:
env_prefix = "INFINITY_"
class ElasticsearchSettings(BaseSettings):
"""Elasticsearch 配置"""
host: str = "localhost"
port: int = 9200
class Config:
env_prefix = "ES_"
class Settings(BaseSettings):
"""应用配置"""
app_name: str = "RAG System"
debug: bool = False
vector_db_type: str = "infinity" # infinity 或 elasticsearch
# 子配置
database: DatabaseSettings = DatabaseSettings()
infinity: InfinitySettings = InfinitySettings()
elasticsearch: ElasticsearchSettings = ElasticsearchSettings()
class Config:
env_file = ".env"
case_sensitive = False
@lru_cache()
def get_settings() -> Settings:
"""获取配置单例"""
return Settings()
Document(文档):
Vector(向量):
SearchQuery(搜索查询):
KnowledgeBase(知识库):
# infrastructure/database/models.py
from sqlalchemy import Column, String, Text, JSON, DateTime, ForeignKey, Table
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class DocumentModel(Base):
"""文档数据库模型"""
__tablename__ = "documents"
id = Column(String(36), primary_key=True)
content = Column(Text, nullable=False)
metadata = Column(JSON, default={})
created_at = Column(DateTime, nullable=False)
updated_at = Column(DateTime, nullable=False)
# 关系
knowledge_bases = relationship(
"KnowledgeBaseModel",
secondary="knowledge_base_documents",
back_populates="documents"
)
class KnowledgeBaseModel(Base):
"""知识库数据库模型"""
__tablename__ = "knowledge_bases"
id = Column(String(36), primary_key=True)
name = Column(String(255), nullable=False)
description = Column(Text)
tags = Column(JSON, default=[])
created_at = Column(DateTime, nullable=False)
updated_at = Column(DateTime, nullable=False)
# 关系
documents = relationship(
"DocumentModel",
secondary="knowledge_base_documents",
back_populates="knowledge_bases"
)
# 关联表
knowledge_base_documents = Table(
"knowledge_base_documents",
Base.metadata,
Column("knowledge_base_id", String(36), ForeignKey("knowledge_bases.id")),
Column("document_id", String(36), ForeignKey("documents.id"))
)
属性是一个特征或行为,应该在系统的所有有效执行中保持为真——本质上是关于系统应该做什么的形式化陈述。属性作为人类可读规范和机器可验证正确性保证之间的桥梁。
*对于任何*模块,如果它位于领域层,那么它不应该导入来自应用层、基础设施层或表现层的任何内容。如果它位于应用层,那么它不应该导入来自基础设施层或表现层的任何内容。
Validates: Requirements 1.2
实现说明:使用静态分析工具(如 import-linter)检查模块间的导入依赖关系,确保依赖方向符合分层架构原则。
*对于任何*领域层模块中的文件,它不应该导入任何外部框架(FastAPI、SQLAlchemy、Pydantic 等),只能导入标准库和领域内部模块。
Validates: Requirements 1.3
实现说明:扫描 domain/ 目录下的所有 Python 文件,检查 import 语句,确保没有导入外部框架。
*对于任何*配置项,系统应该能够从环境变量、配置文件或默认值中加载该配置,并且环境变量的优先级应该高于配置文件,配置文件的优先级应该高于默认值。
Validates: Requirements 2.2
实现说明:为每个配置项创建测试,验证在不同来源提供配置时的加载顺序和优先级。
*对于任何*无效的配置值(类型错误、超出范围等),配置模块应该在加载时抛出验证错误,而不是在运行时才发现问题。
Validates: Requirements 2.4, 2.5
实现说明:使用 Pydantic 的验证机制,为每种配置类型创建测试用例,提供无效值并验证是否抛出预期的验证错误。
*对于任何*需要外部依赖的组件(服务、处理器、仓储等),该组件应该通过构造函数参数或 FastAPI 的 Depends 机制接收依赖,而不是在内部直接实例化依赖。
Validates: Requirements 3.2, 3.3
实现说明:通过代码审查或静态分析检查组件的构造函数和方法,确保没有直接实例化依赖对象(如 repository = SQLRepository()),而是通过参数注入。
*对于任何*定义的抽象接口(ABC 类),至少应该存在一个具体的实现类,并且该实现类应该实现接口的所有抽象方法。
Validates: Requirements 3.4
实现说明:扫描所有 ABC 类,检查是否有对应的实现类,并验证实现类是否实现了所有抽象方法。
*对于任何*环境(开发、测试、生产),依赖注入系统应该能够根据环境配置注入不同的依赖实现(如测试环境使用 Mock 数据库,生产环境使用真实数据库)。
Validates: Requirements 3.5
实现说明:创建测试用例,在不同环境配置下验证注入的依赖类型是否正确。
*对于任何*在重构前存在的 API 端点,重构后该端点的 URL 路径、HTTP 方法、请求参数和响应格式应该保持不变。
Validates: Requirements 4.1, 4.2
实现说明:在重构前记录所有 API 端点的规范(可以使用 OpenAPI 规范),重构后运行契约测试验证端点行为一致性。
*对于任何*API 响应,响应头中应该包含 API 版本信息(如 X-API-Version: 1.0)。
Validates: Requirements 4.5
实现说明:为所有 API 端点创建测试,验证响应头中是否包含版本信息。
*对于任何*迁移阶段完成后的系统状态,运行完整的测试套件应该全部通过,并且所有关键业务流程应该能够正常执行。
Validates: Requirements 5.2, 5.3
实现说明:在每个迁移阶段完成后,运行完整的集成测试和端到端测试套件,确保所有测试通过。
*对于任何*正在迁移的功能,在迁移期间应该同时存在旧代码路径和新代码路径,并且两个路径都应该能够正常工作并产生等价的结果。
Validates: Requirements 5.4
实现说明:在迁移期间,为同一功能创建测试用例,分别调用旧路径和新路径,验证结果的等价性。
*对于任何*在系统内部抛出的异常(领域异常、应用异常、基础设施异常),当它到达表现层时,应该被统一捕获并转换为适当的 HTTP 响应(包含正确的状态码和错误消息)。
Validates: Requirements 6.2
实现说明:为每种异常类型创建测试,模拟异常抛出,验证 API 响应的状态码和错误格式是否正确。
*对于任何*日志记录,日志输出应该是结构化的(如 JSON 格式),并且应该包含必需的上下文信息(请求 ID、时间戳、日志级别、消息),当记录错误时还应该包含堆栈跟踪和相关输入参数。
Validates: Requirements 6.3, 6.4, 6.5
实现说明:配置日志捕获,触发不同类型的日志记录(INFO、ERROR 等),验证日志输出格式和包含的字段。
*对于任何*日志记录,当配置了多个日志输出目标(控制台、文件、远程服务)时,日志应该同时输出到所有配置的目标。
Validates: Requirements 6.6
实现说明:配置多个日志处理器,触发日志记录,验证所有目标都收到了日志输出。
*对于任何*模块间的交互,调用方应该只导入和使用被调用模块的公共接口(在 __init__.py 中导出的内容),而不应该直接导入模块的内部实现细节。
Validates: Requirements 8.7
实现说明:使用静态分析工具检查跨模块的导入语句,确保只导入公共 API。
*对于任何*公共类、函数或方法,它应该有一个符合规范的文档字符串(docstring),包含描述、参数说明和返回值说明。
Validates: Requirements 9.2
实现说明:扫描所有公共 API(不以下划线开头的类和函数),检查是否有 docstring,并验证 docstring 的格式是否符合规范(如 Google 风格或 NumPy 风格)。
*对于任何*被迁移的代码模块,迁移前后运行相同的测试套件应该产生相同的结果,证明功能行为没有改变。
Validates: Requirements 10.2
实现说明:在迁移前为旧代码创建测试套件,迁移后对新代码运行相同的测试,比较测试结果。
*对于任何*被标记为废弃的代码路径,当该代码被调用时,系统应该在日志中记录一条警告消息,提示该代码已废弃。
Validates: Requirements 10.4
实现说明:使用 Python 的 warnings 模块或自定义装饰器标记废弃代码,创建测试调用废弃代码,验证日志中是否有警告记录。
*对于任何*关键 API 端点,重构后的响应时间(P95)应该不超过重构前的响应时间的 110%(允许 10% 的性能波动)。
Validates: Requirements 11.1
实现说明:在重构前后分别运行性能基准测试,记录每个端点的 P95 响应时间,比较并验证性能没有显著退化。
*对于任何*HTTP 请求,系统应该在日志中记录该请求的处理时间、内存使用情况和其他相关性能指标。
Validates: Requirements 11.4
实现说明:配置请求日志中间件,发送测试请求,验证日志中是否包含处理时间和资源使用信息。
# shared/exceptions.py
class RAGSystemException(Exception):
"""所有系统异常的基类"""
def __init__(self, message: str, details: dict = None):
self.message = message
self.details = details or {}
super().__init__(self.message)
# domain/shared/exceptions.py
class DomainException(RAGSystemException):
"""领域异常基类"""
pass
class InvalidValueException(DomainException):
"""无效值异常"""
pass
class BusinessRuleViolationException(DomainException):
"""业务规则违反异常"""
pass
# application/shared/exceptions.py
class ApplicationException(RAGSystemException):
"""应用异常基类"""
pass
class ResourceNotFoundException(ApplicationException):
"""资源未找到异常"""
pass
class ValidationException(ApplicationException):
"""验证异常"""
pass
# infrastructure/shared/exceptions.py
class InfrastructureException(RAGSystemException):
"""基础设施异常基类"""
pass
class DatabaseException(InfrastructureException):
"""数据库异常"""
pass
class ExternalServiceException(InfrastructureException):
"""外部服务异常"""
pass
领域层:
应用层:
基础设施层:
表现层:
# presentation/schemas/responses.py
class ErrorResponse(BaseModel):
"""统一错误响应格式"""
error: str # 错误类型
message: str # 错误消息
details: Optional[Dict[str, any]] = None # 详细信息
request_id: Optional[str] = None # 请求 ID
timestamp: datetime = Field(default_factory=datetime.now)
# 示例错误响应
{
"error": "ValidationError",
"message": "Invalid document content",
"details": {
"field": "content",
"reason": "Content cannot be empty"
},
"request_id": "req_123456",
"timestamp": "2025-01-15T10:30:00Z"
}
日志级别使用指南:
结构化日志格式:
# config/logging.py
import structlog
def configure_logging():
"""配置结构化日志"""
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer()
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
# 日志输出示例
{
"event": "document_created",
"level": "info",
"timestamp": "2025-01-15T10:30:00.123Z",
"request_id": "req_123456",
"user_id": "user_789",
"document_id": "doc_abc",
"duration_ms": 45
}
请求日志中间件:
# presentation/api/middleware.py
import time
import uuid
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""请求日志中间件"""
async def dispatch(self, request: Request, call_next):
# 生成请求 ID
request_id = str(uuid.uuid4())
request.state.request_id = request_id
# 记录请求开始
start_time = time.time()
logger.info(
"request_started",
request_id=request_id,
method=request.method,
path=request.url.path,
client_ip=request.client.host
)
try:
# 处理请求
response = await call_next(request)
# 记录请求完成
duration_ms = (time.time() - start_time) * 1000
logger.info(
"request_completed",
request_id=request_id,
status_code=response.status_code,
duration_ms=duration_ms
)
# 添加响应头
response.headers["X-Request-ID"] = request_id
response.headers["X-API-Version"] = "1.0"
return response
except Exception as e:
# 记录请求错误
duration_ms = (time.time() - start_time) * 1000
logger.error(
"request_failed",
request_id=request_id,
error=str(e),
duration_ms=duration_ms,
exc_info=True
)
raise
/\
/ \
/ E2E \ 少量端到端测试(关键业务流程)
/______\
/ \
/Integration\ 中等数量集成测试(组件交互)
/____________\
/ \
/ Unit Tests \ 大量单元测试(领域逻辑、应用服务)
/__________________\
测试范围:
测试原则:
示例:
# tests/unit/domain/vector_search/test_entities.py
import pytest
from domain.vector_search.entities import Document
from domain.vector_search.value_objects import Vector
class TestDocument:
"""文档实体单元测试"""
def test_update_content_changes_content(self):
"""测试更新内容会改变文档内容"""
# Arrange
doc = Document(
id="doc_1",
content="original",
embedding=Vector([1.0, 2.0]),
metadata={},
created_at=datetime.now(),
updated_at=datetime.now()
)
# Act
doc.update_content("new content")
# Assert
assert doc.content == "new content"
def test_update_content_clears_embedding(self):
"""测试更新内容会清除嵌入向量"""
# Arrange
doc = Document(
id="doc_1",
content="original",
embedding=Vector([1.0, 2.0]),
metadata={},
created_at=datetime.now(),
updated_at=datetime.now()
)
# Act
doc.update_content("new content")
# Assert
assert doc.embedding is None
def test_update_content_updates_timestamp(self):
"""测试更新内容会更新时间戳"""
# Arrange
original_time = datetime.now()
doc = Document(
id="doc_1",
content="original",
embedding=None,
metadata={},
created_at=original_time,
updated_at=original_time
)
# Act
time.sleep(0.01) # 确保时间差异
doc.update_content("new content")
# Assert
assert doc.updated_at > original_time
# tests/unit/domain/vector_search/test_value_objects.py
class TestVector:
"""向量值对象单元测试"""
def test_vector_creation_with_valid_dimensions(self):
"""测试使用有效维度创建向量"""
# Act
vector = Vector([1.0, 2.0, 3.0])
# Assert
assert vector.dimensions == [1.0, 2.0, 3.0]
assert vector.dimension_count == 3
def test_vector_creation_with_empty_dimensions_raises_error(self):
"""测试使用空维度创建向量会抛出错误"""
# Act & Assert
with pytest.raises(ValueError, match="Vector cannot be empty"):
Vector([])
def test_vector_is_immutable(self):
"""测试向量是不可变的"""
# Arrange
vector = Vector([1.0, 2.0])
# Act & Assert
with pytest.raises(AttributeError):
vector.dimensions = [3.0, 4.0]
测试范围:
测试原则:
示例:
# tests/integration/infrastructure/test_repositories.py
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from infrastructure.database.repositories import SQLDocumentRepository
from infrastructure.database.models import Base
from domain.vector_search.entities import Document
@pytest.fixture
def test_db():
"""测试数据库夹具"""
# 创建内存数据库
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
SessionLocal = sessionmaker(bind=engine)
yield SessionLocal
# 清理
Base.metadata.drop_all(engine)
@pytest.fixture
def repository(test_db):
"""仓储夹具"""
return SQLDocumentRepository(
session_factory=test_db,
vector_db=MockVectorDB()
)
class TestSQLDocumentRepository:
"""SQL 文档仓储集成测试"""
async def test_save_and_find_document(self, repository):
"""测试保存和查找文档"""
# Arrange
doc = Document(
id="doc_1",
content="test content",
embedding=Vector([1.0, 2.0]),
metadata={"key": "value"},
created_at=datetime.now(),
updated_at=datetime.now()
)
# Act
await repository.save(doc)
found_doc = await repository.find_by_id("doc_1")
# Assert
assert found_doc is not None
assert found_doc.id == "doc_1"
assert found_doc.content == "test content"
assert found_doc.metadata == {"key": "value"}
async def test_delete_document(self, repository):
"""测试删除文档"""
# Arrange
doc = Document(
id="doc_1",
content="test content",
embedding=None,
metadata={},
created_at=datetime.now(),
updated_at=datetime.now()
)
await repository.save(doc)
# Act
await repository.delete("doc_1")
found_doc = await repository.find_by_id("doc_1")
# Assert
assert found_doc is None
测试范围:
测试原则:
示例:
# tests/e2e/test_document_workflow.py
import pytest
from httpx import AsyncClient
@pytest.fixture
async def client():
"""HTTP 客户端夹具"""
async with AsyncClient(base_url="http://localhost:8000") as client:
yield client
class TestDocumentWorkflow:
"""文档工作流端到端测试"""
async def test_complete_document_lifecycle(self, client):
"""测试完整的文档生命周期"""
# 1. 创建文档
create_response = await client.post(
"/api/v1/documents/",
json={
"content": "This is a test document",
"metadata": {"source": "test"}
}
)
assert create_response.status_code == 201
doc_id = create_response.json()["id"]
# 2. 搜索文档
search_response = await client.post(
"/api/v1/documents/search",
json={
"query": "test document",
"top_k": 10
}
)
assert search_response.status_code == 200
results = search_response.json()["results"]
assert len(results) > 0
assert any(r["id"] == doc_id for r in results)
# 3. 获取文档
get_response = await client.get(f"/api/v1/documents/{doc_id}")
assert get_response.status_code == 200
doc = get_response.json()
assert doc["content"] == "This is a test document"
# 4. 删除文档
delete_response = await client.delete(f"/api/v1/documents/{doc_id}")
assert delete_response.status_code == 204
# 5. 验证删除
get_after_delete = await client.get(f"/api/v1/documents/{doc_id}")
assert get_after_delete.status_code == 404
虽然本项目是重构项目,但对于核心领域逻辑,我们应该使用属性测试来验证通用规则:
使用 Hypothesis 库:
# tests/unit/domain/vector_search/test_properties.py
from hypothesis import given, strategies as st
from domain.vector_search.value_objects import Vector, SearchQuery
class TestVectorProperties:
"""向量属性测试"""
@given(st.lists(st.floats(allow_nan=False, allow_infinity=False), min_size=1))
def test_vector_dimension_count_equals_list_length(self, dimensions):
"""属性:向量的维度数量应该等于维度列表的长度"""
# Feature: rag-system-refactoring, Property: Vector dimension count
vector = Vector(dimensions)
assert vector.dimension_count == len(dimensions)
@given(st.lists(st.floats(allow_nan=False, allow_infinity=False), min_size=1))
def test_vector_immutability(self, dimensions):
"""属性:向量创建后不可修改"""
# Feature: rag-system-refactoring, Property: Vector immutability
vector = Vector(dimensions)
original_dims = vector.dimensions
# 尝试修改应该失败
with pytest.raises(AttributeError):
vector.dimensions = [1.0, 2.0]
# 原始值不变
assert vector.dimensions == original_dims
class TestSearchQueryProperties:
"""搜索查询属性测试"""
@given(
text=st.text(min_size=1),
top_k=st.integers(min_value=1, max_value=1000)
)
def test_search_query_creation_with_valid_inputs(self, text, top_k):
"""属性:使用有效输入创建搜索查询应该成功"""
# Feature: rag-system-refactoring, Property: SearchQuery creation
query = SearchQuery(text=text, top_k=top_k)
assert query.text == text
assert query.top_k == top_k
pytest 配置:
# tests/conftest.py
import pytest
import asyncio
@pytest.fixture(scope="session")
def event_loop():
"""创建事件循环"""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True)
def reset_singletons():
"""每个测试后重置单例"""
yield
# 清理缓存的配置等
# pytest.ini
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
asyncio_mode = auto
markers =
unit: Unit tests
integration: Integration tests
e2e: End-to-end tests
slow: Slow tests
# .github/workflows/test.yml
name: Tests
on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.11'
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Run unit tests
run: pytest tests/unit -v --cov=src --cov-report=xml
- name: Run integration tests
run: pytest tests/integration -v
- name: Run E2E tests
run: pytest tests/e2e -v
- name: Upload coverage
uses: codecov/codecov-action@v2
with:
file: ./coverage.xml
Phase 1: 基础设施搭建(2 周)
Phase 2: 领域层重构(3 周)
Phase 3: 应用层重构(3 周)
Phase 4: 基础设施层迁移(4 周)
Phase 5: 表现层迁移(2 周)
Phase 6: 废弃代码清理(1 周)
风险 1: API 兼容性破坏
风险 2: 性能退化
风险 3: 数据迁移失败
风险 4: 团队学习曲线
本设计文档描述了 RAG 系统重构的完整方案,采用领域驱动设计和四层架构,通过分阶段迁移策略确保平滑过渡。关键设计决策包括:
通过这个重构方案,系统将获得更好的可维护性、可扩展性和可测试性,为未来的功能扩展奠定坚实的基础。