jhj0517 commited on
Commit
4e55844
2 Parent(s): 0b0f426 15b3a25

Merge pull request #244 from jhj0517/fix/device

Browse files
modules/diarize/diarizer.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import torch
3
- from typing import List, Union, BinaryIO
4
  import numpy as np
5
  import time
6
  import logging
@@ -24,7 +24,7 @@ class Diarizer:
24
  audio: Union[str, BinaryIO, np.ndarray],
25
  transcribed_result: List[dict],
26
  use_auth_token: str,
27
- device: str
28
  ):
29
  """
30
  Diarize transcribed result as a post-processing
@@ -38,7 +38,7 @@ class Diarizer:
38
  use_auth_token: str
39
  Huggingface token with READ permission. This is only needed the first time you download the model.
40
  You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
41
- device: str
42
  Device for diarization.
43
 
44
  Returns
@@ -50,8 +50,10 @@ class Diarizer:
50
  """
51
  start_time = time.time()
52
 
53
- if (device != self.device
54
- or self.pipe is None):
 
 
55
  self.update_pipe(
56
  device=device,
57
  use_auth_token=use_auth_token
@@ -89,6 +91,7 @@ class Diarizer:
89
  device: str
90
  Device for diarization.
91
  """
 
92
 
93
  os.makedirs(self.model_dir, exist_ok=True)
94
 
 
1
  import os
2
  import torch
3
+ from typing import List, Union, BinaryIO, Optional
4
  import numpy as np
5
  import time
6
  import logging
 
24
  audio: Union[str, BinaryIO, np.ndarray],
25
  transcribed_result: List[dict],
26
  use_auth_token: str,
27
+ device: Optional[str] = None
28
  ):
29
  """
30
  Diarize transcribed result as a post-processing
 
38
  use_auth_token: str
39
  Huggingface token with READ permission. This is only needed the first time you download the model.
40
  You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
41
+ device: Optional[str]
42
  Device for diarization.
43
 
44
  Returns
 
50
  """
51
  start_time = time.time()
52
 
53
+ if device is None:
54
+ device = self.device
55
+
56
+ if device != self.device or self.pipe is None:
57
  self.update_pipe(
58
  device=device,
59
  use_auth_token=use_auth_token
 
91
  device: str
92
  Device for diarization.
93
  """
94
+ self.device = device
95
 
96
  os.makedirs(self.model_dir, exist_ok=True)
97
 
modules/whisper/whisper_base.py CHANGED
@@ -130,7 +130,6 @@ class WhisperBase(ABC):
130
  audio=audio,
131
  use_auth_token=params.hf_token,
132
  transcribed_result=result,
133
- device=self.device
134
  )
135
  elapsed_time += elapsed_time_diarization
136
  return result, elapsed_time
 
130
  audio=audio,
131
  use_auth_token=params.hf_token,
132
  transcribed_result=result,
 
133
  )
134
  elapsed_time += elapsed_time_diarization
135
  return result, elapsed_time