zifei9 commited on
Commit
5915555
1 Parent(s): e67448d

remove check for cuda

Browse files
Files changed (1) hide show
  1. modeling_gemma.py +1 -1
modeling_gemma.py CHANGED
@@ -507,7 +507,7 @@ class GemmaSdpaAttention(GemmaAttention):
507
 
508
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
509
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
510
- if query_states.device.type == "cuda" and causal_mask is not None:
511
  query_states = query_states.contiguous()
512
  key_states = key_states.contiguous()
513
  value_states = value_states.contiguous()
 
507
 
508
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
509
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
510
+ if causal_mask is not None:
511
  query_states = query_states.contiguous()
512
  key_states = key_states.contiguous()
513
  value_states = value_states.contiguous()