Update modeling_feynmodel.py
Browse files- modeling_feynmodel.py +15 -1
modeling_feynmodel.py
CHANGED
@@ -1469,7 +1469,21 @@ class FeynModelForCausalLM(Gemma2ForCausalLM):
|
|
1469 |
else:
|
1470 |
raise TypeError(f"Type inattendu pour self.lm_head.weight : {type(self.lm_head.weight)}")
|
1471 |
|
1472 |
-
min_dtype = torch.finfo(dtype).min
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1473 |
|
1474 |
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
1475 |
attention_mask,
|
|
|
1469 |
else:
|
1470 |
raise TypeError(f"Type inattendu pour self.lm_head.weight : {type(self.lm_head.weight)}")
|
1471 |
|
1472 |
+
# min_dtype = torch.finfo(dtype).min
|
1473 |
+
# Obtenir le dtype des poids de lm_head
|
1474 |
+
if isinstance(self.lm_head, torch.ao.nn.quantized.dynamic.Linear):
|
1475 |
+
# Pour les modules quantifiés dynamiquement, utiliser _weight_bias()
|
1476 |
+
weight, bias = self.lm_head._weight_bias()
|
1477 |
+
dtype = weight.dtype
|
1478 |
+
else:
|
1479 |
+
dtype = self.lm_head.weight.dtype
|
1480 |
+
|
1481 |
+
# Vérifier si dtype est un type de données en virgule flottante
|
1482 |
+
if torch.is_floating_point(torch.empty(0, dtype=dtype)):
|
1483 |
+
min_dtype = torch.finfo(dtype).min
|
1484 |
+
else:
|
1485 |
+
min_dtype = torch.iinfo(dtype).min
|
1486 |
+
|
1487 |
|
1488 |
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
1489 |
attention_mask,
|