Flash attention

#34
by utensil - opened

I'm trying to figure out whether falcon is using Flash attention (it is per its model card), but I found no related code in the repo such as from flash_attn.flash_attention import FlashMHAetc. Am I missing something?

You can find the model code using scaled_dot_product_attention, check here. It is expected to work in flash attention mode, but due to an issue we got to wait for PyTorch 2.1 to benefit from flash or memory_efficient attention.

Thank you for the reply and the pointers, and the great work in general!

As for xformer attention mentioned in the issue, my test shows that falcon can work with it already and saves ~ 15% VRAM (exact number might vary in different setting).

May I also assume that with pytorch 2.1, falcon will work with better transformer (which includes flash attention to my knowledge ) ? Link: https://huggingface.co/docs/optimum/bettertransformer/overview

Sign up or log in to comment