Yeonchan Ahn commited on
Commit
d40e838
1 Parent(s): a9ecc32

added non negative uniform loss

Browse files
Files changed (1) hide show
  1. Alignment-and-Uniformity.py +18 -2
Alignment-and-Uniformity.py CHANGED
@@ -19,6 +19,7 @@ Returns:
19
  "x_unif_loss": float(x_unif_loss_v),
20
  "y_unif_loss": float(y_unif_loss_v),
21
  "unif_loss": float(unif_loss)
 
22
 
23
  Examples:
24
 
@@ -40,9 +41,17 @@ def uniform_loss(x, t=2):
40
  return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()
41
 
42
 
 
 
 
 
 
 
 
43
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
44
  class AlignmentandUniformity(evaluate.Metric):
45
- def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0, *args, **kwargs):
 
46
  super(AlignmentandUniformity, self).__init__(*args, **kwargs)
47
  self.align_alpha = align_alpha
48
  self.unif_t = unif_t
@@ -81,10 +90,17 @@ class AlignmentandUniformity(evaluate.Metric):
81
  x_unif_loss_v = uniform_loss(xs, t=self.unif_t)
82
  y_unif_loss_v = uniform_loss(ys, t=self.unif_t)
83
  unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2
 
 
 
 
84
 
85
  return {
86
  "align_loss": float(align_loss_val),
87
  "x_unif_loss": float(x_unif_loss_v),
88
  "y_unif_loss": float(y_unif_loss_v),
89
- "unif_loss": float(unif_loss)
 
 
 
90
  }
 
19
  "x_unif_loss": float(x_unif_loss_v),
20
  "y_unif_loss": float(y_unif_loss_v),
21
  "unif_loss": float(unif_loss)
22
+
23
 
24
  Examples:
25
 
 
41
  return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()
42
 
43
 
44
+ def nonneg_uniform_loss(x, t=2):
45
+ tmp = torch.pdist(x, p=2).pow(2)
46
+ original = tmp.mul(-t).exp().mean().log()
47
+ boundary = -t * tmp.mean()
48
+ return original - boundary
49
+
50
+
51
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
52
  class AlignmentandUniformity(evaluate.Metric):
53
+ def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0,
54
+ *args, **kwargs):
55
  super(AlignmentandUniformity, self).__init__(*args, **kwargs)
56
  self.align_alpha = align_alpha
57
  self.unif_t = unif_t
 
90
  x_unif_loss_v = uniform_loss(xs, t=self.unif_t)
91
  y_unif_loss_v = uniform_loss(ys, t=self.unif_t)
92
  unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2
93
+
94
+ nn_x_unif_loss_v = nonneg_uniform_loss(xs, t=self.unif_t)
95
+ nn_y_unif_loss_v = nonneg_uniform_loss(ys, t=self.unif_t)
96
+ nn_unif_loss = (nonneg_uniform_loss + nonneg_uniform_loss) / 2
97
 
98
  return {
99
  "align_loss": float(align_loss_val),
100
  "x_unif_loss": float(x_unif_loss_v),
101
  "y_unif_loss": float(y_unif_loss_v),
102
+ "unif_loss": float(unif_loss),
103
+ "nonneg_x_unif_loss": float(nn_x_unif_loss_v),
104
+ "nonneg_y_unif_loss": float(nn_y_unif_loss_v),
105
+ "nonneg_unif_loss": float(nn_unif_loss)
106
  }