Transformers SigLip2 processor的一个坑

3 分钟阅读时长

发布时间:


Siglip2 processor 的一个坑

用huggingface上的模型通用的流程(假设你不想pipeline) 假如你需要推理一批数据(而不写在dataloader里面,因为dataloader可以并行处理),你需要手动处理batch

from transformers import AutoProcessor, AutoModel

processor = AutoProcessor.from_pretrained("microsoft/siglip2-base")
model = AutoModel.from_pretrained("microsoft/siglip2-base").to("cuda")

image_batch = [Image.open("image.jpg")] * batch_size
text_batch = ["A photo of a cat"] * batch_size
inputs = processor(text=text_batch, images=image_batch, return_tensors="pt").to("cuda")
outputs = model(**inputs)

假如batch_size比较大,理论上讲能提高模型推理的并行度 但实际用起来傻眼了,还是慢的要死 打印出每个过程的时间:

Preprocess Time: 0.16 seconds
Model Forward Time: 0.009 seconds

大量时间都花在预处理上了

扒siglip2 image processor的代码(google开源的哦)

def _preprocess(
        self,
        images: list["torch.Tensor"],
        do_resize: bool,
        patch_size: int,
        max_num_patches: int,
        interpolation: Optional["F.InterpolationMode"],
        do_rescale: bool,
        rescale_factor: float,
        do_normalize: bool,
        image_mean: Optional[Union[float, list[float]]],
        image_std: Optional[Union[float, list[float]]],
        return_tensors: Optional[Union[str, TensorType]],
        **kwargs,
    ) -> BatchFeature:
        pixel_masks = []
        pixel_values = []
        spatial_shapes = []

        ## 一个非常愚蠢的for 循环
        for image in images:
            if do_resize:
                height, width = get_image_size_for_max_num_patches(
                    image_height=image.shape[1],
                    image_width=image.shape[2],
                    patch_size=patch_size,
                    max_num_patches=max_num_patches,
                )
                side_dict = SizeDict(height=height, width=width)
                image = self.resize(image=image, size=side_dict, interpolation=interpolation)

            image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)

            # (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels)
            patches = convert_image_to_patches(image, patch_size)
            patches, mask = pad_along_first_dim(patches, max_num_patches)

            num_patches_height = image.shape[1] // patch_size
            num_patches_width = image.shape[2] // patch_size

            spatial_shapes.append((num_patches_height, num_patches_width))
            pixel_values.append(patches)
            pixel_masks.append(mask)

        pixel_values = torch.stack(pixel_values)
        pixel_masks = torch.stack(pixel_masks)
        spatial_shapes = torch.tensor(spatial_shapes)

        batch_feature = BatchFeature(
            data={
                "pixel_values": pixel_values,
                "pixel_attention_mask": pixel_masks,
                "spatial_shapes": spatial_shapes,
            },
            tensor_type=return_tensors,
        )
        return batch_feature

用了一个非常愚蠢的for循环串行处理batch内的每个图片,无非也就是一些resize,crop之类的操作

在外面套了一个并行,速度快了很多,但还是不优美

让cursor帮忙实现了一个torch并行的版本(假设图像大小一致)

class Siglip2ImageProcessorFaster(Siglip2ImageProcessorFast):
    def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]):
        super().__init__(**kwargs)

    def process_image(self, image):
        return super()._preprocess(image)

    def _preprocess(
        self,
        images: "torch.Tensor",
        do_resize: bool,
        patch_size: int,
        max_num_patches: int,
        interpolation: Optional["F.InterpolationMode"],
        do_rescale: bool,
        rescale_factor: float,
        do_normalize: bool,
        image_mean: Optional[Union[float, list[float]]],
        image_std: Optional[Union[float, list[float]]],
        return_tensors: Optional[Union[str, TensorType]],
        **kwargs,
    ) -> BatchFeature:
        # 如果输入是list,转换为tensor(兼容原有接口)
        if isinstance(images, list):
            images = torch.stack(images)

        # 输入应该是 (batch_size, num_channels, height, width) 的tensor
        batch_size, num_channels, image_height, image_width = images.shape

        if do_resize:
            # 计算目标尺寸(所有图像大小相同,所以只需要计算一次)
            height, width = get_image_size_for_max_num_patches(
                image_height=image_height,
                image_width=image_width,
                patch_size=patch_size,
                max_num_patches=max_num_patches,
            )
            side_dict = SizeDict(height=height, width=width)
            # 批量resize:torchvision的F.resize支持4D tensor
            images = self.resize(
                image=images, size=side_dict, interpolation=interpolation
            )

        # 批量rescale和normalize
        images = self.rescale_and_normalize(
            images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
        )

        # 批量转换为patches: (batch_size, num_channels, height, width) -> (batch_size, num_patches, patch_size * patch_size * num_channels)
        patches = self._convert_batch_to_patches(images, patch_size)

        # 批量padding: (batch_size, num_patches, ...) -> (batch_size, max_num_patches, ...)
        patches, pixel_masks = self._pad_batch_along_patch_dim(patches, max_num_patches)

        # 计算spatial shapes
        num_patches_height = images.shape[2] // patch_size
        num_patches_width = images.shape[3] // patch_size
        spatial_shapes = torch.tensor(
            [(num_patches_height, num_patches_width)] * batch_size, dtype=torch.int32
        )

        batch_feature = BatchFeature(
            data={
                "pixel_values": patches,
                "pixel_attention_mask": pixel_masks,
                "spatial_shapes": spatial_shapes,
            },
            tensor_type=return_tensors,
        )
        return batch_feature

    def _convert_batch_to_patches(
        self, images: "torch.Tensor", patch_size: int
    ) -> "torch.Tensor":
        """
        批量将图像转换为patches。

        Args:
            images: (batch_size, num_channels, height, width) 的tensor
            patch_size: patch大小

        Returns:
            (batch_size, num_patches_height * num_patches_width, patch_size * patch_size * num_channels) 的tensor
        """
        batch_size, num_channels, image_height, image_width = images.shape
        num_patches_height = image_height // patch_size
        num_patches_width = image_width // patch_size

        # Reshape: (batch_size, num_channels, num_patches_height, patch_size, num_patches_width, patch_size)
        patched_images = images.reshape(
            batch_size,
            num_channels,
            num_patches_height,
            patch_size,
            num_patches_width,
            patch_size,
        )
        # Permute: (batch_size, num_patches_height, num_patches_width, patch_size, patch_size, num_channels)
        patched_images = patched_images.permute(0, 2, 4, 3, 5, 1)
        # Reshape: (batch_size, num_patches_height * num_patches_width, patch_size * patch_size * num_channels)
        patched_images = patched_images.reshape(
            batch_size, num_patches_height * num_patches_width, -1
        )
        return patched_images

    def _pad_batch_along_patch_dim(
        self, patches: "torch.Tensor", target_length: int, pad_value: int = 0
    ) -> tuple["torch.Tensor", "torch.Tensor"]:
        """
        批量在patch维度上padding。

        Args:
            patches: (batch_size, num_patches, patch_features) 的tensor
            target_length: 目标长度(max_num_patches)
            pad_value: padding值

        Returns:
            (padded_patches, masks) tuple
            - padded_patches: (batch_size, target_length, patch_features) 的tensor
            - masks: (batch_size, target_length) 的tensor,1表示有效,0表示padding
        """
        batch_size, current_length, *patch_dims = patches.shape
        padding_length = target_length - current_length

        if padding_length > 0:
            # 创建padding: [0, 0, ..., 0, padding_length] 表示只在最后一个维度(patch维度)padding
            padding = [0, 0] * len(patch_dims) + [0, padding_length]
            padded_patches = torch.nn.functional.pad(
                patches, padding, mode="constant", value=pad_value
            )

            # 创建mask: (batch_size, target_length)
            masks = torch.ones(
                (batch_size, target_length), dtype=torch.int32, device=patches.device
            )
            masks[:, -padding_length:] = 0
        else:
            padded_patches = patches
            masks = torch.ones(
                (batch_size, target_length), dtype=torch.int32, device=patches.device
            )

        return padded_patches, masks

速度终于正常了(预处理的时间不到推理的时间的1/8)

阅读量: 加载中...