Spaces:
Runtime error
Runtime error
from .config import set_layer_config | |
from .helpers import load_checkpoint | |
from .gen_efficientnet import * | |
from .mobilenetv3 import * | |
def create_model( | |
model_name='mnasnet_100', | |
pretrained=None, | |
num_classes=1000, | |
in_chans=3, | |
checkpoint_path='', | |
**kwargs): | |
model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) | |
if model_name in globals(): | |
create_fn = globals()[model_name] | |
model = create_fn(**model_kwargs) | |
else: | |
raise RuntimeError('Unknown model (%s)' % model_name) | |
if checkpoint_path and not pretrained: | |
load_checkpoint(model, checkpoint_path) | |
return model | |