Spaces:
Sleeping
Sleeping
import torch | |
import torchaudio | |
import random | |
import itertools | |
import numpy as np | |
from tools.mix import mix | |
from PIL import Image | |
import cv2 | |
from moviepy.editor import VideoFileClip | |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode, RandomResizedCrop | |
def normalize_wav(waveform): | |
waveform = waveform - torch.mean(waveform) | |
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) | |
return waveform * 0.5 | |
def sinusoidal_positional_embedding(token_sequence_size, token_embedding_dim, n=10000.0): | |
if token_embedding_dim % 2 != 0: | |
raise ValueError("Sinusoidal positional embedding cannot apply to odd token embedding dim (got dim={:d})".format(token_embedding_dim)) | |
T = token_sequence_size | |
d = token_embedding_dim #d_model=head_num*d_k, not d_q, d_k, d_v | |
positions = torch.arange(0, T).unsqueeze_(1) | |
embeddings = torch.zeros(T, d) | |
denominators = torch.pow(n, 2*torch.arange(0, d//2)/d) # 10000^(2i/d_model), i is the index of embedding | |
embeddings[:, 0::2] = torch.sin(positions/denominators) # sin(pos/10000^(2i/d_model)) | |
embeddings[:, 1::2] = torch.cos(positions/denominators) # cos(pos/10000^(2i/d_model)) | |
return embeddings | |
def pad_wav(waveform, segment_length): | |
waveform_length = len(waveform) | |
if segment_length is None or waveform_length == segment_length: | |
return waveform | |
elif waveform_length > segment_length: | |
return waveform[:segment_length] | |
else: | |
pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device) | |
waveform = torch.cat([waveform, pad_wav]) | |
return waveform | |
def _pad_spec(fbank, target_length=1000): | |
batch, n_frames, channels = fbank.shape | |
p = target_length - n_frames | |
if p > 0: | |
pad = torch.zeros(batch, p, channels).to(fbank.device) | |
fbank = torch.cat([fbank, pad], 1) | |
elif p < 0: | |
fbank = fbank[:, :target_length, :] | |
if channels % 2 != 0: | |
fbank = fbank[:, :, :-1] | |
return fbank | |
def read_wav_file(filename, segment_length, tgt_sr=48000): | |
waveform, sr = torchaudio.load(filename) # Faster!!! | |
if sr != tgt_sr: | |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=tgt_sr)[0] | |
else: | |
waveform = waveform.squeeze() | |
try: | |
waveform = normalize_wav(waveform) | |
except: | |
print ("Exception normalizing:", filename) | |
waveform = torch.ones(tgt_sr * 10) | |
waveform = pad_wav(waveform, segment_length).unsqueeze(0) | |
waveform = waveform / torch.max(torch.abs(waveform)) | |
waveform = 0.5 * waveform | |
return waveform | |
def get_mel_from_wav(audio, _stft): | |
audio1 = torch.nan_to_num(torch.clip(audio, -1, 1)) | |
audio2 = torch.autograd.Variable(audio1, requires_grad=False) | |
melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio2) | |
return melspec, log_magnitudes_stft, energy | |
def wav_to_fbank(paths, target_length=1000, sample_rate=16000, fn_STFT=None): | |
assert fn_STFT is not None | |
if sample_rate == 16000: | |
hop_size = 160 | |
elif sample_rate == 24000: | |
hop_size = 240 | |
elif sample_rate == 32000: | |
hop_size = 320 | |
elif sample_rate == 48000: | |
hop_size = 480 | |
else: | |
raise ValueError(f"sample_rate wrong.") | |
#print("target_length", target_length, hop_size) | |
#print("target_length", target_length, sample_rate, fn_STFT) | |
#for name, param in fn_STFT.named_parameters(): | |
# print(name, param.data) | |
waveform = torch.cat([read_wav_file(path, target_length * hop_size, tgt_sr=sample_rate) for path in paths], 0) # hop size is 160 | |
#print("waveform", waveform.size()) | |
#np.set_printoptions(threshold=np.inf) | |
#print("waveform", waveform) | |
#f_out = open(paths[0].split("/")[-1]+".scp",'w') | |
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) | |
#print("fbank", fbank) | |
fbank = fbank.transpose(1, 2) | |
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) | |
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( | |
log_magnitudes_stft, target_length | |
) | |
#f_out.write(paths[0]+ "\n" + str(waveform.cpu().numpy())+"\n") | |
#f_out.write("audio1"+ "\n" + str(audio1.cpu().numpy())+"\n") | |
#f_out.write("audio2"+ "\n" + str(audio2.cpu().numpy())+"\n") | |
#f_out.write("fbank" + "\n" + str(fbank.cpu().numpy())+"\n") | |
#print(fbank2) | |
return fbank, log_magnitudes_stft, waveform | |
def get_wav_from_video(video_path, segment_length, tgt_sr=48000): | |
video = VideoFileClip(video_path) | |
audio = video.audio | |
sr = audio.fps | |
audio_data = audio.to_soundarray() # 441882 * 2 ειι | |
waveform = torch.mean(torch.tensor(audio_data, dtype=torch.float), dim=1).unsqueeze(0) # εζειι | |
if sr != tgt_sr: | |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=tgt_sr)[0] | |
else: | |
waveform = waveform.squeeze() | |
try: | |
waveform = normalize_wav(waveform) | |
except: | |
print ("Exception normalizing:", video_path) | |
waveform = torch.ones(tgt_sr * 10) | |
waveform = pad_wav(waveform, segment_length).unsqueeze(0) | |
waveform = waveform / torch.max(torch.abs(waveform)) | |
waveform = 0.5 * waveform | |
return waveform | |
def get_wavs_from_videos(video_paths, segment_length, tgt_sr=48000): | |
wavs = [] | |
for video_path in video_paths: | |
waveform = get_wav_from_video(video_path, segment_length, tgt_sr) | |
wavs.append(waveform) | |
wavs = torch.cat(wavs, 0) | |
return wavs | |
def wav_in_video_to_fbank(input, target_length=1000, sample_rate=16000, fn_STFT=None, waveform=False): | |
assert fn_STFT is not None | |
if sample_rate == 16000: | |
hop_size = 160 | |
elif sample_rate == 24000: | |
hop_size = 240 | |
elif sample_rate == 32000: | |
hop_size = 320 | |
elif sample_rate == 48000: | |
hop_size = 480 | |
else: | |
raise ValueError(f"sample_rate wrong.") | |
if not waveform: | |
paths = input | |
waveform = get_wavs_from_videos(paths, target_length * hop_size, tgt_sr=sample_rate) # hop size is 160 | |
else: | |
waveform = input | |
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) | |
fbank = fbank.transpose(1, 2) | |
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) | |
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( | |
log_magnitudes_stft, target_length | |
) | |
return fbank, log_magnitudes_stft, waveform | |
def uncapitalize(s): | |
if s: | |
return s[:1].lower() + s[1:] | |
else: | |
return "" | |
def mix_wavs_and_captions(path1, path2, caption1, caption2, target_length=1000, sample_rate=16000): | |
if sample_rate == 16000: | |
hop_size = 160 | |
elif sample_rate == 24000: | |
hop_size = 240 | |
elif sample_rate == 32000: | |
hop_size = 320 | |
elif sample_rate == 48000: | |
hop_size = 480 | |
else: | |
raise ValueError(f"sample_rate wrong.") | |
sound1 = read_wav_file(path1, target_length * hop_size)[0].numpy() | |
#print("sound1", target_length, sound1.size) | |
sound2 = read_wav_file(path2, target_length * hop_size)[0].numpy() | |
mixed_sound = mix(sound1, sound2, 0.5, sample_rate).reshape(1, -1) | |
#print("mixed_sound", mixed_sound.size) | |
mixed_caption = "{} and {}".format(caption1, uncapitalize(caption2)) | |
return mixed_sound, mixed_caption | |
def augment(paths, texts, num_items=4, target_length=1000, sample_rate=16000): | |
mixed_sounds, mixed_captions = [], [] | |
combinations = list(itertools.combinations(list(range(len(texts))), 2)) | |
random.shuffle(combinations) | |
if len(combinations) < num_items: | |
selected_combinations = combinations | |
else: | |
selected_combinations = combinations[:num_items] | |
for (i, j) in selected_combinations: | |
new_sound, new_caption = mix_wavs_and_captions(paths[i], paths[j], texts[i], texts[j], target_length, sample_rate) | |
mixed_sounds.append(new_sound) | |
mixed_captions.append(new_caption) | |
waveform = torch.tensor(np.concatenate(mixed_sounds, 0)) | |
waveform = waveform / torch.max(torch.abs(waveform)) | |
waveform = 0.5 * waveform | |
return waveform, mixed_captions | |
def augment_wav_to_fbank(paths, texts, num_items=4, target_length=1000, sample_rate=16000, fn_STFT=None): | |
assert fn_STFT is not None | |
waveform, captions = augment(paths, texts, target_length = target_length, sample_rate=sample_rate) | |
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) | |
fbank = fbank.transpose(1, 2) | |
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) | |
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( | |
log_magnitudes_stft, target_length | |
) | |
return fbank, log_magnitudes_stft, waveform, captions | |
def load_image(impaths, crop_size=384): | |
imgs = [] | |
RGB_mean = [0.485, 0.456, 0.406] | |
RGB_std = [0.229, 0.224, 0.225] | |
image_resize_and_crop = Compose([RandomResizedCrop(crop_size), ToTensor()]) | |
image_normalize = Normalize(mean=RGB_mean, std=RGB_std) | |
for impath in impaths: | |
img = Image.open(impath).convert('RGB') | |
img = image_resize_and_crop(img) | |
img = image_normalize(img) | |
imgs.append(img) | |
imgs = torch.stack(imgs) | |
return imgs | |
def load_video(video_path, frame_rate=1.0, size=224): | |
def preprocess(size, n_px): | |
return Compose([ | |
Resize(size, interpolation=InterpolationMode.BICUBIC), | |
CenterCrop(size), | |
lambda image: image.convert("RGB"), | |
ToTensor(), | |
# Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
])(n_px) | |
videos = [] | |
# for video_path in video_paths: | |
# cap = cv2.VideoCapture(video_path) | |
cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG) | |
frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
if fps < 1: | |
images = np.zeros([3, size, size], dtype=np.float32) | |
print("ERROR: problem reading video file: ", video_path) | |
else: | |
total_duration = (frameCount + fps - 1) // fps | |
start_sec, end_sec = 0, total_duration | |
interval = fps / frame_rate | |
frames_idx = np.floor(np.arange(start_sec*fps, end_sec*fps, interval)) | |
ret = True | |
images = np.zeros([len(frames_idx), 3, size, size], dtype=np.float32) | |
for i, idx in enumerate(frames_idx): | |
cap.set(cv2.CAP_PROP_POS_FRAMES , idx) | |
ret, frame = cap.read() | |
if not ret: break | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
last_frame = i | |
images[i,:,:,:] = preprocess(size, Image.fromarray(frame).convert("RGB")) | |
images = images[:last_frame+1] | |
cap.release() | |
return torch.tensor(images) |