workflow.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import sys
  2. import os
  3. import concurrent.futures
  4. from concurrent.futures import ThreadPoolExecutor
  5. # 添加项目根目录到Python路径
  6. sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  7. from langgraph.graph import StateGraph, START, END
  8. from langgraph.graph.message import add_messages
  9. from typing import List, Dict, Any, Annotated
  10. from pydantic import BaseModel, Field, ConfigDict
  11. from services.pdf_parser.pdf_splitter import PDFSplitter
  12. from services.model.qwen_vl import QWenVLParser
  13. from services.ragflow.ragflow_service import RAGFlowService
  14. from services.utils.vector_db import VectorDBFactory
  15. from services.model.multimodal_embedding import MultimodalEmbedding
  16. from conf.config import ModelConfig
  17. # 定义工作流状态类
  18. class PDFParsingState(BaseModel):
  19. """PDF解析工作流状态"""
  20. model_config = ConfigDict(arbitrary_types_allowed=True)
  21. pdf_path: str = Field(..., description="PDF文件路径")
  22. dataset_id: str = Field(..., description="数据集ID")
  23. ragflow_service: RAGFlowService = Field(default_factory=RAGFlowService, description="RAGFLOW服务")
  24. vector_db: Any = Field(default_factory=VectorDBFactory.get_vector_db, description="向量数据库实例")
  25. embedding_model: MultimodalEmbedding = Field(default_factory=MultimodalEmbedding, description="多模态嵌入模型实例")
  26. document_id: str = Field(default="", description="上传后的文档ID")
  27. split_pages: List[Dict[str, Any]] = Field(default_factory=list, description="拆分后的页面列表")
  28. current_page: Dict[str, Any] = Field(default_factory=dict, description="当前处理的页面")
  29. parsed_results: List[Dict[str, Any]] = Field(default_factory=list, description="解析结果列表")
  30. vectorized_results: List[Dict[str, Any]] = Field(default_factory=list, description="向量化结果列表")
  31. processed_pages: int = Field(default=0, description="已处理的页面数量")
  32. vectorized_pages: int = Field(default=0, description="已向量化的页面数量")
  33. is_complete: bool = Field(default=False, description="是否处理完成")
  34. # 创建工作流构建器
  35. class PDFParsingWorkflow:
  36. """PDF扫描件拆分解析工作流"""
  37. def __init__(self, model_name: str = "Qwen/Qwen3-VL-8B-Instruct"):
  38. """
  39. 初始化PDF解析工作流
  40. Args:
  41. model_name: QWEN VL模型名称
  42. """
  43. self.model_name = model_name
  44. self.workflow = self._build_workflow()
  45. def _build_workflow(self):
  46. """构建langgraph工作流,实现基于条件路由的并行处理"""
  47. # 创建状态图
  48. graph = StateGraph(PDFParsingState)
  49. # 添加上传文档节点
  50. graph.add_node("upload_document", self._upload_document_node)
  51. # 添加解析文档节点
  52. graph.add_node("parse_document", self._parse_document_node)
  53. # 添加拆分PDF节点
  54. graph.add_node("split_pdf", self._split_pdf_node)
  55. # 添加解析图像节点
  56. graph.add_node("parse_image", self._parse_image_node)
  57. # 添加向量化入库节点
  58. graph.add_node("vectorize_store", self._vectorize_store_node)
  59. # 添加完成节点
  60. graph.add_node("complete", self._complete_node)
  61. # 定义边
  62. graph.add_edge(START, "upload_document")
  63. # 添加解析文档边
  64. graph.add_edge("upload_document", "parse_document")
  65. graph.add_edge("parse_document", "split_pdf")
  66. graph.add_edge("split_pdf", "parse_image")
  67. # 添加条件边:判断是否继续解析
  68. graph.add_conditional_edges(
  69. "parse_image",
  70. self._should_continue_parsing,
  71. {
  72. "continue": "parse_image",
  73. "complete": "vectorize_store"
  74. }
  75. )
  76. # 添加向量化入库边
  77. graph.add_edge("vectorize_store", "complete")
  78. graph.add_edge("complete", END)
  79. # 编译工作流
  80. return graph.compile()
  81. def _upload_document_node(self, state: PDFParsingState) -> Dict[str, Any]:
  82. """RAGFLOW上传文档节点"""
  83. print(f"开始上传文档到数据集 {state.dataset_id}: {state.pdf_path}")
  84. try:
  85. # 上传文档
  86. document_info_list = state.ragflow_service.upload_document(
  87. dataset_id=state.dataset_id,
  88. file_path=state.pdf_path
  89. )
  90. # 检查响应
  91. if document_info_list and len(document_info_list) > 0:
  92. document_id = document_info_list[0]["id"]
  93. print(f"文档上传成功,文档ID: {document_id}")
  94. return {
  95. "document_id": document_id
  96. }
  97. else:
  98. print("文档上传失败: 未返回有效的文档信息")
  99. raise Exception("文档上传失败: 未返回有效的文档信息")
  100. except Exception as e:
  101. print(f"上传文档时出错: {str(e)}")
  102. raise
  103. def _parse_document_node(self, state: PDFParsingState) -> Dict[str, Any]:
  104. """RAGFLOW文档解析节点"""
  105. print(f"开始解析文档 {state.dataset_id}: {state.document_id}")
  106. try:
  107. # 解析文档
  108. parsed_results = state.ragflow_service.parse_document(
  109. dataset_id=state.dataset_id,
  110. document_ids=[state.document_id]
  111. )
  112. # 检查响应parsed_results为bool
  113. if parsed_results:
  114. print(f"文档解析成功,文档ID: {state.document_id}")
  115. return {
  116. "parsed_results": parsed_results
  117. }
  118. else:
  119. print("文档解析失败: 未返回有效的解析结果")
  120. raise Exception("文档解析失败: 未返回有效的解析结果")
  121. except Exception as e:
  122. print(f"解析文档时出错: {str(e)}")
  123. raise
  124. def _split_pdf_node(self, state: PDFParsingState) -> Dict[str, Any]:
  125. """拆分PDF节点"""
  126. print(f"开始拆分PDF: {state.pdf_path}")
  127. # 拆分PDF
  128. splitter = PDFSplitter()
  129. split_pages = splitter.split_pdf(state.pdf_path)
  130. print(f"PDF拆分完成,共 {len(split_pages)} 页")
  131. return {
  132. "split_pages": split_pages,
  133. "parsed_results": [],
  134. "processed_pages": 0,
  135. "is_complete": False
  136. }
  137. def _parse_single_page(self, page: Dict[str, Any], model_name: str) -> Dict[str, Any]:
  138. """解析单个页面(用于并行处理)"""
  139. prompt = """
  140. 你是一个文本提取助手,你的任务是从图像中提取出文本内容。
  141. 注意:不要修改任何文本原始内容与标点符号,只进行提取。
  142. """
  143. page_number = page["page_number"]
  144. image = page["image"]
  145. print(f"开始解析第 {page_number} 页")
  146. # 使用QWEN VL模型解析图像
  147. parser = QWenVLParser(model_name)
  148. result = parser.parse_image(image, page_number, prompt)
  149. print(f"第 {page_number} 页解析完成")
  150. return result
  151. def _parse_image_node(self, state: PDFParsingState) -> Dict[str, Any]:
  152. """解析图像节点,使用并行处理"""
  153. if not state.split_pages:
  154. return state.dict()
  155. print(f"开始并行解析 {len(state.split_pages)} 页")
  156. parsed_results = []
  157. # 使用ThreadPoolExecutor实现并行处理
  158. with ThreadPoolExecutor(max_workers=4) as executor:
  159. # 提交所有页面解析任务
  160. future_to_page = {
  161. executor.submit(self._parse_single_page, page, self.model_name): page
  162. for page in state.split_pages
  163. }
  164. # 收集解析结果
  165. for future in concurrent.futures.as_completed(future_to_page):
  166. try:
  167. result = future.result()
  168. parsed_results.append(result)
  169. except Exception as e:
  170. page = future_to_page[future]
  171. print(f"解析第 {page['page_number']} 页时出错: {str(e)}")
  172. # 按页码排序结果
  173. parsed_results.sort(key=lambda x: x["page_number"])
  174. print(f"所有页面解析完成,共解析 {len(parsed_results)} 页")
  175. return {
  176. "split_pages": state.split_pages, # 保留split_pages,以便后续访问图片
  177. "parsed_results": parsed_results,
  178. "processed_pages": len(parsed_results),
  179. "is_complete": True
  180. }
  181. def _should_continue_parsing(self, state: PDFParsingState) -> str:
  182. """判断是否继续解析"""
  183. # 由于我们使用了并行处理,parse_image_node会一次性处理所有页面
  184. # 所以这里总是返回"complete"
  185. return "complete"
  186. def _vectorize_store_node(self, state: PDFParsingState) -> Dict[str, Any]:
  187. """向量化入库节点"""
  188. print(f"开始向量化入库,共 {len(state.parsed_results)} 页")
  189. # 创建索引(如果不存在)
  190. index_name = f"pdf_documents_{state.dataset_id}"
  191. state.vector_db.create_index(index_name)
  192. # 准备要入库的文档列表
  193. documents_to_store = []
  194. # 获取文件名和总页数
  195. file_name = os.path.basename(state.pdf_path)
  196. file_page_count = len(state.split_pages)
  197. # 遍历所有解析结果,生成向量化文档
  198. for i, parsed_result in enumerate(state.parsed_results):
  199. try:
  200. page_number = parsed_result.get("page_number")
  201. text = parsed_result.get("text", "")
  202. image = parsed_result.get("image")
  203. # 生成图片地址(假设图片已保存)
  204. image_path = parsed_result.get("image_path", f"temp/{file_name}_{page_number}.png")
  205. # 获取多模态嵌入向量
  206. print(f"正在生成第 {page_number} 页的多模态嵌入...")
  207. embedding = state.embedding_model.get_multimodal_embedding(text, image)
  208. # 生成1024维稠密向量(如果嵌入向量维度不是1024,这里需要处理)
  209. dense_vector_1024 = embedding[:1024] # 取前1024维
  210. # 创建文档
  211. document = {
  212. "file_name": file_name,
  213. "file_page_count": file_page_count,
  214. "page_number": page_number,
  215. "text": text,
  216. "image_path": image_path,
  217. "sparse_vector": [], # 稀疏向量,暂时为空
  218. "dense_vector_1024": dense_vector_1024,
  219. "dataset_id": state.dataset_id,
  220. "document_id": state.document_id
  221. }
  222. documents_to_store.append(document)
  223. print(f"第 {page_number} 页向量化完成")
  224. except Exception as e:
  225. print(f"第 {i+1} 页向量化失败: {str(e)}")
  226. # 批量入库
  227. if documents_to_store:
  228. print(f"开始批量入库,共 {len(documents_to_store)} 个文档")
  229. result = state.vector_db.bulk_insert(index_name, documents_to_store)
  230. print(f"批量入库结果: {result}")
  231. return {
  232. "vectorized_results": documents_to_store,
  233. "vectorized_pages": len(documents_to_store),
  234. "is_complete": True
  235. }
  236. def _complete_node(self, state: PDFParsingState) -> Dict[str, Any]:
  237. """完成节点"""
  238. print(f"PDF解析工作流完成,共解析 {len(state.parsed_results)} 页,向量化 {state.vectorized_pages} 页")
  239. return {
  240. "is_complete": True
  241. }
  242. def run(self, pdf_path: str, dataset_id: str, ragflow_api_url: str, rag_flow_api_key: str) -> Dict[str, Any]:
  243. """
  244. 运行PDF解析工作流
  245. Args:
  246. pdf_path: PDF文件路径
  247. dataset_id: 数据集ID
  248. ragflow_api_url: RAGFLOW API URL
  249. rag_flow_api_key: RAGFLOW API密钥
  250. Returns:
  251. Dict: 包含最终状态的字典
  252. """
  253. initial_state = PDFParsingState(
  254. pdf_path=pdf_path,
  255. dataset_id=dataset_id,
  256. ragflow_service=RAGFlowService(base_url=ragflow_api_url, api_key=rag_flow_api_key)
  257. )
  258. result = self.workflow.invoke(initial_state)
  259. # 检查结果类型,如果是字典直接返回,否则调用dict()方法
  260. if isinstance(result, dict):
  261. return result
  262. else:
  263. return result.dict()