CyberZHG commited on
Commit
17f5623
1 Parent(s): 561820f

Use input attention mask instead of casual mask in attention

Browse files

The current implementation does not work with left/leading padding.

Files changed (1) hide show
  1. modelling_RW.py +2 -2
modelling_RW.py CHANGED
@@ -281,13 +281,14 @@ class Attention(nn.Module):
281
  else:
282
  present = None
283
 
 
284
  if alibi is None:
285
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
286
  key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
287
  value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
288
 
289
  attn_output = F.scaled_dot_product_attention(
290
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
291
  )
292
 
293
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
@@ -300,7 +301,6 @@ class Attention(nn.Module):
300
  assert not output_attentions # not supported.
301
  return outputs
302
  else:
303
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
304
  matmul_result = query_layer @ key_layer.transpose(-1, -2)
305
 
306
  # change view to [batch_size, num_heads, q_length, kv_length]
 
281
  else:
282
  present = None
283
 
284
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(query_layer.dtype)
285
  if alibi is None:
286
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
287
  key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
288
  value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
289
 
290
  attn_output = F.scaled_dot_product_attention(
291
+ query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
292
  )
293
 
294
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
 
301
  assert not output_attentions # not supported.
302
  return outputs
303
  else:
 
304
  matmul_result = query_layer @ key_layer.transpose(-1, -2)
305
 
306
  # change view to [batch_size, num_heads, q_length, kv_length]