sczhou's picture
init code
320e465
raw
history blame
3.35 kB
from typing import List, Dict, Optional
from omegaconf import DictConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from tracker.model.transformer.positional_encoding import PositionalEncoding
# @torch.jit.script
def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor,
logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
# value: B*num_objects*H*W*value_dim
# logits: B*num_objects*H*W*num_summaries
# masks: B*num_objects*H*W*num_summaries: 1 if allowed
weights = logits.sigmoid() * masks
# B*num_objects*num_summaries*value_dim
sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value)
# B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1
area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1)
# B*num_objects*num_summaries*value_dim
return sums, area
class ObjectSummarizer(nn.Module):
def __init__(self, model_cfg: DictConfig):
super().__init__()
this_cfg = model_cfg.object_summarizer
self.value_dim = model_cfg.value_dim
self.embed_dim = this_cfg.embed_dim
self.num_summaries = this_cfg.num_summaries
self.add_pe = this_cfg.add_pe
self.pixel_pe_scale = model_cfg.pixel_pe_scale
self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
if self.add_pe:
self.pos_enc = PositionalEncoding(self.embed_dim,
scale=self.pixel_pe_scale,
temperature=self.pixel_pe_temperature)
self.input_proj = nn.Linear(self.value_dim, self.embed_dim)
self.feature_pred = nn.Sequential(
nn.Linear(self.embed_dim, self.embed_dim),
nn.ReLU(inplace=True),
nn.Linear(self.embed_dim, self.embed_dim),
)
self.weights_pred = nn.Sequential(
nn.Linear(self.embed_dim, self.embed_dim),
nn.ReLU(inplace=True),
nn.Linear(self.embed_dim, self.num_summaries),
)
def forward(self,
masks: torch.Tensor,
value: torch.Tensor,
need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]):
# masks: B*num_objects*(H0)*(W0)
# value: B*num_objects*value_dim*H*W
# -> B*num_objects*H*W*value_dim
h, w = value.shape[-2:]
masks = F.interpolate(masks, size=(h, w), mode='area')
masks = masks.unsqueeze(-1)
inv_masks = 1 - masks
repeated_masks = torch.cat([
masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
],
dim=-1)
value = value.permute(0, 1, 3, 4, 2)
value = self.input_proj(value)
if self.add_pe:
pe = self.pos_enc(value)
value = value + pe
with torch.cuda.amp.autocast(enabled=False):
value = value.float()
feature = self.feature_pred(value)
logits = self.weights_pred(value)
sums, area = _weighted_pooling(repeated_masks, feature, logits)
summaries = torch.cat([sums, area], dim=-1)
if need_weights:
return summaries, logits
else:
return summaries, None