|
import os.path |
|
from io import BytesIO |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from network.hubert.hubert_model import hubert_soft, get_units |
|
from network.hubert.vec_model import load_model, get_vec_units |
|
from utils.hparams import hparams |
|
|
|
|
|
class Hubertencoder(): |
|
def __init__(self, pt_path=f'.checkpoints/hubert/hubert_soft.pt'): |
|
if not 'use_vec' in hparams.keys(): |
|
hparams['use_vec'] = False |
|
if hparams['use_vec']: |
|
pt_path = f".checkpoints/vec/checkpoint_best_legacy_500.pt" |
|
self.dev = torch.device("cuda") |
|
self.hbt_model = load_model(pt_path) |
|
else: |
|
pt_path = list(Path(pt_path).parent.rglob('*.pt'))[0] |
|
if 'hubert_gpu' in hparams.keys(): |
|
self.use_gpu = hparams['hubert_gpu'] |
|
else: |
|
self.use_gpu = True |
|
self.dev = torch.device("cuda" if self.use_gpu and torch.cuda.is_available() else "cpu") |
|
self.hbt_model = hubert_soft(str(pt_path)).to(self.dev) |
|
|
|
def encode(self, wav_path): |
|
if isinstance(wav_path, BytesIO): |
|
npy_path = "" |
|
wav_path.seek(0) |
|
else: |
|
npy_path = Path(wav_path).with_suffix('.npy') |
|
if os.path.exists(npy_path): |
|
units = np.load(str(npy_path)) |
|
elif hparams['use_vec']: |
|
units = get_vec_units(self.hbt_model, wav_path, self.dev).cpu().numpy()[0] |
|
else: |
|
units = get_units(self.hbt_model, wav_path, self.dev).cpu().numpy()[0] |
|
return units |
|
|