Spaces:
Sleeping
Sleeping
kevinwang676
commited on
Commit
•
9016314
1
Parent(s):
b40bf00
Upload 93 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- Phoneme_Hallucinator_v2/.gitignore +6 -0
- Phoneme_Hallucinator_v2/.vscode/launch.json +26 -0
- Phoneme_Hallucinator_v2/Phoneme Hallucinator DEMO.ipynb +0 -0
- Phoneme_Hallucinator_v2/README.md +36 -0
- Phoneme_Hallucinator_v2/__pycache__/__init__.cpython-310.pyc +0 -0
- Phoneme_Hallucinator_v2/datasets/__init__.py +27 -0
- Phoneme_Hallucinator_v2/datasets/__pycache__/__init__.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/datasets/__pycache__/speech.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/datasets/speech.py +278 -0
- Phoneme_Hallucinator_v2/evaluation/ASR-Eval.ipynb +0 -0
- Phoneme_Hallucinator_v2/evaluation/ASR.ipynb +0 -0
- Phoneme_Hallucinator_v2/evaluation/init +1 -0
- Phoneme_Hallucinator_v2/exp/speech_XXL_cond/params.json +33 -0
- Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/checkpoint +3 -0
- Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/last.ckpt.data-00000-of-00001 +3 -0
- Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/last.ckpt.index +0 -0
- Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/last.ckpt.meta +3 -0
- Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/params.ckpt.data-00000-of-00001 +3 -0
- Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/params.ckpt.index +0 -0
- Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/params.ckpt.meta +3 -0
- Phoneme_Hallucinator_v2/models/__init__.py +9 -0
- Phoneme_Hallucinator_v2/models/__pycache__/__init__.cpython-310.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/__init__.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/__init__.cpython-37.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/base.cpython-310.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/base.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/base.cpython-37.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/cVAE.cpython-310.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/cVAE.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/cVAE.cpython-37.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/networks.cpython-310.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/networks.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/networks.cpython-37.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/pc_acset.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/pc_acset_vae.cpython-310.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/pc_acset_vae.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/pc_acset_vae.cpython-37.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/pc_acset_vae_06.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/pc_encoder.cpython-310.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/pc_encoder.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/pc_encoder.cpython-37.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/runner.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/set_transformer.cpython-310.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/set_transformer.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/set_transformer.cpython-37.pyc +0 -0
- Phoneme_Hallucinator_v2/models/__pycache__/utils.cpython-36.pyc +0 -0
- Phoneme_Hallucinator_v2/models/base.py +103 -0
- Phoneme_Hallucinator_v2/models/cVAE.py +45 -0
- Phoneme_Hallucinator_v2/models/flow/__init__.py +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/last.ckpt.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/last.ckpt.meta filter=lfs diff=lfs merge=lfs -text
|
38 |
+
Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/params.ckpt.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/params.ckpt.meta filter=lfs diff=lfs merge=lfs -text
|
Phoneme_Hallucinator_v2/.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.vscode
|
2 |
+
exp/speech_XXL_cond
|
3 |
+
*__pycache__
|
4 |
+
*.pt
|
5 |
+
*.npy
|
6 |
+
*.wav
|
Phoneme_Hallucinator_v2/.vscode/launch.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// Use IntelliSense to learn about possible attributes.
|
3 |
+
// Hover to view descriptions of existing attributes.
|
4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
5 |
+
"version": "0.2.0",
|
6 |
+
"configurations": [
|
7 |
+
{
|
8 |
+
"name": "Python: Current File",
|
9 |
+
"type": "python",
|
10 |
+
"request": "launch",
|
11 |
+
"program": "${file}",
|
12 |
+
"console": "integratedTerminal",
|
13 |
+
"justMyCode": true,
|
14 |
+
"args": [
|
15 |
+
"--cfg_file",
|
16 |
+
"exp/speech_XXL_cond/params.json",
|
17 |
+
"--num_samples",
|
18 |
+
"5000",
|
19 |
+
"--path",
|
20 |
+
"matching_set/target.pt",
|
21 |
+
"--out_path",
|
22 |
+
"matching_set/target_expanded_5k.npy"
|
23 |
+
]
|
24 |
+
}
|
25 |
+
]
|
26 |
+
}
|
Phoneme_Hallucinator_v2/Phoneme Hallucinator DEMO.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Phoneme_Hallucinator_v2/README.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Phoneme_Hallucinator
|
2 |
+
This is the repository of the paper "Phoneme Hallucinator: One-shot Voice Conversion via Set Expansion" accepted by AAAI-2024. Some audio samples are provided [here](https://phonemehallucinator.github.io/).
|
3 |
+
|
4 |
+
## Inference Tutorial
|
5 |
+
1. If you only want to run our VC pipeline, please download `Phoneme Hallucinator DEMO.ipynb` in this repo and run it in google colab.
|
6 |
+
|
7 |
+
## Training Tutorial
|
8 |
+
1. Prepare environment. Require `Python 3.6.3` and the following packages
|
9 |
+
```
|
10 |
+
pillow == 8.0.1
|
11 |
+
torch == 1.10.2
|
12 |
+
tensorflow == 1.15.5
|
13 |
+
tensorflow-probability == 0.7.0
|
14 |
+
tensorpack == 0.9.8
|
15 |
+
h5py == 2.10.0
|
16 |
+
numpy == 1.19.5
|
17 |
+
pathlib == 1.0.1
|
18 |
+
tqdm == 4.64.1
|
19 |
+
easydict == 1.10
|
20 |
+
matplotlib == 3.3.4
|
21 |
+
scikit-learn == 0.24.2
|
22 |
+
scipy == 1.5.4
|
23 |
+
seaborn == 0.11.2
|
24 |
+
```
|
25 |
+
3. To prepare the training set, we need to use WavLM to extract speech representations. Go to [kNN-VC repo](https://github.com/bshall/knn-vc) and follow its instructions to extract speech representations. Namely, after placing LibriSpeech dataset in a correct location, run the command:
|
26 |
+
|
27 |
+
`python prematch_dataset.py --librispeech_path /path/to/librispeech/root --out_path /path/where/you/want/outputs/to/go --topk 4 --matching_layer 6 --synthesis_layer 6`
|
28 |
+
|
29 |
+
Note that we don't use the "--prematch" option, becuase we only need to extract representations, not to extract and then perform kNN regression.
|
30 |
+
|
31 |
+
4. After the above step, you can get a `--out_path` folder with three subfolders `train-clean-100`, `test-clean` and `dev-clean` where each folder contains the speech representation files (".pt").
|
32 |
+
5. Go to our repo `./dataset/speech.py` and change the variables `path_to_wavlm_feat` and `tfrecord_path` accordingly. You need to change `path_to_wavlm_feat` to where the speech representations are stored in the previous step.
|
33 |
+
6. Start Training by the following command:
|
34 |
+
`python scripts/run.py --cfg_file=./exp/speech_XXL_cond/params.json --mode=train`
|
35 |
+
|
36 |
+
If `tfrecord_path` doesn't exist, our codes will create tfrecords and save them to `tfrecord_path` before training starts. Note that if you encounter numerical issues ("NaN, INF") when the training starts, just try re-run the command multiple times. Training los will be saved to `./exp/speech_XXL_cond/`.
|
Phoneme_Hallucinator_v2/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (173 Bytes). View file
|
|
Phoneme_Hallucinator_v2/datasets/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
def get_dataset(args, split):
|
5 |
+
if args.dataset == 'speech':
|
6 |
+
from .speech import Dataset
|
7 |
+
dataset = Dataset(split, args.batch_size, args.set_size, args.mask_type)
|
8 |
+
else:
|
9 |
+
raise ValueError()
|
10 |
+
|
11 |
+
return dataset
|
12 |
+
|
13 |
+
def cache(args, split, fname):
|
14 |
+
if os.path.isfile(fname):
|
15 |
+
with open(fname, 'rb') as f:
|
16 |
+
batches = pickle.load(f)
|
17 |
+
else:
|
18 |
+
batches = []
|
19 |
+
dataset = get_dataset(args, split)
|
20 |
+
dataset.initialize()
|
21 |
+
for _ in range(dataset.num_batches):
|
22 |
+
batch = dataset.next_batch()
|
23 |
+
batches.append(batch)
|
24 |
+
with open(fname, 'wb') as f:
|
25 |
+
pickle.dump(batches, f)
|
26 |
+
|
27 |
+
return batches
|
Phoneme_Hallucinator_v2/datasets/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (863 Bytes). View file
|
|
Phoneme_Hallucinator_v2/datasets/__pycache__/speech.cpython-36.pyc
ADDED
Binary file (7.36 kB). View file
|
|
Phoneme_Hallucinator_v2/datasets/speech.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import h5py
|
2 |
+
import numpy as np
|
3 |
+
from pathlib import Path
|
4 |
+
import tensorflow as tf
|
5 |
+
import torch
|
6 |
+
from tqdm import tqdm
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import glob
|
10 |
+
import random
|
11 |
+
import pdb
|
12 |
+
np.random.seed(0)
|
13 |
+
generate_tf_record = False
|
14 |
+
|
15 |
+
tfrecord_path = "/path/to/save/your/tfrecord/"
|
16 |
+
path_to_wavlm_feat = "/path/to/your/wavlm/feat"
|
17 |
+
|
18 |
+
if not os.path.exists(tfrecord_path):
|
19 |
+
generate_tf_record = True
|
20 |
+
os.makedirs(tfrecord_path, exist_ok=True)
|
21 |
+
train_filename = tfrecord_path + 'train'
|
22 |
+
valid_filename= tfrecord_path + 'valid'
|
23 |
+
test_filename= tfrecord_path + 'test'
|
24 |
+
train_path = Path(os.path.join(path_to_wavlm_feat, "train-clean-100"))
|
25 |
+
valid_path = Path(os.path.join(path_to_wavlm_feat, "dev-clean"))
|
26 |
+
test_path = Path(os.path.join(path_to_wavlm_feat, "test-clean"))
|
27 |
+
|
28 |
+
train_size = 27269
|
29 |
+
valid_size = 1940
|
30 |
+
test_size = 1850
|
31 |
+
|
32 |
+
def get_filenames(path):
|
33 |
+
all_files = []
|
34 |
+
all_files.extend(list(path.rglob("**/*.pt")))
|
35 |
+
return all_files
|
36 |
+
|
37 |
+
def length_filter(paths):
|
38 |
+
filtered_paths = []
|
39 |
+
print("filter short files")
|
40 |
+
for each in tqdm(paths):
|
41 |
+
data = torch.load(each).numpy().astype(np.float32)
|
42 |
+
if data.shape[0] < 200:
|
43 |
+
continue
|
44 |
+
filtered_paths.append(each)
|
45 |
+
return filtered_paths
|
46 |
+
|
47 |
+
|
48 |
+
def generate_mask(x, mask_type):
|
49 |
+
if mask_type == b'expand':
|
50 |
+
m = np.zeros_like(x)
|
51 |
+
N = np.random.randint(x.shape[0]//8, x.shape[0])
|
52 |
+
ind = np.random.choice(x.shape[0], N, replace=False)
|
53 |
+
m[ind] = 1.
|
54 |
+
elif mask_type == b'few_expand':
|
55 |
+
m = np.zeros_like(x)
|
56 |
+
N = np.random.randint(x.shape[0]//8)
|
57 |
+
ind = np.random.choice(x.shape[0], N, replace=False)
|
58 |
+
m[ind] = 1.
|
59 |
+
elif mask_type == b'arb_expand':
|
60 |
+
m = np.zeros_like(x)
|
61 |
+
N = np.random.randint(x.shape[0])
|
62 |
+
ind = np.random.choice(x.shape[0], N, replace=False)
|
63 |
+
m[ind] = 1.
|
64 |
+
elif mask_type == b'det_expand':
|
65 |
+
m = np.zeros_like(x)
|
66 |
+
ind = np.random.choice(x.shape[0], 100, replace=False)
|
67 |
+
m[ind] = 1.
|
68 |
+
elif mask_type == b'complete':
|
69 |
+
m = np.zeros_like(x)
|
70 |
+
while np.sum(m[:,0]) < x.shape[0] // 8:
|
71 |
+
p = np.random.uniform(-0.5, 0.5, size=4)
|
72 |
+
xa = np.concatenate([x, np.ones([x.shape[0],1])], axis=1)
|
73 |
+
m = (np.dot(xa, p) > 0).astype(np.float32)
|
74 |
+
m = np.repeat(np.expand_dims(m, axis=1), 3, axis=1)
|
75 |
+
else:
|
76 |
+
raise ValueError()
|
77 |
+
|
78 |
+
return m
|
79 |
+
|
80 |
+
|
81 |
+
def wrap_int64(value):
|
82 |
+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
|
83 |
+
def wrap_bytes(value):
|
84 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
85 |
+
def print_progress(count, total):
|
86 |
+
# Percentage completion.
|
87 |
+
pct_complete = float(count) / total
|
88 |
+
|
89 |
+
# Status-message.
|
90 |
+
# Note the \r which means the line should overwrite itself.
|
91 |
+
msg = "\r- Progress: {0:.1%}".format(pct_complete)
|
92 |
+
|
93 |
+
# Print it.
|
94 |
+
sys.stdout.write(msg)
|
95 |
+
sys.stdout.flush()
|
96 |
+
def convert(image_paths, out_path, max_files=1000):
|
97 |
+
# Args:
|
98 |
+
# image_paths List of file-paths for the images.
|
99 |
+
# labels Class-labels for the images.
|
100 |
+
# out_path File-path for the TFRecords output file.
|
101 |
+
|
102 |
+
print("Converting: " + out_path)
|
103 |
+
|
104 |
+
# Number of images. Used when printing the progress.
|
105 |
+
num_images = len(image_paths)
|
106 |
+
splits = (num_images//max_files) + 1
|
107 |
+
if num_images%max_files == 0:
|
108 |
+
splits-=1
|
109 |
+
print(f"\nUsing {splits} shard(s) for {num_images} files, with up to {max_files} samples per shard")
|
110 |
+
file_count = 0
|
111 |
+
for i in tqdm(range(splits)):
|
112 |
+
# Open a TFRecordWriter for the output-file.
|
113 |
+
with tf.io.TFRecordWriter("{}_{}_{}.tfrecords".format(out_path, i+1, splits)) as writer:
|
114 |
+
|
115 |
+
# Iterate over all the image-paths and class-labels.
|
116 |
+
current_shard_count = 0
|
117 |
+
while current_shard_count < max_files:
|
118 |
+
index = i*max_files+current_shard_count
|
119 |
+
if index == len(image_paths):
|
120 |
+
break
|
121 |
+
current_image = image_paths[index]
|
122 |
+
|
123 |
+
# Load the image-file using matplotlib's imread function.
|
124 |
+
img = torch.load(current_image).numpy().astype(np.float32)
|
125 |
+
|
126 |
+
# Convert the image to raw bytes.
|
127 |
+
img_bytes = img.tostring()
|
128 |
+
|
129 |
+
# Create a dict with the data we want to save in the
|
130 |
+
# TFRecords file. You can add more relevant data here.
|
131 |
+
data = \
|
132 |
+
{
|
133 |
+
'image': wrap_bytes(img_bytes),
|
134 |
+
'length': wrap_int64(img.shape[0]),
|
135 |
+
"filename": wrap_bytes(bytes(os.path.splitext(current_image.name)[0], 'utf-8'))
|
136 |
+
}
|
137 |
+
|
138 |
+
# Wrap the data as TensorFlow Features.
|
139 |
+
feature = tf.train.Features(feature=data)
|
140 |
+
|
141 |
+
# Wrap again as a TensorFlow Example.
|
142 |
+
example = tf.train.Example(features=feature)
|
143 |
+
|
144 |
+
# Serialize the data.
|
145 |
+
serialized = example.SerializeToString()
|
146 |
+
|
147 |
+
# Write the serialized data to the TFRecords file.
|
148 |
+
writer.write(serialized)
|
149 |
+
current_shard_count+=1
|
150 |
+
file_count += 1
|
151 |
+
print(f"\nWrote {file_count} elements to TFRecord")
|
152 |
+
|
153 |
+
|
154 |
+
if generate_tf_record:
|
155 |
+
train_image_paths = length_filter(get_filenames(train_path))
|
156 |
+
valid_image_paths = length_filter(get_filenames(valid_path))
|
157 |
+
test_image_paths = length_filter(get_filenames(test_path))
|
158 |
+
print(f"Number of training data after length filering: {len(train_image_paths)}")
|
159 |
+
print(f"Number of valid data after length filering: {len(valid_image_paths)}")
|
160 |
+
print(f"Number of testing data after length filering: {len(test_image_paths)}")
|
161 |
+
random.Random(4).shuffle(train_image_paths)
|
162 |
+
|
163 |
+
train_size = len(train_image_paths)
|
164 |
+
valid_size = len(valid_image_paths)
|
165 |
+
test_size = len(test_image_paths)
|
166 |
+
convert(image_paths=train_image_paths,
|
167 |
+
out_path=train_filename)
|
168 |
+
|
169 |
+
convert(image_paths=valid_image_paths,
|
170 |
+
out_path=valid_filename)
|
171 |
+
|
172 |
+
convert(image_paths=test_image_paths,
|
173 |
+
out_path=test_filename)
|
174 |
+
|
175 |
+
|
176 |
+
def parse(serialized):
|
177 |
+
# Define a dict with the data-names and types we expect to
|
178 |
+
# find in the TFRecords file.
|
179 |
+
# It is a bit awkward that this needs to be specified again,
|
180 |
+
# because it could have been written in the header of the
|
181 |
+
# TFRecords file instead.
|
182 |
+
features = \
|
183 |
+
{
|
184 |
+
'image': tf.io.FixedLenFeature([], tf.string),
|
185 |
+
'length': tf.io.FixedLenFeature([], tf.int64),
|
186 |
+
'filename': tf.io.FixedLenFeature([], tf.string),
|
187 |
+
}
|
188 |
+
|
189 |
+
# Parse the serialized data so we get a dict with our data.
|
190 |
+
parsed_example = tf.io.parse_single_example(serialized=serialized,
|
191 |
+
features=features)
|
192 |
+
|
193 |
+
# Get the image as raw bytes.
|
194 |
+
image_raw = parsed_example['image']
|
195 |
+
|
196 |
+
# Decode the raw bytes so it becomes a tensor with type.
|
197 |
+
image = tf.io.decode_raw(image_raw, tf.float32)
|
198 |
+
|
199 |
+
|
200 |
+
# Get the label associated with the image.
|
201 |
+
length = parsed_example['length']
|
202 |
+
|
203 |
+
image = tf.reshape(image, [length, 1024])
|
204 |
+
filename = parsed_example['filename']
|
205 |
+
|
206 |
+
# The image and label are now correct TensorFlow types.
|
207 |
+
return image, filename
|
208 |
+
|
209 |
+
def process(x, filename, set_size, mask_type):
|
210 |
+
x = x/10
|
211 |
+
ind = np.random.choice(x.shape[0], set_size, replace=False)
|
212 |
+
x = x[ind]
|
213 |
+
m = generate_mask(x, mask_type)
|
214 |
+
#N = np.random.randint(set_size)
|
215 |
+
#S = np.random.randint(x.shape[0] - set_size + 1)
|
216 |
+
#x = x[S:S+set_size]
|
217 |
+
#m = np.zeros_like(x)
|
218 |
+
#S = np.random.randint(set_size - N + 1)
|
219 |
+
#m[S:S+N] = 1.0
|
220 |
+
return x, m, filename
|
221 |
+
|
222 |
+
|
223 |
+
def get_dst(split, set_size, mask_type):
|
224 |
+
if split == 'train':
|
225 |
+
files = glob.glob(train_filename+"*.tfrecords", recursive=False)
|
226 |
+
dst = tf.data.TFRecordDataset(files)
|
227 |
+
size = train_size
|
228 |
+
dst = dst.map(parse)
|
229 |
+
dst = dst.shuffle(256)
|
230 |
+
dst = dst.map(lambda x, y: tuple(tf.compat.v1.py_func(process, [x, y, set_size, mask_type], [tf.float32, tf.float32, tf.string])), num_parallel_calls=8)
|
231 |
+
elif split == 'valid':
|
232 |
+
files = glob.glob(valid_filename+"*.tfrecords", recursive=False)
|
233 |
+
dst = tf.data.TFRecordDataset(files)
|
234 |
+
size = valid_size
|
235 |
+
dst = dst.map(parse)
|
236 |
+
dst = dst.map(lambda x, y: tuple(tf.compat.v1.py_func(process, [x, y, set_size, mask_type], [tf.float32, tf.float32, tf.string])), num_parallel_calls=8)
|
237 |
+
else:
|
238 |
+
files = glob.glob(test_filename+"*.tfrecords", recursive=False)
|
239 |
+
dst = tf.data.TFRecordDataset(files)
|
240 |
+
size = test_size
|
241 |
+
dst = dst.map(parse)
|
242 |
+
dst = dst.map(lambda x, y: tuple(tf.compat.v1.py_func(process, [x, y, set_size, mask_type], [tf.float32, tf.float32, tf.string])), num_parallel_calls=8)
|
243 |
+
return dst, size
|
244 |
+
|
245 |
+
class Dataset(object):
|
246 |
+
def __init__(self, split, batch_size, set_size, mask_type):
|
247 |
+
g = tf.Graph()
|
248 |
+
with g.as_default():
|
249 |
+
# open a session
|
250 |
+
config = tf.compat.v1.ConfigProto()
|
251 |
+
config.log_device_placement = True
|
252 |
+
config.allow_soft_placement = True
|
253 |
+
config.gpu_options.allow_growth = True
|
254 |
+
self.sess = tf.compat.v1.Session(config=config, graph=g)
|
255 |
+
# build dataset
|
256 |
+
dst, size = get_dst(split, set_size, mask_type)
|
257 |
+
self.size = size
|
258 |
+
self.num_batches = self.size // batch_size
|
259 |
+
dst = dst.batch(batch_size, drop_remainder=False)
|
260 |
+
dst = dst.prefetch(1)
|
261 |
+
|
262 |
+
dst_it = tf.compat.v1.data.make_initializable_iterator(dst)
|
263 |
+
x, b, filename = dst_it.get_next()
|
264 |
+
self.x = x
|
265 |
+
self.b = b
|
266 |
+
self.filename = filename
|
267 |
+
#self.x = tf.reshape(x, [batch_size, set_size, 1024])
|
268 |
+
#self.b = tf.reshape(b, [batch_size, set_size, 1024])
|
269 |
+
self.dimension = 1024
|
270 |
+
self.initializer = dst_it.initializer
|
271 |
+
|
272 |
+
def initialize(self):
|
273 |
+
self.sess.run(self.initializer)
|
274 |
+
|
275 |
+
def next_batch(self):
|
276 |
+
x, b, filename = self.sess.run([self.x, self.b, self.filename])
|
277 |
+
m = np.ones_like(b)
|
278 |
+
return {'x':x, 'b':b, 'm':m, "f":filename}
|
Phoneme_Hallucinator_v2/evaluation/ASR-Eval.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Phoneme_Hallucinator_v2/evaluation/ASR.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Phoneme_Hallucinator_v2/evaluation/init
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
Phoneme_Hallucinator_v2/exp/speech_XXL_cond/params.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset": "speech",
|
3 |
+
"dimension": 1024,
|
4 |
+
"batch_size": 50,
|
5 |
+
"set_size": 200,
|
6 |
+
"mask_type": "arb_expand",
|
7 |
+
"model": "pc_acset_vae",
|
8 |
+
"latent_encoder_hidden": [256,256,256,256],
|
9 |
+
"latent_dim": 256,
|
10 |
+
"trans_params": {
|
11 |
+
"transform": ["L","LR","CP","R","L","LR","CP","R","L","LR","CP","R","L","LR","CP"],
|
12 |
+
"dimension": 256,
|
13 |
+
"coupling_hids": [256,256]
|
14 |
+
},
|
15 |
+
"vae_params": {
|
16 |
+
"hid_dimensions": 256,
|
17 |
+
"dimension": 1024,
|
18 |
+
"enc_dense_hids": [512,512,512,512],
|
19 |
+
"dec_dense_hids": [512,512,512,512]
|
20 |
+
},
|
21 |
+
"use_peq_embed": 1,
|
22 |
+
"set_xformer_hids": [256,256,256,256],
|
23 |
+
"epochs": 1000,
|
24 |
+
"optimizer": "adam",
|
25 |
+
"lr": 0.0001,
|
26 |
+
"decay_steps": 100000,
|
27 |
+
"decay_rate": 0.5,
|
28 |
+
"clip_gradient": 1,
|
29 |
+
"exp_dir": "Phoneme_Hallucinator_v2/exp/speech_XXL_cond",
|
30 |
+
"summ_freq": 100,
|
31 |
+
"eval_metrics": ["sam"]
|
32 |
+
}
|
33 |
+
|
Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/checkpoint
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
model_checkpoint_path: "last.ckpt"
|
2 |
+
all_model_checkpoint_paths: "params.ckpt"
|
3 |
+
all_model_checkpoint_paths: "last.ckpt"
|
Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/last.ckpt.data-00000-of-00001
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f48a8699b9c871c3e4d8fa92c1f4e4c58c3c054d7f3620577286307b1cee9c22
|
3 |
+
size 228403264
|
Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/last.ckpt.index
ADDED
Binary file (33.9 kB). View file
|
|
Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/last.ckpt.meta
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc65091d5349de08462ecf43b05ff3ed9e07c4de97160bc81c858f99a2979411
|
3 |
+
size 7340118
|
Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/params.ckpt.data-00000-of-00001
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a10bbba728a64b8c3697e3fe546a1fe1a1a666c609e021acb24a15ccf615c740
|
3 |
+
size 228403264
|
Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/params.ckpt.index
ADDED
Binary file (33.9 kB). View file
|
|
Phoneme_Hallucinator_v2/exp/speech_XXL_cond/weights/params.ckpt.meta
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22bf81dcaf6d4079a9151a9e797c1d4f6e7e209d493dba1f1f29b7d1ba5c2f59
|
3 |
+
size 7340118
|
Phoneme_Hallucinator_v2/models/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def get_model(hps):
|
3 |
+
if hps.model == 'pc_acset_vae':
|
4 |
+
from .pc_acset_vae import ACSetVAE
|
5 |
+
model = ACSetVAE(hps)
|
6 |
+
else:
|
7 |
+
raise ValueError()
|
8 |
+
|
9 |
+
return model
|
Phoneme_Hallucinator_v2/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (392 Bytes). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (354 Bytes). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (358 Bytes). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/base.cpython-310.pyc
ADDED
Binary file (4.06 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/base.cpython-36.pyc
ADDED
Binary file (3.89 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/base.cpython-37.pyc
ADDED
Binary file (3.87 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/cVAE.cpython-310.pyc
ADDED
Binary file (2.09 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/cVAE.cpython-36.pyc
ADDED
Binary file (1.92 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/cVAE.cpython-37.pyc
ADDED
Binary file (1.93 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/networks.cpython-310.pyc
ADDED
Binary file (5 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/networks.cpython-36.pyc
ADDED
Binary file (4.64 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/networks.cpython-37.pyc
ADDED
Binary file (4.6 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/pc_acset.cpython-36.pyc
ADDED
Binary file (2.77 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/pc_acset_vae.cpython-310.pyc
ADDED
Binary file (3.07 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/pc_acset_vae.cpython-36.pyc
ADDED
Binary file (3.01 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/pc_acset_vae.cpython-37.pyc
ADDED
Binary file (2.92 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/pc_acset_vae_06.cpython-36.pyc
ADDED
Binary file (3.73 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/pc_encoder.cpython-310.pyc
ADDED
Binary file (1.94 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/pc_encoder.cpython-36.pyc
ADDED
Binary file (1.84 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/pc_encoder.cpython-37.pyc
ADDED
Binary file (1.85 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/runner.cpython-36.pyc
ADDED
Binary file (5.79 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/set_transformer.cpython-310.pyc
ADDED
Binary file (2.49 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/set_transformer.cpython-36.pyc
ADDED
Binary file (2.18 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/set_transformer.cpython-37.pyc
ADDED
Binary file (2.16 kB). View file
|
|
Phoneme_Hallucinator_v2/models/__pycache__/utils.cpython-36.pyc
ADDED
Binary file (4.92 kB). View file
|
|
Phoneme_Hallucinator_v2/models/base.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pprint import pformat
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
class BaseModel(object):
|
6 |
+
def __init__(self, hps):
|
7 |
+
super(BaseModel, self).__init__()
|
8 |
+
|
9 |
+
self.hps = hps
|
10 |
+
g = tf.Graph()
|
11 |
+
with g.as_default():
|
12 |
+
# open a session
|
13 |
+
config = tf.compat.v1.ConfigProto()
|
14 |
+
config.log_device_placement = True
|
15 |
+
config.allow_soft_placement = True
|
16 |
+
config.gpu_options.allow_growth = True
|
17 |
+
self.sess = tf.compat.v1.Session(config=config, graph=g)
|
18 |
+
# build model
|
19 |
+
self.build_net()
|
20 |
+
self.build_ops()
|
21 |
+
# initialize
|
22 |
+
self.sess.run(tf.compat.v1.global_variables_initializer())
|
23 |
+
self.saver = tf.compat.v1.train.Saver()
|
24 |
+
self.writer = tf.compat.v1.summary.FileWriter(self.hps.exp_dir + '/summary')
|
25 |
+
# logging
|
26 |
+
total_params = 0
|
27 |
+
trainable_variables = tf.compat.v1.trainable_variables()
|
28 |
+
logging.info('=' * 20)
|
29 |
+
logging.info("Variables:")
|
30 |
+
logging.info(pformat(trainable_variables))
|
31 |
+
for v in trainable_variables:
|
32 |
+
num_params = np.prod(v.get_shape().as_list())
|
33 |
+
total_params += num_params
|
34 |
+
|
35 |
+
logging.info("TOTAL TENSORS: %d TOTAL PARAMS: %f[M]" % (
|
36 |
+
len(trainable_variables), total_params / 1e6))
|
37 |
+
logging.info('=' * 20)
|
38 |
+
|
39 |
+
def save(self, filename='params'):
|
40 |
+
fname = f'{self.hps.exp_dir}/weights/{filename}.ckpt'
|
41 |
+
self.saver.save(self.sess, fname)
|
42 |
+
|
43 |
+
def load(self, filename='params'):
|
44 |
+
fname = f'{self.hps.exp_dir}/weights/{filename}.ckpt'
|
45 |
+
self.saver.restore(self.sess, fname)
|
46 |
+
|
47 |
+
def build_net(self):
|
48 |
+
raise NotImplementedError()
|
49 |
+
|
50 |
+
def build_ops(self):
|
51 |
+
# optimizer
|
52 |
+
self.global_step = tf.compat.v1.train.get_or_create_global_step()
|
53 |
+
learning_rate = tf.compat.v1.train.inverse_time_decay(
|
54 |
+
self.hps.lr, self.global_step,
|
55 |
+
self.hps.decay_steps, self.hps.decay_rate,
|
56 |
+
staircase=True)
|
57 |
+
warmup_lr = tf.compat.v1.train.inverse_time_decay(
|
58 |
+
0.001 * self.hps.lr, self.global_step,
|
59 |
+
self.hps.decay_steps, self.hps.decay_rate,
|
60 |
+
staircase=True)
|
61 |
+
learning_rate = tf.cond(pred=tf.less(self.global_step, 1000), true_fn=lambda: warmup_lr, false_fn=lambda: learning_rate)
|
62 |
+
tf.compat.v1.summary.scalar('lr', learning_rate)
|
63 |
+
if self.hps.optimizer == 'adam':
|
64 |
+
optimizer = tf.compat.v1.train.AdamOptimizer(
|
65 |
+
learning_rate=learning_rate,
|
66 |
+
beta1=0.9, beta2=0.999, epsilon=1e-08,
|
67 |
+
use_locking=False, name="Adam")
|
68 |
+
elif self.hps.optimizer == 'rmsprop':
|
69 |
+
optimizer = tf.compat.v1.train.RMSPropOptimizer(
|
70 |
+
learning_rate=learning_rate)
|
71 |
+
elif self.hps.optimizer == 'mom':
|
72 |
+
optimizer = tf.compat.v1.train.MomentumOptimizer(
|
73 |
+
learning_rate=learning_rate,
|
74 |
+
momentum=0.9)
|
75 |
+
else:
|
76 |
+
optimizer = tf.compat.v1.train.GradientDescentOptimizer(
|
77 |
+
learning_rate=learning_rate)
|
78 |
+
|
79 |
+
# regularization
|
80 |
+
l2_reg = sum(
|
81 |
+
[tf.reduce_sum(input_tensor=tf.square(v)) for v in tf.compat.v1.trainable_variables()
|
82 |
+
if ("magnitude" in v.name) or ("rescaling_scale" in v.name)])
|
83 |
+
reg_loss = 0.00005 * l2_reg
|
84 |
+
|
85 |
+
# train
|
86 |
+
grads_and_vars = optimizer.compute_gradients(
|
87 |
+
self.loss+reg_loss, tf.compat.v1.trainable_variables())
|
88 |
+
grads, vars_ = zip(*grads_and_vars)
|
89 |
+
if self.hps.clip_gradient > 0:
|
90 |
+
grads, gradient_norm = tf.clip_by_global_norm(
|
91 |
+
grads, clip_norm=self.hps.clip_gradient)
|
92 |
+
gradient_norm = tf.debugging.check_numerics(
|
93 |
+
gradient_norm, "Gradient norm is NaN or Inf.")
|
94 |
+
tf.compat.v1.summary.scalar('gradient_norm', gradient_norm)
|
95 |
+
capped_grads_and_vars = zip(grads, vars_)
|
96 |
+
self.train_op = optimizer.apply_gradients(
|
97 |
+
capped_grads_and_vars, global_step=self.global_step)
|
98 |
+
|
99 |
+
# summary
|
100 |
+
self.summ_op = tf.compat.v1.summary.merge_all()
|
101 |
+
|
102 |
+
def execute(self, cmd, batch):
|
103 |
+
return self.sess.run(cmd, {self.x:batch['x'], self.b:batch['b'], self.m:batch['m']})
|
Phoneme_Hallucinator_v2/models/cVAE.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
import tensorflow_probability as tfp
|
4 |
+
tfd = tfp.distributions
|
5 |
+
|
6 |
+
from .networks import dense_nn, cond_dense_nn
|
7 |
+
|
8 |
+
class CondVAE(object):
|
9 |
+
def __init__(self, hps, name="cvae"):
|
10 |
+
self.hps = hps
|
11 |
+
self.name = name
|
12 |
+
|
13 |
+
def enc(self, x, cond=None):
|
14 |
+
'''
|
15 |
+
x: [B, C]
|
16 |
+
cond: [B, C]
|
17 |
+
'''
|
18 |
+
B,C = tf.shape(input=x)[0], tf.shape(input=x)[1]
|
19 |
+
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
|
20 |
+
prior_dist = tfd.MultivariateNormalDiag(tf.zeros(self.hps['hid_dimensions']),tf.ones(self.hps['hid_dimensions']))
|
21 |
+
if cond is None:
|
22 |
+
x = dense_nn(x, self.hps['enc_dense_hids'], 2 * self.hps['hid_dimensions'], False, "enc")
|
23 |
+
else:
|
24 |
+
x = cond_dense_nn(x, cond, self.hps['enc_dense_hids'], 2 * self.hps['hid_dimensions'], False, "enc")
|
25 |
+
m, s = x[:, :self.hps['hid_dimensions']], tf.nn.softplus(x[:, self.hps['hid_dimensions']:])
|
26 |
+
posterior_dist = tfd.MultivariateNormalDiag(m,s)
|
27 |
+
#kl = 0.5 * tf.reduce_sum(s + m ** 2 - 1.0 - tf.log(s), axis=-1)
|
28 |
+
kl = - tfd.kl_divergence(posterior_dist, prior_dist)
|
29 |
+
eps = prior_dist.sample(B)
|
30 |
+
posterior_sample = m + eps * s
|
31 |
+
return kl, posterior_sample
|
32 |
+
|
33 |
+
def dec(self, x, cond=None):
|
34 |
+
'''
|
35 |
+
x: [B, C]
|
36 |
+
'''
|
37 |
+
B,C = tf.shape(input=x)[0], tf.shape(input=x)[1]
|
38 |
+
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
|
39 |
+
if cond is None:
|
40 |
+
x = dense_nn(x, self.hps['dec_dense_hids'], 2 * self.hps['dimension'], False, "dec")
|
41 |
+
else:
|
42 |
+
x = cond_dense_nn(x, cond, self.hps['dec_dense_hids'], 2 * self.hps['dimension'], False, "dec")
|
43 |
+
m, s = x[:, :self.hps['dimension']], tf.nn.softplus(x[:, self.hps['dimension']:])
|
44 |
+
sample_dist = tfd.MultivariateNormalDiag(loc=m, scale_diag=s)
|
45 |
+
return sample_dist
|
Phoneme_Hallucinator_v2/models/flow/__init__.py
ADDED
File without changes
|