picture_stitching_node.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. """
  2. 图片拼接节点
  3. 将拆分的PDF页面图片按页码顺序拼接成一张长图。
  4. """
  5. from typing import Dict, Any
  6. from io import BytesIO
  7. from PIL import Image
  8. from src.datasets.parser.core.base import BaseNode, BaseState
  9. from src.datasets.parser.core.registry import register_node
  10. from src.utils.file.image_util import ImageUtil
  11. from src.common.logging_config import get_logger
  12. logger = get_logger(__name__)
  13. @register_node()
  14. class PictureStitchingNode(BaseNode):
  15. """
  16. 图片拼接节点
  17. 将拆分后的PDF页面图片按页码顺序垂直拼接成一张长图。
  18. 需要的状态字段:
  19. - split_pages: 拆分后的页面列表,每个元素包含:
  20. - page_number: 页码
  21. - image: PIL图像对象
  22. 更新的状态字段:
  23. - book_image: 拼接后的完整书本图片
  24. """
  25. @property
  26. def name(self) -> str:
  27. return "picture_stitching"
  28. def execute(self, state: BaseState) -> Dict[str, Any]:
  29. """
  30. 执行图片拼接
  31. Args:
  32. state: 包含split_pages的状态
  33. Returns:
  34. 包含book_image的更新字典
  35. """
  36. split_pages = getattr(state, 'split_pages', None)
  37. if not split_pages:
  38. raise ValueError("State must contain 'split_pages' field with image data")
  39. if not split_pages:
  40. raise ValueError("split_pages is empty, no images to stitch")
  41. logger.info(f"开始拼接图片,共 {len(split_pages)} 页")
  42. # 按页码排序
  43. sorted_pages = sorted(split_pages, key=lambda x: x.get('page_number', 0))
  44. # 提取所有图片
  45. images = []
  46. for page in sorted_pages:
  47. image = page.get('image')
  48. if image is None:
  49. logger.warning(f"页码 {page.get('page_number')} 的图片为空,跳过")
  50. continue
  51. if not isinstance(image, Image.Image):
  52. logger.warning(f"页码 {page.get('page_number')} 的图片类型不正确: {type(image)},跳过")
  53. continue
  54. images.append(image)
  55. if not images:
  56. raise ValueError("没有有效的图片可以拼接")
  57. logger.info(f"有效图片数量: {len(images)}")
  58. # 计算拼接后图片的尺寸
  59. # 宽度取所有图片的最大宽度
  60. max_width = max(img.width for img in images)
  61. # 高度为所有图片高度之和
  62. total_height = sum(img.height for img in images)
  63. logger.info(f"拼接后图片尺寸: {max_width}x{total_height}")
  64. # 创建新的空白图片
  65. stitched_image = Image.new('RGB', (max_width, total_height), color='white')
  66. # 垂直拼接所有图片
  67. current_y = 0
  68. for idx, img in enumerate(images):
  69. # 如果图片宽度小于最大宽度,将其居中放置
  70. x_offset = (max_width - img.width) // 2
  71. # 将图片粘贴到目标位置
  72. stitched_image.paste(img, (x_offset, current_y))
  73. # 更新当前y坐标
  74. current_y += img.height
  75. logger.debug(f"已拼接第 {idx + 1}/{len(images)} 页,当前高度: {current_y}")
  76. logger.info(f"图片拼接完成,最终尺寸: {stitched_image.size}")
  77. # 将合成后的图片进行压缩
  78. image_util = ImageUtil()
  79. # 检查像素数量是否超过Pillow安全限制
  80. max_pixels = Image.MAX_IMAGE_PIXELS
  81. total_pixels = stitched_image.width * stitched_image.height
  82. if total_pixels > max_pixels:
  83. logger.warning(f"图片像素数 ({total_pixels}) 超过安全限制 ({max_pixels}),进行缩放处理")
  84. # 计算缩放比例,将像素数降到安全限制的80%
  85. target_pixels = max_pixels * 0.8
  86. scale_ratio = (target_pixels / total_pixels) ** 0.5
  87. new_width = int(stitched_image.width * scale_ratio)
  88. new_height = int(stitched_image.height * scale_ratio)
  89. stitched_image = stitched_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
  90. logger.info(f"图片已缩放至: {stitched_image.size}")
  91. image_stream = BytesIO()
  92. stitched_image.save(image_stream, format='JPEG')
  93. image_stream.seek(0)
  94. compressed_bytes = image_util._compress_image_to_bytes(image_stream)
  95. compressed_image = Image.open(BytesIO(compressed_bytes))
  96. return {
  97. "book_image": compressed_image
  98. }