from typing import Dict, Optional from omegaconf import DictConfig import torch import torch.nn as nn from tracker.model.group_modules import GConv2d from tracker.utils.tensor_utils import aggregate from tracker.model.transformer.positional_encoding import PositionalEncoding from tracker.model.transformer.transformer_layers import * class QueryTransformerBlock(nn.Module): def __init__(self, model_cfg: DictConfig): super().__init__() this_cfg = model_cfg.object_transformer self.embed_dim = this_cfg.embed_dim self.num_heads = this_cfg.num_heads self.num_queries = this_cfg.num_queries self.ff_dim = this_cfg.ff_dim self.read_from_pixel = CrossAttention(self.embed_dim, self.num_heads, add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv) self.self_attn = SelfAttention(self.embed_dim, self.num_heads, add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv) self.ffn = FFN(self.embed_dim, self.ff_dim) self.read_from_query = CrossAttention(self.embed_dim, self.num_heads, add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv, norm=this_cfg.read_from_query.output_norm) self.pixel_ffn = PixelFFN(self.embed_dim) def forward( self, x: torch.Tensor, pixel: torch.Tensor, query_pe: torch.Tensor, pixel_pe: torch.Tensor, attn_mask: torch.Tensor, need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): # x: (bs*num_objects)*num_queries*embed_dim # pixel: bs*num_objects*C*H*W # query_pe: (bs*num_objects)*num_queries*embed_dim # pixel_pe: (bs*num_objects)*(H*W)*C # attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W) # bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() x, q_weights = self.read_from_pixel(x, pixel_flat, query_pe, pixel_pe, attn_mask=attn_mask, need_weights=need_weights) x = self.self_attn(x, query_pe) x = self.ffn(x) pixel_flat, p_weights = self.read_from_query(pixel_flat, x, pixel_pe, query_pe, need_weights=need_weights) pixel = self.pixel_ffn(pixel, pixel_flat) if need_weights: bs, num_objects, _, h, w = pixel.shape q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w) p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads, self.num_queries, h, w) return x, pixel, q_weights, p_weights class QueryTransformer(nn.Module): def __init__(self, model_cfg: DictConfig): super().__init__() this_cfg = model_cfg.object_transformer self.value_dim = model_cfg.value_dim self.embed_dim = this_cfg.embed_dim self.num_heads = this_cfg.num_heads self.num_queries = this_cfg.num_queries # query initialization and embedding self.query_init = nn.Embedding(self.num_queries, self.embed_dim) self.query_emb = nn.Embedding(self.num_queries, self.embed_dim) # projection from object summaries to query initialization and embedding self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim) self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim) self.pixel_pe_scale = model_cfg.pixel_pe_scale self.pixel_pe_temperature = model_cfg.pixel_pe_temperature self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) self.spatial_pe = PositionalEncoding(self.embed_dim, scale=self.pixel_pe_scale, temperature=self.pixel_pe_temperature, channel_last=False, transpose_output=True) # transformer blocks self.num_blocks = this_cfg.num_blocks self.blocks = nn.ModuleList( QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks)) self.mask_pred = nn.ModuleList( nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1)) for _ in range(self.num_blocks + 1)) self.act = nn.ReLU(inplace=True) def forward(self, pixel: torch.Tensor, obj_summaries: torch.Tensor, selector: Optional[torch.Tensor] = None, need_weights: bool = False) -> (torch.Tensor, Dict[str, torch.Tensor]): # pixel: B*num_objects*embed_dim*H*W # obj_summaries: B*num_objects*T*num_queries*embed_dim T = obj_summaries.shape[2] bs, num_objects, _, H, W = pixel.shape # normalize object values # the last channel is the cumulative area of the object obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries, self.embed_dim + 1) # sum over time # during inference, T=1 as we already did streaming average in memory_manager obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1) obj_area = obj_summaries[:, :, :, -1:].sum(dim=1) obj_values = obj_sums / (obj_area + 1e-4) obj_init = self.summary_to_query_init(obj_values) obj_emb = self.summary_to_query_emb(obj_values) # positional embeddings for object queries query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb # positional embeddings for pixel features pixel_init = self.pixel_init_proj(pixel) pixel_emb = self.pixel_emb_proj(pixel) pixel_pe = self.spatial_pe(pixel.flatten(0, 1)) pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb pixel = pixel_init # run the transformer aux_features = {'logits': []} # first aux output aux_logits = self.mask_pred[0](pixel).squeeze(2) attn_mask = self._get_aux_mask(aux_logits, selector) aux_features['logits'].append(aux_logits) for i in range(self.num_blocks): query, pixel, q_weights, p_weights = self.blocks[i](query, pixel, query_emb, pixel_pe, attn_mask, need_weights=need_weights) if self.training or i <= self.num_blocks - 1 or need_weights: aux_logits = self.mask_pred[i + 1](pixel).squeeze(2) attn_mask = self._get_aux_mask(aux_logits, selector) aux_features['logits'].append(aux_logits) aux_features['q_weights'] = q_weights # last layer only aux_features['p_weights'] = p_weights # last layer only if self.training: # no need to save all heads aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads, self.num_queries, H, W)[:, :, 0] return pixel, aux_features def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor: # logits: batch_size*num_objects*H*W # selector: batch_size*num_objects*1*1 # returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W) # where True means the attention is blocked if selector is None: prob = logits.sigmoid() else: prob = logits.sigmoid() * selector logits = aggregate(prob, dim=1) is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0]) foreground_mask = is_foreground.bool().flatten(start_dim=2) inv_foreground_mask = ~foreground_mask inv_background_mask = foreground_mask aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat( 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat( 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1) aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False return aux_mask