ymzhang319's picture
init
7f2690b
raw
history blame contribute delete
No virus
66.5 kB
import torch
import torchvision
import torchaudio
import torchvision.transforms as transforms
from diffusers import UNet2DConditionModel, ControlNetModel
from foleycrafter.pipelines.pipeline_controlnet import StableDiffusionControlNetPipeline
from foleycrafter.pipelines.auffusion_pipeline import AuffusionNoAdapterPipeline, Generator
from foleycrafter.models.auffusion_unet import UNet2DConditionModel as af_UNet2DConditionModel
from diffusers.models import AutoencoderKLTemporalDecoder, AutoencoderKL
from diffusers.schedulers import EulerDiscreteScheduler, DDIMScheduler, PNDMScheduler, KarrasDiffusionSchedulers
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection,\
SpeechT5HifiGan, ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast,\
CLIPTextModel, CLIPTokenizer
import glob
from moviepy.editor import ImageSequenceClip, AudioFileClip, VideoFileClip, VideoClip
from moviepy.audio.AudioClip import AudioArrayClip
import numpy as np
from safetensors import safe_open
import random
from typing import Union, Optional
import decord
import os
import os.path as osp
import imageio
import soundfile as sf
from PIL import Image, ImageOps
import torch.distributed as dist
import io
from omegaconf import OmegaConf
import json
from dataclasses import dataclass
from enum import Enum
import typing as T
import warnings
import pydub
from scipy.io import wavfile
from einops import rearrange
def zero_rank_print(s):
if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s, flush=True)
def build_foleycrafter(
pretrained_model_name_or_path: str="auffusion/auffusion-full-no-adapter",
) -> StableDiffusionControlNetPipeline:
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
unet = af_UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
scheduler = PNDMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
controlnet = ControlNetModel.from_unet(unet, conditioning_channels=1)
pipe = StableDiffusionControlNetPipeline(
vae=vae,
controlnet=controlnet,
unet=unet,
scheduler=scheduler,
tokenizer=tokenizer,
text_encoder=text_encoder,
feature_extractor=None,
safety_checker=None,
requires_safety_checker=False,
)
return pipe
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
if len(videos.shape) == 4:
videos = videos.unsqueeze(0)
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = torch.clamp((x * 255), 0, 255).numpy().astype(np.uint8)
outputs.append(x)
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.mimsave(path, outputs, fps=fps)
def save_videos_from_pil_list(videos: list, path: str, fps=7):
for i in range(len(videos)):
videos[i] = ImageOps.scale(videos[i], 255)
imageio.mimwrite(path, videos, fps=fps)
def seed_everything(seed: int) -> None:
r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`,
:obj:`numpy` and :python:`Python`.
Args:
seed (int): The desired seed.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_video_frames(video: np.ndarray, num_frames: int=200):
video_length = video.shape[0]
video_idx = np.linspace(0, video_length-1, num_frames, dtype=int)
video = video[video_idx, ...]
return video
def random_audio_video_clip(audio: np.ndarray, video: np.ndarray, fps:float, \
sample_rate:int=16000, duration:int=5, num_frames: int=20):
"""
Random sample video clips with duration
"""
video_length = video.shape[0]
audio_length = audio.shape[-1]
av_duration = int(video_length / fps)
assert av_duration >= duration,\
f"video duration {av_duration} is less than {duration}"
# random sample start time
start_time = random.uniform(0, av_duration - duration)
end_time = start_time + duration
start_idx, end_idx = start_time / av_duration, end_time / av_duration
video_start_frame, video_end_frame\
= video_length * start_idx, video_length * end_idx
audio_start_frame, audio_end_frame\
= audio_length * start_idx, audio_length * end_idx
# print(f"time_idx : {start_time}:{end_time}")
# print(f"video_idx: {video_start_frame}:{video_end_frame}")
# print(f"audio_idx: {audio_start_frame}:{audio_end_frame}")
audio_idx = np.linspace(audio_start_frame, audio_end_frame, sample_rate * duration, dtype=int)
video_idx = np.linspace(video_start_frame, video_end_frame, num_frames, dtype=int)
audio = audio[..., audio_idx]
video = video[video_idx, ...]
return audio, video
def get_full_indices(reader: Union[decord.VideoReader, decord.AudioReader])\
-> np.ndarray:
if isinstance(reader, decord.VideoReader):
return np.linspace(0, len(reader) - 1, len(reader), dtype=int)
elif isinstance(reader, decord.AudioReader):
return np.linspace(0, reader.shape[-1] - 1, reader.shape[-1], dtype=int)
def get_frames(video_path:str, onset_list, frame_nums=1024):
video = decord.VideoReader(video_path)
video_frame = len(video)
frames_list = []
for start, end in onset_list:
video_start = int(start / frame_nums * video_frame)
video_end = int(end / frame_nums * video_frame)
frames_list.extend(range(video_start, video_end))
frames = video.get_batch(frames_list).asnumpy()
return frames
def get_frames_in_video(video_path:str, onset_list, frame_nums=1024, audio_length_in_s=10):
# this function consider the video length
video = decord.VideoReader(video_path)
video_frame = len(video)
duration = video_frame / video.get_avg_fps()
frames_list = []
video_onset_list = []
for start, end in onset_list:
if int(start / frame_nums * duration) >= audio_length_in_s:
continue
video_start = int(start / audio_length_in_s * duration / frame_nums * video_frame)
if video_start >= video_frame:
continue
video_end = int(end / audio_length_in_s * duration / frame_nums * video_frame)
video_onset_list.append([int(start / audio_length_in_s * duration), int(end / audio_length_in_s * duration)])
frames_list.extend(range(video_start, video_end))
frames = video.get_batch(frames_list).asnumpy()
return frames, video_onset_list
def save_multimodal(video, audio, output_path, audio_fps:int=16000, video_fps:int=8, remove_audio:bool=True):
imgs = [img for img in video]
# if audio.shape[0] == 1 or audio.shape[0] == 2:
# audio = audio.T #[len, channel]
# audio = np.repeat(audio, 2, axis=1)
output_dir = osp.dirname(output_path)
try:
wavfile.write(osp.join(output_dir, "audio.wav"), audio_fps, audio)
except:
sf.write(osp.join(output_dir, "audio.wav"), audio, audio_fps)
audio_clip = AudioFileClip(osp.join(output_dir, "audio.wav"))
# audio_clip = AudioArrayClip(audio, fps=audio_fps)
video_clip = ImageSequenceClip(imgs, fps=video_fps)
video_clip = video_clip.set_audio(audio_clip)
video_clip.write_videofile(output_path, video_fps, audio=True, audio_fps=audio_fps)
if remove_audio:
os.remove(osp.join(output_dir, "audio.wav"))
return
def save_multimodal_by_frame(video, audio, output_path, audio_fps:int=16000):
imgs = [img for img in video]
# if audio.shape[0] == 1 or audio.shape[0] == 2:
# audio = audio.T #[len, channel]
# audio = np.repeat(audio, 2, axis=1)
# output_dir = osp.dirname(output_path)
output_dir = output_path
wavfile.write(osp.join(output_dir, "audio.wav"), audio_fps, audio)
audio_clip = AudioFileClip(osp.join(output_dir, "audio.wav"))
# audio_clip = AudioArrayClip(audio, fps=audio_fps)
os.makedirs(osp.join(output_dir, 'frames'), exist_ok=True)
for num, img in enumerate(imgs):
if isinstance(img, np.ndarray):
img = Image.fromarray(img.astype(np.uint8))
img.save(osp.join(output_dir, 'frames', f"{num}.jpg"))
return
def sanity_check(data: dict, save_path: str="sanity_check", batch_size: int=4, sample_rate: int=16000):
video_path = osp.join(save_path, 'video')
audio_path = osp.join(save_path, 'audio')
av_path = osp.join(save_path, 'av')
video, audio, text = data['pixel_values'], data['audio'], data['text']
video = (video / 2 + 0.5).clamp(0, 1)
zero_rank_print(f"Saving {text} audio: {audio[0].shape} video: {video[0].shape}")
for bsz in range(batch_size):
os.makedirs(video_path, exist_ok=True)
os.makedirs(audio_path, exist_ok=True)
os.makedirs(av_path, exist_ok=True)
# save_videos_grid(video[bsz:bsz+1,...], f"{osp.join(video_path, str(bsz) + '.mp4')}")
bsz_audio = audio[bsz,...].permute(1, 0).cpu().numpy()
bsz_video = video_tensor_to_np(video[bsz, ...])
sf.write(f"{osp.join(audio_path, str(bsz) + '.wav')}", bsz_audio, sample_rate)
save_multimodal(bsz_video, bsz_audio, osp.join(av_path, str(bsz) + '.mp4'))
def video_tensor_to_np(video: torch.Tensor, rescale: bool=True, scale: bool=False):
if scale:
video = (video / 2 + 0.5).clamp(0, 1)
# c f h w -> f h w c
if video.shape[0] == 3:
video = video.permute(1, 2, 3, 0).detach().cpu().numpy()
elif video.shape[1] == 3:
video = video.permute(0, 2, 3, 1).detach().cpu().numpy()
if rescale:
video = video * 255
return video
def composite_audio_video(video: str, audio: str, path:str, video_fps:int=7, audio_sample_rate:int=16000):
video = decord.VideoReader(video)
audio = decord.AudioReader(audio, sample_rate=audio_sample_rate)
audio = audio.get_batch(get_full_indices(audio)).asnumpy()
video = video.get_batch(get_full_indices(video)).asnumpy()
save_multimodal(video, audio, path, audio_fps=audio_sample_rate, video_fps=video_fps)
return
# for video pipeline
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
def resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
h, w = input.shape[-2:]
factors = (h / size[0], w / size[1])
# First, we have to determine sigma
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
sigmas = (
max((factors[0] - 1.0) / 2.0, 0.001),
max((factors[1] - 1.0) / 2.0, 0.001),
)
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
# Make sure it is odd
if (ks[0] % 2) == 0:
ks = ks[0] + 1, ks[1]
if (ks[1] % 2) == 0:
ks = ks[0], ks[1] + 1
input = _gaussian_blur2d(input, ks, sigmas)
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
return output
def _gaussian_blur2d(input, kernel_size, sigma):
if isinstance(sigma, tuple):
sigma = torch.tensor([sigma], dtype=input.dtype)
else:
sigma = sigma.to(dtype=input.dtype)
ky, kx = int(kernel_size[0]), int(kernel_size[1])
bs = sigma.shape[0]
kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
out_x = _filter2d(input, kernel_x[..., None, :])
out = _filter2d(out_x, kernel_y[..., None])
return out
def _filter2d(input, kernel):
# prepare kernel
b, c, h, w = input.shape
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
height, width = tmp_kernel.shape[-2:]
padding_shape: list[int] = _compute_padding([height, width])
input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
# kernel and input tensor reshape to align element-wise or batch-wise params
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
# convolve the tensor with the kernel.
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
out = output.view(b, c, h, w)
return out
def _gaussian(window_size: int, sigma):
if isinstance(sigma, float):
sigma = torch.tensor([[sigma]])
batch_size = sigma.shape[0]
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
if window_size % 2 == 0:
x = x + 0.5
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
return gauss / gauss.sum(-1, keepdim=True)
def _compute_padding(kernel_size):
"""Compute padding tuple."""
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
if len(kernel_size) < 2:
raise AssertionError(kernel_size)
computed = [k - 1 for k in kernel_size]
# for even kernels we need to do asymmetric padding :(
out_padding = 2 * len(kernel_size) * [0]
for i in range(len(kernel_size)):
computed_tmp = computed[-(i + 1)]
pad_front = computed_tmp // 2
pad_rear = computed_tmp - pad_front
out_padding[2 * i + 0] = pad_front
out_padding[2 * i + 1] = pad_rear
return out_padding
def print_gpu_memory_usage(info: str, cuda_id:int=0):
print(f">>> {info} <<<")
reserved = torch.cuda.memory_reserved(cuda_id) / 1024 ** 3
used = torch.cuda.memory_allocated(cuda_id) / 1024 ** 3
print("total: ", reserved, "G")
print("used: ", used, "G")
print("available: ", reserved - used, "G")
# use for dsp mel2spec
@dataclass(frozen=True)
class SpectrogramParams:
"""
Parameters for the conversion from audio to spectrograms to images and back.
Includes helpers to convert to and from EXIF tags, allowing these parameters to be stored
within spectrogram images.
To understand what these parameters do and to customize them, read `spectrogram_converter.py`
and the linked torchaudio documentation.
"""
# Whether the audio is stereo or mono
stereo: bool = False
# FFT parameters
sample_rate: int = 44100
step_size_ms: int = 10
window_duration_ms: int = 100
padded_duration_ms: int = 400
# Mel scale parameters
num_frequencies: int = 200
# TODO(hayk): Set these to [20, 20000] for newer models
min_frequency: int = 0
max_frequency: int = 10000
mel_scale_norm: T.Optional[str] = None
mel_scale_type: str = "htk"
max_mel_iters: int = 200
# Griffin Lim parameters
num_griffin_lim_iters: int = 32
# Image parameterization
power_for_image: float = 0.25
class ExifTags(Enum):
"""
Custom EXIF tags for the spectrogram image.
"""
SAMPLE_RATE = 11000
STEREO = 11005
STEP_SIZE_MS = 11010
WINDOW_DURATION_MS = 11020
PADDED_DURATION_MS = 11030
NUM_FREQUENCIES = 11040
MIN_FREQUENCY = 11050
MAX_FREQUENCY = 11060
POWER_FOR_IMAGE = 11070
MAX_VALUE = 11080
@property
def n_fft(self) -> int:
"""
The number of samples in each STFT window, with padding.
"""
return int(self.padded_duration_ms / 1000.0 * self.sample_rate)
@property
def win_length(self) -> int:
"""
The number of samples in each STFT window.
"""
return int(self.window_duration_ms / 1000.0 * self.sample_rate)
@property
def hop_length(self) -> int:
"""
The number of samples between each STFT window.
"""
return int(self.step_size_ms / 1000.0 * self.sample_rate)
def to_exif(self) -> T.Dict[int, T.Any]:
"""
Return a dictionary of EXIF tags for the current values.
"""
return {
self.ExifTags.SAMPLE_RATE.value: self.sample_rate,
self.ExifTags.STEREO.value: self.stereo,
self.ExifTags.STEP_SIZE_MS.value: self.step_size_ms,
self.ExifTags.WINDOW_DURATION_MS.value: self.window_duration_ms,
self.ExifTags.PADDED_DURATION_MS.value: self.padded_duration_ms,
self.ExifTags.NUM_FREQUENCIES.value: self.num_frequencies,
self.ExifTags.MIN_FREQUENCY.value: self.min_frequency,
self.ExifTags.MAX_FREQUENCY.value: self.max_frequency,
self.ExifTags.POWER_FOR_IMAGE.value: float(self.power_for_image),
}
class SpectrogramImageConverter:
"""
Convert between spectrogram images and audio segments.
This is a wrapper around SpectrogramConverter that additionally converts from spectrograms
to images and back. The real audio processing lives in SpectrogramConverter.
"""
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
self.p = params
self.device = device
self.converter = SpectrogramConverter(params=params, device=device)
def spectrogram_image_from_audio(
self,
segment: pydub.AudioSegment,
) -> Image.Image:
"""
Compute a spectrogram image from an audio segment.
Args:
segment: Audio segment to convert
Returns:
Spectrogram image (in pillow format)
"""
assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch"
if self.p.stereo:
if segment.channels == 1:
print("WARNING: Mono audio but stereo=True, cloning channel")
segment = segment.set_channels(2)
elif segment.channels > 2:
print("WARNING: Multi channel audio, reducing to stereo")
segment = segment.set_channels(2)
else:
if segment.channels > 1:
print("WARNING: Stereo audio but stereo=False, setting to mono")
segment = segment.set_channels(1)
spectrogram = self.converter.spectrogram_from_audio(segment)
image = image_from_spectrogram(
spectrogram,
power=self.p.power_for_image,
)
# Store conversion params in exif metadata of the image
exif_data = self.p.to_exif()
exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram))
exif = image.getexif()
exif.update(exif_data.items())
return image
def audio_from_spectrogram_image(
self,
image: Image.Image,
apply_filters: bool = True,
max_value: float = 30e6,
) -> pydub.AudioSegment:
"""
Reconstruct an audio segment from a spectrogram image.
Args:
image: Spectrogram image (in pillow format)
apply_filters: Apply post-processing to improve the reconstructed audio
max_value: Scaled max amplitude of the spectrogram. Shouldn't matter.
"""
spectrogram = spectrogram_from_image(
image,
max_value=max_value,
power=self.p.power_for_image,
stereo=self.p.stereo,
)
segment = self.converter.audio_from_spectrogram(
spectrogram,
apply_filters=apply_filters,
)
return segment
def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image:
"""
Compute a spectrogram image from a spectrogram magnitude array.
This is the inverse of spectrogram_from_image, except for discretization error from
quantizing to uint8.
Args:
spectrogram: (channels, frequency, time)
power: A power curve to apply to the spectrogram to preserve contrast
Returns:
image: (frequency, time, channels)
"""
# Rescale to 0-1
max_value = np.max(spectrogram)
data = spectrogram / max_value
# Apply the power curve
data = np.power(data, power)
# Rescale to 0-255
data = data * 255
# Invert
data = 255 - data
# Convert to uint8
data = data.astype(np.uint8)
# Munge channels into a PIL image
if data.shape[0] == 1:
# TODO(hayk): Do we want to write single channel to disk instead?
image = Image.fromarray(data[0], mode="L").convert("RGB")
elif data.shape[0] == 2:
data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0)
image = Image.fromarray(data, mode="RGB")
else:
raise NotImplementedError(f"Unsupported number of channels: {data.shape[0]}")
# Flip Y
image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
return image
def spectrogram_from_image(
image: Image.Image,
power: float = 0.25,
stereo: bool = False,
max_value: float = 30e6,
) -> np.ndarray:
"""
Compute a spectrogram magnitude array from a spectrogram image.
This is the inverse of image_from_spectrogram, except for discretization error from
quantizing to uint8.
Args:
image: (frequency, time, channels)
power: The power curve applied to the spectrogram
stereo: Whether the spectrogram encodes stereo data
max_value: The max value of the original spectrogram. In practice doesn't matter.
Returns:
spectrogram: (channels, frequency, time)
"""
# Convert to RGB if single channel
if image.mode in ("P", "L"):
image = image.convert("RGB")
# Flip Y
image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
# Munge channels into a numpy array of (channels, frequency, time)
data = np.array(image).transpose(2, 0, 1)
if stereo:
# Take the G and B channels as done in image_from_spectrogram
data = data[[1, 2], :, :]
else:
data = data[0:1, :, :]
# Convert to floats
data = data.astype(np.float32)
# Invert
data = 255 - data
# Rescale to 0-1
data = data / 255
# Reverse the power curve
data = np.power(data, 1 / power)
# Rescale to max value
data = data * max_value
return data
class SpectrogramConverter:
"""
Convert between audio segments and spectrogram tensors using torchaudio.
In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values
that represent the amplitude of the frequency at that time bucket (in the frequency domain).
Frequencies are given in the perceptul Mel scale defined by the params. A more specific term
used in some functions is "mel amplitudes".
The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only
returns the amplitude, because the phase is chaotic and hard to learn. The function
`audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which
approximates the phase information using the Griffin-Lim algorithm.
Each channel in the audio is treated independently, and the spectrogram has a batch dimension
equal to the number of channels in the input audio segment.
Both the Griffin Lim algorithm and the Mel scaling process are lossy.
For more information, see https://pytorch.org/audio/stable/transforms.html
"""
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
self.p = params
self.device = check_device(device)
if device.lower().startswith("mps"):
warnings.warn(
"WARNING: MPS does not support audio operations, falling back to CPU for them",
stacklevel=2,
)
self.device = "cpu"
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html
self.spectrogram_func = torchaudio.transforms.Spectrogram(
n_fft=params.n_fft,
hop_length=params.hop_length,
win_length=params.win_length,
pad=0,
window_fn=torch.hann_window,
power=None,
normalized=False,
wkwargs=None,
center=True,
pad_mode="reflect",
onesided=True,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html
self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(
n_fft=params.n_fft,
n_iter=params.num_griffin_lim_iters,
win_length=params.win_length,
hop_length=params.hop_length,
window_fn=torch.hann_window,
power=1.0,
wkwargs=None,
momentum=0.99,
length=None,
rand_init=True,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html
self.mel_scaler = torchaudio.transforms.MelScale(
n_mels=params.num_frequencies,
sample_rate=params.sample_rate,
f_min=params.min_frequency,
f_max=params.max_frequency,
n_stft=params.n_fft // 2 + 1,
norm=params.mel_scale_norm,
mel_scale=params.mel_scale_type,
).to(self.device)
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html
self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
n_stft=params.n_fft // 2 + 1,
n_mels=params.num_frequencies,
sample_rate=params.sample_rate,
f_min=params.min_frequency,
f_max=params.max_frequency,
# max_iter=params.max_mel_iters, # for higher verson of torchaudio
# tolerance_loss=1e-5, # for higher verson of torchaudio
# tolerance_change=1e-8, # for higher verson of torchaudio
# sgdargs=None, # for higher verson of torchaudio
norm=params.mel_scale_norm,
mel_scale=params.mel_scale_type,
).to(self.device)
def spectrogram_from_audio(
self,
audio: pydub.AudioSegment,
) -> np.ndarray:
"""
Compute a spectrogram from an audio segment.
Args:
audio: Audio segment which must match the sample rate of the params
Returns:
spectrogram: (channel, frequency, time)
"""
assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"
# Get the samples as a numpy array in (batch, samples) shape
waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])
# Convert to floats if necessary
if waveform.dtype != np.float32:
waveform = waveform.astype(np.float32)
waveform_tensor = torch.from_numpy(waveform).to(self.device)
amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
return amplitudes_mel.cpu().numpy()
def audio_from_spectrogram(
self,
spectrogram: np.ndarray,
apply_filters: bool = True,
) -> pydub.AudioSegment:
"""
Reconstruct an audio segment from a spectrogram.
Args:
spectrogram: (batch, frequency, time)
apply_filters: Post-process with normalization and compression
Returns:
audio: Audio segment with channels equal to the batch dimension
"""
# Move to device
amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)
# Reconstruct the waveform
waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)
# Convert to audio segment
segment = audio_from_waveform(
samples=waveform.cpu().numpy(),
sample_rate=self.p.sample_rate,
# Normalize the waveform to the range [-1, 1]
normalize=True,
)
# Optionally apply post-processing filters
if apply_filters:
segment = apply_filters_func(
segment,
compression=False,
)
return segment
def mel_amplitudes_from_waveform(
self,
waveform: torch.Tensor,
) -> torch.Tensor:
"""
Torch-only function to compute Mel-scale amplitudes from a waveform.
Args:
waveform: (batch, samples)
Returns:
amplitudes_mel: (batch, frequency, time)
"""
# Compute the complex-valued spectrogram
spectrogram_complex = self.spectrogram_func(waveform)
# Take the magnitude
amplitudes = torch.abs(spectrogram_complex)
# Convert to mel scale
return self.mel_scaler(amplitudes)
def waveform_from_mel_amplitudes(
self,
amplitudes_mel: torch.Tensor,
) -> torch.Tensor:
"""
Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes.
Args:
amplitudes_mel: (batch, frequency, time)
Returns:
waveform: (batch, samples)
"""
# Convert from mel scale to linear
amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)
# Run the approximate algorithm to compute the phase and recover the waveform
return self.inverse_spectrogram_func(amplitudes_linear)
def check_device(device: str, backup: str = "cpu") -> str:
"""
Check that the device is valid and available. If not,
"""
cuda_not_found = device.lower().startswith("cuda") and not torch.cuda.is_available()
mps_not_found = device.lower().startswith("mps") and not torch.backends.mps.is_available()
if cuda_not_found or mps_not_found:
warnings.warn(f"WARNING: {device} is not available, using {backup} instead.", stacklevel=3)
return backup
return device
def audio_from_waveform(
samples: np.ndarray, sample_rate: int, normalize: bool = False
) -> pydub.AudioSegment:
"""
Convert a numpy array of samples of a waveform to an audio segment.
Args:
samples: (channels, samples) array
"""
# Normalize volume to fit in int16
if normalize:
samples *= np.iinfo(np.int16).max / np.max(np.abs(samples))
# Transpose and convert to int16
samples = samples.transpose(1, 0)
samples = samples.astype(np.int16)
# Write to the bytes of a WAV file
wav_bytes = io.BytesIO()
wavfile.write(wav_bytes, sample_rate, samples)
wav_bytes.seek(0)
# Read into pydub
return pydub.AudioSegment.from_wav(wav_bytes)
def apply_filters_func(segment: pydub.AudioSegment, compression: bool = False) -> pydub.AudioSegment:
"""
Apply post-processing filters to the audio segment to compress it and
keep at a -10 dBFS level.
"""
# TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end.
# TODO(hayk): Is this going to make audio unbalanced between sequential clips?
if compression:
segment = pydub.effects.normalize(
segment,
headroom=0.1,
)
segment = segment.apply_gain(-10 - segment.dBFS)
# TODO(hayk): This is quite slow, ~1.7 seconds on a beefy CPU
segment = pydub.effects.compress_dynamic_range(
segment,
threshold=-20.0,
ratio=4.0,
attack=5.0,
release=50.0,
)
desired_db = -12
segment = segment.apply_gain(desired_db - segment.dBFS)
segment = pydub.effects.normalize(
segment,
headroom=0.1,
)
return segment
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
for path in paths:
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif 'to_out.0.weight' in new_path:
checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
elif any([qkv in new_path for qkv in ['to_q', 'to_k', 'to_v']]):
checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
if controlnet:
unet_params = original_config.model.params.control_stage_config.params
else:
unet_params = original_config.model.params.unet_config.params
vae_params = original_config.model.params.first_stage_config.params.ddconfig
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim = [5, 10, 20, 20]
class_embed_type = None
projection_class_embeddings_input_dim = None
if "num_classes" in unet_params:
if unet_params.num_classes == "sequential":
class_embed_type = "projection"
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
else:
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"cross_attention_dim": unet_params.context_dim,
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
}
if not controlnet:
config["out_channels"] = unet_params.out_channels
config["up_block_types"] = tuple(up_block_types)
return config
def create_vae_diffusers_config(original_config, image_size: int):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
vae_params = original_config.model.params.first_stage_config.params.ddconfig
_ = original_config.model.params.first_stage_config.params.embed_dim
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = {
"sample_size": image_size,
"in_channels": vae_params.in_channels,
"out_channels": vae_params.out_ch,
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels,
"layers_per_block": vae_params.num_res_blocks,
}
return config
def create_diffusers_schedular(original_config):
schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps,
beta_start=original_config.model.params.linear_start,
beta_end=original_config.model.params.linear_end,
beta_schedule="scaled_linear",
)
return schedular
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
if controlnet:
unet_key = "control_model."
else:
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
print(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
if config["class_embed_type"] is None:
# No parameters to port
...
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
else:
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
if not controlnet:
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_paths = renew_resnet_paths(resnet_0)
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
resnet_1_paths = renew_resnet_paths(resnet_1)
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
layer_in_block_id = i % (config["layers_per_block"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_checkpoint[new_path] = unet_state_dict[old_path]
if controlnet:
# conditioning embedding
orig_index = 0
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
orig_index += 2
diffusers_index = 0
while diffusers_index < 6:
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
diffusers_index += 1
orig_index += 2
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
# down blocks
for i in range(num_input_blocks):
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
# mid block
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
return new_checkpoint
def convert_ldm_vae_checkpoint(checkpoint, config, only_decoder=False, only_encoder=False):
# extract state dict for VAE
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
if only_decoder:
new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('decoder') or k.startswith('post_quant')}
elif only_encoder:
new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('encoder') or k.startswith('quant')}
return new_checkpoint
def convert_ldm_clip_checkpoint(checkpoint):
keys = list(checkpoint.keys())
text_model_dict = {}
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
return text_model_dict
def convert_lora_model_level(state_dict, unet, text_encoder=None, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
"""convert lora in model level instead of pipeline leval
"""
visited = []
# directly update weight in diffusers model
for key in state_dict:
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# as we have set the alpha beforehand, so just skip
if ".alpha" in key or key in visited:
continue
if "text" in key:
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
assert text_encoder is not None, (
'text_encoder must be passed since lora contains text encoder layers')
curr_layer = text_encoder
else:
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
curr_layer = unet
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
# NOTE: load lycon, meybe have bugs :(
if 'conv_in' in pair_keys[0]:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
weight_up = weight_up.view(weight_up.size(0), -1)
weight_down = weight_down.view(weight_down.size(0), -1)
shape = [e for e in curr_layer.weight.data.shape]
shape[1] = 4
curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape)
elif 'conv' in pair_keys[0]:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
weight_up = weight_up.view(weight_up.size(0), -1)
weight_down = weight_down.view(weight_down.size(0), -1)
shape = [e for e in curr_layer.weight.data.shape]
curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape)
elif len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
# update visited list
for item in pair_keys:
visited.append(item)
return unet, text_encoder
def denormalize_spectrogram(
data: torch.Tensor,
max_value: float = 200,
min_value: float = 1e-5,
power: float = 1,
inverse: bool = False,
) -> torch.Tensor:
max_value = np.log(max_value)
min_value = np.log(min_value)
# Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner
data = torch.flip(data, [1])
assert len(data.shape) == 3, "Expected 3 dimensions, got {}".format(len(data.shape))
if data.shape[0] == 1:
data = data.repeat(3, 1, 1)
assert data.shape[0] == 3, "Expected 3 channels, got {}".format(data.shape[0])
data = data[0]
# Reverse the power curve
data = torch.pow(data, 1 / power)
# Invert
if inverse:
data = 1 - data
# Rescale to max value
spectrogram = data * (max_value - min_value) + min_value
return spectrogram
class ToTensor1D(torchvision.transforms.ToTensor):
def __call__(self, tensor: np.ndarray):
tensor_2d = super(ToTensor1D, self).__call__(tensor[..., np.newaxis])
return tensor_2d.squeeze_(0)
def scale(old_value, old_min, old_max, new_min, new_max):
old_range = (old_max - old_min)
new_range = (new_max - new_min)
new_value = (((old_value - old_min) * new_range) / old_range) + new_min
return new_value
def read_frames_with_moviepy(video_path, max_frame_nums=None):
clip = VideoFileClip(video_path)
duration = clip.duration
frames = []
for frame in clip.iter_frames():
frames.append(frame)
if max_frame_nums is not None:
frames_idx = np.linspace(0, len(frames) - 1, max_frame_nums, dtype=int)
return np.array(frames)[frames_idx,...], duration
def read_frames_with_moviepy_resample(video_path, save_path):
vision_transform_list = [
transforms.Resize((128, 128)),
transforms.CenterCrop((112, 112)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
video_transform = transforms.Compose(vision_transform_list)
os.makedirs(save_path, exist_ok=True)
command = f'ffmpeg -v quiet -y -i \"{video_path}\" -f image2 -vf \"scale=-1:360,fps=15\" -qscale:v 3 \"{save_path}\"/frame%06d.jpg'
os.system(command)
frame_list = glob.glob(f'{save_path}/*.jpg')
frame_list.sort()
convert_tensor = transforms.ToTensor()
frame_list = [convert_tensor(np.array(Image.open(frame))) for frame in frame_list]
imgs = torch.stack(frame_list, dim=0)
imgs = video_transform(imgs)
imgs = imgs.permute(1, 0, 2, 3)
return imgs