aadnk's picture
Ensure GPU memory in diarization can be cleaned up
18bb72f
raw
history blame
7.47 kB
import argparse
import gc
import json
import os
from pathlib import Path
import tempfile
from typing import TYPE_CHECKING, List
import torch
import ffmpeg
class DiarizationEntry:
def __init__(self, start, end, speaker):
self.start = start
self.end = end
self.speaker = speaker
def __repr__(self):
return f"<DiarizationEntry start={self.start} end={self.end} speaker={self.speaker}>"
def toJson(self):
return {
"start": self.start,
"end": self.end,
"speaker": self.speaker
}
class Diarization:
def __init__(self, auth_token=None):
if auth_token is None:
auth_token = os.environ.get("HK_ACCESS_TOKEN")
if auth_token is None:
raise ValueError("No HuggingFace API Token provided - please use the --auth_token argument or set the HK_ACCESS_TOKEN environment variable")
self.auth_token = auth_token
self.initialized = False
self.pipeline = None
@staticmethod
def has_libraries():
try:
import pyannote.audio
import intervaltree
return True
except ImportError:
return False
def initialize(self):
if self.initialized:
return
from pyannote.audio import Pipeline
self.pipeline = Pipeline.from_pretrained("pyannote/[email protected]", use_auth_token=self.auth_token)
self.initialized = True
# Load GPU mode if available
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
print("Diarization - using GPU")
self.pipeline = self.pipeline.to(torch.device(0))
else:
print("Diarization - using CPU")
def run(self, audio_file, **kwargs):
self.initialize()
audio_file_obj = Path(audio_file)
# Supported file types in soundfile is WAV, FLAC, OGG and MAT
if audio_file_obj.suffix in [".wav", ".flac", ".ogg", ".mat"]:
target_file = audio_file
else:
# Create temp WAV file
target_file = tempfile.mktemp(prefix="diarization_", suffix=".wav")
try:
ffmpeg.input(audio_file).output(target_file, ac=1).run()
except ffmpeg.Error as e:
print(f"Error occurred during audio conversion: {e.stderr}")
diarization = self.pipeline(target_file, **kwargs)
if target_file != audio_file:
# Delete temp file
os.remove(target_file)
# Yield result
for turn, _, speaker in diarization.itertracks(yield_label=True):
yield DiarizationEntry(turn.start, turn.end, speaker)
def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict):
from intervaltree import IntervalTree
result = whisper_result.copy()
# Create an interval tree from the diarization results
tree = IntervalTree()
for entry in diarization_result:
tree[entry.start:entry.end] = entry
# Iterate through each segment in the Whisper JSON
for segment in result["segments"]:
segment_start = segment["start"]
segment_end = segment["end"]
# Find overlapping speakers using the interval tree
overlapping_speakers = tree[segment_start:segment_end]
# If no speakers overlap with this segment, skip it
if not overlapping_speakers:
continue
# If multiple speakers overlap with this segment, choose the one with the longest duration
longest_speaker = None
longest_duration = 0
for speaker_interval in overlapping_speakers:
overlap_start = max(speaker_interval.begin, segment_start)
overlap_end = min(speaker_interval.end, segment_end)
overlap_duration = overlap_end - overlap_start
if overlap_duration > longest_duration:
longest_speaker = speaker_interval.data.speaker
longest_duration = overlap_duration
# Add speakers
segment["longest_speaker"] = longest_speaker
segment["speakers"] = list([speaker_interval.data.toJson() for speaker_interval in overlapping_speakers])
# The write_srt will use the longest_speaker if it exist, and add it to the text field
return result
def _write_file(input_file: str, output_path: str, output_extension: str, file_writer: lambda f: None):
if input_file is None:
raise ValueError("input_file is required")
if file_writer is None:
raise ValueError("file_writer is required")
# Write file
if output_path is None:
effective_path = os.path.splitext(input_file)[0] + "_output" + output_extension
else:
effective_path = output_path
with open(effective_path, 'w+', encoding="utf-8") as f:
file_writer(f)
print(f"Output saved to {effective_path}")
def main():
from src.utils import write_srt
from src.diarization.transcriptLoader import load_transcript
parser = argparse.ArgumentParser(description='Add speakers to a SRT file or Whisper JSON file using pyannote/speaker-diarization.')
parser.add_argument('audio_file', type=str, help='Input audio file')
parser.add_argument('whisper_file', type=str, help='Input Whisper JSON/SRT file')
parser.add_argument('--output_json_file', type=str, default=None, help='Output JSON file (optional)')
parser.add_argument('--output_srt_file', type=str, default=None, help='Output SRT file (optional)')
parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
parser.add_argument("--max_line_width", type=int, default=40, help="Maximum line width for SRT file (default: 40)")
parser.add_argument("--num_speakers", type=int, default=None, help="Number of speakers")
parser.add_argument("--min_speakers", type=int, default=None, help="Minimum number of speakers")
parser.add_argument("--max_speakers", type=int, default=None, help="Maximum number of speakers")
args = parser.parse_args()
print("\nReading whisper JSON from " + args.whisper_file)
# Read whisper JSON or SRT file
whisper_result = load_transcript(args.whisper_file)
diarization = Diarization(auth_token=args.auth_token)
diarization_result = list(diarization.run(args.audio_file, num_speakers=args.num_speakers, min_speakers=args.min_speakers, max_speakers=args.max_speakers))
# Print result
print("Diarization result:")
for entry in diarization_result:
print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
marked_whisper_result = diarization.mark_speakers(diarization_result, whisper_result)
# Write output JSON to file
_write_file(args.whisper_file, args.output_json_file, ".json",
lambda f: json.dump(marked_whisper_result, f, indent=4, ensure_ascii=False))
# Write SRT
_write_file(args.whisper_file, args.output_srt_file, ".srt",
lambda f: write_srt(marked_whisper_result["segments"], f, maxLineWidth=args.max_line_width))
if __name__ == "__main__":
main()
#test = Diarization()
#print("Initializing")
#test.initialize()
#input("Press Enter to continue...")