justinpinkney commited on
Commit
e04297f
1 Parent(s): 004fadc

fix cache logic

Browse files
Files changed (1) hide show
  1. modelling_RW.py +1 -1
modelling_RW.py CHANGED
@@ -72,7 +72,7 @@ class RotaryEmbedding(torch.nn.Module):
72
  dtype=torch.bfloat16,
73
  start_idx: int = 0,
74
  ) -> torch.Tensor:
75
- if seq_len != self.seq_len_cached and self.start_idx != start_idx:
76
  self.seq_len_cached = seq_len
77
  t = torch.arange(start_idx, start_idx+seq_len, device=device).type_as(self.inv_freq)
78
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
 
72
  dtype=torch.bfloat16,
73
  start_idx: int = 0,
74
  ) -> torch.Tensor:
75
+ if seq_len != self.seq_len_cached or self.start_idx != start_idx:
76
  self.seq_len_cached = seq_len
77
  t = torch.arange(start_idx, start_idx+seq_len, device=device).type_as(self.inv_freq)
78
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)