jupyterjazz
commited on
Commit
•
11ba200
1
Parent(s):
77a17f7
refactor: revert alibi stuff
Browse filesSigned-off-by: jupyterjazz <[email protected]>
mha.py
CHANGED
@@ -56,7 +56,15 @@ class FlashSelfAttention(nn.Module):
|
|
56 |
(default: 0.0)
|
57 |
"""
|
58 |
|
59 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
super().__init__()
|
61 |
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
62 |
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
@@ -64,6 +72,7 @@ class FlashSelfAttention(nn.Module):
|
|
64 |
self.softmax_scale = softmax_scale
|
65 |
self.drop = nn.Dropout(attention_dropout)
|
66 |
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
|
|
67 |
self.deterministic = deterministic
|
68 |
|
69 |
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
@@ -87,6 +96,8 @@ class FlashSelfAttention(nn.Module):
|
|
87 |
assert qkv.is_cuda
|
88 |
causal = self.causal if causal is None else causal
|
89 |
unpadded = cu_seqlens is not None
|
|
|
|
|
90 |
if unpadded:
|
91 |
assert cu_seqlens.dtype == torch.int32
|
92 |
assert max_seqlen is not None
|
@@ -99,6 +110,7 @@ class FlashSelfAttention(nn.Module):
|
|
99 |
softmax_scale=self.softmax_scale,
|
100 |
causal=causal,
|
101 |
alibi_slopes=self.alibi_slopes,
|
|
|
102 |
deterministic=self.deterministic,
|
103 |
)
|
104 |
else:
|
@@ -108,6 +120,7 @@ class FlashSelfAttention(nn.Module):
|
|
108 |
softmax_scale=self.softmax_scale,
|
109 |
causal=causal,
|
110 |
alibi_slopes=self.alibi_slopes,
|
|
|
111 |
deterministic=self.deterministic,
|
112 |
)
|
113 |
|
@@ -123,7 +136,15 @@ class FlashCrossAttention(nn.Module):
|
|
123 |
(default: 0.0)
|
124 |
"""
|
125 |
|
126 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
super().__init__()
|
128 |
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
129 |
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
@@ -131,6 +152,7 @@ class FlashCrossAttention(nn.Module):
|
|
131 |
self.softmax_scale = softmax_scale
|
132 |
self.drop = nn.Dropout(attention_dropout)
|
133 |
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
|
|
134 |
self.deterministic = deterministic
|
135 |
|
136 |
def forward(
|
@@ -160,6 +182,8 @@ class FlashCrossAttention(nn.Module):
|
|
160 |
assert q.is_cuda and kv.is_cuda
|
161 |
causal = self.causal if causal is None else causal
|
162 |
unpadded = cu_seqlens is not None
|
|
|
|
|
163 |
if unpadded:
|
164 |
assert cu_seqlens.dtype == torch.int32
|
165 |
assert max_seqlen is not None
|
@@ -179,6 +203,7 @@ class FlashCrossAttention(nn.Module):
|
|
179 |
softmax_scale=self.softmax_scale,
|
180 |
causal=causal,
|
181 |
alibi_slopes=self.alibi_slopes,
|
|
|
182 |
deterministic=self.deterministic,
|
183 |
)
|
184 |
else:
|
@@ -192,6 +217,7 @@ class FlashCrossAttention(nn.Module):
|
|
192 |
causal=causal,
|
193 |
softmax_scale=self.softmax_scale,
|
194 |
alibi_slopes=self.alibi_slopes,
|
|
|
195 |
deterministic=self.deterministic,
|
196 |
)
|
197 |
|
@@ -367,6 +393,7 @@ class MHA(nn.Module):
|
|
367 |
rotary_emb_scale_base=None,
|
368 |
rotary_emb_interleaved=False,
|
369 |
use_alibi=False,
|
|
|
370 |
fused_bias_fc=False,
|
371 |
use_flash_attn=False,
|
372 |
return_residual=False,
|
@@ -396,6 +423,8 @@ class MHA(nn.Module):
|
|
396 |
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
397 |
else:
|
398 |
alibi_slopes = None
|
|
|
|
|
399 |
|
400 |
self.num_heads = num_heads
|
401 |
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
@@ -426,12 +455,12 @@ class MHA(nn.Module):
|
|
426 |
)
|
427 |
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
428 |
inner_attn_cls = (
|
429 |
-
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
|
430 |
if use_flash_attn
|
431 |
else SelfAttention
|
432 |
)
|
433 |
inner_cross_attn_cls = (
|
434 |
-
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
|
435 |
if use_flash_attn
|
436 |
else CrossAttention
|
437 |
)
|
@@ -584,7 +613,6 @@ class MHA(nn.Module):
|
|
584 |
assert key_padding_mask is None
|
585 |
assert self.use_flash_attn
|
586 |
assert not self.dwconv
|
587 |
-
# assert self.rotary_emb_dim == 0
|
588 |
if key_padding_mask is not None:
|
589 |
assert cu_seqlens is None
|
590 |
assert max_seqlen is None
|
|
|
56 |
(default: 0.0)
|
57 |
"""
|
58 |
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
causal=False,
|
62 |
+
softmax_scale=None,
|
63 |
+
attention_dropout=0.0,
|
64 |
+
window_size=(-1, -1),
|
65 |
+
alibi_slopes=None,
|
66 |
+
deterministic=False,
|
67 |
+
):
|
68 |
super().__init__()
|
69 |
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
70 |
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
|
|
72 |
self.softmax_scale = softmax_scale
|
73 |
self.drop = nn.Dropout(attention_dropout)
|
74 |
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
75 |
+
self.window_size = window_size
|
76 |
self.deterministic = deterministic
|
77 |
|
78 |
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
|
|
96 |
assert qkv.is_cuda
|
97 |
causal = self.causal if causal is None else causal
|
98 |
unpadded = cu_seqlens is not None
|
99 |
+
if self.alibi_slopes is not None:
|
100 |
+
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
|
101 |
if unpadded:
|
102 |
assert cu_seqlens.dtype == torch.int32
|
103 |
assert max_seqlen is not None
|
|
|
110 |
softmax_scale=self.softmax_scale,
|
111 |
causal=causal,
|
112 |
alibi_slopes=self.alibi_slopes,
|
113 |
+
window_size=self.window_size,
|
114 |
deterministic=self.deterministic,
|
115 |
)
|
116 |
else:
|
|
|
120 |
softmax_scale=self.softmax_scale,
|
121 |
causal=causal,
|
122 |
alibi_slopes=self.alibi_slopes,
|
123 |
+
window_size=self.window_size,
|
124 |
deterministic=self.deterministic,
|
125 |
)
|
126 |
|
|
|
136 |
(default: 0.0)
|
137 |
"""
|
138 |
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
causal=False,
|
142 |
+
softmax_scale=None,
|
143 |
+
attention_dropout=0.0,
|
144 |
+
alibi_slopes=None,
|
145 |
+
window_size=(-1, -1),
|
146 |
+
deterministic=False,
|
147 |
+
):
|
148 |
super().__init__()
|
149 |
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
150 |
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
|
|
152 |
self.softmax_scale = softmax_scale
|
153 |
self.drop = nn.Dropout(attention_dropout)
|
154 |
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
155 |
+
self.window_size = window_size
|
156 |
self.deterministic = deterministic
|
157 |
|
158 |
def forward(
|
|
|
182 |
assert q.is_cuda and kv.is_cuda
|
183 |
causal = self.causal if causal is None else causal
|
184 |
unpadded = cu_seqlens is not None
|
185 |
+
if self.alibi_slopes is not None:
|
186 |
+
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
|
187 |
if unpadded:
|
188 |
assert cu_seqlens.dtype == torch.int32
|
189 |
assert max_seqlen is not None
|
|
|
203 |
softmax_scale=self.softmax_scale,
|
204 |
causal=causal,
|
205 |
alibi_slopes=self.alibi_slopes,
|
206 |
+
window_size=self.window_size,
|
207 |
deterministic=self.deterministic,
|
208 |
)
|
209 |
else:
|
|
|
217 |
causal=causal,
|
218 |
softmax_scale=self.softmax_scale,
|
219 |
alibi_slopes=self.alibi_slopes,
|
220 |
+
window_size=self.window_size,
|
221 |
deterministic=self.deterministic,
|
222 |
)
|
223 |
|
|
|
393 |
rotary_emb_scale_base=None,
|
394 |
rotary_emb_interleaved=False,
|
395 |
use_alibi=False,
|
396 |
+
window_size=(-1, -1),
|
397 |
fused_bias_fc=False,
|
398 |
use_flash_attn=False,
|
399 |
return_residual=False,
|
|
|
423 |
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
424 |
else:
|
425 |
alibi_slopes = None
|
426 |
+
if window_size != (-1, -1):
|
427 |
+
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
428 |
|
429 |
self.num_heads = num_heads
|
430 |
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
|
|
455 |
)
|
456 |
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
457 |
inner_attn_cls = (
|
458 |
+
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
459 |
if use_flash_attn
|
460 |
else SelfAttention
|
461 |
)
|
462 |
inner_cross_attn_cls = (
|
463 |
+
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
464 |
if use_flash_attn
|
465 |
else CrossAttention
|
466 |
)
|
|
|
613 |
assert key_padding_mask is None
|
614 |
assert self.use_flash_attn
|
615 |
assert not self.dwconv
|
|
|
616 |
if key_padding_mask is not None:
|
617 |
assert cu_seqlens is None
|
618 |
assert max_seqlen is None
|