""" big_modules.py - This file stores higher-level network blocks. x - usually denotes features that are shared between objects. g - usually denotes features that are not shared between objects with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W). The trailing number of a variable usually denotes the stride """ from omegaconf import DictConfig import torch import torch.nn as nn import torch.nn.functional as F from tracker.model.group_modules import * from tracker.model.utils import resnet from tracker.model.modules import * class PixelEncoder(nn.Module): def __init__(self, model_cfg: DictConfig): super().__init__() self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type if self.is_resnet: if model_cfg.pixel_encoder.type == 'resnet18': network = resnet.resnet18(pretrained=True) elif model_cfg.pixel_encoder.type == 'resnet50': network = resnet.resnet50(pretrained=True) else: raise NotImplementedError self.conv1 = network.conv1 self.bn1 = network.bn1 self.relu = network.relu self.maxpool = network.maxpool self.res2 = network.layer1 self.layer2 = network.layer2 self.layer3 = network.layer3 else: raise NotImplementedError def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) f4 = self.res2(x) f8 = self.layer2(f4) f16 = self.layer3(f8) return f16, f8, f4 # override the default train() to freeze BN statistics def train(self, mode=True): self.training = False for module in self.children(): module.train(False) return self class KeyProjection(nn.Module): def __init__(self, model_cfg: DictConfig): super().__init__() in_dim = model_cfg.pixel_encoder.ms_dims[0] mid_dim = model_cfg.pixel_dim key_dim = model_cfg.key_dim self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1) self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) # shrinkage self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1) # selection self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) nn.init.orthogonal_(self.key_proj.weight.data) nn.init.zeros_(self.key_proj.bias.data) def forward(self, x: torch.Tensor, *, need_s: bool, need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor): x = self.pix_feat_proj(x) shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None return self.key_proj(x), shrinkage, selection class MaskEncoder(nn.Module): def __init__(self, model_cfg: DictConfig, single_object=False): super().__init__() pixel_dim = model_cfg.pixel_dim value_dim = model_cfg.value_dim sensory_dim = model_cfg.sensory_dim final_dim = model_cfg.mask_encoder.final_dim self.single_object = single_object extra_dim = 1 if single_object else 2 if model_cfg.mask_encoder.type == 'resnet18': network = resnet.resnet18(pretrained=True, extra_dim=extra_dim) elif model_cfg.mask_encoder.type == 'resnet50': network = resnet.resnet50(pretrained=True, extra_dim=extra_dim) else: raise NotImplementedError self.conv1 = network.conv1 self.bn1 = network.bn1 self.relu = network.relu self.maxpool = network.maxpool self.layer1 = network.layer1 self.layer2 = network.layer2 self.layer3 = network.layer3 self.distributor = MainToGroupDistributor() self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim) self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim) def forward(self, image: torch.Tensor, pix_feat: torch.Tensor, sensory: torch.Tensor, masks: torch.Tensor, others: torch.Tensor, *, deep_update: bool = True, chunk_size: int = -1) -> (torch.Tensor, torch.Tensor): # ms_features are from the key encoder # we only use the first one (lowest resolution), following XMem if self.single_object: g = masks.unsqueeze(2) else: g = torch.stack([masks, others], dim=2) g = self.distributor(image, g) batch_size, num_objects = g.shape[:2] if chunk_size < 1 or chunk_size >= num_objects: chunk_size = num_objects fast_path = True new_sensory = sensory else: if deep_update: new_sensory = torch.empty_like(sensory) else: new_sensory = sensory fast_path = False # chunk-by-chunk inference all_g = [] for i in range(0, num_objects, chunk_size): if fast_path: g_chunk = g else: g_chunk = g[:, i:i + chunk_size] actual_chunk_size = g_chunk.shape[1] g_chunk = g_chunk.flatten(start_dim=0, end_dim=1) g_chunk = self.conv1(g_chunk) g_chunk = self.bn1(g_chunk) # 1/2, 64 g_chunk = self.maxpool(g_chunk) # 1/4, 64 g_chunk = self.relu(g_chunk) g_chunk = self.layer1(g_chunk) # 1/4 g_chunk = self.layer2(g_chunk) # 1/8 g_chunk = self.layer3(g_chunk) # 1/16 g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:]) g_chunk = self.fuser(pix_feat, g_chunk) all_g.append(g_chunk) if deep_update: if fast_path: new_sensory = self.sensory_update(g_chunk, sensory) else: new_sensory[:, i:i + chunk_size] = self.sensory_update( g_chunk, sensory[:, i:i + chunk_size]) g = torch.cat(all_g, dim=1) return g, new_sensory # override the default train() to freeze BN statistics def train(self, mode=True): self.training = False for module in self.children(): module.train(False) return self class PixelFeatureFuser(nn.Module): def __init__(self, model_cfg: DictConfig, single_object=False): super().__init__() value_dim = model_cfg.value_dim sensory_dim = model_cfg.sensory_dim pixel_dim = model_cfg.pixel_dim embed_dim = model_cfg.embed_dim self.single_object = single_object self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim) if self.single_object: self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1) else: self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1) def forward(self, pix_feat: torch.Tensor, pixel_memory: torch.Tensor, sensory_memory: torch.Tensor, last_mask: torch.Tensor, last_others: torch.Tensor, *, chunk_size: int = -1) -> torch.Tensor: batch_size, num_objects = pixel_memory.shape[:2] if self.single_object: last_mask = last_mask.unsqueeze(2) else: last_mask = torch.stack([last_mask, last_others], dim=2) if chunk_size < 1: chunk_size = num_objects # chunk-by-chunk inference all_p16 = [] for i in range(0, num_objects, chunk_size): sensory_readout = self.sensory_compress( torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2)) p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout p16 = self.fuser(pix_feat, p16) all_p16.append(p16) p16 = torch.cat(all_p16, dim=1) return p16 class MaskDecoder(nn.Module): def __init__(self, model_cfg: DictConfig): super().__init__() embed_dim = model_cfg.embed_dim sensory_dim = model_cfg.sensory_dim ms_image_dims = model_cfg.pixel_encoder.ms_dims up_dims = model_cfg.mask_decoder.up_dims assert embed_dim == up_dims[0] self.sensory_update = SensoryUpdater([up_dims[0], up_dims[1], up_dims[2] + 1], sensory_dim, sensory_dim) self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1]) self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1]) self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2]) self.pred = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1) def forward(self, ms_image_feat: Iterable[torch.Tensor], memory_readout: torch.Tensor, sensory: torch.Tensor, *, chunk_size: int = -1, update_sensory: bool = True) -> (torch.Tensor, torch.Tensor): batch_size, num_objects = memory_readout.shape[:2] f8, f4 = self.decoder_feat_proc(ms_image_feat[1:]) if chunk_size < 1 or chunk_size >= num_objects: chunk_size = num_objects fast_path = True new_sensory = sensory else: if update_sensory: new_sensory = torch.empty_like(sensory) else: new_sensory = sensory fast_path = False # chunk-by-chunk inference all_logits = [] for i in range(0, num_objects, chunk_size): if fast_path: p16 = memory_readout else: p16 = memory_readout[:, i:i + chunk_size] actual_chunk_size = p16.shape[1] p8 = self.up_16_8(p16, f8) p4 = self.up_8_4(p8, f4) with torch.cuda.amp.autocast(enabled=False): logits = self.pred(F.relu(p4.flatten(start_dim=0, end_dim=1).float())) if update_sensory: p4 = torch.cat( [p4, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2) if fast_path: new_sensory = self.sensory_update([p16, p8, p4], sensory) else: new_sensory[:, i:i + chunk_size] = self.sensory_update([p16, p8, p4], sensory[:, i:i + chunk_size]) all_logits.append(logits) logits = torch.cat(all_logits, dim=0) logits = logits.view(batch_size, num_objects, *logits.shape[-2:]) return new_sensory, logits