Kangarroar commited on
Commit
433b568
1 Parent(s): 524341f

Update infer_tools/infer_tool.py

Browse files
Files changed (1) hide show
  1. infer_tools/infer_tool.py +6 -5
infer_tools/infer_tool.py CHANGED
@@ -20,7 +20,8 @@ from preprocessing.data_gen_utils import get_pitch_parselmouth, get_pitch_crepe,
20
  from preprocessing.hubertinfer import Hubertencoder
21
  from utils.hparams import hparams, set_hparams
22
  from utils.pitch_utils import denorm_f0, norm_interp_f0
23
-
 
24
  if os.path.exists("chunks_temp.json"):
25
  os.remove("chunks_temp.json")
26
 
@@ -127,10 +128,10 @@ class Svc:
127
  spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
128
  )
129
  self.load_ckpt()
130
- self.model.cuda()
131
  hparams['hubert_gpu'] = hubert_gpu
132
  self.hubert = Hubertencoder(hparams['hubert_path'])
133
- self.pe = PitchExtractor().cuda()
134
  utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
135
  self.pe.eval()
136
  self.vocoder = get_vocoder_cls(hparams)()
@@ -152,8 +153,8 @@ class Svc:
152
  @timeit
153
  def diff_infer():
154
  outputs = self.model(
155
- hubert.cuda(), spk_embed=spk_embed, mel2ph=mel2ph.cuda(), f0=f0.cuda(), uv=uv.cuda(),energy=energy.cuda(),
156
- ref_mels=ref_mels.cuda(),
157
  infer=True, **kwargs)
158
  return outputs
159
  outputs=diff_infer()
 
20
  from preprocessing.hubertinfer import Hubertencoder
21
  from utils.hparams import hparams, set_hparams
22
  from utils.pitch_utils import denorm_f0, norm_interp_f0
23
+ defice = 'cpu'
24
+ map_location=torch.device('cpu')
25
  if os.path.exists("chunks_temp.json"):
26
  os.remove("chunks_temp.json")
27
 
 
128
  spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
129
  )
130
  self.load_ckpt()
131
+ self.model.to(device)
132
  hparams['hubert_gpu'] = hubert_gpu
133
  self.hubert = Hubertencoder(hparams['hubert_path'])
134
+ self.pe = PitchExtractor().to(device)
135
  utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
136
  self.pe.eval()
137
  self.vocoder = get_vocoder_cls(hparams)()
 
153
  @timeit
154
  def diff_infer():
155
  outputs = self.model(
156
+ hubert.to(device), spk_embed=spk_embed, mel2ph=mel2ph.to(device), f0=f0.to(device), uv=uv.to(device),energy=energy.to(device),
157
+ ref_mels=ref_mels.to(device),
158
  infer=True, **kwargs)
159
  return outputs
160
  outputs=diff_infer()