gina9726 commited on
Commit
4d22e10
1 Parent(s): 475b162

Update lavila/models/prompt_tuning.py

Browse files
Files changed (1) hide show
  1. lavila/models/prompt_tuning.py +3 -3
lavila/models/prompt_tuning.py CHANGED
@@ -58,7 +58,7 @@ class PromptPoolLearner(nn.Module):
58
 
59
  if istrain:
60
  inv_freq = self.id_table.sum() / self.id_table.float()
61
- weights = torch.softmax((1 + similarity) / 2 + 0.5 * (1 - gamma) * inv_freq / inv_freq.sum(), dim=1)
62
  idx = torch.multinomial(weights, k, replacement=False)
63
  else:
64
  idx = torch.argsort(similarity, dim=-1, descending=True)[:, :k]
@@ -74,8 +74,8 @@ class PromptPoolLearner(nn.Module):
74
  out['prompts'] = prompts
75
  sel_sim = similarity[torch.arange(BZ).view(-1, 1), idx]
76
  sel_key = keys[idx.flatten(), ...].view(BZ, k, self.prompt_dim)
77
- diff = F.mse_loss((sel_sim.unsqueeze(1) @ sel_key).squeeze(), query.detach(), reduction='sum') / BZ
78
- ksim = torch.mean(torch.matmul(keys, keys.t()) - torch.eye(self.size).to(keys.device))
79
  out['ps_loss'] = diff + ksim
80
 
81
  return out
 
58
 
59
  if istrain:
60
  inv_freq = self.id_table.sum() / self.id_table.float()
61
+ weights = (similarity + 1) / 2 * gamma + (1 - gamma) * torch.softmax(inv_freq, dim=-1)
62
  idx = torch.multinomial(weights, k, replacement=False)
63
  else:
64
  idx = torch.argsort(similarity, dim=-1, descending=True)[:, :k]
 
74
  out['prompts'] = prompts
75
  sel_sim = similarity[torch.arange(BZ).view(-1, 1), idx]
76
  sel_key = keys[idx.flatten(), ...].view(BZ, k, self.prompt_dim)
77
+ diff = F.mse_loss((sel_sim.unsqueeze(1) @ sel_key).squeeze(1), query.detach(), reduction='sum') / BZ
78
+ ksim = torch.sum(torch.abs(torch.matmul(keys, keys.t()) - torch.eye(self.size).to(keys.device))) / BZ
79
  out['ps_loss'] = diff + ksim
80
 
81
  return out