justinpinkney
commited on
Commit
•
e04297f
1
Parent(s):
004fadc
fix cache logic
Browse files- modelling_RW.py +1 -1
modelling_RW.py
CHANGED
@@ -72,7 +72,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
72 |
dtype=torch.bfloat16,
|
73 |
start_idx: int = 0,
|
74 |
) -> torch.Tensor:
|
75 |
-
if seq_len != self.seq_len_cached
|
76 |
self.seq_len_cached = seq_len
|
77 |
t = torch.arange(start_idx, start_idx+seq_len, device=device).type_as(self.inv_freq)
|
78 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
|
72 |
dtype=torch.bfloat16,
|
73 |
start_idx: int = 0,
|
74 |
) -> torch.Tensor:
|
75 |
+
if seq_len != self.seq_len_cached or self.start_idx != start_idx:
|
76 |
self.seq_len_cached = seq_len
|
77 |
t = torch.arange(start_idx, start_idx+seq_len, device=device).type_as(self.inv_freq)
|
78 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|