|
import torch |
|
import random |
|
import torchvision.transforms as T |
|
import numpy as np |
|
|
|
class RandAug: |
|
"""Randomly chosen image augmentations.""" |
|
|
|
def __init__(self): |
|
|
|
self.trans = ['identity', 'color', 'sharpness', 'blur'] |
|
|
|
def __call__(self, img): |
|
self.choice = random.choices(self.trans, weights=(25, 25, 25, 25))[0] |
|
|
|
if self.choice == 'identity': |
|
return img |
|
|
|
elif self.choice == 'color': |
|
rand_brightness = random.uniform(0, 0.3) |
|
rand_hue = random.uniform(0, 0.5) |
|
rand_contrast = random.uniform(0, 0.5) |
|
rand_saturation = random.uniform(0, 0.5) |
|
trans = T.ColorJitter(brightness=rand_brightness, contrast=rand_contrast, saturation=rand_saturation, hue=rand_hue) |
|
img = trans(img) |
|
|
|
elif self.choice=='sharpness': |
|
sharpness = 1+(np.random.exponential()/2) |
|
trans = T.RandomAdjustSharpness(sharpness, p=1) |
|
img = trans(img) |
|
|
|
elif self.choice=='blur': |
|
kernel = random.choice([1,3,5]) |
|
trans = T.GaussianBlur(kernel, sigma=(0.1, 2.0)) |
|
img = trans(img) |
|
|
|
return img |
|
|
|
|
|
class RandRotate: |
|
"""Randomly chosen image augmentations.""" |
|
|
|
def __init__(self, low = 0, high = 180): |
|
|
|
self.rotation = torch.randint(low=low, high=high, size=(1,)).item() |
|
self.trans = ['identity', 'rotation'] |
|
|
|
def __call__(self, img, mask): |
|
self.choice = random.choices(self.trans, weights=(50, 50))[0] |
|
|
|
if self.choice == 'identity': |
|
return img, mask |
|
|
|
elif self.choice == 'rotation': |
|
rotated_img = T.functional.rotate(img=img, angle=self.rotation, expand=False) |
|
rotated_mask = T.functional.rotate(img=mask, angle=self.rotation, expand=False) |
|
return rotated_img, rotated_mask |