|
@@ -4,7 +4,6 @@ QA对生成节点
|
|
|
从文本块生成问答对。
|
|
从文本块生成问答对。
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
-import json
|
|
|
|
|
import concurrent.futures
|
|
import concurrent.futures
|
|
|
from typing import Dict, Any, List, Optional
|
|
from typing import Dict, Any, List, Optional
|
|
|
from src.datasets.parser.core.base import BaseNode, BaseState
|
|
from src.datasets.parser.core.base import BaseNode, BaseState
|
|
@@ -118,6 +117,140 @@ class QAGenerateNode(BaseNode):
|
|
|
return []
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+ def execute(self, state: BaseState) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 执行QA对生成
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ state: 包含text_chunks的状态
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 包含qa_pairs的更新字典
|
|
|
|
|
+ """
|
|
|
|
|
+ text_chunks = getattr(state, 'text_chunks', [])
|
|
|
|
|
+ qa_count_per_chunk = getattr(state, 'qa_count_per_chunk', self.default_qa_count)
|
|
|
|
|
+ qa_pairs = getattr(state, 'qa_pairs', [])
|
|
|
|
|
+ qa_count = getattr(state, 'qa_count', 0)
|
|
|
|
|
+
|
|
|
|
|
+ if not text_chunks:
|
|
|
|
|
+ logger.warning("没有待处理的文本块")
|
|
|
|
|
+ return {"qa_pairs": qa_pairs, "qa_count": qa_count}
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"开始并行生成QA对,共 {len(text_chunks)} 块,每块 {qa_count_per_chunk} 个")
|
|
|
|
|
+
|
|
|
|
|
+ all_qa_pairs = qa_pairs
|
|
|
|
|
+
|
|
|
|
|
+ # 使用全局线程池
|
|
|
|
|
+ pool = ThreadPoolManager.get_pool("parser")
|
|
|
|
|
+ futures = [
|
|
|
|
|
+ pool.submit(self._generate_qa_for_chunk, chunk, qa_count_per_chunk, i)
|
|
|
|
|
+ for i, chunk in enumerate(text_chunks)
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ for future in concurrent.futures.as_completed(futures):
|
|
|
|
|
+ try:
|
|
|
|
|
+ current_qa_pairs = future.result()
|
|
|
|
|
+ all_qa_pairs.extend(current_qa_pairs)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"QA生成任务失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"QA对生成完成,共 {len(all_qa_pairs)} 个")
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "qa_pairs": all_qa_pairs,
|
|
|
|
|
+ "qa_count": len(all_qa_pairs)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@register_node()
|
|
|
|
|
+class QABaseGenerateNode(BaseNode):
|
|
|
|
|
+ """
|
|
|
|
|
+ QA对基础生成节点
|
|
|
|
|
+
|
|
|
|
|
+ 从文本块并行生成问答对。
|
|
|
|
|
+
|
|
|
|
|
+ 需要的状态字段:
|
|
|
|
|
+ - text_chunks: 文本块列表
|
|
|
|
|
+ - qa_count_per_chunk: 每块生成的QA数量(可选)
|
|
|
|
|
+
|
|
|
|
|
+ 更新的状态字段:
|
|
|
|
|
+ - qa_pairs: QA对列表
|
|
|
|
|
+ - qa_count: 生成的QA总数
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ model_name: Optional[str] = None,
|
|
|
|
|
+ default_qa_count: int = 50,
|
|
|
|
|
+ max_workers: int = 5
|
|
|
|
|
+ ):
|
|
|
|
|
+ """
|
|
|
|
|
+ 初始化QA生成节点
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ model_name: 模型名称
|
|
|
|
|
+ default_qa_count: 每块默认生成的QA数量
|
|
|
|
|
+ max_workers: 并行处理的最大工作线程数
|
|
|
|
|
+ """
|
|
|
|
|
+ self.model_name = model_name or model_settings.chat_model_name
|
|
|
|
|
+ self.default_qa_count = default_qa_count
|
|
|
|
|
+ self.max_workers = max_workers
|
|
|
|
|
+
|
|
|
|
|
+ @property
|
|
|
|
|
+ def name(self) -> str:
|
|
|
|
|
+ return "qa_base_generate"
|
|
|
|
|
+
|
|
|
|
|
+ def _generate_base_qa_for_chunk(self, chunk: str, count: int, chunk_index: int) -> List[Dict[str, str]]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 为单个分块生成QA对
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ chunk: 文本块内容
|
|
|
|
|
+ count: 生成的QA数量
|
|
|
|
|
+ chunk_index: 块索引
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ QA对列表
|
|
|
|
|
+ """
|
|
|
|
|
+ system_prompt = """你是一个专业的问答对生成专家。你的任务是根据给定的文本内容,生成高质量的问答对。
|
|
|
|
|
+
|
|
|
|
|
+ 要求:
|
|
|
|
|
+ 1. 问题应该覆盖内容的不同方面和细节
|
|
|
|
|
+ 2. 答案应该准确、简洁,直接来自文本
|
|
|
|
|
+ 3. 问答对应该有助于知识检索和理解
|
|
|
|
|
+ 4. 避免生成过于简单或重复的问题
|
|
|
|
|
+ 5. 确保问题是自包含的,不需要额外上下文即可理解"""
|
|
|
|
|
+
|
|
|
|
|
+ user_prompt = f"""请根据以下内容,生成 {count} 个高质量的问答对。
|
|
|
|
|
+
|
|
|
|
|
+ 内容:
|
|
|
|
|
+ {chunk}
|
|
|
|
|
+
|
|
|
|
|
+ 请严格以JSON格式输出,格式如下:
|
|
|
|
|
+ [
|
|
|
|
|
+ {{"question": "问题1", "answer": "答案1"}},
|
|
|
|
|
+ {{"question": "问题2", "answer": "答案2"}}
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ 只输出JSON数组,不要添加任何其他内容。"""
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ logger.debug(f"开始为第 {chunk_index + 1} 块生成QA对")
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ parser = QWenVLParser(self.model_name)
|
|
|
|
|
+ result = parser.chat(prompt=user_prompt, system_prompt=system_prompt)
|
|
|
|
|
+
|
|
|
|
|
+ # 解析JSON响应(使用通用工具函数)
|
|
|
|
|
+ qa_pairs = parse_qa_response(result)
|
|
|
|
|
+ logger.debug(f"第 {chunk_index + 1} 块生成 {len(qa_pairs)} 个QA对")
|
|
|
|
|
+ return qa_pairs
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"第 {chunk_index + 1} 块QA生成失败: {str(e)}")
|
|
|
|
|
+ return []
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def execute(self, state: BaseState) -> Dict[str, Any]:
|
|
def execute(self, state: BaseState) -> Dict[str, Any]:
|
|
|
"""
|
|
"""
|
|
|
执行QA对生成
|
|
执行QA对生成
|
|
@@ -142,7 +275,7 @@ class QAGenerateNode(BaseNode):
|
|
|
# 使用全局线程池
|
|
# 使用全局线程池
|
|
|
pool = ThreadPoolManager.get_pool("parser")
|
|
pool = ThreadPoolManager.get_pool("parser")
|
|
|
futures = [
|
|
futures = [
|
|
|
- pool.submit(self._generate_qa_for_chunk, chunk, qa_count_per_chunk, i)
|
|
|
|
|
|
|
+ pool.submit(self._generate_base_qa_for_chunk, chunk, qa_count_per_chunk, i)
|
|
|
for i, chunk in enumerate(text_chunks)
|
|
for i, chunk in enumerate(text_chunks)
|
|
|
]
|
|
]
|
|
|
|
|
|
|
@@ -151,11 +284,11 @@ class QAGenerateNode(BaseNode):
|
|
|
qa_pairs = future.result()
|
|
qa_pairs = future.result()
|
|
|
all_qa_pairs.extend(qa_pairs)
|
|
all_qa_pairs.extend(qa_pairs)
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
- logger.error(f"QA生成任务失败: {str(e)}")
|
|
|
|
|
|
|
+ logger.error(f"QA基础生成任务失败: {str(e)}")
|
|
|
|
|
|
|
|
- logger.info(f"QA对生成完成,共 {len(all_qa_pairs)} 个")
|
|
|
|
|
|
|
+ logger.info(f"QA对基础生成完成,共 {len(all_qa_pairs)} 个")
|
|
|
|
|
|
|
|
return {
|
|
return {
|
|
|
"qa_pairs": all_qa_pairs,
|
|
"qa_pairs": all_qa_pairs,
|
|
|
"qa_count": len(all_qa_pairs)
|
|
"qa_count": len(all_qa_pairs)
|
|
|
- }
|
|
|
|
|
|
|
+ }
|