Fix the kv-cache dimensions
Browse filesHello!
I have noticed that the dimension of the kv-cache here is weird, and does not match the hugginface transformers modeling_bloom.py file.
Is the departure from the bloom dimension intended?
Judging from the copy-pasted comments, it looks like a bug - also, `_convert_to_rw_cache` & its `_convert_to_standard_cache` counterpart matches bloom dimensions.
- modelling_RW.py +1 -1
modelling_RW.py
CHANGED
@@ -271,7 +271,7 @@ class Attention(nn.Module):
|
|
271 |
# concatenate along seq_length dimension:
|
272 |
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
273 |
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
274 |
-
key_layer = torch.cat((past_key, key_layer), dim=
|
275 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
276 |
|
277 |
_, kv_length, _ = key_layer.shape
|
|
|
271 |
# concatenate along seq_length dimension:
|
272 |
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
273 |
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
274 |
+
key_layer = torch.cat((past_key, key_layer), dim=2)
|
275 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
276 |
|
277 |
_, kv_length, _ = key_layer.shape
|