jupyterjazz commited on
Commit
e8e1e15
1 Parent(s): c232c27

refactor: raise error if flash attention is not installed

Browse files
Files changed (1) hide show
  1. rotary.py +5 -1
rotary.py CHANGED
@@ -6,7 +6,11 @@ from typing import Optional, Tuple, Union
6
 
7
  import torch
8
  from einops import rearrange, repeat
9
- from flash_attn.ops.triton.rotary import apply_rotary
 
 
 
 
10
 
11
 
12
  def rotate_half(x, interleaved=False):
 
6
 
7
  import torch
8
  from einops import rearrange, repeat
9
+ try:
10
+ from flash_attn.ops.triton.rotary import apply_rotary
11
+ except ImportError:
12
+ def apply_rotary(*args, **kwargs):
13
+ raise RuntimeError('RoPE requires flash-attention to be installed')
14
 
15
 
16
  def rotate_half(x, interleaved=False):