| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- """
- 图片拼接节点
- 将拆分的PDF页面图片按页码顺序拼接成一张长图。
- """
- from typing import Dict, Any
- from io import BytesIO
- from PIL import Image
- from src.datasets.parser.core.base import BaseNode, BaseState
- from src.datasets.parser.core.registry import register_node
- from src.utils.file.image_util import ImageUtil
- from src.common.logging_config import get_logger
- logger = get_logger(__name__)
- @register_node()
- class PictureStitchingNode(BaseNode):
- """
- 图片拼接节点
-
- 将拆分后的PDF页面图片按页码顺序垂直拼接成一张长图。
-
- 需要的状态字段:
- - split_pages: 拆分后的页面列表,每个元素包含:
- - page_number: 页码
- - image: PIL图像对象
-
- 更新的状态字段:
- - book_image: 拼接后的完整书本图片
- """
-
- @property
- def name(self) -> str:
- return "picture_stitching"
-
- def execute(self, state: BaseState) -> Dict[str, Any]:
- """
- 执行图片拼接
-
- Args:
- state: 包含split_pages的状态
-
- Returns:
- 包含book_image的更新字典
- """
- split_pages = getattr(state, 'split_pages', None)
-
- if not split_pages:
- raise ValueError("State must contain 'split_pages' field with image data")
-
- if not split_pages:
- raise ValueError("split_pages is empty, no images to stitch")
-
- logger.info(f"开始拼接图片,共 {len(split_pages)} 页")
-
- # 按页码排序
- sorted_pages = sorted(split_pages, key=lambda x: x.get('page_number', 0))
-
- # 提取所有图片
- images = []
- for page in sorted_pages:
- image = page.get('image')
- if image is None:
- logger.warning(f"页码 {page.get('page_number')} 的图片为空,跳过")
- continue
- if not isinstance(image, Image.Image):
- logger.warning(f"页码 {page.get('page_number')} 的图片类型不正确: {type(image)},跳过")
- continue
- images.append(image)
-
- if not images:
- raise ValueError("没有有效的图片可以拼接")
-
- logger.info(f"有效图片数量: {len(images)}")
-
- # 计算拼接后图片的尺寸
- # 宽度取所有图片的最大宽度
- max_width = max(img.width for img in images)
- # 高度为所有图片高度之和
- total_height = sum(img.height for img in images)
-
- logger.info(f"拼接后图片尺寸: {max_width}x{total_height}")
-
- # 创建新的空白图片
- stitched_image = Image.new('RGB', (max_width, total_height), color='white')
-
- # 垂直拼接所有图片
- current_y = 0
- for idx, img in enumerate(images):
- # 如果图片宽度小于最大宽度,将其居中放置
- x_offset = (max_width - img.width) // 2
-
- # 将图片粘贴到目标位置
- stitched_image.paste(img, (x_offset, current_y))
-
- # 更新当前y坐标
- current_y += img.height
-
- logger.debug(f"已拼接第 {idx + 1}/{len(images)} 页,当前高度: {current_y}")
-
- logger.info(f"图片拼接完成,最终尺寸: {stitched_image.size}")
-
- # 将合成后的图片进行压缩
- image_util = ImageUtil()
-
- # 检查像素数量是否超过Pillow安全限制
- max_pixels = Image.MAX_IMAGE_PIXELS
- total_pixels = stitched_image.width * stitched_image.height
- if total_pixels > max_pixels:
- logger.warning(f"图片像素数 ({total_pixels}) 超过安全限制 ({max_pixels}),进行缩放处理")
- # 计算缩放比例,将像素数降到安全限制的80%
- target_pixels = max_pixels * 0.8
- scale_ratio = (target_pixels / total_pixels) ** 0.5
- new_width = int(stitched_image.width * scale_ratio)
- new_height = int(stitched_image.height * scale_ratio)
- stitched_image = stitched_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
- logger.info(f"图片已缩放至: {stitched_image.size}")
-
- image_stream = BytesIO()
- stitched_image.save(image_stream, format='JPEG')
- image_stream.seek(0)
- compressed_bytes = image_util._compress_image_to_bytes(image_stream)
- compressed_image = Image.open(BytesIO(compressed_bytes))
-
- return {
- "book_image": compressed_image
- }
|