Spaces:
Paused
Paused
Update lavila/models/prompt_tuning.py
Browse files
lavila/models/prompt_tuning.py
CHANGED
@@ -58,7 +58,7 @@ class PromptPoolLearner(nn.Module):
|
|
58 |
|
59 |
if istrain:
|
60 |
inv_freq = self.id_table.sum() / self.id_table.float()
|
61 |
-
weights =
|
62 |
idx = torch.multinomial(weights, k, replacement=False)
|
63 |
else:
|
64 |
idx = torch.argsort(similarity, dim=-1, descending=True)[:, :k]
|
@@ -74,8 +74,8 @@ class PromptPoolLearner(nn.Module):
|
|
74 |
out['prompts'] = prompts
|
75 |
sel_sim = similarity[torch.arange(BZ).view(-1, 1), idx]
|
76 |
sel_key = keys[idx.flatten(), ...].view(BZ, k, self.prompt_dim)
|
77 |
-
diff = F.mse_loss((sel_sim.unsqueeze(1) @ sel_key).squeeze(), query.detach(), reduction='sum') / BZ
|
78 |
-
ksim = torch.
|
79 |
out['ps_loss'] = diff + ksim
|
80 |
|
81 |
return out
|
|
|
58 |
|
59 |
if istrain:
|
60 |
inv_freq = self.id_table.sum() / self.id_table.float()
|
61 |
+
weights = (similarity + 1) / 2 * gamma + (1 - gamma) * torch.softmax(inv_freq, dim=-1)
|
62 |
idx = torch.multinomial(weights, k, replacement=False)
|
63 |
else:
|
64 |
idx = torch.argsort(similarity, dim=-1, descending=True)[:, :k]
|
|
|
74 |
out['prompts'] = prompts
|
75 |
sel_sim = similarity[torch.arange(BZ).view(-1, 1), idx]
|
76 |
sel_key = keys[idx.flatten(), ...].view(BZ, k, self.prompt_dim)
|
77 |
+
diff = F.mse_loss((sel_sim.unsqueeze(1) @ sel_key).squeeze(1), query.detach(), reduction='sum') / BZ
|
78 |
+
ksim = torch.sum(torch.abs(torch.matmul(keys, keys.t()) - torch.eye(self.size).to(keys.device))) / BZ
|
79 |
out['ps_loss'] = diff + ksim
|
80 |
|
81 |
return out
|