vilarin commited on
Commit
77e2827
1 Parent(s): 393567b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import os
2
  import time
3
  import spaces
@@ -36,7 +42,8 @@ h3 {
36
 
37
  model = AutoModelForCausalLM.from_pretrained(
38
  MODEL_ID,
39
- torch_dtype=torch.float16,
 
40
  trust_remote_code=True).cuda()
41
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
42
 
 
1
+ import subprocess
2
+ subprocess.run(
3
+ 'pip install flash-attn --no-build-isolation',
4
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
5
+ shell=True
6
+ )
7
  import os
8
  import time
9
  import spaces
 
42
 
43
  model = AutoModelForCausalLM.from_pretrained(
44
  MODEL_ID,
45
+ torch_dtype=torch.float16,
46
+ attn_implementation="flash_attention_2",
47
  trust_remote_code=True).cuda()
48
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
49