ohallstrom
commited on
Commit
•
2bf3643
1
Parent(s):
cceda44
fix bug when num_kv > 1
Browse files- modeling_RW.py +2 -2
modeling_RW.py
CHANGED
@@ -290,8 +290,8 @@ class Attention(nn.Module):
|
|
290 |
|
291 |
if alibi is None:
|
292 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
293 |
-
key_layer_ = key_layer.reshape(batch_size, self.
|
294 |
-
value_layer_ = value_layer.reshape(batch_size, self.
|
295 |
|
296 |
attn_output = F.scaled_dot_product_attention(
|
297 |
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
|
|
290 |
|
291 |
if alibi is None:
|
292 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
293 |
+
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
294 |
+
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
295 |
|
296 |
attn_output = F.scaled_dot_product_attention(
|
297 |
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|