from collections import namedtuple from typing import List ModelConfig = namedtuple("ModelConfig", ["model_name", "i2s_model", "online_model", "model_path"]) model_config = {} def register_model_config( model_name: str, i2s_model: bool, online_model: bool, model_path: str = None ): config = ModelConfig(model_name, i2s_model, online_model, model_path) model_config[model_name] = config def get_model_config(model_name: str) -> ModelConfig: assert model_name in model_config return model_config[model_name] register_model_config( model_name="dreamfusion", i2s_model=False, online_model=False ) register_model_config( model_name="instant3d", i2s_model=False, online_model=False ) register_model_config( model_name="latent-nerf", i2s_model=False, online_model=False ) register_model_config( model_name="magic3d", i2s_model=False, online_model=False ) # register_model_config( # model_name="mvdream", # i2s_model=False, # online_model=False # ) # register_model_config( # model_name="prolificdreamer", # i2s_model=False, # online_model=False # ) register_model_config( model_name="dreamgaussian", i2s_model=True, online_model=False ) # register_model_config( # model_name="wonder3d", # i2s_model=True, # online_model=False # ) register_model_config( model_name="lgm", i2s_model=True, online_model=False ) register_model_config( model_name="openlrm", i2s_model=True, online_model=False ) register_model_config( model_name="triplane-gaussian", i2s_model=True, online_model=False )