Spaces:
Running
on
A10G
Running
on
A10G
Update src/audio2pose_models/audio_encoder.py
Browse files
src/audio2pose_models/audio_encoder.py
CHANGED
@@ -19,7 +19,7 @@ class Conv2d(nn.Module):
|
|
19 |
return self.act(out)
|
20 |
|
21 |
class AudioEncoder(nn.Module):
|
22 |
-
def __init__(self, wav2lip_checkpoint):
|
23 |
super(AudioEncoder, self).__init__()
|
24 |
|
25 |
self.audio_encoder = nn.Sequential(
|
@@ -42,7 +42,7 @@ class AudioEncoder(nn.Module):
|
|
42 |
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
43 |
|
44 |
#### load the pre-trained audio_encoder\
|
45 |
-
wav2lip_state_dict = torch.load(wav2lip_checkpoint)['state_dict']
|
46 |
state_dict = self.audio_encoder.state_dict()
|
47 |
|
48 |
for k,v in wav2lip_state_dict.items():
|
|
|
19 |
return self.act(out)
|
20 |
|
21 |
class AudioEncoder(nn.Module):
|
22 |
+
def __init__(self, wav2lip_checkpoint, device='cpu'):
|
23 |
super(AudioEncoder, self).__init__()
|
24 |
|
25 |
self.audio_encoder = nn.Sequential(
|
|
|
42 |
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
43 |
|
44 |
#### load the pre-trained audio_encoder\
|
45 |
+
wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=device)['state_dict']
|
46 |
state_dict = self.audio_encoder.state_dict()
|
47 |
|
48 |
for k,v in wav2lip_state_dict.items():
|