sczhou's picture
init code
320e465
raw
history blame
5.51 kB
# Modified from PyTorch nn.Transformer
from typing import List, Callable
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from tracker.model.channel_attn import CAResBlock
class SelfAttention(nn.Module):
def __init__(self,
dim: int,
nhead: int,
dropout: float = 0.0,
batch_first: bool = True,
add_pe_to_qkv: List[bool] = [True, True, False]):
super().__init__()
self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first)
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.add_pe_to_qkv = add_pe_to_qkv
def forward(self,
x: torch.Tensor,
pe: torch.Tensor,
attn_mask: bool = None,
key_padding_mask: bool = None) -> torch.Tensor:
x = self.norm(x)
if any(self.add_pe_to_qkv):
x_with_pe = x + pe
q = x_with_pe if self.add_pe_to_qkv[0] else x
k = x_with_pe if self.add_pe_to_qkv[1] else x
v = x_with_pe if self.add_pe_to_qkv[2] else x
else:
q = k = v = x
r = x
x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0]
return r + self.dropout(x)
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
class CrossAttention(nn.Module):
def __init__(self,
dim: int,
nhead: int,
dropout: float = 0.0,
batch_first: bool = True,
add_pe_to_qkv: List[bool] = [True, True, False],
residual: bool = True,
norm: bool = True):
super().__init__()
self.cross_attn = nn.MultiheadAttention(dim,
nhead,
dropout=dropout,
batch_first=batch_first)
if norm:
self.norm = nn.LayerNorm(dim)
else:
self.norm = nn.Identity()
self.dropout = nn.Dropout(dropout)
self.add_pe_to_qkv = add_pe_to_qkv
self.residual = residual
def forward(self,
x: torch.Tensor,
mem: torch.Tensor,
x_pe: torch.Tensor,
mem_pe: torch.Tensor,
attn_mask: bool = None,
*,
need_weights: bool = False) -> (torch.Tensor, torch.Tensor):
x = self.norm(x)
if self.add_pe_to_qkv[0]:
q = x + x_pe
else:
q = x
if any(self.add_pe_to_qkv[1:]):
mem_with_pe = mem + mem_pe
k = mem_with_pe if self.add_pe_to_qkv[1] else mem
v = mem_with_pe if self.add_pe_to_qkv[2] else mem
else:
k = v = mem
r = x
x, weights = self.cross_attn(q,
k,
v,
attn_mask=attn_mask,
need_weights=need_weights,
average_attn_weights=False)
if self.residual:
return r + self.dropout(x), weights
else:
return self.dropout(x), weights
class FFN(nn.Module):
def __init__(self, dim_in: int, dim_ff: int, activation=F.relu):
super().__init__()
self.linear1 = nn.Linear(dim_in, dim_ff)
self.linear2 = nn.Linear(dim_ff, dim_in)
self.norm = nn.LayerNorm(dim_in)
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
else:
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
r = x
x = self.norm(x)
x = self.linear2(self.activation(self.linear1(x)))
x = r + x
return x
class PixelFFN(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
self.conv = CAResBlock(dim, dim)
def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor:
# pixel: batch_size * num_objects * dim * H * W
# pixel_flat: (batch_size*num_objects) * (H*W) * dim
bs, num_objects, _, h, w = pixel.shape
pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim)
pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous()
x = self.conv(pixel_flat)
x = x.view(bs, num_objects, self.dim, h, w)
return x
class OutputFFN(nn.Module):
def __init__(self, dim_in: int, dim_out: int, activation=F.relu):
super().__init__()
self.linear1 = nn.Linear(dim_in, dim_out)
self.linear2 = nn.Linear(dim_out, dim_out)
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
else:
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear2(self.activation(self.linear1(x)))
return x
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))