Spaces:
Running
Running
from enum import Enum | |
import torch | |
from model_classes import Model200M, Model5M, SyntheticV2 | |
from model_transforms import transform_200M, transform_5M, transform_synthetic | |
class ModelType(str, Enum): | |
MIDJOURNEY_200M = "midjourney_200M" | |
DIFFUSIONS_200M = "diffusions_200M" | |
MIDJOURNEY_5M = "midjourney_5M" | |
DIFFUSIONS_5M = "diffusions_5M" | |
SYNTHETIC_DETECTOR_V2 = "synthetic_detector_v2" | |
def __str__(self): | |
return str(self.value) | |
def get_list(): | |
return [model_type.value for model_type in ModelType] | |
def load_model(value: ModelType): | |
model = type_to_class[value] | |
path = type_to_path[value] | |
ckpt = torch.load(path, map_location=torch.device('cpu')) | |
model.load_state_dict(ckpt) | |
model.eval() | |
return model | |
type_to_class = { | |
ModelType.MIDJOURNEY_200M : Model200M(), | |
ModelType.DIFFUSIONS_200M : Model200M(), | |
ModelType.MIDJOURNEY_5M : Model5M(), | |
ModelType.DIFFUSIONS_5M : Model5M(), | |
ModelType.SYNTHETIC_DETECTOR_V2 : SyntheticV2(), | |
} | |
type_to_path = { | |
ModelType.MIDJOURNEY_200M : 'models/midjourney200M.pt', | |
ModelType.DIFFUSIONS_200M : 'models/diffusions200M.pt', | |
ModelType.MIDJOURNEY_5M : 'models/midjourney5M.pt', | |
ModelType.DIFFUSIONS_5M : 'models/diffusions5M.pt', | |
ModelType.SYNTHETIC_DETECTOR_V2 : 'models/synthetic_detector_v2.pt', | |
} | |
type_to_loaded_model = { | |
ModelType.MIDJOURNEY_200M: load_model(ModelType.MIDJOURNEY_200M), | |
ModelType.DIFFUSIONS_200M: load_model(ModelType.DIFFUSIONS_200M), | |
ModelType.MIDJOURNEY_5M: load_model(ModelType.MIDJOURNEY_5M), | |
ModelType.DIFFUSIONS_5M: load_model(ModelType.DIFFUSIONS_5M), | |
ModelType.SYNTHETIC_DETECTOR_V2: load_model(ModelType.SYNTHETIC_DETECTOR_V2) | |
} | |
type_to_transforms = { | |
ModelType.MIDJOURNEY_200M: transform_200M, | |
ModelType.DIFFUSIONS_200M: transform_200M, | |
ModelType.MIDJOURNEY_5M: transform_5M, | |
ModelType.DIFFUSIONS_5M: transform_5M, | |
ModelType.SYNTHETIC_DETECTOR_V2: transform_synthetic | |
} |