Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from .clip import build_model, build_promptlearner, build_modified_model, PromptLearner, build_lclip_model | |
from torch.cuda.amp import autocast as autocast | |
from timm.models.layers import trunc_normal_ as __call_trunc_normal_ | |
from timm.models.layers import variance_scaling_ | |
from einops import rearrange, repeat | |
from loguru import logger | |
from transformers import AlignProcessor, AlignModel | |
from sklearn.metrics import classification_report | |
from huggingface_hub import PyTorchModelHubMixin | |
from .layers import FPN, TransformerDecoder, ViTFPN, AdaptiveSpatialFeatureFusion, Text_Projector, Image_Projector, Adapter, GAP | |
from cisen.model.clip import CLIP | |
def lecun_normal_(tensor): | |
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") | |
def trunc_normal_(tensor, mean=0.0, std=1.0): | |
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) | |
class CISEN_vit(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder & Label Encoder | |
clip_model = torch.jit.load(cfg.clip_pretrain, | |
map_location="cpu").eval() | |
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model.state_dict(), cfg.word_len) | |
self.backbone = backbone.float() | |
self.patch_emb = image_resolution // patch_size | |
cfg.image_resolution = image_resolution | |
cfg.input_size = image_resolution | |
cfg.heads = vision_heads // 32 | |
cfg.emb_dim = vision_width | |
cfg.output_dim = embed_dim | |
# multi-scale adapter | |
# Multi-Modal FPN | |
self.FPN = ViTFPN(image_resolution, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# Fined-grained Fusion | |
# self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, | |
# d_model=cfg.vis_dim, | |
# nhead=cfg.num_head, | |
# dim_ffn=cfg.dim_ffn, | |
# dropout=cfg.dropout, | |
# return_intermediate=cfg.intermediate) | |
# image-text transformer | |
# self.trans = nn.Linear(1024, 1024) | |
self.ADP = Adapter(cfg.output_dim, 4) | |
# parameter | |
self.ratio = cfg.ratio | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.share_temperature = True | |
self.ce = nn.CrossEntropyLoss() | |
self.ms_adaptor = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), | |
nn.GroupNorm(32, cfg.emb_dim), | |
nn.GELU(), | |
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), | |
), | |
nn.Sequential( | |
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), | |
), | |
nn.Sequential( | |
nn.Identity(), | |
), | |
nn.Sequential( | |
nn.MaxPool2d(2), | |
), | |
] | |
) | |
self.ms_adaptor.apply(self.init_adaptor) | |
def init_adaptor(self, m): | |
if isinstance(m, nn.Conv2d): | |
lecun_normal_(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.GroupNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
elif isinstance(m, nn.ConvTranspose2d): | |
lecun_normal_(m.weight) | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
# self.fc = nn.Linear(512, cfg.num_classes) | |
def IT_loss(self, image_features, text_features): | |
# b, 1024 / b, 1024 | |
batch = image_features.shape[0] | |
# # normalized features | |
image_features = image_features / image_features.norm(dim=-1, | |
keepdim=True) | |
text_features = text_features / text_features.norm(dim=-1, | |
keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ text_features.t() | |
logits_per_text = logits_per_image.t() | |
# shape = [global_batch_size, global_batch_size] | |
contrastive_labels = torch.arange(batch).to(logits_per_image.device) | |
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 | |
return contrastive_loss | |
def forward(self, img, txt, stage): | |
if stage == '1st': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1-self.ratio) * image | |
# b, 1024 | |
# fq_t = self.FPN(vis, x) | |
# | |
# fv_t = self.gap(fq_t) | |
loss1 = self.IT_loss(x, text) | |
loss = loss1 | |
ft = text | |
fi = x | |
fv = None | |
elif stage == '2nd': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1 - self.ratio) * image | |
# Construct multi-scale feats | |
vis_trans = [] | |
for i in range(len(self.ms_adaptor)): | |
x_ = rearrange( | |
vis[i], | |
"b (h w) c -> b c h w", | |
h=self.patch_emb, | |
w=self.patch_emb, | |
).contiguous() | |
feats = self.ms_adaptor[i](x_) | |
vis_trans.append(feats) | |
# fq = self.FPN(vis, x_t) | |
fv_t = self.FPN(vis_trans[1:], x, False) | |
# fv_t = self.gap(fq_t) | |
# b, 1024 | |
loss2 = self.IT_loss(fv_t, text) | |
loss = (loss2) | |
fv = fv_t | |
ft = text | |
fi = x | |
return loss, fv, fi, ft | |
def visualize(self, img, txt): | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1 - self.ratio) * image | |
# Construct multi-scale feats | |
vis_trans = [] | |
for i in range(len(self.ms_adaptor)): | |
x_ = rearrange( | |
vis[i], | |
"b (h w) c -> b c h w", | |
h=self.patch_emb, | |
w=self.patch_emb, | |
).contiguous() | |
feats = self.ms_adaptor[i](x_) | |
vis_trans.append(feats) | |
# fq = self.FPN(vis, x_t) | |
fv_t = self.FPN(vis_trans[1:], x, True) | |
ft_t = self.FPN(vis_trans[1:], text, True) | |
return vis, fv_t, ft_t | |
class CISEN_rsvit(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder & Label Encoder | |
clip_model = torch.load(cfg.clip_pretrain, | |
map_location="cpu") | |
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len) | |
self.backbone = backbone.float() | |
self.patch_emb = image_resolution // patch_size | |
cfg.image_resolution = image_resolution | |
cfg.input_size = image_resolution | |
cfg.heads = vision_heads // 32 | |
cfg.emb_dim = vision_width | |
cfg.output_dim = embed_dim | |
# multi-scale adapter | |
# Multi-Modal FPN | |
self.FPN = ViTFPN(image_resolution, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# Fined-grained Fusion | |
# self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, | |
# d_model=cfg.vis_dim, | |
# nhead=cfg.num_head, | |
# dim_ffn=cfg.dim_ffn, | |
# dropout=cfg.dropout, | |
# return_intermediate=cfg.intermediate) | |
# image-text transformer | |
# self.trans = nn.Linear(1024, 1024) | |
self.ADP = Adapter(cfg.output_dim, 4) | |
# parameter | |
self.ratio = cfg.ratio | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.share_temperature = True | |
self.ce = nn.CrossEntropyLoss() | |
self.ms_adaptor = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), | |
nn.GroupNorm(32, cfg.emb_dim), | |
nn.GELU(), | |
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), | |
), | |
nn.Sequential( | |
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), | |
), | |
nn.Sequential( | |
nn.Identity(), | |
), | |
nn.Sequential( | |
nn.MaxPool2d(2), | |
), | |
] | |
) | |
self.ms_adaptor.apply(self.init_adaptor) | |
def init_adaptor(self, m): | |
if isinstance(m, nn.Conv2d): | |
lecun_normal_(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.GroupNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
elif isinstance(m, nn.ConvTranspose2d): | |
lecun_normal_(m.weight) | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
# self.fc = nn.Linear(512, cfg.num_classes) | |
def IT_loss(self, image_features, text_features): | |
# b, 1024 / b, 1024 | |
batch = image_features.shape[0] | |
# # normalized features | |
image_features = image_features / image_features.norm(dim=-1, | |
keepdim=True) | |
text_features = text_features / text_features.norm(dim=-1, | |
keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ text_features.t() | |
logits_per_text = logits_per_image.t() | |
# shape = [global_batch_size, global_batch_size] | |
contrastive_labels = torch.arange(batch).to(logits_per_image.device) | |
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 | |
return contrastive_loss | |
def image_encode(self, img): | |
vis, image = self.backbone.encode_image(img) | |
x = self.ADP(image) | |
x = self.ratio * x + (1 - self.ratio) * image | |
return x | |
def text_encode(self, txt): | |
word, text = self.backbone.encode_text(txt) | |
return text | |
def forward(self, img, txt, stage): | |
if stage == '1st': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1-self.ratio) * image | |
# b, 1024 | |
# fq_t = self.FPN(vis, x) | |
# | |
# fv_t = self.gap(fq_t) | |
loss1 = self.IT_loss(x, text) | |
loss = loss1 | |
ft = text | |
fi = x | |
fv = None | |
elif stage == '2nd': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1 - self.ratio) * image | |
# Construct multi-scale feats | |
vis_trans = [] | |
for i in range(len(self.ms_adaptor)): | |
x_ = rearrange( | |
vis[i], | |
"b (h w) c -> b c h w", | |
h=self.patch_emb, | |
w=self.patch_emb, | |
).contiguous() | |
feats = self.ms_adaptor[i](x_) | |
vis_trans.append(feats) | |
# fq = self.FPN(vis, x_t) | |
fv_t = self.FPN(vis_trans[1:], x, False) | |
# fv_t = self.gap(fq_t) | |
# b, 1024 | |
loss2 = self.IT_loss(fv_t, text) | |
loss = (loss2) | |
fv = fv_t | |
ft = text | |
fi = x | |
return loss, fv, fi, ft | |
def visualize(self, img): | |
vis, image = self.backbone.encode_image(img) | |
x = self.ADP(image) | |
x = self.ratio * x + (1 - self.ratio) * image | |
# Construct multi-scale feats | |
vis_trans = [] | |
for i in range(len(self.ms_adaptor)): | |
x_ = rearrange( | |
vis[i], | |
"b (h w) c -> b c h w", | |
h=self.patch_emb, | |
w=self.patch_emb, | |
).contiguous() | |
feats = self.ms_adaptor[i](x_) | |
vis_trans.append(feats) | |
fv_t = self.FPN(vis_trans[1:], x, True) | |
return vis, fv_t | |
class CISEN_vit(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder & Label Encoder | |
clip_model = torch.jit.load(cfg.clip_pretrain, | |
map_location="cpu").eval() | |
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model.state_dict(), cfg.word_len) | |
self.backbone = backbone.float() | |
self.patch_emb = image_resolution // patch_size | |
cfg.image_resolution = image_resolution | |
cfg.input_size = image_resolution | |
cfg.heads = vision_heads // 32 | |
cfg.emb_dim = vision_width | |
cfg.output_dim = embed_dim | |
# multi-scale adapter | |
# Multi-Modal FPN | |
self.FPN = ViTFPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# Fined-grained Fusion | |
# self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, | |
# d_model=cfg.vis_dim, | |
# nhead=cfg.num_head, | |
# dim_ffn=cfg.dim_ffn, | |
# dropout=cfg.dropout, | |
# return_intermediate=cfg.intermediate) | |
# image-text transformer | |
# self.trans = nn.Linear(1024, 1024) | |
self.ADP = Adapter(cfg.output_dim, 4) | |
# parameter | |
self.ratio = cfg.ratio | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.share_temperature = True | |
self.ce = nn.CrossEntropyLoss() | |
self.ms_adaptor = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), | |
nn.GroupNorm(32, cfg.emb_dim), | |
nn.GELU(), | |
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), | |
), | |
nn.Sequential( | |
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), | |
), | |
nn.Sequential( | |
nn.Identity(), | |
), | |
nn.Sequential( | |
nn.MaxPool2d(2), | |
), | |
] | |
) | |
self.ms_adaptor.apply(self.init_adaptor) | |
def init_adaptor(self, m): | |
if isinstance(m, nn.Conv2d): | |
lecun_normal_(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.GroupNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
elif isinstance(m, nn.ConvTranspose2d): | |
lecun_normal_(m.weight) | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
# self.fc = nn.Linear(512, cfg.num_classes) | |
def IT_loss(self, image_features, text_features): | |
# b, 1024 / b, 1024 | |
batch = image_features.shape[0] | |
# # normalized features | |
image_features = image_features / image_features.norm(dim=-1, | |
keepdim=True) | |
text_features = text_features / text_features.norm(dim=-1, | |
keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ text_features.t() | |
logits_per_text = logits_per_image.t() | |
# shape = [global_batch_size, global_batch_size] | |
contrastive_labels = torch.arange(batch).to(logits_per_image.device) | |
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 | |
return contrastive_loss | |
def forward(self, img, txt, stage): | |
if stage == '1st': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1-self.ratio) * image | |
# b, 1024 | |
# fq_t = self.FPN(vis, x) | |
# | |
# fv_t = self.gap(fq_t) | |
loss1 = self.IT_loss(x, text) | |
loss = loss1 | |
ft = text | |
fi = x | |
fv = None | |
elif stage == '2nd': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1 - self.ratio) * image | |
# Construct multi-scale feats | |
vis_trans = [] | |
for i in range(len(self.ms_adaptor)): | |
x_ = rearrange( | |
vis[i], | |
"b (h w) c -> b c h w", | |
h=self.patch_emb, | |
w=self.patch_emb, | |
).contiguous() | |
feats = self.ms_adaptor[i](x_) | |
vis_trans.append(feats) | |
# fq = self.FPN(vis, x_t) | |
fv_t = self.FPN(vis_trans[1:], x, False) | |
# fv_t = self.gap(fq_t) | |
# b, 1024 | |
loss2 = self.IT_loss(fv_t, text) | |
loss = (loss2) | |
fv = fv_t | |
ft = text | |
fi = x | |
return loss, fv, fi, ft | |
def visualize(self, img, txt): | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1 - self.ratio) * image | |
# Construct multi-scale feats | |
vis_trans = [] | |
for i in range(len(self.ms_adaptor)): | |
x_ = rearrange( | |
vis[i], | |
"b (h w) c -> b c h w", | |
h=self.patch_emb, | |
w=self.patch_emb, | |
).contiguous() | |
feats = self.ms_adaptor[i](x_) | |
vis_trans.append(feats) | |
# fq = self.FPN(vis, x_t) | |
fv_t = self.FPN(vis_trans[1:], x, True) | |
ft_t = self.FPN(vis_trans[1:], text, True) | |
return vis, fv_t, ft_t | |
class CISEN_rsvit_classification(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder & Label Encoder | |
clip_model = torch.load(cfg.clip_pretrain, | |
map_location="cpu") | |
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len) | |
self.backbone = backbone.float() | |
self.patch_emb = image_resolution // patch_size | |
num_classes_fc = 512 | |
num_classes_output = 10 | |
self.num_classes_fc = num_classes_fc # Number of classes for fully connected layer | |
self.num_classes_output = num_classes_output # Number of classes for output layer | |
# Add a fully connected layer | |
self.fc = nn.Linear(in_features=cfg.vis_dim, out_features=num_classes_fc) | |
# Add an output layer for multi-label classification | |
self.output_layer = nn.Linear(in_features=num_classes_fc, out_features=num_classes_output) | |
self.criterion = nn.BCEWithLogitsLoss() | |
cfg.image_resolution = image_resolution | |
cfg.input_size = image_resolution | |
cfg.heads = vision_heads // 32 | |
cfg.emb_dim = vision_width | |
cfg.output_dim = embed_dim | |
def IT_loss(self, labels, labels_pre): | |
labels = labels.squeeze(1) | |
loss = self.criterion(labels_pre, labels) | |
return loss | |
def forward(self, img, labels): | |
_, image_features = self.backbone.encode_image(img) | |
# Fully connected layer | |
fc_output = self.fc(image_features) | |
# Apply ReLU activation function | |
fc_output = F.relu(fc_output) | |
# Output layer for multi-label classification | |
labels_pre = self.output_layer(fc_output) | |
loss2 = self.IT_loss(labels, labels_pre) | |
return labels_pre, loss2 | |
class CISEN_new(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder & Label Encoder | |
clip_model = torch.jit.load(cfg.clip_pretrain, | |
map_location="cpu").eval() | |
backbone, image_resolution, vision_heads, embed_dim, vision_width, _ = build_model(clip_model.state_dict(), cfg.word_len) | |
self.backbone = backbone.float() | |
cfg.input_size = image_resolution | |
cfg.heads = vision_heads | |
cfg.emb_dim = vision_width * 32 | |
cfg.output_dim = embed_dim | |
# Multi-Modal FPN | |
self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# Fined-grained Fusion | |
# self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, | |
# d_model=cfg.vis_dim, | |
# nhead=cfg.num_head, | |
# dim_ffn=cfg.dim_ffn, | |
# dropout=cfg.dropout, | |
# return_intermediate=cfg.intermediate) | |
# image-text transformer | |
# self.trans = nn.Linear(1024, 1024) | |
self.ADP = Adapter(cfg.output_dim, 4) | |
self.gap = GAP((1,1)) | |
# parameter | |
self.ratio = cfg.ratio | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.share_temperature = True | |
self.margin = 1 | |
self.eps = 1e-3 | |
self.ce = nn.CrossEntropyLoss() | |
#1st stage | |
self.lamda1 = cfg.lamda1 | |
self.lamda2 = cfg.lamda2 | |
self.avg = nn.AdaptiveAvgPool2d((1,1)) | |
# self.fc = nn.Linear(512, cfg.num_classes) | |
def IT_loss(self, image_features, text_features): | |
# b, 1024 / b, 1024 | |
batch = image_features.shape[0] | |
# # normalized features | |
image_features = image_features / image_features.norm(dim=-1, | |
keepdim=True) | |
text_features = text_features / text_features.norm(dim=-1, | |
keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ text_features.t() | |
logits_per_text = logits_per_image.t() | |
# shape = [global_batch_size, global_batch_size] | |
contrastive_labels = torch.arange(batch).to(logits_per_image.device) | |
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 | |
return contrastive_loss | |
def forward(self, img, txt, stage): | |
if stage == '1st': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1-self.ratio) * image | |
# b, 1024 | |
# fq_t = self.FPN(vis, x) | |
# | |
# fv_t = self.gap(fq_t) | |
loss1 = self.IT_loss(x, text) | |
loss = loss1 | |
ft = text | |
fi = x | |
fv = None | |
elif stage == '2nd': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1 - self.ratio) * image | |
# x_t = self.trans(x) | |
# fq = self.FPN(vis, x_t) | |
fq_t = self.FPN(vis, x) | |
fv_t = self.gap(fq_t) | |
# b, 1024 | |
loss2 = self.IT_loss(fv_t, text) | |
loss = (loss2) | |
fv = fv_t | |
ft = text | |
fi = x | |
elif stage == '3rd': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(text) | |
ratio = 0.2 | |
x = ratio * x + (1 - ratio) * text | |
# x_t = self.trans(x) | |
# fq = self.FPN(vis, x_t) | |
# b, 1024 | |
loss1 = self.IT_loss(image, x) | |
loss = loss1 | |
fv = None | |
ft = x | |
fi = image | |
elif stage == '4th': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
# x = self.ADP(image) | |
# ratio = 0.2 | |
# x = ratio * x + (1 - ratio) * text | |
fq_t = self.FPN(vis, image) | |
fv_t = self.gap(fq_t) | |
ratio_1 = 0.2 | |
# b, 1024 | |
loss2 = self.IT_loss(fv_t, text) | |
loss = loss2 | |
fv = fv_t | |
fi = None | |
ft = text | |
elif stage == '5th': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
ratio = 0.2 | |
x = ratio * x + (1 - ratio) * image | |
y = self.ADP_t(text) | |
ratio_1 = 0.2 | |
y = ratio * y + (1 - ratio_1) * text | |
fq_t = self.FPN(vis, image) | |
fv_t = self.gap(fq_t) | |
# b, 1024 | |
loss2 = self.IT_loss(fv_t, y) | |
loss = loss2 | |
fv = fv_t | |
fi = x | |
ft = y | |
return loss, fv, fi, ft | |
class CISEN_lclip(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder & Label Encoder | |
clip_model = torch.load(cfg.clip_pretrain, | |
map_location="cpu") | |
# print(type(clip_model)) | |
backbone, image_resolution, vision_heads, embed_dim, vision_width, _ = build_lclip_model(clip_model, load_from_clip=True) | |
self.backbone = backbone.float() | |
cfg.input_size = image_resolution | |
cfg.heads = vision_heads // 32 | |
cfg.emb_dim = vision_width | |
cfg.output_dim = embed_dim | |
# Multi-Modal FPN | |
self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# Fined-grained Fusion | |
# self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, | |
# d_model=cfg.vis_dim, | |
# nhead=cfg.num_head, | |
# dim_ffn=cfg.dim_ffn, | |
# dropout=cfg.dropout, | |
# return_intermediate=cfg.intermediate) | |
# image-text transformer | |
# self.trans = nn.Linear(1024, 1024) | |
self.ADP = Adapter(cfg.output_dim, 4) | |
self.gap = GAP((1,1)) | |
# parameter | |
self.ratio = cfg.ratio | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.share_temperature = True | |
self.margin = 1 | |
self.eps = 1e-3 | |
self.ce = nn.CrossEntropyLoss() | |
#1st stage | |
self.lamda1 = cfg.lamda1 | |
self.lamda2 = cfg.lamda2 | |
self.avg = nn.AdaptiveAvgPool2d((1,1)) | |
# self.fc = nn.Linear(512, cfg.num_classes) | |
def IT_loss(self, image_features, text_features): | |
# b, 1024 / b, 1024 | |
batch = image_features.shape[0] | |
# # normalized features | |
image_features = image_features / image_features.norm(dim=-1, | |
keepdim=True) | |
text_features = text_features / text_features.norm(dim=-1, | |
keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ text_features.t() | |
logits_per_text = logits_per_image.t() | |
# shape = [global_batch_size, global_batch_size] | |
contrastive_labels = torch.arange(batch).to(logits_per_image.device) | |
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 | |
return contrastive_loss | |
def forward(self, img, txt, stage): | |
if stage == '1st': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1-self.ratio) * image | |
# b, 1024 | |
# fq_t = self.FPN(vis, x) | |
# | |
# fv_t = self.gap(fq_t) | |
loss1 = self.IT_loss(x, text) | |
loss = loss1 | |
ft = text | |
fi = x | |
fv = None | |
elif stage == '2nd': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1 - self.ratio) * image | |
# x_t = self.trans(x) | |
# fq = self.FPN(vis, x_t) | |
fq_t = self.FPN(vis, x) | |
fv_t = self.gap(fq_t) | |
# b, 1024 | |
loss2 = self.IT_loss(fv_t, text) | |
loss = (loss2) | |
fv = fv_t | |
ft = text | |
fi = x | |
elif stage == '3rd': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
text = self.backbone.encode_text(txt) | |
x = self.ADP(text) | |
ratio = 0.2 | |
x = ratio * x + (1 - ratio) * text | |
# x_t = self.trans(x) | |
# fq = self.FPN(vis, x_t) | |
# b, 1024 | |
loss1 = self.IT_loss(image, x) | |
loss = loss1 | |
fv = None | |
ft = x | |
fi = image | |
elif stage == '4th': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
# x = self.ADP(image) | |
# ratio = 0.2 | |
# x = ratio * x + (1 - ratio) * text | |
fq_t = self.FPN(vis, image) | |
fv_t = self.gap(fq_t) | |
ratio_1 = 0.2 | |
# b, 1024 | |
loss2 = self.IT_loss(fv_t, text) | |
loss = loss2 | |
fv = fv_t | |
fi = None | |
ft = text | |
elif stage == '5th': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
ratio = 0.2 | |
x = ratio * x + (1 - ratio) * image | |
y = self.ADP_t(text) | |
ratio_1 = 0.2 | |
y = ratio * y + (1 - ratio_1) * text | |
fq_t = self.FPN(vis, image) | |
fv_t = self.gap(fq_t) | |
# b, 1024 | |
loss2 = self.IT_loss(fv_t, y) | |
loss = loss2 | |
fv = fv_t | |
fi = x | |
ft = y | |
return loss, fv, fi, ft | |
class GeoRSCLIP(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder & Label Encoder | |
clip_model = torch.load(cfg.clip_pretrain, | |
map_location="cpu") | |
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len) | |
self.backbone = backbone.float() | |
def forward(self, img, txt, stage): | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
loss = None | |
ft = text | |
fi = image | |
fv = None | |
return loss, fv, fi, ft | |
class CISEN(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder & Label Encoder | |
clip_model = torch.jit.load(cfg.clip_pretrain, | |
map_location="cpu").eval() | |
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float() | |
# Multi-Modal FPN | |
self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# Fined-grained Fusion | |
self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, | |
d_model=cfg.vis_dim, | |
nhead=cfg.num_head, | |
dim_ffn=cfg.dim_ffn, | |
dropout=cfg.dropout, | |
return_intermediate=cfg.intermediate) | |
# adaptively aggretation | |
self.ASFF = AdaptiveSpatialFeatureFusion(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# text projector | |
self.projT = Text_Projector(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# image projector | |
# self.projI = Image_Projector(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# parameter | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.share_temperature = True | |
self.margin = 1 | |
self.eps = 1e-3 | |
self.ce = nn.CrossEntropyLoss() | |
#1st stage | |
self.lamda1 = cfg.lamda1 | |
self.lamda2 = cfg.lamda2 | |
self.beta1 = cfg.beta1 | |
self.beta2 = cfg.beta2 | |
self.avg = nn.AdaptiveAvgPool2d((1,1)) | |
# self.fc = nn.Linear(512, cfg.num_classes) | |
#2nd stage | |
self.pos_samples = cfg.pos_samples | |
self.neg_samples = cfg.neg_samples | |
def IT_loss(self, image_features, text_features): | |
# b, 1024 / b, 1024 | |
batch = image_features.shape[0] | |
# # normalized features | |
image_features = image_features / image_features.norm(dim=-1, | |
keepdim=True) | |
text_features = text_features / text_features.norm(dim=-1, | |
keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ text_features.t() | |
logits_per_text = logits_per_image.t() | |
# shape = [global_batch_size, global_batch_size] | |
contrastive_labels = torch.arange(batch).to(logits_per_image.device) | |
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 | |
return contrastive_loss | |
def IET_loss(self, image_features, text_features, pos_samples, beta): | |
# b, 1024 / b, 1024 | |
# # normalized features | |
image_features = [image_feature / image_feature.norm(dim=-1, | |
keepdim=True) for image_feature in image_features] | |
text_features = text_features / text_features.norm(dim=-1, | |
keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
# logits_per_image = [logit_scale * image_feature @ text_features.t() for image_feature in image_features] | |
logits_per_image = [logit_scale * torch.sum(torch.mul(image_feature, text_features),1) for image_feature in image_features] | |
logits_per_image = torch.stack(logits_per_image).t() | |
b = logits_per_image.shape[0] | |
loss1 = torch.norm(text_features - image_features[0]) | |
positive_tagsT = torch.zeros(b,len(image_features)).to(text_features.device) | |
negative_tagsT = torch.zeros(b,len(image_features)).to(text_features.device) | |
positive_tagsT[:, 0 : pos_samples + 1] = 1 | |
negative_tagsT[:, pos_samples + 1 : -1] = 1 | |
maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1) | |
pos_score_matT = logits_per_image * positive_tagsT | |
neg_score_matT = logits_per_image * negative_tagsT | |
IW_pos3T = pos_score_matT.unsqueeze(1) | |
IW_neg3T = neg_score_matT.unsqueeze(-1) | |
OT = 1 + IW_neg3T - IW_pos3T | |
O_maskT = maskT * OT | |
diffT = torch.clamp(O_maskT, 0) | |
violationT = torch.sign(diffT).sum(1).sum(1) | |
diffT = diffT.sum(1).sum(1) | |
lossT = torch.mean(diffT / (violationT + self.eps)) | |
loss = beta * loss1 + lossT | |
return loss | |
def test_IET_loss(self, image_features, text_features, pos_samples, beta1, beta2): | |
# text_features: enhanced_features | |
# b, 1024 / b, 1024 | |
# # normalized features | |
image_features = image_features / image_features.norm(dim=-1, | |
keepdim=True) | |
text_features = text_features / text_features.norm(dim=-1, | |
keepdim=True) | |
image_features = image_features.unsqueeze(1) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
# image_features = image_features.expand(-1, text_features.shape[1], -1) | |
logits_per_image = logit_scale * torch.matmul(image_features, text_features.transpose(1, 2)) | |
logits_per_image = logits_per_image.squeeze(1) | |
# logits_per_image = logit_scale * image_features @ text_features.t() | |
# logits_per_image = [logit_scale * image_feature @ text_features.t() for image_feature in image_features] | |
b = logits_per_image.shape[0] | |
# loss1 = torch.norm(text_features[:, 0, :] - image_features.squeeze(1)) | |
positive_tagsT = torch.zeros(b, text_features.shape[1]).to(text_features.device) | |
negative_tagsT = torch.zeros(b, text_features.shape[1]).to(text_features.device) | |
positive_tagsT[:, 0 : pos_samples + 1] = 1 | |
negative_tagsT[:, pos_samples + 1 : -1] = 1 | |
maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1) | |
pos_score_matT = logits_per_image * positive_tagsT | |
neg_score_matT = logits_per_image * negative_tagsT | |
IW_pos3T = pos_score_matT.unsqueeze(1) | |
IW_neg3T = neg_score_matT.unsqueeze(-1) | |
OT = 1 + IW_neg3T - IW_pos3T | |
O_maskT = maskT * OT | |
diffT = torch.clamp(O_maskT, 0) | |
violationT = torch.sign(diffT).sum(1).sum(1) | |
diffT = diffT.sum(1).sum(1) | |
lossT = torch.mean(diffT / (violationT + self.eps)) | |
# loss = beta1 * loss1 + beta2 * lossT | |
loss = lossT | |
return loss | |
def test_IT_loss(self, image_features, text_features): | |
# b, 1024 / b, 1024 | |
batch = image_features.shape[0] | |
# # normalized features | |
image_features = image_features / image_features.norm(dim=-1, | |
keepdim=True) | |
text_features = text_features / text_features.norm(dim=-1, | |
keepdim=True) | |
image_features = image_features.unsqueeze(1) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits_per_image = logit_scale * torch.matmul(image_features, text_features.transpose(1, 2)) | |
logits_per_image = logits_per_image.squeeze(1) | |
# shape = [global_batch_size, global_batch_size] | |
contrastive_labels = torch.arange(batch).to(logits_per_image.device) | |
contrastive_loss = self.ce(logits_per_image, contrastive_labels) | |
return contrastive_loss | |
def test_forward(self, img, txt): | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# state: b, 1024 | |
# image: b, 512 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
fq = self.FPN(vis, text) | |
b, c, h, w = fq.size() | |
# b, 512, 14, 14 | |
ff = self.FGFusion(fq, word, pad_mask) | |
ff = ff.reshape(b, c, h, w) | |
f2 = self.avg(ff) | |
fi = image.unsqueeze(-1).unsqueeze(-1) | |
fv = self.ASFF(fi, f2) | |
fi = fi.squeeze(-1).squeeze(-1) | |
# b, 1024 | |
ft = self.projT(text) | |
loss1 = self.IT_loss(fi, ft) | |
loss2 = self.IT_loss(fv, ft) | |
loss = self.lamda1 * loss1 + self.lamda2 * loss2 | |
return loss, fv, ft, fi | |
def forward(self, img, txt, stage): | |
if stage == '1st': | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# state: b, 1024 | |
# image: b, 512 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
fq = self.FPN(vis, text) | |
b, c, h, w = fq.size() | |
# b, 512, 14, 14 | |
ff = self.FGFusion(fq, word, pad_mask) | |
ff = ff.reshape(b, c, h, w) | |
f2 = self.avg(ff) | |
fi = image.unsqueeze(-1).unsqueeze(-1) | |
fv = self.ASFF(fi, f2) | |
fi = fi.squeeze(-1).squeeze(-1) | |
# b, 1024 | |
ft = self.projT(text) | |
loss1 = self.IT_loss(fi, ft) | |
loss2 = self.IT_loss(fv, ft) | |
loss = self.lamda1 * loss1 + self.lamda2 * loss2 | |
elif stage == '2nd': | |
""" | |
txt: b, num, words | |
img: b, 3, h, w | |
""" | |
# txt = b * num, word | |
b, num, l = txt.shape[0], txt.shape[1], txt.shape[2] | |
txt = txt.view(-1, txt.size(-1)) | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
b = img.shape[0] | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
fq = self.FPN(vis, text) | |
# b, 512, 14, 14 (C4) | |
b, c, h, w = fq.size() | |
# b, 512, 14, 14 | |
ff = self.FGFusion(fq, word, pad_mask) | |
ff = ff.reshape(b, c, h, w) | |
f2 = self.avg(ff) | |
fi = image.unsqueeze(-1).unsqueeze(-1) | |
fi_ = fi.repeat(int(f2.shape[0] / fi.shape[0]), 1, 1, 1) | |
fv = self.ASFF(fi_, f2) | |
fi = fi.squeeze(-1).squeeze(-1) | |
# fi_ = fi_.squeeze(-1).squeeze(-1) | |
# b, 1024 | |
ft = text.view(img.shape[0], int(text.shape[0] / img.shape[0]), -1)[:, 0, :] | |
fv = fv.view(ft.shape[0], int(text.shape[0] / ft.shape[0]), fv.shape[1]) | |
loss = self.test_IET_loss(fi, fv, self.pos_samples, self.beta1, self.beta2) | |
elif stage == 'test': | |
""" | |
txt: b, num, words | |
img: b, 3, h, w | |
""" | |
txt = txt.permute(1, 0, 2) | |
# txt = b * num, word | |
# txt = txt.view(-1, txt.size(-1)) | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# state: b, 1024 | |
# image: b, 512 | |
b = img.shape[0] | |
words = [] | |
texts = [] | |
vis, image = self.backbone.encode_image(img) | |
for i in range(txt.shape[0]): | |
word, text = self.backbone.encode_text(txt[i]) | |
words.append(word) | |
texts.append(text) | |
fvn = [] | |
# b, 512, 14, 14 (C4) | |
for i in range(txt.shape[0]): | |
fq = self.FPN(vis, texts[i]) | |
b, c, h, w = fq.size() | |
# b, 512, 14, 14 | |
ff = self.FGFusion(fq, words[i], pad_mask[i, :, :]) | |
ff = ff.reshape(b, c, h, w) | |
f2 = self.avg(ff) | |
fi = image.unsqueeze(-1).unsqueeze(-1) | |
fv = self.ASFF(fi, f2) | |
fi = fi.squeeze(-1).squeeze(-1) | |
fvn.append(fv) | |
# b, 1024 | |
ft = self.projT(texts[0]) | |
loss = self.IET_loss(fvn, ft, self.pos_samples, self.beta) | |
fv = fvn | |
else: | |
print('stage should be either 1st or 2nd or test') | |
# labels = torch.ones(image.shape[0], image.shape[0]).to(image.device) | |
# labels[:,-1] = 0 | |
# labels[3, :] = 0 | |
# out = self.avg(fq) | |
# out = out.squeeze(-1).squeeze(-1) | |
# out = self.fc(out) | |
return loss, fv, fi, ft | |
class CRIS(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder & Label Encoder | |
clip_model = torch.jit.load(cfg.clip_pretrain, | |
map_location="cpu").eval() | |
self.backbone, _, _, _, _ = build_model(clip_model.state_dict(), cfg.word_len) | |
self.backbone = self.backbone.float() | |
self.Label_encoder = build_promptlearner(clip_model.state_dict()).float() | |
self.Label_encoder.init_label_emb(cfg.label_path) | |
# Multi-Modal FPN | |
self.FPN = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# Fined-grained Fusion | |
self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, | |
d_model=cfg.vis_dim, | |
nhead=cfg.num_head, | |
dim_ffn=cfg.dim_ffn, | |
dropout=cfg.dropout, | |
return_intermediate=cfg.intermediate) | |
# adaptively aggretation | |
self.ASFF = AdaptiveSpatialFeatureFusion(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# text projector | |
self.projT = Text_Projector(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# parameter | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.share_temperature = True | |
self.margin = 1 | |
self.eps = 1e-3 | |
self.ce = nn.CrossEntropyLoss() | |
self.avg = nn.AdaptiveAvgPool2d((1,1)) | |
self.fc = nn.Linear(512, cfg.num_classes) | |
def IT_loss(self, image_features, text_features): | |
# b, 1024 / b, 1024 | |
batch = image_features.shape[0] | |
# # normalized features | |
image_features = image_features / image_features.norm(dim=-1, | |
keepdim=True) | |
text_features = text_features / text_features.norm(dim=-1, | |
keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ text_features.t() | |
logits_per_text = logits_per_image.t() | |
# shape = [global_batch_size, global_batch_size] | |
contrastive_labels = torch.arange(batch).to(logits_per_image.device) | |
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 | |
return contrastive_loss | |
def IL_loss(self, image_features, label_features, labels): | |
# b, 1024 / K, 1024/ b, K | |
positive_tagsT = torch.clamp(labels,0.,1.) | |
negative_tagsT = torch.clamp(-labels,0.,1.) | |
maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1) | |
# normalized features | |
image_features = image_features / image_features.norm(dim=-1, | |
keepdim=True) | |
label_features = label_features / label_features.norm(dim=-1, | |
keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.multi_label_logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ label_features.t() | |
# logits_per_label = logit_scale * label_features @ image_features.t() | |
pos_score_matT = logits_per_image * positive_tagsT | |
neg_score_matT = logits_per_image * negative_tagsT | |
IW_pos3T = pos_score_matT.unsqueeze(1) | |
IW_neg3T = neg_score_matT.unsqueeze(-1) | |
OT = self.margin + IW_neg3T - IW_pos3T | |
O_maskT = maskT * OT | |
diffT = torch.clamp(O_maskT, 0) | |
violationT = torch.sign(diffT).sum(1).sum(1) | |
diffT = diffT.sum(1).sum(1) | |
lossT = torch.mean(diffT / (violationT + self.eps)) | |
return lossT | |
def margin_loss(self, image_features, label_features, labels): | |
# b, 1024 / K, 1024/ b, K | |
# normalized features | |
image_features = image_features / image_features.norm(dim=-1, | |
keepdim=True) | |
label_features = label_features / label_features.norm(dim=-1, | |
keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.multi_label_logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ label_features.t() | |
# logits_per_label = logit_scale * label_features @ image_features.t() | |
image_label_positive_pairs = logits_per_image * labels | |
image_label_mean_positive = image_label_positive_pairs.sum() / labels.sum() | |
image_label_negative_pairs = logits_per_image * (1 - labels) | |
image_label_mean_negative = image_label_negative_pairs.sum() / (logits_per_image.numel() - labels.sum() + self.eps) | |
contrastive_loss = torch.relu(self.margin - image_label_mean_positive + image_label_mean_negative) | |
return contrastive_loss | |
def forward(self, img, txt, target=None): | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# state: b, 1024 | |
# image: b, 512 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
fl = self.Label_encoder(image.device) | |
# b, 512, 14, 14 (C4) | |
fq = self.FPN(vis, text) | |
b, c, h, w = fq.size() | |
# b, 512, 14, 14 | |
ff = self.FGFusion(fq, word, pad_mask) | |
# b, 512, 196 | |
ff = ff.reshape(b, c, h, w) | |
f2 = self.avg(ff) | |
# b, 1024 | |
f1 = image.unsqueeze(-1).unsqueeze(-1) | |
fv = self.ASFF(f1, f2) | |
# b, 1024 | |
ft = self.projT(text) | |
# labels = torch.ones(image.shape[0], image.shape[0]).to(image.device) | |
# labels[:,-1] = 0 | |
# labels[3, :] = 0 | |
loss1 = self.IT_loss(fv, ft) | |
loss2 = self.IL_loss(fv, fl, target) | |
loss = loss1 + loss2 | |
# out = self.avg(fq) | |
# out = out.squeeze(-1).squeeze(-1) | |
# out = self.fc(out) | |
return loss, fv, ft, fl | |
class zh_clip(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder | |
clip_model = torch.jit.load(cfg.clip_pretrain, | |
map_location="cpu").eval() | |
self.backbone = build_modified_model(clip_model.state_dict(), cfg.word_len).float() | |
self.text_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.chinese) | |
self.text_lin = nn.Linear(512, 1024) | |
# Multi-Modal FPN | |
self.neck = ViTFPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# Decoder | |
self.avg = nn.AdaptiveAvgPool2d((1,1)) | |
self.fc = nn.Linear(512, cfg.num_classes) | |
def forward(self, img, word): | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
''' | |
# padding mask used in decoder | |
# vis: v1 / v2 / b, 49, 1024/ b, 196, 512 | |
# state: b, 1024 | |
# feat: f1 / f2 / b, 1024, 7, 7/ b, 1024, 7, 7 | |
# cls: c1 / c2 / b, 1024/ b, 512 | |
vis, feat, cls = self.backbone.encode_image(img) | |
state = self.text_encoder(word.squeeze(1)).logits | |
state = self.text_lin(state) | |
# b, 1024, 7, 7 (C5) | |
fq = self.neck(feat, state) | |
out = self.avg(fq) | |
out = out.squeeze(-1).squeeze(-1) | |
out = self.fc(out) | |
return out | |
class poi_clip(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder | |
clip_model = torch.jit.load(cfg.clip_pretrain, | |
map_location="cpu").eval() | |
self.backbone = build_modified_model(clip_model.state_dict(), cfg.word_len).float() | |
self.text_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.chinese) | |
self.text_lin = nn.Linear(512, 1024) | |
# Multi-Modal FPN | |
self.neck = ViTFPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# Decoder | |
self.avg = nn.AdaptiveAvgPool2d((1,1)) | |
self.fc = nn.Linear(512, cfg.num_classes) | |
def forward(self, img, word): | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
''' | |
# padding mask used in decoder | |
# vis: v1 / v2 / b, 49, 1024/ b, 196, 512 | |
# state: b, 1024 | |
# feat: f1 / f2 / b, 1024, 7, 7/ b, 1024, 7, 7 | |
# cls: c1 / c2 / b, 1024/ b, 512 | |
vis, feat, cls = self.backbone.encode_image(img) | |
state = self.text_encoder(word.squeeze(1)).logits | |
state = self.text_lin(state) | |
# b, 1024, 7, 7 (C5) | |
fq = self.neck(feat, state) | |
out = self.avg(fq) | |
out = out.squeeze(-1).squeeze(-1) | |
out = self.fc(out) | |
return out | |
class Clip_hash_model(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder | |
clip_model = torch.jit.load(cfg.clip_pretrain, | |
map_location="cpu").eval() | |
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float() | |
# Multi-Modal FPN | |
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# Decoder | |
self.avg = nn.AdaptiveAvgPool2d((1, 1)) | |
self.classifier = nn.Sequential( | |
nn.Linear(cfg.fpn_out[1], cfg.hash_dim, bias=True), | |
nn.Tanh(), | |
) | |
self.classifier2 = nn.Sequential( | |
nn.Linear(cfg.hash_dim, cfg.num_classes) | |
) | |
# Hash Module | |
self.image_module = nn.Sequential( | |
nn.Linear(cfg.img_dim, cfg.hidden_dim, bias=True), | |
nn.BatchNorm1d(cfg.hidden_dim), | |
nn.ReLU(True), | |
nn.Linear(cfg.hidden_dim, cfg.hash_dim, bias=True), | |
nn.Tanh() | |
) | |
self.text_module = nn.Sequential( | |
nn.Linear(cfg.txt_dim, cfg.hidden_dim, bias=True), | |
nn.BatchNorm1d(cfg.hidden_dim), | |
nn.ReLU(True), | |
nn.Linear(cfg.hidden_dim, cfg.hash_dim, bias=True), | |
nn.Tanh() | |
) | |
def forward(self, img, word, mask=None): | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
''' | |
pad_mask = torch.zeros_like(word).masked_fill_(word == 0, 1).bool() | |
# vis: C3 / C4 / C5 | |
# word: b, length, 512 | |
# state: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, state = self.backbone.encode_text(word) | |
# b, 512, 26, 26 (C4) | |
fq = self.neck(vis, state) | |
# out_hash: b, code_length | |
# res: b, classes | |
out = self.avg(fq) | |
out = out.squeeze(-1).squeeze(-1) | |
out_hash = self.classifier(out) | |
res = self.classifier2(out_hash) | |
# img_hash: b, code_length | |
# txt_hash: b, code_length | |
img_hash = self.image_module(image) | |
txt_hash = self.text_module(state) | |
return img_hash, txt_hash, out_hash, res | |
class Clip_model(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder | |
clip_model = torch.jit.load(cfg.clip_pretrain, | |
map_location="cpu").eval() | |
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
self.avg = nn.AdaptiveAvgPool2d((1, 1)) | |
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float() | |
def forward(self, img, word, mask=None): | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
''' | |
# vis: C3 / C4 / C5 | |
# word: b, length, 512 | |
# state: b, 1024 | |
pad_mask = torch.zeros_like(word).masked_fill_(word == 0, 1).bool() | |
vis, image = self.backbone.encode_image(img) | |
word, state = self.backbone.encode_text(word) | |
f = self.neck(vis, state) | |
out = self.avg(f) | |
out = out.squeeze(-1).squeeze(-1) | |
image_features = image / image.norm(dim=-1, keepdim=True) | |
text_features = state / state.norm(dim=-1, keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.backbone.logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ text_features.t() | |
logits_per_text = logits_per_image.t() | |
# shape = [global_batch_size, global_batch_size] | |
return logits_per_image, logits_per_text | |
class CISEN_rsvit_hug(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, embed_dim, image_resolution, vision_layers, vision_width, | |
vision_patch_size, context_length, txt_length, vocab_size, | |
transformer_width, transformer_heads, transformer_layers, patch_size, | |
output_dim, ratio, emb_dim, fpn_in, fpn_out): | |
super().__init__() | |
# Vision & Text Encoder & Label Encoder | |
vision_heads = vision_width * 32 // 64 | |
backbone = CLIP(embed_dim, image_resolution, vision_layers, vision_width, | |
vision_patch_size, context_length, txt_length, vocab_size, | |
transformer_width, transformer_heads, transformer_layers) | |
self.backbone = backbone.float() | |
self.patch_emb = image_resolution // patch_size | |
self.FPN = ViTFPN(image_resolution, in_channels=fpn_in, out_channels=fpn_out) | |
self.ADP = Adapter(output_dim, 4) | |
# parameter | |
self.ratio = ratio | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.share_temperature = True | |
self.ce = nn.CrossEntropyLoss() | |
self.ms_adaptor = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2), | |
nn.GroupNorm(32, emb_dim), | |
nn.GELU(), | |
nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2), | |
), | |
nn.Sequential( | |
nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2), | |
), | |
nn.Sequential( | |
nn.Identity(), | |
), | |
nn.Sequential( | |
nn.MaxPool2d(2), | |
), | |
] | |
) | |
self.ms_adaptor.apply(self.init_adaptor) | |
def init_adaptor(self, m): | |
if isinstance(m, nn.Conv2d): | |
lecun_normal_(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.GroupNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
elif isinstance(m, nn.ConvTranspose2d): | |
lecun_normal_(m.weight) | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
# self.fc = nn.Linear(512, cfg.num_classes) | |
def image_encode(self, img): | |
vis, image = self.backbone.encode_image(img) | |
x = self.ADP(image) | |
x = self.ratio * x + (1 - self.ratio) * image | |
return x | |
def text_encode(self, txt): | |
word, text = self.backbone.encode_text(txt) | |
return text | |
def forward(self, img, txt): | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
mask: b, 1, h, w | |
stage: 1st or 2nd stage | |
''' | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() | |
# vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 | |
# word: b, length, 512 | |
# text: b, 1024 | |
# image: b, 1024 | |
vis, image = self.backbone.encode_image(img) | |
word, text = self.backbone.encode_text(txt) | |
x = self.ADP(image) | |
x = self.ratio * x + (1 - self.ratio) * image | |
# Construct multi-scale feats | |
vis_trans = [] | |
for i in range(len(self.ms_adaptor)): | |
x_ = rearrange( | |
vis[i], | |
"b (h w) c -> b c h w", | |
h=self.patch_emb, | |
w=self.patch_emb, | |
).contiguous() | |
feats = self.ms_adaptor[i](x_) | |
vis_trans.append(feats) | |
# fq = self.FPN(vis, x_t) | |
fv_t = self.FPN(vis_trans[1:], x, False) | |
# fv_t = self.gap(fq_t) | |
# b, 1024 | |
fv = fv_t | |
ft = text | |
fi = x | |
return fv, fi, ft |