Update modeling_feynmodel.py
Browse files- modeling_feynmodel.py +11 -1
modeling_feynmodel.py
CHANGED
@@ -1458,7 +1458,17 @@ class FeynModelForCausalLM(Gemma2ForCausalLM):
|
|
1458 |
device = input_ids.device
|
1459 |
#print(f"22222222 +-+-+-+-+-+-+-+-+-+- sequence_length = input_ids.shape {sequence_length}")
|
1460 |
|
1461 |
-
dtype = self.lm_head.weight.dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1462 |
min_dtype = torch.finfo(dtype).min
|
1463 |
|
1464 |
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
|
|
1458 |
device = input_ids.device
|
1459 |
#print(f"22222222 +-+-+-+-+-+-+-+-+-+- sequence_length = input_ids.shape {sequence_length}")
|
1460 |
|
1461 |
+
# dtype = self.lm_head.weight.dtype
|
1462 |
+
# Obtenir le dtype des poids de lm_head
|
1463 |
+
if hasattr(self.lm_head, 'weight'):
|
1464 |
+
# Vérifier si weight est un attribut ou une méthode
|
1465 |
+
if isinstance(self.lm_head.weight, torch.Tensor):
|
1466 |
+
dtype = self.lm_head.weight.dtype
|
1467 |
+
elif callable(self.lm_head.weight):
|
1468 |
+
dtype = self.lm_head.weight().dtype
|
1469 |
+
else:
|
1470 |
+
raise TypeError(f"Type inattendu pour self.lm_head.weight : {type(self.lm_head.weight)}")
|
1471 |
+
|
1472 |
min_dtype = torch.finfo(dtype).min
|
1473 |
|
1474 |
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|