NTT123 commited on
Commit
7911639
1 Parent(s): 41ba53f

new model, new sampling method.

Browse files
Files changed (4) hide show
  1. inference.py +1 -1
  2. wavegru.ckpt +1 -1
  3. wavegru_mod.cc +15 -1
  4. wavegru_mod.so +1 -1
inference.py CHANGED
@@ -73,7 +73,7 @@ def mel_to_wav(net, netcpp, mel, config):
73
  )
74
  ft = wavegru_inference(net, mel)
75
  ft = jax.device_get(ft[0])
76
- wav = netcpp.inference(ft, 0.9)
77
  wav = np.array(wav)
78
  wav = librosa.mu_expand(wav - 127, mu=255)
79
  wav = librosa.effects.deemphasis(wav, coef=0.86)
 
73
  )
74
  ft = wavegru_inference(net, mel)
75
  ft = jax.device_get(ft[0])
76
+ wav = netcpp.inference(ft, 1.0)
77
  wav = np.array(wav)
78
  wav = librosa.mu_expand(wav - 127, mu=255)
79
  wav = librosa.effects.deemphasis(wav, coef=0.86)
wavegru.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c06310d989fd524359d5f3ecf8ea1dc146980bb594b7b90553d0d42a64c512d8
3
  size 58039876
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64b1ce6558cfe09b95bb29d48c34900ed0c3490d17c81fd6190969b226f4617a
3
  size 58039876
wavegru_mod.cc CHANGED
@@ -39,6 +39,7 @@ struct WaveGRU {
39
  vec o1b, o2b;
40
  vec t;
41
  vec h;
 
42
  mat o1, o2;
43
  std::vector<vec> embed;
44
 
@@ -55,7 +56,8 @@ struct WaveGRU {
55
  fco2(256),
56
  h(hidden_dim),
57
  o1b(hidden_dim),
58
- o2b(256) {
 
59
  m = create_mat(hidden_dim, 3*hidden_dim);
60
  o1 = create_mat(hidden_dim, hidden_dim);
61
  o2 = create_mat(hidden_dim, 256);
@@ -120,6 +122,18 @@ struct WaveGRU {
120
  }
121
  o1.SpMM_bias(h, o1b, &fco1, true);
122
  o2.SpMM_bias(fco1, o2b, &fco2, false);
 
 
 
 
 
 
 
 
 
 
 
 
123
  value = fco2.Sample(temperature);
124
  signal[index] = value;
125
  }
 
39
  vec o1b, o2b;
40
  vec t;
41
  vec h;
42
+ vec logits;
43
  mat o1, o2;
44
  std::vector<vec> embed;
45
 
 
56
  fco2(256),
57
  h(hidden_dim),
58
  o1b(hidden_dim),
59
+ o2b(256),
60
+ logits(256) {
61
  m = create_mat(hidden_dim, 3*hidden_dim);
62
  o1 = create_mat(hidden_dim, hidden_dim);
63
  o2 = create_mat(hidden_dim, 256);
 
122
  }
123
  o1.SpMM_bias(h, o1b, &fco1, true);
124
  o2.SpMM_bias(fco1, o2b, &fco2, false);
125
+ auto max_logit = fco2[0];
126
+ for (int i = 1; i <= 255; ++i) {
127
+ max_logit = max(max_logit, fco2[i]);
128
+ }
129
+ float total = 0.0;
130
+ for (int i = 0; i <= 255; ++i) {
131
+ logits[i] = csrblocksparse::fast_exp(fco2[i] - max_logit);
132
+ total += logits[i];
133
+ }
134
+ for (int i = 0; i <= 255; ++i) {
135
+ if (logits[i] < total / 256.0) fco2[i] = -1e9;
136
+ }
137
  value = fco2.Sample(temperature);
138
  signal[index] = value;
139
  }
wavegru_mod.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:12c27f0ea07f8da3a3ab48bc01bb0f68971ce7d57b19ada87669eab138623a9c
3
  size 525536
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b65c5312f24f8ab9cfa51e8340a24ac1165b247046a331386d636fba9036c19c
3
  size 525536