Segment_and_track_Anything / aot_tracker.py
Zeeshan01's picture
Upload folder using huggingface_hub
04daa95
from statistics import mode
import torch
import torch.nn.functional as F
import os
import sys
sys.path.append("./aot")
from aot.networks.engines.aot_engine import AOTEngine,AOTInferEngine
from aot.networks.engines.deaot_engine import DeAOTEngine,DeAOTInferEngine
import importlib
import numpy as np
from PIL import Image
from skimage.morphology.binary import binary_dilation
np.random.seed(200)
_palette = ((np.random.random((3*255))*0.7+0.3)*255).astype(np.uint8).tolist()
_palette = [0,0,0]+_palette
import aot.dataloaders.video_transforms as tr
from aot.utils.checkpoint import load_network
from aot.networks.models import build_vos_model
from aot.networks.engines import build_engine
from torchvision import transforms
class AOTTracker(object):
def __init__(self, cfg, gpu_id=0):
self.gpu_id = gpu_id
self.model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(gpu_id)
self.model, _ = load_network(self.model, cfg.TEST_CKPT_PATH, gpu_id)
# self.engine = self.build_tracker_engine(cfg.MODEL_ENGINE,
# aot_model=self.model,
# gpu_id=gpu_id,
# short_term_mem_skip=4,
# long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP)
self.engine = build_engine(cfg.MODEL_ENGINE,
phase='eval',
aot_model=self.model,
gpu_id=gpu_id,
short_term_mem_skip=1,
long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP,
max_len_long_term=cfg.MAX_LEN_LONG_TERM)
self.transform = transforms.Compose([
tr.MultiRestrictSize(cfg.TEST_MAX_SHORT_EDGE,
cfg.TEST_MAX_LONG_EDGE, cfg.TEST_FLIP,
cfg.TEST_MULTISCALE, cfg.MODEL_ALIGN_CORNERS),
tr.MultiToTensor()
])
self.model.eval()
@torch.no_grad()
def add_reference_frame(self, frame, mask, obj_nums, frame_step, incremental=False):
# mask = cv2.resize(mask, frame.shape[:2][::-1], interpolation = cv2.INTER_NEAREST)
sample = {
'current_img': frame,
'current_label': mask,
}
sample = self.transform(sample)
frame = sample[0]['current_img'].unsqueeze(0).float().cuda(self.gpu_id)
mask = sample[0]['current_label'].unsqueeze(0).float().cuda(self.gpu_id)
_mask = F.interpolate(mask,size=frame.shape[-2:],mode='nearest')
if incremental:
self.engine.add_reference_frame_incremental(frame, _mask, obj_nums=obj_nums, frame_step=frame_step)
else:
self.engine.add_reference_frame(frame, _mask, obj_nums=obj_nums, frame_step=frame_step)
@torch.no_grad()
def track(self, image):
output_height, output_width = image.shape[0], image.shape[1]
sample = {'current_img': image}
sample = self.transform(sample)
image = sample[0]['current_img'].unsqueeze(0).float().cuda(self.gpu_id)
self.engine.match_propogate_one_frame(image)
pred_logit = self.engine.decode_current_logits((output_height, output_width))
# pred_prob = torch.softmax(pred_logit, dim=1)
pred_label = torch.argmax(pred_logit, dim=1,
keepdim=True).float()
return pred_label
@torch.no_grad()
def update_memory(self, pred_label):
self.engine.update_memory(pred_label)
@torch.no_grad()
def restart(self):
self.engine.restart_engine()
@torch.no_grad()
def build_tracker_engine(self, name, **kwargs):
if name == 'aotengine':
return AOTTrackerInferEngine(**kwargs)
elif name == 'deaotengine':
return DeAOTTrackerInferEngine(**kwargs)
else:
raise NotImplementedError
class AOTTrackerInferEngine(AOTInferEngine):
def __init__(self, aot_model, gpu_id=0, long_term_mem_gap=9999, short_term_mem_skip=1, max_aot_obj_num=None):
super().__init__(aot_model, gpu_id, long_term_mem_gap, short_term_mem_skip, max_aot_obj_num)
def add_reference_frame_incremental(self, img, mask, obj_nums, frame_step=-1):
if isinstance(obj_nums, list):
obj_nums = obj_nums[0]
self.obj_nums = obj_nums
aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1)
while (aot_num > len(self.aot_engines)):
new_engine = AOTEngine(self.AOT, self.gpu_id,
self.long_term_mem_gap,
self.short_term_mem_skip)
new_engine.eval()
self.aot_engines.append(new_engine)
separated_masks, separated_obj_nums = self.separate_mask(
mask, obj_nums)
img_embs = None
for aot_engine, separated_mask, separated_obj_num in zip(
self.aot_engines, separated_masks, separated_obj_nums):
if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num:
aot_engine.add_reference_frame(img,
separated_mask,
obj_nums=[separated_obj_num],
frame_step=frame_step,
img_embs=img_embs)
else:
aot_engine.update_short_term_memory(separated_mask)
if img_embs is None: # reuse image embeddings
img_embs = aot_engine.curr_enc_embs
self.update_size()
class DeAOTTrackerInferEngine(DeAOTInferEngine):
def __init__(self, aot_model, gpu_id=0, long_term_mem_gap=9999, short_term_mem_skip=1, max_aot_obj_num=None):
super().__init__(aot_model, gpu_id, long_term_mem_gap, short_term_mem_skip, max_aot_obj_num)
def add_reference_frame_incremental(self, img, mask, obj_nums, frame_step=-1):
if isinstance(obj_nums, list):
obj_nums = obj_nums[0]
self.obj_nums = obj_nums
aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1)
while (aot_num > len(self.aot_engines)):
new_engine = DeAOTEngine(self.AOT, self.gpu_id,
self.long_term_mem_gap,
self.short_term_mem_skip)
new_engine.eval()
self.aot_engines.append(new_engine)
separated_masks, separated_obj_nums = self.separate_mask(
mask, obj_nums)
img_embs = None
for aot_engine, separated_mask, separated_obj_num in zip(
self.aot_engines, separated_masks, separated_obj_nums):
if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num:
aot_engine.add_reference_frame(img,
separated_mask,
obj_nums=[separated_obj_num],
frame_step=frame_step,
img_embs=img_embs)
else:
aot_engine.update_short_term_memory(separated_mask)
if img_embs is None: # reuse image embeddings
img_embs = aot_engine.curr_enc_embs
self.update_size()
def get_aot(args):
# build vos engine
engine_config = importlib.import_module('configs.' + 'pre_ytb_dav')
cfg = engine_config.EngineConfig(args['phase'], args['model'])
cfg.TEST_CKPT_PATH = args['model_path']
cfg.TEST_LONG_TERM_MEM_GAP = args['long_term_mem_gap']
cfg.MAX_LEN_LONG_TERM = args['max_len_long_term']
# init AOTTracker
tracker = AOTTracker(cfg, args['gpu_id'])
return tracker