""" Use torchvision instead of transformers to perform resize and center crop. This is because transformers' version is sometimes 1-pixel off. For example, if the image size is 640x480, both results are consistent. (e.g., "http://images.cocodataset.org/val2017/000000039769.jpg") However, if the image size is 500x334, the following happens: (e.g., "http://images.cocodataset.org/val2014/COCO_val2014_000000324158.jpg") >>> # Results' shape: (h, w) >>> torch.allclose(torchvision_result[:, :-1], transformers_result[:, 1:]) ... True Note that if only resize is performed with torchvision, the inconsistency remains. Therefore, center crop must also be done with torchvision. """ import PIL from torchvision.transforms import CenterCrop, InterpolationMode, Resize from transformers import AutoImageProcessor, CLIPImageProcessor from transformers.image_processing_utils import get_size_dict from transformers.image_utils import ImageInput, PILImageResampling, make_list_of_images def PILImageResampling_to_InterpolationMode( resample: PILImageResampling, ) -> InterpolationMode: return getattr(InterpolationMode, PILImageResampling(resample).name) class CustomCLIPImageProcessor(CLIPImageProcessor): def preprocess( self, images: ImageInput, do_resize: bool = None, size: dict[str, int] = None, resample: PILImageResampling = None, do_center_crop: bool = None, crop_size: int = None, **kwargs, ) -> PIL.Image.Image: do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size resample = resample if resample is not None else self.resample do_center_crop = ( do_center_crop if do_center_crop is not None else self.do_center_crop ) crop_size = crop_size if crop_size is not None else self.crop_size images = make_list_of_images(images) if do_resize: # TODO input_data_format is ignored _size = get_size_dict( size, param_name="size", default_to_square=getattr(self, "use_square_size", False), ) if set(_size) == {"shortest_edge"}: # Corresponds to `image_transform.transforms[0]` resize = Resize( size=_size["shortest_edge"], interpolation=PILImageResampling_to_InterpolationMode(resample), ) images = [resize(image) for image in images] do_resize = False if do_center_crop: # TODO input_data_format is ignored _crop_size = get_size_dict( crop_size, param_name="crop_size", default_to_square=True ) # Corresponds to `image_transform.transforms[1]` center_crop = CenterCrop( size=tuple(map(_crop_size.get, ["height", "width"])) ) images = [center_crop(image) for image in images] do_center_crop = False return super().preprocess( images=images, do_resize=do_resize, size=size, resample=resample, do_center_crop=do_center_crop, crop_size=crop_size, **kwargs, ) AutoImageProcessor.register("CustomCLIPImageProcessor", CustomCLIPImageProcessor)