Spaces:
Running
Running
import timm | |
import torch | |
from torch import nn | |
import pytorch_lightning as pl | |
from pytorch_lightning.core.mixins import HyperparametersMixin | |
class Model200M(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', pretrained=False, | |
num_classes=0) | |
self.clf = nn.Sequential( | |
nn.Linear(1536, 128), | |
nn.ReLU(inplace=True), | |
nn.Linear(128, 2)) | |
def forward(self, image): | |
image_features = self.model(image) | |
return self.clf(image_features) | |
class Model5M(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = timm.create_model('timm/tf_mobilenetv3_large_100.in1k', pretrained=False, num_classes=0) | |
self.clf = nn.Sequential( | |
nn.Linear(1280, 128), | |
nn.ReLU(inplace=True), | |
nn.Linear(128, 2)) | |
def forward(self, image): | |
image_features = self.model(image) | |
return self.clf(image_features) | |
class SyntheticV2(pl.LightningModule, HyperparametersMixin): | |
def __init__(self): | |
super().__init__() | |
self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', pretrained=False, | |
num_classes=0) | |
self.clf = nn.Sequential( | |
nn.Linear(1536, 128), | |
nn.ReLU(inplace=True), | |
nn.Linear(128, 2)) | |
def forward(self, image): | |
image_features = self.model(image) | |
return self.clf(image_features) |