vinthony commited on
Commit
8662725
1 Parent(s): 50dbb74

Update src/audio2pose_models/audio_encoder.py

Browse files
src/audio2pose_models/audio_encoder.py CHANGED
@@ -19,7 +19,7 @@ class Conv2d(nn.Module):
19
  return self.act(out)
20
 
21
  class AudioEncoder(nn.Module):
22
- def __init__(self, wav2lip_checkpoint):
23
  super(AudioEncoder, self).__init__()
24
 
25
  self.audio_encoder = nn.Sequential(
@@ -42,7 +42,7 @@ class AudioEncoder(nn.Module):
42
  Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
 
44
  #### load the pre-trained audio_encoder\
45
- wav2lip_state_dict = torch.load(wav2lip_checkpoint)['state_dict']
46
  state_dict = self.audio_encoder.state_dict()
47
 
48
  for k,v in wav2lip_state_dict.items():
 
19
  return self.act(out)
20
 
21
  class AudioEncoder(nn.Module):
22
+ def __init__(self, wav2lip_checkpoint, device='cpu'):
23
  super(AudioEncoder, self).__init__()
24
 
25
  self.audio_encoder = nn.Sequential(
 
42
  Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
 
44
  #### load the pre-trained audio_encoder\
45
+ wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=device)['state_dict']
46
  state_dict = self.audio_encoder.state_dict()
47
 
48
  for k,v in wav2lip_state_dict.items():