remove check for cuda
Browse files- modeling_gemma.py +1 -1
modeling_gemma.py
CHANGED
@@ -507,7 +507,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
|
507 |
|
508 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
509 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
510 |
-
if
|
511 |
query_states = query_states.contiguous()
|
512 |
key_states = key_states.contiguous()
|
513 |
value_states = value_states.contiguous()
|
|
|
507 |
|
508 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
509 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
510 |
+
if causal_mask is not None:
|
511 |
query_states = query_states.contiguous()
|
512 |
key_states = key_states.contiguous()
|
513 |
value_states = value_states.contiguous()
|