fix a bug when flash-attn is not installed
Browse files- modeling_time_moe.py +6 -2
modeling_time_moe.py
CHANGED
@@ -16,10 +16,14 @@ from .ts_generation_mixin import TSGenerationMixin
|
|
16 |
|
17 |
logger = logging.get_logger(__name__)
|
18 |
|
19 |
-
if is_flash_attn_2_available():
|
|
|
|
|
|
|
20 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
21 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
22 |
-
|
|
|
23 |
|
24 |
def _get_unpad_data(attention_mask):
|
25 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
16 |
|
17 |
logger = logging.get_logger(__name__)
|
18 |
|
19 |
+
# if is_flash_attn_2_available():
|
20 |
+
# from flash_attn import flash_attn_func, flash_attn_varlen_func
|
21 |
+
# from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
22 |
+
try:
|
23 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
24 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
25 |
+
except:
|
26 |
+
pass
|
27 |
|
28 |
def _get_unpad_data(attention_mask):
|
29 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|