Transformers
Inference Endpoints
neggles commited on
Commit
8a8582e
1 Parent(s): f664757

change dreamsim class hierarchy a bit

Browse files
Files changed (3) hide show
  1. __init__.py +3 -2
  2. model.py +55 -40
  3. vit.py +3 -3
__init__.py CHANGED
@@ -1,9 +1,10 @@
1
- from .model import DreamsimEnsemble, DreamsimModel
2
  from .vit import VisionTransformer, vit_base_dreamsim
3
 
4
  __all__ = [
5
- "DreamsimModel",
6
  "DreamsimEnsemble",
 
7
  "VisionTransformer",
8
  "vit_base_dreamsim",
9
  ]
 
1
+ from .model import DreamsimBackbone, DreamsimEnsemble, DreamsimModel
2
  from .vit import VisionTransformer, vit_base_dreamsim
3
 
4
  __all__ = [
5
+ "DreamsimBackbone",
6
  "DreamsimEnsemble",
7
+ "DreamsimModel",
8
  "VisionTransformer",
9
  "vit_base_dreamsim",
10
  ]
model.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  from diffusers.configuration_utils import ConfigMixin, register_to_config
3
  from diffusers.models.modeling_utils import ModelMixin
@@ -9,7 +11,31 @@ from .common import ensure_tuple
9
  from .vit import VisionTransformer, vit_base_dreamsim
10
 
11
 
12
- class DreamsimModel(ModelMixin, ConfigMixin):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @register_to_config
14
  def __init__(
15
  self,
@@ -25,7 +51,7 @@ class DreamsimModel(ModelMixin, ConfigMixin):
25
  super().__init__()
26
 
27
  self.image_size = ensure_tuple(image_size, 2)
28
- self.patch_size = patch_size
29
  self.layer_norm_eps = layer_norm_eps
30
  self.pre_norm = pre_norm
31
  self.do_resize = do_resize
@@ -49,6 +75,12 @@ class DreamsimModel(ModelMixin, ConfigMixin):
49
  )
50
  self.img_norm = T.Normalize(mean=self.img_mean, std=self.img_std)
51
 
 
 
 
 
 
 
52
  def transforms(self, x: Tensor) -> Tensor:
53
  if self.do_resize:
54
  x = self.resize(x)
@@ -60,42 +92,29 @@ class DreamsimModel(ModelMixin, ConfigMixin):
60
  x = self.transforms(x)
61
  x = self.extractor.forward(x, norm=self.pre_norm)
62
 
63
- x.div_(x.norm(dim=1, keepdim=True))
64
- x.sub_(x.mean(dim=1, keepdim=True))
65
  return x
66
 
67
- def forward(self, x: Tensor) -> Tensor:
68
- """Dreamsim forward pass for similarity computation.
69
- Args:
70
- x (Tensor): Input tensor of shape [2, B, 3, H, W].
71
 
72
- Returns:
73
- sim (torch.Tensor): dreamsim similarity score of shape [B].
74
- """
75
- all_images = x.view(-1, 3, *x.shape[-2:])
76
-
77
- x = self.forward_features(all_images)
78
- x = x.view(*x.shape[:2], -1)
79
-
80
- return 1 - F.cosine_similarity(x[0], x[1], dim=1)
81
-
82
-
83
- class DreamsimEnsemble(ModelMixin, ConfigMixin):
84
  @register_to_config
85
  def __init__(
86
  self,
87
  image_size: int = 224,
88
  patch_size: int = 16,
89
  layer_norm_eps: float | tuple[float, ...] = (1e-6, 1e-5, 1e-5),
90
- num_classes: tuple[int, int, int] = (0, 512, 512),
91
  do_resize: bool = False,
92
  ) -> None:
93
  super().__init__()
94
  if isinstance(layer_norm_eps, float):
95
  layer_norm_eps = (layer_norm_eps,) * 3
 
 
96
 
97
  self.image_size = ensure_tuple(image_size, 2)
98
- self.patch_size = patch_size
99
  self.do_resize = do_resize
100
 
101
  self.dino: VisionTransformer = vit_base_dreamsim(
@@ -137,10 +156,21 @@ class DreamsimEnsemble(ModelMixin, ConfigMixin):
137
  std=(0.26862954, 0.26130258, 0.27577711),
138
  )
139
 
 
 
 
 
 
 
 
 
 
 
140
  def transforms(self, x: Tensor, resize: bool = False) -> tuple[Tensor, Tensor, Tensor]:
141
  if resize:
142
  x = self.resize(x)
143
- return self.dino_norm(x), self.clip_norm(x), self.clip_norm(x)
 
144
 
145
  def forward_features(self, x: Tensor) -> Tensor:
146
  if x.ndim == 3:
@@ -153,21 +183,6 @@ class DreamsimEnsemble(ModelMixin, ConfigMixin):
153
  x_clip2 = self.clip2.forward(x_clip2, norm=True)
154
 
155
  z: Tensor = torch.cat([x_dino, x_clip1, x_clip2], dim=1)
156
- z.div_(z.norm(dim=1, keepdim=True))
157
- z.sub_(z.mean(dim=1, keepdim=True))
158
  return z
159
-
160
- def forward(self, x: Tensor) -> Tensor:
161
- """Dreamsim forward pass for similarity computation.
162
- Args:
163
- x (Tensor): Input tensor of shape [2, B, 3, H, W].
164
-
165
- Returns:
166
- sim (torch.Tensor): dreamsim similarity score of shape [B].
167
- """
168
- all_images = x.view(-1, 3, *x.shape[-2:])
169
-
170
- x = self.forward_features(all_images)
171
- x = x.view(*x.shape[:2], -1)
172
-
173
- return 1 - F.cosine_similarity(x[0], x[1], dim=1)
 
1
+ from abc import abstractmethod
2
+
3
  import torch
4
  from diffusers.configuration_utils import ConfigMixin, register_to_config
5
  from diffusers.models.modeling_utils import ModelMixin
 
11
  from .vit import VisionTransformer, vit_base_dreamsim
12
 
13
 
14
+ class DreamsimBackbone(ModelMixin, ConfigMixin):
15
+ @abstractmethod
16
+ def forward_features(self, x: Tensor) -> Tensor:
17
+ raise NotImplementedError("abstract base class was called ;_;")
18
+
19
+ def forward(self, x: Tensor) -> Tensor:
20
+ """Dreamsim forward pass for similarity computation.
21
+ Args:
22
+ x (Tensor): Input tensor of shape [2, B, 3, H, W].
23
+
24
+ Returns:
25
+ sim (torch.Tensor): dreamsim similarity score of shape [B].
26
+ """
27
+ inputs = x.view(-1, 3, *x.shape[-2:])
28
+
29
+ x = self.forward_features(inputs).view(*x.shape[:2], -1)
30
+
31
+ return 1 - F.cosine_similarity(x[0], x[1], dim=1)
32
+
33
+ def compile(self, *args, **kwargs):
34
+ """Compile the model with Inductor. This is a no-op unless overridden by a subclass."""
35
+ return self
36
+
37
+
38
+ class DreamsimModel(DreamsimBackbone):
39
  @register_to_config
40
  def __init__(
41
  self,
 
51
  super().__init__()
52
 
53
  self.image_size = ensure_tuple(image_size, 2)
54
+ self.patch_size = ensure_tuple(patch_size, 2)
55
  self.layer_norm_eps = layer_norm_eps
56
  self.pre_norm = pre_norm
57
  self.do_resize = do_resize
 
75
  )
76
  self.img_norm = T.Normalize(mean=self.img_mean, std=self.img_std)
77
 
78
+ def compile(self, *, mode: str = "reduce-overhead", force: bool = False, **kwargs):
79
+ if (not self._compiled) or force:
80
+ self.extractor = torch.compile(self.extractor, mode=mode, **kwargs)
81
+ self._compiled = True
82
+ return self
83
+
84
  def transforms(self, x: Tensor) -> Tensor:
85
  if self.do_resize:
86
  x = self.resize(x)
 
92
  x = self.transforms(x)
93
  x = self.extractor.forward(x, norm=self.pre_norm)
94
 
95
+ x = x.div(x.norm(dim=1, keepdim=True))
96
+ x = x.sub(x.mean(dim=1, keepdim=True))
97
  return x
98
 
 
 
 
 
99
 
100
+ class DreamsimEnsemble(DreamsimBackbone):
 
 
 
 
 
 
 
 
 
 
 
101
  @register_to_config
102
  def __init__(
103
  self,
104
  image_size: int = 224,
105
  patch_size: int = 16,
106
  layer_norm_eps: float | tuple[float, ...] = (1e-6, 1e-5, 1e-5),
107
+ num_classes: int | tuple[int, ...] = (0, 512, 512),
108
  do_resize: bool = False,
109
  ) -> None:
110
  super().__init__()
111
  if isinstance(layer_norm_eps, float):
112
  layer_norm_eps = (layer_norm_eps,) * 3
113
+ if isinstance(num_classes, int):
114
+ num_classes = (num_classes,) * 3
115
 
116
  self.image_size = ensure_tuple(image_size, 2)
117
+ self.patch_size = ensure_tuple(patch_size, 2)
118
  self.do_resize = do_resize
119
 
120
  self.dino: VisionTransformer = vit_base_dreamsim(
 
156
  std=(0.26862954, 0.26130258, 0.27577711),
157
  )
158
 
159
+ self._compiled = False
160
+
161
+ def compile(self, *, mode: str = "reduce-overhead", force: bool = False, **kwargs):
162
+ if (not self._compiled) or force:
163
+ self.dino = torch.compile(self.dino, mode=mode, **kwargs)
164
+ self.clip1 = torch.compile(self.clip1, mode=mode, **kwargs)
165
+ self.clip2 = torch.compile(self.clip2, mode=mode, **kwargs)
166
+ self._compiled = True
167
+ return self
168
+
169
  def transforms(self, x: Tensor, resize: bool = False) -> tuple[Tensor, Tensor, Tensor]:
170
  if resize:
171
  x = self.resize(x)
172
+ x = self.dino_norm(x), self.clip_norm(x), self.clip_norm(x)
173
+ return x
174
 
175
  def forward_features(self, x: Tensor) -> Tensor:
176
  if x.ndim == 3:
 
183
  x_clip2 = self.clip2.forward(x_clip2, norm=True)
184
 
185
  z: Tensor = torch.cat([x_dino, x_clip1, x_clip2], dim=1)
186
+ z = z.div(z.norm(dim=1, keepdim=True))
187
+ z = z.sub(z.mean(dim=1, keepdim=True))
188
  return z
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vit.py CHANGED
@@ -179,9 +179,9 @@ class PatchEmbed(nn.Module):
179
  dynamic_pad: bool = False,
180
  ):
181
  super().__init__()
182
- self.img_size = ensure_tuple(img_size)
183
- self.patch_size = ensure_tuple(patch_size)
184
- self.num_patches = (img_size // patch_size) ** 2
185
 
186
  self.dynamic_pad = dynamic_pad
187
 
 
179
  dynamic_pad: bool = False,
180
  ):
181
  super().__init__()
182
+ self.img_size = ensure_tuple(img_size, 2)
183
+ self.patch_size = ensure_tuple(patch_size, 2)
184
+ self.num_patches = (self.img_size[0] // self.patch_size[0]) * (self.img_size[1] // self.patch_size[1])
185
 
186
  self.dynamic_pad = dynamic_pad
187