Spaces:
Sleeping
Sleeping
import torch | |
from torch.autograd import Variable | |
import numpy as np | |
import torch.nn as nn | |
from torch.nn import functional as F | |
import math | |
def init_hash(dataloader, args): | |
dataset_size = len(dataloader.dataset) | |
B = torch.randn(dataset_size, args.hash_dim).sign().cuda(non_blocking=True) | |
H = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True) | |
Hi = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True) | |
Ht = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True) | |
return B, H, Hi, Ht | |
def GenerateCode(model, data_loader, args): | |
num_data = len(data_loader.dataset) | |
B = np.zeros([num_data, args.hash_dim], dtype=np.float32) | |
Bi = np.zeros([num_data, args.hash_dim], dtype=np.float32) | |
Bt = np.zeros([num_data, args.hash_dim], dtype=np.float32) | |
for i, (idx, image, text, label, target) in enumerate(data_loader, 0): | |
image = image.cuda(non_blocking = True) | |
text = text.cuda(non_blocking = True) | |
img_hash, txt_hash, output, output_s = model(image, text) | |
B[idx, :] = torch.sign(output.detach().cpu()).numpy() | |
Bi[idx, :] = torch.sign(img_hash.detach().cpu()).numpy() | |
Bt[idx, :] = torch.sign(txt_hash.detach().cpu()).numpy() | |
return B, Bi, Bt | |
def CalcSim(batch_label, train_label): | |
S = (batch_label.mm(train_label.t()) > 0) | |
return S | |
# loss | |
def Logtrick(x): | |
lt = torch.log(1+torch.exp(-torch.abs(x))).cuda() + torch.max(x, Variable(torch.FloatTensor([0.]).cuda())) | |
return lt | |
class NTXentLoss(nn.Module): | |
""" | |
Normalized Temperature-scaled Cross-entropy Loss (NTXent Loss). | |
Contains single-modal and cross-modal implementations. | |
""" | |
def __init__(self, temperature=1, eps=1e-6): | |
super(NTXentLoss, self).__init__() | |
self.temperature = temperature | |
self.eps = eps | |
def forward(self, *args, type='orig'): | |
if type == 'cross': | |
return self.forward_cross_modal(*args) | |
if type == 'orig': | |
return self.forward_orig(*args) | |
if type == 'both': | |
return self.forward_orig(*args), self.forward_cross_modal(*args) | |
else: | |
raise Exception("Wrong NTXent loss type, must be: 'cross', 'orig' or 'both'") | |
def forward_cross_modal(self, mod1, mod2): | |
""" | |
Cross-modal case: | |
p - positive pair | |
n - negative pair | |
sim - cosine similarity | |
ix - image modality feature number x | |
tx - text modality feature number x | |
Cross-modal case of NTXent doesn't consider similarities inside of the same modality | |
Similarities matrix: exp(sim(i, y)) | |
+--+--+--+--+--+--+--+ | |
| |i1|i2|i3|t1|t2|t3| | |
Modality +--+--+--+--+--+--+--+ | |
Features |i1|0 |0 |0 |p |n |n | | |
+--+ +--+ +--+--+--+--+--+--+--+ | |
|i1| |t1| |i2|0 |0 |0 |n |p |n | | |
+--+ +--+ +--+--+--+--+--+--+--+ | |
|i2| |t2| ------> |i3|0 |0 |0 |n |n |p | | |
+--+ +--+ +--+--+--+--+--+--+--+ | |
|i3| |t3| |t1|p |n |n |0 |0 |0 | | |
+--+ +--+ +--+--+--+--+--+--+--+ | |
|t2|n |p |n |0 |0 |0 | | |
+--+--+--+--+--+--+--+ | |
|t3|n |n |p |0 |0 |0 | | |
+--+--+--+--+--+--+--+ | |
:param: mod1: features of the 1st modality | |
:param: mod1: features of the 2nd modality | |
:return: NTXent loss | |
""" | |
# normalize for numerical stability | |
mod1 = F.normalize(mod1) | |
mod2 = F.normalize(mod2) | |
out = torch.cat([mod1, mod2], dim=0) | |
# cov and sim: [2 * batch_size, 2 * batch_size * world_size] | |
cov = torch.mm(out, out.t().contiguous()) # cosine similarities matrix | |
sim = torch.exp(cov / self.temperature) | |
# mask for cross-modal case, nullifies certain regions (see docstring) | |
zeros = torch.zeros(mod1.shape[0], mod1.shape[0]).to(sim.device) | |
ones = torch.ones(mod1.shape[0], mod1.shape[0]).to(sim.device) | |
mask = torch.hstack([torch.vstack([zeros, ones]), torch.vstack([ones, zeros])]).to(sim.device) | |
sim = sim * mask | |
# neg: [2 * batch_size] | |
# negative pairs sum | |
neg = sim.sum(dim=1) | |
# Positive similarity, pos becomes [2 * batch_size] | |
pos = torch.exp(torch.sum(mod1 * mod2, dim=-1) / self.temperature) | |
pos = torch.cat([pos, pos], dim=0) | |
loss = -torch.log(pos / (neg + self.eps)).sum() | |
return loss | |
def forward_orig(self, out_1, out_2): | |
""" | |
Implementation taken from: | |
https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/simclr/simclr_module.py | |
p - positive pair | |
n - negative pair | |
sim - cosine similarity | |
e - Euler's number | |
ix - value x of input feature vector i | |
tx - value x of input feature vector t | |
Similarities matrix: exp(sim(i, y)) | |
+--+--+--+--+--+--+--+ | |
| |i1|i2|i3|t1|t2|t3| | |
Modality +--+--+--+--+--+--+--+ | |
Features |i1|e |n |n |p |n |n | | |
+--+ +--+ +--+--+--+--+--+--+--+ | |
|i1| |t1| |i2|n |e |n |n |p |n | | |
+--+ +--+ +--+--+--+--+--+--+--+ | |
|i2| |t2| ------> |i3|n |n |e |n |n |p | | |
+--+ +--+ +--+--+--+--+--+--+--+ | |
|i3| |t3| |t1|p |n |n |e |n |n | | |
+--+ +--+ +--+--+--+--+--+--+--+ | |
|t2|n |p |n |n |e |n | | |
+--+--+--+--+--+--+--+ | |
|t3|n |n |p |n |n |e | | |
+--+--+--+--+--+--+--+ | |
:param out_1: input feature vector i | |
:param out_2: input feature vector t | |
:return: NTXent loss | |
""" | |
out_1 = F.normalize(out_1) | |
out_2 = F.normalize(out_2) | |
out = torch.cat([out_1, out_2], dim=0) | |
# cov and sim: [2 * batch_size, 2 * batch_size * world_size] | |
# neg: [2 * batch_size] | |
cov = torch.mm(out, out.t().contiguous()) | |
sim = torch.exp(cov / self.temperature) | |
neg = sim.sum(dim=-1) | |
# from each row, subtract e^1 to remove similarity measure for x1.x1 | |
row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device) | |
neg = torch.clamp(neg - row_sub, min=self.eps) # clamp for numerical stability | |
# Positive similarity, pos becomes [2 * batch_size] | |
o = out_1 * out_2 | |
pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / self.temperature) | |
pos = torch.cat([pos, pos], dim=0) | |
loss = -torch.log(pos / (neg + self.eps)).mean() | |
return loss | |
""" | |
out_hash: real-value code | |
H: total real-value code | |
Bbatch: batch hash code | |
S: similarity | |
num_train: number of train | |
num_batch: batchsize | |
""" | |
def Calcloss(out_hash, H, Bbatch, S, num_train, num_batch, args): | |
theta_x = out_hash.float().mm(Variable(H.cuda()).t()) / 2 | |
logloss = (Variable(S.cuda()) * theta_x - Logtrick(theta_x)).sum() \ | |
/ (num_train * num_batch) | |
regterm = (Bbatch - out_hash).pow(2).sum() / (num_train * num_batch) | |
loss_p = - logloss + args.lamda * regterm | |
return logloss, regterm, loss_p | |
def CalcNTXentLoss(img_hash, txt_hash, out_hash, Criterion, args): | |
""" | |
Calculate NTXent Loss | |
:param: h_img1: batch of image hashes #1 (original) | |
:param: h_img2: batch of image hashes #2 (augmented) | |
:param: h_txt1: batch of text hashes #1 (original) | |
:param: h_txt2: batch of text hashes #2 (augmented) | |
:returns: NTXent Loss | |
""" | |
loss_ntxent_inter1 = Criterion(img_hash, txt_hash, type='cross') | |
loss_ntxent_inter2 = Criterion(img_hash, out_hash, type='orig') | |
loss_ntxent_inter3 = Criterion(out_hash, txt_hash, type='orig') | |
# loss_ntxent_intra = Criterion(out_hash, out_hash, type='orig') * args.contrastive_weights[1] | |
loss_ntxent = loss_ntxent_inter1 * args.contrastive[0] + loss_ntxent_inter2 * args.contrastive[1] + loss_ntxent_inter3 * args.contrastive[2] | |
return loss_ntxent | |
def Calc_total_loss(H, B, S, num_train, args): | |
theta = H.mm(H.t()) / 2 | |
t1 = (theta*theta).sum() / (num_train * num_train) | |
logloss = (- theta * S + Logtrick(Variable(theta)).data).sum() | |
regterm = (H - B).pow(2).sum() | |
loss_p = logloss + args.lamda * regterm | |
return logloss, regterm, loss_p | |
def CalcHammingDist(B1, B2): | |
q = B2.shape[1] | |
distH = 0.5 * (q - np.dot(B1, B2.transpose())) | |
return distH | |
def CalcMap(qB, rB, queryL, retrievalL): | |
# qB: m, q | |
# rB: n, q | |
# queryL: {0,1}^{mxl} | |
# retrievalL: {0,1}^{nxl} | |
num_query = queryL.shape[0] | |
map = 0 | |
# print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') | |
for iter in range(num_query): | |
# 标签匹配 | |
gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32) | |
tsum = np.sum(gnd) | |
if tsum == 0: | |
continue | |
# 计算query 与 database之间的汉明距离 | |
hamm = CalcHammingDist(qB[iter, :], rB) | |
# 排序 | |
ind = np.argsort(hamm) | |
# 汉明距离与标签对应 | |
gnd = gnd[ind] | |
count = np.linspace(1, int(tsum), int(tsum)) | |
# 按照结果排序比对是否标签一致,并返回一致的坐标 | |
tindex = np.asarray(np.where(gnd == 1)) + 1.0 | |
map_ = np.mean(count / (tindex)) | |
# print(map_) | |
map = map + map_ | |
map = map / num_query | |
# print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') | |
return map | |
def CalcTopMap(qB, rB, queryL, retrievalL, topk = 20): | |
# qB: {-1,+1}^{mxq} | |
# rB: {-1,+1}^{nxq} | |
# queryL: {0,1}^{mxl} | |
# retrievalL: {0,1}^{nxl} | |
num_query = queryL.shape[0] | |
topkmap = 0 | |
# print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') | |
for iter in range(num_query): | |
gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32) | |
hamm = CalcHammingDist(qB[iter, :], rB) | |
ind = np.argsort(hamm) | |
gnd = gnd[ind] | |
tgnd = gnd[0:topk] | |
tsum = np.sum(tgnd) | |
if tsum == 0: | |
continue | |
count = np.linspace(1, int(tsum), int(tsum)) | |
tindex = np.asarray(np.where(tgnd == 1)) + 1.0 | |
topkmap_ = np.mean(count / (tindex)) | |
# print(topkmap_) | |
topkmap = topkmap + topkmap_ | |
topkmap = topkmap / num_query | |
# print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') | |
return topkmap |