Matt
commited on
Commit
•
e8c1eff
1
Parent(s):
27cdeb1
Correctly mark z as a buffer
Browse files- modeling_hyena.py +2 -2
modeling_hyena.py
CHANGED
@@ -62,8 +62,8 @@ class HyenaPositionalEmbedding(nn.Module):
|
|
62 |
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
|
63 |
|
64 |
z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
|
65 |
-
|
66 |
-
self.
|
67 |
self.register_buffer("t", t)
|
68 |
|
69 |
def forward(self, L):
|
|
|
62 |
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
|
63 |
|
64 |
z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
|
65 |
+
|
66 |
+
self.register_buffer("z", z)
|
67 |
self.register_buffer("t", t)
|
68 |
|
69 |
def forward(self, L):
|