fixing import for modeling_gemma.py
Browse files- modeling_gemma.py +10 -31
modeling_gemma.py
CHANGED
@@ -27,19 +27,19 @@ import torch.utils.checkpoint
|
|
27 |
from torch import nn
|
28 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
29 |
|
30 |
-
from
|
31 |
-
from
|
32 |
-
from
|
33 |
-
from
|
34 |
-
from
|
35 |
BaseModelOutputWithPast,
|
36 |
CausalLMOutputWithPast,
|
37 |
SequenceClassifierOutputWithPast,
|
38 |
TokenClassifierOutput,
|
39 |
)
|
40 |
-
from
|
41 |
-
from
|
42 |
-
from
|
43 |
add_start_docstrings,
|
44 |
add_start_docstrings_to_model_forward,
|
45 |
is_flash_attn_greater_or_equal_2_10,
|
@@ -47,7 +47,7 @@ from ...utils import (
|
|
47 |
replace_return_docstrings,
|
48 |
)
|
49 |
from .configuration_gemma import GemmaConfig
|
50 |
-
|
51 |
|
52 |
logger = logging.get_logger(__name__)
|
53 |
|
@@ -105,27 +105,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
105 |
|
106 |
return causal_mask
|
107 |
|
108 |
-
|
109 |
-
class GemmaRMSNorm(nn.Module):
|
110 |
-
def __init__(self, dim: int, eps: float = 1e-6):
|
111 |
-
super().__init__()
|
112 |
-
self.eps = eps
|
113 |
-
self.weight = nn.Parameter(torch.zeros(dim))
|
114 |
-
|
115 |
-
def _norm(self, x):
|
116 |
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
117 |
-
|
118 |
-
def forward(self, x):
|
119 |
-
output = self._norm(x.float())
|
120 |
-
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
121 |
-
# See https://github.com/huggingface/transformers/pull/29402
|
122 |
-
output = output * (1.0 + self.weight.float())
|
123 |
-
return output.type_as(x)
|
124 |
-
|
125 |
-
def extra_repr(self):
|
126 |
-
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
127 |
-
|
128 |
-
|
129 |
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
|
130 |
|
131 |
|
@@ -528,7 +507,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
|
528 |
|
529 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
530 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
531 |
-
if causal_mask is not None:
|
532 |
query_states = query_states.contiguous()
|
533 |
key_states = key_states.contiguous()
|
534 |
value_states = value_states.contiguous()
|
|
|
27 |
from torch import nn
|
28 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
29 |
|
30 |
+
from transformers.activations import ACT2FN
|
31 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
32 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
33 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
34 |
+
from transformers.modeling_outputs import (
|
35 |
BaseModelOutputWithPast,
|
36 |
CausalLMOutputWithPast,
|
37 |
SequenceClassifierOutputWithPast,
|
38 |
TokenClassifierOutput,
|
39 |
)
|
40 |
+
from transformers.modeling_utils import PreTrainedModel
|
41 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
42 |
+
from transformers.utils import (
|
43 |
add_start_docstrings,
|
44 |
add_start_docstrings_to_model_forward,
|
45 |
is_flash_attn_greater_or_equal_2_10,
|
|
|
47 |
replace_return_docstrings,
|
48 |
)
|
49 |
from .configuration_gemma import GemmaConfig
|
50 |
+
from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
|
51 |
|
52 |
logger = logging.get_logger(__name__)
|
53 |
|
|
|
105 |
|
106 |
return causal_mask
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
|
109 |
|
110 |
|
|
|
507 |
|
508 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
509 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
510 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
511 |
query_states = query_states.contiguous()
|
512 |
key_states = key_states.contiguous()
|
513 |
value_states = value_states.contiguous()
|