import torch import numpy as np import joblib from s3prl.nn import S3PRLUpstream import soundfile as sf from argparse import ArgumentParser import sentencepiece as spm import os from tqdm import tqdm class ApplyKmeans(object): def __init__(self, km_path, use_gpu): self.km_model = joblib.load(km_path) self.C_np = self.km_model.cluster_centers_.transpose() self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True) self.C = torch.from_numpy(self.C_np) self.Cnorm = torch.from_numpy(self.Cnorm_np) if use_gpu and torch.cuda.is_available(): self.C = self.C.cuda() self.Cnorm = self.Cnorm.cuda() def __call__(self, x): if isinstance(x, torch.Tensor): x = x.to(self.C.device) dist = ( x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm ) return dist.argmin(dim=1).cpu().numpy() else: dist = ( (x**2).sum(1, keepdims=True) - 2 * np.matmul(x, self.C_np) + self.Cnorm_np ) return np.argmin(dist, axis=1) def streaming_extract(wav, window_size=60): chunk_audios = [] for i in tqdm(range(0, wav.shape[0], window_size * 16)): batched_audio = (torch.tensor(wav[i : i+window_size*16]).unsqueeze(0), torch.tensor([window_size*16])) # all_hs = [] # for i in tqdm(range(0, wav.shape[0], window_size * 16)): # hs, _ = model(torch.tensor(wav[i : i+window_size*16]).unsqueeze(0), torch.tensor([window_size*16])) # all_hs.append(hs[20]) # return torch.concat(all_hs, dim=1) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--km_model", default='./km_2000.mdl', help="Path to the kmeans model") parser.add_argument("--bpe_model", default='./bpe.model', help="Path to the bpe model") parser.add_argument("--audio", required=True, help="Path to the audio file") parser.add_argument("-s", action='store_true', help="Streaming mode") args = parser.parse_args() kmeans_path = args.km_model bpe_path = args.bpe_model audio_file = args.audio streaming = args.s apply_kmeans = ApplyKmeans(kmeans_path, use_gpu=True) ssl_model = S3PRLUpstream("hf_hubert_custom", path_or_url='TencentGameMate/chinese-hubert-large') ssl_model.eval() sp = spm.SentencePieceProcessor(model_file=bpe_path) unit_to_char = {} for l in open('distinct_cjk_token_lists').readlines(): l = l.split() unit_to_char[int(l[0])] = l[1] wav, sr = sf.read(audio_file) with torch.no_grad(): if streaming: assert False, "streaming mode is still developing" # all_hs = streaming_extract(wav) else: all_hs, all_hs_len = ssl_model(torch.tensor([wav]), torch.tensor([wav.shape[0]])) ssl_units = apply_kmeans(all_hs[20][0, :, :].numpy()) print(ssl_units) ssl_char = "".join([unit_to_char[c] for c in ssl_units]) ssl_char_bpe = sp.encode(ssl_char, out_type=str) print(ssl_char) print(ssl_char_bpe)