File size: 1,331 Bytes
b7920e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import time
import typing as tp
from audiocraft.models import MusicGen
from audiocraft.modules.conditioners import ConditioningAttributes

class HijackedMusicGen(MusicGen):
    def __init__(self, socketio=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.socketio = socketio
        self._progress_callback = self._timed_progress_callback if socketio is not None else None
        self._last_update_time = time.time()
        
    def _timed_progress_callback(self, generated_tokens: int, tokens_to_generate: int):
        current_time = time.time()
        if current_time - self._last_update_time >= 0.1:  # 0.1 seconds have passed
            self.socketio.emit('progress', {'generated_tokens': generated_tokens, 'tokens_to_generate': tokens_to_generate})
            self._last_update_time = current_time
        
    @staticmethod
    def get_pretrained(socketio, name: str = 'melody', device='cuda'):
        music_gen = MusicGen.get_pretrained(name, device)
        return HijackedMusicGen(socketio, music_gen.name, music_gen.compression_model, music_gen.lm)

    @property
    def progress_callback(self):
        raise Exception("Progress callback is write-only")

    @progress_callback.setter
    def progress_callback(self, callback):
        self._progress_callback = callback