import numpy as np from random import choice as rchoice from random import randint import random import cv2, traceback, imageio import os.path as osp from typing import Optional, List, Union, Tuple, Dict from utils.io_utils import imread_nogrey_rgb, json2dict from .transforms import rotate_image from utils.logger import LOGGER class NameSampler: def __init__(self, name_prob_dict, sample_num=2048) -> None: self.name_prob_dict = name_prob_dict self._id2name = list(name_prob_dict.keys()) self.sample_ids = [] total_prob = 0. for ii, (_, prob) in enumerate(name_prob_dict.items()): tgt_num = int(prob * sample_num) total_prob += prob if tgt_num > 0: self.sample_ids += [ii] * tgt_num nsamples = len(self.sample_ids) assert prob <= 1 if prob < 1 and nsamples < sample_num: self.sample_ids += [len(self._id2name)] * (sample_num - nsamples) self._id2name.append('_') def sample(self) -> str: return self._id2name[rchoice(self.sample_ids)] class PossionSampler: def __init__(self, lam=3, min_val=1, max_val=8) -> None: self._distr = np.random.poisson(lam, 1024) invalid = np.where(np.logical_or(self._distr max_val)) self._distr[invalid] = np.random.randint(min_val, max_val, len(invalid[0])) def sample(self) -> int: return rchoice(self._distr) class NormalSampler: def __init__(self, loc=0.33, std=0.2, min_scale=0.15, max_scale=0.85, scalar=1, to_int = True): s = np.random.normal(loc, std, 4096) valid = np.where(np.logical_and(s>min_scale, s int: return rchoice(self._distr) class PersonBBoxSampler: def __init__(self, sample_path: Union[str, List]='data/cocoperson_bbox_samples.json', fg_info_list: List = None, fg_transform=None, is_train=True) -> None: if isinstance(sample_path, str): sample_path = [sample_path] self.bbox_list = [] for sp in sample_path: bboxlist = json2dict(sp) for bboxes in bboxlist: if isinstance(bboxes, dict): bboxes = bboxes['bboxes'] bboxes = np.array(bboxes) bboxes[:, [0, 1]] -= bboxes[:, [0, 1]].min(axis=0) self.bbox_list.append(bboxes) self.fg_info_list = fg_info_list self.fg_transform = fg_transform self.is_train = is_train def sample(self, tgt_size: int, scale_range=(1, 1), size_thres=(0.02, 0.85)) -> List[np.ndarray]: bboxes_normalized = rchoice(self.bbox_list) if scale_range[0] != 1 or scale_range[1] != 1: bbox_scale = random.uniform(scale_range[0], scale_range[1]) else: bbox_scale = 1 bboxes = (bboxes_normalized * tgt_size * bbox_scale).astype(np.int32) xyxy_array = np.copy(bboxes) xyxy_array[:, [2, 3]] += xyxy_array[:, [0, 1]] x_max, y_max = xyxy_array[:, 2].max(), xyxy_array[:, 3].max() x_shift = tgt_size - x_max x_shift = randint(0, x_shift) if x_shift > 0 else 0 y_shift = tgt_size - y_max y_shift = randint(0, y_shift) if y_shift > 0 else 0 bboxes[:, [0, 1]] += [x_shift, y_shift] valid_bboxes = [] max_size = size_thres[1] * tgt_size min_size = size_thres[0] * tgt_size for bbox in bboxes: w = min(bbox[2], tgt_size - bbox[0]) h = min(bbox[3], tgt_size - bbox[1]) if max(h, w) < max_size and min(h, w) > min_size: valid_bboxes.append(bbox) return valid_bboxes def sample_matchfg(self, tgt_size: int): while True: bboxes = self.sample(tgt_size, (1.1, 1.8)) if len(bboxes) > 0: break MIN_FG_SIZE = 20 num_fg = len(bboxes) rotate = 20 if self.is_train else 15 fgs = random_load_nfg(num_fg, self.fg_info_list, random_rotate_prob=0.33, random_rotate=rotate) assert len(fgs) == num_fg bboxes.sort(key=lambda x: x[2] / x[3]) fgs.sort(key=lambda x: x['asp_ratio']) for fg, bbox in zip(fgs, bboxes): x, y, w, h = bbox img = fg['image'] im_h, im_w = img.shape[:2] if im_h < h and im_w < w: scale = min(h / im_h, w / im_w) new_h, new_w = int(scale * im_h), int(scale * im_w) img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) else: scale_h, scale_w = min(1, h / im_h), min(1, w / im_w) scale = (scale_h + scale_w) / 2 if scale < 1: new_h, new_w = max(int(scale * im_h), MIN_FG_SIZE), max(int(scale * im_w), MIN_FG_SIZE) img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) if self.fg_transform is not None: img = self.fg_transform(image=img)['image'] im_h, im_w = img.shape[:2] fg['image'] = img px = int(x + w / 2 - im_w / 2) py = int(y + h / 2 - im_h / 2) fg['pos'] = (px, py) random.shuffle(fgs) slist, llist = [], [] large_size = int(tgt_size * 0.55) for fg in fgs: if max(fg['image'].shape[:2]) > large_size: llist.append(fg) else: slist.append(fg) return llist + slist def random_load_nfg(num_fg: int, fg_info_list: List[Union[Dict, str]], random_rotate=0, random_rotate_prob=0.): fgs = [] while len(fgs) < num_fg: fg, fginfo = random_load_valid_fg(fg_info_list) if random.random() < random_rotate_prob: rotate_deg = randint(-random_rotate, random_rotate) fg = rotate_image(fg, rotate_deg, alpha_crop=True) asp_ratio = fg.shape[1] / fg.shape[0] fgs.append({'image': fg, 'asp_ratio': asp_ratio, 'fginfo': fginfo}) while len(fgs) < num_fg and random.random() < 0.12: fgs.append({'image': fg, 'asp_ratio': asp_ratio, 'fginfo': fginfo}) return fgs def random_load_valid_fg(fg_info_list: List[Union[Dict, str]]) -> Tuple[np.ndarray, Dict]: while True: item = fginfo = rchoice(fg_info_list) file_path = fginfo['file_path'] if 'root_dir' in fginfo and fginfo['root_dir']: file_path = osp.join(fginfo['root_dir'], file_path) try: fg = imageio.imread(file_path) except: LOGGER.error(traceback.format_exc()) LOGGER.error(f'invalid fg: {file_path}') fg_info_list.remove(item) continue c = 1 if len(fg.shape) == 3: c = fg.shape[-1] if c != 4: LOGGER.warning(f'fg {file_path} doesnt have alpha channel') fg_info_list.remove(item) else: if 'xyxy' in fginfo: x1, y1, x2, y2 = fginfo['xyxy'] else: oh, ow = fg.shape[:2] ksize = 5 mask = cv2.blur(fg[..., 3], (ksize,ksize)) _, mask = cv2.threshold(mask, 20, 255, cv2.THRESH_BINARY) x1, y1, w, h = cv2.boundingRect(cv2.findNonZero(mask)) x2, y2 = x1 + w, y1 + h if oh - h > 15 or ow - w > 15: crop = True else: x1 = y1 = 0 x2, y2 = ow, oh fginfo['xyxy'] = [x1, y1, x2, y2] fg = fg[y1: y2, x1: x2] return fg, fginfo def random_load_valid_bg(bg_list: List[str]) -> np.ndarray: while True: try: bgp = rchoice(bg_list) return imread_nogrey_rgb(bgp) except: LOGGER.error(traceback.format_exc()) LOGGER.error(f'invalid bg: {bgp}') bg_list.remove(bgp) continue