Imagroune commited on
Commit
c20defd
1 Parent(s): 13ca255

Update modeling_feynmodel.py

Browse files
Files changed (1) hide show
  1. modeling_feynmodel.py +15 -1
modeling_feynmodel.py CHANGED
@@ -1469,7 +1469,21 @@ class FeynModelForCausalLM(Gemma2ForCausalLM):
1469
  else:
1470
  raise TypeError(f"Type inattendu pour self.lm_head.weight : {type(self.lm_head.weight)}")
1471
 
1472
- min_dtype = torch.finfo(dtype).min
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1473
 
1474
  attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1475
  attention_mask,
 
1469
  else:
1470
  raise TypeError(f"Type inattendu pour self.lm_head.weight : {type(self.lm_head.weight)}")
1471
 
1472
+ # min_dtype = torch.finfo(dtype).min
1473
+ # Obtenir le dtype des poids de lm_head
1474
+ if isinstance(self.lm_head, torch.ao.nn.quantized.dynamic.Linear):
1475
+ # Pour les modules quantifiés dynamiquement, utiliser _weight_bias()
1476
+ weight, bias = self.lm_head._weight_bias()
1477
+ dtype = weight.dtype
1478
+ else:
1479
+ dtype = self.lm_head.weight.dtype
1480
+
1481
+ # Vérifier si dtype est un type de données en virgule flottante
1482
+ if torch.is_floating_point(torch.empty(0, dtype=dtype)):
1483
+ min_dtype = torch.finfo(dtype).min
1484
+ else:
1485
+ min_dtype = torch.iinfo(dtype).min
1486
+
1487
 
1488
  attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1489
  attention_mask,