# Modified from PyTorch nn.Transformer from typing import List, Callable import torch from torch import Tensor import torch.nn as nn import torch.nn.functional as F from tracker.model.channel_attn import CAResBlock class SelfAttention(nn.Module): def __init__(self, dim: int, nhead: int, dropout: float = 0.0, batch_first: bool = True, add_pe_to_qkv: List[bool] = [True, True, False]): super().__init__() self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first) self.norm = nn.LayerNorm(dim) self.dropout = nn.Dropout(dropout) self.add_pe_to_qkv = add_pe_to_qkv def forward(self, x: torch.Tensor, pe: torch.Tensor, attn_mask: bool = None, key_padding_mask: bool = None) -> torch.Tensor: x = self.norm(x) if any(self.add_pe_to_qkv): x_with_pe = x + pe q = x_with_pe if self.add_pe_to_qkv[0] else x k = x_with_pe if self.add_pe_to_qkv[1] else x v = x_with_pe if self.add_pe_to_qkv[2] else x else: q = k = v = x r = x x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] return r + self.dropout(x) # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention class CrossAttention(nn.Module): def __init__(self, dim: int, nhead: int, dropout: float = 0.0, batch_first: bool = True, add_pe_to_qkv: List[bool] = [True, True, False], residual: bool = True, norm: bool = True): super().__init__() self.cross_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first) if norm: self.norm = nn.LayerNorm(dim) else: self.norm = nn.Identity() self.dropout = nn.Dropout(dropout) self.add_pe_to_qkv = add_pe_to_qkv self.residual = residual def forward(self, x: torch.Tensor, mem: torch.Tensor, x_pe: torch.Tensor, mem_pe: torch.Tensor, attn_mask: bool = None, *, need_weights: bool = False) -> (torch.Tensor, torch.Tensor): x = self.norm(x) if self.add_pe_to_qkv[0]: q = x + x_pe else: q = x if any(self.add_pe_to_qkv[1:]): mem_with_pe = mem + mem_pe k = mem_with_pe if self.add_pe_to_qkv[1] else mem v = mem_with_pe if self.add_pe_to_qkv[2] else mem else: k = v = mem r = x x, weights = self.cross_attn(q, k, v, attn_mask=attn_mask, need_weights=need_weights, average_attn_weights=False) if self.residual: return r + self.dropout(x), weights else: return self.dropout(x), weights class FFN(nn.Module): def __init__(self, dim_in: int, dim_ff: int, activation=F.relu): super().__init__() self.linear1 = nn.Linear(dim_in, dim_ff) self.linear2 = nn.Linear(dim_ff, dim_in) self.norm = nn.LayerNorm(dim_in) if isinstance(activation, str): self.activation = _get_activation_fn(activation) else: self.activation = activation def forward(self, x: torch.Tensor) -> torch.Tensor: r = x x = self.norm(x) x = self.linear2(self.activation(self.linear1(x))) x = r + x return x class PixelFFN(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim self.conv = CAResBlock(dim, dim) def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor: # pixel: batch_size * num_objects * dim * H * W # pixel_flat: (batch_size*num_objects) * (H*W) * dim bs, num_objects, _, h, w = pixel.shape pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim) pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous() x = self.conv(pixel_flat) x = x.view(bs, num_objects, self.dim, h, w) return x class OutputFFN(nn.Module): def __init__(self, dim_in: int, dim_out: int, activation=F.relu): super().__init__() self.linear1 = nn.Linear(dim_in, dim_out) self.linear2 = nn.Linear(dim_out, dim_out) if isinstance(activation, str): self.activation = _get_activation_fn(activation) else: self.activation = activation def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linear2(self.activation(self.linear1(x))) return x def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError("activation should be relu/gelu, not {}".format(activation))