from torchvision.transforms import ( | |
Normalize, | |
Compose, | |
RandomResizedCrop, | |
InterpolationMode, | |
ToTensor, | |
Resize, | |
CenterCrop, | |
) | |
def _convert_to_rgb(image): | |
return image.convert("RGB") | |
def image_transform( | |
image_size: int, | |
is_train: bool, | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711), | |
): | |
normalize = Normalize(mean=mean, std=std) | |
if is_train: | |
return Compose( | |
[ | |
RandomResizedCrop( | |
image_size, | |
scale=(0.9, 1.0), | |
interpolation=InterpolationMode.BICUBIC, | |
), | |
_convert_to_rgb, | |
ToTensor(), | |
normalize, | |
] | |
) | |
else: | |
return Compose( | |
[ | |
Resize(image_size, interpolation=InterpolationMode.BICUBIC), | |
CenterCrop(image_size), | |
_convert_to_rgb, | |
ToTensor(), | |
normalize, | |
] | |
) | |