tablecell-htr / augments.py
MikkoLipsanen's picture
Upload 2 files
8713ab2 verified
import torch
import random
import torchvision.transforms as T
import numpy as np
class RandAug:
"""Randomly chosen image augmentations."""
def __init__(self):
# Augmentation options
self.trans = ['identity', 'color', 'sharpness', 'blur']
def __call__(self, img):
self.choice = random.choices(self.trans, weights=(25, 25, 25, 25))[0]
if self.choice == 'identity':
return img
elif self.choice == 'color':
rand_brightness = random.uniform(0, 0.3)
rand_hue = random.uniform(0, 0.5)
rand_contrast = random.uniform(0, 0.5)
rand_saturation = random.uniform(0, 0.5)
trans = T.ColorJitter(brightness=rand_brightness, contrast=rand_contrast, saturation=rand_saturation, hue=rand_hue)
img = trans(img)
elif self.choice=='sharpness':
sharpness = 1+(np.random.exponential()/2)
trans = T.RandomAdjustSharpness(sharpness, p=1)
img = trans(img)
elif self.choice=='blur':
kernel = random.choice([1,3,5])
trans = T.GaussianBlur(kernel, sigma=(0.1, 2.0))
img = trans(img)
return img
class RandRotate:
"""Randomly chosen image augmentations."""
def __init__(self, low = 0, high = 180):
# Augmentation options
self.rotation = torch.randint(low=low, high=high, size=(1,)).item()
self.trans = ['identity', 'rotation']
def __call__(self, img, mask):
self.choice = random.choices(self.trans, weights=(50, 50))[0]
if self.choice == 'identity':
return img, mask
elif self.choice == 'rotation':
rotated_img = T.functional.rotate(img=img, angle=self.rotation, expand=False)
rotated_mask = T.functional.rotate(img=mask, angle=self.rotation, expand=False)
return rotated_img, rotated_mask