Maple728 commited on
Commit
bdc2191
1 Parent(s): 7e84fb4

fix a bug when flash-attn is not installed

Browse files
Files changed (1) hide show
  1. modeling_time_moe.py +6 -2
modeling_time_moe.py CHANGED
@@ -16,10 +16,14 @@ from .ts_generation_mixin import TSGenerationMixin
16
 
17
  logger = logging.get_logger(__name__)
18
 
19
- if is_flash_attn_2_available():
 
 
 
20
  from flash_attn import flash_attn_func, flash_attn_varlen_func
21
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
22
-
 
23
 
24
  def _get_unpad_data(attention_mask):
25
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
16
 
17
  logger = logging.get_logger(__name__)
18
 
19
+ # if is_flash_attn_2_available():
20
+ # from flash_attn import flash_attn_func, flash_attn_varlen_func
21
+ # from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
22
+ try:
23
  from flash_attn import flash_attn_func, flash_attn_varlen_func
24
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
25
+ except:
26
+ pass
27
 
28
  def _get_unpad_data(attention_mask):
29
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)