Upload model_atom.py with huggingface_hub
Browse files- model_atom.py +2 -2
model_atom.py
CHANGED
@@ -160,7 +160,7 @@ class LlamaRotaryEmbedding(nn.Module):
|
|
160 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
161 |
self.max_seq_len_cached = seq_len
|
162 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
163 |
-
|
164 |
freqs = torch.outer(t, self.inv_freq)
|
165 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
166 |
emb = torch.cat((freqs, freqs), dim=-1)
|
@@ -211,7 +211,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
|
211 |
base = self.base * (
|
212 |
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
213 |
) ** (self.dim / (self.dim - 2))
|
214 |
-
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
215 |
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
216 |
|
217 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
|
|
160 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
161 |
self.max_seq_len_cached = seq_len
|
162 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
163 |
+
self.inv_freq = self.inv_freq.to(device)
|
164 |
freqs = torch.outer(t, self.inv_freq)
|
165 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
166 |
emb = torch.cat((freqs, freqs), dim=-1)
|
|
|
211 |
base = self.base * (
|
212 |
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
213 |
) ** (self.dim / (self.dim - 2))
|
214 |
+
self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
215 |
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
216 |
|
217 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|