Transformers
Inference Endpoints
dreamsim / model.py
neggles's picture
change dreamsim class hierarchy a bit
8a8582e
raw
history blame
No virus
6.49 kB
from abc import abstractmethod
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from torch import Tensor
from torch.nn import functional as F
from torchvision.transforms import v2 as T
from .common import ensure_tuple
from .vit import VisionTransformer, vit_base_dreamsim
class DreamsimBackbone(ModelMixin, ConfigMixin):
@abstractmethod
def forward_features(self, x: Tensor) -> Tensor:
raise NotImplementedError("abstract base class was called ;_;")
def forward(self, x: Tensor) -> Tensor:
"""Dreamsim forward pass for similarity computation.
Args:
x (Tensor): Input tensor of shape [2, B, 3, H, W].
Returns:
sim (torch.Tensor): dreamsim similarity score of shape [B].
"""
inputs = x.view(-1, 3, *x.shape[-2:])
x = self.forward_features(inputs).view(*x.shape[:2], -1)
return 1 - F.cosine_similarity(x[0], x[1], dim=1)
def compile(self, *args, **kwargs):
"""Compile the model with Inductor. This is a no-op unless overridden by a subclass."""
return self
class DreamsimModel(DreamsimBackbone):
@register_to_config
def __init__(
self,
image_size: int = 224,
patch_size: int = 16,
layer_norm_eps: float = 1e-6,
pre_norm: bool = False,
act_layer: str = "gelu",
img_mean: tuple[float, float, float] = (0.485, 0.456, 0.406),
img_std: tuple[float, float, float] = (0.229, 0.224, 0.225),
do_resize: bool = False,
) -> None:
super().__init__()
self.image_size = ensure_tuple(image_size, 2)
self.patch_size = ensure_tuple(patch_size, 2)
self.layer_norm_eps = layer_norm_eps
self.pre_norm = pre_norm
self.do_resize = do_resize
self.img_mean = img_mean
self.img_std = img_std
num_classes = 512 if self.pre_norm else 0
self.extractor: VisionTransformer = vit_base_dreamsim(
image_size=image_size,
patch_size=patch_size,
layer_norm_eps=layer_norm_eps,
num_classes=num_classes,
pre_norm=pre_norm,
act_layer=act_layer,
)
self.resize = T.Resize(
self.image_size,
interpolation=T.InterpolationMode.BICUBIC,
antialias=True,
)
self.img_norm = T.Normalize(mean=self.img_mean, std=self.img_std)
def compile(self, *, mode: str = "reduce-overhead", force: bool = False, **kwargs):
if (not self._compiled) or force:
self.extractor = torch.compile(self.extractor, mode=mode, **kwargs)
self._compiled = True
return self
def transforms(self, x: Tensor) -> Tensor:
if self.do_resize:
x = self.resize(x)
return self.img_norm(x)
def forward_features(self, x: Tensor) -> Tensor:
if x.ndim == 3:
x = x.unsqueeze(0)
x = self.transforms(x)
x = self.extractor.forward(x, norm=self.pre_norm)
x = x.div(x.norm(dim=1, keepdim=True))
x = x.sub(x.mean(dim=1, keepdim=True))
return x
class DreamsimEnsemble(DreamsimBackbone):
@register_to_config
def __init__(
self,
image_size: int = 224,
patch_size: int = 16,
layer_norm_eps: float | tuple[float, ...] = (1e-6, 1e-5, 1e-5),
num_classes: int | tuple[int, ...] = (0, 512, 512),
do_resize: bool = False,
) -> None:
super().__init__()
if isinstance(layer_norm_eps, float):
layer_norm_eps = (layer_norm_eps,) * 3
if isinstance(num_classes, int):
num_classes = (num_classes,) * 3
self.image_size = ensure_tuple(image_size, 2)
self.patch_size = ensure_tuple(patch_size, 2)
self.do_resize = do_resize
self.dino: VisionTransformer = vit_base_dreamsim(
image_size=self.image_size,
patch_size=self.patch_size,
layer_norm_eps=layer_norm_eps[0],
num_classes=num_classes[0],
pre_norm=False,
act_layer="gelu",
)
self.clip1: VisionTransformer = vit_base_dreamsim(
image_size=self.image_size,
patch_size=self.patch_size,
layer_norm_eps=layer_norm_eps[1],
num_classes=num_classes[1],
pre_norm=True,
act_layer="quick_gelu",
)
self.clip2: VisionTransformer = vit_base_dreamsim(
image_size=self.image_size,
patch_size=self.patch_size,
layer_norm_eps=layer_norm_eps[2],
num_classes=num_classes[2],
pre_norm=True,
act_layer="gelu",
)
self.resize = T.Resize(
self.image_size,
interpolation=T.InterpolationMode.BICUBIC,
antialias=True,
)
self.dino_norm = T.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
)
self.clip_norm = T.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
)
self._compiled = False
def compile(self, *, mode: str = "reduce-overhead", force: bool = False, **kwargs):
if (not self._compiled) or force:
self.dino = torch.compile(self.dino, mode=mode, **kwargs)
self.clip1 = torch.compile(self.clip1, mode=mode, **kwargs)
self.clip2 = torch.compile(self.clip2, mode=mode, **kwargs)
self._compiled = True
return self
def transforms(self, x: Tensor, resize: bool = False) -> tuple[Tensor, Tensor, Tensor]:
if resize:
x = self.resize(x)
x = self.dino_norm(x), self.clip_norm(x), self.clip_norm(x)
return x
def forward_features(self, x: Tensor) -> Tensor:
if x.ndim == 3:
x = x.unsqueeze(0)
x_dino, x_clip1, x_clip2 = self.transforms(x, self.do_resize)
# these expect to always receive a batch, and will return a batch
x_dino = self.dino.forward(x_dino, norm=False)
x_clip1 = self.clip1.forward(x_clip1, norm=True)
x_clip2 = self.clip2.forward(x_clip2, norm=True)
z: Tensor = torch.cat([x_dino, x_clip1, x_clip2], dim=1)
z = z.div(z.norm(dim=1, keepdim=True))
z = z.sub(z.mean(dim=1, keepdim=True))
return z