zifei9 commited on
Commit
e67448d
1 Parent(s): 8745eab

fixing import for modeling_gemma.py

Browse files
Files changed (1) hide show
  1. modeling_gemma.py +10 -31
modeling_gemma.py CHANGED
@@ -27,19 +27,19 @@ import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
 
30
- from ...activations import ACT2FN
31
- from ...cache_utils import Cache, DynamicCache, StaticCache
32
- from ...modeling_attn_mask_utils import AttentionMaskConverter
33
- from ...modeling_flash_attention_utils import _flash_attention_forward
34
- from ...modeling_outputs import (
35
  BaseModelOutputWithPast,
36
  CausalLMOutputWithPast,
37
  SequenceClassifierOutputWithPast,
38
  TokenClassifierOutput,
39
  )
40
- from ...modeling_utils import PreTrainedModel
41
- from ...pytorch_utils import ALL_LAYERNORM_LAYERS
42
- from ...utils import (
43
  add_start_docstrings,
44
  add_start_docstrings_to_model_forward,
45
  is_flash_attn_greater_or_equal_2_10,
@@ -47,7 +47,7 @@ from ...utils import (
47
  replace_return_docstrings,
48
  )
49
  from .configuration_gemma import GemmaConfig
50
-
51
 
52
  logger = logging.get_logger(__name__)
53
 
@@ -105,27 +105,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
105
 
106
  return causal_mask
107
 
108
-
109
- class GemmaRMSNorm(nn.Module):
110
- def __init__(self, dim: int, eps: float = 1e-6):
111
- super().__init__()
112
- self.eps = eps
113
- self.weight = nn.Parameter(torch.zeros(dim))
114
-
115
- def _norm(self, x):
116
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
117
-
118
- def forward(self, x):
119
- output = self._norm(x.float())
120
- # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
121
- # See https://github.com/huggingface/transformers/pull/29402
122
- output = output * (1.0 + self.weight.float())
123
- return output.type_as(x)
124
-
125
- def extra_repr(self):
126
- return f"{tuple(self.weight.shape)}, eps={self.eps}"
127
-
128
-
129
  ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
130
 
131
 
@@ -528,7 +507,7 @@ class GemmaSdpaAttention(GemmaAttention):
528
 
529
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
530
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
531
- if causal_mask is not None:
532
  query_states = query_states.contiguous()
533
  key_states = key_states.contiguous()
534
  value_states = value_states.contiguous()
 
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
 
30
+ from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
32
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
33
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
34
+ from transformers.modeling_outputs import (
35
  BaseModelOutputWithPast,
36
  CausalLMOutputWithPast,
37
  SequenceClassifierOutputWithPast,
38
  TokenClassifierOutput,
39
  )
40
+ from transformers.modeling_utils import PreTrainedModel
41
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
42
+ from transformers.utils import (
43
  add_start_docstrings,
44
  add_start_docstrings_to_model_forward,
45
  is_flash_attn_greater_or_equal_2_10,
 
47
  replace_return_docstrings,
48
  )
49
  from .configuration_gemma import GemmaConfig
50
+ from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
51
 
52
  logger = logging.get_logger(__name__)
53
 
 
105
 
106
  return causal_mask
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
109
 
110
 
 
507
 
508
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
509
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
510
+ if query_states.device.type == "cuda" and causal_mask is not None:
511
  query_states = query_states.contiguous()
512
  key_states = key_states.contiguous()
513
  value_states = value_states.contiguous()