File size: 3,415 Bytes
06998c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Use torchvision instead of transformers to perform resize and center crop.

This is because transformers' version is sometimes 1-pixel off.

For example, if the image size is 640x480, both results are consistent.
(e.g., "http://images.cocodataset.org/val2017/000000039769.jpg")

However, if the image size is 500x334, the following happens:
(e.g., "http://images.cocodataset.org/val2014/COCO_val2014_000000324158.jpg")

>>> # Results' shape: (h, w)
>>> torch.allclose(torchvision_result[:, :-1], transformers_result[:, 1:])
... True

Note that if only resize is performed with torchvision,
the inconsistency remains.
Therefore, center crop must also be done with torchvision.
"""

import PIL
from torchvision.transforms import CenterCrop, InterpolationMode, Resize
from transformers import AutoImageProcessor, CLIPImageProcessor
from transformers.image_processing_utils import get_size_dict
from transformers.image_utils import ImageInput, PILImageResampling, make_list_of_images


def PILImageResampling_to_InterpolationMode(
    resample: PILImageResampling,
) -> InterpolationMode:
    return getattr(InterpolationMode, PILImageResampling(resample).name)


class CustomCLIPImageProcessor(CLIPImageProcessor):
    def preprocess(
        self,
        images: ImageInput,
        do_resize: bool = None,
        size: dict[str, int] = None,
        resample: PILImageResampling = None,
        do_center_crop: bool = None,
        crop_size: int = None,
        **kwargs,
    ) -> PIL.Image.Image:
        do_resize = do_resize if do_resize is not None else self.do_resize
        size = size if size is not None else self.size
        resample = resample if resample is not None else self.resample
        do_center_crop = (
            do_center_crop if do_center_crop is not None else self.do_center_crop
        )
        crop_size = crop_size if crop_size is not None else self.crop_size

        images = make_list_of_images(images)

        if do_resize:
            # TODO input_data_format is ignored
            _size = get_size_dict(
                size,
                param_name="size",
                default_to_square=getattr(self, "use_square_size", False),
            )
            if set(_size) == {"shortest_edge"}:
                # Corresponds to `image_transform.transforms[0]`
                resize = Resize(
                    size=_size["shortest_edge"],
                    interpolation=PILImageResampling_to_InterpolationMode(resample),
                )
                images = [resize(image) for image in images]
                do_resize = False

        if do_center_crop:
            # TODO input_data_format is ignored
            _crop_size = get_size_dict(
                crop_size, param_name="crop_size", default_to_square=True
            )
            # Corresponds to `image_transform.transforms[1]`
            center_crop = CenterCrop(
                size=tuple(map(_crop_size.get, ["height", "width"]))
            )
            images = [center_crop(image) for image in images]
            do_center_crop = False

        return super().preprocess(
            images=images,
            do_resize=do_resize,
            size=size,
            resample=resample,
            do_center_crop=do_center_crop,
            crop_size=crop_size,
            **kwargs,
        )


AutoImageProcessor.register("CustomCLIPImageProcessor", CustomCLIPImageProcessor)