sczhou's picture
init code
320e465
raw
history blame
9.66 kB
from typing import Tuple, Optional, Dict
import logging
import os
import shutil
from os import path
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
import pycocotools.mask as mask_util
from threading import Thread
from queue import Queue
from dataclasses import dataclass
import copy
from tracker.utils.pano_utils import ID2RGBConverter
from tracker.utils.palette import davis_palette_np
from tracker.inference.object_manager import ObjectManager
from tracker.inference.object_info import ObjectInfo
log = logging.getLogger()
try:
import hickle as hkl
except ImportError:
log.warning('Failed to import hickle. Fine if not using multi-scale testing.')
class ResultSaver:
def __init__(self,
output_root,
video_name,
*,
dataset,
object_manager: ObjectManager,
use_long_id,
palette=None,
save_mask=True,
save_scores=False,
score_output_root=None,
visualize_output_root=None,
visualize=False,
init_json=None):
self.output_root = output_root
self.video_name = video_name
self.dataset = dataset.lower()
self.use_long_id = use_long_id
self.palette = palette
self.object_manager = object_manager
self.save_mask = save_mask
self.save_scores = save_scores
self.score_output_root = score_output_root
self.visualize_output_root = visualize_output_root
self.visualize = visualize
if self.visualize:
if self.palette is not None:
self.colors = np.array(self.palette, dtype=np.uint8).reshape(-1, 3)
else:
self.colors = davis_palette_np
self.need_remapping = True
self.json_style = None
self.id2rgb_converter = ID2RGBConverter()
if 'burst' in self.dataset:
assert init_json is not None
self.input_segmentations = init_json['segmentations']
self.segmentations = [{} for _ in init_json['segmentations']]
self.annotated_frames = init_json['annotated_image_paths']
self.video_json = {k: v for k, v in init_json.items() if k != 'segmentations'}
self.video_json['segmentations'] = self.segmentations
self.json_style = 'burst'
self.queue = Queue(maxsize=10)
self.thread = Thread(target=save_result, args=(self.queue, ))
self.thread.daemon = True
self.thread.start()
def process(self,
prob: torch.Tensor,
frame_name: str,
resize_needed: bool = False,
shape: Optional[Tuple[int, int]] = None,
last_frame: bool = False,
path_to_image: str = None):
if resize_needed:
prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,
0]
# Probability mask -> index mask
mask = torch.argmax(prob, dim=0)
if self.save_scores:
# also need to pass prob
prob = prob.cpu()
else:
prob = None
# remap indices
if self.need_remapping:
new_mask = torch.zeros_like(mask)
for tmp_id, obj in self.object_manager.tmp_id_to_obj.items():
new_mask[mask == tmp_id] = obj.id
mask = new_mask
args = ResultArgs(saver=self,
prob=prob,
mask=mask.cpu(),
frame_name=frame_name,
path_to_image=path_to_image,
tmp_id_to_obj=copy.deepcopy(self.object_manager.tmp_id_to_obj),
obj_to_tmp_id=copy.deepcopy(self.object_manager.obj_to_tmp_id),
last_frame=last_frame)
self.queue.put(args)
def end(self):
self.queue.put(None)
self.queue.join()
self.thread.join()
@dataclass
class ResultArgs:
saver: ResultSaver
prob: torch.Tensor
mask: torch.Tensor
frame_name: str
path_to_image: str
tmp_id_to_obj: Dict[int, ObjectInfo]
obj_to_tmp_id: Dict[ObjectInfo, int]
last_frame: bool
def save_result(queue: Queue):
while True:
args: ResultArgs = queue.get()
if args is None:
queue.task_done()
break
saver = args.saver
prob = args.prob
mask = args.mask
frame_name = args.frame_name
path_to_image = args.path_to_image
tmp_id_to_obj = args.tmp_id_to_obj
obj_to_tmp_id = args.obj_to_tmp_id
last_frame = args.last_frame
all_obj_ids = [k.id for k in obj_to_tmp_id]
# record output in the json file
if saver.json_style == 'burst':
if frame_name in saver.annotated_frames:
frame_index = saver.annotated_frames.index(frame_name)
input_segments = saver.input_segmentations[frame_index]
frame_segments = saver.segmentations[frame_index]
for id in all_obj_ids:
if id in input_segments:
# if this frame has been given as input, just copy
frame_segments[id] = input_segments[id]
continue
segment = {}
segment_mask = (mask == id)
if segment_mask.sum() > 0:
coco_mask = mask_util.encode(np.asfortranarray(segment_mask.numpy()))
segment['rle'] = coco_mask['counts'].decode('utf-8')
frame_segments[id] = segment
# save the mask to disk
if saver.save_mask:
if saver.use_long_id:
out_mask = mask.numpy().astype(np.uint32)
rgb_mask = np.zeros((*out_mask.shape[-2:], 3), dtype=np.uint8)
for id in all_obj_ids:
_, image = saver.id2rgb_converter.convert(id)
obj_mask = (out_mask == id)
rgb_mask[obj_mask] = image
out_img = Image.fromarray(rgb_mask)
else:
rgb_mask = None
out_mask = mask.numpy().astype(np.uint8)
out_img = Image.fromarray(out_mask)
if saver.palette is not None:
out_img.putpalette(saver.palette)
this_out_path = path.join(saver.output_root, saver.video_name)
os.makedirs(this_out_path, exist_ok=True)
out_img.save(os.path.join(this_out_path, frame_name[:-4] + '.png'))
# save scores for multi-scale testing
if saver.save_scores:
this_out_path = path.join(saver.score_output_root, saver.video_name)
os.makedirs(this_out_path, exist_ok=True)
prob = (prob.detach().numpy() * 255).astype(np.uint8)
if last_frame:
tmp_to_obj_mapping = {obj.id: tmp_id for obj, tmp_id in tmp_id_to_obj.items()}
hkl.dump(tmp_to_obj_mapping, path.join(this_out_path, f'backward.hkl'), mode='w')
hkl.dump(prob,
path.join(this_out_path, f'{frame_name[:-4]}.hkl'),
mode='w',
compression='lzf')
if saver.visualize:
if path_to_image is not None:
image_np = np.array(Image.open(path_to_image))
else:
raise ValueError('Cannot visualize without path_to_image')
if rgb_mask is None:
# we need to apply a palette
rgb_mask = np.zeros((*out_mask.shape, 3), dtype=np.uint8)
for id in all_obj_ids:
image = saver.colors[id]
obj_mask = (out_mask == id)
rgb_mask[obj_mask] = image
alpha = (out_mask == 0).astype(np.float32) * 0.5 + 0.5
alpha = alpha[:, :, None]
blend = (image_np * alpha + rgb_mask * (1 - alpha)).astype(np.uint8)
# find a place to save the visualization
this_vis_path = path.join(saver.visualize_output_root, saver.video_name)
os.makedirs(this_vis_path, exist_ok=True)
Image.fromarray(blend).save(path.join(this_vis_path, frame_name[:-4] + '.jpg'))
queue.task_done()
def make_zip(dataset, run_dir, exp_id, mask_output_root):
if dataset.startswith('y'):
# YoutubeVOS
log.info('Making zip for YouTubeVOS...')
shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir,
'Annotations')
elif dataset == 'd17-test-dev':
# DAVIS 2017 test-dev -- zip from within the Annotation folder
log.info('Making zip for DAVIS test-dev...')
shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root)
elif dataset == 'mose-val':
# MOSE validation -- same as DAVIS test-dev
log.info('Making zip for MOSE validation...')
shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root)
elif dataset == 'lvos-test':
# LVOS test -- same as YouTubeVOS
log.info('Making zip for LVOS test...')
shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir,
'Annotations')
else:
log.info(f'Not making zip for {dataset}.')