Text Generation
Transformers
PyTorch
RefinedWeb
falcon-40b
rlhf
falcon
custom_code
text-generation-inference
Inference Endpoints
ohallstrom commited on
Commit
2bf3643
1 Parent(s): cceda44

fix bug when num_kv > 1

Browse files
Files changed (1) hide show
  1. modeling_RW.py +2 -2
modeling_RW.py CHANGED
@@ -290,8 +290,8 @@ class Attention(nn.Module):
290
 
291
  if alibi is None:
292
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
293
- key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
294
- value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
295
 
296
  attn_output = F.scaled_dot_product_attention(
297
  query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
 
290
 
291
  if alibi is None:
292
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
293
+ key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
294
+ value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
295
 
296
  attn_output = F.scaled_dot_product_attention(
297
  query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True