Triton-InternVL2-2B / triton_flash_atn.py
radna's picture
Upload 20 files
5c0cb68 verified
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team
Extra Credits:
- Original flash attention paper (https://arxiv.org/abs/2205.14135)
- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
"""
import pytest
import torch
import triton
import triton.language as tl
# Pick the fp8 data type
# AMD E4M3B8
# Note: When picking this f8 data type, scaling is required when using f8
# for the second gemm
# TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
# AMD E5M2B16
TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz')
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q,
K_block_ptr, V_block_ptr,
start_m,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
N_CTX,
pre_load_v: tl.constexpr):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
# causal = False
else:
lo, hi = 0, N_CTX
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(K_block_ptr)
if pre_load_v:
v = tl.load(V_block_ptr)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = tl.where(mask, qk, float("-inf"))
qk += tl.dot(q, k)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
# -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
if not pre_load_v:
v = tl.load(V_block_ptr)
acc += tl.dot(p.to(v.dtype), v)
# -- update m_i and l_i
l_ij = tl.sum(p, 1)
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
return acc, l_i, m_i
# We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting
# the code below and commenting out the equivalent parameters is convenient for
# re-tuning.
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 2,
'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=2),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 2,
'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2,
'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=1),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2,
'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=1),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 2,
'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 3,
'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=1),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 3,
'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=1),
],
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
)
@triton.jit
def _attn_fwd(Q, K, V, sm_scale, M, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H,
N_CTX,
BLOCK_DMODEL: tl.constexpr,
STAGE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
pre_load_v: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(q.dtype)
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
4 - STAGE, offs_m, offs_n, N_CTX,
pre_load_v,
)
# stage 2: on-band
if STAGE & 2:
# barrier makes it easier for compielr to schedule the
# two loops independently
tl.debug_barrier()
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
2, offs_m, offs_n, N_CTX,
pre_load_v,
)
# epilogue
# write back m
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(m_ptrs, m_i + tl.math.log2(l_i))
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
@triton.jit
def _attn_bwd_preprocess(O, DO,
Delta,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_hz = tl.program_id(1)
off_n = tl.arange(0, D_HEAD)
o = tl.load(O + off_hz * D_HEAD * N_CTX +
off_m[:, None] * D_HEAD + off_n[None, :])
do = tl.load(DO + off_hz * D_HEAD * N_CTX +
off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
delta = tl.sum(o * do, axis=1)
tl.store(Delta + off_hz * N_CTX + off_m, delta)
# The main inner-loop logic for computing dK and dV.
@triton.jit
def _attn_bwd_dkdv(dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
# shared by Q/K/V/DO.
stride_tok, stride_d,
H, N_CTX, BLOCK_M1: tl.constexpr,
BLOCK_N1: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
# Filled in by the wrapper.
start_n, start_m, num_steps,
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M1)
offs_n = start_n + tl.arange(0, BLOCK_N1)
offs_k = tl.arange(0, BLOCK_DMODEL)
QT_block_ptr = tl.make_block_ptr(
base=Q,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_d, stride_tok),
offsets=(0, start_m),
block_shape=(BLOCK_DMODEL, BLOCK_M1),
order=(0, 1)
)
DO_block_ptr = tl.make_block_ptr(
base=DO,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_m, 0),
block_shape=(BLOCK_M1, BLOCK_DMODEL),
order=(1, 0)
)
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
curr_m = start_m
step_m = BLOCK_M1
for blk_idx in range(num_steps):
qT = tl.load(QT_block_ptr)
# Load m before computing qk to reduce pipeline stall.
offs_m = curr_m + tl.arange(0, BLOCK_M1)
m = tl.load(M + offs_m)
qkT = tl.dot(k, qT)
pT = tl.math.exp2(qkT - m[None, :])
# Autoregressive masking.
if MASK:
mask = (offs_m[None, :] >= offs_n[:, None])
pT = tl.where(mask, pT, 0.0)
do = tl.load(DO_block_ptr)
# Compute dV.
ppT = pT
ppT = ppT.to(tl.float16)
dv += tl.dot(ppT, do)
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do))
dsT = pT * (dpT - Di[None, :])
dsT = dsT.to(tl.float16)
dk += tl.dot(dsT, tl.trans(qT))
# Increment pointers.
curr_m += step_m
QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m))
DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0))
return dk, dv
# the main inner-loop logic for computing dQ
@triton.jit
def _attn_bwd_dq(dq, q, K, V,
do, m, D,
# shared by Q/K/V/DO.
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2: tl.constexpr,
BLOCK_N2: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
# Filled in by the wrapper.
start_m, start_n, num_steps,
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M2)
offs_n = start_n + tl.arange(0, BLOCK_N2)
offs_k = tl.arange(0, BLOCK_DMODEL)
KT_block_ptr = tl.make_block_ptr(
base=K,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_d, stride_tok),
offsets=(0, start_n),
block_shape=(BLOCK_DMODEL, BLOCK_N2),
order=(0, 1)
)
VT_block_ptr = tl.make_block_ptr(
base=V,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_d, stride_tok),
offsets=(0, start_n),
block_shape=(BLOCK_DMODEL, BLOCK_N2),
order=(0, 1)
)
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
curr_n = start_n
step_n = BLOCK_N2
for blk_idx in range(num_steps):
kT = tl.load(KT_block_ptr)
qk = tl.dot(q, kT)
p = tl.math.exp2(qk - m)
# Autoregressive masking.
if MASK:
offs_n = curr_n + tl.arange(0, BLOCK_N2)
mask = (offs_m[:, None] >= offs_n[None, :])
p = tl.where(mask, p, 0.0)
# Compute dP and dS.
vT = tl.load(VT_block_ptr)
dp = tl.dot(do, vT).to(tl.float32)
ds = p * (dp - Di[:, None])
ds = ds.to(tl.float16)
# Compute dQ.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq += tl.dot(ds, tl.trans(kT))
# Increment pointers.
curr_n += step_n
KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n))
VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))
return dq
@triton.autotune(
configs=[
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
num_stages=1, num_warps=8),
],
key=['H', 'N_CTX', 'BLOCK_DMODEL'],
)
@triton.jit
def _attn_bwd(Q, K, V, sm_scale,
DO,
DQ, DK, DV,
M, D,
# shared by Q/K/V/DO.
stride_z, stride_h, stride_tok, stride_d,
# H = 16, N_CTX = 1024
H, N_CTX,
BLOCK_DMODEL: tl.constexpr,
BLOCK_M1: tl.constexpr,
BLOCK_N1: tl.constexpr,
BLOCK_M2: tl.constexpr,
BLOCK_N2: tl.constexpr,
BLK_SLICE_FACTOR: tl.constexpr):
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
bhid = tl.program_id(2)
off_chz = (bhid * N_CTX).to(tl.int64)
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
pid = tl.program_id(0)
# offset pointers for batch/head
Q += adj
K += adj
V += adj
DO += adj
DQ += adj
DK += adj
DV += adj
M += off_chz
D += off_chz
offs_k = tl.arange(0, BLOCK_DMODEL)
start_n = pid * BLOCK_N1
# This assignment is important. It is what allows us to pick the diagonal
# blocks. Later, when we want to do the lower triangular, we update start_m
# after the first dkdv call.
start_m = start_n
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_n, 0),
block_shape=(BLOCK_N1, BLOCK_DMODEL),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_n, 0),
block_shape=(BLOCK_N1, BLOCK_DMODEL),
order=(1, 0),
)
# load K and V: they stay in SRAM throughout the inner loop for dkdv.
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
num_steps = BLOCK_N1 // MASK_BLOCK_M1
dk, dv = _attn_bwd_dkdv(dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=True
)
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
# Compute dK and dV for non-masked blocks.
dk, dv = _attn_bwd_dkdv(
dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=False
)
DV_block_ptrs = tl.make_block_ptr(
base=DV,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_n, 0),
block_shape=(BLOCK_N1, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(DV_block_ptrs, dv.to(tl.float16))
# Write back dK.
dk *= sm_scale
DK_block_ptrs = tl.make_block_ptr(
base=DK,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_n, 0),
block_shape=(BLOCK_N1, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(DK_block_ptrs, dk.to(tl.float16))
# THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + tl.arange(0, BLOCK_M2)
Q_block_ptr = tl.make_block_ptr(
base=Q,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_m, 0),
block_shape=(BLOCK_M2, BLOCK_DMODEL),
order=(1, 0)
)
DO_block_ptr = tl.make_block_ptr(
base=DO,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_m, 0),
block_shape=(BLOCK_M2, BLOCK_DMODEL),
order=(1, 0)
)
q = tl.load(Q_block_ptr)
do = tl.load(DO_block_ptr)
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
m = tl.load(M + offs_m)
m = m[:, None]
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V,
do, m, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
MASK=True
)
end_n -= num_steps * MASK_BLOCK_N2
# stage 2
num_steps = end_n // BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V,
do, m, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * BLOCK_N2, num_steps,
MASK=False
)
# Write back dQ.
DQ_block_ptr = tl.make_block_ptr(
base=DQ,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_m, 0),
block_shape=(BLOCK_M2, BLOCK_DMODEL),
order=(1, 0)
)
dq *= LN2
tl.store(DQ_block_ptr, dq.to(tl.float16))
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q, dtype=v.dtype)
if torch.version.hip is None:
BLOCK_M = 128
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
num_warps = 4 if Lk <= 64 else 8
# Tuning for H100
if torch.cuda.get_device_capability()[0] == 9:
num_warps = 8
num_stages = 7 if Lk >= 64 else 3
stage = 3 if causal else 1
def grid(META): return (
triton.cdiv(q.shape[2], META['BLOCK_M']),
q.shape[0] * q.shape[1],
1
)
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]),
device=q.device, dtype=torch.float32)
_attn_fwd[grid](
q, k, v, sm_scale, M, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1],
N_CTX=q.shape[2],
BLOCK_DMODEL=Lk,
STAGE=stage,
)
# restore the grid for bwd kernel
best_config = _attn_fwd.get_best_config()
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)
ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
if torch.version.hip is not None:
BLOCK = 64
else:
BLOCK = 128
q, k, v, o, M = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 1
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
assert N_CTX % PRE_BLOCK == 0
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
delta = torch.empty_like(M)
_attn_bwd_preprocess[pre_grid](
o, do,
delta,
BATCH, N_HEAD, N_CTX,
BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL
)
def grid(META): return (
triton.cdiv(N_CTX, META['BLOCK_N1']),
1,
BATCH * N_HEAD
)
_attn_bwd[grid](
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,
M, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
N_HEAD, N_CTX,
BLOCK_DMODEL=ctx.BLOCK_DMODEL
)
return dq, dk, dv, None, None
attention = _attention.apply