Shan1990 commited on
Commit
9d3d7be
1 Parent(s): 06c7c87

fix rmsnorm init weight bug.

Browse files

Using torch.ones to init rmsnorm weight. And torch.empty gets random weight tensor, which maybe out of float value limits.

Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -181,7 +181,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
181
  class RMSNorm(torch.nn.Module):
182
  def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
183
  super().__init__()
184
- self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
185
  self.eps = eps
186
 
187
  def forward(self, hidden_states: torch.Tensor):
 
181
  class RMSNorm(torch.nn.Module):
182
  def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
183
  super().__init__()
184
+ self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype))
185
  self.eps = eps
186
 
187
  def forward(self, hidden_states: torch.Tensor):