HugoLaurencon commited on
Commit
248de06
·
1 Parent(s): 28f2e7e

initial files PR #26522 in `transformers`

Browse files
configuration_siglip.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Siglip model configuration"""
16
+
17
+ import os
18
+ from collections import OrderedDict
19
+ from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
20
+
21
+
22
+ if TYPE_CHECKING:
23
+ from ...processing_utils import ProcessorMixin
24
+ from ...utils import TensorType
25
+
26
+ from ...configuration_utils import PretrainedConfig
27
+ from ...onnx import OnnxConfig
28
+ from ...utils import logging
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
34
+ "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json",
35
+ }
36
+
37
+
38
+ class SiglipTextConfig(PretrainedConfig):
39
+ r"""
40
+ This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
41
+ Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
42
+ configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
43
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
44
+
45
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
46
+ documentation from [`PretrainedConfig`] for more information.
47
+
48
+ Args:
49
+ vocab_size (`int`, *optional*, defaults to 49408):
50
+ Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
51
+ the `inputs_ids` passed when calling [`SiglipModel`].
52
+ hidden_size (`int`, *optional*, defaults to 512):
53
+ Dimensionality of the encoder layers and the pooler layer.
54
+ intermediate_size (`int`, *optional*, defaults to 2048):
55
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
56
+ num_hidden_layers (`int`, *optional*, defaults to 12):
57
+ Number of hidden layers in the Transformer encoder.
58
+ num_attention_heads (`int`, *optional*, defaults to 8):
59
+ Number of attention heads for each attention layer in the Transformer encoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 64):
61
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
62
+ just in case (e.g., 512 or 1024 or 2048).
63
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
64
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
65
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
66
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
67
+ The epsilon used by the layer normalization layers.
68
+ attention_dropout (`float`, *optional*, defaults to 0.0):
69
+ The dropout ratio for the attention probabilities.
70
+ initializer_range (`float`, *optional*, defaults to 0.02):
71
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72
+ initializer_factor (`float`, *optional*, defaults to 1):
73
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
74
+ testing).
75
+
76
+ Example:
77
+
78
+ ```python
79
+ >>> from transformers import SiglipTextConfig, SiglipTextModel
80
+
81
+ >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
82
+ >>> configuration = SiglipTextConfig()
83
+
84
+ >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
85
+ >>> model = SiglipTextModel(configuration)
86
+
87
+ >>> # Accessing the model configuration
88
+ >>> configuration = model.config
89
+ ```"""
90
+ model_type = "siglip_text_model"
91
+
92
+ def __init__(
93
+ self,
94
+ vocab_size=49408,
95
+ hidden_size=512,
96
+ intermediate_size=2048,
97
+ projection_dim=512,
98
+ num_hidden_layers=12,
99
+ num_attention_heads=8,
100
+ max_position_embeddings=64,
101
+ hidden_act="gelu_pytorch_tanh",
102
+ layer_norm_eps=1e-6,
103
+ attention_dropout=0.0,
104
+ initializer_range=0.02,
105
+ initializer_factor=1.0,
106
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip
107
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
108
+ pad_token_id=1,
109
+ bos_token_id=49406,
110
+ eos_token_id=49407,
111
+ **kwargs,
112
+ ):
113
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
114
+
115
+ self.vocab_size = vocab_size
116
+ self.hidden_size = hidden_size
117
+ self.intermediate_size = intermediate_size
118
+ self.projection_dim = projection_dim
119
+ self.num_hidden_layers = num_hidden_layers
120
+ self.num_attention_heads = num_attention_heads
121
+ self.max_position_embeddings = max_position_embeddings
122
+ self.layer_norm_eps = layer_norm_eps
123
+ self.hidden_act = hidden_act
124
+ self.initializer_range = initializer_range
125
+ self.initializer_factor = initializer_factor
126
+ self.attention_dropout = attention_dropout
127
+
128
+ @classmethod
129
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
130
+ cls._set_token_in_kwargs(kwargs)
131
+
132
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
133
+
134
+ # get the text config dict if we are loading from SiglipConfig
135
+ if config_dict.get("model_type") == "siglip":
136
+ config_dict = config_dict["text_config"]
137
+
138
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
139
+ logger.warning(
140
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
141
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
142
+ )
143
+
144
+ return cls.from_dict(config_dict, **kwargs)
145
+
146
+
147
+ class SiglipVisionConfig(PretrainedConfig):
148
+ r"""
149
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
150
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
151
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
152
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
153
+
154
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
155
+ documentation from [`PretrainedConfig`] for more information.
156
+
157
+ Args:
158
+ hidden_size (`int`, *optional*, defaults to 768):
159
+ Dimensionality of the encoder layers and the pooler layer.
160
+ intermediate_size (`int`, *optional*, defaults to 3072):
161
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
162
+ num_hidden_layers (`int`, *optional*, defaults to 12):
163
+ Number of hidden layers in the Transformer encoder.
164
+ num_attention_heads (`int`, *optional*, defaults to 12):
165
+ Number of attention heads for each attention layer in the Transformer encoder.
166
+ image_size (`int`, *optional*, defaults to 224):
167
+ The size (resolution) of each image.
168
+ patch_size (`int`, *optional*, defaults to 32):
169
+ The size (resolution) of each patch.
170
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
171
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
172
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
173
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
174
+ The epsilon used by the layer normalization layers.
175
+ attention_dropout (`float`, *optional*, defaults to 0.0):
176
+ The dropout ratio for the attention probabilities.
177
+ initializer_range (`float`, *optional*, defaults to 0.02):
178
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
179
+ initializer_factor (`float`, *optional*, defaults to 1):
180
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
181
+ testing).
182
+
183
+ Example:
184
+
185
+ ```python
186
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
187
+
188
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
189
+ >>> configuration = SiglipVisionConfig()
190
+
191
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
192
+ >>> model = SiglipVisionModel(configuration)
193
+
194
+ >>> # Accessing the model configuration
195
+ >>> configuration = model.config
196
+ ```"""
197
+
198
+ model_type = "siglip_vision_model"
199
+
200
+ def __init__(
201
+ self,
202
+ hidden_size=768,
203
+ intermediate_size=3072,
204
+ projection_dim=512,
205
+ num_hidden_layers=12,
206
+ num_attention_heads=12,
207
+ num_channels=3,
208
+ image_size=224,
209
+ patch_size=32,
210
+ hidden_act="gelu_pytorch_tanh",
211
+ layer_norm_eps=1e-6,
212
+ attention_dropout=0.0,
213
+ initializer_range=0.02,
214
+ initializer_factor=1.0,
215
+ **kwargs,
216
+ ):
217
+ super().__init__(**kwargs)
218
+
219
+ self.hidden_size = hidden_size
220
+ self.intermediate_size = intermediate_size
221
+ self.projection_dim = projection_dim
222
+ self.num_hidden_layers = num_hidden_layers
223
+ self.num_attention_heads = num_attention_heads
224
+ self.num_channels = num_channels
225
+ self.patch_size = patch_size
226
+ self.image_size = image_size
227
+ self.initializer_range = initializer_range
228
+ self.initializer_factor = initializer_factor
229
+ self.attention_dropout = attention_dropout
230
+ self.layer_norm_eps = layer_norm_eps
231
+ self.hidden_act = hidden_act
232
+
233
+ @classmethod
234
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
235
+ cls._set_token_in_kwargs(kwargs)
236
+
237
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
238
+
239
+ # get the vision config dict if we are loading from SiglipConfig
240
+ if config_dict.get("model_type") == "siglip":
241
+ config_dict = config_dict["vision_config"]
242
+
243
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
244
+ logger.warning(
245
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
246
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
247
+ )
248
+
249
+ return cls.from_dict(config_dict, **kwargs)
250
+
251
+
252
+ class SiglipConfig(PretrainedConfig):
253
+ r"""
254
+ [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
255
+ instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
256
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
257
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
258
+
259
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
260
+ documentation from [`PretrainedConfig`] for more information.
261
+
262
+ Args:
263
+ text_config (`dict`, *optional*):
264
+ Dictionary of configuration options used to initialize [`SiglipTextConfig`].
265
+ vision_config (`dict`, *optional*):
266
+ Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
267
+ projection_dim (`int`, *optional*, defaults to 512):
268
+ Dimentionality of text and vision projection layers.
269
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
270
+ The inital value of the *logit_scale* paramter. Default is used as per the original Siglip implementation.
271
+ kwargs (*optional*):
272
+ Dictionary of keyword arguments.
273
+
274
+ Example:
275
+
276
+ ```python
277
+ >>> from transformers import SiglipConfig, SiglipModel
278
+
279
+ >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
280
+ >>> configuration = SiglipConfig()
281
+
282
+ >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
283
+ >>> model = SiglipModel(configuration)
284
+
285
+ >>> # Accessing the model configuration
286
+ >>> configuration = model.config
287
+
288
+ >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
289
+ >>> from transformers import SiglipTextConfig, SiglipVisionConfig
290
+
291
+ >>> # Initializing a SiglipText and SiglipVision configuration
292
+ >>> config_text = SiglipTextConfig()
293
+ >>> config_vision = SiglipVisionConfig()
294
+
295
+ >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
296
+ ```"""
297
+
298
+ model_type = "siglip"
299
+
300
+ def __init__(
301
+ self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
302
+ ):
303
+ # If `_config_dict` exist, we use them for the backward compatibility.
304
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
305
+ # of confusion!).
306
+ text_config_dict = kwargs.pop("text_config_dict", None)
307
+ vision_config_dict = kwargs.pop("vision_config_dict", None)
308
+
309
+ super().__init__(**kwargs)
310
+
311
+ # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
312
+ # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
313
+ # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
314
+ if text_config_dict is not None:
315
+ if text_config is None:
316
+ text_config = {}
317
+
318
+ # This is the complete result when using `text_config_dict`.
319
+ _text_config_dict = SiglipTextConfig(**text_config_dict).to_dict()
320
+
321
+ # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
322
+ for key, value in _text_config_dict.items():
323
+ if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
324
+ # If specified in `text_config_dict`
325
+ if key in text_config_dict:
326
+ message = (
327
+ f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
328
+ f'The value `text_config_dict["{key}"]` will be used instead.'
329
+ )
330
+ # If inferred from default argument values (just to be super careful)
331
+ else:
332
+ message = (
333
+ f"`text_config_dict` is provided which will be used to initialize `SiglipTextConfig`. The "
334
+ f'value `text_config["{key}"]` will be overriden.'
335
+ )
336
+ logger.warning(message)
337
+
338
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
339
+ text_config.update(_text_config_dict)
340
+
341
+ if vision_config_dict is not None:
342
+ if vision_config is None:
343
+ vision_config = {}
344
+
345
+ # This is the complete result when using `vision_config_dict`.
346
+ _vision_config_dict = SiglipVisionConfig(**vision_config_dict).to_dict()
347
+ # convert keys to string instead of integer
348
+ if "id2label" in _vision_config_dict:
349
+ _vision_config_dict["id2label"] = {
350
+ str(key): value for key, value in _vision_config_dict["id2label"].items()
351
+ }
352
+
353
+ # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
354
+ for key, value in _vision_config_dict.items():
355
+ if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
356
+ # If specified in `vision_config_dict`
357
+ if key in vision_config_dict:
358
+ message = (
359
+ f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
360
+ f'values. The value `vision_config_dict["{key}"]` will be used instead.'
361
+ )
362
+ # If inferred from default argument values (just to be super careful)
363
+ else:
364
+ message = (
365
+ f"`vision_config_dict` is provided which will be used to initialize `SiglipVisionConfig`. "
366
+ f'The value `vision_config["{key}"]` will be overriden.'
367
+ )
368
+ logger.warning(message)
369
+
370
+ # Update all values in `vision_config` with the ones in `_vision_config_dict`.
371
+ vision_config.update(_vision_config_dict)
372
+
373
+ if text_config is None:
374
+ text_config = {}
375
+ logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
376
+
377
+ if vision_config is None:
378
+ vision_config = {}
379
+ logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
380
+
381
+ self.text_config = SiglipTextConfig(**text_config)
382
+ self.vision_config = SiglipVisionConfig(**vision_config)
383
+
384
+ self.projection_dim = projection_dim
385
+ self.logit_scale_init_value = logit_scale_init_value
386
+ self.initializer_factor = 1.0
387
+
388
+ @classmethod
389
+ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
390
+ r"""
391
+ Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
392
+ model configuration.
393
+
394
+ Returns:
395
+ [`SiglipConfig`]: An instance of a configuration object
396
+ """
397
+
398
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
399
+
400
+
401
+ class SiglipOnnxConfig(OnnxConfig):
402
+ @property
403
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
404
+ return OrderedDict(
405
+ [
406
+ ("input_ids", {0: "batch", 1: "sequence"}),
407
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
408
+ ("attention_mask", {0: "batch", 1: "sequence"}),
409
+ ]
410
+ )
411
+
412
+ @property
413
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
414
+ return OrderedDict(
415
+ [
416
+ ("logits_per_image", {0: "batch"}),
417
+ ("logits_per_text", {0: "batch"}),
418
+ ("text_embeds", {0: "batch"}),
419
+ ("image_embeds", {0: "batch"}),
420
+ ]
421
+ )
422
+
423
+ @property
424
+ def atol_for_validation(self) -> float:
425
+ return 1e-4
426
+
427
+ def generate_dummy_inputs(
428
+ self,
429
+ processor: "ProcessorMixin",
430
+ batch_size: int = -1,
431
+ seq_length: int = -1,
432
+ framework: Optional["TensorType"] = None,
433
+ ) -> Mapping[str, Any]:
434
+ text_input_dict = super().generate_dummy_inputs(
435
+ processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
436
+ )
437
+ image_input_dict = super().generate_dummy_inputs(
438
+ processor.image_processor, batch_size=batch_size, framework=framework
439
+ )
440
+ return {**text_input_dict, **image_input_dict}
441
+
442
+ @property
443
+ def default_onnx_opset(self) -> int:
444
+ return 14
convert_siglip_to_hf.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert SigLIP checkpoints from the original repository.
16
+
17
+ URL: https://github.com/google-research/big_vision/tree/main
18
+ """
19
+
20
+
21
+ import argparse
22
+ import collections
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import requests
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+ from numpy import load
30
+ from PIL import Image
31
+
32
+ from transformers import SiglipConfig, SiglipModel
33
+ from transformers.utils import logging
34
+
35
+
36
+ logging.set_verbosity_info()
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ def get_siglip_config(model_name):
41
+ config = SiglipConfig()
42
+
43
+ # size of the architecture
44
+ if "base" in model_name:
45
+ config.vision_config.image_size = 224
46
+ config.vision_config.patch_size = 16
47
+ config.text_config.vocab_size = 32000
48
+ config.text_config.hidden_size = 768
49
+ config.text_config.intermediate_size = 3072
50
+ config.text_config.max_position_embeddings = 64
51
+ config.text_config.num_attention_heads = 12
52
+ elif "large" in model_name:
53
+ config.vision_config.hidden_size = 1024
54
+ config.vision_config.num_hidden_layers = 24
55
+ config.vision_config.num_attention_heads = 16
56
+ else:
57
+ raise ValueError("Model not supported")
58
+
59
+ return config
60
+
61
+
62
+ def create_rename_keys(config):
63
+ rename_keys = []
64
+ # fmt: off
65
+
66
+ # vision encoder
67
+
68
+ rename_keys.append(("params/img/embedding/kernel", "vision_model.vision_model.embeddings.patch_embedding.weight"))
69
+ rename_keys.append(("params/img/embedding/bias", "vision_model.vision_model.embeddings.patch_embedding.bias"))
70
+ rename_keys.append(("params/img/pos_embedding", "vision_model.vision_model.embeddings.position_embedding.weight"))
71
+
72
+ for i in range(config.vision_config.num_hidden_layers):
73
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.vision_model.encoder.layers.{i}.layer_norm1.weight"))
74
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.vision_model.encoder.layers.{i}.layer_norm1.bias"))
75
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.vision_model.encoder.layers.{i}.layer_norm2.weight"))
76
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.vision_model.encoder.layers.{i}.layer_norm2.bias"))
77
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.vision_model.encoder.layers.{i}.mlp.fc1.weight"))
78
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.vision_model.encoder.layers.{i}.mlp.fc1.bias"))
79
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.vision_model.encoder.layers.{i}.mlp.fc2.weight"))
80
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.vision_model.encoder.layers.{i}.mlp.fc2.bias"))
81
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"))
82
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"))
83
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"))
84
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"))
85
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"))
86
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"))
87
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"))
88
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"))
89
+
90
+ rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.vision_model.post_layernorm.weight"))
91
+ rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.vision_model.post_layernorm.bias"))
92
+
93
+ rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.vision_model.head.probe"))
94
+ rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.vision_model.head.layernorm.weight"))
95
+ rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.vision_model.head.layernorm.bias"))
96
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.vision_model.head.mlp.fc1.weight"))
97
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.vision_model.head.mlp.fc1.bias"))
98
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.vision_model.head.mlp.fc2.weight"))
99
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.vision_model.head.mlp.fc2.bias"))
100
+ rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.vision_model.head.attention.out_proj.weight"))
101
+ rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.vision_model.head.attention.out_proj.bias"))
102
+
103
+ # text encoder
104
+
105
+ rename_keys.append(("params/txt/Embed_0/embedding", "text_model.text_model.embeddings.token_embedding.weight"))
106
+ rename_keys.append(("params/txt/pos_embedding", "text_model.text_model.embeddings.position_embedding.weight"))
107
+
108
+ for i in range(config.text_config.num_hidden_layers):
109
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.text_model.encoder.layers.{i}.layer_norm1.weight"))
110
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.text_model.encoder.layers.{i}.layer_norm1.bias"))
111
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.text_model.encoder.layers.{i}.layer_norm2.weight"))
112
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.text_model.encoder.layers.{i}.layer_norm2.bias"))
113
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.text_model.encoder.layers.{i}.mlp.fc1.weight"))
114
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.text_model.encoder.layers.{i}.mlp.fc1.bias"))
115
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.text_model.encoder.layers.{i}.mlp.fc2.weight"))
116
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.text_model.encoder.layers.{i}.mlp.fc2.bias"))
117
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.text_model.encoder.layers.{i}.self_attn.k_proj.weight"))
118
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.text_model.encoder.layers.{i}.self_attn.k_proj.bias"))
119
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.text_model.encoder.layers.{i}.self_attn.v_proj.weight"))
120
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.text_model.encoder.layers.{i}.self_attn.v_proj.bias"))
121
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.text_model.encoder.layers.{i}.self_attn.q_proj.weight"))
122
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.text_model.encoder.layers.{i}.self_attn.q_proj.bias"))
123
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.text_model.encoder.layers.{i}.self_attn.out_proj.weight"))
124
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.text_model.encoder.layers.{i}.self_attn.out_proj.bias"))
125
+
126
+ rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.text_model.final_layer_norm.weight"))
127
+ rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.text_model.final_layer_norm.bias"))
128
+ rename_keys.append(("params/txt/head/kernel", "text_model.text_model.head.weight"))
129
+ rename_keys.append(("params/txt/head/bias", "text_model.text_model.head.bias"))
130
+
131
+ # learned temperature and bias
132
+ rename_keys.append(("params/t", "temperature"))
133
+ rename_keys.append(("params/b", "bias"))
134
+
135
+ # fmt: on
136
+ return rename_keys
137
+
138
+
139
+ def rename_key(dct, old, new, config):
140
+ val = dct.pop(old)
141
+
142
+ if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new:
143
+ val = val.reshape(-1, config.vision_config.hidden_size)
144
+ if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new:
145
+ val = val.reshape(-1, config.text_config.hidden_size)
146
+
147
+ if "patch_embedding.weight" in new:
148
+ val = val.transpose(3, 2, 0, 1)
149
+ elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new:
150
+ val = val.T
151
+
152
+ if "position_embedding" in new and "vision" in new:
153
+ val = val.reshape(-1, config.vision_config.hidden_size)
154
+ if "position_embedding" in new and "text" in new:
155
+ val = val.reshape(-1, config.text_config.hidden_size)
156
+
157
+ if new.endswith("bias"):
158
+ val = val.reshape(-1)
159
+
160
+ dct[new] = torch.from_numpy(val)
161
+
162
+
163
+ def read_in_q_k_v_head(state_dict, config):
164
+ # read in individual input projection layers
165
+ key_proj_weight = (
166
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel")
167
+ .reshape(-1, config.vision_config.hidden_size)
168
+ .T
169
+ )
170
+ key_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias").reshape(-1)
171
+ value_proj_weight = (
172
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel")
173
+ .reshape(-1, config.vision_config.hidden_size)
174
+ .T
175
+ )
176
+ value_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias").reshape(-1)
177
+ query_proj_weight = (
178
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel")
179
+ .reshape(-1, config.vision_config.hidden_size)
180
+ .T
181
+ )
182
+ query_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias").reshape(-1)
183
+
184
+ # next, add them to the state dict as a single matrix + vector
185
+ state_dict["vision_model.vision_model.head.attention.in_proj_weight"] = torch.from_numpy(
186
+ np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0)
187
+ )
188
+ state_dict["vision_model.vision_model.head.attention.in_proj_bias"] = torch.from_numpy(
189
+ np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0)
190
+ )
191
+
192
+
193
+ # We will verify our results on an image of cute cats
194
+ def prepare_img():
195
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
196
+ image = Image.open(requests.get(url, stream=True).raw)
197
+ return image
198
+
199
+
200
+ def flatten_nested_dict(params, parent_key="", sep="/"):
201
+ items = []
202
+
203
+ for k, v in params.items():
204
+ new_key = parent_key + sep + k if parent_key else k
205
+
206
+ if isinstance(v, collections.abc.MutableMapping):
207
+ items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
208
+ else:
209
+ items.append((new_key, v))
210
+ return dict(items)
211
+
212
+
213
+ @torch.no_grad()
214
+ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
215
+ """
216
+ Copy/paste/tweak model's weights to our SigLIP structure.
217
+ """
218
+
219
+ # define default SigLIP configuration
220
+ config = get_siglip_config(model_name)
221
+
222
+ # load original state dict
223
+ data = load("/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz")
224
+ state_dict = flatten_nested_dict(data)
225
+
226
+ # remove and rename some keys
227
+ rename_keys = create_rename_keys(config)
228
+ for src, dest in rename_keys:
229
+ rename_key(state_dict, src, dest, config)
230
+
231
+ # qkv matrices of attention pooling head need special treatment
232
+ read_in_q_k_v_head(state_dict, config)
233
+
234
+ # load HuggingFace model
235
+ model = SiglipModel(config).eval()
236
+ model.load_state_dict(state_dict)
237
+
238
+ print("Original temperature:", data["params/t"])
239
+
240
+ # TODO create image processor
241
+ # url = "http://images.cocodataset.org/val2017/000000039769.jpg"
242
+ # image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
243
+ # preprocess image
244
+ #
245
+ # pixel_values = processor(image, return_tensors="pt").pixel_values
246
+
247
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename="pixel_values_siglip.npy", repo_type="dataset")
248
+ pixel_values = np.load(filepath)
249
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2)
250
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename="input_ids_siglip.npy", repo_type="dataset")
251
+ input_ids = np.load(filepath)
252
+ input_ids = torch.from_numpy(input_ids)
253
+
254
+ with torch.no_grad():
255
+ outputs = model(input_ids=input_ids, pixel_values=pixel_values)
256
+
257
+ # assert values
258
+ expected_slice = torch.tensor(
259
+ [[-2.9621, -2.1672, -1.7837], [-0.2713, 0.2910, -10.6595], [-13.6617, -13.1611, -17.4408]]
260
+ )
261
+ assert torch.allclose(outputs.logits_per_image[:3, :3], expected_slice, atol=1e-4)
262
+ print("Looks ok!")
263
+
264
+ if pytorch_dump_folder_path is not None:
265
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
266
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
267
+ model.save_pretrained(pytorch_dump_folder_path)
268
+ # print(f"Saving processor to {pytorch_dump_folder_path}")
269
+ # processor.save_pretrained(pytorch_dump_folder_path)
270
+
271
+ if push_to_hub:
272
+ model.push_to_hub(f"nielsr/{model_name}")
273
+ # processor.push_to_hub(f"nielsr/{model_name}")
274
+
275
+
276
+ if __name__ == "__main__":
277
+ parser = argparse.ArgumentParser()
278
+ # Required parameters
279
+ parser.add_argument(
280
+ "--model_name",
281
+ default="siglip-base-patch16-224",
282
+ type=str,
283
+ choices=["siglip-base-patch16-224"],
284
+ help="Name of the model you'd like to convert.",
285
+ )
286
+ parser.add_argument(
287
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
288
+ )
289
+ parser.add_argument(
290
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
291
+ )
292
+
293
+ args = parser.parse_args()
294
+ convert_siglip_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
image_processing_siglip.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for SigLIP."""
16
+
17
+ from typing import Dict, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ...image_transforms import (
23
+ rescale,
24
+ resize,
25
+ to_channel_dimension_format,
26
+ )
27
+ from ...image_utils import (
28
+ ChannelDimension,
29
+ ImageInput,
30
+ PILImageResampling,
31
+ infer_channel_dimension_format,
32
+ is_scaled_image,
33
+ make_list_of_images,
34
+ to_numpy_array,
35
+ valid_images,
36
+ )
37
+ from ...utils import TensorType, is_vision_available, logging
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ if is_vision_available():
44
+ import PIL
45
+
46
+
47
+ class SiglipImageProcessor(BaseImageProcessor):
48
+ r"""
49
+ Constructs a SigLIP image processor.
50
+
51
+ Args:
52
+ do_resize (`bool`, *optional*, defaults to `True`):
53
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
54
+ `do_resize` in the `preprocess` method.
55
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
56
+ Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
57
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
58
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
59
+ do_rescale (`bool`, *optional*, defaults to `True`):
60
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
61
+ the `preprocess` method.
62
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
63
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
64
+ method.
65
+ """
66
+
67
+ model_input_names = ["pixel_values"]
68
+
69
+ def __init__(
70
+ self,
71
+ do_resize: bool = True,
72
+ size: Dict[str, int] = None,
73
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
74
+ do_rescale: bool = True,
75
+ rescale_factor: Union[int, float] = 1 / 255,
76
+ **kwargs,
77
+ ) -> None:
78
+ super().__init__(**kwargs)
79
+ size = size if size is not None else {"height": 224, "width": 224}
80
+ size = get_size_dict(size, default_to_square=False)
81
+
82
+ self.do_resize = do_resize
83
+ self.size = size
84
+ self.resample = resample
85
+ self.do_rescale = do_rescale
86
+ self.rescale_factor = rescale_factor
87
+
88
+ def rescale(
89
+ self,
90
+ image: np.ndarray,
91
+ rescale_factor: float,
92
+ data_format: Optional[Union[str, ChannelDimension]] = None,
93
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
94
+ **kwargs,
95
+ ) -> np.ndarray:
96
+ """
97
+ Rescale an image by a scale factor. image = image * scale, after which image = image * 2 - 1.
98
+
99
+ Args:
100
+ image (`np.ndarray`):
101
+ Image to rescale.
102
+ scale (`float`):
103
+ The scaling factor to rescale pixel values by.
104
+ data_format (`str` or `ChannelDimension`, *optional*):
105
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
106
+ image is used. Can be one of:
107
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
108
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
109
+ input_data_format (`ChannelDimension` or `str`, *optional*):
110
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
111
+ from the input image. Can be one of:
112
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
113
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
114
+
115
+ Returns:
116
+ `np.ndarray`: The rescaled image.
117
+ """
118
+ # first, rescale to 0->1
119
+ rescaled_image = rescale(
120
+ image, scale=rescale_factor, data_format=data_format, input_data_format=input_data_format, **kwargs
121
+ )
122
+
123
+ # next, rescale to -1->1
124
+ rescaled_image = 2 * rescaled_image - 1
125
+
126
+ return rescaled_image
127
+
128
+ def preprocess(
129
+ self,
130
+ images: ImageInput,
131
+ do_resize: bool = None,
132
+ size: Dict[str, int] = None,
133
+ resample: PILImageResampling = None,
134
+ do_rescale: bool = None,
135
+ rescale_factor: float = None,
136
+ return_tensors: Optional[Union[str, TensorType]] = None,
137
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
138
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
139
+ **kwargs,
140
+ ) -> PIL.Image.Image:
141
+ """
142
+ Preprocess an image or batch of images.
143
+
144
+ Args:
145
+ images (`ImageInput`):
146
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
147
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
148
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
149
+ Whether to resize the image.
150
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
151
+ Size of the image after resizing.
152
+ resample (`int`, *optional*, defaults to `self.resample`):
153
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
154
+ has an effect if `do_resize` is set to `True`.
155
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
156
+ Whether to rescale the image.
157
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
158
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
159
+ return_tensors (`str` or `TensorType`, *optional*):
160
+ The type of tensors to return. Can be one of:
161
+ - Unset: Return a list of `np.ndarray`.
162
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
163
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
164
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
165
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
166
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
167
+ The channel dimension format for the output image. Can be one of:
168
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
169
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
170
+ - Unset: Use the channel dimension format of the input image.
171
+ input_data_format (`ChannelDimension` or `str`, *optional*):
172
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
173
+ from the input image. Can be one of:
174
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
175
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
176
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
177
+ """
178
+ do_resize = do_resize if do_resize is not None else self.do_resize
179
+ size = size if size is not None else self.size
180
+ size = get_size_dict(size, param_name="size", default_to_square=False)
181
+ resample = resample if resample is not None else self.resample
182
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
183
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
184
+
185
+ images = make_list_of_images(images)
186
+
187
+ if not valid_images(images):
188
+ raise ValueError(
189
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
190
+ "torch.Tensor, tf.Tensor or jax.ndarray."
191
+ )
192
+
193
+ if do_resize and size is None:
194
+ raise ValueError("Size must be specified if do_resize is True.")
195
+
196
+ if do_rescale and rescale_factor is None:
197
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
198
+
199
+ # All transformations expect numpy arrays.
200
+ images = [to_numpy_array(image) for image in images]
201
+
202
+ if is_scaled_image(images[0]) and do_rescale:
203
+ logger.warning_once(
204
+ "It looks like you are trying to rescale already rescaled images. If the input"
205
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
206
+ )
207
+
208
+ if input_data_format is None:
209
+ # We assume that all images have the same channel dimension format.
210
+ input_data_format = infer_channel_dimension_format(images[0])
211
+
212
+ if do_resize:
213
+ images = [
214
+ resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
215
+ for image in images
216
+ ]
217
+
218
+ if do_rescale:
219
+ images = [
220
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
221
+ for image in images
222
+ ]
223
+
224
+ images = [
225
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
226
+ ]
227
+
228
+ data = {"pixel_values": images}
229
+ return BatchFeature(data=data, tensor_type=return_tensors)
modeling_siglip.py ADDED
@@ -0,0 +1,1184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Google AI and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Siglip model."""
16
+
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Any, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+
25
+ from ...activations import ACT2FN
26
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
27
+ from ...modeling_utils import PreTrainedModel
28
+ from ...utils import (
29
+ ModelOutput,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ )
35
+ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
41
+
42
+ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
43
+ "google/siglip-base-patch16-224",
44
+ # See all SigLIP models at https://huggingface.co/models?filter=siglip
45
+ ]
46
+
47
+
48
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
49
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
50
+ """
51
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
52
+ """
53
+ bsz, src_len = mask.size()
54
+ tgt_len = tgt_len if tgt_len is not None else src_len
55
+
56
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
57
+
58
+ inverted_mask = 1.0 - expanded_mask
59
+
60
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
61
+
62
+
63
+ # contrastive loss function, adapted from
64
+ # https://sachinruk.github.io/blog/2021-03-07-siglip.html
65
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
66
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
67
+
68
+
69
+ # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->siglip
70
+ def siglip_loss(similarity: torch.Tensor) -> torch.Tensor:
71
+ caption_loss = contrastive_loss(similarity)
72
+ image_loss = contrastive_loss(similarity.t())
73
+ return (caption_loss + image_loss) / 2.0
74
+
75
+
76
+ @dataclass
77
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
78
+ class SiglipVisionModelOutput(ModelOutput):
79
+ """
80
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
81
+
82
+ Args:
83
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
84
+ The image embeddings obtained by applying the projection layer to the pooler_output.
85
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
86
+ Sequence of hidden-states at the output of the last layer of the model.
87
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
88
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
89
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
90
+
91
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
92
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
93
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
94
+ sequence_length)`.
95
+
96
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
97
+ heads.
98
+ """
99
+
100
+ image_embeds: Optional[torch.FloatTensor] = None
101
+ last_hidden_state: torch.FloatTensor = None
102
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
103
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
104
+
105
+
106
+ @dataclass
107
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
108
+ class SiglipTextModelOutput(ModelOutput):
109
+ """
110
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
111
+
112
+ Args:
113
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
114
+ The text embeddings obtained by applying the projection layer to the pooler_output.
115
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
116
+ Sequence of hidden-states at the output of the last layer of the model.
117
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
118
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
119
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
120
+
121
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
122
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
123
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
124
+ sequence_length)`.
125
+
126
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
127
+ heads.
128
+ """
129
+
130
+ text_embeds: Optional[torch.FloatTensor] = None
131
+ last_hidden_state: torch.FloatTensor = None
132
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
133
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
134
+
135
+
136
+ @dataclass
137
+ # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
138
+ class SiglipOutput(ModelOutput):
139
+ """
140
+ Args:
141
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
142
+ Contrastive loss for image-text similarity.
143
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
144
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
145
+ similarity scores.
146
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
147
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
148
+ similarity scores.
149
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
150
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
151
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
152
+ The image embeddings obtained by applying the projection layer to the pooled output of
153
+ [`SiglipVisionModel`].
154
+ text_model_output(`BaseModelOutputWithPooling`):
155
+ The output of the [`SiglipTextModel`].
156
+ vision_model_output(`BaseModelOutputWithPooling`):
157
+ The output of the [`SiglipVisionModel`].
158
+ """
159
+
160
+ loss: Optional[torch.FloatTensor] = None
161
+ logits_per_image: torch.FloatTensor = None
162
+ logits_per_text: torch.FloatTensor = None
163
+ text_embeds: torch.FloatTensor = None
164
+ image_embeds: torch.FloatTensor = None
165
+ text_model_output: BaseModelOutputWithPooling = None
166
+ vision_model_output: BaseModelOutputWithPooling = None
167
+
168
+ def to_tuple(self) -> Tuple[Any]:
169
+ return tuple(
170
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
171
+ for k in self.keys()
172
+ )
173
+
174
+
175
+ class SiglipVisionEmbeddings(nn.Module):
176
+ def __init__(self, config: SiglipVisionConfig):
177
+ super().__init__()
178
+ self.config = config
179
+ self.embed_dim = config.hidden_size
180
+ self.image_size = config.image_size
181
+ self.patch_size = config.patch_size
182
+
183
+ self.patch_embedding = nn.Conv2d(
184
+ in_channels=config.num_channels,
185
+ out_channels=self.embed_dim,
186
+ kernel_size=self.patch_size,
187
+ stride=self.patch_size,
188
+ padding="valid",
189
+ )
190
+
191
+ self.num_patches = (self.image_size // self.patch_size) ** 2
192
+ self.num_positions = self.num_patches
193
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
194
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
195
+
196
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
197
+ print("First values of pixel values:", pixel_values[0, 0, :3, :3])
198
+
199
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
200
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
201
+
202
+ print("Shape of embeddings: ", embeddings.shape)
203
+ print("First values of patch embeddings:", embeddings[0, :3, :3])
204
+
205
+ embeddings = embeddings + self.position_embedding(self.position_ids)
206
+ return embeddings
207
+
208
+
209
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
210
+ class SiglipTextEmbeddings(nn.Module):
211
+ def __init__(self, config: SiglipTextConfig):
212
+ super().__init__()
213
+ embed_dim = config.hidden_size
214
+
215
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
216
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
217
+
218
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
219
+ self.register_buffer(
220
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
221
+ )
222
+
223
+ def forward(
224
+ self,
225
+ input_ids: Optional[torch.LongTensor] = None,
226
+ position_ids: Optional[torch.LongTensor] = None,
227
+ inputs_embeds: Optional[torch.FloatTensor] = None,
228
+ ) -> torch.Tensor:
229
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
230
+
231
+ if position_ids is None:
232
+ position_ids = self.position_ids[:, :seq_length]
233
+
234
+ if inputs_embeds is None:
235
+ inputs_embeds = self.token_embedding(input_ids)
236
+
237
+ position_embeddings = self.position_embedding(position_ids)
238
+ embeddings = inputs_embeds + position_embeddings
239
+
240
+ return embeddings
241
+
242
+
243
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->Siglip
244
+ class SiglipAttention(nn.Module):
245
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
246
+
247
+ def __init__(self, config):
248
+ super().__init__()
249
+ self.config = config
250
+ self.embed_dim = config.hidden_size
251
+ self.num_heads = config.num_attention_heads
252
+ self.head_dim = self.embed_dim // self.num_heads
253
+ if self.head_dim * self.num_heads != self.embed_dim:
254
+ raise ValueError(
255
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
256
+ f" {self.num_heads})."
257
+ )
258
+ self.scale = self.head_dim**-0.5
259
+ self.dropout = config.attention_dropout
260
+
261
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
262
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
263
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
264
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
265
+
266
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
267
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
268
+
269
+ def forward(
270
+ self,
271
+ hidden_states: torch.Tensor,
272
+ attention_mask: Optional[torch.Tensor] = None,
273
+ causal_attention_mask: Optional[torch.Tensor] = None,
274
+ output_attentions: Optional[bool] = False,
275
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
276
+ """Input shape: Batch x Time x Channel"""
277
+
278
+ bsz, tgt_len, embed_dim = hidden_states.size()
279
+
280
+ # get query proj
281
+ query_states = self.q_proj(hidden_states) * self.scale
282
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
283
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
284
+
285
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
286
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
287
+ key_states = key_states.view(*proj_shape)
288
+ value_states = value_states.view(*proj_shape)
289
+
290
+ src_len = key_states.size(1)
291
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
292
+
293
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
294
+ raise ValueError(
295
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
296
+ f" {attn_weights.size()}"
297
+ )
298
+
299
+ # apply the causal_attention_mask first
300
+ if causal_attention_mask is not None:
301
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
302
+ raise ValueError(
303
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
304
+ f" {causal_attention_mask.size()}"
305
+ )
306
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
307
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
308
+
309
+ if attention_mask is not None:
310
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
311
+ raise ValueError(
312
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
313
+ )
314
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
315
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
316
+
317
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
318
+
319
+ if output_attentions:
320
+ # this operation is a bit akward, but it's required to
321
+ # make sure that attn_weights keeps its gradient.
322
+ # In order to do so, attn_weights have to reshaped
323
+ # twice and have to be reused in the following
324
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
325
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
326
+ else:
327
+ attn_weights_reshaped = None
328
+
329
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
330
+
331
+ attn_output = torch.bmm(attn_probs, value_states)
332
+
333
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
334
+ raise ValueError(
335
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
336
+ f" {attn_output.size()}"
337
+ )
338
+
339
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
340
+ attn_output = attn_output.transpose(1, 2)
341
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
342
+
343
+ attn_output = self.out_proj(attn_output)
344
+
345
+ return attn_output, attn_weights_reshaped
346
+
347
+
348
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
349
+ class SiglipMLP(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.config = config
353
+ self.activation_fn = ACT2FN[config.hidden_act]
354
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
355
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
356
+
357
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
358
+ hidden_states = self.fc1(hidden_states)
359
+ hidden_states = self.activation_fn(hidden_states)
360
+ hidden_states = self.fc2(hidden_states)
361
+ return hidden_states
362
+
363
+
364
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
365
+ class SiglipEncoderLayer(nn.Module):
366
+ def __init__(self, config: SiglipConfig):
367
+ super().__init__()
368
+ self.embed_dim = config.hidden_size
369
+ self.self_attn = SiglipAttention(config)
370
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
371
+ self.mlp = SiglipMLP(config)
372
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
373
+
374
+ def forward(
375
+ self,
376
+ hidden_states: torch.Tensor,
377
+ attention_mask: torch.Tensor,
378
+ causal_attention_mask: torch.Tensor,
379
+ output_attentions: Optional[bool] = False,
380
+ ) -> Tuple[torch.FloatTensor]:
381
+ """
382
+ Args:
383
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
384
+ attention_mask (`torch.FloatTensor`): attention mask of size
385
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
386
+ `(config.encoder_attention_heads,)`.
387
+ output_attentions (`bool`, *optional*):
388
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
389
+ returned tensors for more detail.
390
+ """
391
+ residual = hidden_states
392
+
393
+ hidden_states = self.layer_norm1(hidden_states)
394
+ hidden_states, attn_weights = self.self_attn(
395
+ hidden_states=hidden_states,
396
+ attention_mask=attention_mask,
397
+ causal_attention_mask=causal_attention_mask,
398
+ output_attentions=output_attentions,
399
+ )
400
+ hidden_states = residual + hidden_states
401
+
402
+ residual = hidden_states
403
+ hidden_states = self.layer_norm2(hidden_states)
404
+ hidden_states = self.mlp(hidden_states)
405
+ hidden_states = residual + hidden_states
406
+
407
+ outputs = (hidden_states,)
408
+
409
+ if output_attentions:
410
+ outputs += (attn_weights,)
411
+
412
+ return outputs
413
+
414
+
415
+ class SiglipPreTrainedModel(PreTrainedModel):
416
+ """
417
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
418
+ models.
419
+ """
420
+
421
+ config_class = SiglipConfig
422
+ base_model_prefix = "siglip"
423
+ supports_gradient_checkpointing = True
424
+
425
+ def _init_weights(self, module):
426
+ """Initialize the weights"""
427
+ factor = self.config.initializer_factor
428
+ if isinstance(module, SiglipTextEmbeddings):
429
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
430
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
431
+ elif isinstance(module, SiglipVisionEmbeddings):
432
+ factor = self.config.initializer_factor
433
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
434
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
435
+ elif isinstance(module, SiglipAttention):
436
+ factor = self.config.initializer_factor
437
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
438
+ out_proj_std = (module.embed_dim**-0.5) * factor
439
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
440
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
441
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
442
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
443
+ elif isinstance(module, SiglipMLP):
444
+ factor = self.config.initializer_factor
445
+ in_proj_std = (
446
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
447
+ )
448
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
449
+ nn.init.normal_(module.fc1.weight, std=fc_std)
450
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
451
+ if isinstance(module, nn.LayerNorm):
452
+ module.bias.data.zero_()
453
+ module.weight.data.fill_(1.0)
454
+ if isinstance(module, nn.Linear) and module.bias is not None:
455
+ module.bias.data.zero_()
456
+
457
+ def _set_gradient_checkpointing(self, module, value=False):
458
+ if isinstance(module, SiglipEncoder):
459
+ module.gradient_checkpointing = value
460
+
461
+
462
+ SIGLIP_START_DOCSTRING = r"""
463
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
464
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
465
+ etc.)
466
+
467
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
468
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
469
+ and behavior.
470
+
471
+ Parameters:
472
+ config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
473
+ Initializing with a config file does not load the weights associated with the model, only the
474
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
475
+ """
476
+
477
+ SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
478
+ Args:
479
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
480
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
481
+ it.
482
+
483
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
484
+ [`PreTrainedTokenizer.__call__`] for details.
485
+
486
+ [What are input IDs?](../glossary#input-ids)
487
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
488
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
489
+
490
+ - 1 for tokens that are **not masked**,
491
+ - 0 for tokens that are **masked**.
492
+
493
+ [What are attention masks?](../glossary#attention-mask)
494
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
495
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
496
+ config.max_position_embeddings - 1]`.
497
+
498
+ [What are position IDs?](../glossary#position-ids)
499
+ output_attentions (`bool`, *optional*):
500
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
501
+ tensors for more detail.
502
+ output_hidden_states (`bool`, *optional*):
503
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
504
+ more detail.
505
+ return_dict (`bool`, *optional*):
506
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
507
+ """
508
+
509
+ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
510
+ Args:
511
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
512
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
513
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
514
+ output_attentions (`bool`, *optional*):
515
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
516
+ tensors for more detail.
517
+ output_hidden_states (`bool`, *optional*):
518
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
519
+ more detail.
520
+ return_dict (`bool`, *optional*):
521
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
522
+ """
523
+
524
+ SIGLIP_INPUTS_DOCSTRING = r"""
525
+ Args:
526
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
527
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
528
+ it.
529
+
530
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
531
+ [`PreTrainedTokenizer.__call__`] for details.
532
+
533
+ [What are input IDs?](../glossary#input-ids)
534
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
535
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
536
+
537
+ - 1 for tokens that are **not masked**,
538
+ - 0 for tokens that are **masked**.
539
+
540
+ [What are attention masks?](../glossary#attention-mask)
541
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
542
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
543
+ config.max_position_embeddings - 1]`.
544
+
545
+ [What are position IDs?](../glossary#position-ids)
546
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
547
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
548
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
549
+ return_loss (`bool`, *optional*):
550
+ Whether or not to return the contrastive loss.
551
+ output_attentions (`bool`, *optional*):
552
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
553
+ tensors for more detail.
554
+ output_hidden_states (`bool`, *optional*):
555
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
556
+ more detail.
557
+ return_dict (`bool`, *optional*):
558
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
559
+ """
560
+
561
+
562
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
563
+ class SiglipEncoder(nn.Module):
564
+ """
565
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
566
+ [`SiglipEncoderLayer`].
567
+
568
+ Args:
569
+ config: SiglipConfig
570
+ """
571
+
572
+ def __init__(self, config: SiglipConfig):
573
+ super().__init__()
574
+ self.config = config
575
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
576
+ self.gradient_checkpointing = False
577
+
578
+ def forward(
579
+ self,
580
+ inputs_embeds,
581
+ attention_mask: Optional[torch.Tensor] = None,
582
+ causal_attention_mask: Optional[torch.Tensor] = None,
583
+ output_attentions: Optional[bool] = None,
584
+ output_hidden_states: Optional[bool] = None,
585
+ return_dict: Optional[bool] = None,
586
+ ) -> Union[Tuple, BaseModelOutput]:
587
+ r"""
588
+ Args:
589
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
590
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
591
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
592
+ than the model's internal embedding lookup matrix.
593
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
594
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
595
+
596
+ - 1 for tokens that are **not masked**,
597
+ - 0 for tokens that are **masked**.
598
+
599
+ [What are attention masks?](../glossary#attention-mask)
600
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
601
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
602
+
603
+ - 1 for tokens that are **not masked**,
604
+ - 0 for tokens that are **masked**.
605
+
606
+ [What are attention masks?](../glossary#attention-mask)
607
+ output_attentions (`bool`, *optional*):
608
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
609
+ returned tensors for more detail.
610
+ output_hidden_states (`bool`, *optional*):
611
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
612
+ for more detail.
613
+ return_dict (`bool`, *optional*):
614
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
615
+ """
616
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
617
+ output_hidden_states = (
618
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
619
+ )
620
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
621
+
622
+ encoder_states = () if output_hidden_states else None
623
+ all_attentions = () if output_attentions else None
624
+
625
+ hidden_states = inputs_embeds
626
+ for idx, encoder_layer in enumerate(self.layers):
627
+ if output_hidden_states:
628
+ encoder_states = encoder_states + (hidden_states,)
629
+ if self.gradient_checkpointing and self.training:
630
+
631
+ def create_custom_forward(module):
632
+ def custom_forward(*inputs):
633
+ return module(*inputs, output_attentions)
634
+
635
+ return custom_forward
636
+
637
+ layer_outputs = torch.utils.checkpoint.checkpoint(
638
+ create_custom_forward(encoder_layer),
639
+ hidden_states,
640
+ attention_mask,
641
+ causal_attention_mask,
642
+ )
643
+ else:
644
+ layer_outputs = encoder_layer(
645
+ hidden_states,
646
+ attention_mask,
647
+ causal_attention_mask,
648
+ output_attentions=output_attentions,
649
+ )
650
+
651
+ hidden_states = layer_outputs[0]
652
+
653
+ if output_attentions:
654
+ all_attentions = all_attentions + (layer_outputs[1],)
655
+
656
+ if output_hidden_states:
657
+ encoder_states = encoder_states + (hidden_states,)
658
+
659
+ if not return_dict:
660
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
661
+ return BaseModelOutput(
662
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
663
+ )
664
+
665
+
666
+ class SiglipTextTransformer(nn.Module):
667
+ def __init__(self, config: SiglipTextConfig):
668
+ super().__init__()
669
+ self.config = config
670
+ embed_dim = config.hidden_size
671
+ self.embeddings = SiglipTextEmbeddings(config)
672
+ self.encoder = SiglipEncoder(config)
673
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
674
+
675
+ self.head = nn.Linear(embed_dim, embed_dim)
676
+
677
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
678
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
679
+ def forward(
680
+ self,
681
+ input_ids: Optional[torch.Tensor] = None,
682
+ attention_mask: Optional[torch.Tensor] = None,
683
+ position_ids: Optional[torch.Tensor] = None,
684
+ output_attentions: Optional[bool] = None,
685
+ output_hidden_states: Optional[bool] = None,
686
+ return_dict: Optional[bool] = None,
687
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
688
+ r"""
689
+ Returns:
690
+
691
+ """
692
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
693
+ output_hidden_states = (
694
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
695
+ )
696
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
697
+
698
+ if input_ids is None:
699
+ raise ValueError("You have to specify input_ids")
700
+
701
+ input_shape = input_ids.size()
702
+ input_ids = input_ids.view(-1, input_shape[-1])
703
+
704
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
705
+
706
+ # note: SigLIP's text model does not use q causal mask, unlike the original CLIP model.
707
+ # expand attention_mask
708
+ if attention_mask is not None:
709
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
710
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
711
+
712
+ encoder_outputs = self.encoder(
713
+ inputs_embeds=hidden_states,
714
+ attention_mask=None,
715
+ causal_attention_mask=None,
716
+ output_attentions=output_attentions,
717
+ output_hidden_states=output_hidden_states,
718
+ return_dict=return_dict,
719
+ )
720
+
721
+ last_hidden_state = encoder_outputs[0]
722
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
723
+
724
+ print("Final text hidden states:", last_hidden_state[0, :3, :3])
725
+
726
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
727
+ pooled_output = last_hidden_state[:, -1, :]
728
+ pooled_output = self.head(pooled_output)
729
+
730
+ print("First values of text pooled output:", pooled_output[0, :3])
731
+
732
+ if not return_dict:
733
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
734
+
735
+ return BaseModelOutputWithPooling(
736
+ last_hidden_state=last_hidden_state,
737
+ pooler_output=pooled_output,
738
+ hidden_states=encoder_outputs.hidden_states,
739
+ attentions=encoder_outputs.attentions,
740
+ )
741
+
742
+
743
+ @add_start_docstrings(
744
+ """The text model from SigLIP without any head or projection on top.""",
745
+ SIGLIP_START_DOCSTRING,
746
+ )
747
+ class SiglipTextModel(SiglipPreTrainedModel):
748
+ config_class = SiglipTextConfig
749
+
750
+ _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
751
+
752
+ def __init__(self, config: SiglipTextConfig):
753
+ super().__init__(config)
754
+ self.text_model = SiglipTextTransformer(config)
755
+ # Initialize weights and apply final processing
756
+ self.post_init()
757
+
758
+ def get_input_embeddings(self) -> nn.Module:
759
+ return self.text_model.embeddings.token_embedding
760
+
761
+ def set_input_embeddings(self, value):
762
+ self.text_model.embeddings.token_embedding = value
763
+
764
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
765
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
766
+ def forward(
767
+ self,
768
+ input_ids: Optional[torch.Tensor] = None,
769
+ attention_mask: Optional[torch.Tensor] = None,
770
+ position_ids: Optional[torch.Tensor] = None,
771
+ output_attentions: Optional[bool] = None,
772
+ output_hidden_states: Optional[bool] = None,
773
+ return_dict: Optional[bool] = None,
774
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
775
+ r"""
776
+ Returns:
777
+
778
+ Examples:
779
+
780
+ ```python
781
+ >>> from transformers import AutoTokenizer, SiglipTextModel
782
+
783
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
784
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
785
+
786
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
787
+
788
+ >>> outputs = model(**inputs)
789
+ >>> last_hidden_state = outputs.last_hidden_state
790
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
791
+ ```"""
792
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
793
+
794
+ return self.text_model(
795
+ input_ids=input_ids,
796
+ attention_mask=attention_mask,
797
+ position_ids=position_ids,
798
+ output_attentions=output_attentions,
799
+ output_hidden_states=output_hidden_states,
800
+ return_dict=return_dict,
801
+ )
802
+
803
+
804
+ class SiglipVisionTransformer(nn.Module):
805
+ def __init__(self, config: SiglipVisionConfig):
806
+ super().__init__()
807
+ self.config = config
808
+ embed_dim = config.hidden_size
809
+
810
+ self.embeddings = SiglipVisionEmbeddings(config)
811
+ self.encoder = SiglipEncoder(config)
812
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
813
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
814
+
815
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
816
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
817
+ def forward(
818
+ self,
819
+ pixel_values,
820
+ output_attentions: Optional[bool] = None,
821
+ output_hidden_states: Optional[bool] = None,
822
+ return_dict: Optional[bool] = None,
823
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
824
+ r"""
825
+ Returns:
826
+
827
+ """
828
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
829
+ output_hidden_states = (
830
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
831
+ )
832
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
833
+
834
+ hidden_states = self.embeddings(pixel_values)
835
+
836
+ encoder_outputs = self.encoder(
837
+ inputs_embeds=hidden_states,
838
+ output_attentions=output_attentions,
839
+ output_hidden_states=output_hidden_states,
840
+ return_dict=return_dict,
841
+ )
842
+
843
+ last_hidden_state = encoder_outputs[0]
844
+ last_hidden_state = self.post_layernorm(last_hidden_state)
845
+
846
+ print("First values post layernorm:", last_hidden_state[0, :3, :3])
847
+
848
+ pooled_output = self.head(last_hidden_state)
849
+
850
+ print("Shape of pooled vision output:", pooled_output.shape)
851
+ print("First values of pooled vision output:", pooled_output[0, :3])
852
+
853
+ if not return_dict:
854
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
855
+
856
+ return BaseModelOutputWithPooling(
857
+ last_hidden_state=last_hidden_state,
858
+ pooler_output=pooled_output,
859
+ hidden_states=encoder_outputs.hidden_states,
860
+ attentions=encoder_outputs.attentions,
861
+ )
862
+
863
+
864
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
865
+ """Multihead Attention Pooling."""
866
+
867
+ def __init__(self, config: SiglipVisionConfig):
868
+ super().__init__()
869
+
870
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
871
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
872
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
873
+ self.mlp = SiglipMLP(config)
874
+
875
+ def forward(self, hidden_state):
876
+ batch_size = hidden_state.shape[0]
877
+ probe = self.probe.repeat(batch_size, 1, 1)
878
+
879
+ print("Shape of probe:", probe.shape)
880
+ print("First values of probe:", probe[0, :3, :3])
881
+ print("Shape of hidden state:", hidden_state.shape)
882
+ print("First values of hidden state:", hidden_state[0, :3, :3])
883
+
884
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
885
+
886
+ residual = hidden_state
887
+ hidden_state = self.layernorm(hidden_state)
888
+ hidden_state = residual + self.mlp(hidden_state)
889
+
890
+ return hidden_state[:, 0]
891
+
892
+
893
+ @add_start_docstrings(
894
+ """The vision model from SigLIP without any head or projection on top.""",
895
+ SIGLIP_START_DOCSTRING,
896
+ )
897
+ class SiglipVisionModel(SiglipPreTrainedModel):
898
+ config_class = SiglipVisionConfig
899
+ main_input_name = "pixel_values"
900
+
901
+ def __init__(self, config: SiglipVisionConfig):
902
+ super().__init__(config)
903
+
904
+ self.vision_model = SiglipVisionTransformer(config)
905
+
906
+ # Initialize weights and apply final processing
907
+ self.post_init()
908
+
909
+ def get_input_embeddings(self) -> nn.Module:
910
+ return self.vision_model.embeddings.patch_embedding
911
+
912
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
913
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
914
+ def forward(
915
+ self,
916
+ pixel_values,
917
+ output_attentions: Optional[bool] = None,
918
+ output_hidden_states: Optional[bool] = None,
919
+ return_dict: Optional[bool] = None,
920
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
921
+ r"""
922
+ Returns:
923
+
924
+ Examples:
925
+
926
+ ```python
927
+ >>> from PIL import Image
928
+ >>> import requests
929
+ >>> from transformers import AutoProcessor, SiglipVisionModel
930
+
931
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
932
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
933
+
934
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
935
+ >>> image = Image.open(requests.get(url, stream=True).raw)
936
+
937
+ >>> inputs = processor(images=image, return_tensors="pt")
938
+
939
+ >>> outputs = model(**inputs)
940
+ >>> last_hidden_state = outputs.last_hidden_state
941
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
942
+ ```"""
943
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
944
+
945
+ return self.vision_model(
946
+ pixel_values=pixel_values,
947
+ output_attentions=output_attentions,
948
+ output_hidden_states=output_hidden_states,
949
+ return_dict=return_dict,
950
+ )
951
+
952
+
953
+ @add_start_docstrings(SIGLIP_START_DOCSTRING)
954
+ class SiglipModel(SiglipPreTrainedModel):
955
+ config_class = SiglipConfig
956
+
957
+ def __init__(self, config: SiglipConfig):
958
+ super().__init__(config)
959
+
960
+ if not isinstance(config.text_config, SiglipTextConfig):
961
+ raise ValueError(
962
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
963
+ f" {type(config.text_config)}."
964
+ )
965
+
966
+ if not isinstance(config.vision_config, SiglipVisionConfig):
967
+ raise ValueError(
968
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
969
+ f" {type(config.vision_config)}."
970
+ )
971
+
972
+ text_config = config.text_config
973
+ vision_config = config.vision_config
974
+
975
+ self.text_model = SiglipTextModel(text_config)
976
+ self.vision_model = SiglipVisionModel(vision_config)
977
+
978
+ self.temperature = nn.Parameter(
979
+ torch.randn(
980
+ 1,
981
+ )
982
+ )
983
+ self.bias = nn.Parameter(
984
+ torch.randn(
985
+ 1,
986
+ )
987
+ )
988
+
989
+ # Initialize weights and apply final processing
990
+ self.post_init()
991
+
992
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
993
+ def get_text_features(
994
+ self,
995
+ input_ids: Optional[torch.Tensor] = None,
996
+ attention_mask: Optional[torch.Tensor] = None,
997
+ position_ids: Optional[torch.Tensor] = None,
998
+ output_attentions: Optional[bool] = None,
999
+ output_hidden_states: Optional[bool] = None,
1000
+ return_dict: Optional[bool] = None,
1001
+ ) -> torch.FloatTensor:
1002
+ r"""
1003
+ Returns:
1004
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1005
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
1006
+
1007
+ Examples:
1008
+
1009
+ ```python
1010
+ >>> from transformers import AutoTokenizer, SiglipModel
1011
+
1012
+ >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1013
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1014
+
1015
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1016
+ >>> text_features = model.get_text_features(**inputs)
1017
+ ```"""
1018
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1019
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1020
+ output_hidden_states = (
1021
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1022
+ )
1023
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1024
+
1025
+ text_outputs = self.text_model(
1026
+ input_ids=input_ids,
1027
+ attention_mask=attention_mask,
1028
+ position_ids=position_ids,
1029
+ output_attentions=output_attentions,
1030
+ output_hidden_states=output_hidden_states,
1031
+ return_dict=return_dict,
1032
+ )
1033
+
1034
+ pooled_output = text_outputs[1]
1035
+
1036
+ return pooled_output
1037
+
1038
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1039
+ def get_image_features(
1040
+ self,
1041
+ pixel_values: Optional[torch.FloatTensor] = None,
1042
+ output_attentions: Optional[bool] = None,
1043
+ output_hidden_states: Optional[bool] = None,
1044
+ return_dict: Optional[bool] = None,
1045
+ ) -> torch.FloatTensor:
1046
+ r"""
1047
+ Returns:
1048
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1049
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
1050
+
1051
+ Examples:
1052
+
1053
+ ```python
1054
+ >>> from PIL import Image
1055
+ >>> import requests
1056
+ >>> from transformers import AutoProcessor, SiglipModel
1057
+
1058
+ >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1059
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1060
+
1061
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1062
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1063
+
1064
+ >>> inputs = processor(images=image, return_tensors="pt")
1065
+
1066
+ >>> image_features = model.get_image_features(**inputs)
1067
+ ```"""
1068
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1069
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1070
+ output_hidden_states = (
1071
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1072
+ )
1073
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1074
+
1075
+ vision_outputs = self.vision_model(
1076
+ pixel_values=pixel_values,
1077
+ output_attentions=output_attentions,
1078
+ output_hidden_states=output_hidden_states,
1079
+ return_dict=return_dict,
1080
+ )
1081
+
1082
+ pooled_output = vision_outputs[1]
1083
+
1084
+ return pooled_output
1085
+
1086
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1087
+ @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1088
+ def forward(
1089
+ self,
1090
+ input_ids: Optional[torch.LongTensor] = None,
1091
+ pixel_values: Optional[torch.FloatTensor] = None,
1092
+ attention_mask: Optional[torch.Tensor] = None,
1093
+ position_ids: Optional[torch.LongTensor] = None,
1094
+ return_loss: Optional[bool] = None,
1095
+ output_attentions: Optional[bool] = None,
1096
+ output_hidden_states: Optional[bool] = None,
1097
+ return_dict: Optional[bool] = None,
1098
+ ) -> Union[Tuple, SiglipOutput]:
1099
+ r"""
1100
+ Returns:
1101
+
1102
+ Examples:
1103
+
1104
+ ```python
1105
+ >>> from PIL import Image
1106
+ >>> import requests
1107
+ >>> from transformers import AutoProcessor, SiglipModel
1108
+
1109
+ >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1110
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1111
+
1112
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1113
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1114
+
1115
+ >>> inputs = processor(
1116
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
1117
+ ... )
1118
+
1119
+ >>> outputs = model(**inputs)
1120
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1121
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
1122
+ ```"""
1123
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1124
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1125
+ output_hidden_states = (
1126
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1127
+ )
1128
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1129
+
1130
+ vision_outputs = self.vision_model(
1131
+ pixel_values=pixel_values,
1132
+ output_attentions=output_attentions,
1133
+ output_hidden_states=output_hidden_states,
1134
+ return_dict=return_dict,
1135
+ )
1136
+
1137
+ text_outputs = self.text_model(
1138
+ input_ids=input_ids,
1139
+ attention_mask=attention_mask,
1140
+ position_ids=position_ids,
1141
+ output_attentions=output_attentions,
1142
+ output_hidden_states=output_hidden_states,
1143
+ return_dict=return_dict,
1144
+ )
1145
+
1146
+ image_embeds = vision_outputs[1]
1147
+ text_embeds = text_outputs[1]
1148
+
1149
+ # normalized features
1150
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1151
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1152
+
1153
+ print("Normalized image embeds:", image_embeds[0, :3])
1154
+ print("Normalized text embeds:", text_embeds[0, :3])
1155
+
1156
+ # cosine similarity as logits
1157
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.temperature.exp() + self.bias
1158
+ logits_per_image = logits_per_text.t()
1159
+
1160
+ print("Learned temperature:", self.temperature)
1161
+ print("Learned bias:", self.bias)
1162
+
1163
+ z = torch.matmul(image_embeds, text_embeds.t()) * self.temperature.exp()
1164
+ print("Multiplying by temperature:", z[:3, :3])
1165
+
1166
+ print("Logits per image:", logits_per_image[:3, :3])
1167
+
1168
+ loss = None
1169
+ if return_loss:
1170
+ raise NotImplementedError("SigLIP loss to be implemented")
1171
+
1172
+ if not return_dict:
1173
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1174
+ return ((loss,) + output) if loss is not None else output
1175
+
1176
+ return SiglipOutput(
1177
+ loss=loss,
1178
+ logits_per_image=logits_per_image,
1179
+ logits_per_text=logits_per_text,
1180
+ text_embeds=text_embeds,
1181
+ image_embeds=image_embeds,
1182
+ text_model_output=text_outputs,
1183
+ vision_model_output=vision_outputs,
1184
+ )