How to get the discrete codes correctly
I am trying to get the discrete codes in the right way but seems the faiss index is wrong somehow ?.
import torch
from transformers import HubertModel
from datasets import load_dataset
import faiss
import numpy as np
def load_index(index_path):
index: faiss.IndexPreTransform = faiss.read_index(index_path)
#Make sure we have access to the ivf subindex. We'll need it to get the centroids (clusters)
index_ivf = faiss.extract_index_ivf(index)
return index, index_ivf
def get_centroids_index(xq, index, index_ivf):
''' Get centroids '''
#Get OPQ matix
opq_mt = faiss.downcast_VectorTransform(index.chain.at(0))
#Apply pre-transform to query
xq_t = opq_mt.apply_py(xq)
#Get centroids C and distances DC on a pre-transformed index
DC,C = index_ivf.quantizer.search(xq_t, 1)
return DC, C
class Hubert2Unit(torch.nn.Module):
def __init__(
self,
model_name="",
kmean_path="",
dtype=torch.float32,
device="cuda:0",
):
super(Hubert2Unit, self).__init__()
self.model = HubertModel.from_pretrained("utter-project/mHuBERT-147").eval()
self.model.to(dtype=torch.float32, device=device) # trained with float32
self.index, self.index_ivf = load_index("mhubert147_faiss.index")
def zero_mean_unit_var_norm(
self, input_values, wav_lengths, padding_value: float = 0.0
):
"""
Every array in the list is normalized to have zero mean and unit variance
"""
if wav_lengths is not None:
normed_input_values = []
for vector, length in zip(input_values, wav_lengths):
normed_slice = (vector - vector[:length].mean()) / torch.sqrt(vector[:length].var() + 1e-7)
if length < normed_slice.shape[0]:
normed_slice[length:] = padding_value
normed_input_values.append(normed_slice)
else:
normed_input_values = [(x - x.mean()) / torch.sqrt(x.var() + 1e-7) for x in input_values]
return torch.stack(normed_input_values, dim=0)
def forward(self, wav, wav_lengths, do_normalize=True):
with torch.no_grad():
if do_normalize:
input_values = self.zero_mean_unit_var_norm(wav, wav_lengths)
else:
input_values = wav.clone()
# calcualte the attention_mask based on the wav_lengths_16k
attention_mask = torch.arange(
input_values.size(1),
device=input_values.device)[None, :] < wav_lengths[:, None]
attention_mask = attention_mask.long()
hidden_states = self.model(
input_values,
attention_mask=attention_mask,
output_hidden_states=True
).hidden_states[9] # 9th layer of encoder block.
hidden_states = hidden_states.reshape(hidden_states.size(0) * hidden_states.size(1), -1)
hidden_states_cpu = hidden_states.float().detach().cpu().numpy()
_, C = get_centroids_index(hidden_states_cpu, self.index, self.index_ivf)
C = C.reshape(wav.shape[0], -1)
n_unique_codes = len(np.unique(C))
return C, n_unique_codes
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
hubert = Hubert2Unit()
wav = ds[0]["audio"]["array"]
wav = torch.tensor(wav).to("cuda:0").unsqueeze(0).float()
lengths = torch.tensor([wav.shape[1]]).to("cuda:0")
C, n = hubert(wav, lengths)
@mzboito Thanks you in advance.
Hi! Thanks for moving this thread to a dedicated issue.
The source of your issue is very likely the mismatch between the trained faiss index and the mhubert-147 model you are using.
This index here (https://huggingface.co/utter-project/mHuBERT-147/blob/main/mhubert147_faiss.index) was trained on the output of the 9th layer of the 2nd iteration mHuBERT-147 (https://huggingface.co/utter-project/mHuBERT-147-base-2nd-iter), in order to generate targets for the mHuBERT-147 3rd iteration training.
If you input mHuBERT-147 (3rd iteration) features into it, it will not know how to cluster it very well, as it was trained on the output of a different model.
Basically, there are two settings in which you might be interested on faiss:
If you want to continuous pretrain the mHuBERT-147 (3rd iteration), you should extract features for your speech using the 2nd iteration 9th layer, and then generate the indices using the faiss index you are using (https://huggingface.co/utter-project/mHuBERT-147/blob/main/mhubert147_faiss.index). This should work.
If you want to generate faiss discretization using as input the features from the 3rd iteration (mHuBERT-147), then you need to train a new index on your target data. You can check our training recommendations here: https://github.com/utter-project/mHuBERT-147-scripts
I hope it was understandable!
@mzboito thanks you so much, it seems to be corrected now. Just want to make sure everything is matching, the do_normalize = True and the hidden_states[9] are correct (instead of False or hidden_states[10]) ?. The reason is because it seems true that len(hidden_states) = 13 not 12.
Yes, do_normalize=True for everything.
Regarding the layer: I did feature extraction on fairseq, not HF, so I'm not 100% sure, but it should be [9] if your length is of 13.
That is because the forward for feature extraction takes output_layer - 1: https://github.com/utter-project/fairseq/blob/3fb951a8658b81f09011fc2e9e5fe4c2e818a304/fairseq/models/hubert/hubert.py#L470