fix issues with erf and xavier init
Browse files- modeling_siglip.py +15 -9
modeling_siglip.py
CHANGED
@@ -95,7 +95,12 @@ def _trunc_normal_(tensor, mean, std, a, b):
|
|
95 |
|
96 |
# Use inverse cdf transform for normal distribution to get truncated
|
97 |
# standard normal
|
98 |
-
tensor.
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
# Transform to proper mean, std
|
101 |
tensor.mul_(std * math.sqrt(2.0))
|
@@ -670,6 +675,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
|
|
670 |
|
671 |
def _init_weights(self, module):
|
672 |
"""Initialize the weights"""
|
|
|
673 |
if isinstance(module, SiglipVisionEmbeddings):
|
674 |
width = (
|
675 |
self.config.vision_config.hidden_size
|
@@ -680,22 +686,22 @@ class SiglipPreTrainedModel(PreTrainedModel):
|
|
680 |
elif isinstance(module, nn.Embedding):
|
681 |
default_flax_embed_init(module.weight)
|
682 |
elif isinstance(module, SiglipAttention):
|
683 |
-
nn.init.
|
684 |
-
nn.init.
|
685 |
-
nn.init.
|
686 |
-
nn.init.
|
687 |
nn.init.zeros_(module.q_proj.bias)
|
688 |
nn.init.zeros_(module.k_proj.bias)
|
689 |
nn.init.zeros_(module.v_proj.bias)
|
690 |
nn.init.zeros_(module.out_proj.bias)
|
691 |
elif isinstance(module, SiglipMLP):
|
692 |
-
nn.init.
|
693 |
-
nn.init.
|
694 |
nn.init.normal_(module.fc1.bias, std=1e-6)
|
695 |
nn.init.normal_(module.fc2.bias, std=1e-6)
|
696 |
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
|
697 |
-
nn.init.
|
698 |
-
nn.init.
|
699 |
nn.init.zeros_(module.attention.in_proj_bias.data)
|
700 |
elif isinstance(module, SiglipModel):
|
701 |
logit_scale_init = torch.log(torch.tensor(1.0))
|
|
|
95 |
|
96 |
# Use inverse cdf transform for normal distribution to get truncated
|
97 |
# standard normal
|
98 |
+
if tensor.dtype == torch.bfloat16:
|
99 |
+
tensor = tensor.to(torch.float32)
|
100 |
+
tensor.erfinv_()
|
101 |
+
tensor = tensor.to(torch.bfloat16)
|
102 |
+
else:
|
103 |
+
tensor.erfinv_()
|
104 |
|
105 |
# Transform to proper mean, std
|
106 |
tensor.mul_(std * math.sqrt(2.0))
|
|
|
675 |
|
676 |
def _init_weights(self, module):
|
677 |
"""Initialize the weights"""
|
678 |
+
|
679 |
if isinstance(module, SiglipVisionEmbeddings):
|
680 |
width = (
|
681 |
self.config.vision_config.hidden_size
|
|
|
686 |
elif isinstance(module, nn.Embedding):
|
687 |
default_flax_embed_init(module.weight)
|
688 |
elif isinstance(module, SiglipAttention):
|
689 |
+
nn.init.normal_(module.q_proj.weight)
|
690 |
+
nn.init.normal_(module.k_proj.weight)
|
691 |
+
nn.init.normal_(module.v_proj.weight)
|
692 |
+
nn.init.normal_(module.out_proj.weight)
|
693 |
nn.init.zeros_(module.q_proj.bias)
|
694 |
nn.init.zeros_(module.k_proj.bias)
|
695 |
nn.init.zeros_(module.v_proj.bias)
|
696 |
nn.init.zeros_(module.out_proj.bias)
|
697 |
elif isinstance(module, SiglipMLP):
|
698 |
+
nn.init.normal_(module.fc1.weight)
|
699 |
+
nn.init.normal_(module.fc2.weight)
|
700 |
nn.init.normal_(module.fc1.bias, std=1e-6)
|
701 |
nn.init.normal_(module.fc2.bias, std=1e-6)
|
702 |
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
|
703 |
+
nn.init.normal_(module.probe.data)
|
704 |
+
nn.init.normal_(module.attention.in_proj_weight.data)
|
705 |
nn.init.zeros_(module.attention.in_proj_bias.data)
|
706 |
elif isinstance(module, SiglipModel):
|
707 |
logit_scale_init = torch.log(torch.tensor(1.0))
|