Yeonchan Ahn
commited on
Commit
•
d40e838
1
Parent(s):
a9ecc32
added non negative uniform loss
Browse files- Alignment-and-Uniformity.py +18 -2
Alignment-and-Uniformity.py
CHANGED
@@ -19,6 +19,7 @@ Returns:
|
|
19 |
"x_unif_loss": float(x_unif_loss_v),
|
20 |
"y_unif_loss": float(y_unif_loss_v),
|
21 |
"unif_loss": float(unif_loss)
|
|
|
22 |
|
23 |
Examples:
|
24 |
|
@@ -40,9 +41,17 @@ def uniform_loss(x, t=2):
|
|
40 |
return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
44 |
class AlignmentandUniformity(evaluate.Metric):
|
45 |
-
def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0,
|
|
|
46 |
super(AlignmentandUniformity, self).__init__(*args, **kwargs)
|
47 |
self.align_alpha = align_alpha
|
48 |
self.unif_t = unif_t
|
@@ -81,10 +90,17 @@ class AlignmentandUniformity(evaluate.Metric):
|
|
81 |
x_unif_loss_v = uniform_loss(xs, t=self.unif_t)
|
82 |
y_unif_loss_v = uniform_loss(ys, t=self.unif_t)
|
83 |
unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2
|
|
|
|
|
|
|
|
|
84 |
|
85 |
return {
|
86 |
"align_loss": float(align_loss_val),
|
87 |
"x_unif_loss": float(x_unif_loss_v),
|
88 |
"y_unif_loss": float(y_unif_loss_v),
|
89 |
-
"unif_loss": float(unif_loss)
|
|
|
|
|
|
|
90 |
}
|
|
|
19 |
"x_unif_loss": float(x_unif_loss_v),
|
20 |
"y_unif_loss": float(y_unif_loss_v),
|
21 |
"unif_loss": float(unif_loss)
|
22 |
+
|
23 |
|
24 |
Examples:
|
25 |
|
|
|
41 |
return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()
|
42 |
|
43 |
|
44 |
+
def nonneg_uniform_loss(x, t=2):
|
45 |
+
tmp = torch.pdist(x, p=2).pow(2)
|
46 |
+
original = tmp.mul(-t).exp().mean().log()
|
47 |
+
boundary = -t * tmp.mean()
|
48 |
+
return original - boundary
|
49 |
+
|
50 |
+
|
51 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
52 |
class AlignmentandUniformity(evaluate.Metric):
|
53 |
+
def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0,
|
54 |
+
*args, **kwargs):
|
55 |
super(AlignmentandUniformity, self).__init__(*args, **kwargs)
|
56 |
self.align_alpha = align_alpha
|
57 |
self.unif_t = unif_t
|
|
|
90 |
x_unif_loss_v = uniform_loss(xs, t=self.unif_t)
|
91 |
y_unif_loss_v = uniform_loss(ys, t=self.unif_t)
|
92 |
unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2
|
93 |
+
|
94 |
+
nn_x_unif_loss_v = nonneg_uniform_loss(xs, t=self.unif_t)
|
95 |
+
nn_y_unif_loss_v = nonneg_uniform_loss(ys, t=self.unif_t)
|
96 |
+
nn_unif_loss = (nonneg_uniform_loss + nonneg_uniform_loss) / 2
|
97 |
|
98 |
return {
|
99 |
"align_loss": float(align_loss_val),
|
100 |
"x_unif_loss": float(x_unif_loss_v),
|
101 |
"y_unif_loss": float(y_unif_loss_v),
|
102 |
+
"unif_loss": float(unif_loss),
|
103 |
+
"nonneg_x_unif_loss": float(nn_x_unif_loss_v),
|
104 |
+
"nonneg_y_unif_loss": float(nn_y_unif_loss_v),
|
105 |
+
"nonneg_unif_loss": float(nn_unif_loss)
|
106 |
}
|