pranjalchitale commited on
Commit
65a5cec
1 Parent(s): 2b5678b

Update modeling_indictrans.py

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +7 -4
modeling_indictrans.py CHANGED
@@ -54,10 +54,13 @@ logger = logging.get_logger(__name__)
54
 
55
  INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
56
 
57
- if is_flash_attn_2_available():
58
- from flash_attn import flash_attn_func, flash_attn_varlen_func
59
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
60
-
 
 
 
61
 
62
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
63
  def _get_unpad_data(attention_mask):
 
54
 
55
  INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
56
 
57
+ try:
58
+ if is_flash_attn_2_available():
59
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
60
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
61
+ except:
62
+ pass
63
+
64
 
65
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
66
  def _get_unpad_data(attention_mask):