Leyo commited on
Commit
3b15dc9
1 Parent(s): 9d8f3cc

add convert siglip to hf

Browse files
Files changed (1) hide show
  1. convert_siglip_to_hf.py +413 -0
convert_siglip_to_hf.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert SigLIP checkpoints from the original repository.
16
+
17
+ URL: https://github.com/google-research/big_vision/tree/main
18
+ """
19
+
20
+
21
+ import argparse
22
+ import collections
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import requests
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+ from numpy import load
30
+ from PIL import Image
31
+
32
+ from transformers import SiglipConfig, SiglipImageProcessor, SiglipModel, SiglipProcessor, SiglipTokenizer
33
+ from transformers.utils import logging
34
+
35
+
36
+ logging.set_verbosity_info()
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ model_name_to_checkpoint = {
41
+ # base checkpoints
42
+ "siglip-base-patch16-224": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz",
43
+ "siglip-base-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_256_60500360.npz",
44
+ "siglip-base-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_384_68578854.npz",
45
+ "siglip-base-patch16-512": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_512_68580893.npz",
46
+ # large checkpoints
47
+ "siglip-large-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_256_60552751.npz",
48
+ "siglip-large-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_384_63634585.npz",
49
+ # multilingual checkpoint
50
+ "siglip-base-patch16-256-i18n": "/Users/nielsrogge/Documents/SigLIP/webli_i18n_b16_256_66117334.npz",
51
+ # so400m checkpoints
52
+ "siglip-so400m-patch14-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_so400m_384_58765454.npz",
53
+ }
54
+
55
+ model_name_to_image_size = {
56
+ "siglip-base-patch16-224": 224,
57
+ "siglip-base-patch16-256": 256,
58
+ "siglip-base-patch16-384": 384,
59
+ "siglip-base-patch16-512": 512,
60
+ "siglip-large-patch16-256": 256,
61
+ "siglip-large-patch16-384": 384,
62
+ "siglip-base-patch16-256-i18n": 256,
63
+ "siglip-so400m-patch14-384": 384,
64
+ }
65
+
66
+
67
+ def get_siglip_config(model_name):
68
+ config = SiglipConfig()
69
+
70
+ vocab_size = 250000 if "i18n" in model_name else 32000
71
+ image_size = model_name_to_image_size[model_name]
72
+ patch_size = 16 if "patch16" in model_name else 14
73
+
74
+ # size of the architecture
75
+ config.vision_config.image_size = image_size
76
+ config.vision_config.patch_size = patch_size
77
+ config.text_config.vocab_size = vocab_size
78
+
79
+ if "base" in model_name:
80
+ pass
81
+ elif "large" in model_name:
82
+ config.text_config.hidden_size = 1024
83
+ config.text_config.intermediate_size = 4096
84
+ config.text_config.num_hidden_layers = 24
85
+ config.text_config.num_attention_heads = 16
86
+ config.vision_config.hidden_size = 1024
87
+ config.vision_config.intermediate_size = 4096
88
+ config.vision_config.num_hidden_layers = 24
89
+ config.vision_config.num_attention_heads = 16
90
+ elif "so400m" in model_name:
91
+ config.text_config.hidden_size = 1152
92
+ config.text_config.intermediate_size = 4304
93
+ config.text_config.num_hidden_layers = 27
94
+ config.text_config.num_attention_heads = 16
95
+ config.vision_config.hidden_size = 1152
96
+ config.vision_config.intermediate_size = 4304
97
+ config.vision_config.num_hidden_layers = 27
98
+ config.vision_config.num_attention_heads = 16
99
+ else:
100
+ raise ValueError("Model not supported")
101
+
102
+ return config
103
+
104
+
105
+ def create_rename_keys(config):
106
+ rename_keys = []
107
+ # fmt: off
108
+
109
+ # vision encoder
110
+
111
+ rename_keys.append(("params/img/embedding/kernel", "vision_model.embeddings.patch_embedding.weight"))
112
+ rename_keys.append(("params/img/embedding/bias", "vision_model.embeddings.patch_embedding.bias"))
113
+ rename_keys.append(("params/img/pos_embedding", "vision_model.embeddings.position_embedding.weight"))
114
+
115
+ for i in range(config.vision_config.num_hidden_layers):
116
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
117
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
118
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
119
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
120
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
121
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
122
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
123
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
124
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.encoder.layers.{i}.self_attn.k_proj.weight"))
125
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.encoder.layers.{i}.self_attn.k_proj.bias"))
126
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.encoder.layers.{i}.self_attn.v_proj.weight"))
127
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.encoder.layers.{i}.self_attn.v_proj.bias"))
128
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.encoder.layers.{i}.self_attn.q_proj.weight"))
129
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.encoder.layers.{i}.self_attn.q_proj.bias"))
130
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.encoder.layers.{i}.self_attn.out_proj.weight"))
131
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.encoder.layers.{i}.self_attn.out_proj.bias"))
132
+
133
+ rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.post_layernorm.weight"))
134
+ rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.post_layernorm.bias"))
135
+
136
+ rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.head.probe"))
137
+ rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.head.layernorm.weight"))
138
+ rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.head.layernorm.bias"))
139
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.head.mlp.fc1.weight"))
140
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.head.mlp.fc1.bias"))
141
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.head.mlp.fc2.weight"))
142
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.head.mlp.fc2.bias"))
143
+ rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.head.attention.out_proj.weight"))
144
+ rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.head.attention.out_proj.bias"))
145
+
146
+ # text encoder
147
+
148
+ rename_keys.append(("params/txt/Embed_0/embedding", "text_model.embeddings.token_embedding.weight"))
149
+ rename_keys.append(("params/txt/pos_embedding", "text_model.embeddings.position_embedding.weight"))
150
+
151
+ for i in range(config.text_config.num_hidden_layers):
152
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.encoder.layers.{i}.layer_norm1.weight"))
153
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.encoder.layers.{i}.layer_norm1.bias"))
154
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.encoder.layers.{i}.layer_norm2.weight"))
155
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.encoder.layers.{i}.layer_norm2.bias"))
156
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.encoder.layers.{i}.mlp.fc1.weight"))
157
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.encoder.layers.{i}.mlp.fc1.bias"))
158
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.encoder.layers.{i}.mlp.fc2.weight"))
159
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.encoder.layers.{i}.mlp.fc2.bias"))
160
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.encoder.layers.{i}.self_attn.k_proj.weight"))
161
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.encoder.layers.{i}.self_attn.k_proj.bias"))
162
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.encoder.layers.{i}.self_attn.v_proj.weight"))
163
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.encoder.layers.{i}.self_attn.v_proj.bias"))
164
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.encoder.layers.{i}.self_attn.q_proj.weight"))
165
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.encoder.layers.{i}.self_attn.q_proj.bias"))
166
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.encoder.layers.{i}.self_attn.out_proj.weight"))
167
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.encoder.layers.{i}.self_attn.out_proj.bias"))
168
+
169
+ rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.final_layer_norm.weight"))
170
+ rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.final_layer_norm.bias"))
171
+ rename_keys.append(("params/txt/head/kernel", "text_model.head.weight"))
172
+ rename_keys.append(("params/txt/head/bias", "text_model.head.bias"))
173
+
174
+ # learned temperature and bias
175
+ rename_keys.append(("params/t", "logit_scale"))
176
+ rename_keys.append(("params/b", "logit_bias"))
177
+
178
+ # fmt: on
179
+ return rename_keys
180
+
181
+
182
+ def rename_key(dct, old, new, config):
183
+ val = dct.pop(old)
184
+
185
+ if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new:
186
+ val = val.reshape(-1, config.vision_config.hidden_size)
187
+ if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new:
188
+ val = val.reshape(-1, config.text_config.hidden_size)
189
+
190
+ if "patch_embedding.weight" in new:
191
+ val = val.transpose(3, 2, 0, 1)
192
+ elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new:
193
+ val = val.T
194
+
195
+ if "position_embedding" in new and "vision" in new:
196
+ val = val.reshape(-1, config.vision_config.hidden_size)
197
+ if "position_embedding" in new and "text" in new:
198
+ val = val.reshape(-1, config.text_config.hidden_size)
199
+
200
+ if new.endswith("bias"):
201
+ val = val.reshape(-1)
202
+
203
+ dct[new] = torch.from_numpy(val)
204
+
205
+
206
+ def read_in_q_k_v_head(state_dict, config):
207
+ # read in individual input projection layers
208
+ key_proj_weight = (
209
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel")
210
+ .reshape(-1, config.vision_config.hidden_size)
211
+ .T
212
+ )
213
+ key_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias").reshape(-1)
214
+ value_proj_weight = (
215
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel")
216
+ .reshape(-1, config.vision_config.hidden_size)
217
+ .T
218
+ )
219
+ value_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias").reshape(-1)
220
+ query_proj_weight = (
221
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel")
222
+ .reshape(-1, config.vision_config.hidden_size)
223
+ .T
224
+ )
225
+ query_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias").reshape(-1)
226
+
227
+ # next, add them to the state dict as a single matrix + vector
228
+ state_dict["vision_model.head.attention.in_proj_weight"] = torch.from_numpy(
229
+ np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0)
230
+ )
231
+ state_dict["vision_model.head.attention.in_proj_bias"] = torch.from_numpy(
232
+ np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0)
233
+ )
234
+
235
+
236
+ # We will verify our results on an image of cute cats
237
+ def prepare_img():
238
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
239
+ image = Image.open(requests.get(url, stream=True).raw)
240
+ return image
241
+
242
+
243
+ def flatten_nested_dict(params, parent_key="", sep="/"):
244
+ items = []
245
+
246
+ for k, v in params.items():
247
+ new_key = parent_key + sep + k if parent_key else k
248
+
249
+ if isinstance(v, collections.abc.MutableMapping):
250
+ items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
251
+ else:
252
+ items.append((new_key, v))
253
+ return dict(items)
254
+
255
+
256
+ @torch.no_grad()
257
+ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False):
258
+ """
259
+ Copy/paste/tweak model's weights to our SigLIP structure.
260
+ """
261
+
262
+ # define default SigLIP configuration
263
+ config = get_siglip_config(model_name)
264
+
265
+ # get checkpoint
266
+ checkpoint = model_name_to_checkpoint[model_name]
267
+
268
+ # get vocab file
269
+ if "i18n" in model_name:
270
+ vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model"
271
+ else:
272
+ vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model"
273
+
274
+ # load original state dict
275
+ data = load(checkpoint)
276
+ state_dict = flatten_nested_dict(data)
277
+
278
+ # remove and rename some keys
279
+ rename_keys = create_rename_keys(config)
280
+ for src, dest in rename_keys:
281
+ rename_key(state_dict, src, dest, config)
282
+
283
+ # qkv matrices of attention pooling head need special treatment
284
+ read_in_q_k_v_head(state_dict, config)
285
+
286
+ # load HuggingFace model
287
+ model = SiglipModel(config).eval()
288
+ model.load_state_dict(state_dict)
289
+
290
+ # create processor
291
+ # important: make tokenizer not return attention_mask since original one doesn't require it
292
+ image_size = config.vision_config.image_size
293
+ size = {"height": image_size, "width": image_size}
294
+ image_processor = SiglipImageProcessor(size=size)
295
+ tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"])
296
+ processor = SiglipProcessor(image_processor=image_processor, tokenizer=tokenizer)
297
+
298
+ # verify on dummy images and texts
299
+ url_1 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg"
300
+ image_1 = Image.open(requests.get(url_1, stream=True).raw).convert("RGB")
301
+ url_2 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg"
302
+ image_2 = Image.open(requests.get(url_2, stream=True).raw).convert("RGB")
303
+ texts = ["an apple", "a picture of an apple"]
304
+
305
+ inputs = processor(images=[image_1, image_2], text=texts, return_tensors="pt", padding="max_length")
306
+
307
+ # verify input_ids against original ones
308
+ if image_size == 224:
309
+ filename = "siglip_pixel_values.pt"
310
+ elif image_size == 256:
311
+ filename = "siglip_pixel_values_256.pt"
312
+ elif image_size == 384:
313
+ filename = "siglip_pixel_values_384.pt"
314
+ elif image_size == 512:
315
+ filename = "siglip_pixel_values_512.pt"
316
+ else:
317
+ raise ValueError("Image size not supported")
318
+
319
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename=filename, repo_type="dataset")
320
+ original_pixel_values = torch.load(filepath)
321
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename="siglip_input_ids.pt", repo_type="dataset")
322
+ original_input_ids = torch.load(filepath)
323
+
324
+ if "i18n" not in model_name:
325
+ assert inputs.input_ids.tolist() == original_input_ids.tolist()
326
+
327
+ print("Mean of original pixel values:", original_pixel_values.mean())
328
+ print("Mean of new pixel values:", inputs.pixel_values.mean())
329
+
330
+ # note: we're testing with original pixel values here since we don't have exact pixel values
331
+ with torch.no_grad():
332
+ outputs = model(input_ids=inputs.input_ids, pixel_values=original_pixel_values)
333
+
334
+ # with torch.no_grad():
335
+ # outputs = model(input_ids=inputs.input_ids, pixel_values=inputs.pixel_values)
336
+
337
+ print(outputs.logits_per_image[:3, :3])
338
+
339
+ probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities
340
+ print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
341
+ print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'")
342
+
343
+ if verify_logits:
344
+ if model_name == "siglip-base-patch16-224":
345
+ expected_slice = torch.tensor(
346
+ [[-2.9621, -2.1672], [-0.2713, 0.2910]],
347
+ )
348
+ elif model_name == "siglip-base-patch16-256":
349
+ expected_slice = torch.tensor(
350
+ [[-3.1146, -1.9894], [-0.7312, 0.6387]],
351
+ )
352
+ elif model_name == "siglip-base-patch16-384":
353
+ expected_slice = torch.tensor(
354
+ [[-2.8098, -2.1891], [-0.4242, 0.4102]],
355
+ )
356
+ elif model_name == "siglip-base-patch16-512":
357
+ expected_slice = torch.tensor(
358
+ [[-2.7899, -2.2668], [-0.4295, -0.0735]],
359
+ )
360
+ elif model_name == "siglip-large-patch16-256":
361
+ expected_slice = torch.tensor(
362
+ [[-1.5827, -0.5801], [-0.9153, 0.1363]],
363
+ )
364
+ elif model_name == "siglip-large-patch16-384":
365
+ expected_slice = torch.tensor(
366
+ [[-2.1523, -0.2899], [-0.2959, 0.7884]],
367
+ )
368
+ elif model_name == "siglip-so400m-patch14-384":
369
+ expected_slice = torch.tensor([[-1.2441, -0.6649], [-0.7060, 0.7374]])
370
+ elif model_name == "siglip-base-patch16-256-i18n":
371
+ expected_slice = torch.tensor(
372
+ [[-0.9064, 0.1073], [-0.0299, 0.5304]],
373
+ )
374
+
375
+ assert torch.allclose(outputs.logits_per_image[:3, :3], expected_slice, atol=1e-4)
376
+ print("Looks ok!")
377
+
378
+ if pytorch_dump_folder_path is not None:
379
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
380
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
381
+ model.save_pretrained(pytorch_dump_folder_path)
382
+ print(f"Saving processor to {pytorch_dump_folder_path}")
383
+ processor.save_pretrained(pytorch_dump_folder_path)
384
+
385
+ if push_to_hub:
386
+ model.push_to_hub(f"nielsr/{model_name}")
387
+ processor.push_to_hub(f"nielsr/{model_name}")
388
+
389
+
390
+ if __name__ == "__main__":
391
+ parser = argparse.ArgumentParser()
392
+ # Required parameters
393
+ parser.add_argument(
394
+ "--model_name",
395
+ default="siglip-base-patch16-224",
396
+ type=str,
397
+ choices=model_name_to_checkpoint.keys(),
398
+ help="Name of the model you'd like to convert.",
399
+ )
400
+ parser.add_argument(
401
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
402
+ )
403
+ parser.add_argument(
404
+ "--verify_logits",
405
+ action="store_false",
406
+ help="Whether to verify logits against the original implementation.",
407
+ )
408
+ parser.add_argument(
409
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
410
+ )
411
+
412
+ args = parser.parse_args()
413
+ convert_siglip_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub)