diffsingerkr / Modules /Diffusion.py
codejin's picture
initial commit
67d041f
import torch
import math
from argparse import Namespace
from typing import Optional, List, Dict, Union
from tqdm import tqdm
from .Layer import Conv1d, Lambda
class Diffusion(torch.nn.Module):
def __init__(
self,
hyper_parameters: Namespace
):
super().__init__()
self.hp = hyper_parameters
if self.hp.Feature_Type == 'Mel':
self.feature_size = self.hp.Sound.Mel_Dim
elif self.hp.Feature_Type == 'Spectrogram':
self.feature_size = self.hp.Sound.N_FFT // 2 + 1
self.denoiser = Denoiser(
hyper_parameters= self.hp
)
self.timesteps = self.hp.Diffusion.Max_Step
betas = torch.linspace(1e-4, 0.06, self.timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis= 0)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('alphas_cumprod', alphas_cumprod) # [Diffusion_t]
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) # [Diffusion_t]
self.register_buffer('sqrt_alphas_cumprod', alphas_cumprod.sqrt())
self.register_buffer('sqrt_one_minus_alphas_cumprod', (1.0 - alphas_cumprod).sqrt())
self.register_buffer('sqrt_recip_alphas_cumprod', (1.0 / alphas_cumprod).sqrt())
self.register_buffer('sqrt_recipm1_alphas_cumprod', (1.0 / alphas_cumprod - 1.0).sqrt())
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance', torch.maximum(posterior_variance, torch.tensor([1e-20])).log())
self.register_buffer('posterior_mean_coef1', betas * alphas_cumprod_prev.sqrt() / (1.0 - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * alphas.sqrt() / (1.0 - alphas_cumprod))
def forward(
self,
encodings: torch.Tensor,
features: torch.Tensor= None
):
'''
encodings: [Batch, Enc_d, Enc_t]
features: [Batch, Feature_d, Feature_t]
feature_lengths: [Batch]
'''
if not features is None: # train
diffusion_steps = torch.randint(
low= 0,
high= self.timesteps,
size= (encodings.size(0),),
dtype= torch.long,
device= encodings.device
) # random single step
noises, epsilons = self.Get_Noise_Epsilon_for_Train(
features= features,
encodings= encodings,
diffusion_steps= diffusion_steps,
)
return None, noises, epsilons
else: # inference
features = self.Sampling(
encodings= encodings,
)
return features, None, None
def Sampling(
self,
encodings: torch.Tensor,
):
features = torch.randn(
size= (encodings.size(0), self.feature_size, encodings.size(2)),
device= encodings.device
)
for diffusion_step in reversed(range(self.timesteps)):
features = self.P_Sampling(
features= features,
encodings= encodings,
diffusion_steps= torch.full(
size= (encodings.size(0), ),
fill_value= diffusion_step,
dtype= torch.long,
device= encodings.device
),
)
return features
def P_Sampling(
self,
features: torch.Tensor,
encodings: torch.Tensor,
diffusion_steps: torch.Tensor,
):
posterior_means, posterior_log_variances = self.Get_Posterior(
features= features,
encodings= encodings,
diffusion_steps= diffusion_steps,
)
noises = torch.randn_like(features) # [Batch, Feature_d, Feature_d]
masks = (diffusion_steps > 0).float().unsqueeze(1).unsqueeze(1) #[Batch, 1, 1]
return posterior_means + masks * (0.5 * posterior_log_variances).exp() * noises
def Get_Posterior(
self,
features: torch.Tensor,
encodings: torch.Tensor,
diffusion_steps: torch.Tensor
):
noised_predictions = self.denoiser(
features= features,
encodings= encodings,
diffusion_steps= diffusion_steps
)
epsilons = \
features * self.sqrt_recip_alphas_cumprod[diffusion_steps][:, None, None] - \
noised_predictions * self.sqrt_recipm1_alphas_cumprod[diffusion_steps][:, None, None]
epsilons.clamp_(-1.0, 1.0) # clipped
posterior_means = \
epsilons * self.posterior_mean_coef1[diffusion_steps][:, None, None] + \
features * self.posterior_mean_coef2[diffusion_steps][:, None, None]
posterior_log_variances = \
self.posterior_log_variance[diffusion_steps][:, None, None]
return posterior_means, posterior_log_variances
def Get_Noise_Epsilon_for_Train(
self,
features: torch.Tensor,
encodings: torch.Tensor,
diffusion_steps: torch.Tensor,
):
noises = torch.randn_like(features)
noised_features = \
features * self.sqrt_alphas_cumprod[diffusion_steps][:, None, None] + \
noises * self.sqrt_one_minus_alphas_cumprod[diffusion_steps][:, None, None]
epsilons = self.denoiser(
features= noised_features,
encodings= encodings,
diffusion_steps= diffusion_steps
)
return noises, epsilons
def DDIM(
self,
encodings: torch.Tensor,
ddim_steps: int,
eta: float= 0.0,
temperature: float= 1.0,
use_tqdm: bool= False
):
ddim_timesteps = self.Get_DDIM_Steps(
ddim_steps= ddim_steps
)
sigmas, alphas, alphas_prev = self.Get_DDIM_Sampling_Parameters(
ddim_timesteps= ddim_timesteps,
eta= eta
)
sqrt_one_minus_alphas = (1. - alphas).sqrt()
features = torch.randn(
size= (encodings.size(0), self.feature_size, encodings.size(2)),
device= encodings.device
)
setp_range = reversed(range(ddim_steps))
if use_tqdm:
tqdm(
setp_range,
desc= '[Diffusion]',
total= ddim_steps
)
for diffusion_steps in setp_range:
noised_predictions = self.denoiser(
features= features,
encodings= encodings,
diffusion_steps= torch.full(
size= (encodings.size(0), ),
fill_value= diffusion_steps,
dtype= torch.long,
device= encodings.device
)
)
feature_starts = (features - sqrt_one_minus_alphas[diffusion_steps] * noised_predictions) / alphas[diffusion_steps].sqrt()
direction_pointings = (1.0 - alphas_prev[diffusion_steps] - sigmas[diffusion_steps].pow(2.0)) * noised_predictions
noises = sigmas[diffusion_steps] * torch.randn_like(features) * temperature
features = alphas_prev[diffusion_steps].sqrt() * feature_starts + direction_pointings + noises
return features
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
def Get_DDIM_Steps(
self,
ddim_steps: int,
ddim_discr_method: str= 'uniform'
):
if ddim_discr_method == 'uniform':
ddim_timesteps = torch.arange(0, self.timesteps, self.timesteps // ddim_steps).long()
elif ddim_discr_method == 'quad':
ddim_timesteps = torch.linspace(0, (torch.tensor(self.timesteps) * 0.8).sqrt(), ddim_steps).pow(2.0).long()
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
ddim_timesteps[-1] = self.timesteps - 1
return ddim_timesteps
def Get_DDIM_Sampling_Parameters(self, ddim_timesteps, eta):
alphas = self.alphas_cumprod[ddim_timesteps]
alphas_prev = self.alphas_cumprod_prev[ddim_timesteps]
sigmas = eta * ((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)).sqrt()
return sigmas, alphas, alphas_prev
class Denoiser(torch.nn.Module):
def __init__(
self,
hyper_parameters: Namespace
):
super().__init__()
self.hp = hyper_parameters
if self.hp.Feature_Type == 'Mel':
feature_size = self.hp.Sound.Mel_Dim
elif self.hp.Feature_Type == 'Spectrogram':
feature_size = self.hp.Sound.N_FFT // 2 + 1
self.prenet = torch.nn.Sequential(
Conv1d(
in_channels= feature_size,
out_channels= self.hp.Diffusion.Size,
kernel_size= 1,
w_init_gain= 'relu'
),
torch.nn.Mish()
)
self.step_ffn = torch.nn.Sequential(
Diffusion_Embedding(
channels= self.hp.Diffusion.Size
),
Lambda(lambda x: x.unsqueeze(2)),
Conv1d(
in_channels= self.hp.Diffusion.Size,
out_channels= self.hp.Diffusion.Size * 4,
kernel_size= 1,
w_init_gain= 'relu'
),
torch.nn.Mish(),
Conv1d(
in_channels= self.hp.Diffusion.Size * 4,
out_channels= self.hp.Diffusion.Size,
kernel_size= 1,
w_init_gain= 'linear'
)
)
self.residual_blocks = torch.nn.ModuleList([
Residual_Block(
in_channels= self.hp.Diffusion.Size,
kernel_size= self.hp.Diffusion.Kernel_Size,
condition_channels= self.hp.Encoder.Size + feature_size
)
for _ in range(self.hp.Diffusion.Stack)
])
self.projection = torch.nn.Sequential(
Conv1d(
in_channels= self.hp.Diffusion.Size,
out_channels= self.hp.Diffusion.Size,
kernel_size= 1,
w_init_gain= 'relu'
),
torch.nn.ReLU(),
Conv1d(
in_channels= self.hp.Diffusion.Size,
out_channels= feature_size,
kernel_size= 1
),
)
torch.nn.init.zeros_(self.projection[-1].weight) # This is key factor....
def forward(
self,
features: torch.Tensor,
encodings: torch.Tensor,
diffusion_steps: torch.Tensor
):
'''
features: [Batch, Feature_d, Feature_t]
encodings: [Batch, Enc_d, Feature_t]
diffusion_steps: [Batch]
'''
x = self.prenet(features)
diffusion_steps = self.step_ffn(diffusion_steps) # [Batch, Res_d, 1]
skips_list = []
for residual_block in self.residual_blocks:
x, skips = residual_block(
x= x,
conditions= encodings,
diffusion_steps= diffusion_steps
)
skips_list.append(skips)
x = torch.stack(skips_list, dim= 0).sum(dim= 0) / math.sqrt(self.hp.Diffusion.Stack)
x = self.projection(x)
return x
class Diffusion_Embedding(torch.nn.Module):
def __init__(
self,
channels: int
):
super().__init__()
self.channels = channels
def forward(self, x: torch.Tensor):
half_channels = self.channels // 2 # sine and cosine
embeddings = math.log(10000.0) / (half_channels - 1)
embeddings = torch.exp(torch.arange(half_channels, device= x.device) * -embeddings)
embeddings = x.unsqueeze(1) * embeddings.unsqueeze(0)
embeddings = torch.cat([embeddings.sin(), embeddings.cos()], dim= -1)
return embeddings
class Residual_Block(torch.nn.Module):
def __init__(
self,
in_channels: int,
kernel_size: int,
condition_channels: int
):
super().__init__()
self.in_channels = in_channels
self.condition = Conv1d(
in_channels= condition_channels,
out_channels= in_channels * 2,
kernel_size= 1
)
self.diffusion_step = Conv1d(
in_channels= in_channels,
out_channels= in_channels,
kernel_size= 1
)
self.conv = Conv1d(
in_channels= in_channels,
out_channels= in_channels * 2,
kernel_size= kernel_size,
padding= kernel_size // 2
)
self.projection = Conv1d(
in_channels= in_channels,
out_channels= in_channels * 2,
kernel_size= 1
)
def forward(
self,
x: torch.Tensor,
conditions: torch.Tensor,
diffusion_steps: torch.Tensor
):
residuals = x
conditions = self.condition(conditions)
diffusion_steps = self.diffusion_step(diffusion_steps)
x = self.conv(x + diffusion_steps) + conditions
x_a, x_b = x.chunk(chunks= 2, dim= 1)
x = x_a.sigmoid() * x_b.tanh()
x = self.projection(x)
x, skips = x.chunk(chunks= 2, dim= 1)
return (x + residuals) / math.sqrt(2.0), skips