LuojiaHOG / cisen /model /layers.py
aleo1's picture
Upload 41 files
bb6012a verified
raw
history blame
26.3 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# import open_clip
def conv_layer(in_dim, out_dim, kernel_size=1, padding=0, stride=1):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_dim), nn.ReLU(True))
# return nn.Sequential(
# nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
# nn.LayerNorm(out_dim), nn.ReLU(True))
# def conv_layer_1(in_dim, out_dim, kernel_size=1, padding=0, stride=1):
# return nn.Sequential(
# nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
# nn.LayerNorm(out_dim), nn.ReLU(True))
def linear_layer(in_dim, out_dim,bias=False):
return nn.Sequential(nn.Linear(in_dim, out_dim, bias),
nn.BatchNorm1d(out_dim), nn.ReLU(True))
# return nn.Sequential(nn.Linear(in_dim, out_dim, bias),
# nn.LayerNorm(out_dim), nn.ReLU(True))
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
# class AttentionPool2d(nn.Module):
# def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
# super().__init__()
# self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
# self.k_proj = nn.Linear(embed_dim, embed_dim)
# self.q_proj = nn.Linear(embed_dim, embed_dim)
# self.v_proj = nn.Linear(embed_dim, embed_dim)
# self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
# self.num_heads = num_heads
#
# def forward(self, x):
# x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
# x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
# x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
# x, _ = F.multi_head_attention_forward(
# query=x, key=x, value=x,
# embed_dim_to_check=x.shape[-1],
# num_heads=self.num_heads,
# q_proj_weight=self.q_proj.weight,
# k_proj_weight=self.k_proj.weight,
# v_proj_weight=self.v_proj.weight,
# in_proj_weight=None,
# in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
# bias_k=None,
# bias_v=None,
# add_zero_attn=False,
# dropout_p=0,
# out_proj_weight=self.c_proj.weight,
# out_proj_bias=self.c_proj.bias,
# use_separate_proj_weight=True,
# training=self.training,
# need_weights=False
# )
#
# return x[0]
class CoordConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
padding=1,
stride=1):
super().__init__()
self.conv1 = conv_layer(in_channels + 2, out_channels, kernel_size,
padding, stride)
def add_coord(self, input):
b, _, h, w = input.size()
x_range = torch.linspace(-1, 1, w, device=input.device)
y_range = torch.linspace(-1, 1, h, device=input.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([b, 1, -1, -1])
x = x.expand([b, 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
input = torch.cat([input, coord_feat], 1)
return input
def forward(self, x):
x = self.add_coord(x)
x = self.conv1(x)
return x
class TransformerDecoder(nn.Module):
def __init__(self,
num_layers,
d_model,
nhead,
dim_ffn,
dropout,
return_intermediate=False):
super().__init__()
self.layers = nn.ModuleList([
TransformerDecoderLayer(d_model=d_model,
nhead=nhead,
dim_feedforward=dim_ffn,
dropout=dropout) for _ in range(num_layers)
])
self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model)
self.return_intermediate = return_intermediate
@staticmethod
def pos1d(d_model, length):
"""
:param d_model: dimension of the model
:param length: length of positions
:return: length*d_model position matrix
"""
if d_model % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model))
pe = torch.zeros(length, d_model)
position = torch.arange(0, length).unsqueeze(1)
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
-(math.log(10000.0) / d_model)))
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
return pe.unsqueeze(1) # n, 1, 512
@staticmethod
def pos2d(d_model, height, width):
"""
:param d_model: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
"""
if d_model % 4 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(d_model))
pe = torch.zeros(d_model, height, width)
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = torch.exp(
torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
pos_w = torch.arange(0., width).unsqueeze(1)
pos_h = torch.arange(0., height).unsqueeze(1)
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(
0, 1).unsqueeze(1).repeat(1, height, 1)
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(
0, 1).unsqueeze(1).repeat(1, height, 1)
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(
0, 1).unsqueeze(2).repeat(1, 1, width)
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(
0, 1).unsqueeze(2).repeat(1, 1, width)
return pe.reshape(-1, 1, height * width).permute(2, 1, 0) # hw, 1, 512
def forward(self, vis, txt, pad_mask):
'''
vis: b, 512, h, w
txt: b, L, 512
pad_mask: b, L
'''
B, C, H, W = vis.size()
_, L, D = txt.size()
# position encoding
vis_pos = self.pos2d(C, H, W)
txt_pos = self.pos1d(D, L)
# reshape & permute
vis = vis.reshape(B, C, -1).permute(2, 0, 1)
txt = txt.permute(1, 0, 2)
# forward
output = vis
intermediate = []
for layer in self.layers:
output = layer(output, txt, vis_pos, txt_pos, pad_mask)
if self.return_intermediate:
# HW, b, 512 -> b, 512, HW
intermediate.append(self.norm(output).permute(1, 2, 0))
if self.norm is not None:
# HW, b, 512 -> b, 512, HW
output = self.norm(output).permute(1, 2, 0)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
# [output1, output2, ..., output_n]
return intermediate
else:
# b, 512, HW
return output
return output
class TransformerDecoderLayer(nn.Module):
def __init__(self,
d_model=512,
nhead=9,
dim_feedforward=2048,
dropout=0.1):
super().__init__()
# Normalization Layer
self.self_attn_norm = nn.LayerNorm(d_model)
self.cross_attn_norm = nn.LayerNorm(d_model)
# Attention Layer
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model,
nhead,
dropout=dropout,
kdim=d_model,
vdim=d_model)
# FFN
self.ffn = nn.Sequential(nn.Linear(d_model, dim_feedforward),
nn.ReLU(True), nn.Dropout(dropout),
nn.LayerNorm(dim_feedforward),
nn.Linear(dim_feedforward, d_model))
# LayerNorm & Dropout
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos.to(tensor.device)
def forward(self, vis, txt, vis_pos, txt_pos, pad_mask):
'''
vis: 26*26, b, 512
txt: L, b, 512
vis_pos: 26*26, 1, 512
txt_pos: L, 1, 512
pad_mask: b, L
'''
# Self-Attention
vis2 = self.norm1(vis)
q = k = self.with_pos_embed(vis2, vis_pos)
vis2 = self.self_attn(q, k, value=vis2)[0]
vis2 = self.self_attn_norm(vis2)
vis = vis + self.dropout1(vis2)
# Cross-Attention
vis2 = self.norm2(vis)
vis2 = self.multihead_attn(query=self.with_pos_embed(vis2, vis_pos),
key=self.with_pos_embed(txt, txt_pos),
value=txt,
key_padding_mask=pad_mask)[0]
vis2 = self.cross_attn_norm(vis2)
vis = vis + self.dropout2(vis2)
# FFN
vis2 = self.norm3(vis)
vis2 = self.ffn(vis2)
vis = vis + self.dropout3(vis2)
return vis
class Text_Projector(nn.Module):
def __init__(self, args, in_channels=[512, 1024, 1024],
out_channels=[256, 512, 1024]):
super(Text_Projector, self).__init__()
self.proj = linear_layer(args, in_channels[2], out_channels[2])
self.ReLU = nn.ReLU(True)
def forward(self, text):
text = self.ReLU(text + self.proj(text))
return text
class Image_Projector(nn.Module):
def __init__(self, args, in_channels=[512, 1024, 1024],
out_channels=[256, 512, 1024]):
super(Image_Projector, self).__init__()
self.proj = linear_layer(args, in_channels[0], out_channels[2])
self.ReLU = nn.ReLU(True)
def forward(self, image):
image = self.ReLU(image + self.proj(image))
return image
class Adapter(nn.Module):
def __init__(self, c_in, reduction=4):
super(Adapter, self).__init__()
self.fc = nn.Sequential(
nn.Linear(c_in, c_in // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(c_in // reduction, c_in, bias=False),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.fc(x)
return x
class GAP(nn.Module):
def __init__(self, kernel):
super(GAP, self).__init__()
self.k = kernel
# self.fc = nn.Linear(512, 1024)
def forward(self, x):
x = F.adaptive_avg_pool2d(x, self.k)
return x.squeeze(-1).squeeze(-1)
class AdaptiveSpatialFeatureFusion(nn.Module):
def __init__(self, args, in_channels=[512, 1024, 1024],
out_channels=[256, 512, 1024]):
super(AdaptiveSpatialFeatureFusion, self).__init__()
self.weight = nn.LayerNorm(out_channels[2])
self.proj = linear_layer(args, in_channels[0], out_channels[2])
def forward(self, feature_map1, feature_map2):
# feature_map1 : b, 1024, 1, 1
# feature_map2 : b, 512, 1, 1
feature_map2 = self.proj(feature_map2.squeeze(-1).squeeze(-1))
feature_map1 = feature_map1.squeeze(-1).squeeze(-1)
weights1 = torch.norm(feature_map1, dim=1).unsqueeze(-1)
weights2 = torch.norm(feature_map2, dim=1).unsqueeze(-1)
weights1 = weights1 / (weights1 + weights2)
weights2 = 1 - weights1
fused_feature_map = weights1 * feature_map1 + weights2 * feature_map2
# b, 1024
return fused_feature_map
class ModifiedAttentionPool2d(nn.Module):
def __init__(self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None):
super().__init__()
self.spacial_dim = spacial_dim
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
# residual
self.connect = nn.Sequential(
nn.Conv2d(embed_dim, output_dim, 1, stride=1, bias=False),
nn.BatchNorm2d(output_dim))
def resize_pos_embed(self, pos_embed, input_shpae):
"""Resize pos_embed weights.
Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor): Position embedding weights.
input_shpae (tuple): Tuple for (downsampled input image height,
downsampled input image width).
pos_shape (tuple): The resolution of downsampled origin training
image.
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``
Return:
torch.Tensor: The resized pos_embed of shape [B, C, L_new]
"""
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
pos_h = pos_w = self.spacial_dim
cls_token_weight = pos_embed[:, 0]
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = F.interpolate(pos_embed_weight,
size=input_shpae,
align_corners=False,
mode='bicubic')
cls_token_weight = cls_token_weight.unsqueeze(1)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
# pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
return pos_embed_weight.transpose(-2, -1)
def forward(self, x):
B, C, H, W = x.size()
res = self.connect(x)
x = x.reshape(B, C, -1) # NC(HW)
# x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(1+HW)
pos_embed = self.positional_embedding.unsqueeze(0)
pos_embed = self.resize_pos_embed(pos_embed, (H, W)) # NC(HW)
x = x + pos_embed.to(x.dtype) # NC(HW)
x = x.permute(2, 0, 1) # (HW)NC
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False)
xt = x[0]
x = x.permute(1, 2, 0).reshape(B, -1, H, W)
x = x + res
x = F.relu(x, True)
return x, xt
# modified
class FPN(nn.Module):
def __init__(self, args,
in_channels=[512, 1024, 1024],
out_channels=[256, 512, 1024, 1024]):
super(FPN, self).__init__()
input_resolution = args.input_size
heads = args.heads
output_dim = args.output_dim
embed_dim = args.emb_dim
# image projection
self.attn = ModifiedAttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
# text projection
self.txt_proj = linear_layer(args, in_channels[2], out_channels[2])
# fusion 1: v5 & seq -> f_5: b, 1024, 13, 13
self.f1_v_proj = conv_layer(in_channels[2], out_channels[2], 1, 0)
self.norm_layer = nn.Sequential(nn.BatchNorm2d(out_channels[2]),
nn.ReLU(True))
# fusion 2: v4 & fm -> f_4: b, 512, 26, 26
self.f2_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1)
self.f2_cat = conv_layer(out_channels[2] + out_channels[1],
out_channels[1], 1, 0)
# fusion 3: v3 & fm_mid -> f_3: b, 512, 52, 52
self.f3_v_proj = conv_layer(in_channels[0], out_channels[0], 3, 1)
self.f3_cat = conv_layer(out_channels[0] + out_channels[1],
out_channels[1], 1, 0)
# fusion 4: f_3 & f_4 & f_5 -> fq: b, 256, 26, 26
self.f4_proj5 = conv_layer(out_channels[2], out_channels[1], 3, 1)
self.f4_proj4 = conv_layer(out_channels[1], out_channels[1], 3, 1)
self.f4_proj3 = conv_layer(out_channels[1], out_channels[1], 3, 1)
# aggregation
self.aggr = conv_layer(3 * out_channels[1], out_channels[1], 1, 0)
self.coordconv = nn.Sequential(
CoordConv(out_channels[1], out_channels[1], 3, 1),
conv_layer(out_channels[1], out_channels[3], 3, 1))
def forward(self, imgs, text):
# v3, v4, v5: 256, 52, 52 / 512, 26, 26 / 1024, 13, 13
v3, v4, v5 = imgs
# fusion 1: b, 1024, 13, 13
# text projection: b, 1024 -> b, 1024
v5, _ = self.attn(v5)
text_ = self.txt_proj(text)
state = text_.unsqueeze(-1).unsqueeze(
-1)# b, 1024, 1, 1
f5 = self.f1_v_proj(v5) # b, 1024, 7, 7
f5 = self.norm_layer(f5 * state)
# fusion 2: b, 512, 26, 26
f4 = self.f2_v_proj(v4)
# f4 = f4.repeat(w2,1,1,1)
f5_ = F.interpolate(f5, scale_factor=2, mode='bilinear')
f4 = self.f2_cat(torch.cat([f4, f5_], dim=1))
# fusion 3: b, 256, 26, 26
f3 = self.f3_v_proj(v3)
f3 = F.avg_pool2d(f3, 2, 2)
# f3 = f3.repeat(w2, 1, 1, 1)
f3 = self.f3_cat(torch.cat([f3, f4], dim=1))
# fusion 4: b, 512, 13, 13 / b, 512, 26, 26 / b, 512, 26, 26
fq5 = self.f4_proj5(f5)
fq4 = self.f4_proj4(f4)
fq3 = self.f4_proj3(f3)
# query
fq5 = F.interpolate(fq5, scale_factor=2, mode='bilinear')
fq = torch.cat([fq3, fq4, fq5], dim=1)
fq = self.aggr(fq)
fq = self.coordconv(fq)
# fqq = fq.reshape(w1, w2, fq.shape[1], fq.shape[2], fq.shape[3])
# b, 512, 26, 26
# elif text.shape[0] != v3.shape[0]:
#
# text = self.txt_proj(text)
# state = text.unsqueeze(-1).unsqueeze(
# -1) # b, 1024, 1, 1
# state = state.view(v5.shape[0], int(text.shape[0] / v5.shape[0]), state.shape[1], state.shape[2], state.shape[3])
#
# f5 = self.f1_v_proj(v5) # b, 1024, 7, 7
# f5 = f5.unsqueeze(1)
# f5_ = f5 * state
# f5_ = f5_.view(-1, f5.shape[2], f5.shape[3], f5.shape[4])
# f5 = self.norm_layer(f5_)
# # fusion 2: b, 512, 26, 26
# f4 = self.f2_v_proj(v4)
# # f4 = f4.repeat(w2,1,1,1)
#
# f5_ = F.interpolate(f5, scale_factor=2, mode='bilinear')
# f4 = f4.repeat(int(f5_.shape[0] / f4.shape[0]), 1, 1, 1)
# f4 = self.f2_cat(torch.cat([f4, f5_], dim=1))
#
# # fusion 3: b, 256, 26, 26
# f3 = self.f3_v_proj(v3)
# f3 = F.avg_pool2d(f3, 2, 2)
# # f3 = f3.repeat(w2, 1, 1, 1)
# f3 = f3.repeat(int(f5_.shape[0] / f3.shape[0]), 1, 1, 1)
# f3 = self.f3_cat(torch.cat([f3, f4], dim=1))
# # fusion 4: b, 512, 13, 13 / b, 512, 26, 26 / b, 512, 26, 26
# fq5 = self.f4_proj5(f5)
# fq4 = self.f4_proj4(f4)
# fq3 = self.f4_proj3(f3)
# # query
# fq5 = F.interpolate(fq5, scale_factor=2, mode='bilinear')
# fq = torch.cat([fq3, fq4, fq5], dim=1)
# fq = self.aggr(fq)
# fq = self.coordconv(fq)
return fq
class ViTFPN(nn.Module):
def __init__(self, image_resolution,
in_channels=[512, 768, 768],
out_channels=[768, 768, 768, 512]):
super(ViTFPN, self).__init__()
# text projection
self.txt_proj = linear_layer(in_channels[0], out_channels[1])
# fusion 1: v5 & seq -> f_5: b, 1024, 13, 13
self.f1_v_proj = conv_layer(in_channels[1], out_channels[1], 1, 0)
self.norm_layer = nn.Sequential(nn.BatchNorm2d(out_channels[1]),
nn.ReLU(True))
# fusion 2: v4 & fm -> f_4: b, 512, 26, 26
self.f2_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1)
self.f2_cat = conv_layer(out_channels[0] + out_channels[0],
out_channels[0], 1, 0)
# fusion 3: v3 & fm_mid -> f_3: b, 512, 52, 52
self.f3_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1)
self.f3_cat = conv_layer(out_channels[0] + out_channels[1],
out_channels[1], 1, 0)
# fusion 4: f_3 & f_4 & f_5 -> fq: b, 256, 26, 26
self.f4_proj5 = conv_layer(out_channels[1], out_channels[0], 3, 1)
self.f4_proj4 = conv_layer(out_channels[0], out_channels[0], 3, 1)
self.f4_proj3 = conv_layer(out_channels[1], out_channels[1], 3, 1)
# aggregation
self.aggr = conv_layer(3 * out_channels[0], out_channels[0], 1, 0)
self.coordconv = nn.Sequential(
CoordConv(out_channels[0], out_channels[0], 3, 1),
conv_layer(out_channels[0], out_channels[-1], 3, 1))
self.attnpool = AttentionPool2d(image_resolution // 32, out_channels[-1],
8, out_channels[-1])
def forward(self, imgs, state, vis):
# v1 / v2 / b, 49, 1024/ b, 196, 512
v3, v4, v5 = imgs
# fusion 1: b, 1024, 13, 13
# text projection: b, 1024 -> b, 1024
state = self.txt_proj(state)
state = state.unsqueeze(-1).unsqueeze(
-1)# b, 1024, 1, 1
f5 = self.f1_v_proj(v5)
f5 = self.norm_layer(f5 * state)
# fusion 2: b, 512, 26, 26
f4 = self.f2_v_proj(v4)
b, c, h, w = f4.size()
f5_ = F.interpolate(f5, (h, w), mode='bilinear')
f4 = self.f2_cat(torch.cat([f4, f5_], dim=1))
# fusion 3: b, 256, 26, 26
f3 = self.f3_v_proj(v3)
f3 = F.avg_pool2d(f3, 2, 2)
# f3 = f3.repeat(w2, 1, 1, 1)
f3 = self.f3_cat(torch.cat([f3, f4], dim=1))
# fusion 4: b, 512, 13, 13 / b, 512, 26, 26 / b, 512, 26, 26
fq5 = self.f4_proj5(f5)
fq4 = self.f4_proj4(f4)
fq3 = self.f4_proj3(f3)
# query
fq5 = F.interpolate(fq5, (h, w), mode='bilinear')
fq = torch.cat([fq3, fq4, fq5], dim=1)
fq = self.aggr(fq)
if not vis:
fq = self.coordconv(fq)
fq = self.attnpool(fq)
# b, 512, 26, 26
return fq