Imagroune commited on
Commit
13ca255
1 Parent(s): e7920db

Update modeling_feynmodel.py

Browse files
Files changed (1) hide show
  1. modeling_feynmodel.py +11 -1
modeling_feynmodel.py CHANGED
@@ -1458,7 +1458,17 @@ class FeynModelForCausalLM(Gemma2ForCausalLM):
1458
  device = input_ids.device
1459
  #print(f"22222222 +-+-+-+-+-+-+-+-+-+- sequence_length = input_ids.shape {sequence_length}")
1460
 
1461
- dtype = self.lm_head.weight.dtype
 
 
 
 
 
 
 
 
 
 
1462
  min_dtype = torch.finfo(dtype).min
1463
 
1464
  attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
 
1458
  device = input_ids.device
1459
  #print(f"22222222 +-+-+-+-+-+-+-+-+-+- sequence_length = input_ids.shape {sequence_length}")
1460
 
1461
+ # dtype = self.lm_head.weight.dtype
1462
+ # Obtenir le dtype des poids de lm_head
1463
+ if hasattr(self.lm_head, 'weight'):
1464
+ # Vérifier si weight est un attribut ou une méthode
1465
+ if isinstance(self.lm_head.weight, torch.Tensor):
1466
+ dtype = self.lm_head.weight.dtype
1467
+ elif callable(self.lm_head.weight):
1468
+ dtype = self.lm_head.weight().dtype
1469
+ else:
1470
+ raise TypeError(f"Type inattendu pour self.lm_head.weight : {type(self.lm_head.weight)}")
1471
+
1472
  min_dtype = torch.finfo(dtype).min
1473
 
1474
  attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(