Support for gradient checkpointing and Flash Attention
Any plans to support gradient checkpointing and flash attention for training/finetuning? Would be very helpful to get this working on fewer resources.
I think FlashAttn is already used under specific conditions. The attention implementation calls PyTorch's scaled_dot_product_attention
function which calls into a FlashAttn kernel if some conditions are met. You can actually enforce the use of this kernel for debugging purposes with an appropriate context manager:
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
model.generate(**input)
Note, that in the referenced code there is a branch which may execute a naive attention implementation so even though you are using enforcing FA use in PyTorch, you would still make sure that the if-statement runs into the first branch.
If you upgrade to PyTorch 2.2.0, you should be able to directly use PyTorch integrated with Flash Attention 2.0. Try to follow the instructions in the previous comment, but there's no need to pull a separate branch.