sczhou's picture
init code
320e465
raw
history blame
3.28 kB
import numpy as np
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
import sys
sys.path.append('../')
from tracker.config import CONFIG
from tracker.model.cutie import CUTIE
from tracker.inference.inference_core import InferenceCore
from tracker.utils.mask_mapper import MaskMapper
from tools.painter import mask_painter
class BaseTracker:
def __init__(self, cutie_checkpoint, device) -> None:
"""
device: model device
cutie_checkpoint: checkpoint of XMem model
"""
config = OmegaConf.create(CONFIG)
# initialise XMem
network = CUTIE(config).to(device).eval()
model_weights = torch.load(cutie_checkpoint, map_location=device)
network.load_weights(model_weights)
# initialise IncerenceCore
self.tracker = InferenceCore(network, config)
self.device = device
# changable properties
self.mapper = MaskMapper()
self.initialised = False
@torch.no_grad()
def resize_mask(self, mask):
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
h, w = mask.shape[-2:]
min_hw = min(h, w)
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
mode='nearest')
@torch.no_grad()
def image_to_torch(self, frame: np.ndarray, device: str = 'cuda'):
# frame: H*W*3 numpy array
frame = frame.transpose(2, 0, 1)
frame = torch.from_numpy(frame).float().to(device, non_blocking=True) / 255
return frame
@torch.no_grad()
def track(self, frame, first_frame_annotation=None):
"""
Input:
frames: numpy arrays (H, W, 3)
logit: numpy array (H, W), logit
Output:
mask: numpy arrays (H, W)
logit: numpy arrays, probability map (H, W)
painted_image: numpy array (H, W, 3)
"""
if first_frame_annotation is not None: # first frame mask
# initialisation
mask, labels = self.mapper.convert_mask(first_frame_annotation)
mask = torch.Tensor(mask).to(self.device)
else:
mask = None
labels = None
# prepare inputs
frame_tensor = self.image_to_torch(frame, self.device)
# track one frame
probs = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
# convert to mask
out_mask = torch.argmax(probs, dim=0)
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
final_mask = np.zeros_like(out_mask)
# map back
for k, v in self.mapper.remappings.items():
final_mask[out_mask == v] = k
num_objs = final_mask.max()
painted_image = frame
for obj in range(1, num_objs+1):
if np.max(final_mask==obj) == 0:
continue
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
return final_mask, final_mask, painted_image
@torch.no_grad()
def clear_memory(self):
self.tracker.clear_memory()
self.mapper.clear_labels()
torch.cuda.empty_cache()