LuojiaHOG / cisen /model /segmenter.py
aleo1's picture
Upload 41 files
bb6012a verified
raw
history blame
73.1 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from .clip import build_model, build_promptlearner, build_modified_model, PromptLearner, build_lclip_model
from torch.cuda.amp import autocast as autocast
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
from timm.models.layers import variance_scaling_
from einops import rearrange, repeat
from loguru import logger
from transformers import AlignProcessor, AlignModel
from sklearn.metrics import classification_report
from huggingface_hub import PyTorchModelHubMixin
from .layers import FPN, TransformerDecoder, ViTFPN, AdaptiveSpatialFeatureFusion, Text_Projector, Image_Projector, Adapter, GAP
from cisen.model.clip import CLIP
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def trunc_normal_(tensor, mean=0.0, std=1.0):
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
class CISEN_vit(nn.Module, PyTorchModelHubMixin):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder & Label Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model.state_dict(), cfg.word_len)
self.backbone = backbone.float()
self.patch_emb = image_resolution // patch_size
cfg.image_resolution = image_resolution
cfg.input_size = image_resolution
cfg.heads = vision_heads // 32
cfg.emb_dim = vision_width
cfg.output_dim = embed_dim
# multi-scale adapter
# Multi-Modal FPN
self.FPN = ViTFPN(image_resolution, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Fined-grained Fusion
# self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
# d_model=cfg.vis_dim,
# nhead=cfg.num_head,
# dim_ffn=cfg.dim_ffn,
# dropout=cfg.dropout,
# return_intermediate=cfg.intermediate)
# image-text transformer
# self.trans = nn.Linear(1024, 1024)
self.ADP = Adapter(cfg.output_dim, 4)
# parameter
self.ratio = cfg.ratio
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.share_temperature = True
self.ce = nn.CrossEntropyLoss()
self.ms_adaptor = nn.ModuleList(
[
nn.Sequential(
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
nn.GroupNorm(32, cfg.emb_dim),
nn.GELU(),
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
),
nn.Sequential(
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
),
nn.Sequential(
nn.Identity(),
),
nn.Sequential(
nn.MaxPool2d(2),
),
]
)
self.ms_adaptor.apply(self.init_adaptor)
def init_adaptor(self, m):
if isinstance(m, nn.Conv2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.ConvTranspose2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# self.fc = nn.Linear(512, cfg.num_classes)
def IT_loss(self, image_features, text_features):
# b, 1024 / b, 1024
batch = image_features.shape[0]
# # normalized features
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
return contrastive_loss
def forward(self, img, txt, stage):
if stage == '1st':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1-self.ratio) * image
# b, 1024
# fq_t = self.FPN(vis, x)
#
# fv_t = self.gap(fq_t)
loss1 = self.IT_loss(x, text)
loss = loss1
ft = text
fi = x
fv = None
elif stage == '2nd':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1 - self.ratio) * image
# Construct multi-scale feats
vis_trans = []
for i in range(len(self.ms_adaptor)):
x_ = rearrange(
vis[i],
"b (h w) c -> b c h w",
h=self.patch_emb,
w=self.patch_emb,
).contiguous()
feats = self.ms_adaptor[i](x_)
vis_trans.append(feats)
# fq = self.FPN(vis, x_t)
fv_t = self.FPN(vis_trans[1:], x, False)
# fv_t = self.gap(fq_t)
# b, 1024
loss2 = self.IT_loss(fv_t, text)
loss = (loss2)
fv = fv_t
ft = text
fi = x
return loss, fv, fi, ft
def visualize(self, img, txt):
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1 - self.ratio) * image
# Construct multi-scale feats
vis_trans = []
for i in range(len(self.ms_adaptor)):
x_ = rearrange(
vis[i],
"b (h w) c -> b c h w",
h=self.patch_emb,
w=self.patch_emb,
).contiguous()
feats = self.ms_adaptor[i](x_)
vis_trans.append(feats)
# fq = self.FPN(vis, x_t)
fv_t = self.FPN(vis_trans[1:], x, True)
ft_t = self.FPN(vis_trans[1:], text, True)
return vis, fv_t, ft_t
class CISEN_rsvit(nn.Module, PyTorchModelHubMixin):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder & Label Encoder
clip_model = torch.load(cfg.clip_pretrain,
map_location="cpu")
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len)
self.backbone = backbone.float()
self.patch_emb = image_resolution // patch_size
cfg.image_resolution = image_resolution
cfg.input_size = image_resolution
cfg.heads = vision_heads // 32
cfg.emb_dim = vision_width
cfg.output_dim = embed_dim
# multi-scale adapter
# Multi-Modal FPN
self.FPN = ViTFPN(image_resolution, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Fined-grained Fusion
# self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
# d_model=cfg.vis_dim,
# nhead=cfg.num_head,
# dim_ffn=cfg.dim_ffn,
# dropout=cfg.dropout,
# return_intermediate=cfg.intermediate)
# image-text transformer
# self.trans = nn.Linear(1024, 1024)
self.ADP = Adapter(cfg.output_dim, 4)
# parameter
self.ratio = cfg.ratio
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.share_temperature = True
self.ce = nn.CrossEntropyLoss()
self.ms_adaptor = nn.ModuleList(
[
nn.Sequential(
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
nn.GroupNorm(32, cfg.emb_dim),
nn.GELU(),
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
),
nn.Sequential(
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
),
nn.Sequential(
nn.Identity(),
),
nn.Sequential(
nn.MaxPool2d(2),
),
]
)
self.ms_adaptor.apply(self.init_adaptor)
def init_adaptor(self, m):
if isinstance(m, nn.Conv2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.ConvTranspose2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# self.fc = nn.Linear(512, cfg.num_classes)
def IT_loss(self, image_features, text_features):
# b, 1024 / b, 1024
batch = image_features.shape[0]
# # normalized features
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
return contrastive_loss
def image_encode(self, img):
vis, image = self.backbone.encode_image(img)
x = self.ADP(image)
x = self.ratio * x + (1 - self.ratio) * image
return x
def text_encode(self, txt):
word, text = self.backbone.encode_text(txt)
return text
def forward(self, img, txt, stage):
if stage == '1st':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1-self.ratio) * image
# b, 1024
# fq_t = self.FPN(vis, x)
#
# fv_t = self.gap(fq_t)
loss1 = self.IT_loss(x, text)
loss = loss1
ft = text
fi = x
fv = None
elif stage == '2nd':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1 - self.ratio) * image
# Construct multi-scale feats
vis_trans = []
for i in range(len(self.ms_adaptor)):
x_ = rearrange(
vis[i],
"b (h w) c -> b c h w",
h=self.patch_emb,
w=self.patch_emb,
).contiguous()
feats = self.ms_adaptor[i](x_)
vis_trans.append(feats)
# fq = self.FPN(vis, x_t)
fv_t = self.FPN(vis_trans[1:], x, False)
# fv_t = self.gap(fq_t)
# b, 1024
loss2 = self.IT_loss(fv_t, text)
loss = (loss2)
fv = fv_t
ft = text
fi = x
return loss, fv, fi, ft
def visualize(self, img):
vis, image = self.backbone.encode_image(img)
x = self.ADP(image)
x = self.ratio * x + (1 - self.ratio) * image
# Construct multi-scale feats
vis_trans = []
for i in range(len(self.ms_adaptor)):
x_ = rearrange(
vis[i],
"b (h w) c -> b c h w",
h=self.patch_emb,
w=self.patch_emb,
).contiguous()
feats = self.ms_adaptor[i](x_)
vis_trans.append(feats)
fv_t = self.FPN(vis_trans[1:], x, True)
return vis, fv_t
class CISEN_vit(nn.Module, PyTorchModelHubMixin):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder & Label Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model.state_dict(), cfg.word_len)
self.backbone = backbone.float()
self.patch_emb = image_resolution // patch_size
cfg.image_resolution = image_resolution
cfg.input_size = image_resolution
cfg.heads = vision_heads // 32
cfg.emb_dim = vision_width
cfg.output_dim = embed_dim
# multi-scale adapter
# Multi-Modal FPN
self.FPN = ViTFPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Fined-grained Fusion
# self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
# d_model=cfg.vis_dim,
# nhead=cfg.num_head,
# dim_ffn=cfg.dim_ffn,
# dropout=cfg.dropout,
# return_intermediate=cfg.intermediate)
# image-text transformer
# self.trans = nn.Linear(1024, 1024)
self.ADP = Adapter(cfg.output_dim, 4)
# parameter
self.ratio = cfg.ratio
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.share_temperature = True
self.ce = nn.CrossEntropyLoss()
self.ms_adaptor = nn.ModuleList(
[
nn.Sequential(
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
nn.GroupNorm(32, cfg.emb_dim),
nn.GELU(),
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
),
nn.Sequential(
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
),
nn.Sequential(
nn.Identity(),
),
nn.Sequential(
nn.MaxPool2d(2),
),
]
)
self.ms_adaptor.apply(self.init_adaptor)
def init_adaptor(self, m):
if isinstance(m, nn.Conv2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.ConvTranspose2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# self.fc = nn.Linear(512, cfg.num_classes)
def IT_loss(self, image_features, text_features):
# b, 1024 / b, 1024
batch = image_features.shape[0]
# # normalized features
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
return contrastive_loss
def forward(self, img, txt, stage):
if stage == '1st':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1-self.ratio) * image
# b, 1024
# fq_t = self.FPN(vis, x)
#
# fv_t = self.gap(fq_t)
loss1 = self.IT_loss(x, text)
loss = loss1
ft = text
fi = x
fv = None
elif stage == '2nd':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1 - self.ratio) * image
# Construct multi-scale feats
vis_trans = []
for i in range(len(self.ms_adaptor)):
x_ = rearrange(
vis[i],
"b (h w) c -> b c h w",
h=self.patch_emb,
w=self.patch_emb,
).contiguous()
feats = self.ms_adaptor[i](x_)
vis_trans.append(feats)
# fq = self.FPN(vis, x_t)
fv_t = self.FPN(vis_trans[1:], x, False)
# fv_t = self.gap(fq_t)
# b, 1024
loss2 = self.IT_loss(fv_t, text)
loss = (loss2)
fv = fv_t
ft = text
fi = x
return loss, fv, fi, ft
def visualize(self, img, txt):
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1 - self.ratio) * image
# Construct multi-scale feats
vis_trans = []
for i in range(len(self.ms_adaptor)):
x_ = rearrange(
vis[i],
"b (h w) c -> b c h w",
h=self.patch_emb,
w=self.patch_emb,
).contiguous()
feats = self.ms_adaptor[i](x_)
vis_trans.append(feats)
# fq = self.FPN(vis, x_t)
fv_t = self.FPN(vis_trans[1:], x, True)
ft_t = self.FPN(vis_trans[1:], text, True)
return vis, fv_t, ft_t
class CISEN_rsvit_classification(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder & Label Encoder
clip_model = torch.load(cfg.clip_pretrain,
map_location="cpu")
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len)
self.backbone = backbone.float()
self.patch_emb = image_resolution // patch_size
num_classes_fc = 512
num_classes_output = 10
self.num_classes_fc = num_classes_fc # Number of classes for fully connected layer
self.num_classes_output = num_classes_output # Number of classes for output layer
# Add a fully connected layer
self.fc = nn.Linear(in_features=cfg.vis_dim, out_features=num_classes_fc)
# Add an output layer for multi-label classification
self.output_layer = nn.Linear(in_features=num_classes_fc, out_features=num_classes_output)
self.criterion = nn.BCEWithLogitsLoss()
cfg.image_resolution = image_resolution
cfg.input_size = image_resolution
cfg.heads = vision_heads // 32
cfg.emb_dim = vision_width
cfg.output_dim = embed_dim
def IT_loss(self, labels, labels_pre):
labels = labels.squeeze(1)
loss = self.criterion(labels_pre, labels)
return loss
def forward(self, img, labels):
_, image_features = self.backbone.encode_image(img)
# Fully connected layer
fc_output = self.fc(image_features)
# Apply ReLU activation function
fc_output = F.relu(fc_output)
# Output layer for multi-label classification
labels_pre = self.output_layer(fc_output)
loss2 = self.IT_loss(labels, labels_pre)
return labels_pre, loss2
class CISEN_new(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder & Label Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
backbone, image_resolution, vision_heads, embed_dim, vision_width, _ = build_model(clip_model.state_dict(), cfg.word_len)
self.backbone = backbone.float()
cfg.input_size = image_resolution
cfg.heads = vision_heads
cfg.emb_dim = vision_width * 32
cfg.output_dim = embed_dim
# Multi-Modal FPN
self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Fined-grained Fusion
# self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
# d_model=cfg.vis_dim,
# nhead=cfg.num_head,
# dim_ffn=cfg.dim_ffn,
# dropout=cfg.dropout,
# return_intermediate=cfg.intermediate)
# image-text transformer
# self.trans = nn.Linear(1024, 1024)
self.ADP = Adapter(cfg.output_dim, 4)
self.gap = GAP((1,1))
# parameter
self.ratio = cfg.ratio
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.share_temperature = True
self.margin = 1
self.eps = 1e-3
self.ce = nn.CrossEntropyLoss()
#1st stage
self.lamda1 = cfg.lamda1
self.lamda2 = cfg.lamda2
self.avg = nn.AdaptiveAvgPool2d((1,1))
# self.fc = nn.Linear(512, cfg.num_classes)
def IT_loss(self, image_features, text_features):
# b, 1024 / b, 1024
batch = image_features.shape[0]
# # normalized features
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
return contrastive_loss
def forward(self, img, txt, stage):
if stage == '1st':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1-self.ratio) * image
# b, 1024
# fq_t = self.FPN(vis, x)
#
# fv_t = self.gap(fq_t)
loss1 = self.IT_loss(x, text)
loss = loss1
ft = text
fi = x
fv = None
elif stage == '2nd':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1 - self.ratio) * image
# x_t = self.trans(x)
# fq = self.FPN(vis, x_t)
fq_t = self.FPN(vis, x)
fv_t = self.gap(fq_t)
# b, 1024
loss2 = self.IT_loss(fv_t, text)
loss = (loss2)
fv = fv_t
ft = text
fi = x
elif stage == '3rd':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(text)
ratio = 0.2
x = ratio * x + (1 - ratio) * text
# x_t = self.trans(x)
# fq = self.FPN(vis, x_t)
# b, 1024
loss1 = self.IT_loss(image, x)
loss = loss1
fv = None
ft = x
fi = image
elif stage == '4th':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
# x = self.ADP(image)
# ratio = 0.2
# x = ratio * x + (1 - ratio) * text
fq_t = self.FPN(vis, image)
fv_t = self.gap(fq_t)
ratio_1 = 0.2
# b, 1024
loss2 = self.IT_loss(fv_t, text)
loss = loss2
fv = fv_t
fi = None
ft = text
elif stage == '5th':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
ratio = 0.2
x = ratio * x + (1 - ratio) * image
y = self.ADP_t(text)
ratio_1 = 0.2
y = ratio * y + (1 - ratio_1) * text
fq_t = self.FPN(vis, image)
fv_t = self.gap(fq_t)
# b, 1024
loss2 = self.IT_loss(fv_t, y)
loss = loss2
fv = fv_t
fi = x
ft = y
return loss, fv, fi, ft
class CISEN_lclip(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder & Label Encoder
clip_model = torch.load(cfg.clip_pretrain,
map_location="cpu")
# print(type(clip_model))
backbone, image_resolution, vision_heads, embed_dim, vision_width, _ = build_lclip_model(clip_model, load_from_clip=True)
self.backbone = backbone.float()
cfg.input_size = image_resolution
cfg.heads = vision_heads // 32
cfg.emb_dim = vision_width
cfg.output_dim = embed_dim
# Multi-Modal FPN
self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Fined-grained Fusion
# self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
# d_model=cfg.vis_dim,
# nhead=cfg.num_head,
# dim_ffn=cfg.dim_ffn,
# dropout=cfg.dropout,
# return_intermediate=cfg.intermediate)
# image-text transformer
# self.trans = nn.Linear(1024, 1024)
self.ADP = Adapter(cfg.output_dim, 4)
self.gap = GAP((1,1))
# parameter
self.ratio = cfg.ratio
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.share_temperature = True
self.margin = 1
self.eps = 1e-3
self.ce = nn.CrossEntropyLoss()
#1st stage
self.lamda1 = cfg.lamda1
self.lamda2 = cfg.lamda2
self.avg = nn.AdaptiveAvgPool2d((1,1))
# self.fc = nn.Linear(512, cfg.num_classes)
def IT_loss(self, image_features, text_features):
# b, 1024 / b, 1024
batch = image_features.shape[0]
# # normalized features
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
return contrastive_loss
def forward(self, img, txt, stage):
if stage == '1st':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1-self.ratio) * image
# b, 1024
# fq_t = self.FPN(vis, x)
#
# fv_t = self.gap(fq_t)
loss1 = self.IT_loss(x, text)
loss = loss1
ft = text
fi = x
fv = None
elif stage == '2nd':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1 - self.ratio) * image
# x_t = self.trans(x)
# fq = self.FPN(vis, x_t)
fq_t = self.FPN(vis, x)
fv_t = self.gap(fq_t)
# b, 1024
loss2 = self.IT_loss(fv_t, text)
loss = (loss2)
fv = fv_t
ft = text
fi = x
elif stage == '3rd':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
text = self.backbone.encode_text(txt)
x = self.ADP(text)
ratio = 0.2
x = ratio * x + (1 - ratio) * text
# x_t = self.trans(x)
# fq = self.FPN(vis, x_t)
# b, 1024
loss1 = self.IT_loss(image, x)
loss = loss1
fv = None
ft = x
fi = image
elif stage == '4th':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
# x = self.ADP(image)
# ratio = 0.2
# x = ratio * x + (1 - ratio) * text
fq_t = self.FPN(vis, image)
fv_t = self.gap(fq_t)
ratio_1 = 0.2
# b, 1024
loss2 = self.IT_loss(fv_t, text)
loss = loss2
fv = fv_t
fi = None
ft = text
elif stage == '5th':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
ratio = 0.2
x = ratio * x + (1 - ratio) * image
y = self.ADP_t(text)
ratio_1 = 0.2
y = ratio * y + (1 - ratio_1) * text
fq_t = self.FPN(vis, image)
fv_t = self.gap(fq_t)
# b, 1024
loss2 = self.IT_loss(fv_t, y)
loss = loss2
fv = fv_t
fi = x
ft = y
return loss, fv, fi, ft
class GeoRSCLIP(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder & Label Encoder
clip_model = torch.load(cfg.clip_pretrain,
map_location="cpu")
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len)
self.backbone = backbone.float()
def forward(self, img, txt, stage):
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
loss = None
ft = text
fi = image
fv = None
return loss, fv, fi, ft
class CISEN(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder & Label Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
# Multi-Modal FPN
self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Fined-grained Fusion
self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
d_model=cfg.vis_dim,
nhead=cfg.num_head,
dim_ffn=cfg.dim_ffn,
dropout=cfg.dropout,
return_intermediate=cfg.intermediate)
# adaptively aggretation
self.ASFF = AdaptiveSpatialFeatureFusion(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# text projector
self.projT = Text_Projector(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# image projector
# self.projI = Image_Projector(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# parameter
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.share_temperature = True
self.margin = 1
self.eps = 1e-3
self.ce = nn.CrossEntropyLoss()
#1st stage
self.lamda1 = cfg.lamda1
self.lamda2 = cfg.lamda2
self.beta1 = cfg.beta1
self.beta2 = cfg.beta2
self.avg = nn.AdaptiveAvgPool2d((1,1))
# self.fc = nn.Linear(512, cfg.num_classes)
#2nd stage
self.pos_samples = cfg.pos_samples
self.neg_samples = cfg.neg_samples
def IT_loss(self, image_features, text_features):
# b, 1024 / b, 1024
batch = image_features.shape[0]
# # normalized features
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
return contrastive_loss
def IET_loss(self, image_features, text_features, pos_samples, beta):
# b, 1024 / b, 1024
# # normalized features
image_features = [image_feature / image_feature.norm(dim=-1,
keepdim=True) for image_feature in image_features]
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
# logits_per_image = [logit_scale * image_feature @ text_features.t() for image_feature in image_features]
logits_per_image = [logit_scale * torch.sum(torch.mul(image_feature, text_features),1) for image_feature in image_features]
logits_per_image = torch.stack(logits_per_image).t()
b = logits_per_image.shape[0]
loss1 = torch.norm(text_features - image_features[0])
positive_tagsT = torch.zeros(b,len(image_features)).to(text_features.device)
negative_tagsT = torch.zeros(b,len(image_features)).to(text_features.device)
positive_tagsT[:, 0 : pos_samples + 1] = 1
negative_tagsT[:, pos_samples + 1 : -1] = 1
maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1)
pos_score_matT = logits_per_image * positive_tagsT
neg_score_matT = logits_per_image * negative_tagsT
IW_pos3T = pos_score_matT.unsqueeze(1)
IW_neg3T = neg_score_matT.unsqueeze(-1)
OT = 1 + IW_neg3T - IW_pos3T
O_maskT = maskT * OT
diffT = torch.clamp(O_maskT, 0)
violationT = torch.sign(diffT).sum(1).sum(1)
diffT = diffT.sum(1).sum(1)
lossT = torch.mean(diffT / (violationT + self.eps))
loss = beta * loss1 + lossT
return loss
def test_IET_loss(self, image_features, text_features, pos_samples, beta1, beta2):
# text_features: enhanced_features
# b, 1024 / b, 1024
# # normalized features
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
image_features = image_features.unsqueeze(1)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
# image_features = image_features.expand(-1, text_features.shape[1], -1)
logits_per_image = logit_scale * torch.matmul(image_features, text_features.transpose(1, 2))
logits_per_image = logits_per_image.squeeze(1)
# logits_per_image = logit_scale * image_features @ text_features.t()
# logits_per_image = [logit_scale * image_feature @ text_features.t() for image_feature in image_features]
b = logits_per_image.shape[0]
# loss1 = torch.norm(text_features[:, 0, :] - image_features.squeeze(1))
positive_tagsT = torch.zeros(b, text_features.shape[1]).to(text_features.device)
negative_tagsT = torch.zeros(b, text_features.shape[1]).to(text_features.device)
positive_tagsT[:, 0 : pos_samples + 1] = 1
negative_tagsT[:, pos_samples + 1 : -1] = 1
maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1)
pos_score_matT = logits_per_image * positive_tagsT
neg_score_matT = logits_per_image * negative_tagsT
IW_pos3T = pos_score_matT.unsqueeze(1)
IW_neg3T = neg_score_matT.unsqueeze(-1)
OT = 1 + IW_neg3T - IW_pos3T
O_maskT = maskT * OT
diffT = torch.clamp(O_maskT, 0)
violationT = torch.sign(diffT).sum(1).sum(1)
diffT = diffT.sum(1).sum(1)
lossT = torch.mean(diffT / (violationT + self.eps))
# loss = beta1 * loss1 + beta2 * lossT
loss = lossT
return loss
def test_IT_loss(self, image_features, text_features):
# b, 1024 / b, 1024
batch = image_features.shape[0]
# # normalized features
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
image_features = image_features.unsqueeze(1)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * torch.matmul(image_features, text_features.transpose(1, 2))
logits_per_image = logits_per_image.squeeze(1)
# shape = [global_batch_size, global_batch_size]
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
contrastive_loss = self.ce(logits_per_image, contrastive_labels)
return contrastive_loss
def test_forward(self, img, txt):
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# state: b, 1024
# image: b, 512
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
fq = self.FPN(vis, text)
b, c, h, w = fq.size()
# b, 512, 14, 14
ff = self.FGFusion(fq, word, pad_mask)
ff = ff.reshape(b, c, h, w)
f2 = self.avg(ff)
fi = image.unsqueeze(-1).unsqueeze(-1)
fv = self.ASFF(fi, f2)
fi = fi.squeeze(-1).squeeze(-1)
# b, 1024
ft = self.projT(text)
loss1 = self.IT_loss(fi, ft)
loss2 = self.IT_loss(fv, ft)
loss = self.lamda1 * loss1 + self.lamda2 * loss2
return loss, fv, ft, fi
def forward(self, img, txt, stage):
if stage == '1st':
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# state: b, 1024
# image: b, 512
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
fq = self.FPN(vis, text)
b, c, h, w = fq.size()
# b, 512, 14, 14
ff = self.FGFusion(fq, word, pad_mask)
ff = ff.reshape(b, c, h, w)
f2 = self.avg(ff)
fi = image.unsqueeze(-1).unsqueeze(-1)
fv = self.ASFF(fi, f2)
fi = fi.squeeze(-1).squeeze(-1)
# b, 1024
ft = self.projT(text)
loss1 = self.IT_loss(fi, ft)
loss2 = self.IT_loss(fv, ft)
loss = self.lamda1 * loss1 + self.lamda2 * loss2
elif stage == '2nd':
"""
txt: b, num, words
img: b, 3, h, w
"""
# txt = b * num, word
b, num, l = txt.shape[0], txt.shape[1], txt.shape[2]
txt = txt.view(-1, txt.size(-1))
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
b = img.shape[0]
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
fq = self.FPN(vis, text)
# b, 512, 14, 14 (C4)
b, c, h, w = fq.size()
# b, 512, 14, 14
ff = self.FGFusion(fq, word, pad_mask)
ff = ff.reshape(b, c, h, w)
f2 = self.avg(ff)
fi = image.unsqueeze(-1).unsqueeze(-1)
fi_ = fi.repeat(int(f2.shape[0] / fi.shape[0]), 1, 1, 1)
fv = self.ASFF(fi_, f2)
fi = fi.squeeze(-1).squeeze(-1)
# fi_ = fi_.squeeze(-1).squeeze(-1)
# b, 1024
ft = text.view(img.shape[0], int(text.shape[0] / img.shape[0]), -1)[:, 0, :]
fv = fv.view(ft.shape[0], int(text.shape[0] / ft.shape[0]), fv.shape[1])
loss = self.test_IET_loss(fi, fv, self.pos_samples, self.beta1, self.beta2)
elif stage == 'test':
"""
txt: b, num, words
img: b, 3, h, w
"""
txt = txt.permute(1, 0, 2)
# txt = b * num, word
# txt = txt.view(-1, txt.size(-1))
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# state: b, 1024
# image: b, 512
b = img.shape[0]
words = []
texts = []
vis, image = self.backbone.encode_image(img)
for i in range(txt.shape[0]):
word, text = self.backbone.encode_text(txt[i])
words.append(word)
texts.append(text)
fvn = []
# b, 512, 14, 14 (C4)
for i in range(txt.shape[0]):
fq = self.FPN(vis, texts[i])
b, c, h, w = fq.size()
# b, 512, 14, 14
ff = self.FGFusion(fq, words[i], pad_mask[i, :, :])
ff = ff.reshape(b, c, h, w)
f2 = self.avg(ff)
fi = image.unsqueeze(-1).unsqueeze(-1)
fv = self.ASFF(fi, f2)
fi = fi.squeeze(-1).squeeze(-1)
fvn.append(fv)
# b, 1024
ft = self.projT(texts[0])
loss = self.IET_loss(fvn, ft, self.pos_samples, self.beta)
fv = fvn
else:
print('stage should be either 1st or 2nd or test')
# labels = torch.ones(image.shape[0], image.shape[0]).to(image.device)
# labels[:,-1] = 0
# labels[3, :] = 0
# out = self.avg(fq)
# out = out.squeeze(-1).squeeze(-1)
# out = self.fc(out)
return loss, fv, fi, ft
class CRIS(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder & Label Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.backbone, _, _, _, _ = build_model(clip_model.state_dict(), cfg.word_len)
self.backbone = self.backbone.float()
self.Label_encoder = build_promptlearner(clip_model.state_dict()).float()
self.Label_encoder.init_label_emb(cfg.label_path)
# Multi-Modal FPN
self.FPN = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Fined-grained Fusion
self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
d_model=cfg.vis_dim,
nhead=cfg.num_head,
dim_ffn=cfg.dim_ffn,
dropout=cfg.dropout,
return_intermediate=cfg.intermediate)
# adaptively aggretation
self.ASFF = AdaptiveSpatialFeatureFusion(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# text projector
self.projT = Text_Projector(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# parameter
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.share_temperature = True
self.margin = 1
self.eps = 1e-3
self.ce = nn.CrossEntropyLoss()
self.avg = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(512, cfg.num_classes)
def IT_loss(self, image_features, text_features):
# b, 1024 / b, 1024
batch = image_features.shape[0]
# # normalized features
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
text_features = text_features / text_features.norm(dim=-1,
keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
return contrastive_loss
def IL_loss(self, image_features, label_features, labels):
# b, 1024 / K, 1024/ b, K
positive_tagsT = torch.clamp(labels,0.,1.)
negative_tagsT = torch.clamp(-labels,0.,1.)
maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1)
# normalized features
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
label_features = label_features / label_features.norm(dim=-1,
keepdim=True)
# cosine similarity as logits
logit_scale = self.multi_label_logit_scale.exp()
logits_per_image = logit_scale * image_features @ label_features.t()
# logits_per_label = logit_scale * label_features @ image_features.t()
pos_score_matT = logits_per_image * positive_tagsT
neg_score_matT = logits_per_image * negative_tagsT
IW_pos3T = pos_score_matT.unsqueeze(1)
IW_neg3T = neg_score_matT.unsqueeze(-1)
OT = self.margin + IW_neg3T - IW_pos3T
O_maskT = maskT * OT
diffT = torch.clamp(O_maskT, 0)
violationT = torch.sign(diffT).sum(1).sum(1)
diffT = diffT.sum(1).sum(1)
lossT = torch.mean(diffT / (violationT + self.eps))
return lossT
def margin_loss(self, image_features, label_features, labels):
# b, 1024 / K, 1024/ b, K
# normalized features
image_features = image_features / image_features.norm(dim=-1,
keepdim=True)
label_features = label_features / label_features.norm(dim=-1,
keepdim=True)
# cosine similarity as logits
logit_scale = self.multi_label_logit_scale.exp()
logits_per_image = logit_scale * image_features @ label_features.t()
# logits_per_label = logit_scale * label_features @ image_features.t()
image_label_positive_pairs = logits_per_image * labels
image_label_mean_positive = image_label_positive_pairs.sum() / labels.sum()
image_label_negative_pairs = logits_per_image * (1 - labels)
image_label_mean_negative = image_label_negative_pairs.sum() / (logits_per_image.numel() - labels.sum() + self.eps)
contrastive_loss = torch.relu(self.margin - image_label_mean_positive + image_label_mean_negative)
return contrastive_loss
def forward(self, img, txt, target=None):
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# state: b, 1024
# image: b, 512
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
fl = self.Label_encoder(image.device)
# b, 512, 14, 14 (C4)
fq = self.FPN(vis, text)
b, c, h, w = fq.size()
# b, 512, 14, 14
ff = self.FGFusion(fq, word, pad_mask)
# b, 512, 196
ff = ff.reshape(b, c, h, w)
f2 = self.avg(ff)
# b, 1024
f1 = image.unsqueeze(-1).unsqueeze(-1)
fv = self.ASFF(f1, f2)
# b, 1024
ft = self.projT(text)
# labels = torch.ones(image.shape[0], image.shape[0]).to(image.device)
# labels[:,-1] = 0
# labels[3, :] = 0
loss1 = self.IT_loss(fv, ft)
loss2 = self.IL_loss(fv, fl, target)
loss = loss1 + loss2
# out = self.avg(fq)
# out = out.squeeze(-1).squeeze(-1)
# out = self.fc(out)
return loss, fv, ft, fl
class zh_clip(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.backbone = build_modified_model(clip_model.state_dict(), cfg.word_len).float()
self.text_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.chinese)
self.text_lin = nn.Linear(512, 1024)
# Multi-Modal FPN
self.neck = ViTFPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Decoder
self.avg = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(512, cfg.num_classes)
def forward(self, img, word):
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
'''
# padding mask used in decoder
# vis: v1 / v2 / b, 49, 1024/ b, 196, 512
# state: b, 1024
# feat: f1 / f2 / b, 1024, 7, 7/ b, 1024, 7, 7
# cls: c1 / c2 / b, 1024/ b, 512
vis, feat, cls = self.backbone.encode_image(img)
state = self.text_encoder(word.squeeze(1)).logits
state = self.text_lin(state)
# b, 1024, 7, 7 (C5)
fq = self.neck(feat, state)
out = self.avg(fq)
out = out.squeeze(-1).squeeze(-1)
out = self.fc(out)
return out
class poi_clip(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.backbone = build_modified_model(clip_model.state_dict(), cfg.word_len).float()
self.text_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.chinese)
self.text_lin = nn.Linear(512, 1024)
# Multi-Modal FPN
self.neck = ViTFPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Decoder
self.avg = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(512, cfg.num_classes)
def forward(self, img, word):
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
'''
# padding mask used in decoder
# vis: v1 / v2 / b, 49, 1024/ b, 196, 512
# state: b, 1024
# feat: f1 / f2 / b, 1024, 7, 7/ b, 1024, 7, 7
# cls: c1 / c2 / b, 1024/ b, 512
vis, feat, cls = self.backbone.encode_image(img)
state = self.text_encoder(word.squeeze(1)).logits
state = self.text_lin(state)
# b, 1024, 7, 7 (C5)
fq = self.neck(feat, state)
out = self.avg(fq)
out = out.squeeze(-1).squeeze(-1)
out = self.fc(out)
return out
class Clip_hash_model(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
# Multi-Modal FPN
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Decoder
self.avg = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Sequential(
nn.Linear(cfg.fpn_out[1], cfg.hash_dim, bias=True),
nn.Tanh(),
)
self.classifier2 = nn.Sequential(
nn.Linear(cfg.hash_dim, cfg.num_classes)
)
# Hash Module
self.image_module = nn.Sequential(
nn.Linear(cfg.img_dim, cfg.hidden_dim, bias=True),
nn.BatchNorm1d(cfg.hidden_dim),
nn.ReLU(True),
nn.Linear(cfg.hidden_dim, cfg.hash_dim, bias=True),
nn.Tanh()
)
self.text_module = nn.Sequential(
nn.Linear(cfg.txt_dim, cfg.hidden_dim, bias=True),
nn.BatchNorm1d(cfg.hidden_dim),
nn.ReLU(True),
nn.Linear(cfg.hidden_dim, cfg.hash_dim, bias=True),
nn.Tanh()
)
def forward(self, img, word, mask=None):
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
'''
pad_mask = torch.zeros_like(word).masked_fill_(word == 0, 1).bool()
# vis: C3 / C4 / C5
# word: b, length, 512
# state: b, 1024
vis, image = self.backbone.encode_image(img)
word, state = self.backbone.encode_text(word)
# b, 512, 26, 26 (C4)
fq = self.neck(vis, state)
# out_hash: b, code_length
# res: b, classes
out = self.avg(fq)
out = out.squeeze(-1).squeeze(-1)
out_hash = self.classifier(out)
res = self.classifier2(out_hash)
# img_hash: b, code_length
# txt_hash: b, code_length
img_hash = self.image_module(image)
txt_hash = self.text_module(state)
return img_hash, txt_hash, out_hash, res
class Clip_model(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
self.avg = nn.AdaptiveAvgPool2d((1, 1))
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
def forward(self, img, word, mask=None):
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
'''
# vis: C3 / C4 / C5
# word: b, length, 512
# state: b, 1024
pad_mask = torch.zeros_like(word).masked_fill_(word == 0, 1).bool()
vis, image = self.backbone.encode_image(img)
word, state = self.backbone.encode_text(word)
f = self.neck(vis, state)
out = self.avg(f)
out = out.squeeze(-1).squeeze(-1)
image_features = image / image.norm(dim=-1, keepdim=True)
text_features = state / state.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.backbone.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
class CISEN_rsvit_hug(nn.Module, PyTorchModelHubMixin):
def __init__(self, embed_dim, image_resolution, vision_layers, vision_width,
vision_patch_size, context_length, txt_length, vocab_size,
transformer_width, transformer_heads, transformer_layers, patch_size,
output_dim, ratio, emb_dim, fpn_in, fpn_out):
super().__init__()
# Vision & Text Encoder & Label Encoder
vision_heads = vision_width * 32 // 64
backbone = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
vision_patch_size, context_length, txt_length, vocab_size,
transformer_width, transformer_heads, transformer_layers)
self.backbone = backbone.float()
self.patch_emb = image_resolution // patch_size
self.FPN = ViTFPN(image_resolution, in_channels=fpn_in, out_channels=fpn_out)
self.ADP = Adapter(output_dim, 4)
# parameter
self.ratio = ratio
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.share_temperature = True
self.ce = nn.CrossEntropyLoss()
self.ms_adaptor = nn.ModuleList(
[
nn.Sequential(
nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2),
nn.GroupNorm(32, emb_dim),
nn.GELU(),
nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2),
),
nn.Sequential(
nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2),
),
nn.Sequential(
nn.Identity(),
),
nn.Sequential(
nn.MaxPool2d(2),
),
]
)
self.ms_adaptor.apply(self.init_adaptor)
def init_adaptor(self, m):
if isinstance(m, nn.Conv2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.ConvTranspose2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# self.fc = nn.Linear(512, cfg.num_classes)
def image_encode(self, img):
vis, image = self.backbone.encode_image(img)
x = self.ADP(image)
x = self.ratio * x + (1 - self.ratio) * image
return x
def text_encode(self, txt):
word, text = self.backbone.encode_text(txt)
return text
def forward(self, img, txt):
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
mask: b, 1, h, w
stage: 1st or 2nd stage
'''
# padding mask used in decoder
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7
# word: b, length, 512
# text: b, 1024
# image: b, 1024
vis, image = self.backbone.encode_image(img)
word, text = self.backbone.encode_text(txt)
x = self.ADP(image)
x = self.ratio * x + (1 - self.ratio) * image
# Construct multi-scale feats
vis_trans = []
for i in range(len(self.ms_adaptor)):
x_ = rearrange(
vis[i],
"b (h w) c -> b c h w",
h=self.patch_emb,
w=self.patch_emb,
).contiguous()
feats = self.ms_adaptor[i](x_)
vis_trans.append(feats)
# fq = self.FPN(vis, x_t)
fv_t = self.FPN(vis_trans[1:], x, False)
# fv_t = self.gap(fq_t)
# b, 1024
fv = fv_t
ft = text
fi = x
return fv, fi, ft