diff --git a/cluster/__init__.py b/cluster/__init__.py index f1b9bde04e73e9218a5d534227caa4c25332f424..ae00ea692643e69c8c8c60e392f456ab0adcdd93 100644 --- a/cluster/__init__.py +++ b/cluster/__init__.py @@ -1,7 +1,7 @@ -import numpy as np import torch from sklearn.cluster import KMeans + def get_cluster_model(ckpt_path): checkpoint = torch.load(ckpt_path) kmeans_dict = {} diff --git a/cluster/__pycache__/__init__.cpython-38.pyc b/cluster/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83ecf62be95af14ec74e39e359b2add2510638fe Binary files /dev/null and b/cluster/__pycache__/__init__.cpython-38.pyc differ diff --git a/cluster/__pycache__/kmeans.cpython-38.pyc b/cluster/__pycache__/kmeans.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4990f9348d795cd5ea173561da762e4823c54624 Binary files /dev/null and b/cluster/__pycache__/kmeans.cpython-38.pyc differ diff --git a/cluster/km_train.py b/cluster/km_train.py new file mode 100644 index 0000000000000000000000000000000000000000..917b2da181dda7b2a918cad9961838ad23c5f0b4 --- /dev/null +++ b/cluster/km_train.py @@ -0,0 +1,80 @@ +import time,pdb +import tqdm +from time import time as ttime +import os +from pathlib import Path +import logging +import argparse +from cluster.kmeans import KMeansGPU +import torch +import numpy as np +from sklearn.cluster import KMeans,MiniBatchKMeans + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +from time import time as ttime +import pynvml,torch + +def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=False):#gpu_minibatch真拉,虽然库支持但是也不考虑 + logger.info(f"Loading features from {in_dir}") + features = [] + nums = 0 + for path in tqdm.tqdm(in_dir.glob("*.soft.pt")): + # for name in os.listdir(in_dir): + # path="%s/%s"%(in_dir,name) + features.append(torch.load(path,map_location="cpu").squeeze(0).numpy().T) + # print(features[-1].shape) + features = np.concatenate(features, axis=0) + print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype) + features = features.astype(np.float32) + logger.info(f"Clustering features of shape: {features.shape}") + t = time.time() + if(use_gpu==False): + if use_minibatch: + kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features) + else: + kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features) + else: + kmeans = KMeansGPU(n_clusters=n_clusters, mode='euclidean', verbose=2 if verbose else 0,max_iter=500,tol=1e-2)# + features=torch.from_numpy(features)#.to(device) + labels = kmeans.fit_predict(features)# + + print(time.time()-t, "s") + + x = { + "n_features_in_": kmeans.n_features_in_ if use_gpu==False else features.shape[0], + "_n_threads": kmeans._n_threads if use_gpu==False else 4, + "cluster_centers_": kmeans.cluster_centers_ if use_gpu==False else kmeans.centroids.cpu().numpy(), + } + print("end") + + return x + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=Path, default="./dataset/44k", + help='path of training data directory') + parser.add_argument('--output', type=Path, default="logs/44k", + help='path of model output directory') + + args = parser.parse_args() + + checkpoint_dir = args.output + dataset = args.dataset + n_clusters = 1000 + + ckpt = {} + for spk in os.listdir(dataset): + if os.path.isdir(dataset/spk): + print(f"train kmeans for {spk}...") + in_dir = dataset/spk + x = train_cluster(in_dir, n_clusters,use_minibatch=False,verbose=False,use_gpu=True) + ckpt[spk] = x + + checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt" + checkpoint_path.parent.mkdir(exist_ok=True, parents=True) + torch.save( + ckpt, + checkpoint_path, + ) + diff --git a/cluster/kmeans.py b/cluster/kmeans.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6dd40876570190a9361aa0f9f4f609af042f23 --- /dev/null +++ b/cluster/kmeans.py @@ -0,0 +1,204 @@ +from time import time + +import numpy as np +import pynvml +import torch +from torch.nn.functional import normalize + + +# device=torch.device("cuda:0") +def _kpp(data: torch.Tensor, k: int, sample_size: int = -1): + """ Picks k points in the data based on the kmeans++ method. + + Parameters + ---------- + data : torch.Tensor + Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D + data, rank 2 multidimensional data, in which case one + row is one observation. + k : int + Number of samples to generate. + sample_size : int + sample data to avoid memory overflow during calculation + + Returns + ------- + init : ndarray + A 'k' by 'N' containing the initial centroids. + + References + ---------- + .. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of + careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium + on Discrete Algorithms, 2007. + .. [2] scipy/cluster/vq.py: _kpp + """ + batch_size=data.shape[0] + if batch_size>sample_size: + data = data[torch.randint(0, batch_size,[sample_size], device=data.device)] + dims = data.shape[1] if len(data.shape) > 1 else 1 + init = torch.zeros((k, dims)).to(data.device) + r = torch.distributions.uniform.Uniform(0, 1) + for i in range(k): + if i == 0: + init[i, :] = data[torch.randint(data.shape[0], [1])] + else: + D2 = torch.cdist(init[:i, :][None, :], data[None, :], p=2)[0].amin(dim=0) + probs = D2 / torch.sum(D2) + cumprobs = torch.cumsum(probs, dim=0) + init[i, :] = data[torch.searchsorted(cumprobs, r.sample([1]).to(data.device))] + return init +class KMeansGPU: + ''' + Kmeans clustering algorithm implemented with PyTorch + + Parameters: + n_clusters: int, + Number of clusters + + max_iter: int, default: 100 + Maximum number of iterations + + tol: float, default: 0.0001 + Tolerance + + verbose: int, default: 0 + Verbosity + + mode: {'euclidean', 'cosine'}, default: 'euclidean' + Type of distance measure + + init_method: {'random', 'point', '++'} + Type of initialization + + minibatch: {None, int}, default: None + Batch size of MinibatchKmeans algorithm + if None perform full KMeans algorithm + + Attributes: + centroids: torch.Tensor, shape: [n_clusters, n_features] + cluster centroids + ''' + def __init__(self, n_clusters, max_iter=200, tol=1e-4, verbose=0, mode="euclidean",device=torch.device("cuda:0")): + self.n_clusters = n_clusters + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + self.mode = mode + self.device=device + pynvml.nvmlInit() + gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(device.index) + info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle) + self.minibatch=int(33e6/self.n_clusters*info.free/ 1024 / 1024 / 1024) + print("free_mem/GB:",info.free/ 1024 / 1024 / 1024,"minibatch:",self.minibatch) + + @staticmethod + def cos_sim(a, b): + """ + Compute cosine similarity of 2 sets of vectors + + Parameters: + a: torch.Tensor, shape: [m, n_features] + + b: torch.Tensor, shape: [n, n_features] + """ + return normalize(a, dim=-1) @ normalize(b, dim=-1).transpose(-2, -1) + + @staticmethod + def euc_sim(a, b): + """ + Compute euclidean similarity of 2 sets of vectors + Parameters: + a: torch.Tensor, shape: [m, n_features] + b: torch.Tensor, shape: [n, n_features] + """ + return 2 * a @ b.transpose(-2, -1) -(a**2).sum(dim=1)[..., :, None] - (b**2).sum(dim=1)[..., None, :] + + def max_sim(self, a, b): + """ + Compute maximum similarity (or minimum distance) of each vector + in a with all of the vectors in b + Parameters: + a: torch.Tensor, shape: [m, n_features] + b: torch.Tensor, shape: [n, n_features] + """ + if self.mode == 'cosine': + sim_func = self.cos_sim + elif self.mode == 'euclidean': + sim_func = self.euc_sim + sim = sim_func(a, b) + max_sim_v, max_sim_i = sim.max(dim=-1) + return max_sim_v, max_sim_i + + def fit_predict(self, X): + """ + Combination of fit() and predict() methods. + This is faster than calling fit() and predict() seperately. + Parameters: + X: torch.Tensor, shape: [n_samples, n_features] + centroids: {torch.Tensor, None}, default: None + if given, centroids will be initialized with given tensor + if None, centroids will be randomly chosen from X + Return: + labels: torch.Tensor, shape: [n_samples] + + mini_=33kk/k*remain + mini=min(mini_,fea_shape) + offset=log2(k/1000)*1.5 + kpp_all=min(mini_*10/offset,fea_shape) + kpp_sample=min(mini_/12/offset,fea_shape) + """ + assert isinstance(X, torch.Tensor), "input must be torch.Tensor" + assert X.dtype in [torch.half, torch.float, torch.double], "input must be floating point" + assert X.ndim == 2, "input must be a 2d tensor with shape: [n_samples, n_features] " + # print("verbose:%s"%self.verbose) + + offset = np.power(1.5,np.log(self.n_clusters / 1000))/np.log(2) + with torch.no_grad(): + batch_size= X.shape[0] + # print(self.minibatch, int(self.minibatch * 10 / offset), batch_size) + start_time = time() + if (self.minibatch*10//offset< batch_size): + x = X[torch.randint(0, batch_size,[int(self.minibatch*10/offset)])].to(self.device) + else: + x = X.to(self.device) + # print(x.device) + self.centroids = _kpp(x, self.n_clusters, min(int(self.minibatch/12/offset),batch_size)) + del x + torch.cuda.empty_cache() + # self.centroids = self.centroids.to(self.device) + num_points_in_clusters = torch.ones(self.n_clusters, device=self.device, dtype=X.dtype)#全1 + closest = None#[3098036]#int64 + if(self.minibatch>=batch_size//2 and self.minibatch=batch_size): + X=X.to(self.device) + for i in range(self.max_iter): + iter_time = time() + if self.minibatch= 2: + print('iter:', i, 'error:', error.item(), 'time spent:', round(time()-iter_time, 4)) + if error <= self.tol: + break + + if self.verbose >= 1: + print(f'used {i+1} iterations ({round(time()-start_time, 4)}s) to cluster {batch_size} items into {self.n_clusters} clusters') + return closest diff --git a/cluster/train_cluster.py b/cluster/train_cluster.py index 4ac025d400414226e66849407f477ae786c3d5d3..135f179a389804afc0266873ae31a1cf107ebcf8 100644 --- a/cluster/train_cluster.py +++ b/cluster/train_cluster.py @@ -1,67 +1,79 @@ +import argparse +import logging import os -from glob import glob +import time from pathlib import Path -import torch -import logging -import argparse -import torch + import numpy as np -from sklearn.cluster import KMeans, MiniBatchKMeans +import torch import tqdm +from kmeans import KMeansGPU +from sklearn.cluster import KMeans, MiniBatchKMeans + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -import time -import random -def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False): +def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=False):#gpu_minibatch真拉,虽然库支持但是也不考虑 + if str(in_dir).endswith(".ipynb_checkpoints"): + logger.info(f"Ignore {in_dir}") logger.info(f"Loading features from {in_dir}") features = [] nums = 0 for path in tqdm.tqdm(in_dir.glob("*.soft.pt")): - features.append(torch.load(path).squeeze(0).numpy().T) + # for name in os.listdir(in_dir): + # path="%s/%s"%(in_dir,name) + features.append(torch.load(path,map_location="cpu").squeeze(0).numpy().T) # print(features[-1].shape) features = np.concatenate(features, axis=0) print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype) features = features.astype(np.float32) logger.info(f"Clustering features of shape: {features.shape}") t = time.time() - if use_minibatch: - kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features) + if(use_gpu is False): + if use_minibatch: + kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features) + else: + kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features) else: - kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features) + kmeans = KMeansGPU(n_clusters=n_clusters, mode='euclidean', verbose=2 if verbose else 0,max_iter=500,tol=1e-2)# + features=torch.from_numpy(features)#.to(device) + kmeans.fit_predict(features)# + print(time.time()-t, "s") x = { - "n_features_in_": kmeans.n_features_in_, - "_n_threads": kmeans._n_threads, - "cluster_centers_": kmeans.cluster_centers_, + "n_features_in_": kmeans.n_features_in_ if use_gpu is False else features.shape[1], + "_n_threads": kmeans._n_threads if use_gpu is False else 4, + "cluster_centers_": kmeans.cluster_centers_ if use_gpu is False else kmeans.centroids.cpu().numpy(), } print("end") return x - if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=Path, default="./dataset/44k", help='path of training data directory') parser.add_argument('--output', type=Path, default="logs/44k", help='path of model output directory') + parser.add_argument('--gpu',action='store_true', default=False , + help='to use GPU') + args = parser.parse_args() checkpoint_dir = args.output dataset = args.dataset + use_gpu = args.gpu n_clusters = 10000 - + ckpt = {} for spk in os.listdir(dataset): if os.path.isdir(dataset/spk): print(f"train kmeans for {spk}...") in_dir = dataset/spk - x = train_cluster(in_dir, n_clusters, verbose=False) + x = train_cluster(in_dir, n_clusters,use_minibatch=False,verbose=False,use_gpu=use_gpu) ckpt[spk] = x checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt" @@ -70,20 +82,4 @@ if __name__ == "__main__": ckpt, checkpoint_path, ) - - - # import cluster - # for spk in tqdm.tqdm(os.listdir("dataset")): - # if os.path.isdir(f"dataset/{spk}"): - # print(f"start kmeans inference for {spk}...") - # for feature_path in tqdm.tqdm(glob(f"dataset/{spk}/*.discrete.npy", recursive=True)): - # mel_path = feature_path.replace(".discrete.npy",".mel.npy") - # mel_spectrogram = np.load(mel_path) - # feature_len = mel_spectrogram.shape[-1] - # c = np.load(feature_path) - # c = utils.tools.repeat_expand_2d(torch.FloatTensor(c), feature_len).numpy() - # feature = c.T - # feature_class = cluster.get_cluster_result(feature, spk) - # np.save(feature_path.replace(".discrete.npy", ".discrete_class.npy"), feature_class) - - + diff --git a/diffusion/__pycache__/__init__.cpython-38.pyc b/diffusion/__pycache__/__init__.cpython-38.pyc index 5b432ba7b3ad97ba7a95f20e4f35749c4a18d9ee..973c8a552f812ca2682e68fc5a8eb71710e6b3d6 100644 Binary files a/diffusion/__pycache__/__init__.cpython-38.pyc and b/diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusion/__pycache__/data_loaders.cpython-38.pyc b/diffusion/__pycache__/data_loaders.cpython-38.pyc index 50d14c4389d3046b5eeb21bcbfe7fb0d29fe05d7..10e219ec3b4a85b3ddd75dfbc12180aaf131408e 100644 Binary files a/diffusion/__pycache__/data_loaders.cpython-38.pyc and b/diffusion/__pycache__/data_loaders.cpython-38.pyc differ diff --git a/diffusion/__pycache__/diffusion.cpython-38.pyc b/diffusion/__pycache__/diffusion.cpython-38.pyc index 0a3f4cf66eb31eedb98267b61f5f3c090d15577c..2d10ba9a26b26bd5ebe8c7f90095eecb13c9c724 100644 Binary files a/diffusion/__pycache__/diffusion.cpython-38.pyc and b/diffusion/__pycache__/diffusion.cpython-38.pyc differ diff --git a/diffusion/__pycache__/dpm_solver_pytorch.cpython-38.pyc b/diffusion/__pycache__/dpm_solver_pytorch.cpython-38.pyc index cf365c57143708dde7fb2a91676fe723d9fdbbee..458a2487cb70b12428cdd855a9729547f8469358 100644 Binary files a/diffusion/__pycache__/dpm_solver_pytorch.cpython-38.pyc and b/diffusion/__pycache__/dpm_solver_pytorch.cpython-38.pyc differ diff --git a/diffusion/__pycache__/solver.cpython-38.pyc b/diffusion/__pycache__/solver.cpython-38.pyc index 3d1471df003afeb6dc52861af4c463d39f865c07..9b2cf935b8c06468f46528dd819e1ead5c95e87e 100644 Binary files a/diffusion/__pycache__/solver.cpython-38.pyc and b/diffusion/__pycache__/solver.cpython-38.pyc differ diff --git a/diffusion/__pycache__/unit2mel.cpython-38.pyc b/diffusion/__pycache__/unit2mel.cpython-38.pyc index ddb06174707715126cf352fddd64aea558a15e03..4a1ee4d9c013b48628e694ccad6dc29aa030e696 100644 Binary files a/diffusion/__pycache__/unit2mel.cpython-38.pyc and b/diffusion/__pycache__/unit2mel.cpython-38.pyc differ diff --git a/diffusion/__pycache__/vocoder.cpython-38.pyc b/diffusion/__pycache__/vocoder.cpython-38.pyc index b472d6f8c82102aa4fa8e30d07008c455d93bd46..bc7ee3827bec6280193b98c9ab547ee0ed0bbbff 100644 Binary files a/diffusion/__pycache__/vocoder.cpython-38.pyc and b/diffusion/__pycache__/vocoder.cpython-38.pyc differ diff --git a/diffusion/__pycache__/wavenet.cpython-38.pyc b/diffusion/__pycache__/wavenet.cpython-38.pyc index f66eb6cd424a17a284c53d0d753d63cfd9ed773d..69d4791aba62311659821ef1e12bb5d39bfc3640 100644 Binary files a/diffusion/__pycache__/wavenet.cpython-38.pyc and b/diffusion/__pycache__/wavenet.cpython-38.pyc differ diff --git a/diffusion/data_loaders.py b/diffusion/data_loaders.py index bf18572329019d7a8f1df01799eda207c16dd7ff..9f00b9afd01565e568e5315dfa49b82dd2ec68ed 100644 --- a/diffusion/data_loaders.py +++ b/diffusion/data_loaders.py @@ -1,13 +1,14 @@ import os import random -import re -import numpy as np + import librosa +import numpy as np import torch -import random -from utils import repeat_expand_2d -from tqdm import tqdm from torch.utils.data import Dataset +from tqdm import tqdm + +from utils import repeat_expand_2d + def traverse_dir( root_dir, @@ -63,6 +64,7 @@ def get_data_loaders(args, whole_audio=False): spk=args.spk, device=args.train.cache_device, fp16=args.train.cache_fp16, + unit_interpolate_mode = args.data.unit_interpolate_mode, use_aug=True) loader_train = torch.utils.data.DataLoader( data_train , @@ -81,6 +83,7 @@ def get_data_loaders(args, whole_audio=False): whole_audio=True, spk=args.spk, extensions=args.data.extensions, + unit_interpolate_mode = args.data.unit_interpolate_mode, n_spk=args.model.n_spk) loader_valid = torch.utils.data.DataLoader( data_valid, @@ -107,6 +110,7 @@ class AudioDataset(Dataset): device='cpu', fp16=False, use_aug=False, + unit_interpolate_mode = 'left' ): super().__init__() @@ -118,6 +122,7 @@ class AudioDataset(Dataset): self.use_aug = use_aug self.data_buffer={} self.pitch_aug_dict = {} + self.unit_interpolate_mode = unit_interpolate_mode # np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item() if load_all_data: print('Load all the data filelists:', filelists) @@ -126,7 +131,6 @@ class AudioDataset(Dataset): with open(filelists,"r") as f: self.paths = f.read().splitlines() for name_ext in tqdm(self.paths, total=len(self.paths)): - name = os.path.splitext(name_ext)[0] path_audio = name_ext duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate) @@ -171,7 +175,7 @@ class AudioDataset(Dataset): path_units = name_ext + ".soft.pt" units = torch.load(path_units).to(device) units = units[0] - units = repeat_expand_2d(units,f0.size(0)).transpose(0,1) + units = repeat_expand_2d(units,f0.size(0),unit_interpolate_mode).transpose(0,1) if fp16: mel = mel.half() @@ -263,7 +267,7 @@ class AudioDataset(Dataset): path_units = name_ext + ".soft.pt" units = torch.load(path_units) units = units[0] - units = repeat_expand_2d(units,f0.size(0)).transpose(0,1) + units = repeat_expand_2d(units,f0.size(0),self.unit_interpolate_mode).transpose(0,1) units = units[start_frame : start_frame + units_frame_len] diff --git a/diffusion/diffusion.py b/diffusion/diffusion.py index decc1d31503e93e6611b02ced7b9c6f00b95db58..646234b5f8d3161ba126055c628a89162f16b0cf 100644 --- a/diffusion/diffusion.py +++ b/diffusion/diffusion.py @@ -1,10 +1,10 @@ from collections import deque from functools import partial from inspect import isfunction -import torch.nn.functional as F -import librosa.sequence + import numpy as np import torch +import torch.nn.functional as F from torch import nn from tqdm import tqdm @@ -26,8 +26,10 @@ def extract(a, t, x_shape): def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - noise = lambda: torch.randn(shape, device=device) + def repeat_noise(): + return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + def noise(): + return torch.randn(shape, device=device) return repeat_noise() if repeat else noise() @@ -67,6 +69,7 @@ class GaussianDiffusion(nn.Module): max_beta=0.02, spec_min=-12, spec_max=2): + super().__init__() self.denoise_fn = denoise_fn self.out_dims = out_dims @@ -78,7 +81,7 @@ class GaussianDiffusion(nn.Module): timesteps, = betas.shape self.num_timesteps = int(timesteps) - self.k_step = k_step + self.k_step = k_step if k_step>0 and k_step 1: - if method == 'dpm-solver': - from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver + if method == 'dpm-solver' or method == 'dpm-solver++': + from .dpm_solver_pytorch import ( + DPM_Solver, + NoiseScheduleVP, + model_wrapper, + ) # 1. Define the noise schedule. noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) @@ -267,17 +286,20 @@ class GaussianDiffusion(nn.Module): # (We recommend singlestep DPM-Solver for unconditional sampling) # You can adjust the `steps` to balance the computation # costs and the sample quality. - dpm_solver = DPM_Solver(model_fn, noise_schedule) - + if method == 'dpm-solver': + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + elif method == 'dpm-solver++': + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + steps = t // infer_speedup if use_tqdm: self.bar = tqdm(desc="sample time step", total=steps) x = dpm_solver.sample( x, steps=steps, - order=3, + order=2, skip_type="time_uniform", - method="singlestep", + method="multistep", ) if use_tqdm: self.bar.close() @@ -298,6 +320,63 @@ class GaussianDiffusion(nn.Module): x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond ) + elif method == 'ddim': + if use_tqdm: + for i in tqdm( + reversed(range(0, t, infer_speedup)), desc='sample time step', + total=t // infer_speedup, + ): + x = self.p_sample_ddim( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + else: + for i in reversed(range(0, t, infer_speedup)): + x = self.p_sample_ddim( + x, torch.full((b,), i, device=device, dtype=torch.long), + infer_speedup, cond=cond + ) + elif method == 'unipc': + from .uni_pc import NoiseScheduleVP, UniPC, model_wrapper + # 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) + + # 2. Convert your discrete-time `model` to the continuous-time + # noise prediction model. Here is an example for a diffusion model + # `model` with the noise prediction type ("noise") . + def my_wrapper(fn): + def wrapped(x, t, **kwargs): + ret = fn(x, t, **kwargs) + if use_tqdm: + self.bar.update(1) + return ret + + return wrapped + + model_fn = model_wrapper( + my_wrapper(self.denoise_fn), + noise_schedule, + model_type="noise", # or "x_start" or "v" or "score" + model_kwargs={"cond": cond} + ) + + # 3. Define uni_pc and sample by multistep UniPC. + # You can adjust the `steps` to balance the computation + # costs and the sample quality. + uni_pc = UniPC(model_fn, noise_schedule, variant='bh2') + + steps = t // infer_speedup + if use_tqdm: + self.bar = tqdm(desc="sample time step", total=steps) + x = uni_pc.sample( + x, + steps=steps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + if use_tqdm: + self.bar.close() else: raise NotImplementedError(method) else: diff --git a/diffusion/diffusion_onnx.py b/diffusion/diffusion_onnx.py index 1c1e80321de162b5233801efa3423739f7f92bdc..f01e463515bd6dccd02fe49b1db1f5af64fc746b 100644 --- a/diffusion/diffusion_onnx.py +++ b/diffusion/diffusion_onnx.py @@ -1,15 +1,14 @@ +import math from collections import deque from functools import partial from inspect import isfunction -import torch.nn.functional as F -import librosa.sequence + import numpy as np -from torch.nn import Conv1d -from torch.nn import Mish import torch +import torch.nn.functional as F from torch import nn +from torch.nn import Conv1d, Mish from tqdm import tqdm -import math def exists(x): @@ -27,8 +26,10 @@ def extract(a, t): def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - noise = lambda: torch.randn(shape, device=device) + def repeat_noise(): + return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + def noise(): + return torch.randn(shape, device=device) return repeat_noise() if repeat else noise() @@ -389,7 +390,11 @@ class GaussianDiffusion(nn.Module): if method is not None and infer_speedup > 1: if method == 'dpm-solver': - from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver + from .dpm_solver_pytorch import ( + DPM_Solver, + NoiseScheduleVP, + model_wrapper, + ) # 1. Define the noise schedule. noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) @@ -576,9 +581,6 @@ class GaussianDiffusion(nn.Module): plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device) noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device) - ot = step_range[0] - ot_1 = torch.full((1,), ot, device=device, dtype=torch.long) - for t in step_range: t_1 = torch.full((1,), t, device=device, dtype=torch.long) noise_pred = self.denoise_fn(x, t_1, cond) diff --git a/diffusion/dpm_solver_pytorch.py b/diffusion/dpm_solver_pytorch.py index dee5e280661b61e0a99038ce0bd240db51344ead..83ed73e22d37cb8ef224425dcfd6bb3dcba74578 100644 --- a/diffusion/dpm_solver_pytorch.py +++ b/diffusion/dpm_solver_pytorch.py @@ -1,5 +1,3 @@ -import math - import torch @@ -11,7 +9,8 @@ class NoiseScheduleVP: alphas_cumprod=None, continuous_beta_0=0.1, continuous_beta_1=20., - ): + dtype=torch.float32, + ): """Create a wrapper class for the forward SDE (VP type). *** @@ -46,7 +45,7 @@ class NoiseScheduleVP: betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) - Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. **Important**: Please pay special attention for the args for `alphas_cumprod`: The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that @@ -59,21 +58,19 @@ class NoiseScheduleVP: 2. For continuous-time DPMs: - We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise - schedule are the default settings in DDPM and improved-DDPM: + We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise + schedule are the default settings in Yang Song's ScoreSDE: Args: beta_min: A `float` number. The smallest beta for the linear schedule. beta_max: A `float` number. The largest beta for the linear schedule. - cosine_s: A `float` number. The hyperparameter in the cosine schedule. - cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. T: A `float` number. The ending time of the forward process. =============================================================== Args: schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, - 'linear' or 'cosine' for continuous-time DPMs. + 'linear' for continuous-time DPMs. Returns: A wrapper object of the forward SDE (VP type). @@ -92,10 +89,8 @@ class NoiseScheduleVP: """ - if schedule not in ['discrete', 'linear', 'cosine']: - raise ValueError( - "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( - schedule)) + if schedule not in ['discrete', 'linear']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule)) self.schedule = schedule if schedule == 'discrete': @@ -104,40 +99,37 @@ class NoiseScheduleVP: else: assert alphas_cumprod is not None log_alphas = 0.5 * torch.log(alphas_cumprod) - self.total_N = len(log_alphas) self.T = 1. - self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) - self.log_alpha_array = log_alphas.reshape((1, -1,)) + self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype) + self.total_N = self.log_alpha_array.shape[1] + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) else: + self.T = 1. self.total_N = 1000 self.beta_0 = continuous_beta_0 self.beta_1 = continuous_beta_1 - self.cosine_s = 0.008 - self.cosine_beta_max = 999. - self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) - self.schedule = schedule - if schedule == 'cosine': - # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. - # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. - self.T = 0.9946 - else: - self.T = 1. + + def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1): + """ + For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. + We clip the log-SNR near t=T within -5.1 to ensure the stability. + Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE. + """ + log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas)) + lambs = log_alphas - log_sigmas + idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda) + if idx > 0: + log_alphas = log_alphas[:-idx] + return log_alphas def marginal_log_mean_coeff(self, t): """ Compute log(alpha_t) of a given continuous-time label t in [0, T]. """ if self.schedule == 'discrete': - return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), - self.log_alpha_array.to(t.device)).reshape((-1)) + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) elif self.schedule == 'linear': return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == 'cosine': - log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) - log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 - return log_alpha_t def marginal_alpha(self, t): """ @@ -165,32 +157,25 @@ class NoiseScheduleVP: """ if self.schedule == 'linear': tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0 ** 2 + tmp + Delta = self.beta_0**2 + tmp return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) elif self.schedule == 'discrete': log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) - t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), - torch.flip(self.t_array.to(lamb.device), [1])) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) return t.reshape((-1,)) - else: - log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - t = t_fn(log_alpha) - return t def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1., - classifier_fn=None, - classifier_kwargs={}, + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, ): """Create a wrapper function for the noise prediction model. @@ -293,8 +278,6 @@ def model_wrapper( return t_continuous def noise_pred_fn(x, t_continuous, cond=None): - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) t_input = get_model_input_time(t_continuous) if cond is None: output = model(x, t_input, **model_kwargs) @@ -304,16 +287,13 @@ def model_wrapper( return output elif model_type == "x_start": alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim()) elif model_type == "v": alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x elif model_type == "score": sigma_t = noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return -expand_dims(sigma_t, dims) * output + return -expand_dims(sigma_t, x.dim()) * output def cond_grad_fn(x, t_input): """ @@ -328,8 +308,6 @@ def model_wrapper( """ The noise predicition model function that is used for DPM-Solver. """ - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) if guidance_type == "uncond": return noise_pred_fn(x, t_continuous) elif guidance_type == "classifier": @@ -338,7 +316,7 @@ def model_wrapper( cond_grad = cond_grad_fn(x, t_input) sigma_t = noise_schedule.marginal_std(t_continuous) noise = noise_pred_fn(x, t_continuous) - return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad elif guidance_type == "classifier-free": if guidance_scale == 1. or unconditional_condition is None: return noise_pred_fn(x, t_continuous, cond=condition) @@ -349,20 +327,34 @@ def model_wrapper( noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) return noise_uncond + guidance_scale * (noise - noise_uncond) - assert model_type in ["noise", "x_start", "v"] + assert model_type in ["noise", "x_start", "v", "score"] assert guidance_type in ["uncond", "classifier", "classifier-free"] return model_fn class DPM_Solver: - def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995, + ): """Construct a DPM-Solver. - We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). - If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). - If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). - In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. - The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. Args: model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): @@ -370,18 +362,65 @@ class DPM_Solver: def model_fn(x, t_continuous): return noise `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. noise_schedule: A noise schedule object, such as NoiseScheduleVP. - predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. - thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. - max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. - - [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. """ - self.model = model_fn + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) self.noise_schedule = noise_schedule - self.predict_x0 = predict_x0 - self.thresholding = thresholding - self.max_val = max_val + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 def noise_prediction_fn(self, x, t): """ @@ -391,24 +430,20 @@ class DPM_Solver: def data_prediction_fn(self, x, t): """ - Return the data prediction model (with thresholding). + Return the data prediction model (with corrector). """ noise = self.noise_prediction_fn(x, t) - dims = x.dim() alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) - if self.thresholding: - p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. - s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) - x0 = torch.clamp(x0, -s, s) / s + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) return x0 def model_fn(self, x, t): """ Convert the model to the noise prediction model or the data prediction model. """ - if self.predict_x0: + if self.algorithm_type == "dpmsolver++": return self.data_prediction_fn(x, t) else: return self.noise_prediction_fn(x, t) @@ -437,11 +472,10 @@ class DPM_Solver: return torch.linspace(t_T, t_0, N + 1).to(device) elif skip_type == 'time_quadratic': t_order = 2 - t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) return t else: - raise ValueError( - "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): """ @@ -478,32 +512,31 @@ class DPM_Solver: if order == 3: K = steps // 3 + 1 if steps % 3 == 0: - orders = [3, ] * (K - 2) + [2, 1] + orders = [3,] * (K - 2) + [2, 1] elif steps % 3 == 1: - orders = [3, ] * (K - 1) + [1] + orders = [3,] * (K - 1) + [1] else: - orders = [3, ] * (K - 1) + [2] + orders = [3,] * (K - 1) + [2] elif order == 2: if steps % 2 == 0: K = steps // 2 - orders = [2, ] * K + orders = [2,] * K else: K = steps // 2 + 1 - orders = [2, ] * (K - 1) + [1] + orders = [2,] * (K - 1) + [1] elif order == 1: K = 1 - orders = [1, ] * steps + orders = [1,] * steps else: raise ValueError("'order' must be '1' or '2' or '3'.") if skip_type == 'logSNR': # To reproduce the results in DPM-Solver paper timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) else: - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ - torch.cumsum(torch.tensor([0, ] + orders), dim=0).to(device)] + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] return timesteps_outer, orders - def denoise_fn(self, x, s): + def denoise_to_zero_fn(self, x, s): """ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. """ @@ -515,8 +548,8 @@ class DPM_Solver: Args: x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). model_s: A pytorch tensor. The model function evaluated at time `s`. If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. return_intermediate: A `bool`. If true, also return the model value at time `s`. @@ -524,20 +557,19 @@ class DPM_Solver: x_t: A pytorch tensor. The approximated solution at time `t`. """ ns = self.noise_schedule - dims = x.dim() lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) h = lambda_t - lambda_s log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) - if self.predict_x0: + if self.algorithm_type == "dpmsolver++": phi_1 = torch.expm1(-h) if model_s is None: model_s = self.model_fn(x, s) x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s + sigma_t / sigma_s * x + - alpha_t * phi_1 * model_s ) if return_intermediate: return x_t, {'model_s': model_s} @@ -548,70 +580,66 @@ class DPM_Solver: if model_s is None: model_s = self.model_fn(x, s) x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s ) if return_intermediate: return x_t, {'model_s': model_s} else: return x_t - def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, - solver_type='dpm_solver'): + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpmsolver'): """ Singlestep solver DPM-Solver-2 from time `s` to time `t`. Args: x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). r1: A `float`. The hyperparameter of the second-order solver. model_s: A pytorch tensor. The model function evaluated at time `s`. If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) if r1 is None: r1 = 0.5 ns = self.noise_schedule - dims = x.dim() lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) h = lambda_t - lambda_s lambda_s1 = lambda_s + r1 * h s1 = ns.inverse_lambda(lambda_s1) - log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( - s1), ns.marginal_log_mean_coeff(t) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) - if self.predict_x0: + if self.algorithm_type == "dpmsolver++": phi_11 = torch.expm1(-r1 * h) phi_1 = torch.expm1(-h) if model_s is None: model_s = self.model_fn(x, s) x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s ) model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': + if solver_type == 'dpmsolver': x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) ) elif solver_type == 'taylor': x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( - model_s1 - model_s) + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s) ) else: phi_11 = torch.expm1(r1 * h) @@ -620,36 +648,35 @@ class DPM_Solver: if model_s is None: model_s = self.model_fn(x, s) x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s + torch.exp(log_alpha_s1 - log_alpha_s) * x + - (sigma_s1 * phi_11) * model_s ) model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': + if solver_type == 'dpmsolver': x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) ) elif solver_type == 'taylor': x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s) ) if return_intermediate: return x_t, {'model_s': model_s, 'model_s1': model_s1} else: return x_t - def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, - return_intermediate=False, solver_type='dpm_solver'): + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpmsolver'): """ Singlestep solver DPM-Solver-3 from time `s` to time `t`. Args: x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). r1: A `float`. The hyperparameter of the third-order solver. r2: A `float`. The hyperparameter of the third-order solver. model_s: A pytorch tensor. The model function evaluated at time `s`. @@ -657,32 +684,29 @@ class DPM_Solver: model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) if r1 is None: r1 = 1. / 3. if r2 is None: r2 = 2. / 3. ns = self.noise_schedule - dims = x.dim() lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) h = lambda_t - lambda_s lambda_s1 = lambda_s + r1 * h lambda_s2 = lambda_s + r2 * h s1 = ns.inverse_lambda(lambda_s1) s2 = ns.inverse_lambda(lambda_s2) - log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( - s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( - s2), ns.marginal_std(t) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) - if self.predict_x0: + if self.algorithm_type == "dpmsolver++": phi_11 = torch.expm1(-r1 * h) phi_12 = torch.expm1(-r2 * h) phi_1 = torch.expm1(-h) @@ -694,21 +718,21 @@ class DPM_Solver: model_s = self.model_fn(x, s) if model_s1 is None: x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s + (sigma_s1 / sigma_s) * x + - (alpha_s1 * phi_11) * model_s ) model_s1 = self.model_fn(x_s1, s1) x_s2 = ( - expand_dims(sigma_s2 / sigma_s, dims) * x - - expand_dims(alpha_s2 * phi_12, dims) * model_s - + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) ) model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': + if solver_type == 'dpmsolver': x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s) ) elif solver_type == 'taylor': D1_0 = (1. / r1) * (model_s1 - model_s) @@ -716,10 +740,10 @@ class DPM_Solver: D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) D2 = 2. * (D1_1 - D1_0) / (r2 - r1) x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + expand_dims(alpha_t * phi_2, dims) * D1 - - expand_dims(alpha_t * phi_3, dims) * D2 + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 ) else: phi_11 = torch.expm1(r1 * h) @@ -733,21 +757,21 @@ class DPM_Solver: model_s = self.model_fn(x, s) if model_s1 is None: x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s + (torch.exp(log_alpha_s1 - log_alpha_s)) * x + - (sigma_s1 * phi_11) * model_s ) model_s1 = self.model_fn(x_s1, s1) x_s2 = ( - expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - - expand_dims(sigma_s2 * phi_12, dims) * model_s - - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) ) model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': + if solver_type == 'dpmsolver': x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s) ) elif solver_type == 'taylor': D1_0 = (1. / r1) * (model_s1 - model_s) @@ -755,10 +779,10 @@ class DPM_Solver: D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) D2 = 2. * (D1_1 - D1_0) / (r2 - r1) x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - expand_dims(sigma_t * phi_2, dims) * D1 - - expand_dims(sigma_t * phi_3, dims) * D2 + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 ) if return_intermediate: @@ -766,28 +790,26 @@ class DPM_Solver: else: return x_t - def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): """ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. Args: x: A pytorch tensor. The initial value at time `s`. model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) ns = self.noise_schedule - dims = x.dim() - model_prev_1, model_prev_0 = model_prev_list - t_prev_1, t_prev_0 = t_prev_list - lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( - t_prev_0), ns.marginal_lambda(t) + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -795,55 +817,55 @@ class DPM_Solver: h_0 = lambda_prev_0 - lambda_prev_1 h = lambda_t - lambda_prev_0 r0 = h_0 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - if self.predict_x0: - if solver_type == 'dpm_solver': + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == 'dpmsolver': x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + - 0.5 * (alpha_t * phi_1) * D1_0 ) elif solver_type == 'taylor': x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.)) * D1_0 ) else: - if solver_type == 'dpm_solver': + phi_1 = torch.expm1(h) + if solver_type == 'dpmsolver': x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 ) elif solver_type == 'taylor': x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.)) * D1_0 ) return x_t - def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'): """ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. Args: x: A pytorch tensor. The initial value at time `s`. model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ ns = self.noise_schedule - dims = x.dim() model_prev_2, model_prev_1, model_prev_0 = model_prev_list t_prev_2, t_prev_1, t_prev_0 = t_prev_list - lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( - t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -852,39 +874,44 @@ class DPM_Solver: h_0 = lambda_prev_0 - lambda_prev_1 h = lambda_t - lambda_prev_0 r0, r1 = h_0 / h, h_1 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) - D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) - D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) - if self.predict_x0: + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1. / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 ) else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 ) return x_t - def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, - r2=None): + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, r2=None): """ Singlestep DPM-Solver with the order `order` from time `s` to time `t`. Args: x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. r1: A `float`. The hyperparameter of the second-order or third-order solver. r2: A `float`. The hyperparameter of the third-order solver. Returns: @@ -893,26 +920,24 @@ class DPM_Solver: if order == 1: return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) elif order == 2: - return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1) + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) elif order == 3: - return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1, r2=r2) + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'): """ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. Args: x: A pytorch tensor. The initial value at time `s`. model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ @@ -925,8 +950,7 @@ class DPM_Solver: else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, - solver_type='dpm_solver'): + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpmsolver'): """ The adaptive step size solver based on singlestep DPM-Solver. @@ -941,15 +965,15 @@ class DPM_Solver: theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the current time and `t_0` is less than `t_err`. The default setting is 1e-5. - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. Returns: x_0: A pytorch tensor. The approximated solution at time `t_0`. [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. """ ns = self.noise_schedule - s = t_T * torch.ones((x.shape[0],)).to(x) + s = t_T * torch.ones((1,)).to(x) lambda_s = ns.marginal_lambda(s) lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) h = h_init * torch.ones_like(s).to(x) @@ -957,18 +981,16 @@ class DPM_Solver: nfe = 0 if order == 2: r1 = 0.5 - lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - solver_type=solver_type, - **kwargs) + def lower_update(x, s, t): + return self.dpm_solver_first_update(x, s, t, return_intermediate=True) + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) elif order == 3: r1, r2 = 1. / 3., 2. / 3. - lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - return_intermediate=True, - solver_type=solver_type) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, - solver_type=solver_type, - **kwargs) + def lower_update(x, s, t): + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) else: raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) while torch.abs((s - t_0)).mean() > t_err: @@ -976,7 +998,8 @@ class DPM_Solver: x_lower, lower_noise_kwargs = lower_update(x, s, t) x_higher = higher_update(x, s, t, **lower_noise_kwargs) delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) - norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + def norm_fn(v): + return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) E = norm_fn((x_higher - x_lower) / delta).max() if torch.all(E <= 1.): x = x_higher @@ -988,10 +1011,45 @@ class DPM_Solver: print('adaptive solver nfe', nfe) return x - def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', denoise=False, solver_type='dpm_solver', atol=0.0078, - rtol=0.05, - ): + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type, + method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type, + atol=atol, rtol=rtol, return_intermediate=return_intermediate) + + def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', + atol=0.0078, rtol=0.05, return_intermediate=False, + ): """ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. @@ -1040,15 +1098,19 @@ class DPM_Solver: Some advices for choosing the algorithm: - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: - Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, skip_type='time_uniform', method='singlestep') - For **guided sampling with large guidance scale** by DPMs: - Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, skip_type='time_uniform', method='multistep') @@ -1074,72 +1136,116 @@ class DPM_Solver: order: A `int`. The order of DPM-Solver. skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. - denoise: A `bool`. Whether to denoise at the final step. Default is False. - If `denoise` is True, the total NFE is (`steps` + 1). - solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. Returns: x_end: A pytorch tensor. The approximated solution at time `t_end`. """ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None" device = x.device - if method == 'adaptive': - with torch.no_grad(): - x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, - solver_type=solver_type) - elif method == 'multistep': - assert steps >= order - timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - assert timesteps.shape[0] - 1 == steps - with torch.no_grad(): - vec_t = timesteps[0].expand((x.shape[0])) - model_prev_list = [self.model_fn(x, vec_t)] - t_prev_list = [vec_t] + intermediates = [] + with torch.no_grad(): + if method == 'adaptive': + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) # Init the first `order` values by lower order multistep DPM-Solver. - for init_order in range(1, order): - vec_t = timesteps[init_order].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, - solver_type=solver_type) - model_prev_list.append(self.model_fn(x, vec_t)) - t_prev_list.append(vec_t) + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) # Compute the remaining values by `order`-th order multistep DPM-Solver. for step in range(order, steps + 1): - vec_t = timesteps[step].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, order, - solver_type=solver_type) + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) for i in range(order - 1): t_prev_list[i] = t_prev_list[i + 1] model_prev_list[i] = model_prev_list[i + 1] - t_prev_list[-1] = vec_t + t_prev_list[-1] = t # We do not need to evaluate the final model value. if step < steps: - model_prev_list[-1] = self.model_fn(x, vec_t) - elif method in ['singlestep', 'singlestep_fixed']: - if method == 'singlestep': - timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, - skip_type=skip_type, - t_T=t_T, t_0=t_0, - device=device) - elif method == 'singlestep_fixed': - K = steps // order - orders = [order, ] * K - timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) - for i, order in enumerate(orders): - t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] - timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), - N=order, device=device) - lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) - vec_s, vec_t = t_T_inner.repeat(x.shape[0]), t_0_inner.repeat(x.shape[0]) - h = lambda_inner[-1] - lambda_inner[0] - r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h - r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h - x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) - if denoise: - x = self.denoise_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) - return x + model_prev_list[-1] = self.model_fn(x, t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order,] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError("Got wrong method {}".format(method)) + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + ############################################################# @@ -1198,4 +1304,4 @@ def expand_dims(v, dims): Returns: a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. """ - return v[(...,) + (None,) * (dims - 1)] + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/diffusion/infer_gt_mel.py b/diffusion/infer_gt_mel.py index 033b821a5d21a1232f1786bce5616b12e01488ad..0bdf1fed45a7df3eef954c24fd0d4b0dcb0fc1b8 100644 --- a/diffusion/infer_gt_mel.py +++ b/diffusion/infer_gt_mel.py @@ -1,6 +1,6 @@ -import numpy as np import torch import torch.nn.functional as F + from diffusion.unit2mel import load_model_vocoder diff --git a/diffusion/logger/__pycache__/__init__.cpython-38.pyc b/diffusion/logger/__pycache__/__init__.cpython-38.pyc index ac53ae184cabbf830a31b095b0c9fa0011ff65c1..06c4ce5ccebc3e8c3f6028b00f2de5a17761fa64 100644 Binary files a/diffusion/logger/__pycache__/__init__.cpython-38.pyc and b/diffusion/logger/__pycache__/__init__.cpython-38.pyc differ diff --git a/diffusion/logger/__pycache__/saver.cpython-38.pyc b/diffusion/logger/__pycache__/saver.cpython-38.pyc index 4cf0ddecb8d7947f648bfccf67f21a2332502c80..6185f7cb8f62dd0bac9780d4f78fff4d31d3579a 100644 Binary files a/diffusion/logger/__pycache__/saver.cpython-38.pyc and b/diffusion/logger/__pycache__/saver.cpython-38.pyc differ diff --git a/diffusion/logger/__pycache__/utils.cpython-38.pyc b/diffusion/logger/__pycache__/utils.cpython-38.pyc index 768bb5c08216c20be44b605c8e2372ce9d6f39a3..00f94f0cbb521fd62badd833a06b832336412b1f 100644 Binary files a/diffusion/logger/__pycache__/utils.cpython-38.pyc and b/diffusion/logger/__pycache__/utils.cpython-38.pyc differ diff --git a/diffusion/logger/saver.py b/diffusion/logger/saver.py index ef78b52b6bcd32106f962b731d3784d72d5f0cce..954ce99b37f6c983999d4f7e1b08dcc5b7d99bc4 100644 --- a/diffusion/logger/saver.py +++ b/diffusion/logger/saver.py @@ -2,16 +2,16 @@ author: wayn391@mastertones ''' +import datetime import os -import json import time -import yaml -import datetime -import torch + import matplotlib.pyplot as plt -from . import utils +import torch +import yaml from torch.utils.tensorboard import SummaryWriter + class Saver(object): def __init__( self, @@ -125,12 +125,7 @@ class Saver(object): torch.save({ 'global_step': self.global_step, 'model': model.state_dict()}, path_pt) - - # to json - if to_json: - path_json = os.path.join( - self.expdir , name+'.json') - utils.to_json(path_params, path_json) + def delete_model(self, name='model', postfix=''): # path diff --git a/diffusion/logger/utils.py b/diffusion/logger/utils.py index 485681ced897980dc0bf5b149308245bbd708de9..a907de7dc4ece50746f87f92ee1985d50d34fcbb 100644 --- a/diffusion/logger/utils.py +++ b/diffusion/logger/utils.py @@ -1,8 +1,9 @@ -import os -import yaml import json -import pickle +import os + import torch +import yaml + def traverse_dir( root_dir, @@ -121,6 +122,6 @@ def load_model( ckpt = torch.load(path_pt, map_location=torch.device(device)) global_step = ckpt['global_step'] model.load_state_dict(ckpt['model'], strict=False) - if ckpt.get('optimizer') != None: + if ckpt.get("optimizer") is not None: optimizer.load_state_dict(ckpt['optimizer']) return global_step, model, optimizer diff --git a/diffusion/onnx_export.py b/diffusion/onnx_export.py index 5deda785cf22b341f7d2e6399ef5fcdad6fe129e..053cb46b4359a0acd5d95d879e1bcf3a047542cf 100644 --- a/diffusion/onnx_export.py +++ b/diffusion/onnx_export.py @@ -1,12 +1,12 @@ -from diffusion_onnx import GaussianDiffusion import os -import yaml + +import numpy as np import torch import torch.nn as nn -import numpy as np -from wavenet import WaveNet import torch.nn.functional as F -import diffusion +import yaml +from diffusion_onnx import GaussianDiffusion + class DotDict(dict): def __getattr__(*args): @@ -33,7 +33,9 @@ def load_model_vocoder( 128, args.model.n_layers, args.model.n_chans, - args.model.n_hidden) + args.model.n_hidden, + args.model.timesteps, + args.model.k_step_max) print(' [Loading] ' + model_path) ckpt = torch.load(model_path, map_location=torch.device(device)) @@ -52,8 +54,11 @@ class Unit2Mel(nn.Module): out_dims=128, n_layers=20, n_chans=384, - n_hidden=256): + n_hidden=256, + timesteps=1000, + k_step_max=1000): super().__init__() + self.unit_embed = nn.Linear(input_channel, n_hidden) self.f0_embed = nn.Linear(1, n_hidden) self.volume_embed = nn.Linear(1, n_hidden) @@ -64,9 +69,13 @@ class Unit2Mel(nn.Module): self.n_spk = n_spk if n_spk is not None and n_spk > 1: self.spk_embed = nn.Embedding(n_spk, n_hidden) - + + self.timesteps = timesteps if timesteps is not None else 1000 + self.k_step_max = k_step_max if k_step_max is not None and k_step_max>0 and k_step_max>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + def log_alpha_fn(s): + return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + def t_fn(log_alpha_t): + return torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2.0 * (1.0 + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - alpha_t * output) / sigma_t + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return alpha_t * output + sigma_t * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -sigma_t * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * sigma_t * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class UniPC: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="data_prediction", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1., + dynamic_thresholding_ratio=0.995, + variant='bh1' + ): + """Construct a UniPC. + + We support both data_prediction and noise_prediction. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["data_prediction", "noise_prediction"] + + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + self.variant = variant + self.predict_x0 = algorithm_type == "data_prediction" + + def dynamic_thresholding_fn(self, x0, t=None): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = steps + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs): + if len(t.shape) == 0: + t = t.view(-1) + if 'bh' in self.variant: + return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + else: + assert self.variant == 'vary_coeff' + return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + + def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): + #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') + ns = self.noise_schedule + assert order <= len(model_prev_list) + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_t = ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = (lambda_prev_i - lambda_prev_0) / h + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + K = len(rks) + # build C matrix + C = [] + + col = torch.ones_like(rks) + for k in range(1, K + 1): + C.append(col) + col = col * rks / (k + 1) + C = torch.stack(C, dim=1) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + C_inv_p = torch.linalg.inv(C[:-1, :-1]) + A_p = C_inv_p + + if use_corrector: + #print('using corrector') + C_inv = torch.linalg.inv(C) + A_c = C_inv + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_ks = [] + factorial_k = 1 + h_phi_k = h_phi_1 + for k in range(1, K + 2): + h_phi_ks.append(h_phi_k) + h_phi_k = h_phi_k / hh - 1 / factorial_k + factorial_k *= (k + 1) + + model_t = None + if self.predict_x0: + x_t_ = ( + sigma_t / sigma_prev_0 * x + - alpha_t * h_phi_1 * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + else: + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + x_t_ = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * h_phi_1) * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + return x_t, model_t + + def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): + #print(f'using unified predictor-corrector with order {order} (solver type: B(h))') + ns = self.noise_schedule + assert order <= len(model_prev_list) + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = (lambda_prev_i - lambda_prev_0) / h + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.variant == 'bh1': + B_h = hh + elif self.variant == 'bh2': + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= (i + 1) + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.cat(b) + + # now predictor + use_predictor = len(D1s) > 0 and x_t is None + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + if x_t is None: + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], device=b.device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + else: + D1s = None + + if use_corrector: + #print('using corrector') + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], device=b.device) + else: + rhos_c = torch.linalg.solve(R, b) + + model_t = None + if self.predict_x0: + x_t_ = ( + sigma_t / sigma_prev_0 * x + - alpha_t * h_phi_1 * model_prev_0 + ) + + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = ( + torch.exp(log_alpha_t - log_alpha_prev_0) * x + - sigma_t * h_phi_1 * model_prev_0 + ) + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + return x_t, model_t + + def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', + method='multistep', lower_order_final=True, denoise_to_zero=False, atol=0.0078, rtol=0.05, return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by UniPC, given the initial `x` at time `t_start`. + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + + # Init the first `order` values by lower order multistep UniPC. + for step in range(1, order): + t = timesteps[step] + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step, use_corrector=True) + if model_x is None: + model_x = self.model_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(model_x) + + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + if lower_order_final: + step_order = min(order, steps + 1 - step) + else: + step_order = order + if step == steps: + #print('do not run corrector at the last step') + use_corrector = False + else: + use_corrector = True + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, t, step_order, use_corrector=use_corrector) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + if model_x is None: + model_x = self.model_fn(x, t) + model_prev_list[-1] = model_x + else: + raise ValueError("Got wrong method {}".format(method)) + + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/diffusion/unit2mel.py b/diffusion/unit2mel.py index 52293b13da8e1afeef6fa5586aeaf01cbcc27fb7..5087f2a512aba1d265d82c644ac6c9859a34d422 100644 --- a/diffusion/unit2mel.py +++ b/diffusion/unit2mel.py @@ -1,11 +1,14 @@ import os -import yaml + +import numpy as np import torch import torch.nn as nn -import numpy as np +import yaml + from .diffusion import GaussianDiffusion -from .wavenet import WaveNet from .vocoder import Vocoder +from .wavenet import WaveNet + class DotDict(dict): def __getattr__(*args): @@ -21,9 +24,11 @@ def load_model_vocoder( device='cpu', config_path = None ): - if config_path is None: config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') - else: config_file = config_path - + if config_path is None: + config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') + else: + config_file = config_path + with open(config_file, "r") as config: args = yaml.safe_load(config) args = DotDict(args) @@ -39,13 +44,17 @@ def load_model_vocoder( vocoder.dimension, args.model.n_layers, args.model.n_chans, - args.model.n_hidden) + args.model.n_hidden, + args.model.timesteps, + args.model.k_step_max + ) print(' [Loading] ' + model_path) ckpt = torch.load(model_path, map_location=torch.device(device)) model.to(device) model.load_state_dict(ckpt['model']) model.eval() + print(f'Loaded diffusion model, sampler is {args.infer.method}, speedup: {args.infer.speedup} ') return model, vocoder, args @@ -58,7 +67,10 @@ class Unit2Mel(nn.Module): out_dims=128, n_layers=20, n_chans=384, - n_hidden=256): + n_hidden=256, + timesteps=1000, + k_step_max=1000 + ): super().__init__() self.unit_embed = nn.Linear(input_channel, n_hidden) self.f0_embed = nn.Linear(1, n_hidden) @@ -71,9 +83,12 @@ class Unit2Mel(nn.Module): if n_spk is not None and n_spk > 1: self.spk_embed = nn.Embedding(n_spk, n_hidden) + self.timesteps = timesteps if timesteps is not None else 1000 + self.k_step_max = k_step_max if k_step_max is not None and k_step_max>0 and k_step_maxself.k_step_max: + raise Exception("The shallow diffusion k_step is greater than the maximum diffusion k_step(k_step_max)!") + + if not self.training and gt_spec is None and self.k_step_max!=self.timesteps: + raise Exception("This model can only be used for shallow diffusion and can not infer alone!") + x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume) if self.n_spk is not None and self.n_spk > 1: if spk_mix_dict is not None: diff --git a/diffusion/vocoder.py b/diffusion/vocoder.py index bbaa47f64fd5a3191a24dfaa054c423fa86e5bae..ec9c80e65d160cd9364a7e2106fd0a94a814e3c9 100644 --- a/diffusion/vocoder.py +++ b/diffusion/vocoder.py @@ -1,9 +1,10 @@ import torch -from vdecoder.nsf_hifigan.nvSTFT import STFT -from vdecoder.nsf_hifigan.models import load_model,load_config from torchaudio.transforms import Resample - +from vdecoder.nsf_hifigan.models import load_config, load_model +from vdecoder.nsf_hifigan.nvSTFT import STFT + + class Vocoder: def __init__(self, vocoder_type, vocoder_ckpt, device = None): if device is None: diff --git a/modules/DSConv.py b/modules/DSConv.py new file mode 100644 index 0000000000000000000000000000000000000000..44c2bf60e9cd2b837ca95fb6436768782057014a --- /dev/null +++ b/modules/DSConv.py @@ -0,0 +1,76 @@ +import torch.nn as nn +from torch.nn.utils import remove_weight_norm, weight_norm + + +class Depthwise_Separable_Conv1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride = 1, + padding = 0, + dilation = 1, + bias = True, + padding_mode = 'zeros', # TODO: refine this type + device=None, + dtype=None + ): + super().__init__() + self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype) + self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype) + + def forward(self, input): + return self.point_conv(self.depth_conv(input)) + + def weight_norm(self): + self.depth_conv = weight_norm(self.depth_conv, name = 'weight') + self.point_conv = weight_norm(self.point_conv, name = 'weight') + + def remove_weight_norm(self): + self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight') + self.point_conv = remove_weight_norm(self.point_conv, name = 'weight') + +class Depthwise_Separable_TransposeConv1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride = 1, + padding = 0, + output_padding = 0, + bias = True, + dilation = 1, + padding_mode = 'zeros', # TODO: refine this type + device=None, + dtype=None + ): + super().__init__() + self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype) + self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype) + + def forward(self, input): + return self.point_conv(self.depth_conv(input)) + + def weight_norm(self): + self.depth_conv = weight_norm(self.depth_conv, name = 'weight') + self.point_conv = weight_norm(self.point_conv, name = 'weight') + + def remove_weight_norm(self): + remove_weight_norm(self.depth_conv, name = 'weight') + remove_weight_norm(self.point_conv, name = 'weight') + + +def weight_norm_modules(module, name = 'weight', dim = 0): + if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D): + module.weight_norm() + return module + else: + return weight_norm(module,name,dim) + +def remove_weight_norm_modules(module, name = 'weight'): + if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D): + module.remove_weight_norm() + else: + remove_weight_norm(module,name) \ No newline at end of file diff --git a/modules/F0Predictor/CrepeF0Predictor.py b/modules/F0Predictor/CrepeF0Predictor.py index e0052881b9b7b3aa373ebf69eb553815a564f610..086ca1079193d5cffb64f3ce5538402ae423b6d3 100644 --- a/modules/F0Predictor/CrepeF0Predictor.py +++ b/modules/F0Predictor/CrepeF0Predictor.py @@ -1,7 +1,9 @@ -from modules.F0Predictor.F0Predictor import F0Predictor -from modules.F0Predictor.crepe import CrepePitchExtractor import torch +from modules.F0Predictor.crepe import CrepePitchExtractor +from modules.F0Predictor.F0Predictor import F0Predictor + + class CrepeF0Predictor(F0Predictor): def __init__(self,hop_length=512,f0_min=50,f0_max=1100,device=None,sampling_rate=44100,threshold=0.05,model="full"): self.F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=threshold,model=model) diff --git a/modules/F0Predictor/DioF0Predictor.py b/modules/F0Predictor/DioF0Predictor.py index 4ab27de23cae4dbc282e30f84501afebd1a37518..ef470a4c09e424004b119c642edec898f8ea1431 100644 --- a/modules/F0Predictor/DioF0Predictor.py +++ b/modules/F0Predictor/DioF0Predictor.py @@ -1,6 +1,8 @@ -from modules.F0Predictor.F0Predictor import F0Predictor -import pyworld import numpy as np +import pyworld + +from modules.F0Predictor.F0Predictor import F0Predictor + class DioF0Predictor(F0Predictor): def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): @@ -13,39 +15,25 @@ class DioF0Predictor(F0Predictor): ''' 对F0进行插值处理 ''' + vuv_vector = np.zeros_like(f0, dtype=np.float32) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 - data = np.reshape(f0, (f0.size, 1)) - - vuv_vector = np.zeros((data.size, 1), dtype=np.float32) - vuv_vector[data > 0.0] = 1.0 - vuv_vector[data <= 0.0] = 0.0 - - ip_data = data - - frame_number = data.size - last_value = 0.0 - for i in range(frame_number): - if data[i] <= 0.0: - j = i + 1 - for j in range(i + 1, frame_number): - if data[j] > 0.0: - break - if j < frame_number - 1: - if last_value > 0.0: - step = (data[j] - data[i - 1]) / float(j - i) - for k in range(i, j): - ip_data[k] = data[i - 1] + step * (k - i + 1) - else: - for k in range(i, j): - ip_data[k] = data[j] - else: - for k in range(i, frame_number): - ip_data[k] = last_value - else: - ip_data[i] = data[i] #这里可能存在一个没有必要的拷贝 - last_value = data[i] - - return ip_data[:,0], vuv_vector[:,0] + nzindex = np.nonzero(f0)[0] + data = f0[nzindex] + nzindex = nzindex.astype(np.float32) + time_org = self.hop_length / self.sampling_rate * nzindex + time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate + + if data.shape[0] <= 0: + return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector + + if data.shape[0] == 1: + return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector + + f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) + + return f0,vuv_vector def resize_f0(self,x, target_len): source = np.array(x) diff --git a/modules/F0Predictor/HarvestF0Predictor.py b/modules/F0Predictor/HarvestF0Predictor.py index 122bdbb4c736feb4a8d974eca03df71aede76f69..fe279f67add9504d7044cd77a4500fbb39c10b80 100644 --- a/modules/F0Predictor/HarvestF0Predictor.py +++ b/modules/F0Predictor/HarvestF0Predictor.py @@ -1,6 +1,8 @@ -from modules.F0Predictor.F0Predictor import F0Predictor -import pyworld import numpy as np +import pyworld + +from modules.F0Predictor.F0Predictor import F0Predictor + class HarvestF0Predictor(F0Predictor): def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): @@ -13,40 +15,25 @@ class HarvestF0Predictor(F0Predictor): ''' 对F0进行插值处理 ''' + vuv_vector = np.zeros_like(f0, dtype=np.float32) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 - data = np.reshape(f0, (f0.size, 1)) - - vuv_vector = np.zeros((data.size, 1), dtype=np.float32) - vuv_vector[data > 0.0] = 1.0 - vuv_vector[data <= 0.0] = 0.0 - - ip_data = data - - frame_number = data.size - last_value = 0.0 - for i in range(frame_number): - if data[i] <= 0.0: - j = i + 1 - for j in range(i + 1, frame_number): - if data[j] > 0.0: - break - if j < frame_number - 1: - if last_value > 0.0: - step = (data[j] - data[i - 1]) / float(j - i) - for k in range(i, j): - ip_data[k] = data[i - 1] + step * (k - i + 1) - else: - for k in range(i, j): - ip_data[k] = data[j] - else: - for k in range(i, frame_number): - ip_data[k] = last_value - else: - ip_data[i] = data[i] #这里可能存在一个没有必要的拷贝 - last_value = data[i] - - return ip_data[:,0], vuv_vector[:,0] + nzindex = np.nonzero(f0)[0] + data = f0[nzindex] + nzindex = nzindex.astype(np.float32) + time_org = self.hop_length / self.sampling_rate * nzindex + time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate + if data.shape[0] <= 0: + return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector + + if data.shape[0] == 1: + return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector + + f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) + + return f0,vuv_vector def resize_f0(self,x, target_len): source = np.array(x) source[source<0.001] = np.nan diff --git a/modules/F0Predictor/PMF0Predictor.py b/modules/F0Predictor/PMF0Predictor.py index ccf4128436c5b7e5a3e720d4597bad0c622d0920..cb7355f65bb29f68607e027c316038f22c7edddb 100644 --- a/modules/F0Predictor/PMF0Predictor.py +++ b/modules/F0Predictor/PMF0Predictor.py @@ -1,6 +1,8 @@ -from modules.F0Predictor.F0Predictor import F0Predictor -import parselmouth import numpy as np +import parselmouth + +from modules.F0Predictor.F0Predictor import F0Predictor + class PMF0Predictor(F0Predictor): def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): @@ -14,39 +16,26 @@ class PMF0Predictor(F0Predictor): ''' 对F0进行插值处理 ''' + vuv_vector = np.zeros_like(f0, dtype=np.float32) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 - data = np.reshape(f0, (f0.size, 1)) - - vuv_vector = np.zeros((data.size, 1), dtype=np.float32) - vuv_vector[data > 0.0] = 1.0 - vuv_vector[data <= 0.0] = 0.0 - - ip_data = data - - frame_number = data.size - last_value = 0.0 - for i in range(frame_number): - if data[i] <= 0.0: - j = i + 1 - for j in range(i + 1, frame_number): - if data[j] > 0.0: - break - if j < frame_number - 1: - if last_value > 0.0: - step = (data[j] - data[i - 1]) / float(j - i) - for k in range(i, j): - ip_data[k] = data[i - 1] + step * (k - i + 1) - else: - for k in range(i, j): - ip_data[k] = data[j] - else: - for k in range(i, frame_number): - ip_data[k] = last_value - else: - ip_data[i] = data[i] #这里可能存在一个没有必要的拷贝 - last_value = data[i] + nzindex = np.nonzero(f0)[0] + data = f0[nzindex] + nzindex = nzindex.astype(np.float32) + time_org = self.hop_length / self.sampling_rate * nzindex + time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate + + if data.shape[0] <= 0: + return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector + + if data.shape[0] == 1: + return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector + + f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) + + return f0,vuv_vector - return ip_data[:,0], vuv_vector[:,0] def compute_f0(self,wav,p_len=None): x = wav diff --git a/modules/F0Predictor/RMVPEF0Predictor.py b/modules/F0Predictor/RMVPEF0Predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..755f61cbe2782fb5ab3b655871edeb3dc3f575e5 --- /dev/null +++ b/modules/F0Predictor/RMVPEF0Predictor.py @@ -0,0 +1,106 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn.functional as F + +from modules.F0Predictor.F0Predictor import F0Predictor + +from .rmvpe import RMVPE + + +class RMVPEF0Predictor(F0Predictor): + def __init__(self,hop_length=512,f0_min=50,f0_max=1100, dtype=torch.float32, device=None,sampling_rate=44100,threshold=0.05): + self.rmvpe = RMVPE(model_path="pretrain/rmvpe.pt",dtype=dtype,device=device) + self.hop_length = hop_length + self.f0_min = f0_min + self.f0_max = f0_max + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + self.threshold = threshold + self.sampling_rate = sampling_rate + self.dtype = dtype + + def repeat_expand( + self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" + ): + ndim = content.ndim + + if content.ndim == 1: + content = content[None, None] + elif content.ndim == 2: + content = content[None] + + assert content.ndim == 3 + + is_np = isinstance(content, np.ndarray) + if is_np: + content = torch.from_numpy(content) + + results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) + + if is_np: + results = results.numpy() + + if ndim == 1: + return results[0, 0] + elif ndim == 2: + return results[0] + + def post_process(self, x, sampling_rate, f0, pad_to): + if isinstance(f0, np.ndarray): + f0 = torch.from_numpy(f0).float().to(x.device) + + if pad_to is None: + return f0 + + f0 = self.repeat_expand(f0, pad_to) + + vuv_vector = torch.zeros_like(f0) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 + + # 去掉0频率, 并线性插值 + nzindex = torch.nonzero(f0).squeeze() + f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() + time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() + time_frame = np.arange(pad_to) * self.hop_length / sampling_rate + + vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0] + + if f0.shape[0] <= 0: + return torch.zeros(pad_to, dtype=torch.float, device=x.device),vuv_vector.cpu().numpy() + if f0.shape[0] == 1: + return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],vuv_vector.cpu().numpy() + + # 大概可以用 torch 重写? + f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) + #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) + + return f0,vuv_vector.cpu().numpy() + + def compute_f0(self,wav,p_len=None): + x = torch.FloatTensor(wav).to(self.dtype).to(self.device) + if p_len is None: + p_len = x.shape[0]//self.hop_length + else: + assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" + f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold) + if torch.all(f0 == 0): + rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) + return rtn,rtn + return self.post_process(x,self.sampling_rate,f0,p_len)[0] + + def compute_f0_uv(self,wav,p_len=None): + x = torch.FloatTensor(wav).to(self.dtype).to(self.device) + if p_len is None: + p_len = x.shape[0]//self.hop_length + else: + assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" + f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold) + if torch.all(f0 == 0): + rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) + return rtn,rtn + return self.post_process(x,self.sampling_rate,f0,p_len) \ No newline at end of file diff --git a/modules/F0Predictor/__pycache__/CrepeF0Predictor.cpython-38.pyc b/modules/F0Predictor/__pycache__/CrepeF0Predictor.cpython-38.pyc index 36eeae3d42ce629baa24c429edf1a18b607c78be..3c7cc4079deb650be3b85b2e5306264b88a0d824 100644 Binary files a/modules/F0Predictor/__pycache__/CrepeF0Predictor.cpython-38.pyc and b/modules/F0Predictor/__pycache__/CrepeF0Predictor.cpython-38.pyc differ diff --git a/modules/F0Predictor/__pycache__/F0Predictor.cpython-38.pyc b/modules/F0Predictor/__pycache__/F0Predictor.cpython-38.pyc index 99cf916f678e0415f3e7cbc125ff769abb5d7a2c..7aa6f894af610362df9f828e836a227d8e650ce6 100644 Binary files a/modules/F0Predictor/__pycache__/F0Predictor.cpython-38.pyc and b/modules/F0Predictor/__pycache__/F0Predictor.cpython-38.pyc differ diff --git a/modules/F0Predictor/__pycache__/HarvestF0Predictor.cpython-38.pyc b/modules/F0Predictor/__pycache__/HarvestF0Predictor.cpython-38.pyc index c7dbadab31a99648ff751358a3aeb50080798fb2..fd8e8107611713da5b4214335ccd75cba1224036 100644 Binary files a/modules/F0Predictor/__pycache__/HarvestF0Predictor.cpython-38.pyc and b/modules/F0Predictor/__pycache__/HarvestF0Predictor.cpython-38.pyc differ diff --git a/modules/F0Predictor/__pycache__/PMF0Predictor.cpython-38.pyc b/modules/F0Predictor/__pycache__/PMF0Predictor.cpython-38.pyc index 2b910b23ec2825458cf18ca1c874c6996cbc7ec8..3c60990a7838c5fcc279482188b72a3d2a4ccb96 100644 Binary files a/modules/F0Predictor/__pycache__/PMF0Predictor.cpython-38.pyc and b/modules/F0Predictor/__pycache__/PMF0Predictor.cpython-38.pyc differ diff --git a/modules/F0Predictor/__pycache__/RMVPEF0Predictor.cpython-38.pyc b/modules/F0Predictor/__pycache__/RMVPEF0Predictor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e53c9818364c152c6101a96790efc58ddd9146be Binary files /dev/null and b/modules/F0Predictor/__pycache__/RMVPEF0Predictor.cpython-38.pyc differ diff --git a/modules/F0Predictor/__pycache__/__init__.cpython-38.pyc b/modules/F0Predictor/__pycache__/__init__.cpython-38.pyc index 3af86b5539399d7054467a138b10809211537452..2c232cdfd9b79f01456edfdaf11fd074078026e1 100644 Binary files a/modules/F0Predictor/__pycache__/__init__.cpython-38.pyc and b/modules/F0Predictor/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/F0Predictor/__pycache__/crepe.cpython-38.pyc b/modules/F0Predictor/__pycache__/crepe.cpython-38.pyc index ce4d2539c6a5e2fd366fcfdae9b7efd80988d4d4..3410f2ff049d17da1fd6d1f218e95629139fba24 100644 Binary files a/modules/F0Predictor/__pycache__/crepe.cpython-38.pyc and b/modules/F0Predictor/__pycache__/crepe.cpython-38.pyc differ diff --git a/modules/F0Predictor/crepe.py b/modules/F0Predictor/crepe.py index c6fb45c79bcd306202a2c0282b3d73a8074ced5d..e68f19cb39eb79931926ffd312fb61e30bf39d72 100644 --- a/modules/F0Predictor/crepe.py +++ b/modules/F0Predictor/crepe.py @@ -1,14 +1,14 @@ -from typing import Optional,Union +from typing import Optional, Union + try: from typing import Literal -except Exception as e: +except Exception: from typing_extensions import Literal import numpy as np import torch import torchcrepe from torch import nn from torch.nn import functional as F -import scipy #from:https://github.com/fishaudio/fish-diffusion @@ -97,19 +97,19 @@ class BasePitchExtractor: f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() time_frame = np.arange(pad_to) * self.hop_length / sampling_rate + + vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0] if f0.shape[0] <= 0: - return torch.zeros(pad_to, dtype=torch.float, device=x.device),torch.zeros(pad_to, dtype=torch.float, device=x.device) - + return torch.zeros(pad_to, dtype=torch.float, device=x.device),vuv_vector.cpu().numpy() if f0.shape[0] == 1: - return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],torch.ones(pad_to, dtype=torch.float, device=x.device) + return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],vuv_vector.cpu().numpy() # 大概可以用 torch 重写? f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) - vuv_vector = vuv_vector.cpu().numpy() - vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) + #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) - return f0,vuv_vector + return f0,vuv_vector.cpu().numpy() class MaskedAvgPool1d(nn.Module): @@ -323,7 +323,7 @@ class CrepePitchExtractor(BasePitchExtractor): else: pd = torchcrepe.filter.median(pd, 3) - pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, 512) + pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, self.hop_length) f0 = torchcrepe.threshold.At(self.threshold)(f0, pd) if self.use_fast_filters: @@ -334,7 +334,7 @@ class CrepePitchExtractor(BasePitchExtractor): f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)[0] if torch.all(f0 == 0): - rtn = f0.cpu().numpy() if pad_to==None else np.zeros(pad_to) + rtn = f0.cpu().numpy() if pad_to is None else np.zeros(pad_to) return rtn,rtn return self.post_process(x, sampling_rate, f0, pad_to) diff --git a/modules/F0Predictor/rmvpe/__init__.py b/modules/F0Predictor/rmvpe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2dcf9e971ac4fcea29fe2e312d591fd0447f95d --- /dev/null +++ b/modules/F0Predictor/rmvpe/__init__.py @@ -0,0 +1,10 @@ +from .constants import * # noqa: F403 +from .inference import RMVPE # noqa: F401 +from .model import E2E, E2E0 # noqa: F401 +from .spec import MelSpectrogram # noqa: F401 +from .utils import ( # noqa: F401 + cycle, + summary, + to_local_average_cents, + to_viterbi_cents, +) diff --git a/modules/F0Predictor/rmvpe/__pycache__/__init__.cpython-38.pyc b/modules/F0Predictor/rmvpe/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25af78105f306be398edcbd191031417e1837971 Binary files /dev/null and b/modules/F0Predictor/rmvpe/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/F0Predictor/rmvpe/__pycache__/constants.cpython-38.pyc b/modules/F0Predictor/rmvpe/__pycache__/constants.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d96cc67b443c88b0af72824ae374b8831e9f5676 Binary files /dev/null and b/modules/F0Predictor/rmvpe/__pycache__/constants.cpython-38.pyc differ diff --git a/modules/F0Predictor/rmvpe/__pycache__/deepunet.cpython-38.pyc b/modules/F0Predictor/rmvpe/__pycache__/deepunet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f26b0d71641c2cf5a3ef7011bb5c1e511c3f5aa Binary files /dev/null and b/modules/F0Predictor/rmvpe/__pycache__/deepunet.cpython-38.pyc differ diff --git a/modules/F0Predictor/rmvpe/__pycache__/inference.cpython-38.pyc b/modules/F0Predictor/rmvpe/__pycache__/inference.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..712c7b1b667996dbe1677631dd5ba71019b4627f Binary files /dev/null and b/modules/F0Predictor/rmvpe/__pycache__/inference.cpython-38.pyc differ diff --git a/modules/F0Predictor/rmvpe/__pycache__/model.cpython-38.pyc b/modules/F0Predictor/rmvpe/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..309d590c622f7885aaffedc1172f1311a6dff2a4 Binary files /dev/null and b/modules/F0Predictor/rmvpe/__pycache__/model.cpython-38.pyc differ diff --git a/modules/F0Predictor/rmvpe/__pycache__/seq.cpython-38.pyc b/modules/F0Predictor/rmvpe/__pycache__/seq.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e23ccb86b40606f22d83bdc7bf79d6a21b7a384d Binary files /dev/null and b/modules/F0Predictor/rmvpe/__pycache__/seq.cpython-38.pyc differ diff --git a/modules/F0Predictor/rmvpe/__pycache__/spec.cpython-38.pyc b/modules/F0Predictor/rmvpe/__pycache__/spec.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c4ebabcc44e30190d27f620907a2dd1e7992688 Binary files /dev/null and b/modules/F0Predictor/rmvpe/__pycache__/spec.cpython-38.pyc differ diff --git a/modules/F0Predictor/rmvpe/__pycache__/utils.cpython-38.pyc b/modules/F0Predictor/rmvpe/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc629f7acc2008564401ab15f3ebceac9eb94a4b Binary files /dev/null and b/modules/F0Predictor/rmvpe/__pycache__/utils.cpython-38.pyc differ diff --git a/modules/F0Predictor/rmvpe/constants.py b/modules/F0Predictor/rmvpe/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..c5f52efc9b40f49bb746dae6807a817bffce4375 --- /dev/null +++ b/modules/F0Predictor/rmvpe/constants.py @@ -0,0 +1,9 @@ +SAMPLE_RATE = 16000 + +N_CLASS = 360 + +N_MELS = 128 +MEL_FMIN = 30 +MEL_FMAX = SAMPLE_RATE // 2 +WINDOW_LENGTH = 1024 +CONST = 1997.3794084376191 diff --git a/modules/F0Predictor/rmvpe/deepunet.py b/modules/F0Predictor/rmvpe/deepunet.py new file mode 100644 index 0000000000000000000000000000000000000000..b0171d562ac58526c7693a15124e181c78ad0a18 --- /dev/null +++ b/modules/F0Predictor/rmvpe/deepunet.py @@ -0,0 +1,190 @@ +import torch +import torch.nn as nn + +from .constants import N_MELS + + +class ConvBlockRes(nn.Module): + def __init__(self, in_channels, out_channels, momentum=0.01): + super(ConvBlockRes, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + + nn.Conv2d(in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + if in_channels != out_channels: + self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) + self.is_shortcut = True + else: + self.is_shortcut = False + + def forward(self, x): + if self.is_shortcut: + return self.conv(x) + self.shortcut(x) + else: + return self.conv(x) + x + + +class ResEncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): + super(ResEncoderBlock, self).__init__() + self.n_blocks = n_blocks + self.conv = nn.ModuleList() + self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) + for i in range(n_blocks - 1): + self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) + self.kernel_size = kernel_size + if self.kernel_size is not None: + self.pool = nn.AvgPool2d(kernel_size=kernel_size) + + def forward(self, x): + for i in range(self.n_blocks): + x = self.conv[i](x) + if self.kernel_size is not None: + return x, self.pool(x) + else: + return x + + +class ResDecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): + super(ResDecoderBlock, self).__init__() + out_padding = (0, 1) if stride == (1, 2) else (1, 1) + self.n_blocks = n_blocks + self.conv1 = nn.Sequential( + nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=stride, + padding=(1, 1), + output_padding=out_padding, + bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + self.conv2 = nn.ModuleList() + self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) + for i in range(n_blocks-1): + self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) + + def forward(self, x, concat_tensor): + x = self.conv1(x) + x = torch.cat((x, concat_tensor), dim=1) + for i in range(self.n_blocks): + x = self.conv2[i](x) + return x + + +class Encoder(nn.Module): + def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01): + super(Encoder, self).__init__() + self.n_encoders = n_encoders + self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) + self.layers = nn.ModuleList() + self.latent_channels = [] + for i in range(self.n_encoders): + self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)) + self.latent_channels.append([out_channels, in_size]) + in_channels = out_channels + out_channels *= 2 + in_size //= 2 + self.out_size = in_size + self.out_channel = out_channels + + def forward(self, x): + concat_tensors = [] + x = self.bn(x) + for i in range(self.n_encoders): + _, x = self.layers[i](x) + concat_tensors.append(_) + return x, concat_tensors + + +class Intermediate(nn.Module): + def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): + super(Intermediate, self).__init__() + self.n_inters = n_inters + self.layers = nn.ModuleList() + self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)) + for i in range(self.n_inters-1): + self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)) + + def forward(self, x): + for i in range(self.n_inters): + x = self.layers[i](x) + return x + + +class Decoder(nn.Module): + def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): + super(Decoder, self).__init__() + self.layers = nn.ModuleList() + self.n_decoders = n_decoders + for i in range(self.n_decoders): + out_channels = in_channels // 2 + self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)) + in_channels = out_channels + + def forward(self, x, concat_tensors): + for i in range(self.n_decoders): + x = self.layers[i](x, concat_tensors[-1-i]) + return x + + +class TimbreFilter(nn.Module): + def __init__(self, latent_rep_channels): + super(TimbreFilter, self).__init__() + self.layers = nn.ModuleList() + for latent_rep in latent_rep_channels: + self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0])) + + def forward(self, x_tensors): + out_tensors = [] + for i, layer in enumerate(self.layers): + out_tensors.append(layer(x_tensors[i])) + return out_tensors + + +class DeepUnet(nn.Module): + def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): + super(DeepUnet, self).__init__() + self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) + self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) + self.tf = TimbreFilter(self.encoder.latent_channels) + self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) + + def forward(self, x): + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + concat_tensors = self.tf(concat_tensors) + x = self.decoder(x, concat_tensors) + return x + + +class DeepUnet0(nn.Module): + def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): + super(DeepUnet0, self).__init__() + self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) + self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) + self.tf = TimbreFilter(self.encoder.latent_channels) + self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) + + def forward(self, x): + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + x = self.decoder(x, concat_tensors) + return x diff --git a/modules/F0Predictor/rmvpe/inference.py b/modules/F0Predictor/rmvpe/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..6beac87414ef19f47c7af6e730797aef5aea8503 --- /dev/null +++ b/modules/F0Predictor/rmvpe/inference.py @@ -0,0 +1,57 @@ +import torch +import torch.nn.functional as F +from torchaudio.transforms import Resample + +from .constants import * # noqa: F403 +from .model import E2E0 +from .spec import MelSpectrogram +from .utils import to_local_average_cents, to_viterbi_cents + + +class RMVPE: + def __init__(self, model_path, device=None, dtype = torch.float32, hop_length=160): + self.resample_kernel = {} + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + model = E2E0(4, 1, (2, 2)) + ckpt = torch.load(model_path) + model.load_state_dict(ckpt['model']) + model = model.to(dtype).to(self.device) + model.eval() + self.model = model + self.dtype = dtype + self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405 + self.resample_kernel = {} + + def mel2hidden(self, mel): + with torch.no_grad(): + n_frames = mel.shape[-1] + mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='reflect') + hidden = self.model(mel) + return hidden[:, :n_frames] + + def decode(self, hidden, thred=0.03, use_viterbi=False): + if use_viterbi: + cents_pred = to_viterbi_cents(hidden, thred=thred) + else: + cents_pred = to_local_average_cents(hidden, thred=thred) + f0 = torch.Tensor([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred]).to(self.device) + return f0 + + def infer_from_audio(self, audio, sample_rate=16000, thred=0.05, use_viterbi=False): + audio = audio.unsqueeze(0).to(self.dtype).to(self.device) + if sample_rate == 16000: + audio_res = audio + else: + key_str = str(sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128) + self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.dtype).to(self.device) + audio_res = self.resample_kernel[key_str](audio) + mel_extractor = self.mel_extractor.to(self.device) + mel = mel_extractor(audio_res, center=True).to(self.dtype) + hidden = self.mel2hidden(mel) + f0 = self.decode(hidden.squeeze(0), thred=thred, use_viterbi=use_viterbi) + return f0 \ No newline at end of file diff --git a/modules/F0Predictor/rmvpe/model.py b/modules/F0Predictor/rmvpe/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b6b643b113a0eee9a9142016c15444273002c5 --- /dev/null +++ b/modules/F0Predictor/rmvpe/model.py @@ -0,0 +1,67 @@ +from torch import nn + +from .constants import * # noqa: F403 +from .deepunet import DeepUnet, DeepUnet0 +from .seq import BiGRU +from .spec import MelSpectrogram + + +class E2E(nn.Module): + def __init__(self, hop_length, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, + en_out_channels=16): + super(E2E, self).__init__() + self.mel = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405 + self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405 + nn.Linear(512, N_CLASS), # noqa: F405 + nn.Dropout(0.25), + nn.Sigmoid() + ) + else: + self.fc = nn.Sequential( + nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405 + nn.Dropout(0.25), + nn.Sigmoid() + ) + + def forward(self, x): + mel = self.mel(x.reshape(-1, x.shape[-1])).transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + # x = self.fc(x) + hidden_vec = 0 + if len(self.fc) == 4: + for i in range(len(self.fc)): + x = self.fc[i](x) + if i == 0: + hidden_vec = x + return hidden_vec, x + + +class E2E0(nn.Module): + def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, + en_out_channels=16): + super(E2E0, self).__init__() + self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405 + nn.Linear(512, N_CLASS), # noqa: F405 + nn.Dropout(0.25), + nn.Sigmoid() + ) + else: + self.fc = nn.Sequential( + nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405 + nn.Dropout(0.25), + nn.Sigmoid() + ) + + def forward(self, mel): + mel = mel.transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + x = self.fc(x) + return x diff --git a/modules/F0Predictor/rmvpe/seq.py b/modules/F0Predictor/rmvpe/seq.py new file mode 100644 index 0000000000000000000000000000000000000000..0d48e49d72e14d34f048ca0b5824ea1f335e9a0d --- /dev/null +++ b/modules/F0Predictor/rmvpe/seq.py @@ -0,0 +1,20 @@ +import torch.nn as nn + + +class BiGRU(nn.Module): + def __init__(self, input_features, hidden_features, num_layers): + super(BiGRU, self).__init__() + self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) + + def forward(self, x): + return self.gru(x)[0] + + +class BiLSTM(nn.Module): + def __init__(self, input_features, hidden_features, num_layers): + super(BiLSTM, self).__init__() + self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) + + def forward(self, x): + return self.lstm(x)[0] + diff --git a/modules/F0Predictor/rmvpe/spec.py b/modules/F0Predictor/rmvpe/spec.py new file mode 100644 index 0000000000000000000000000000000000000000..349d05e4541ccad31cbbb24372a89cda7c0aacc0 --- /dev/null +++ b/modules/F0Predictor/rmvpe/spec.py @@ -0,0 +1,67 @@ +import numpy as np +import torch +import torch.nn.functional as F +from librosa.filters import mel + + +class MelSpectrogram(torch.nn.Module): + def __init__( + self, + n_mel_channels, + sampling_rate, + win_length, + hop_length, + n_fft=None, + mel_fmin=0, + mel_fmax=None, + clamp = 1e-5 + ): + super().__init__() + n_fft = win_length if n_fft is None else n_fft + self.hann_window = {} + mel_basis = mel( + sr=sampling_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=mel_fmin, + fmax=mel_fmax, + htk=True) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + self.n_fft = win_length if n_fft is None else n_fft + self.hop_length = hop_length + self.win_length = win_length + self.sampling_rate = sampling_rate + self.n_mel_channels = n_mel_channels + self.clamp = clamp + + def forward(self, audio, keyshift=0, speed=1, center=True): + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(self.n_fft * factor)) + win_length_new = int(np.round(self.win_length * factor)) + hop_length_new = int(np.round(self.hop_length * speed)) + + keyshift_key = str(keyshift)+'_'+str(audio.device) + if keyshift_key not in self.hann_window: + self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device) + + fft = torch.stft( + audio, + n_fft=n_fft_new, + hop_length=hop_length_new, + win_length=win_length_new, + window=self.hann_window[keyshift_key], + center=center, + return_complex=True) + magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) + + if keyshift != 0: + size = self.n_fft // 2 + 1 + resize = magnitude.size(1) + if resize < size: + magnitude = F.pad(magnitude, (0, 0, 0, size-resize)) + magnitude = magnitude[:, :size, :] * self.win_length / win_length_new + + mel_output = torch.matmul(self.mel_basis, magnitude) + log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) + return log_mel_spec \ No newline at end of file diff --git a/modules/F0Predictor/rmvpe/utils.py b/modules/F0Predictor/rmvpe/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4395255f8608da2bce0b1f15d6bd2b2bd02c1fe7 --- /dev/null +++ b/modules/F0Predictor/rmvpe/utils.py @@ -0,0 +1,107 @@ +import sys +from functools import reduce + +import librosa +import numpy as np +import torch +from torch.nn.modules.module import _addindent + +from .constants import * # noqa: F403 + + +def cycle(iterable): + while True: + for item in iterable: + yield item + + +def summary(model, file=sys.stdout): + def repr(model): + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = model.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split('\n') + child_lines = [] + total_params = 0 + for key, module in model._modules.items(): + mod_str, num_params = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append('(' + key + '): ' + mod_str) + total_params += num_params + lines = extra_lines + child_lines + + for name, p in model._parameters.items(): + if hasattr(p, 'shape'): + total_params += reduce(lambda x, y: x * y, p.shape) + + main_str = model._get_name() + '(' + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += '\n ' + '\n '.join(lines) + '\n' + + main_str += ')' + if file is sys.stdout: + main_str += ', \033[92m{:,}\033[0m params'.format(total_params) + else: + main_str += ', {:,} params'.format(total_params) + return main_str, total_params + + string, count = repr(model) + if file is not None: + if isinstance(file, str): + file = open(file, 'w') + print(string, file=file) + file.flush() + + return count + + +def to_local_average_cents(salience, center=None, thred=0.05): + """ + find the weighted average cents near the argmax bin + """ + + if not hasattr(to_local_average_cents, 'cents_mapping'): + # the bin number-to-cents mapping + to_local_average_cents.cents_mapping = ( + 20 * torch.arange(N_CLASS) + CONST).to(salience.device) # noqa: F405 + + if salience.ndim == 1: + if center is None: + center = int(torch.argmax(salience)) + start = max(0, center - 4) + end = min(len(salience), center + 5) + salience = salience[start:end] + product_sum = torch.sum( + salience * to_local_average_cents.cents_mapping[start:end]) + weight_sum = torch.sum(salience) + return product_sum / weight_sum if torch.max(salience) > thred else 0 + if salience.ndim == 2: + return torch.Tensor([to_local_average_cents(salience[i, :], None, thred) for i in + range(salience.shape[0])]).to(salience.device) + + raise Exception("label should be either 1d or 2d ndarray") + +def to_viterbi_cents(salience, thred=0.05): + # Create viterbi transition matrix + if not hasattr(to_viterbi_cents, 'transition'): + xx, yy = torch.meshgrid(range(N_CLASS), range(N_CLASS)) # noqa: F405 + transition = torch.maximum(30 - abs(xx - yy), 0) + transition = transition / transition.sum(axis=1, keepdims=True) + to_viterbi_cents.transition = transition + + # Convert to probability + prob = salience.T + prob = prob / prob.sum(axis=0) + + # Perform viterbi decoding + path = librosa.sequence.viterbi(prob.detach().cpu().numpy(), to_viterbi_cents.transition).astype(np.int64) + + return torch.Tensor([to_local_average_cents(salience[i, :], path[i], thred) for i in + range(len(path))]).to(salience.device) + \ No newline at end of file diff --git a/modules/__pycache__/DSConv.cpython-38.pyc b/modules/__pycache__/DSConv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..875399162a47dedbe5626d98efbfc64d65d8f8eb Binary files /dev/null and b/modules/__pycache__/DSConv.cpython-38.pyc differ diff --git a/modules/__pycache__/__init__.cpython-38.pyc b/modules/__pycache__/__init__.cpython-38.pyc index edc72da2db11334f14dfa33aa038fe6b5a927308..d091ce988ce17c7ef56089c355519c7c9ef8e365 100644 Binary files a/modules/__pycache__/__init__.cpython-38.pyc and b/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/__pycache__/attentions.cpython-38.pyc b/modules/__pycache__/attentions.cpython-38.pyc index 01b358fe74b8677e530154d2ceab92228e850042..675e61ad251a31e45d97dd2a3f6f4682484fdea2 100644 Binary files a/modules/__pycache__/attentions.cpython-38.pyc and b/modules/__pycache__/attentions.cpython-38.pyc differ diff --git a/modules/__pycache__/commons.cpython-38.pyc b/modules/__pycache__/commons.cpython-38.pyc index 662499ad13055febd7204dadf96ffdc881b3dc51..47055f9075887cd3ebc8f33f36078dc12b222c86 100644 Binary files a/modules/__pycache__/commons.cpython-38.pyc and b/modules/__pycache__/commons.cpython-38.pyc differ diff --git a/modules/__pycache__/enhancer.cpython-38.pyc b/modules/__pycache__/enhancer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81dc2c4cc439aea823ec2a621cc69d7189ee6bc8 Binary files /dev/null and b/modules/__pycache__/enhancer.cpython-38.pyc differ diff --git a/modules/__pycache__/losses.cpython-38.pyc b/modules/__pycache__/losses.cpython-38.pyc index 948f24596e00f5b257dbb6476655806c72040441..7551add1e5d4f393990871f3949cde7c43158a94 100644 Binary files a/modules/__pycache__/losses.cpython-38.pyc and b/modules/__pycache__/losses.cpython-38.pyc differ diff --git a/modules/__pycache__/mel_processing.cpython-38.pyc b/modules/__pycache__/mel_processing.cpython-38.pyc index 024e33b96c199c64537713a7a47482bed5a8e7a8..f53ec02ee571276085e4f6cf2300aa9fbcdd8326 100644 Binary files a/modules/__pycache__/mel_processing.cpython-38.pyc and b/modules/__pycache__/mel_processing.cpython-38.pyc differ diff --git a/modules/__pycache__/modules.cpython-38.pyc b/modules/__pycache__/modules.cpython-38.pyc index 01b18a639b6b6e785daa34fb18ae0d0d6d4742cf..238f73019ca68c5905baf542d49bc0b216bbdf8e 100644 Binary files a/modules/__pycache__/modules.cpython-38.pyc and b/modules/__pycache__/modules.cpython-38.pyc differ diff --git a/modules/attentions.py b/modules/attentions.py index f9c11ca4a3acb86bf1abc04d9dcfa82a4ed4061f..9086e0ed5944fbf096429e1ee37dc26eec81f9a3 100644 --- a/modules/attentions.py +++ b/modules/attentions.py @@ -1,12 +1,10 @@ -import copy import math -import numpy as np + import torch from torch import nn from torch.nn import functional as F import modules.commons as commons -import modules.modules as modules from modules.modules import LayerNorm @@ -243,7 +241,7 @@ class MultiHeadAttention(nn.Module): return ret def _get_relative_embeddings(self, relative_embeddings, length): - max_relative_position = 2 * self.window_size + 1 + 2 * self.window_size + 1 # Pad first before slice to avoid using cond ops. pad_length = max(length - (self.window_size + 1), 0) slice_start_position = max((self.window_size + 1) - length, 0) diff --git a/modules/commons.py b/modules/commons.py index 074888006392e956ce204d8368362dbb2cd4e304..761379da55793b7f2eca1c9ba511ec767ac1d90e 100644 --- a/modules/commons.py +++ b/modules/commons.py @@ -1,9 +1,9 @@ import math -import numpy as np + import torch -from torch import nn from torch.nn import functional as F + def slice_pitch_segments(x, ids_str, segment_size=4): ret = torch.zeros_like(x[:, :segment_size]) for i in range(x.size(0)): @@ -24,10 +24,12 @@ def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4): def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ - if classname.find("Conv") != -1: + if "Depthwise_Separable" in classname: + m.depth_conv.weight.data.normal_(mean, std) + m.point_conv.weight.data.normal_(mean, std) + elif classname.find("Conv") != -1: m.weight.data.normal_(mean, std) - def get_padding(kernel_size, dilation=1): return int((kernel_size*dilation - dilation)/2) @@ -134,12 +136,6 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): return acts -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - def shift_1d(x): x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] return x @@ -157,7 +153,6 @@ def generate_path(duration, mask): duration: [b, 1, t_x] mask: [b, 1, t_y, t_x] """ - device = duration.device b, _, t_y, t_x = mask.shape cum_duration = torch.cumsum(duration, -1) diff --git a/modules/enhancer.py b/modules/enhancer.py index 37676311f7d8dc4ddc2a5244dedc27b2437e04f5..a3f0dd0460ff6d6153f9277dfa90763bc03861db 100644 --- a/modules/enhancer.py +++ b/modules/enhancer.py @@ -1,10 +1,12 @@ import numpy as np import torch import torch.nn.functional as F -from vdecoder.nsf_hifigan.nvSTFT import STFT -from vdecoder.nsf_hifigan.models import load_model from torchaudio.transforms import Resample +from vdecoder.nsf_hifigan.models import load_model +from vdecoder.nsf_hifigan.nvSTFT import STFT + + class Enhancer: def __init__(self, enhancer_type, enhancer_ckpt, device=None): if device is None: diff --git a/modules/losses.py b/modules/losses.py index cd21799eccde350c3aac0bdd661baf96ed220147..494e979a60ba069114cac609bf6454a99c1019e3 100644 --- a/modules/losses.py +++ b/modules/losses.py @@ -1,7 +1,4 @@ -import torch -from torch.nn import functional as F - -import modules.commons as commons +import torch def feature_loss(fmap_r, fmap_g): diff --git a/modules/mel_processing.py b/modules/mel_processing.py index 99c5b35beb83f3b288af0fac5b49ebf2c69f062c..c21e4bffb6d9f5fd7b45a84176b3e6206f7d29db 100644 --- a/modules/mel_processing.py +++ b/modules/mel_processing.py @@ -1,16 +1,5 @@ -import math -import os -import random import torch -from torch import nn -import torch.nn.functional as F import torch.utils.data -import numpy as np -import librosa -import librosa.util as librosa_util -from librosa.util import normalize, pad_center, tiny -from scipy.signal import get_window -from scipy.io.wavfile import read from librosa.filters import mel as librosa_mel_fn MAX_WAV_VALUE = 32768.0 @@ -62,9 +51,14 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = y.squeeze(1) + + y_dtype = y.dtype + if y.dtype == torch.bfloat16: + y = y.to(torch.float32) spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], - center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) + spec = torch.view_as_real(spec).to(y_dtype) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return spec @@ -83,30 +77,7 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): - if torch.min(y) < -1.: - print('min value is ', torch.min(y)) - if torch.max(y) > 1.: - print('max value is ', torch.max(y)) - - global mel_basis, hann_window - dtype_device = str(y.dtype) + '_' + str(y.device) - fmax_dtype_device = str(fmax) + '_' + dtype_device - wnsize_dtype_device = str(win_size) + '_' + dtype_device - if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') - y = y.squeeze(1) - - spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], - center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - - spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = spectral_normalize_torch(spec) - + spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center) + spec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax) + return spec diff --git a/modules/modules.py b/modules/modules.py index 54290fd207b25e93831bd21005990ea137e6b50e..2b9ad903027de09c0f2393ca3f8341bbba11c9a5 100644 --- a/modules/modules.py +++ b/modules/modules.py @@ -1,20 +1,23 @@ -import copy -import math -import numpy as np -import scipy import torch from torch import nn from torch.nn import functional as F -from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm - import modules.commons as commons -from modules.commons import init_weights, get_padding - +from modules.commons import get_padding, init_weights +from modules.DSConv import ( + Depthwise_Separable_Conv1D, + remove_weight_norm_modules, + weight_norm_modules, +) LRELU_SLOPE = 0.1 +Conv1dModel = nn.Conv1d + +def set_Conv1dModel(use_depthwise_conv): + global Conv1dModel + Conv1dModel = Depthwise_Separable_Conv1D if use_depthwise_conv else nn.Conv1d + class LayerNorm(nn.Module): def __init__(self, channels, eps=1e-5): @@ -44,13 +47,13 @@ class ConvReluNorm(nn.Module): self.conv_layers = nn.ModuleList() self.norm_layers = nn.ModuleList() - self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.conv_layers.append(Conv1dModel(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) self.norm_layers.append(LayerNorm(hidden_channels)) self.relu_drop = nn.Sequential( nn.ReLU(), nn.Dropout(p_dropout)) for _ in range(n_layers-1): - self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.conv_layers.append(Conv1dModel(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) self.norm_layers.append(LayerNorm(hidden_channels)) self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj.weight.data.zero_() @@ -66,47 +69,6 @@ class ConvReluNorm(nn.Module): return x * x_mask -class DDSConv(nn.Module): - """ - Dialted and Depth-Separable Convolution - """ - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - - self.drop = nn.Dropout(p_dropout) - self.convs_sep = nn.ModuleList() - self.convs_1x1 = nn.ModuleList() - self.norms_1 = nn.ModuleList() - self.norms_2 = nn.ModuleList() - for i in range(n_layers): - dilation = kernel_size ** i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, - groups=channels, dilation=dilation, padding=padding - )) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) - - def forward(self, x, x_mask, g=None): - if g is not None: - x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i](y) - y = F.gelu(y) - y = self.convs_1x1[i](y) - y = self.norms_2[i](y) - y = F.gelu(y) - y = self.drop(y) - x = x + y - return x * x_mask - - class WN(torch.nn.Module): def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): super(WN, self).__init__() @@ -124,14 +86,14 @@ class WN(torch.nn.Module): if gin_channels != 0: cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + self.cond_layer = weight_norm_modules(cond_layer, name='weight') for i in range(n_layers): dilation = dilation_rate ** i padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, + in_layer = Conv1dModel(hidden_channels, 2*hidden_channels, kernel_size, dilation=dilation, padding=padding) - in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + in_layer = weight_norm_modules(in_layer, name='weight') self.in_layers.append(in_layer) # last one is not necessary @@ -141,7 +103,7 @@ class WN(torch.nn.Module): res_skip_channels = hidden_channels res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + res_skip_layer = weight_norm_modules(res_skip_layer, name='weight') self.res_skip_layers.append(res_skip_layer) def forward(self, x, x_mask, g=None, **kwargs): @@ -176,32 +138,32 @@ class WN(torch.nn.Module): def remove_weight_norm(self): if self.gin_channels != 0: - torch.nn.utils.remove_weight_norm(self.cond_layer) + remove_weight_norm_modules(self.cond_layer) for l in self.in_layers: - torch.nn.utils.remove_weight_norm(l) + remove_weight_norm_modules(l) for l in self.res_skip_layers: - torch.nn.utils.remove_weight_norm(l) + remove_weight_norm_modules(l) class ResBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() self.convs1 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))) ]) self.convs1.apply(init_weights) self.convs2 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))) ]) self.convs2.apply(init_weights) @@ -223,18 +185,18 @@ class ResBlock1(torch.nn.Module): def remove_weight_norm(self): for l in self.convs1: - remove_weight_norm(l) + remove_weight_norm_modules(l) for l in self.convs2: - remove_weight_norm(l) + remove_weight_norm_modules(l) class ResBlock2(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3)): super(ResBlock2, self).__init__() self.convs = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))) ]) self.convs.apply(init_weights) @@ -252,7 +214,7 @@ class ResBlock2(torch.nn.Module): def remove_weight_norm(self): for l in self.convs: - remove_weight_norm(l) + remove_weight_norm_modules(l) class Log(nn.Module): @@ -303,7 +265,9 @@ class ResidualCouplingLayer(nn.Module): n_layers, p_dropout=0, gin_channels=0, - mean_only=False): + mean_only=False, + wn_sharing_parameter=None + ): assert channels % 2 == 0, "channels should be divisible by 2" super().__init__() self.channels = channels @@ -315,7 +279,7 @@ class ResidualCouplingLayer(nn.Module): self.mean_only = mean_only self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) - self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) self.post.weight.data.zero_() self.post.bias.data.zero_() diff --git a/vdecoder/__pycache__/__init__.cpython-38.pyc b/vdecoder/__pycache__/__init__.cpython-38.pyc index 529baf2ccabbcb12ad03341aea38e07fcdda8354..6b35a44716498f7b4913f63d510a0e44f648e623 100644 Binary files a/vdecoder/__pycache__/__init__.cpython-38.pyc and b/vdecoder/__pycache__/__init__.cpython-38.pyc differ diff --git a/vdecoder/hifigan/__pycache__/env.cpython-38.pyc b/vdecoder/hifigan/__pycache__/env.cpython-38.pyc index 537322df5131060b599754f440572416fe41ba6e..742037b44b9e59a61ed41825e240a60c4c7e275f 100644 Binary files a/vdecoder/hifigan/__pycache__/env.cpython-38.pyc and b/vdecoder/hifigan/__pycache__/env.cpython-38.pyc differ diff --git a/vdecoder/hifigan/__pycache__/models.cpython-38.pyc b/vdecoder/hifigan/__pycache__/models.cpython-38.pyc index 678b6a647f2525f856217d01baec0be0b0b44b59..9ddda2666f2e725a3f8e48aa4c1550cb5831b888 100644 Binary files a/vdecoder/hifigan/__pycache__/models.cpython-38.pyc and b/vdecoder/hifigan/__pycache__/models.cpython-38.pyc differ diff --git a/vdecoder/hifigan/__pycache__/utils.cpython-38.pyc b/vdecoder/hifigan/__pycache__/utils.cpython-38.pyc index 229c666d83a30473700f9ca14e5d493be9e2f422..051ff7f16d27cf4f97a2108a3813451f78f76abf 100644 Binary files a/vdecoder/hifigan/__pycache__/utils.cpython-38.pyc and b/vdecoder/hifigan/__pycache__/utils.cpython-38.pyc differ diff --git a/vdecoder/hifigan/models.py b/vdecoder/hifigan/models.py index 9747301f350bb269e62601017fe4633ce271b27e..8e79752b6fc1f06376bcb6000d2658f34d15a913 100644 --- a/vdecoder/hifigan/models.py +++ b/vdecoder/hifigan/models.py @@ -1,13 +1,15 @@ -import os import json -from .env import AttrDict +import os + import numpy as np import torch -import torch.nn.functional as F import torch.nn as nn -from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from .utils import init_weights, get_padding +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .env import AttrDict +from .utils import get_padding, init_weights LRELU_SLOPE = 0.1 @@ -199,8 +201,6 @@ class SineGen(torch.nn.Module): output uv: tensor(batchsize=1, length, 1) """ with torch.no_grad(): - f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, - device=f0.device) # fundamental component fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) @@ -266,7 +266,7 @@ class SourceModuleHnNSF(torch.nn.Module): """ # source for harmonic branch sine_wavs, uv, _ = self.l_sin_gen(x) - sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype))) # source for noise branch, in the same shape as uv noise = torch.randn_like(uv) * self.sine_amp / 3 @@ -292,11 +292,11 @@ class Generator(torch.nn.Module): c_cur = h["upsample_initial_channel"] // (2 ** (i + 1)) self.ups.append(weight_norm( ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)), - k, u, padding=(k - u) // 2))) + k, u, padding=(k - u +1 ) // 2))) if i + 1 < len(h["upsample_rates"]): # stride_f0 = np.prod(h["upsample_rates"][i + 1:]) self.noise_convs.append(Conv1d( - 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2)) else: self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) self.resblocks = nn.ModuleList() @@ -353,7 +353,7 @@ class DiscriminatorP(torch.nn.Module): def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): super(DiscriminatorP, self).__init__() self.period = period - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), @@ -412,7 +412,7 @@ class MultiPeriodDiscriminator(torch.nn.Module): class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv1d(1, 128, 15, 1, padding=7)), norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), diff --git a/vdecoder/hifigan/nvSTFT.py b/vdecoder/hifigan/nvSTFT.py index 88597d62a505715091f9ba62d38bf0a85a31b95a..b3321b2ee3da28f43c2650ea011e14d5e1cdcc94 100644 --- a/vdecoder/hifigan/nvSTFT.py +++ b/vdecoder/hifigan/nvSTFT.py @@ -1,15 +1,13 @@ -import math import os -os.environ["LRU_CACHE_CAPACITY"] = "3" -import random + +import librosa +import numpy as np +import soundfile as sf import torch import torch.utils.data -import numpy as np -import librosa -from librosa.util import normalize from librosa.filters import mel as librosa_mel_fn -from scipy.io.wavfile import read -import soundfile as sf + +os.environ["LRU_CACHE_CAPACITY"] = "3" def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): sampling_rate = None diff --git a/vdecoder/hifigan/utils.py b/vdecoder/hifigan/utils.py index 9c93c996d3cc73c30d71c1fc47056e4230f35c0f..e519e2b7ed8fe5f93266d21d727a30173699f88b 100644 --- a/vdecoder/hifigan/utils.py +++ b/vdecoder/hifigan/utils.py @@ -1,10 +1,10 @@ import glob import os -import matplotlib -import torch -from torch.nn.utils import weight_norm + # matplotlib.use("Agg") import matplotlib.pylab as plt +import torch +from torch.nn.utils import weight_norm def plot_spectrogram(spectrogram): diff --git a/vdecoder/hifiganwithsnake/alias/__init__.py b/vdecoder/hifiganwithsnake/alias/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..be97a33248ae6378c6736586774abda11cfbdeba --- /dev/null +++ b/vdecoder/hifiganwithsnake/alias/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .act import * # noqa: F403 +from .filter import * # noqa: F403 +from .resample import * # noqa: F403 diff --git a/vdecoder/hifiganwithsnake/alias/act.py b/vdecoder/hifiganwithsnake/alias/act.py new file mode 100644 index 0000000000000000000000000000000000000000..e46b3467b73b90df51c1d19032b90d26595aca6e --- /dev/null +++ b/vdecoder/hifiganwithsnake/alias/act.py @@ -0,0 +1,130 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import pow, sin +from torch.nn import Parameter + +from .resample import DownSample1d, UpSample1d + + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta = x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze( + 0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + return x + + +class Mish(nn.Module): + """ + Mish activation function is proposed in "Mish: A Self + Regularized Non-Monotonic Neural Activation Function" + paper, https://arxiv.org/abs/1908.08681. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class SnakeAlias(nn.Module): + def __init__(self, + channels, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + C = None): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = SnakeBeta(channels, alpha_logscale=True) + self.upsample = UpSample1d(up_ratio, up_kernel_size, C) + self.downsample = DownSample1d(down_ratio, down_kernel_size, C) + + # x: [B,C,T] + def forward(self, x, C=None): + x = self.upsample(x, C) + x = self.act(x) + x = self.downsample(x) + + return x \ No newline at end of file diff --git a/vdecoder/hifiganwithsnake/alias/filter.py b/vdecoder/hifiganwithsnake/alias/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..3942eb3ae547a2f500d5c47defdd70cd29ea4655 --- /dev/null +++ b/vdecoder/hifiganwithsnake/alias/filter.py @@ -0,0 +1,110 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12, + C=None): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + self.conv1d_block = None + if C is not None: + self.conv1d_block = [nn.Conv1d(C,C,kernel_size,stride=self.stride, groups=C, bias=False),] + self.conv1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1)) + self.conv1d_block[0].requires_grad_(False) + + #input [B, C, T] + def forward(self, x): + if self.conv1d_block[0].weight.device != x.device: + self.conv1d_block[0] = self.conv1d_block[0].to(x.device) + if self.conv1d_block is None: + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + else: + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = self.conv1d_block[0](x) + + return out \ No newline at end of file diff --git a/vdecoder/hifiganwithsnake/alias/resample.py b/vdecoder/hifiganwithsnake/alias/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..a364403f0977bc8bcffbb4764081e4bd3619467a --- /dev/null +++ b/vdecoder/hifiganwithsnake/alias/resample.py @@ -0,0 +1,72 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F + +from .filter import LowPassFilter1d, kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, C=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + self.conv_transpose1d_block = None + if C is not None: + self.conv_transpose1d_block = [nn.ConvTranspose1d(C, + C, + kernel_size=self.kernel_size, + stride=self.stride, + groups=C, + bias=False + ),] + self.conv_transpose1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1).clone()) + self.conv_transpose1d_block[0].requires_grad_(False) + + + + # x: [B, C, T] + def forward(self, x, C=None): + if self.conv_transpose1d_block[0].weight.device != x.device: + self.conv_transpose1d_block[0] = self.conv_transpose1d_block[0].to(x.device) + if self.conv_transpose1d_block is None: + if C is None: + _, C, _ = x.shape + # print("snake.conv_t.in:",x.shape) + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + # print("snake.conv_t.out:",x.shape) + x = x[..., self.pad_left:-self.pad_right] + else: + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * self.conv_transpose1d_block[0](x) + x = x[..., self.pad_left:-self.pad_right] + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, C=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + C=C) + + + def forward(self, x): + xx = self.lowpass(x) + + return xx \ No newline at end of file diff --git a/vdecoder/hifiganwithsnake/env.py b/vdecoder/hifiganwithsnake/env.py new file mode 100644 index 0000000000000000000000000000000000000000..2bdbc95d4f7a8bad8fd4f5eef657e2b51d946056 --- /dev/null +++ b/vdecoder/hifiganwithsnake/env.py @@ -0,0 +1,15 @@ +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/vdecoder/hifiganwithsnake/models.py b/vdecoder/hifiganwithsnake/models.py new file mode 100644 index 0000000000000000000000000000000000000000..9b64f9c0ced677a66341491ad951b7019c8ee0fa --- /dev/null +++ b/vdecoder/hifiganwithsnake/models.py @@ -0,0 +1,521 @@ +import json +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from vdecoder.hifiganwithsnake.alias.act import SnakeAlias + +from .env import AttrDict +from .utils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +def load_model(model_path, device='cuda'): + config_file = os.path.join(os.path.split(model_path)[0], 'config.json') + with open(config_file) as f: + data = f.read() + + global h + json_config = json.loads(data) + h = AttrDict(json_config) + + generator = Generator(h).to(device) + + cp_dict = torch.load(model_path) + generator.load_state_dict(cp_dict['generator']) + generator.eval() + generator.remove_weight_norm() + del cp_dict + return generator, h + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), C=None): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) + self.activations = nn.ModuleList([ + SnakeAlias(channels, C=C) for _ in range(self.num_layers) + ]) + + def forward(self, x, DIM=None): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x, DIM) + xt = c1(xt) + xt = a2(xt, DIM) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), C=None): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) + self.activations = nn.ModuleList([ + SnakeAlias(channels, C=C) for _ in range(self.num_layers) + ]) + + def forward(self, x, DIM=None): + for c,a in zip(self.convs, self.activations): + xt = a(x, DIM) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +def padDiff(x): + return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0) + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \ + device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + # for normal case + + # To prevent torch.cumsum numerical overflow, + # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. + # Buffer tmp_over_one_idx indicates the time step to add -1. + # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi + tmp_over_one = torch.cumsum(rad_values, 1) % 1 + tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) + * 2 * np.pi) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + with torch.no_grad(): + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + sine_wavs, uv, _ = self.l_sin_gen(x) + sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype))) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + + self.num_kernels = len(h["resblock_kernel_sizes"]) + self.num_upsamples = len(h["upsample_rates"]) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h["upsample_rates"])) + self.m_source = SourceModuleHnNSF( + sampling_rate=h["sampling_rate"], + harmonic_num=8) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm(Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3)) + resblock = ResBlock1 if h["resblock"] == '1' else ResBlock2 + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h["upsample_rates"], h["upsample_kernel_sizes"])): + c_cur = h["upsample_initial_channel"] // (2 ** (i + 1)) + self.ups.append(weight_norm( + ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)), + k, u, padding=(k - u + 1) // 2))) + if i + 1 < len(h["upsample_rates"]): # + stride_f0 = np.prod(h["upsample_rates"][i + 1:]) + self.noise_convs.append(Conv1d( + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+ 1) // 2)) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + self.resblocks = nn.ModuleList() + self.snakes = nn.ModuleList() + for i in range(len(self.ups)): + ch = h["upsample_initial_channel"] // (2 ** (i + 1)) + self.snakes.append(SnakeAlias(h["upsample_initial_channel"] // (2 ** (i)), C = h["upsample_initial_channel"] >> i)) + for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])): + self.resblocks.append(resblock(h, ch, k, d, C = h["upsample_initial_channel"] >> (i + 1))) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.snake_post = SnakeAlias(ch, C = h["upsample_initial_channel"] >> len(self.ups)) + self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1) + + def forward(self, x, f0, g=None): + # print(1,x.shape,f0.shape,f0[:, None].shape) + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + # print(2,f0.shape) + har_source, noi_source, uv = self.m_source(f0) + har_source = har_source.transpose(1, 2) + x = self.conv_pre(x) + x = x + self.cond(g) + # print(124,x.shape,har_source.shape) + for i in range(self.num_upsamples): + # print(f"self.snakes.{i}.pre:", x.shape) + x = self.snakes[i](x) + # print(f"self.snakes.{i}.after:", x.shape) + x = self.ups[i](x) + x_source = self.noise_convs[i](har_source) + # print(4,x_source.shape,har_source.shape,x.shape) + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + # print(f"self.resblocks.{i}.after:", xs.shape) + x = xs / self.num_kernels + x = self.snake_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, periods=None): + super(MultiPeriodDiscriminator, self).__init__() + self.periods = periods if periods is not None else [2, 3, 5, 7, 11] + self.discriminators = nn.ModuleList() + for period in self.periods: + self.discriminators.append(DiscriminatorP(period)) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/vdecoder/hifiganwithsnake/nvSTFT.py b/vdecoder/hifiganwithsnake/nvSTFT.py new file mode 100644 index 0000000000000000000000000000000000000000..b3321b2ee3da28f43c2650ea011e14d5e1cdcc94 --- /dev/null +++ b/vdecoder/hifiganwithsnake/nvSTFT.py @@ -0,0 +1,109 @@ +import os + +import librosa +import numpy as np +import soundfile as sf +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +os.environ["LRU_CACHE_CAPACITY"] = "3" + +def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): + sampling_rate = None + try: + data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. + except Exception as ex: + print(f"'{full_path}' failed to load.\nException:") + print(ex) + if return_empty_on_exception: + return [], sampling_rate or target_sr or 32000 + else: + raise Exception(ex) + + if len(data.shape) > 1: + data = data[:, 0] + assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) + + if np.issubdtype(data.dtype, np.integer): # if audio data is type int + max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX + else: # if audio data is type fp32 + max_mag = max(np.amax(data), -np.amin(data)) + max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 + + data = torch.FloatTensor(data.astype(np.float32))/max_mag + + if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except + return [], sampling_rate or target_sr or 32000 + if target_sr is not None and sampling_rate != target_sr: + data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) + sampling_rate = target_sr + + return data, sampling_rate + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + +class STFT(): + def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): + self.target_sr = sr + + self.n_mels = n_mels + self.n_fft = n_fft + self.win_size = win_size + self.hop_length = hop_length + self.fmin = fmin + self.fmax = fmax + self.clip_val = clip_val + self.mel_basis = {} + self.hann_window = {} + + def get_mel(self, y, center=False): + sampling_rate = self.target_sr + n_mels = self.n_mels + n_fft = self.n_fft + win_size = self.win_size + hop_length = self.hop_length + fmin = self.fmin + fmax = self.fmax + clip_val = self.clip_val + + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + if fmax not in self.mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + # print(111,spec) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + # print(222,spec) + spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec) + # print(333,spec) + spec = dynamic_range_compression_torch(spec, clip_val=clip_val) + # print(444,spec) + return spec + + def __call__(self, audiopath): + audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) + spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) + return spect + +stft = STFT() diff --git a/vdecoder/hifiganwithsnake/utils.py b/vdecoder/hifiganwithsnake/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e519e2b7ed8fe5f93266d21d727a30173699f88b --- /dev/null +++ b/vdecoder/hifiganwithsnake/utils.py @@ -0,0 +1,68 @@ +import glob +import os + +# matplotlib.use("Agg") +import matplotlib.pylab as plt +import torch +from torch.nn.utils import weight_norm + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def del_old_checkpoints(cp_dir, prefix, n_models=2): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) # get checkpoint paths + cp_list = sorted(cp_list)# sort by iter + if len(cp_list) > n_models: # if more than n_models models are found + for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models + open(cp, 'w').close()# empty file contents + os.unlink(cp)# delete file (move to trash when using Colab) + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + diff --git a/vdecoder/nsf_hifigan/__pycache__/env.cpython-38.pyc b/vdecoder/nsf_hifigan/__pycache__/env.cpython-38.pyc index 8bc0218aa29dc35ee7d0eed7df3281f479bc0c44..0c8f6b87dba6b3ec0ff16b192305400f22d1cc56 100644 Binary files a/vdecoder/nsf_hifigan/__pycache__/env.cpython-38.pyc and b/vdecoder/nsf_hifigan/__pycache__/env.cpython-38.pyc differ diff --git a/vdecoder/nsf_hifigan/__pycache__/models.cpython-38.pyc b/vdecoder/nsf_hifigan/__pycache__/models.cpython-38.pyc index 003a7f9126f5719aea0005a3ad0e3c9ab4b53919..f25aff06b55d97475038f84eca6affc78d84de48 100644 Binary files a/vdecoder/nsf_hifigan/__pycache__/models.cpython-38.pyc and b/vdecoder/nsf_hifigan/__pycache__/models.cpython-38.pyc differ diff --git a/vdecoder/nsf_hifigan/__pycache__/nvSTFT.cpython-38.pyc b/vdecoder/nsf_hifigan/__pycache__/nvSTFT.cpython-38.pyc index 14b5693befbb2300619950ca3ef1422e8fd5d1f8..3dd97d5be33c36ab7ab04719603f1444dd7f8506 100644 Binary files a/vdecoder/nsf_hifigan/__pycache__/nvSTFT.cpython-38.pyc and b/vdecoder/nsf_hifigan/__pycache__/nvSTFT.cpython-38.pyc differ diff --git a/vdecoder/nsf_hifigan/__pycache__/utils.cpython-38.pyc b/vdecoder/nsf_hifigan/__pycache__/utils.cpython-38.pyc index d488da647daf67ceb8b08bd9941ca73b656c076d..7bfebb47c4c610db2bf8a1b9ee55759918f06eff 100644 Binary files a/vdecoder/nsf_hifigan/__pycache__/utils.cpython-38.pyc and b/vdecoder/nsf_hifigan/__pycache__/utils.cpython-38.pyc differ diff --git a/vdecoder/nsf_hifigan/models.py b/vdecoder/nsf_hifigan/models.py index c2c889ec2fbd215702298ba2b7c411c6f5630d80..8a35b134d814008c3990d019d1de502ff10dd86f 100644 --- a/vdecoder/nsf_hifigan/models.py +++ b/vdecoder/nsf_hifigan/models.py @@ -1,13 +1,15 @@ -import os import json -from .env import AttrDict +import os + import numpy as np import torch -import torch.nn.functional as F import torch.nn as nn -from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from .utils import init_weights, get_padding +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .env import AttrDict +from .utils import get_padding, init_weights LRELU_SLOPE = 0.1 @@ -289,7 +291,7 @@ class DiscriminatorP(torch.nn.Module): def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): super(DiscriminatorP, self).__init__() self.period = period - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), @@ -348,7 +350,7 @@ class MultiPeriodDiscriminator(torch.nn.Module): class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList([ norm_f(Conv1d(1, 128, 15, 1, padding=7)), norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), diff --git a/vdecoder/nsf_hifigan/nvSTFT.py b/vdecoder/nsf_hifigan/nvSTFT.py index 62bd5a008f81929054f036c81955d5d73377f772..e756cca561a45bde435f36447e6681bfa17e34aa 100644 --- a/vdecoder/nsf_hifigan/nvSTFT.py +++ b/vdecoder/nsf_hifigan/nvSTFT.py @@ -1,16 +1,14 @@ -import math import os -os.environ["LRU_CACHE_CAPACITY"] = "3" -import random -import torch -import torch.utils.data -import numpy as np + import librosa -from librosa.util import normalize -from librosa.filters import mel as librosa_mel_fn -from scipy.io.wavfile import read +import numpy as np import soundfile as sf +import torch import torch.nn.functional as F +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +os.environ["LRU_CACHE_CAPACITY"] = "3" def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): sampling_rate = None diff --git a/vdecoder/nsf_hifigan/utils.py b/vdecoder/nsf_hifigan/utils.py index 84bff024f4d2e2de194b2a88ee7bbe5f0d33f67c..58d0e701d377e318fe0302743c27bdb4d6e089ec 100644 --- a/vdecoder/nsf_hifigan/utils.py +++ b/vdecoder/nsf_hifigan/utils.py @@ -1,10 +1,12 @@ import glob import os + import matplotlib +import matplotlib.pylab as plt import torch from torch.nn.utils import weight_norm + matplotlib.use("Agg") -import matplotlib.pylab as plt def plot_spectrogram(spectrogram): diff --git a/vencoder/CNHubertLarge.py b/vencoder/CNHubertLarge.py new file mode 100644 index 0000000000000000000000000000000000000000..f43694762f92c5d839d358825f157f5d1a4ff6f6 --- /dev/null +++ b/vencoder/CNHubertLarge.py @@ -0,0 +1,36 @@ +import torch +from fairseq import checkpoint_utils + +from vencoder.encoder import SpeechEncoder + + +class CNHubertLarge(SpeechEncoder): + def __init__(self, vec_path="pretrain/chinese-hubert-large-fairseq-ckpt.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 1024 + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [vec_path], + suffix="", + ) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.model = models[0].to(self.dev) + self.model.eval() + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.to(wav.device), + "padding_mask": padding_mask.to(wav.device) + } + with torch.no_grad(): + logits = self.model.extract_features(**inputs) + return logits[0].transpose(1, 2) \ No newline at end of file diff --git a/vencoder/ContentVec256L12_Onnx.py b/vencoder/ContentVec256L12_Onnx.py index 9ad5085e02654fd1fcfbdad7d476bfa9b763d2c6..466e6c128b88acdfb94392662086e6752d503a27 100644 --- a/vencoder/ContentVec256L12_Onnx.py +++ b/vencoder/ContentVec256L12_Onnx.py @@ -1,25 +1,30 @@ -from vencoder.encoder import SpeechEncoder import onnxruntime import torch +from vencoder.encoder import SpeechEncoder + + class ContentVec256L12_Onnx(SpeechEncoder): - def __init__(self,vec_path = "pretrain/vec-256-layer-12.onnx",device=None): + def __init__(self, vec_path="pretrain/vec-256-layer-12.onnx", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 256 if device is None: self.dev = torch.device("cpu") else: self.dev = torch.device(device) - if device == 'cpu' or device == torch.device("cpu") or device is None: - providers = ['CPUExecutionProvider'] - elif device == 'cuda' or device == torch.device("cuda"): + + if device == 'cuda' or device == torch.device("cuda"): providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) feats = feats.unsqueeze(0).cpu().detach().numpy() diff --git a/vencoder/ContentVec256L9.py b/vencoder/ContentVec256L9.py index b0089c789cd87cfd3b1badb2fc45cb1b88041eab..c973090dd4cdaa3d8ca07d9007c26633883c36a7 100644 --- a/vencoder/ContentVec256L9.py +++ b/vencoder/ContentVec256L9.py @@ -1,9 +1,12 @@ -from vencoder.encoder import SpeechEncoder import torch from fairseq import checkpoint_utils +from vencoder.encoder import SpeechEncoder + + class ContentVec256L9(SpeechEncoder): - def __init__(self,vec_path = "pretrain/checkpoint_best_legacy_500.pt",device=None): + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [vec_path], @@ -20,7 +23,7 @@ class ContentVec256L9(SpeechEncoder): def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) padding_mask = torch.BoolTensor(feats.shape).fill_(False) @@ -30,6 +33,6 @@ class ContentVec256L9(SpeechEncoder): "output_layer": 9, # layer 9 } with torch.no_grad(): - logits = self.model.extract_features(**inputs) - feats = self.model.final_proj(logits[0]) + logits = self.model.extract_features(**inputs) + feats = self.model.final_proj(logits[0]) return feats.transpose(1, 2) diff --git a/vencoder/ContentVec256L9_Onnx.py b/vencoder/ContentVec256L9_Onnx.py index fae2b928252801795b038f51451b234e007f6f03..a27e1f76655d9dc9fcc41d05d11b4a1ac5d85b90 100644 --- a/vencoder/ContentVec256L9_Onnx.py +++ b/vencoder/ContentVec256L9_Onnx.py @@ -1,9 +1,12 @@ -from vencoder.encoder import SpeechEncoder import onnxruntime import torch +from vencoder.encoder import SpeechEncoder + + class ContentVec256L9_Onnx(SpeechEncoder): - def __init__(self,vec_path = "pretrain/vec-256-layer-9.onnx",device=None): + def __init__(self, vec_path="pretrain/vec-256-layer-9.onnx", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 256 if device is None: @@ -19,10 +22,11 @@ class ContentVec256L9_Onnx(SpeechEncoder): def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) feats = feats.unsqueeze(0).cpu().detach().numpy() onnx_input = {self.model.get_inputs()[0].name: feats} logits = self.model.run(None, onnx_input) - return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) \ No newline at end of file + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) + \ No newline at end of file diff --git a/vencoder/ContentVec768L12.py b/vencoder/ContentVec768L12.py index 0d1591c8843b920d5685e822354e8e6adc9a9e19..066b824b68447b5c860730c9f11b7be415068b46 100644 --- a/vencoder/ContentVec768L12.py +++ b/vencoder/ContentVec768L12.py @@ -1,9 +1,12 @@ -from vencoder.encoder import SpeechEncoder import torch from fairseq import checkpoint_utils +from vencoder.encoder import SpeechEncoder + + class ContentVec768L12(SpeechEncoder): - def __init__(self,vec_path = "pretrain/checkpoint_best_legacy_500.pt",device=None): + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 768 models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( @@ -20,7 +23,7 @@ class ContentVec768L12(SpeechEncoder): def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) padding_mask = torch.BoolTensor(feats.shape).fill_(False) @@ -30,5 +33,5 @@ class ContentVec768L12(SpeechEncoder): "output_layer": 12, # layer 12 } with torch.no_grad(): - logits = self.model.extract_features(**inputs) - return logits[0].transpose(1, 2) \ No newline at end of file + logits = self.model.extract_features(**inputs) + return logits[0].transpose(1, 2) diff --git a/vencoder/ContentVec768L12_Onnx.py b/vencoder/ContentVec768L12_Onnx.py index 8dde0f173ed60169282128cc51eb1c200c5d82c5..e737594526fd09f19353b85c11d4c357a325af48 100644 --- a/vencoder/ContentVec768L12_Onnx.py +++ b/vencoder/ContentVec768L12_Onnx.py @@ -1,28 +1,33 @@ -from vencoder.encoder import SpeechEncoder import onnxruntime import torch +from vencoder.encoder import SpeechEncoder + + class ContentVec768L12_Onnx(SpeechEncoder): - def __init__(self,vec_path = "pretrain/vec-768-layer-12.onnx",device=None): + def __init__(self, vec_path="pretrain/vec-768-layer-12.onnx", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 768 if device is None: self.dev = torch.device("cpu") else: self.dev = torch.device(device) - if device == 'cpu' or device == torch.device("cpu") or device is None: - providers = ['CPUExecutionProvider'] - elif device == 'cuda' or device == torch.device("cuda"): + + if device == 'cuda' or device == torch.device("cuda"): providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) feats = feats.unsqueeze(0).cpu().detach().numpy() onnx_input = {self.model.get_inputs()[0].name: feats} logits = self.model.run(None, onnx_input) - return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) \ No newline at end of file + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/ContentVec768L9_Onnx.py b/vencoder/ContentVec768L9_Onnx.py index 7cdac4cd93478d3ddddb4b76dd9d9ccc5d1af2d4..3bd0f337bbf5fa261ea43adfab2377fced7c9e7c 100644 --- a/vencoder/ContentVec768L9_Onnx.py +++ b/vencoder/ContentVec768L9_Onnx.py @@ -1,28 +1,33 @@ -from vencoder.encoder import SpeechEncoder import onnxruntime import torch +from vencoder.encoder import SpeechEncoder + + class ContentVec768L9_Onnx(SpeechEncoder): def __init__(self,vec_path = "pretrain/vec-768-layer-9.onnx",device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 768 if device is None: self.dev = torch.device("cpu") else: self.dev = torch.device(device) - if device == 'cpu' or device == torch.device("cpu") or device is None: - providers = ['CPUExecutionProvider'] - elif device == 'cuda' or device == torch.device("cuda"): + + if device == 'cuda' or device == torch.device("cuda"): providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) feats = feats.unsqueeze(0).cpu().detach().numpy() onnx_input = {self.model.get_inputs()[0].name: feats} logits = self.model.run(None, onnx_input) - return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) \ No newline at end of file + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/DPHubert.py b/vencoder/DPHubert.py new file mode 100644 index 0000000000000000000000000000000000000000..130064ff3ea5c24017be2f0faa204fc4c7dbd078 --- /dev/null +++ b/vencoder/DPHubert.py @@ -0,0 +1,29 @@ +import torch + +from vencoder.dphubert.model import wav2vec2_model +from vencoder.encoder import SpeechEncoder + + +class DPHubert(SpeechEncoder): + def __init__(self, vec_path="pretrain/DPHuBERT-sp0.75.pth", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + ckpt = torch.load(vec_path) + self.hidden_dim = 768 + self.model = wav2vec2_model(**ckpt["config"]).to(self.dev) + self.model.load_state_dict(ckpt["state_dict"], strict=False) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats[None, :] + with torch.no_grad(): + with torch.inference_mode(): + units = self.model(feats)[0] + return units.transpose(1,2) diff --git a/vencoder/HubertSoft.py b/vencoder/HubertSoft.py index e540775d9b6336953ab8642fa424a5e7e3e38c3f..423c159c44f0e5cb820a911a47b71ae1478d725d 100644 --- a/vencoder/HubertSoft.py +++ b/vencoder/HubertSoft.py @@ -1,8 +1,12 @@ -from vencoder.encoder import SpeechEncoder import torch + +from vencoder.encoder import SpeechEncoder from vencoder.hubert import hubert_model + + class HubertSoft(SpeechEncoder): - def __init__(self,vec_path = "pretrain/hubert-soft-0d54a1f4.pt",device=None): + def __init__(self, vec_path="pretrain/hubert-soft-0d54a1f4.pt", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) hubert_soft = hubert_model.hubert_soft(vec_path) if device is None: @@ -15,9 +19,10 @@ class HubertSoft(SpeechEncoder): def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() - feats = feats[None,None,:] - with torch.inference_mode(): - units = self.model.units(feats) - return units.transpose(1,2) + feats = feats[None,None,:] + with torch.no_grad(): + with torch.inference_mode(): + units = self.model.units(feats) + return units.transpose(1,2) diff --git a/vencoder/HubertSoft_Onnx.py b/vencoder/HubertSoft_Onnx.py index 06f10a4ca79c429ed59ab9743578128e8db506cc..038d78e8ffa0804cb63b146f8122b3f2bba2f637 100644 --- a/vencoder/HubertSoft_Onnx.py +++ b/vencoder/HubertSoft_Onnx.py @@ -1,28 +1,33 @@ -from vencoder.encoder import SpeechEncoder import onnxruntime import torch +from vencoder.encoder import SpeechEncoder + + class HubertSoft_Onnx(SpeechEncoder): - def __init__(self,vec_path = "pretrain/hubert-soft.onnx",device=None): + def __init__(self, vec_path="pretrain/hubert-soft.onnx", device=None): + super().__init__() print("load model(s) from {}".format(vec_path)) self.hidden_dim = 256 if device is None: self.dev = torch.device("cpu") else: self.dev = torch.device(device) - if device == 'cpu' or device == torch.device("cpu") or device is None: - providers = ['CPUExecutionProvider'] - elif device == 'cuda' or device == torch.device("cuda"): + + if device == 'cuda' or device == torch.device("cuda"): providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) def encoder(self, wav): feats = wav if feats.dim() == 2: # double channels - feats = feats.mean(-1) + feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1) feats = feats.unsqueeze(0).cpu().detach().numpy() onnx_input = {self.model.get_inputs()[0].name: feats} logits = self.model.run(None, onnx_input) - return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) \ No newline at end of file + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/WavLMBasePlus.py b/vencoder/WavLMBasePlus.py new file mode 100644 index 0000000000000000000000000000000000000000..99df15be73c0c4774cea83a376f79fb68405bfa1 --- /dev/null +++ b/vencoder/WavLMBasePlus.py @@ -0,0 +1,32 @@ +import torch + +from vencoder.encoder import SpeechEncoder +from vencoder.wavlm.WavLM import WavLM, WavLMConfig + + +class WavLMBasePlus(SpeechEncoder): + def __init__(self, vec_path="pretrain/WavLM-Base+.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + checkpoint = torch.load(vec_path) + self.cfg = WavLMConfig(checkpoint['cfg']) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.hidden_dim = self.cfg.encoder_embed_dim + self.model = WavLM(self.cfg) + self.model.load_state_dict(checkpoint['model']) + self.model.to(self.dev).eval() + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + if self.cfg.normalize: + feats = torch.nn.functional.layer_norm(feats, feats.shape) + with torch.no_grad(): + with torch.inference_mode(): + units = self.model.extract_features(feats[None, :])[0] + return units.transpose(1, 2) diff --git a/vencoder/WhisperPPG.py b/vencoder/WhisperPPG.py index aa988b0a6d05696ea519d1652e5801302ba8a6c6..86af53e69b5f60f143a4acce0949c24812e327d1 100644 --- a/vencoder/WhisperPPG.py +++ b/vencoder/WhisperPPG.py @@ -1,12 +1,13 @@ -from vencoder.encoder import SpeechEncoder import torch -from vencoder.whisper.model import Whisper, ModelDimensions -from vencoder.whisper.audio import pad_or_trim, log_mel_spectrogram +from vencoder.encoder import SpeechEncoder +from vencoder.whisper.audio import log_mel_spectrogram, pad_or_trim +from vencoder.whisper.model import ModelDimensions, Whisper class WhisperPPG(SpeechEncoder): - def __init__(self,vec_path = "pretrain/medium.pt",device=None): + def __init__(self, vec_path="pretrain/medium.pt", device=None): + super().__init__() if device is None: self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: @@ -26,5 +27,5 @@ class WhisperPPG(SpeechEncoder): mel = log_mel_spectrogram(audio).to(self.dev) with torch.no_grad(): ppg = self.model.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() - ppg = torch.FloatTensor(ppg[:ppgln,]).to(self.dev) - return ppg[None,:,:].transpose(1, 2) + ppg = torch.FloatTensor(ppg[:ppgln, ]).to(self.dev) + return ppg[None, :, :].transpose(1, 2) diff --git a/vencoder/WhisperPPGLarge.py b/vencoder/WhisperPPGLarge.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d3ea212bff50c11c2711077c67800b06318e3a --- /dev/null +++ b/vencoder/WhisperPPGLarge.py @@ -0,0 +1,31 @@ +import torch + +from vencoder.encoder import SpeechEncoder +from vencoder.whisper.audio import log_mel_spectrogram, pad_or_trim +from vencoder.whisper.model import ModelDimensions, Whisper + + +class WhisperPPGLarge(SpeechEncoder): + def __init__(self, vec_path="pretrain/large-v2.pt", device=None): + super().__init__() + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + checkpoint = torch.load(vec_path, map_location=device) + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) + self.hidden_dim = dims + self.model = model.to(self.dev) + + def encoder(self, wav): + audio = wav + audln = audio.shape[0] + ppgln = audln // 320 + audio = pad_or_trim(audio) + mel = log_mel_spectrogram(audio).to(self.dev) + with torch.no_grad(): + ppg = self.model.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() + ppg = torch.FloatTensor(ppg[:ppgln, ]).to(self.dev) + return ppg[None, :, :].transpose(1, 2) diff --git a/vencoder/__pycache__/ContentVec256L9.cpython-38.pyc b/vencoder/__pycache__/ContentVec256L9.cpython-38.pyc index abb0a9e7ce8ac2d9626250a7184497ddbab4ee64..122d12c51400151a2e8ef413463ad3868ca9a881 100644 Binary files a/vencoder/__pycache__/ContentVec256L9.cpython-38.pyc and b/vencoder/__pycache__/ContentVec256L9.cpython-38.pyc differ diff --git a/vencoder/__pycache__/ContentVec768L12.cpython-38.pyc b/vencoder/__pycache__/ContentVec768L12.cpython-38.pyc index f3c7eaf2247f7fe3a59596cd4fb636db8ce0fec8..041200a67a51bf44ba4ba5fd424d7588a6c57887 100644 Binary files a/vencoder/__pycache__/ContentVec768L12.cpython-38.pyc and b/vencoder/__pycache__/ContentVec768L12.cpython-38.pyc differ diff --git a/vencoder/__pycache__/__init__.cpython-38.pyc b/vencoder/__pycache__/__init__.cpython-38.pyc index 1daf19bfe3df4eb6c729e1d781e43ce73b6c298a..45580b22e4f17be4f7f545a52b2010933ff5cecf 100644 Binary files a/vencoder/__pycache__/__init__.cpython-38.pyc and b/vencoder/__pycache__/__init__.cpython-38.pyc differ diff --git a/vencoder/__pycache__/encoder.cpython-38.pyc b/vencoder/__pycache__/encoder.cpython-38.pyc index f6d2ed9b0ee939d543a7007378dfaa2475914188..5d0092417e4f737427eeafe767f9eb5d03553e51 100644 Binary files a/vencoder/__pycache__/encoder.cpython-38.pyc and b/vencoder/__pycache__/encoder.cpython-38.pyc differ diff --git a/vencoder/dphubert/__init__.py b/vencoder/dphubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vencoder/dphubert/components.py b/vencoder/dphubert/components.py new file mode 100644 index 0000000000000000000000000000000000000000..be5cc8ce28f11f4f1339578a9d2658740f103283 --- /dev/null +++ b/vencoder/dphubert/components.py @@ -0,0 +1,1410 @@ +"""Building blocks for speech SSL models supporting pruning. + +Originally from: +https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/components.py + +""" + +import math +from collections import defaultdict +from typing import List, Optional, Tuple + +import torch +from torch import Tensor, nn +from torch.nn import Module + +from .hardconcrete import HardConcrete +from .pruning_utils import ( + prune_conv1d_layer, + prune_layer_norm, + prune_linear_layer, +) + + +def _init_transformer_params(module): + """ + Initialize the weights of Transformer module in Wav2Vec2/HuBERT. + + If the module is ``nn.Linear``, normalize the weight with mean 0 and standard deviation 0.02. + If ``bias`` is set to ``True`` in the module, set ``bias`` to 0. + + If the module is ``nn.Embedding``, normalize the weight with mean 0 and standard deviation 0.02. + If ``padding_idx`` is not None, set the weight of padding to 0. + + Note: + Ths method corresponds to + `init_bert_params + `__ + in the original ``fairseq`` implementation. + """ + + def normal_(data): + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class LayerNorm(nn.LayerNorm): + """Layer norm with transpose""" + + def forward(self, input: Tensor) -> Tensor: + x = input.transpose(-2, -1) + x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.transpose(-2, -1) + return x + + +class ConvLayerBlock(Module): + """Convolution unit of FeatureExtractor""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias: bool, + layer_norm: Optional[Module], + prune_conv_channels: bool = False, + ): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.layer_norm = layer_norm + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + ) + + if prune_conv_channels: + self.hard_concrete = HardConcrete(n_in=out_channels, init_mean=0.01) + else: + self.hard_concrete = None + + def forward( + self, + x: Tensor, + length: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): Shape: ``[batch, in_channels, in_frame]``. + length (Tensor or None, optional): Shape ``[batch, ]``. + Returns: + Tensor: Shape ``[batch, out_channels, out_frames]``. + Optional[Tensor]: Shape ``[batch, ]``. + """ + x = self.conv(x) + if self.layer_norm is not None: + x = self.layer_norm(x) + x = nn.functional.gelu(x) + + if self.hard_concrete is not None: + channel_mask = self.hard_concrete() # hard concrete mask, (out_channels,) + x = x * channel_mask.unsqueeze(-1) + + if length is not None: + length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1 + # When input length is 0, the resulting length can be negative. So fix it here. + length = torch.max(torch.zeros_like(length), length) + return x, length + + def get_num_params_and_out_channels(self, in_channels): + if self.hard_concrete is not None: + out_channels = self.hard_concrete.l0_norm() + else: + out_channels = self.conv.out_channels + + num_params = in_channels * out_channels * self.kernel_size + if self.conv.bias is not None: + num_params += out_channels + if self.layer_norm is not None: + num_params += out_channels * 2 + + return num_params, out_channels + + +class FeatureExtractor(Module): + """Extract features from audio + + Args: + conv_layers (nn.ModuleList): + convolution layers + """ + + def __init__( + self, + conv_layers: nn.ModuleList, + ): + super().__init__() + self.conv_layers = conv_layers + + # NOTE: a dummy weight used to save the soft mask of the last conv layer + self.dummy_weight = nn.Parameter( + torch.ones(conv_layers[-1].conv.out_channels, dtype=torch.float32), + requires_grad=False + ) + + def forward( + self, + x: Tensor, + length: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): + Input Tensor representing a batch of audio, + shape: ``[batch, time]``. + length (Tensor or None, optional): + Valid length of each input sample. shape: ``[batch, ]``. + + Returns: + Tensor: + The resulting feature, shape: ``[batch, frame, feature]`` + Optional[Tensor]: + Valid length of each output sample. shape: ``[batch, ]``. + """ + if x.ndim != 2: + raise ValueError("Expected the input Tensor to be 2D (batch, time), " "but received {list(x.shape)}") + + x = x.unsqueeze(1) # (batch, channel==1, frame) + for layer in self.conv_layers: + x, length = layer(x, length) # (batch, feature, frame) + x = x.transpose(1, 2) # (batch, frame, feature) + x = x * self.dummy_weight + return x, length + + def get_num_params_and_final_out_channels(self): + in_channels = 1 + num_params = 0 + for layer in self.conv_layers: + layer_params, in_channels = layer.get_num_params_and_out_channels(in_channels) + num_params += layer_params + + num_params += in_channels # dummy weight + + return num_params, in_channels + + def prune(self): + """"Prune conv layers and dummy weight based on hardconcrete parameters. + This is an in-place operation. + """ + new_config = [] # [(output_channel, kernel_size, stride), ...] + for idx, layer in enumerate(self.conv_layers): + if layer.hard_concrete is not None: + assert not layer.hard_concrete.training + mask = layer.hard_concrete() # (out_features,) + index = mask.nonzero().squeeze(-1) # 2D -> 1D + assert len(index) > 0, f"Conv channels pruned to zero at index {idx}" + new_config.append( + (len(index), layer.kernel_size, layer.stride) + ) + + # prune the current layer + prune_conv1d_layer(layer.conv, index, "output") + if layer.layer_norm is not None: + prune_layer_norm(layer.layer_norm, index) + + # prune the next layer + if idx == len(self.conv_layers) - 1: + self.dummy_weight.data *= mask + self.dummy_weight = nn.Parameter( + self.dummy_weight.index_select(0, index).clone().detach(), requires_grad=False + ) + else: + self.conv_layers[idx+1].conv.weight.data *= mask.unsqueeze(-1) + prune_conv1d_layer(self.conv_layers[idx+1].conv, index, dim="input") + + layer.hard_concrete = None + else: + new_config.append( + (layer.conv.out_channels, layer.kernel_size, layer.stride) + ) + index = torch.arange(layer.conv.out_channels, dtype=torch.long) + + return new_config, index + + +class FeatureProjection(Module): + """Layer that connects FeatureExtractor and Encoder + + Projects features to encoder dimension. + + Args: + in_features (int): Input feature dim. + out_features (int): Output feature dim. + dropout (float): Dropout probability. + """ + + def __init__( + self, + in_features: int, + out_features: int, + dropout: float, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(in_features) + self.projection = nn.Linear( + in_features, + out_features, + ) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + """ + Args: + x (Tensor): + Feature Tensor. shape: ``[batch, frame, in_feature]`` + Returns: + Tensor: Projected features. ``[batch, frame, out_feature]``. + """ + x = self.layer_norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + def get_num_params(self, in_features): + return in_features * 2 + (in_features + 1) * self.projection.out_features + + +class ConvolutionalPositionalEmbedding(Module): + """Positional embedding which is placed at the beginning of Transformer. + + Args: + embed_dim (int): Feature dimension of the input Tensor. + kernel_size (int): The number of frames to be use. + groups (int): The number of groups in feature dimensions. + """ + + def __init__( + self, + embed_dim: int, + kernel_size: int, + groups: int, + ): + super().__init__() + self.embed_dim = embed_dim + self.kernel_size = kernel_size + self.conv = nn.Conv1d( + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + ) + + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + self.num_remove: int = 1 if kernel_size % 2 == 0 else 0 + + def __prepare_scriptable__(self): + for hook in self.conv._forward_pre_hooks.values(): + # The hook we want to remove is an instance of WeightNorm class, so + # normally we would do `if isinstance(...)` but this class is not accessible + # because of shadowing, so we check the module name directly. + # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 + if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm": + torch.nn.utils.remove_weight_norm(self.conv) + return self + + def forward(self, x): + """ + Args: + x (Tensor): shape ``[batch, frame, feature]``. + + Returns: + Tensor: The resulting feature. Shape ``[batch, frame, feature]``. + """ + x = x.transpose(-2, -1) + x = self.conv(x) + if self.num_remove > 0: + x = x[..., : -self.num_remove] + x = torch.nn.functional.gelu(x) + x = x.transpose(-2, -1) + return x + + +class SelfAttention(Module): + """Multihead Self Attention module + + Args: + embed_dim (int): Total dimension of the model. + num_heads (int): The number of heads. + dropout (float, optional): + Dropout probability on attn_output_weights. Default: ``0.0`` + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + prune_heads: bool = False, # whether to prune attention heads + prune_layer: bool = False, # whether to prune entire attention layers + ): + super().__init__() + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = torch.nn.Dropout(dropout) + + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) + self.v_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) + self.q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) + self.out_proj = nn.Linear(num_heads * head_dim, embed_dim, bias=True) + + if prune_heads: + self.hard_concrete_for_heads = HardConcrete(n_in=num_heads, init_mean=0.01) + else: + self.hard_concrete_for_heads = None + + if prune_layer: + self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01) + else: + self.hard_concrete_for_layer = None + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``. + attention_mask (Tensor or ``None``, optional): + shape: ``[batch_size, 1, sequence_length, sequence_length]`` + position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`. + key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with + :py:class:`WavLMSelfAttention`. + Returns: + (Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility + with :py:class:`WavLMSelAttention`). + Attention output shape: ``[batch, sequence_length, embed_dim]``. + """ + if x.ndim != 3 or x.shape[2] != self.embed_dim: + raise ValueError( + f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}." + ) + batch_size, length, embed_dim = x.size() + + shape = (batch_size, length, self.num_heads, self.head_dim) + q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd + k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L + v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd + + # scale down q to avoid value overflow. + weights = (self.scaling * q) @ k # B, nH, L, L + if attention_mask is not None: + weights += attention_mask + # subtracting a constant value from the tensor won't change the output of softmax. + # apply the subtraction to avoid value overflow in torch.nn.functional.softmax. + # for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778 + weights = weights - weights.max(dim=-1, keepdim=True)[0] + + weights = torch.nn.functional.softmax(weights, dim=-1) + weights = self.dropout(weights) + + output = weights @ v # B, nH, L, Hd + + if self.hard_concrete_for_heads is not None: + head_mask = self.hard_concrete_for_heads() # (nH,) + output = output * head_mask.unsqueeze(-1).unsqueeze(-1) + + output = output.transpose(2, 1).reshape(batch_size, length, self.num_heads * self.head_dim) + + output = self.out_proj(output) + + if self.hard_concrete_for_layer is not None: + layer_mask = self.hard_concrete_for_layer() # (1,) + output = output * layer_mask + + return output, None # Necessary for compatibility with WavLMSelAttention + + def get_num_params(self): + if self.hard_concrete_for_heads is not None: + num_heads = self.hard_concrete_for_heads.l0_norm() + else: + num_heads = self.num_heads + num_params = (self.embed_dim + 1) * num_heads * self.head_dim * 3 \ + + (num_heads * self.head_dim + 1) * self.embed_dim + + if self.hard_concrete_for_layer is not None: + num_params *= self.hard_concrete_for_layer.l0_norm() + + return num_params + + def prune(self): + new_config = { + "use_attention": True, + "num_heads": self.num_heads, + } + if self.hard_concrete_for_layer is not None: + assert not self.hard_concrete_for_layer.training + layer_mask = self.hard_concrete_for_layer() # (1,) + self.out_proj.weight.data *= layer_mask + self.out_proj.bias.data *= layer_mask + if layer_mask == 0: + new_config["use_attention"] = False + self.hard_concrete_for_layer = None + + if self.hard_concrete_for_heads is not None: + assert not self.hard_concrete_for_heads.training + head_mask = self.hard_concrete_for_heads() # (num_heads,) + new_config["num_heads"] = len(head_mask.nonzero()) + if new_config["num_heads"] == 0: + new_config["use_attention"] = False + else: + full_mask = head_mask.repeat_interleave(self.head_dim) + full_index = full_mask.nonzero().squeeze(-1) # 1D + + prune_linear_layer(self.k_proj, full_index, "output") + prune_linear_layer(self.v_proj, full_index, "output") + prune_linear_layer(self.q_proj, full_index, "output") + + self.out_proj.weight.data *= full_mask + prune_linear_layer(self.out_proj, full_index, "input") + self.hard_concrete_for_heads = None + + return new_config + + +class WavLMSelfAttention(SelfAttention): + """Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`. + + Args: + embed_dim (int): Total dimension of the model. + num_heads (int): The number of heads. + dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``) + bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``) + has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding. + Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``) + num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``) + max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``) + gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``) + """ + + def __init__( + self, + embed_dim: int, + total_num_heads: int, + remaining_heads: Optional[List[int]] = None, + dropout: float = 0.0, + bias: bool = True, + has_relative_attention_bias: bool = False, + num_buckets: int = 32, + max_distance: int = 128, + gru_rel_pos: bool = True, + prune_heads: bool = False, + prune_layer: bool = False, + ): + self.total_num_heads = total_num_heads + if remaining_heads is None: + self.remaining_heads = list(range(total_num_heads)) + else: + self.remaining_heads = remaining_heads # list of indices + + self.head_dim = embed_dim // total_num_heads + + super().__init__(embed_dim, len(self.remaining_heads), self.head_dim, dropout, prune_heads, prune_layer) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + + if has_relative_attention_bias: + self.rel_attn_embed = nn.Embedding(num_buckets, total_num_heads) + else: + self.rel_attn_embed = None + + # override linear layers to customize bias + self.k_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) + self.out_proj = nn.Linear(len(self.remaining_heads) * self.head_dim, embed_dim, bias=bias) + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8) + self.gru_rel_pos_const = nn.Parameter(torch.ones(1, total_num_heads, 1, 1)) + self.has_position_bias = True + + def compute_bias(self, query_length: int, key_length: int) -> Tensor: + """Compute relative position embeddings for WavLM model. + Args: + query_length (int): Query position can take values between 0 and ``query_length - 1``. + key_length (int): Key position can take values between 0 and ``key_length - 1``. + Returns: + Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings + """ + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # Shape (query_length, key_length) + relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True) + relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) + values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]) + return values + + def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True): + """Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM + paper :cite:`chen2022wavlm`. + Args: + relative_positions (Tensor): Relative offsets between query and key positions, + of shape ``(query_length, key_length)``. + bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting + matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set + to zero. (Default ``True``) + Returns: + Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions. + """ + num_buckets = self.num_buckets + max_distance = self.max_distance + # Shape (query_length, key_length) + relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long) + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def forward( + self, + query: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``. + key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape + `(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``) + attn_mask: Needs to be ``None``. The argument exists for compatibility with + ``EncoderLayer``. (Default: ``None``) + position_bias (Tensor or None, optional): Position bias of shape + ``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be + generated in the first layer and then passed from each encoder layer to the next one. + (Default: ``None``) + Returns: + attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``. + position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``. + """ + bsz, seq_len, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert key_padding_mask is None + + # only for the first layer + if self.rel_attn_embed is not None and position_bias is None: + position_bias = self.compute_bias(seq_len, seq_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.total_num_heads, seq_len, seq_len) + + attn_mask_rel_pos: Optional[Tensor] = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: # Apply gating on relative position bias + query_layer = query.view(bsz, seq_len, self.total_num_heads, -1) + query_layer = query_layer.permute(0, 2, 1, 3) + + gate_a, gate_b = torch.sigmoid( + self.gru_rel_pos_linear(query_layer).view(bsz, self.total_num_heads, seq_len, 2, 4).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.total_num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len)) + attn_mask_rel_pos = attn_mask_rel_pos.reshape(bsz, self.total_num_heads, seq_len, seq_len)[:, self.remaining_heads, :, :] + + attn_mask = attn_mask_rel_pos + if attention_mask is not None: + attn_mask = attn_mask + attention_mask + if key_padding_mask is not None: + attn_mask = attn_mask.masked_fill( + key_padding_mask.reshape(bsz, 1, 1, seq_len), + float("-inf") + ) + attn_output, _ = super().forward(query, attention_mask=attn_mask) + + return attn_output, position_bias + + def prune(self): + new_config = { + "use_attention": True, + "remaining_heads": self.remaining_heads, + } + if self.hard_concrete_for_layer is not None: + assert not self.hard_concrete_for_layer.training + layer_mask = self.hard_concrete_for_layer() # (1,) + self.out_proj.weight.data *= layer_mask + self.out_proj.bias.data *= layer_mask + if layer_mask == 0: + new_config["use_attention"] = False + self.hard_concrete_for_layer = None + + if self.hard_concrete_for_heads is not None: + assert not self.hard_concrete_for_heads.training + head_mask = self.hard_concrete_for_heads() # (num_heads,) + new_config["remaining_heads"] = head_mask.nonzero().squeeze(-1).tolist() + if len(new_config["remaining_heads"]) == 0: + new_config["use_attention"] = False + else: + full_mask = head_mask.repeat_interleave(self.head_dim) + full_index = full_mask.nonzero().squeeze(-1) # 1D + + prune_linear_layer(self.k_proj, full_index, "output") + prune_linear_layer(self.v_proj, full_index, "output") + prune_linear_layer(self.q_proj, full_index, "output") + + self.out_proj.weight.data *= full_mask + prune_linear_layer(self.out_proj, full_index, "input") + self.hard_concrete_for_heads = None + + return new_config + + +class FeedForward(Module): + """Layer that follows attention layer in encoder layer.""" + + def __init__( + self, + io_features: int, + intermediate_features: int, + intermediate_dropout: float, + output_dropout: float, + prune_intermediate: bool = False, + prune_layer: bool = False, + ): + super().__init__() + self.intermediate_dense = nn.Linear(io_features, intermediate_features) + self.intermediate_dropout = nn.Dropout(intermediate_dropout) + self.output_dense = nn.Linear(intermediate_features, io_features) + self.output_dropout = nn.Dropout(output_dropout) + + if prune_intermediate: + self.hard_concrete_for_intermediate = HardConcrete( + n_in=intermediate_features, init_mean=0.5 + ) + else: + self.hard_concrete_for_intermediate = None + + if prune_layer: + self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01) + else: + self.hard_concrete_for_layer = None + + def forward(self, x): + """ + Args: + x (Tensor): shape: `(batch, sequence_length, io_features)` + Returns: + x (Tensor): shape: `(batch, sequence_length, io_features)` + """ + x = self.intermediate_dense(x) + x = torch.nn.functional.gelu(x) + x = self.intermediate_dropout(x) + + if self.hard_concrete_for_intermediate is not None: + intermediate_mask = self.hard_concrete_for_intermediate() # (intermediate_features,) + x = x * intermediate_mask + + x = self.output_dense(x) + x = self.output_dropout(x) + + if self.hard_concrete_for_layer is not None: + layer_mask = self.hard_concrete_for_layer() # (1,) + x = x * layer_mask + + return x + + def get_num_params(self): + io_features = self.intermediate_dense.in_features + if self.hard_concrete_for_intermediate is not None: + intermediate_features = self.hard_concrete_for_intermediate.l0_norm() + else: + intermediate_features = self.intermediate_dense.out_features + num_params = (io_features + 1) * intermediate_features + (intermediate_features + 1) * io_features + + if self.hard_concrete_for_layer is not None: + num_params *= self.hard_concrete_for_layer.l0_norm() + + return num_params + + def prune(self): + new_config = { + "use_feed_forward": True, + "ff_interm_features": self.intermediate_dense.out_features + } + if self.hard_concrete_for_layer is not None: + assert not self.hard_concrete_for_layer.training + layer_mask = self.hard_concrete_for_layer() + self.output_dense.weight.data *= layer_mask + self.output_dense.bias.data *= layer_mask + if layer_mask == 0: + new_config["use_feed_forward"] = False + self.hard_concrete_for_layer = None + + if self.hard_concrete_for_intermediate is not None: + assert not self.hard_concrete_for_intermediate.training + interm_mask = self.hard_concrete_for_intermediate() + interm_index = interm_mask.nonzero().squeeze(-1) # NOTE: must specify dim=-1 + new_config["ff_interm_features"] = len(interm_index) + if new_config["ff_interm_features"] == 0: + new_config["use_feed_forward"] = False + else: + prune_linear_layer(self.intermediate_dense, interm_index, "output") + + self.output_dense.weight.data *= interm_mask + prune_linear_layer(self.output_dense, interm_index, "input") + self.hard_concrete_for_intermediate = None + + return new_config + + +class EncoderLayer(Module): + """A layer unit in encoder. Combines multihead self attention and feed forward.""" + + def __init__( + self, + attention: Optional[Module], # can be None if the entire layer is pruned + dropout: float, + layer_norm_first: bool, + feed_forward: Optional[Module], # can be None if the entire layer is pruned + embed_dim: int, + ): + super().__init__() + self.attention = attention + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(embed_dim) + self.layer_norm_first = layer_norm_first + self.feed_forward = feed_forward + self.final_layer_norm = nn.LayerNorm(embed_dim) + self.embed_dim = embed_dim + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``. + attention_mask (Tensor or ``None``, optional): attention mask + of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``) + position_bias (Tensor or ``None``, optional): position bias of shape + ``(batch_size * num_heads, src_len, src_len)``. + Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``) + key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``. + Only used for WavLM model, ignored otherwise. (Default: ``None``) + Returns: + (x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model, + ``None`` otherwise. + """ + if self.attention is not None: + residual = x + + if self.layer_norm_first: + x = self.layer_norm(x) + + x, position_bias = self.attention( + x, attention_mask=attention_mask, position_bias=position_bias, key_padding_mask=key_padding_mask + ) + + x = self.dropout(x) + x = residual + x + + if self.layer_norm_first: + if self.feed_forward is not None: + x = x + self.feed_forward(self.final_layer_norm(x)) + else: + # NOTE: for post norm, the layer norms should always be applied even if the layers are pruned. + x = self.layer_norm(x) + if self.feed_forward is not None: + x = x + self.feed_forward(x) + x = self.final_layer_norm(x) + return x, position_bias + + def get_num_params(self): + num_params = self.embed_dim * 2 * 2 # two layer norms + if self.attention is not None: + num_params += self.attention.get_num_params() + if self.feed_forward is not None: + num_params += self.feed_forward.get_num_params() + return num_params + + +class Transformer(Module): + def __init__( + self, + pos_conv_embed: Module, + dropout: float, + layers: Module, + layer_norm_first: bool, + layer_drop: float, + ): + super().__init__() + self.pos_conv_embed = pos_conv_embed + self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim) + self.layer_norm_first = layer_norm_first + self.layer_drop = layer_drop + self.dropout = nn.Dropout(dropout) + self.layers = layers + + def _preprocess(self, x: Tensor): + x = x + self.pos_conv_embed(x) + + if self.layer_norm_first: + x = self.layer_norm(x) + + x = self.dropout(x) + return x + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + ) -> Tensor: + x = self._preprocess(x) + for layer in self.layers: + if not (self.training and torch.rand(1).item() <= self.layer_drop): + x, position_bias = layer(x, attention_mask, position_bias=position_bias) + + if not self.layer_norm_first: + x = self.layer_norm(x) + return x + + def get_intermediate_outputs( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + num_layers: Optional[int] = None, + position_bias: Optional[Tensor] = None, + ) -> List[Tensor]: + if num_layers is not None: + if not 0 < num_layers <= len(self.layers): + raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]") + + ret: List[Tensor] = [] + x = self._preprocess(x) + for layer in self.layers: + x, position_bias = layer(x, attention_mask, position_bias=position_bias) + ret.append(x) + if num_layers is not None and len(ret) >= num_layers: + return ret + return ret + + def get_num_params(self): + # pos_conv_embed and layer_norm + num_params = sum(p.numel() for p in self.pos_conv_embed.parameters()) + self.pos_conv_embed.embed_dim * 2 + for layer in self.layers: + num_params += layer.get_num_params() + return num_params + + def prune(self): + new_config = defaultdict(list) + for layer in self.layers: + attention_config = layer.attention.prune() + new_config["use_attention"].append(attention_config["use_attention"]) + if "remaining_heads" in attention_config: + new_config["remaining_heads"].append(attention_config["remaining_heads"]) + else: + new_config["num_heads"].append(attention_config["num_heads"]) + + if not attention_config["use_attention"]: + layer.attention = None + + ff_config = layer.feed_forward.prune() + new_config["use_feed_forward"].append(ff_config["use_feed_forward"]) + new_config["ff_interm_features"].append(ff_config["ff_interm_features"]) + if not ff_config["use_feed_forward"]: + layer.feed_forward = None + + return new_config + + +class Encoder(Module): + def __init__( + self, + feature_projection: Module, + transformer: Module, + ): + super().__init__() + self.feature_projection = feature_projection + self.transformer = transformer + + def _preprocess( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + x = self.feature_projection(features) + + mask: Optional[Tensor] = None + if lengths is not None: + batch_size, max_len, _ = x.shape + # create mask for padded elements and zero-out them + mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] + x[mask] = 0.0 + # extend the mask to attention shape and set weight + mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype) + mask = mask.expand(batch_size, 1, max_len, max_len) + return x, mask + + def forward( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tensor: + x, mask = self._preprocess(features, lengths) + x = self.transformer(x, attention_mask=mask) + return x + + def extract_features( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> List[Tensor]: + x, masks = self._preprocess(features, lengths) + interm = self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers) + return [x] + interm + + def get_num_params(self, in_features): + """Calculate the current model size.""" + feature_projection_size = self.feature_projection.get_num_params(in_features) + transformer_size = self.transformer.get_num_params() + return feature_projection_size + transformer_size + + def prune(self, conv_out_index): + """In-place pruning of submodules.""" + prune_layer_norm(self.feature_projection.layer_norm, conv_out_index) + prune_linear_layer(self.feature_projection.projection, conv_out_index, "input") + transformer_config = self.transformer.prune() + return transformer_config + + +################################################################################ +def _get_feature_extractor( + norm_mode: str, + shapes: List[Tuple[int, int, int]], + bias: bool, + prune_conv_channels: bool = False, +) -> FeatureExtractor: + """ + Args: + norm_mode (str): + Either "group_norm" or "layer_norm". + If "group_norm", then a single normalization is applied + in the first convolution block. Otherwise, all the convolution + blocks will have layer normalization. + This option corresponds to "extractor_mode" from fairseq. + Expected values are "group_norm" for Base arch, and + "layer_norm" for Large arch. + shapes (list of tuple of int): + Configuration of convolution layers. List of convolution configuration, + i.e. ``[(output_channel, kernel_size, stride), ...]`` + This option corresponds to "conv_feature_layers" from fairseq. + Expected values are + ``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2`` + for all the architectures. + bias (bool): + Whether to include bias term to each convolution operation. + This option corresponds to "conv_bias" from fairseq. + Expected values are False for Base arch, and True for Large arch. + + See Also: + * Original implementation + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733 + * "extractor_mode" + - Def and base: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45 + - Large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52 + * "conv_feature_layers" + - Def, base and large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100 + * "conv_bias" + - Def and base: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103 + - Large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61 + """ + if norm_mode not in ["group_norm", "layer_norm"]: + raise ValueError("Invalid norm mode") + blocks = [] + in_channels = 1 + for i, (out_channels, kernel_size, stride) in enumerate(shapes): + normalization = None + if norm_mode == "group_norm" and i == 0: + normalization = nn.GroupNorm( + num_groups=out_channels, + num_channels=out_channels, + affine=True, + ) + elif norm_mode == "layer_norm": + normalization = LayerNorm( + normalized_shape=out_channels, + elementwise_affine=True, + ) + blocks.append( + ConvLayerBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + layer_norm=normalization, + prune_conv_channels=prune_conv_channels, + ) + ) + in_channels = out_channels + return FeatureExtractor(nn.ModuleList(blocks)) + + +def _get_encoder( + in_features: int, + embed_dim: int, + dropout_input: float, + pos_conv_kernel: int, + pos_conv_groups: int, + num_layers: int, + use_attention: List[bool], + use_feed_forward: List[bool], + num_heads: List[int], + head_dim: int, + attention_dropout: float, + ff_interm_features: List[int], + ff_interm_dropout: float, + dropout: float, + layer_norm_first: bool, + layer_drop: float, + prune_attention_heads: bool = False, + prune_attention_layer: bool = False, + prune_feed_forward_intermediate: bool = False, + prune_feed_forward_layer: bool = False, +) -> Encoder: + """ + Args: + in_features (int): The number of input features. + embed_dim (int): + The dimension of embedding. + This option corresponds to "encoder_embed_dim" from fairseq. + Expected values are 768 for Base arch, and 1024 for Large arch. + dropout_input (float): + The dropout probability applied after the input feature is projected + to ``embed_dim``. + This option corresponds to "dropout_input" from fairseq. + Expected values are 0.1 for both Base and Large arch. + pos_conv_kernel (int): + The kernel size of convolutional positional embeddings. + This option corresponds to "conv_pos" from fairseq. + Expected values are 128 for both Base and Large arch. + pos_conv_groups (int): + The number of groups of convolutional positional embeddings. + This option corresponds to "conv_pos_groups" from fairseq. + Expected values are 16 for both Base and Large arch. + num_layers (int): + The number of self attention layers in transformer block. + This option corresponds to "encoder_layers" from fairseq. + Expected values are 12 for Base and 24 for Large arch. + num_heads (int): + The number of heads in self attention layers. + This option corresponds to "encoder_attention_heads" from fairseq. + Expected values are 12 for Base and 16 for Large arch. + attention_dropout (float): + The dropout probability applied after softmax in self-attention layer. + This option corresponds to "attention_dropout" from fairseq. + Expected values are 0.1 for Base and 0.0 for Large arch. + ff_interm_features (int): + The dimension of hidden features in feed forward layer. + This option corresponds to "encoder_ffn_embed_dim" from fairseq. + Expected values are 3072 for Base and 4096 for Large arch. + ff_interm_dropout (float): + The dropout probability applied in feedforward layer. + This option correspinds to "activation_dropout" from fairseq. + Expected values are 0.1 for both Base and Large arch. + dropout (float): + The dropout probability applied at the end of feed forward layer. + This option corresponds to "dropout" from fairseq. + Expected values are 0.1 for Base and 0.0 for Large arch. + layer_norm_first (bool): + Control the order of layer norm in transformer layer and each encoder layer. + If True, in transformer layer, layer norm is applied before features are fed + to encoder layers. In encoder layer, two layer norms are applied before and after + self attention. + If False, in transformer layer, layer norm is applied after features are fed + to encoder layers. In encoder layer, two layer norms are applied after self + attention, before and after feed forward. + This option corresponds to "layer_norm_first" from fairseq. + Expected values are False for Base and True for Large arch. + layer_drop (float): + Probability to drop each encoder layer during training. + This option corresponds to "layerdrop" from fairseq. + Expected values are 0.1 for both Base and Large arch. + + See Also: + * "encoder_embed_dim" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64 + * "dropout_input" + - Def, base and large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78 + * "conv_pos" + - Def, base and large + NOTE: The description is wrong. + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207 + - Usage + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756 + * "conv_pos_groups" + - Def, base and large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211 + * "encoder_layers" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63 + * "encoder_attention_heads" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66 + * "attention_dropout" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60 + * "encoder_ffn_embed_dim" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65 + * "activation_dropout" + - Def + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71 + - Base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55 + * "dropout" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59 + * "layer_norm_first" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53 + * "layerdrop" + - Def + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74 + - Base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54 + """ + feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) + pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) + + # Original impl + # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 + encoder_layers = nn.ModuleList() + for idx in range(num_layers): + if use_attention[idx]: + attention = SelfAttention( + embed_dim=embed_dim, + num_heads=num_heads[idx], + head_dim=head_dim, + dropout=attention_dropout, + prune_heads=prune_attention_heads, + prune_layer=prune_attention_layer, + ) + else: + attention = None + if use_feed_forward[idx]: + feed_forward = FeedForward( + io_features=embed_dim, + intermediate_features=ff_interm_features[idx], + intermediate_dropout=ff_interm_dropout, + output_dropout=dropout, + prune_intermediate=prune_feed_forward_intermediate, + prune_layer=prune_feed_forward_layer, + ) + else: + feed_forward = None + encoder_layers.append( + EncoderLayer( + attention=attention, + dropout=dropout, + layer_norm_first=layer_norm_first, + feed_forward=feed_forward, + embed_dim=embed_dim, + ) + ) + transformer = Transformer( + pos_conv_embed=pos_conv, + dropout=dropout, + layers=encoder_layers, + layer_norm_first=not layer_norm_first, + layer_drop=layer_drop, + ) + return Encoder(feature_projection, transformer) + + +def _get_wavlm_encoder( + in_features: int, + embed_dim: int, + dropout_input: float, + pos_conv_kernel: int, + pos_conv_groups: int, + num_layers: int, + use_attention: List[bool], + use_feed_forward: List[bool], + total_num_heads: List[int], + remaining_heads: List[List[int]], + num_buckets: int, + max_distance: int, + attention_dropout: float, + ff_interm_features: List[int], + ff_interm_dropout: float, + dropout: float, + layer_norm_first: bool, + layer_drop: float, + prune_attention_heads: bool = False, + prune_attention_layer: bool = False, + prune_feed_forward_intermediate: bool = False, + prune_feed_forward_layer: bool = False, +) -> Encoder: + """ + Construct encoder for WavLM model :cite:`chen2022wavlm`. The structure of the encoder and most of the argments are + the same as in :py:func:`_get_encoder` so refer there for documentation. The only difference from Wav2Vec2 encoder + is usage of `WavLMSelfAttention` instead of `SelfAttention` and two additional parameters: `num_buckets` and + `max_distance`. + Args: + in_features (int): See :py:func:`_get_encoder`. + embed_dim (int): See :py:func:`_get_encoder`. + dropout_input (float): See :py:func:`_get_encoder`. + pos_conv_kernel (int): See :py:func:`_get_encoder`. + pos_conv_groups (int): See :py:func:`_get_encoder`. + num_layers (int): See :py:func:`_get_encoder`. + num_heads (int): See :py:func:`_get_encoder`. + num_buckets (int): Number of buckets for relative position embedding. + max_distance (int): Maximum distance for relative position embedding. + attention_dropout (float): See :py:func:`_get_encoder`. + ff_interm_features (int): See :py:func:`_get_encoder`. + ff_interm_dropout (float): See :py:func:`_get_encoder`. + dropout (float): See :py:func:`_get_encoder`. + layer_norm_first (bool): See :py:func:`_get_encoder`. + layer_drop (float): See :py:func:`_get_encoder`. + + """ + feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) + pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) + + # Original impl + # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 + encoder_layers = nn.ModuleList() + for i in range(num_layers): + if use_attention[i]: + attention = WavLMSelfAttention( + embed_dim=embed_dim, + total_num_heads=total_num_heads[i], + remaining_heads=remaining_heads[i], + dropout=attention_dropout, + has_relative_attention_bias=(i == 0), # Position embedding is only necessary in the first layer. + num_buckets=num_buckets, + max_distance=max_distance, + prune_heads=prune_attention_heads, + prune_layer=prune_attention_layer, + ) + else: + attention = None + if use_feed_forward[i]: + feed_forward = FeedForward( + io_features=embed_dim, + intermediate_features=ff_interm_features[i], + intermediate_dropout=ff_interm_dropout, + output_dropout=dropout, + prune_intermediate=prune_feed_forward_intermediate, + prune_layer=prune_feed_forward_layer, + ) + else: + feed_forward = None + encoder_layers.append( + EncoderLayer( + attention=attention, + dropout=dropout, + layer_norm_first=layer_norm_first, + feed_forward=feed_forward, + embed_dim=embed_dim, + ) + ) + transformer = Transformer( + pos_conv_embed=pos_conv, + dropout=dropout, + layers=encoder_layers, + layer_norm_first=not layer_norm_first, + layer_drop=layer_drop, + ) + return Encoder(feature_projection, transformer) + + +def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor: + """Generate the padding mask given the padded input and the lengths Tensors. + Args: + input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`. + lengths (Tensor): The lengths Tensor of dimension `[batch,]`. + + Returns: + (Tensor): The padding mask. + """ + batch_size, max_len, _ = input.shape + mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] + return mask + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None diff --git a/vencoder/dphubert/hardconcrete.py b/vencoder/dphubert/hardconcrete.py new file mode 100644 index 0000000000000000000000000000000000000000..468a30d1eccdf20ee7493e71792c46e48449c4e3 --- /dev/null +++ b/vencoder/dphubert/hardconcrete.py @@ -0,0 +1,122 @@ +"""Implementation of the hard Concrete distribution. + +Originally from: +https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py + +""" + +import math + +import torch +import torch.nn as nn + + +class HardConcrete(nn.Module): + """A HarcConcrete module. + Use this module to create a mask of size N, which you can + then use to perform L0 regularization. + + To obtain a mask, simply run a forward pass through the module + with no input data. The mask is sampled in training mode, and + fixed during evaluation mode, e.g.: + + >>> module = HardConcrete(n_in=100) + >>> mask = module() + >>> norm = module.l0_norm() + """ + + def __init__( + self, + n_in: int, + init_mean: float = 0.5, + init_std: float = 0.01, + temperature: float = 2/3, # from CoFi + stretch: float = 0.1, + eps: float = 1e-6 + ) -> None: + """Initialize the HardConcrete module. + Parameters + ---------- + n_in : int + The number of hard concrete variables in this mask. + init_mean : float, optional + Initial drop rate for hard concrete parameter, + by default 0.5., + init_std: float, optional + Used to initialize the hard concrete parameters, + by default 0.01. + temperature : float, optional + Temperature used to control the sharpness of the + distribution, by default 1.0 + stretch : float, optional + Stretch the sampled value from [0, 1] to the interval + [-stretch, 1 + stretch], by default 0.1. + """ + super().__init__() + + self.n_in = n_in + self.limit_l = -stretch + self.limit_r = 1.0 + stretch + self.log_alpha = nn.Parameter(torch.zeros(n_in)) + self.beta = temperature + self.init_mean = init_mean + self.init_std = init_std + self.bias = -self.beta * math.log(-self.limit_l / self.limit_r) + + self.eps = eps + self.compiled_mask = None + self.reset_parameters() + + def reset_parameters(self): + """Reset the parameters of this module.""" + self.compiled_mask = None + mean = math.log(1 - self.init_mean) - math.log(self.init_mean) + self.log_alpha.data.normal_(mean, self.init_std) + + def l0_norm(self) -> torch.Tensor: + """Compute the expected L0 norm of this mask. + Returns + ------- + torch.Tensor + The expected L0 norm. + """ + return (self.log_alpha + self.bias).sigmoid().sum() + + def forward(self) -> torch.Tensor: + """Sample a hard concrete mask. + Returns + ------- + torch.Tensor + The sampled binary mask + """ + if self.training: + # Reset the compiled mask + self.compiled_mask = None + # Sample mask dynamically + u = self.log_alpha.new(self.n_in).uniform_(self.eps, 1 - self.eps) + s = torch.sigmoid((torch.log(u / (1 - u)) + self.log_alpha) / self.beta) + s = s * (self.limit_r - self.limit_l) + self.limit_l + mask = s.clamp(min=0., max=1.) + + else: + # Compile new mask if not cached + if self.compiled_mask is None: + # Get expected sparsity + expected_num_zeros = self.n_in - self.l0_norm().item() + num_zeros = round(expected_num_zeros) + # Approximate expected value of each mask variable z; + # We use an empirically validated magic number 0.8 + soft_mask = torch.sigmoid(self.log_alpha / self.beta * 0.8) + # Prune small values to set to 0 + _, indices = torch.topk(soft_mask, k=num_zeros, largest=False) + soft_mask[indices] = 0. + self.compiled_mask = soft_mask + mask = self.compiled_mask + + return mask + + def extra_repr(self) -> str: + return str(self.n_in) + + def __repr__(self) -> str: + return "{}({})".format(self.__class__.__name__, self.extra_repr()) diff --git a/vencoder/dphubert/model.py b/vencoder/dphubert/model.py new file mode 100644 index 0000000000000000000000000000000000000000..348ede2c3edc3e5588ee75760085dee9eafd9d68 --- /dev/null +++ b/vencoder/dphubert/model.py @@ -0,0 +1,966 @@ +"""Speech SSL models supporting pruning. + +Originally from: +https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/model.py + +""" + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Module + +from . import components + + +class Wav2Vec2Model(Module): + """Acoustic model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`. + + Note: + To build the model, please use one of the factory functions. + :py:func:`wav2vec2_model`, :py:func:`wav2vec2_base`, :py:func:`wav2vec2_large`, + :py:func:`wav2vec2_large_lv60k`, :py:func:`hubert_base`, :py:func:`hubert_large`, + and :py:func:`hubert_xlarge`. + + See Also: + * :class:`torchaudio.pipelines.Wav2Vec2Bundle`: Pretrained models (without fine-tuning) + * :class:`torchaudio.pipelines.Wav2Vec2ASRBundle`: ASR pipelines with pretrained models. + + Args: + feature_extractor (torch.nn.Module): + Feature extractor that extracts feature vectors from raw audio Tensor. + + encoder (torch.nn.Module): + Encoder that converts the audio features into the sequence of probability + distribution (in negative log-likelihood) over labels. + + aux (torch.nn.Module or None, optional): + Auxiliary module. If provided, the output from encoder is passed to this module. + """ # noqa: E501 + + def __init__( + self, + normalize_waveform: bool, + feature_extractor: Module, + encoder: Module, + aux: Optional[Module] = None, + ): + super().__init__() + self.normalize_waveform = normalize_waveform + self.feature_extractor = feature_extractor + self.encoder = encoder + self.aux = aux + + @torch.jit.export + def extract_features( + self, + waveforms: Tensor, + lengths: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> Tuple[List[Tensor], Optional[Tensor]]: + """Extract feature vectors from raw waveforms + + This returns the list of outputs from the intermediate layers of + transformer block in encoder. + + Args: + waveforms (Tensor): Audio tensor of shape `(batch, frames)`. + lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `(batch, )`. + When the ``waveforms`` contains audios with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths and apply proper mask in + transformer attention layer. + If ``None``, it is assumed that the entire audio waveform + length is valid. + num_layers (int or None, optional): + If given, limit the number of intermediate layers to go through. + Providing `1` will stop the computation after going through one + intermediate layers. If not given, the outputs from all the + intermediate layers are returned. + + Returns: + (List[Tensor], Optional[Tensor]): + List of Tensors + Features from requested layers. + Each Tensor is of shape: `(batch, time frame, feature dimension)` + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` + is returned. + It indicates the valid length in time axis of each feature Tensor. + """ + if self.normalize_waveform: + if lengths is not None: + waveforms = [ + F.layer_norm(wave[:length], (length,)) for wave, length in zip(waveforms, lengths) + ] + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) + else: + waveforms = F.layer_norm(waveforms, waveforms.shape[-1:]) + + x, lengths = self.feature_extractor(waveforms, lengths) + x = self.encoder.extract_features(x, lengths, num_layers) # (num_layers+1,), including the input + return x, lengths + + def get_num_params(self): + """Calculate the current size.""" + feature_extractor_size, encoder_in_features = self.feature_extractor.get_num_params_and_final_out_channels() + encoder_size = self.encoder.get_num_params(encoder_in_features) + return feature_extractor_size + encoder_size + + def prune(self): + self.eval() # must be in eval mode + conv_config, conv_out_index = self.feature_extractor.prune() # [(output_channel, kernel_size, stride), ...] + transformer_config = self.encoder.prune(conv_out_index) # NOTE: this is a defaultdict(list) + use_attention = transformer_config["use_attention"] + use_feed_forward = transformer_config["use_feed_forward"] + num_heads = transformer_config["num_heads"] # can be [] + remaining_heads = transformer_config["remaining_heads"] # can be [] + ff_interm_features = transformer_config["ff_interm_features"] + + return conv_config, use_attention, use_feed_forward, num_heads, remaining_heads, ff_interm_features + + def forward( + self, + waveforms: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Compute the sequence of probability distribution over labels. + + Args: + waveforms (Tensor): Audio tensor of shape `(batch, frames)`. + lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `(batch, )`. + When the ``waveforms`` contains audios with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths and apply proper mask in + transformer attention layer. + If ``None``, it is assumed that all the audio in ``waveforms`` + have valid length. Default: ``None``. + + Returns: + (Tensor, Optional[Tensor]): + Tensor + The sequences of probability distribution (in logit) over labels. + Shape: `(batch, frames, num labels)`. + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` + is returned. + It indicates the valid length in time axis of the output Tensor. + """ + if self.normalize_waveform: + if lengths is not None: + waveforms = [ + F.layer_norm(wave[:length], (length,)) for wave, length in zip(waveforms, lengths) + ] + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) + else: + waveforms = F.layer_norm(waveforms, waveforms.shape[-1:]) + + x, lengths = self.feature_extractor(waveforms, lengths) + x = self.encoder(x, lengths) + if self.aux is not None: + x = self.aux(x) + return x, lengths + + +def wav2vec2_model(**configs) -> Wav2Vec2Model: + """Wraps the original wav2vec2_model and wavlm_model.""" + + if "encoder_remaining_heads" in configs: + return wavlm_model(**configs) + + return wav2vec2_model_original(**configs) + + +def wav2vec2_model_original( + extractor_mode: str, + extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], + extractor_conv_bias: bool, + encoder_embed_dim: int, + encoder_projection_dropout: float, + encoder_pos_conv_kernel: int, + encoder_pos_conv_groups: int, + encoder_num_layers: int, + encoder_use_attention: List[bool], + encoder_use_feed_forward: List[bool], + encoder_num_heads: List[int], + encoder_head_dim: int, + encoder_attention_dropout: float, + encoder_ff_interm_features: List[int], + encoder_ff_interm_dropout: float, + encoder_dropout: float, + encoder_layer_norm_first: bool, + encoder_layer_drop: float, + aux_num_out: Optional[int], + normalize_waveform: bool, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds custom :class:`~torchaudio.models.Wav2Vec2Model`. + + Note: + The "feature extractor" below corresponds to + `ConvFeatureExtractionModel `__ + in the original ``fairseq`` implementation. + This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0* + :cite:`baevski2020wav2vec` paper. + + The "encoder" below corresponds to `TransformerEncoder `__, + and this is referred as "Transformer" in the paper. + + Args: + extractor_mode (str): Operation mode of feature extractor. + Valid values are ``"group_norm"`` or ``"layer_norm"``. + If ``"group_norm"``, then a single normalization is applied + in the first convolution block. Otherwise, all the convolution + blocks will have layer normalization. + + This option corresponds to ``extractor_mode`` from ``fairseq``. + extractor_conv_layer_config (list of integer tuples or None): + Configuration of convolution layers in feature extractor. + List of convolution configuration, + i.e. ``[(output_channel, kernel_size, stride), ...]`` + + If ``None`` is provided, then the following default value is used. + + .. code-block:: python + + [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ] + + This option corresponds to ``conv_feature_layers`` from ``fairseq``. + + extractor_conv_bias (bool): + Whether to include bias term to each convolution operation. + + This option corresponds to ``conv_bias`` from ``fairseq``. + + encoder_embed_dim (int): + The dimension of embedding in encoder. + + This option corresponds to ``encoder_embed_dim`` from ``fairseq``. + + encoder_projection_dropout (float): + The dropout probability applied after the input feature is projected + to ``encoder_embed_dim``. + + This option corresponds to ``dropout_input`` from ``fairseq``. + + encoder_pos_conv_kernel (int): + The kernel size of convolutional positional embeddings. + + This option corresponds to ``conv_pos`` from ``fairseq``. + + encoder_pos_conv_groups (int): + The number of groups of convolutional positional embeddings. + + This option corresponds to ``conv_pos_groups`` from ``fairseq``. + + encoder_num_layers (int): + The number of self attention layers in transformer block. + + This option corresponds to ``encoder_layers`` from ``fairseq``. + + encoder_num_heads (int): + The number of heads in self attention layers. + + This option corresponds to ``encoder_attention_heads`` from ``fairseq``. + + encoder_attention_dropout (float): + The dropout probability applied after softmax in self-attention layer. + + This option corresponds to ``attention_dropout`` from ``fairseq``. + + encoder_ff_interm_features (int): + The dimension of hidden features in feed forward layer. + + This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``. + + encoder_ff_interm_dropout (float): + The dropout probability applied in feedforward layer. + + This option correspinds to ``activation_dropout`` from ``fairseq``. + + encoder_dropout (float): + The dropout probability applied at the end of feed forward layer. + + This option corresponds to ``dropout`` from ``fairseq``. + + encoder_layer_norm_first (bool): + Control the order of layer norm in transformer layer and each encoder layer. + If True, in transformer layer, layer norm is applied before features are fed + to encoder layers. In encoder layer, two layer norms are applied before and after + self attention. + If False, in transformer layer, layer norm is applied after features are fed + to encoder layers. In encoder layer, two layer norms are applied after self + attention, before and after feed forward. + + This option corresponds to ``layer_norm_first`` from ``fairseq``. + + encoder_layer_drop (float): + Probability to drop each encoder layer during training. + + This option corresponds to ``layerdrop`` from ``fairseq``. + + aux_num_out (int or None): + When provided, attach an extra linear layer on top of encoder, which can be + used for fine-tuning. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + if extractor_conv_layer_config is None: + extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 + + feature_extractor = components._get_feature_extractor( + extractor_mode, extractor_conv_layer_config, extractor_conv_bias, + prune_conv_channels=extractor_prune_conv_channels, + ) + encoder = components._get_encoder( + in_features=extractor_conv_layer_config[-1][0], + embed_dim=encoder_embed_dim, + dropout_input=encoder_projection_dropout, + pos_conv_kernel=encoder_pos_conv_kernel, + pos_conv_groups=encoder_pos_conv_groups, + num_layers=encoder_num_layers, + use_attention=encoder_use_attention, + use_feed_forward=encoder_use_feed_forward, + num_heads=encoder_num_heads, + head_dim=encoder_head_dim, + attention_dropout=encoder_attention_dropout, + ff_interm_features=encoder_ff_interm_features, + ff_interm_dropout=encoder_ff_interm_dropout, + dropout=encoder_dropout, + layer_norm_first=encoder_layer_norm_first, + layer_drop=encoder_layer_drop, + prune_attention_heads=encoder_prune_attention_heads, + prune_attention_layer=encoder_prune_attention_layer, + prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + aux = None + if aux_num_out is not None: + aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out) + return Wav2Vec2Model(normalize_waveform, feature_extractor, encoder, aux) + + +def wav2vec2_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "base" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def wav2vec2_large( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "large" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def wav2vec2_large_lv60k( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "large lv-60k" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=True, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def hubert_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.05, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "base" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_use_attention=[True] * 12, + encoder_use_feed_forward=[True] * 12, + encoder_num_heads=[12] * 12, + encoder_head_dim=64, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=[3072] * 12, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def hubert_large( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "large" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def hubert_xlarge( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "extra large" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1280, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=48, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=5120, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def _init_hubert_pretrain_model(module): + if isinstance(module, components.LayerNorm): + torch.nn.init.kaiming_normal_(module.conv.weight) + elif isinstance(module, components.ConvolutionalPositionalEmbedding): + # normalize the weight to normal distribution. + std = math.sqrt(4.0 / (module.embed_dim * module.kernel_size)) + torch.nn.init.normal_(module.conv.weight, mean=0.0, std=std) + torch.nn.init.constant_(module.conv.bias, 0.0) + elif isinstance(module, components.SelfAttention): + # normalize the query, key, value, and out_proj parameters in self attention module. + torch.nn.init.xavier_uniform_(module.k_proj.weight, gain=1 / math.sqrt(2)) + torch.nn.init.xavier_uniform_(module.v_proj.weight, gain=1 / math.sqrt(2)) + torch.nn.init.xavier_uniform_(module.q_proj.weight, gain=1 / math.sqrt(2)) + torch.nn.init.xavier_uniform_(module.out_proj.weight) + torch.nn.init.constant_(module.out_proj.bias, 0.0) + elif isinstance(module, components.Transformer): + module.apply(components._init_transformer_params) + else: + pass + + +def wavlm_model( + extractor_mode: str, + extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], + extractor_conv_bias: bool, + encoder_embed_dim: int, + encoder_projection_dropout: float, + encoder_pos_conv_kernel: int, + encoder_pos_conv_groups: int, + encoder_num_layers: int, + encoder_use_attention: List[bool], + encoder_use_feed_forward: List[bool], + encoder_total_num_heads: List[int], + encoder_remaining_heads: List[List[int]], + encoder_num_buckets: int, + encoder_max_distance: int, + encoder_attention_dropout: float, + encoder_ff_interm_features: List[int], + encoder_ff_interm_dropout: float, + encoder_dropout: float, + encoder_layer_norm_first: bool, + encoder_layer_drop: float, + aux_num_out: Optional[int], + normalize_waveform: bool, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds custom WaveLM model :cite:`chen2022wavlm`. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output object is + :class:`~torchaudio.models.Wav2Vec2Model`. Most of the arguments have the same meaning + as in :py:func:`wav2vec2_model` so please refer there for documentation. + + Args: + extractor_mode (str): Operation mode of feature extractor. + See :py:func:`wav2vec2_model`. + + extractor_conv_layer_config (list of integer tuples or None): + See :py:func:`wav2vec2_model`. + + extractor_conv_bias (bool): + See :py:func:`wav2vec2_model`. + + encoder_embed_dim (int): + See :py:func:`wav2vec2_model`. + + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + + encoder_pos_conv_kernel (int): + See :py:func:`wav2vec2_model`. + + encoder_pos_conv_groups (int): + See :py:func:`wav2vec2_model`. + + encoder_num_layers (int): + See :py:func:`wav2vec2_model`. + + encoder_num_heads (int): + See :py:func:`wav2vec2_model`. + + encoder_num_buckets (int): + Number of buckets for relative position embedding. + encoder_max_distance (int): + Maximum distance for relative position embedding. + + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + + encoder_ff_interm_features (int): + See :py:func:`wav2vec2_model`. + + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + + encoder_layer_norm_first (bool): + See :py:func:`wav2vec2_model`. + + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + + aux_num_out (int or None): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + if extractor_conv_layer_config is None: + extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 + + feature_extractor = components._get_feature_extractor( + extractor_mode, extractor_conv_layer_config, extractor_conv_bias, + prune_conv_channels=extractor_prune_conv_channels, + ) + encoder = components._get_wavlm_encoder( + in_features=extractor_conv_layer_config[-1][0], + embed_dim=encoder_embed_dim, + dropout_input=encoder_projection_dropout, + pos_conv_kernel=encoder_pos_conv_kernel, + pos_conv_groups=encoder_pos_conv_groups, + num_layers=encoder_num_layers, + use_attention=encoder_use_attention, + use_feed_forward=encoder_use_feed_forward, + total_num_heads=encoder_total_num_heads, + remaining_heads=encoder_remaining_heads, + num_buckets=encoder_num_buckets, + max_distance=encoder_max_distance, + attention_dropout=encoder_attention_dropout, + ff_interm_features=encoder_ff_interm_features, + ff_interm_dropout=encoder_ff_interm_dropout, + dropout=encoder_dropout, + layer_norm_first=encoder_layer_norm_first, + layer_drop=encoder_layer_drop, + prune_attention_heads=encoder_prune_attention_heads, + prune_attention_layer=encoder_prune_attention_layer, + prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + aux = None + if aux_num_out is not None: + aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out) + return Wav2Vec2Model(normalize_waveform, feature_extractor, encoder, aux) + + +def wavlm_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "base" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is + :class:`~torchaudio.models.Wav2Vec2Model`. + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + return wavlm_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_num_buckets=320, + encoder_max_distance=800, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def wavlm_large( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "large" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is + :class:`~torchaudio.models.Wav2Vec2Model`. + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + return wavlm_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_num_buckets=320, + encoder_max_distance=800, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) diff --git a/vencoder/dphubert/pruning_utils.py b/vencoder/dphubert/pruning_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac185980c2c3da716bf3ce402a541ffe70776acf --- /dev/null +++ b/vencoder/dphubert/pruning_utils.py @@ -0,0 +1,51 @@ +"""Utility functions for pruning.""" + +from typing import Union + +import torch +import torch.nn as nn + + +def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: str): + "Prune linear layer in place." + # NOTE: weight: (out_features, in_features), bias: (out_features,) + if dim == "input": + dim = 1 + layer.in_features = len(index) + elif dim == "output": + dim = 0 + layer.out_features = len(index) + else: + raise ValueError + + layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach()) + if layer.bias is not None and dim == 0: + layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach()) + + +def prune_conv1d_layer(layer: nn.Conv1d, index: torch.LongTensor, dim: str): + """Prune conv1d in place.""" + # NOTE: weight: (out_channels, in_channels, kernel_size), bias: (out_channels,) + if dim == "input": + dim = 1 + layer.in_channels = len(index) + elif dim == "output": + dim = 0 + layer.out_channels = len(index) + else: + raise ValueError + + layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach()) + if layer.bias is not None and dim == 0: + layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach()) + + +def prune_layer_norm(layernorm: Union[nn.LayerNorm, nn.GroupNorm], index: torch.LongTensor): + """Prune layer norm or group norm in place.""" + layernorm.weight = nn.Parameter(layernorm.weight.index_select(0, index).clone().detach()) + layernorm.bias = nn.Parameter(layernorm.bias.index_select(0, index).clone().detach()) + if isinstance(layernorm, nn.LayerNorm): + layernorm.normalized_shape = (len(index),) + elif isinstance(layernorm, nn.GroupNorm): + layernorm.num_groups = len(index) + layernorm.num_channels = len(index) diff --git a/vencoder/dphubert/utils/__init__.py b/vencoder/dphubert/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vencoder/dphubert/utils/import_huggingface_wavlm.py b/vencoder/dphubert/utils/import_huggingface_wavlm.py new file mode 100644 index 0000000000000000000000000000000000000000..24a3f38ae9cc08e19010b2876b19dc9082873377 --- /dev/null +++ b/vencoder/dphubert/utils/import_huggingface_wavlm.py @@ -0,0 +1,129 @@ +"""Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format. + +Originally from: +https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/utils/import_huggingface.py + +""" + +import logging +from typing import Any, Dict + +from torch.nn import Module + +from ..model import Wav2Vec2Model, wav2vec2_model, wavlm_model + +_LG = logging.getLogger(__name__) + + +def _get_config(cfg): + config = { + "extractor_mode": f"{cfg.feat_extract_norm}_norm", + "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), + "extractor_conv_bias": cfg.conv_bias, + "encoder_embed_dim": cfg.hidden_size, + "encoder_projection_dropout": cfg.feat_proj_dropout, + "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings, + "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups, + "encoder_num_layers": cfg.num_hidden_layers, + "encoder_num_heads": cfg.num_attention_heads, + "encoder_attention_dropout": cfg.attention_dropout, + "encoder_ff_interm_features": cfg.intermediate_size, + "encoder_ff_interm_dropout": cfg.activation_dropout, + "encoder_dropout": cfg.hidden_dropout, + "encoder_layer_norm_first": cfg.do_stable_layer_norm, + "encoder_layer_drop": cfg.layerdrop, + } + return config + + +def _get_config_wavlm(cfg): + config = { + "extractor_mode": f"{cfg.feat_extract_norm}_norm", + "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), + "extractor_conv_bias": cfg.conv_bias, + "encoder_embed_dim": cfg.hidden_size, + "encoder_projection_dropout": cfg.feat_proj_dropout, + "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings, + "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups, + "encoder_num_layers": cfg.num_hidden_layers, + "encoder_use_attention": [True] * cfg.num_hidden_layers, + "encoder_use_feed_forward": [True] * cfg.num_hidden_layers, + "encoder_total_num_heads": [cfg.num_attention_heads for _ in range(cfg.num_hidden_layers)], + "encoder_remaining_heads": [list(range(cfg.num_attention_heads)) for _ in range(cfg.num_hidden_layers)], + "encoder_num_buckets": cfg.num_buckets, + "encoder_max_distance": cfg.max_bucket_distance, + "encoder_attention_dropout": cfg.attention_dropout, + "encoder_ff_interm_features": [cfg.intermediate_size for _ in range(cfg.num_hidden_layers)], + "encoder_ff_interm_dropout": cfg.activation_dropout, + "encoder_dropout": cfg.hidden_dropout, + "encoder_layer_norm_first": cfg.do_stable_layer_norm, + "encoder_layer_drop": cfg.layerdrop, + "normalize_waveform": cfg.feat_extract_norm == "layer", + } + return config + + +def _build(config, original): + is_for_ctc = original.__class__.__name__ in ["Wav2Vec2ForCTC", "WavLMForCTC"] + if is_for_ctc: + aux_num_out = original.config.vocab_size + wav2vec2 = original.wav2vec2 + else: + _LG.warning( + "The model is not an instance of Wav2Vec2ForCTC or WavLMForCTC. " '"lm_head" module is not imported.' + ) + aux_num_out = None + wav2vec2 = original + is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"] + if is_wavlm: + imported = wavlm_model(**config, aux_num_out=aux_num_out) + else: + imported = wav2vec2_model(**config, aux_num_out=aux_num_out) + print(imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict(), strict=False)) + print(imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict(), strict=False)) + encoder_state_dict = wav2vec2.encoder.state_dict() + if is_wavlm: # Rename paramaters of linear transformations for compatibility with the HF model + transform_wavlm_encoder_state(encoder_state_dict, config["encoder_num_layers"]) + print(imported.encoder.transformer.load_state_dict(encoder_state_dict, strict=False)) + if is_for_ctc: + imported.aux.load_state_dict(original.lm_head.state_dict()) + return imported + + +def transform_wavlm_encoder_state(state: Dict[str, Any], encoder_num_layers: int): + """Converts WavLM encoder state from HuggingFace format. In particular, concatenates linear projection weights and + biases to align with the structure of ``torch.nn.MultiheadAttention``. + """ + pass + + +def import_huggingface_model(original: Module) -> Wav2Vec2Model: + """Builds :class:`Wav2Vec2Model` from the corresponding model object of + `Transformers `_. + + Args: + original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``. + + Returns: + Wav2Vec2Model: Imported model. + + Example + >>> from torchaudio.models.wav2vec2.utils import import_huggingface_model + >>> + >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") + >>> model = import_huggingface_model(original) + >>> + >>> waveforms, _ = torchaudio.load("audio.wav") + >>> logits, _ = model(waveforms) + """ + _LG.info("Importing model.") + _LG.info("Loading model configuration.") + is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"] + if is_wavlm: + config = _get_config_wavlm(original.config) + else: + config = _get_config(original.config) + _LG.debug(" - config: %s", config) + _LG.info("Building model.") + imported = _build(config, original) + return imported diff --git a/vencoder/encoder.py b/vencoder/encoder.py index 670b5bb7682b16bea1644d036eddc0466cfefd9b..9ad120da34893d64b47b8ebeeaaed1f822a2e0be 100644 --- a/vencoder/encoder.py +++ b/vencoder/encoder.py @@ -1,12 +1,13 @@ class SpeechEncoder(object): - def __init__(self,vec_path = "pretrain/checkpoint_best_legacy_500.pt",device=None): - self.model = None #This is Model + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + self.model = None # This is Model self.hidden_dim = 768 pass - def encoder(self,wav): - ''' - input: wav:[batchsize,signal_length] - output: embedding:[batchsize,wav_frame,hidden_dim] - ''' - pass \ No newline at end of file + + def encoder(self, wav): + """ + input: wav:[signal_length] + output: embedding:[batchsize,hidden_dim,wav_frame] + """ + pass diff --git a/vencoder/wavlm/WavLM.py b/vencoder/wavlm/WavLM.py new file mode 100644 index 0000000000000000000000000000000000000000..5a3986fdcc00033a9e8f1bfcd25df3799f40ed90 --- /dev/null +++ b/vencoder/wavlm/WavLM.py @@ -0,0 +1,741 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import logging +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm + +from vencoder.wavlm.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GLU_Linear, + GradMultiply, + MultiheadAttention, + SamePad, + TransposeLast, + get_activation_fn, + init_bert_params, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default" + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride, padding=1) + ) + self.conv_layers.append( + torch.nn.LayerNorm([dim, idim]) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append( + torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + ) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, + self_attn_mask=streaming_mask, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + diff --git a/vencoder/wavlm/modules.py b/vencoder/wavlm/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..add4a1aa0042cbcbf5c3b28d4d72f017b507717d --- /dev/null +++ b/vencoder/wavlm/modules.py @@ -0,0 +1,828 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Parameter + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function + """ + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights diff --git a/vencoder/whisper/audio.py b/vencoder/whisper/audio.py index 3bdb70ba9357e95ff05853dcc06437c3401ef3be..05890dc195a376181c21072eb0a8af24cf29928a 100644 --- a/vencoder/whisper/audio.py +++ b/vencoder/whisper/audio.py @@ -1,4 +1,3 @@ -import os from functools import lru_cache from typing import Union @@ -6,11 +5,10 @@ import ffmpeg import numpy as np import torch import torch.nn.functional as F +from librosa.filters import mel as librosa_mel_fn from .utils import exact_div -from librosa.filters import mel as librosa_mel_fn - # hard-coded audio hyperparameters SAMPLE_RATE = 16000 N_FFT = 400 diff --git a/vencoder/whisper/decoding.py b/vencoder/whisper/decoding.py index 603546d4c9ff67514d2567576935b974fe373bef..45e50b1c33c2c8f9ca6572e6175b8d6051ae02ee 100644 --- a/vencoder/whisper/decoding.py +++ b/vencoder/whisper/decoding.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -32,7 +32,7 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) if tokenizer is None: tokenizer = get_tokenizer(model.is_multilingual) if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: - raise ValueError(f"This model doesn't have language tokens so it can't perform lang id") + raise ValueError("This model doesn't have language tokens so it can't perform lang id") single = mel.ndim == 2 if single: diff --git a/vencoder/whisper/model.py b/vencoder/whisper/model.py index cb3781c17a1e78a33bf62246e5134e8512206d0d..f3de4d32cb9646964074401aad176dbef9ef2125 100644 --- a/vencoder/whisper/model.py +++ b/vencoder/whisper/model.py @@ -1,14 +1,13 @@ from dataclasses import dataclass -from typing import Dict -from typing import Iterable, Optional +from typing import Dict, Iterable, Optional import numpy as np import torch import torch.nn.functional as F -from torch import Tensor -from torch import nn +from torch import Tensor, nn -from .decoding import detect_language as detect_language_function, decode as decode_function +from .decoding import decode as decode_function +from .decoding import detect_language as detect_language_function @dataclass diff --git a/vencoder/whisper/tokenizer.py b/vencoder/whisper/tokenizer.py index a27cb359ee891590d3f793624f9f8ec768a26cc3..b15645dc7e15ca9f601413076299b362293eae6d 100644 --- a/vencoder/whisper/tokenizer.py +++ b/vencoder/whisper/tokenizer.py @@ -196,7 +196,7 @@ class Tokenizer: def language_token(self) -> int: """Returns the token id corresponding to the value of the `language` field""" if self.language is None: - raise ValueError(f"This tokenizer does not have language token configured") + raise ValueError("This tokenizer does not have language token configured") additional_tokens = dict( zip(