ljsabc's picture
Initial commit.
395d300
raw
history blame
8.2 kB
import numpy as np
from random import choice as rchoice
from random import randint
import random
import cv2, traceback, imageio
import os.path as osp
from typing import Optional, List, Union, Tuple, Dict
from utils.io_utils import imread_nogrey_rgb, json2dict
from .transforms import rotate_image
from utils.logger import LOGGER
class NameSampler:
def __init__(self, name_prob_dict, sample_num=2048) -> None:
self.name_prob_dict = name_prob_dict
self._id2name = list(name_prob_dict.keys())
self.sample_ids = []
total_prob = 0.
for ii, (_, prob) in enumerate(name_prob_dict.items()):
tgt_num = int(prob * sample_num)
total_prob += prob
if tgt_num > 0:
self.sample_ids += [ii] * tgt_num
nsamples = len(self.sample_ids)
assert prob <= 1
if prob < 1 and nsamples < sample_num:
self.sample_ids += [len(self._id2name)] * (sample_num - nsamples)
self._id2name.append('_')
def sample(self) -> str:
return self._id2name[rchoice(self.sample_ids)]
class PossionSampler:
def __init__(self, lam=3, min_val=1, max_val=8) -> None:
self._distr = np.random.poisson(lam, 1024)
invalid = np.where(np.logical_or(self._distr<min_val, self._distr > max_val))
self._distr[invalid] = np.random.randint(min_val, max_val, len(invalid[0]))
def sample(self) -> int:
return rchoice(self._distr)
class NormalSampler:
def __init__(self, loc=0.33, std=0.2, min_scale=0.15, max_scale=0.85, scalar=1, to_int = True):
s = np.random.normal(loc, std, 4096)
valid = np.where(np.logical_and(s>min_scale, s<max_scale))
self._distr = s[valid] * scalar
if to_int:
self._distr = self._distr.astype(np.int32)
def sample(self) -> int:
return rchoice(self._distr)
class PersonBBoxSampler:
def __init__(self, sample_path: Union[str, List]='data/cocoperson_bbox_samples.json', fg_info_list: List = None, fg_transform=None, is_train=True) -> None:
if isinstance(sample_path, str):
sample_path = [sample_path]
self.bbox_list = []
for sp in sample_path:
bboxlist = json2dict(sp)
for bboxes in bboxlist:
if isinstance(bboxes, dict):
bboxes = bboxes['bboxes']
bboxes = np.array(bboxes)
bboxes[:, [0, 1]] -= bboxes[:, [0, 1]].min(axis=0)
self.bbox_list.append(bboxes)
self.fg_info_list = fg_info_list
self.fg_transform = fg_transform
self.is_train = is_train
def sample(self, tgt_size: int, scale_range=(1, 1), size_thres=(0.02, 0.85)) -> List[np.ndarray]:
bboxes_normalized = rchoice(self.bbox_list)
if scale_range[0] != 1 or scale_range[1] != 1:
bbox_scale = random.uniform(scale_range[0], scale_range[1])
else:
bbox_scale = 1
bboxes = (bboxes_normalized * tgt_size * bbox_scale).astype(np.int32)
xyxy_array = np.copy(bboxes)
xyxy_array[:, [2, 3]] += xyxy_array[:, [0, 1]]
x_max, y_max = xyxy_array[:, 2].max(), xyxy_array[:, 3].max()
x_shift = tgt_size - x_max
x_shift = randint(0, x_shift) if x_shift > 0 else 0
y_shift = tgt_size - y_max
y_shift = randint(0, y_shift) if y_shift > 0 else 0
bboxes[:, [0, 1]] += [x_shift, y_shift]
valid_bboxes = []
max_size = size_thres[1] * tgt_size
min_size = size_thres[0] * tgt_size
for bbox in bboxes:
w = min(bbox[2], tgt_size - bbox[0])
h = min(bbox[3], tgt_size - bbox[1])
if max(h, w) < max_size and min(h, w) > min_size:
valid_bboxes.append(bbox)
return valid_bboxes
def sample_matchfg(self, tgt_size: int):
while True:
bboxes = self.sample(tgt_size, (1.1, 1.8))
if len(bboxes) > 0:
break
MIN_FG_SIZE = 20
num_fg = len(bboxes)
rotate = 20 if self.is_train else 15
fgs = random_load_nfg(num_fg, self.fg_info_list, random_rotate_prob=0.33, random_rotate=rotate)
assert len(fgs) == num_fg
bboxes.sort(key=lambda x: x[2] / x[3])
fgs.sort(key=lambda x: x['asp_ratio'])
for fg, bbox in zip(fgs, bboxes):
x, y, w, h = bbox
img = fg['image']
im_h, im_w = img.shape[:2]
if im_h < h and im_w < w:
scale = min(h / im_h, w / im_w)
new_h, new_w = int(scale * im_h), int(scale * im_w)
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
else:
scale_h, scale_w = min(1, h / im_h), min(1, w / im_w)
scale = (scale_h + scale_w) / 2
if scale < 1:
new_h, new_w = max(int(scale * im_h), MIN_FG_SIZE), max(int(scale * im_w), MIN_FG_SIZE)
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
if self.fg_transform is not None:
img = self.fg_transform(image=img)['image']
im_h, im_w = img.shape[:2]
fg['image'] = img
px = int(x + w / 2 - im_w / 2)
py = int(y + h / 2 - im_h / 2)
fg['pos'] = (px, py)
random.shuffle(fgs)
slist, llist = [], []
large_size = int(tgt_size * 0.55)
for fg in fgs:
if max(fg['image'].shape[:2]) > large_size:
llist.append(fg)
else:
slist.append(fg)
return llist + slist
def random_load_nfg(num_fg: int, fg_info_list: List[Union[Dict, str]], random_rotate=0, random_rotate_prob=0.):
fgs = []
while len(fgs) < num_fg:
fg, fginfo = random_load_valid_fg(fg_info_list)
if random.random() < random_rotate_prob:
rotate_deg = randint(-random_rotate, random_rotate)
fg = rotate_image(fg, rotate_deg, alpha_crop=True)
asp_ratio = fg.shape[1] / fg.shape[0]
fgs.append({'image': fg, 'asp_ratio': asp_ratio, 'fginfo': fginfo})
while len(fgs) < num_fg and random.random() < 0.12:
fgs.append({'image': fg, 'asp_ratio': asp_ratio, 'fginfo': fginfo})
return fgs
def random_load_valid_fg(fg_info_list: List[Union[Dict, str]]) -> Tuple[np.ndarray, Dict]:
while True:
item = fginfo = rchoice(fg_info_list)
file_path = fginfo['file_path']
if 'root_dir' in fginfo and fginfo['root_dir']:
file_path = osp.join(fginfo['root_dir'], file_path)
try:
fg = imageio.imread(file_path)
except:
LOGGER.error(traceback.format_exc())
LOGGER.error(f'invalid fg: {file_path}')
fg_info_list.remove(item)
continue
c = 1
if len(fg.shape) == 3:
c = fg.shape[-1]
if c != 4:
LOGGER.warning(f'fg {file_path} doesnt have alpha channel')
fg_info_list.remove(item)
else:
if 'xyxy' in fginfo:
x1, y1, x2, y2 = fginfo['xyxy']
else:
oh, ow = fg.shape[:2]
ksize = 5
mask = cv2.blur(fg[..., 3], (ksize,ksize))
_, mask = cv2.threshold(mask, 20, 255, cv2.THRESH_BINARY)
x1, y1, w, h = cv2.boundingRect(cv2.findNonZero(mask))
x2, y2 = x1 + w, y1 + h
if oh - h > 15 or ow - w > 15:
crop = True
else:
x1 = y1 = 0
x2, y2 = ow, oh
fginfo['xyxy'] = [x1, y1, x2, y2]
fg = fg[y1: y2, x1: x2]
return fg, fginfo
def random_load_valid_bg(bg_list: List[str]) -> np.ndarray:
while True:
try:
bgp = rchoice(bg_list)
return imread_nogrey_rgb(bgp)
except:
LOGGER.error(traceback.format_exc())
LOGGER.error(f'invalid bg: {bgp}')
bg_list.remove(bgp)
continue