Spaces:
No application file
No application file
import gc | |
import os | |
import gradio | |
import torch.cuda | |
from transformers import Pipeline | |
def choices(): | |
from .download import model_types | |
return [_type + '/' + model for _type in model_types for model in get_installed_models(_type)] | |
def refresh_choices(): | |
return gradio.Dropdown.update('', choices()) | |
def get_installed_models(model_type): | |
_dir = f'data/models/{model_type}' | |
if not os.path.isdir(_dir): | |
os.mkdir(_dir) | |
found = [] | |
for model in [name for name in os.listdir(_dir) if os.path.isdir(os.path.join(_dir, name))]: | |
found.append(model) | |
return found | |
class ModelLoader: | |
no_install = False | |
def __init__(self, model_type): | |
self.type = model_type | |
self.pipeline: Pipeline = None | |
def load_model(self, name): | |
_dir = f'data/models/{self.type}/{name}' | |
self.pipeline = self._load_internal(_dir) | |
def _load_internal(self, path): | |
return Pipeline.from_pretrained(task=self.type, model=path) | |
def unload_model(self): | |
del self.pipeline | |
if not self.pipeline.device == 'cpu': | |
torch.cuda.empty_cache() | |
gc.collect() | |
def get_loaded_model(self): | |
return self.pipeline | |
def get_response(self, *inputs): | |
raise NotImplementedError('Not implemented, please implement this method.') | |
def all_tts(): | |
import webui.modules.implementations as impl | |
return impl.tts.all_tts() | |
def all_tts_models(): | |
return [model.model for model in all_tts()] | |
class TTSModelLoader(ModelLoader): | |
def get_response(self, *inputs, progress=gradio.Progress()): | |
raise NotImplementedError('Not implemented, please implement this method.') | |
model: str | |
trigger: str | |
def __init__(self): | |
super().__init__('text-to-speech') | |
self.trigger = self.model.replace('/', '--') | |
def load_model(self, progress=gradio.Progress()): | |
raise NotImplementedError('Not implemented, please implement this method.') | |
def unload_model(self): | |
raise NotImplementedError('Not implemented, please implement this method.') | |
def from_model(model_path): | |
for model in all_tts(): | |
if model.no_install and model.trigger.lower() == model_path.lower().replace('/', '--'): | |
return model | |
if model.trigger.lower() == model_path.lower().split('/')[-1]: | |
return model | |
return None | |
def _components(self, **quick_kwargs): | |
raise NotImplementedError('Not implemented, please implement this method') | |
def gradio_components(self): | |
# with gradio.Column(): | |
components = self._components(interactive=True, visible=False) | |
return components if components else [] | |