Spaces:
Running
on
Zero
Running
on
Zero
Feature(MInference): build demo
Browse files- .gitignore +415 -0
- LICENSE +21 -0
- README.md +130 -1
- app.py +145 -4
- images/MInference1_onepage.png +0 -0
- images/MInference_logo.png +0 -0
- images/benchmarks/needle_viz_LLaMA-3-8B-1M_ours_1K_1000K.png +0 -0
- images/benchmarks/ppl-LLaMA-3-262k.png +0 -0
- minference/__init__.py +27 -0
- minference/configs/Llama_3_8B_Instruct_262k_kv_out_v32_fit_o_best_pattern.json +1 -0
- minference/configs/Phi_3_mini_128k_instruct_kv_out_v32_fit_o_best_pattern.json +1 -0
- minference/configs/Yi_9B_200k_kv_out_v32_fit_o_best_pattern.json +1 -0
- minference/configs/model2path.py +17 -0
- minference/minference_configuration.py +49 -0
- minference/models_patch.py +100 -0
- minference/modules/inf_llm.py +1296 -0
- minference/modules/minference_forward.py +855 -0
- minference/modules/snap_kv.py +422 -0
- minference/ops/block_sparse_flash_attention.py +464 -0
- minference/ops/pit_sparse_flash_attention.py +740 -0
- minference/ops/pit_sparse_flash_attention_v2.py +735 -0
- minference/ops/streaming_kernel.py +763 -0
- minference/patch.py +1279 -0
- minference/version.py +14 -0
- requirements.txt +5 -0
.gitignore
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Ignore Visual Studio temporary files, build results, and
|
2 |
+
## files generated by popular Visual Studio add-ons.
|
3 |
+
##
|
4 |
+
## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
|
5 |
+
|
6 |
+
# User-specific files
|
7 |
+
*.rsuser
|
8 |
+
*.suo
|
9 |
+
*.user
|
10 |
+
*.userosscache
|
11 |
+
*.sln.docstates
|
12 |
+
|
13 |
+
# User-specific files (MonoDevelop/Xamarin Studio)
|
14 |
+
*.userprefs
|
15 |
+
|
16 |
+
# Mono auto generated files
|
17 |
+
mono_crash.*
|
18 |
+
|
19 |
+
# Build results
|
20 |
+
[Dd]ebug/
|
21 |
+
[Dd]ebugPublic/
|
22 |
+
[Rr]elease/
|
23 |
+
[Rr]eleases/
|
24 |
+
x64/
|
25 |
+
x86/
|
26 |
+
[Ww][Ii][Nn]32/
|
27 |
+
[Aa][Rr][Mm]/
|
28 |
+
[Aa][Rr][Mm]64/
|
29 |
+
bld/
|
30 |
+
[Bb]in/
|
31 |
+
[Oo]bj/
|
32 |
+
[Ll]og/
|
33 |
+
[Ll]ogs/
|
34 |
+
|
35 |
+
# Visual Studio 2015/2017 cache/options directory
|
36 |
+
.vs/
|
37 |
+
# Uncomment if you have tasks that create the project's static files in wwwroot
|
38 |
+
#wwwroot/
|
39 |
+
|
40 |
+
# Visual Studio 2017 auto generated files
|
41 |
+
Generated\ Files/
|
42 |
+
|
43 |
+
# MSTest test Results
|
44 |
+
[Tt]est[Rr]esult*/
|
45 |
+
[Bb]uild[Ll]og.*
|
46 |
+
|
47 |
+
# NUnit
|
48 |
+
*.VisualState.xml
|
49 |
+
TestResult.xml
|
50 |
+
nunit-*.xml
|
51 |
+
|
52 |
+
# Build Results of an ATL Project
|
53 |
+
[Dd]ebugPS/
|
54 |
+
[Rr]eleasePS/
|
55 |
+
dlldata.c
|
56 |
+
|
57 |
+
# Benchmark Results
|
58 |
+
BenchmarkDotNet.Artifacts/
|
59 |
+
|
60 |
+
# .NET Core
|
61 |
+
project.lock.json
|
62 |
+
project.fragment.lock.json
|
63 |
+
artifacts/
|
64 |
+
|
65 |
+
# ASP.NET Scaffolding
|
66 |
+
ScaffoldingReadMe.txt
|
67 |
+
|
68 |
+
# StyleCop
|
69 |
+
StyleCopReport.xml
|
70 |
+
|
71 |
+
# Files built by Visual Studio
|
72 |
+
*_i.c
|
73 |
+
*_p.c
|
74 |
+
*_h.h
|
75 |
+
*.ilk
|
76 |
+
*.meta
|
77 |
+
*.obj
|
78 |
+
*.iobj
|
79 |
+
*.pch
|
80 |
+
*.pdb
|
81 |
+
*.ipdb
|
82 |
+
*.pgc
|
83 |
+
*.pgd
|
84 |
+
*.rsp
|
85 |
+
*.sbr
|
86 |
+
*.tlb
|
87 |
+
*.tli
|
88 |
+
*.tlh
|
89 |
+
*.tmp
|
90 |
+
*.tmp_proj
|
91 |
+
*_wpftmp.csproj
|
92 |
+
*.log
|
93 |
+
*.tlog
|
94 |
+
*.vspscc
|
95 |
+
*.vssscc
|
96 |
+
.builds
|
97 |
+
*.pidb
|
98 |
+
*.svclog
|
99 |
+
*.scc
|
100 |
+
|
101 |
+
# Chutzpah Test files
|
102 |
+
_Chutzpah*
|
103 |
+
|
104 |
+
# Visual C++ cache files
|
105 |
+
ipch/
|
106 |
+
*.aps
|
107 |
+
*.ncb
|
108 |
+
*.opendb
|
109 |
+
*.opensdf
|
110 |
+
*.sdf
|
111 |
+
*.cachefile
|
112 |
+
*.VC.db
|
113 |
+
*.VC.VC.opendb
|
114 |
+
|
115 |
+
# Visual Studio profiler
|
116 |
+
*.psess
|
117 |
+
*.vsp
|
118 |
+
*.vspx
|
119 |
+
*.sap
|
120 |
+
|
121 |
+
# Visual Studio Trace Files
|
122 |
+
*.e2e
|
123 |
+
|
124 |
+
# TFS 2012 Local Workspace
|
125 |
+
$tf/
|
126 |
+
|
127 |
+
# Guidance Automation Toolkit
|
128 |
+
*.gpState
|
129 |
+
|
130 |
+
# ReSharper is a .NET coding add-in
|
131 |
+
_ReSharper*/
|
132 |
+
*.[Rr]e[Ss]harper
|
133 |
+
*.DotSettings.user
|
134 |
+
|
135 |
+
# TeamCity is a build add-in
|
136 |
+
_TeamCity*
|
137 |
+
|
138 |
+
# DotCover is a Code Coverage Tool
|
139 |
+
*.dotCover
|
140 |
+
|
141 |
+
# AxoCover is a Code Coverage Tool
|
142 |
+
.axoCover/*
|
143 |
+
!.axoCover/settings.json
|
144 |
+
|
145 |
+
# Coverlet is a free, cross platform Code Coverage Tool
|
146 |
+
coverage*.json
|
147 |
+
coverage*.xml
|
148 |
+
coverage*.info
|
149 |
+
|
150 |
+
# Visual Studio code coverage results
|
151 |
+
*.coverage
|
152 |
+
*.coveragexml
|
153 |
+
|
154 |
+
# NCrunch
|
155 |
+
_NCrunch_*
|
156 |
+
.*crunch*.local.xml
|
157 |
+
nCrunchTemp_*
|
158 |
+
|
159 |
+
# MightyMoose
|
160 |
+
*.mm.*
|
161 |
+
AutoTest.Net/
|
162 |
+
|
163 |
+
# Web workbench (sass)
|
164 |
+
.sass-cache/
|
165 |
+
|
166 |
+
# Installshield output folder
|
167 |
+
[Ee]xpress/
|
168 |
+
|
169 |
+
# DocProject is a documentation generator add-in
|
170 |
+
DocProject/buildhelp/
|
171 |
+
DocProject/Help/*.HxT
|
172 |
+
DocProject/Help/*.HxC
|
173 |
+
DocProject/Help/*.hhc
|
174 |
+
DocProject/Help/*.hhk
|
175 |
+
DocProject/Help/*.hhp
|
176 |
+
DocProject/Help/Html2
|
177 |
+
DocProject/Help/html
|
178 |
+
|
179 |
+
# Click-Once directory
|
180 |
+
publish/
|
181 |
+
|
182 |
+
# Publish Web Output
|
183 |
+
*.[Pp]ublish.xml
|
184 |
+
*.azurePubxml
|
185 |
+
# Note: Comment the next line if you want to checkin your web deploy settings,
|
186 |
+
# but database connection strings (with potential passwords) will be unencrypted
|
187 |
+
*.pubxml
|
188 |
+
*.publishproj
|
189 |
+
|
190 |
+
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
191 |
+
# checkin your Azure Web App publish settings, but sensitive information contained
|
192 |
+
# in these scripts will be unencrypted
|
193 |
+
PublishScripts/
|
194 |
+
|
195 |
+
# NuGet Packages
|
196 |
+
*.nupkg
|
197 |
+
# NuGet Symbol Packages
|
198 |
+
*.snupkg
|
199 |
+
# The packages folder can be ignored because of Package Restore
|
200 |
+
**/[Pp]ackages/*
|
201 |
+
# except build/, which is used as an MSBuild target.
|
202 |
+
!**/[Pp]ackages/build/
|
203 |
+
# Uncomment if necessary however generally it will be regenerated when needed
|
204 |
+
#!**/[Pp]ackages/repositories.config
|
205 |
+
# NuGet v3's project.json files produces more ignorable files
|
206 |
+
*.nuget.props
|
207 |
+
*.nuget.targets
|
208 |
+
|
209 |
+
# Microsoft Azure Build Output
|
210 |
+
csx/
|
211 |
+
*.build.csdef
|
212 |
+
|
213 |
+
# Microsoft Azure Emulator
|
214 |
+
ecf/
|
215 |
+
rcf/
|
216 |
+
|
217 |
+
# Windows Store app package directories and files
|
218 |
+
AppPackages/
|
219 |
+
BundleArtifacts/
|
220 |
+
Package.StoreAssociation.xml
|
221 |
+
_pkginfo.txt
|
222 |
+
*.appx
|
223 |
+
*.appxbundle
|
224 |
+
*.appxupload
|
225 |
+
|
226 |
+
# Visual Studio cache files
|
227 |
+
# files ending in .cache can be ignored
|
228 |
+
*.[Cc]ache
|
229 |
+
# but keep track of directories ending in .cache
|
230 |
+
!?*.[Cc]ache/
|
231 |
+
|
232 |
+
# Others
|
233 |
+
ClientBin/
|
234 |
+
~$*
|
235 |
+
*~
|
236 |
+
*.dbmdl
|
237 |
+
*.dbproj.schemaview
|
238 |
+
*.jfm
|
239 |
+
*.pfx
|
240 |
+
*.publishsettings
|
241 |
+
orleans.codegen.cs
|
242 |
+
|
243 |
+
# Including strong name files can present a security risk
|
244 |
+
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
245 |
+
#*.snk
|
246 |
+
|
247 |
+
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
248 |
+
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
249 |
+
#bower_components/
|
250 |
+
|
251 |
+
# RIA/Silverlight projects
|
252 |
+
Generated_Code/
|
253 |
+
|
254 |
+
# Backup & report files from converting an old project file
|
255 |
+
# to a newer Visual Studio version. Backup files are not needed,
|
256 |
+
# because we have git ;-)
|
257 |
+
_UpgradeReport_Files/
|
258 |
+
Backup*/
|
259 |
+
UpgradeLog*.XML
|
260 |
+
UpgradeLog*.htm
|
261 |
+
ServiceFabricBackup/
|
262 |
+
*.rptproj.bak
|
263 |
+
|
264 |
+
# SQL Server files
|
265 |
+
*.mdf
|
266 |
+
*.ldf
|
267 |
+
*.ndf
|
268 |
+
|
269 |
+
# Business Intelligence projects
|
270 |
+
*.rdl.data
|
271 |
+
*.bim.layout
|
272 |
+
*.bim_*.settings
|
273 |
+
*.rptproj.rsuser
|
274 |
+
*- [Bb]ackup.rdl
|
275 |
+
*- [Bb]ackup ([0-9]).rdl
|
276 |
+
*- [Bb]ackup ([0-9][0-9]).rdl
|
277 |
+
|
278 |
+
# Microsoft Fakes
|
279 |
+
FakesAssemblies/
|
280 |
+
|
281 |
+
# GhostDoc plugin setting file
|
282 |
+
*.GhostDoc.xml
|
283 |
+
|
284 |
+
# Node.js Tools for Visual Studio
|
285 |
+
.ntvs_analysis.dat
|
286 |
+
node_modules/
|
287 |
+
|
288 |
+
# Visual Studio 6 build log
|
289 |
+
*.plg
|
290 |
+
|
291 |
+
# Visual Studio 6 workspace options file
|
292 |
+
*.opt
|
293 |
+
|
294 |
+
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
295 |
+
*.vbw
|
296 |
+
|
297 |
+
# Visual Studio 6 auto-generated project file (contains which files were open etc.)
|
298 |
+
*.vbp
|
299 |
+
|
300 |
+
# Visual Studio 6 workspace and project file (working project files containing files to include in project)
|
301 |
+
*.dsw
|
302 |
+
*.dsp
|
303 |
+
|
304 |
+
# Visual Studio 6 technical files
|
305 |
+
*.ncb
|
306 |
+
*.aps
|
307 |
+
|
308 |
+
# Visual Studio LightSwitch build output
|
309 |
+
**/*.HTMLClient/GeneratedArtifacts
|
310 |
+
**/*.DesktopClient/GeneratedArtifacts
|
311 |
+
**/*.DesktopClient/ModelManifest.xml
|
312 |
+
**/*.Server/GeneratedArtifacts
|
313 |
+
**/*.Server/ModelManifest.xml
|
314 |
+
_Pvt_Extensions
|
315 |
+
|
316 |
+
# Paket dependency manager
|
317 |
+
.paket/paket.exe
|
318 |
+
paket-files/
|
319 |
+
|
320 |
+
# FAKE - F# Make
|
321 |
+
.fake/
|
322 |
+
|
323 |
+
# CodeRush personal settings
|
324 |
+
.cr/personal
|
325 |
+
|
326 |
+
# Python Tools for Visual Studio (PTVS)
|
327 |
+
__pycache__/
|
328 |
+
*.pyc
|
329 |
+
|
330 |
+
# Cake - Uncomment if you are using it
|
331 |
+
# tools/**
|
332 |
+
# !tools/packages.config
|
333 |
+
|
334 |
+
# Tabs Studio
|
335 |
+
*.tss
|
336 |
+
|
337 |
+
# Telerik's JustMock configuration file
|
338 |
+
*.jmconfig
|
339 |
+
|
340 |
+
# BizTalk build output
|
341 |
+
*.btp.cs
|
342 |
+
*.btm.cs
|
343 |
+
*.odx.cs
|
344 |
+
*.xsd.cs
|
345 |
+
|
346 |
+
# OpenCover UI analysis results
|
347 |
+
OpenCover/
|
348 |
+
|
349 |
+
# Azure Stream Analytics local run output
|
350 |
+
ASALocalRun/
|
351 |
+
|
352 |
+
# MSBuild Binary and Structured Log
|
353 |
+
*.binlog
|
354 |
+
|
355 |
+
# NVidia Nsight GPU debugger configuration file
|
356 |
+
*.nvuser
|
357 |
+
|
358 |
+
# MFractors (Xamarin productivity tool) working folder
|
359 |
+
.mfractor/
|
360 |
+
|
361 |
+
# Local History for Visual Studio
|
362 |
+
.localhistory/
|
363 |
+
|
364 |
+
# Visual Studio History (VSHistory) files
|
365 |
+
.vshistory/
|
366 |
+
|
367 |
+
# BeatPulse healthcheck temp database
|
368 |
+
healthchecksdb
|
369 |
+
|
370 |
+
# Backup folder for Package Reference Convert tool in Visual Studio 2017
|
371 |
+
MigrationBackup/
|
372 |
+
|
373 |
+
# Ionide (cross platform F# VS Code tools) working folder
|
374 |
+
.ionide/
|
375 |
+
|
376 |
+
# Fody - auto-generated XML schema
|
377 |
+
FodyWeavers.xsd
|
378 |
+
|
379 |
+
# VS Code files for those working on multiple tools
|
380 |
+
.vscode/*
|
381 |
+
!.vscode/settings.json
|
382 |
+
!.vscode/tasks.json
|
383 |
+
!.vscode/launch.json
|
384 |
+
!.vscode/extensions.json
|
385 |
+
*.code-workspace
|
386 |
+
|
387 |
+
# Local History for Visual Studio Code
|
388 |
+
.history/
|
389 |
+
|
390 |
+
# Windows Installer files from build outputs
|
391 |
+
*.cab
|
392 |
+
*.msi
|
393 |
+
*.msix
|
394 |
+
*.msm
|
395 |
+
*.msp
|
396 |
+
|
397 |
+
# JetBrains Rider
|
398 |
+
*.sln.iml
|
399 |
+
|
400 |
+
# Experiments
|
401 |
+
data
|
402 |
+
!experiments/ruler/data
|
403 |
+
needle
|
404 |
+
results
|
405 |
+
*.json
|
406 |
+
*.jsonl
|
407 |
+
.vscode/
|
408 |
+
*.pt
|
409 |
+
*.pkl
|
410 |
+
!minference/configs/*
|
411 |
+
|
412 |
+
__pycache__
|
413 |
+
build/
|
414 |
+
*.egg-info/
|
415 |
+
*.so
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Microsoft Corporation.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE
|
README.md
CHANGED
@@ -10,4 +10,133 @@ pinned: false
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
<div style="display: flex; align-items: center;">
|
14 |
+
<div style="width: 100px; margin-right: 10px; height:auto;" align="left">
|
15 |
+
<img src="images/MInference_logo.png" alt="MInference" width="100" align="left">
|
16 |
+
</div>
|
17 |
+
<div style="flex-grow: 1;" align="center">
|
18 |
+
<h2 align="center">MInference: Million-Tokens Prompt Inference for LLMs</h2>
|
19 |
+
</div>
|
20 |
+
</div>
|
21 |
+
|
22 |
+
<p align="center">
|
23 |
+
| <a href="https://llmlingua.com/"><b>Project Page</b></a> |
|
24 |
+
<a href="https://arxiv.org/abs/2406."><b>Paper</b></a> |
|
25 |
+
<a href="https://huggingface.co/spaces/microsoft/MInference"><b>Demo</b></a> |
|
26 |
+
</p>
|
27 |
+
|
28 |
+
https://github.com/microsoft/MInference/assets/30883354/52613efc-738f-4081-8367-7123c81d6b19
|
29 |
+
|
30 |
+
## TL;DR
|
31 |
+
|
32 |
+
**MInference 1.0** leverages the dynamic sparse nature of LLMs' attention, which exhibits some static patterns, to speed up the pre-filling for long-context LLMs. It first determines offline which sparse pattern each head belongs to, then approximates the sparse index online and dynamically computes attention with the optimal custom kernels. This approach achieves up to a **10x speedup** for pre-filling on an A100 while maintaining accuracy.
|
33 |
+
|
34 |
+
- [MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention](https://arxiv.org/abs/2406.) (Under Review)<br>
|
35 |
+
_Huiqiang Jiang†, Yucheng Li†, Chengruidong Zhang†, Qianhui Wu, Xufang Luo, Surin Ahn, Zhenhua Han, Amir H. Abdi, Dongsheng Li, Chin-Yew Lin, Yuqing Yang and Lili Qiu_
|
36 |
+
|
37 |
+
|
38 |
+
## 🎥 Overview
|
39 |
+
|
40 |
+
![Onepage of MInference](./images/MInference1_onepage.png)
|
41 |
+
|
42 |
+
## 🎯 Quick Start
|
43 |
+
|
44 |
+
### Requirements
|
45 |
+
|
46 |
+
- Torch
|
47 |
+
- FlashAttention-2
|
48 |
+
- Triton == 2.1.0
|
49 |
+
|
50 |
+
To get started with MInference, simply install it using pip:
|
51 |
+
|
52 |
+
```bash
|
53 |
+
pip install minference
|
54 |
+
```
|
55 |
+
|
56 |
+
### How to use MInference
|
57 |
+
|
58 |
+
for HF,
|
59 |
+
```diff
|
60 |
+
from transformers import pipeline
|
61 |
+
+from minference import MInference
|
62 |
+
|
63 |
+
pipe = pipeline("text-generation", model=model_name, torch_dtype="auto", device_map="auto")
|
64 |
+
|
65 |
+
# Patch MInference Module
|
66 |
+
+minference_patch = MInference("minference", model_name)
|
67 |
+
+pipe.model = minference_patch(pipe.model)
|
68 |
+
|
69 |
+
pipe(prompt, max_length=10)
|
70 |
+
```
|
71 |
+
|
72 |
+
for vLLM,
|
73 |
+
|
74 |
+
```diff
|
75 |
+
from vllm import LLM, SamplingParams
|
76 |
+
+ from minference import MInference
|
77 |
+
|
78 |
+
llm = LLM(model_name, max_num_seqs=1, enforce_eager=True, max_model_len=128000)
|
79 |
+
|
80 |
+
# Patch MInference Module
|
81 |
+
+minference_patch = MInference("vllm", model_name)
|
82 |
+
+llm = minference_patch(llm)
|
83 |
+
|
84 |
+
outputs = llm.generate(prompts, sampling_params)
|
85 |
+
```
|
86 |
+
|
87 |
+
## FAQ
|
88 |
+
|
89 |
+
For more insights and answers, visit our [FAQ section](./Transparency_FAQ.md).
|
90 |
+
|
91 |
+
**Q1: How to effectively evaluate the impact of dynamic sparse attention on the capabilities of long-context LLMs?**
|
92 |
+
|
93 |
+
To effectively evaluate long-context LLM capabilities, we tested: 1) effective context window with RULER, 2) general long-context tasks with InfiniteBench, 3) retrieval tasks across different contexts and positions with Needle in a Haystack, and 4) language model prediction with PG-19.<br/>
|
94 |
+
We found that traditional methods perform poorly in retrieval tasks, with difficulty levels varying as follows: KV retrieval (every key as a needle) > Needle in a Haystack > Retrieval.Number > Retrieval PassKey. The key challenge is the semantic difference between needles and the haystack. Traditional methods perform better when the semantic difference is larger, as in passkey tasks. KV retrieval demands higher retrieval capabilities since any key can be a target, and multi-needle tasks are even more complex.<br/>
|
95 |
+
We will continue to update our results with more models and datasets in future versions.
|
96 |
+
|
97 |
+
**Q2: Does this dynamic sparse attention pattern only exist in long-context LLMs that are not fully trained?**
|
98 |
+
|
99 |
+
Firstly, attention is dynamically sparse, and this is true for both short- and long-contexts, a characteristic inherent to the attention mechanism.
|
100 |
+
Additionally, we selected the state-of-the-art open-source long-context LLM, LLaMA-3-8B-Instruct-1M, which has an effective context window size of 16K. With MInference, this can be extended to 32K.
|
101 |
+
We will continue to adapt our method to other advanced long-context LLMs and update our results. We will also explore the theoretical reasons behind this dynamic sparse attention pattern.
|
102 |
+
|
103 |
+
**Q3: What is the relationship between MInference, SSM, Linear Attention, and Sparse Attention?**
|
104 |
+
|
105 |
+
All four approaches (MInference, SSM, Linear Attention, and Sparse Attention) are efficient solutions for optimizing the high complexity of attention in Transformers, each introducing inductive bias from different perspectives. Notably, the latter three require training from scratch.
|
106 |
+
Additionally, recent works like Mamba-2 and Unified Implicit Attention Representation unify SSM and Linear Attention as static sparse attention. Mamba-2 itself is a block-wise sparse attention method.
|
107 |
+
Intuitively, the significant sparse redundancy in attention suggests that these approaches have potential. However, static sparse attention may not handle dynamic semantic associations well, especially in complex tasks. Dynamic sparse attention, on the other hand, holds potential for better managing these dynamic relationships.
|
108 |
+
|
109 |
+
## Citation
|
110 |
+
|
111 |
+
If you find MInference useful or relevant to your project and research, please kindly cite our paper:
|
112 |
+
|
113 |
+
```bibtex
|
114 |
+
@article{jiang2024minference,
|
115 |
+
title={MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention},
|
116 |
+
author={Jiang, Huiqiang and Li, Yucheng and Zhang, Chengruidong and Wu, Qianhui and Luo, Xufang and Ahn, Surin and Han, Zhenhua and Abdi, Amir H and Li, Dongsheng and Lin, Chin-Yew and Yang, Yuqing and Qiu, Lili},
|
117 |
+
journal={arXiv},
|
118 |
+
year={2024}
|
119 |
+
}
|
120 |
+
```
|
121 |
+
|
122 |
+
## Contributing
|
123 |
+
|
124 |
+
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
125 |
+
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
126 |
+
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
127 |
+
|
128 |
+
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
129 |
+
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
130 |
+
provided by the bot. You will only need to do this once across all repos using our CLA.
|
131 |
+
|
132 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
133 |
+
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
134 |
+
contact [[email protected]](mailto:[email protected]) with any additional questions or comments.
|
135 |
+
|
136 |
+
## Trademarks
|
137 |
+
|
138 |
+
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
139 |
+
trademarks or logos is subject to and must follow
|
140 |
+
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
141 |
+
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
142 |
+
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
app.py
CHANGED
@@ -1,7 +1,148 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
+
import spaces
|
4 |
+
from transformers import GemmaTokenizer, AutoModelForCausalLM
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
6 |
+
from threading import Thread
|
7 |
+
from minference import MInference
|
8 |
|
9 |
+
# Set an environment variable
|
10 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
11 |
|
12 |
+
|
13 |
+
DESCRIPTION = '''
|
14 |
+
<div>
|
15 |
+
<h1 style="text-align: center;">Meta Llama3 8B</h1>
|
16 |
+
<p>This Space demonstrates the instruction-tuned model <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama3 8b Chat</b></a>. Meta Llama3 is the new open LLM and comes in two sizes: 8b and 70b. Feel free to play with it, or duplicate to run privately!</p>
|
17 |
+
<p>🔎 For more details about the Llama3 release and how to use the model with <code>transformers</code>, take a look <a href="https://huggingface.co/blog/llama3">at our blog post</a>.</p>
|
18 |
+
<p>🦕 Looking for an even more powerful model? Check out the <a href="https://huggingface.co/chat/"><b>Hugging Chat</b></a> integration for Meta Llama 3 70b</p>
|
19 |
+
</div>
|
20 |
+
'''
|
21 |
+
|
22 |
+
LICENSE = """
|
23 |
+
<p/>
|
24 |
+
---
|
25 |
+
Built with Meta Llama 3
|
26 |
+
"""
|
27 |
+
|
28 |
+
PLACEHOLDER = """
|
29 |
+
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
|
30 |
+
<img src="https://ysharma-dummy-chat-app.hf.space/file=/tmp/gradio/8e75e61cc9bab22b7ce3dec85ab0e6db1da5d107/Meta_lockup_positive%20primary_RGB.jpg" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
|
31 |
+
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Meta llama3</h1>
|
32 |
+
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
|
33 |
+
</div>
|
34 |
+
"""
|
35 |
+
|
36 |
+
|
37 |
+
css = """
|
38 |
+
h1 {
|
39 |
+
text-align: center;
|
40 |
+
display: block;
|
41 |
+
}
|
42 |
+
#duplicate-button {
|
43 |
+
margin: auto;
|
44 |
+
color: white;
|
45 |
+
background: #1565c0;
|
46 |
+
border-radius: 100vh;
|
47 |
+
}
|
48 |
+
"""
|
49 |
+
|
50 |
+
# Load the tokenizer and model
|
51 |
+
model_name = "gradientai/Llama-3-8B-Instruct-262k"
|
52 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
53 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") # to("cuda:0")
|
54 |
+
minference_patch = MInference("minference", model_name)
|
55 |
+
model = minference_patch(model)
|
56 |
+
|
57 |
+
terminators = [
|
58 |
+
tokenizer.eos_token_id,
|
59 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
60 |
+
]
|
61 |
+
|
62 |
+
@spaces.GPU(duration=120)
|
63 |
+
def chat_llama3_8b(message: str,
|
64 |
+
history: list,
|
65 |
+
temperature: float,
|
66 |
+
max_new_tokens: int
|
67 |
+
) -> str:
|
68 |
+
"""
|
69 |
+
Generate a streaming response using the llama3-8b model.
|
70 |
+
Args:
|
71 |
+
message (str): The input message.
|
72 |
+
history (list): The conversation history used by ChatInterface.
|
73 |
+
temperature (float): The temperature for generating the response.
|
74 |
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
75 |
+
Returns:
|
76 |
+
str: The generated response.
|
77 |
+
"""
|
78 |
+
conversation = []
|
79 |
+
for user, assistant in history:
|
80 |
+
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
81 |
+
conversation.append({"role": "user", "content": message})
|
82 |
+
|
83 |
+
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
|
84 |
+
|
85 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
86 |
+
|
87 |
+
generate_kwargs = dict(
|
88 |
+
input_ids= input_ids,
|
89 |
+
streamer=streamer,
|
90 |
+
max_new_tokens=max_new_tokens,
|
91 |
+
do_sample=True,
|
92 |
+
temperature=temperature,
|
93 |
+
eos_token_id=terminators,
|
94 |
+
)
|
95 |
+
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
|
96 |
+
if temperature == 0:
|
97 |
+
generate_kwargs['do_sample'] = False
|
98 |
+
|
99 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
100 |
+
t.start()
|
101 |
+
|
102 |
+
outputs = []
|
103 |
+
for text in streamer:
|
104 |
+
outputs.append(text)
|
105 |
+
#print(outputs)
|
106 |
+
yield "".join(outputs)
|
107 |
+
|
108 |
+
|
109 |
+
# Gradio block
|
110 |
+
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
|
111 |
+
|
112 |
+
with gr.Blocks(fill_height=True, css=css) as demo:
|
113 |
+
|
114 |
+
gr.Markdown(DESCRIPTION)
|
115 |
+
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
|
116 |
+
gr.ChatInterface(
|
117 |
+
fn=chat_llama3_8b,
|
118 |
+
chatbot=chatbot,
|
119 |
+
fill_height=True,
|
120 |
+
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
|
121 |
+
additional_inputs=[
|
122 |
+
gr.Slider(minimum=0,
|
123 |
+
maximum=1,
|
124 |
+
step=0.1,
|
125 |
+
value=0.95,
|
126 |
+
label="Temperature",
|
127 |
+
render=False),
|
128 |
+
gr.Slider(minimum=128,
|
129 |
+
maximum=4096,
|
130 |
+
step=1,
|
131 |
+
value=512,
|
132 |
+
label="Max new tokens",
|
133 |
+
render=False ),
|
134 |
+
],
|
135 |
+
examples=[
|
136 |
+
['How to setup a human base on Mars? Give short answer.'],
|
137 |
+
['Explain theory of relativity to me like I’m 8 years old.'],
|
138 |
+
['What is 9,000 * 9,000?'],
|
139 |
+
['Write a pun-filled happy birthday message to my friend Alex.'],
|
140 |
+
['Justify why a penguin might make a good king of the jungle.']
|
141 |
+
],
|
142 |
+
cache_examples=False,
|
143 |
+
)
|
144 |
+
|
145 |
+
gr.Markdown(LICENSE)
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
demo.launch()
|
images/MInference1_onepage.png
ADDED
images/MInference_logo.png
ADDED
images/benchmarks/needle_viz_LLaMA-3-8B-1M_ours_1K_1000K.png
ADDED
images/benchmarks/ppl-LLaMA-3-262k.png
ADDED
minference/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
# flake8: noqa
|
4 |
+
from .minference_configuration import MInferenceConfig
|
5 |
+
from .models_patch import MInference
|
6 |
+
from .ops.block_sparse_flash_attention import block_sparse_attention
|
7 |
+
from .ops.pit_sparse_flash_attention_v2 import vertical_slash_sparse_attention
|
8 |
+
from .ops.streaming_kernel import streaming_forward
|
9 |
+
from .patch import (
|
10 |
+
minference_patch,
|
11 |
+
minference_patch_kv_cache_cpu,
|
12 |
+
minference_patch_with_snapkv,
|
13 |
+
patch_hf,
|
14 |
+
)
|
15 |
+
from .version import VERSION as __version__
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"MInference",
|
19 |
+
"MInferenceConfig",
|
20 |
+
"minference_patch",
|
21 |
+
"minference_patch_kv_cache_cpu",
|
22 |
+
"minference_patch_with_snapkv",
|
23 |
+
"patch_hf",
|
24 |
+
"vertical_slash_sparse_attention",
|
25 |
+
"block_sparse_attention",
|
26 |
+
"streaming_forward",
|
27 |
+
]
|
minference/configs/Llama_3_8B_Instruct_262k_kv_out_v32_fit_o_best_pattern.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
[{"0": ["vertical_and_slash", 1000, 6096, 336], "1": ["vertical_and_slash", 1000, 6096, 26473], "2": ["vertical_and_slash", 1000, 6096, 0], "3": ["vertical_and_slash", 1000, 6096, 26958], "4": ["vertical_and_slash", 1000, 6096, 18905], "5": ["vertical_and_slash", 1000, 6096, 27990], "6": ["vertical_and_slash", 1000, 6096, 15162], "7": ["vertical_and_slash", 1000, 6096, 10529], "8": ["vertical_and_slash", 1000, 6096, 2], "9": ["vertical_and_slash", 1000, 6096, 828], "10": ["vertical_and_slash", 1000, 6096, 11405], "11": ["vertical_and_slash", 1000, 6096, 0], "12": ["vertical_and_slash", 1000, 6096, 55], "13": ["vertical_and_slash", 1000, 6096, 1], "14": ["vertical_and_slash", 1000, 6096, 0], "15": ["vertical_and_slash", 1000, 6096, 7021], "16": ["vertical_and_slash", 30, 800, 185169], "17": ["vertical_and_slash", 30, 800, 72929], "18": ["vertical_and_slash", 30, 800, 460008], "19": ["vertical_and_slash", 1000, 6096, 0], "20": ["vertical_and_slash", 1000, 6096, 71729], "21": ["vertical_and_slash", 1000, 6096, 52], "22": ["vertical_and_slash", 1000, 6096, 636], "23": ["vertical_and_slash", 1000, 6096, 75020], "24": ["vertical_and_slash", 1000, 6096, 23545], "25": ["vertical_and_slash", 1000, 6096, 90256], "26": ["vertical_and_slash", 1000, 6096, 45294], "27": ["vertical_and_slash", 1000, 6096, 32617], "28": ["vertical_and_slash", 3500, 100, 4777248], "29": ["vertical_and_slash", 3500, 100, 3996], "30": ["vertical_and_slash", 3500, 100, 590252], "31": ["vertical_and_slash", 3500, 100, 0]}, {"0": ["vertical_and_slash", 30, 800, 11048], "1": ["vertical_and_slash", 30, 800, 99768], "2": ["vertical_and_slash", 1000, 6096, 1393328], "3": ["vertical_and_slash", 30, 800, 97570], "4": ["vertical_and_slash", 30, 800, 9], "5": ["vertical_and_slash", 30, 800, 18], "6": ["vertical_and_slash", 30, 800, 216277], "7": ["vertical_and_slash", 30, 800, 148491], "8": ["vertical_and_slash", 100, 800, 543785], "9": ["vertical_and_slash", 1000, 6096, 2343829], "10": ["vertical_and_slash", 100, 800, 251542], "11": ["vertical_and_slash", 30, 800, 1064367], "12": ["vertical_and_slash", 1000, 6096, 6092], "13": ["vertical_and_slash", 30, 800, 12654], "14": ["vertical_and_slash", 1000, 6096, 0], "15": ["vertical_and_slash", 1000, 6096, 101], "16": ["vertical_and_slash", 30, 800, 21873], "17": ["vertical_and_slash", 30, 800, 107039], "18": ["vertical_and_slash", 30, 800, 9011], "19": ["vertical_and_slash", 30, 800, 445736], "20": ["vertical_and_slash", 30, 800, 1906], "21": ["vertical_and_slash", 30, 800, 3058], "22": ["vertical_and_slash", 1000, 6096, 430742], "23": ["vertical_and_slash", 1000, 6096, 181839], "24": ["vertical_and_slash", 30, 800, 125666], "25": ["vertical_and_slash", 30, 800, 704271], "26": ["vertical_and_slash", 30, 800, 14405], "27": ["vertical_and_slash", 30, 800, 70563], "28": ["vertical_and_slash", 1000, 6096, 38630], "29": ["vertical_and_slash", 1000, 6096, 68041], "30": ["vertical_and_slash", 30, 800, 6942], "31": ["vertical_and_slash", 1000, 6096, 35430]}, {"0": ["vertical_and_slash", 30, 800, 2720], "1": ["vertical_and_slash", 1000, 6096, 3045], "2": ["vertical_and_slash", 30, 800, 785], "3": ["vertical_and_slash", 1000, 6096, 14146], "4": ["vertical_and_slash", 100, 800, 315229], "5": ["vertical_and_slash", 1000, 6096, 195280], "6": ["vertical_and_slash", 1000, 6096, 1640055], "7": ["vertical_and_slash", 30, 800, 21026], "8": ["vertical_and_slash", 30, 800, 1082], "9": ["vertical_and_slash", 30, 800, 1851], "10": ["vertical_and_slash", 100, 800, 97766], "11": ["vertical_and_slash", 30, 800, 14401], "12": ["vertical_and_slash", 100, 800, 55741], "13": ["vertical_and_slash", 30, 800, 100674], "14": ["vertical_and_slash", 100, 800, 5597503], "15": ["vertical_and_slash", 1000, 6096, 437796], "16": ["vertical_and_slash", 30, 800, 9647], "17": ["vertical_and_slash", 30, 800, 4590], "18": ["vertical_and_slash", 30, 800, 73], "19": ["vertical_and_slash", 1000, 6096, 823400], "20": ["vertical_and_slash", 1000, 6096, 464893], "21": ["vertical_and_slash", 1000, 6096, 406520], "22": ["vertical_and_slash", 1000, 6096, 49477], "23": ["vertical_and_slash", 30, 800, 25445], "24": ["vertical_and_slash", 30, 800, 172935], "25": ["vertical_and_slash", 30, 800, 125813], "26": ["vertical_and_slash", 30, 800, 35964], "27": ["vertical_and_slash", 30, 800, 64113], "28": ["vertical_and_slash", 30, 800, 8780], "29": ["vertical_and_slash", 30, 800, 7883], "30": ["vertical_and_slash", 30, 800, 3944], "31": ["vertical_and_slash", 30, 800, 1049]}, {"0": ["vertical_and_slash", 1000, 6096, 119045], "1": ["vertical_and_slash", 1000, 6096, 21633], "2": ["vertical_and_slash", 1000, 6096, 54], "3": ["vertical_and_slash", 1000, 6096, 756], "4": ["vertical_and_slash", 30, 800, 1524], "5": ["vertical_and_slash", 30, 800, 7576], "6": ["vertical_and_slash", 30, 800, 212024], "7": ["vertical_and_slash", 30, 800, 106253], "8": ["vertical_and_slash", 30, 800, 4801], "9": ["vertical_and_slash", 30, 800, 311445], "10": ["vertical_and_slash", 30, 800, 31540], "11": ["vertical_and_slash", 30, 800, 7706], "12": ["vertical_and_slash", 1000, 6096, 397], "13": ["vertical_and_slash", 1000, 6096, 40], "14": ["vertical_and_slash", 100, 800, 181], "15": ["vertical_and_slash", 1000, 6096, 15], "16": ["vertical_and_slash", 30, 800, 424080], "17": ["vertical_and_slash", 30, 800, 66114], "18": ["vertical_and_slash", 30, 800, 132526], "19": ["vertical_and_slash", 30, 800, 1478993], "20": ["vertical_and_slash", 1000, 6096, 655153], "21": ["vertical_and_slash", 1000, 6096, 117322], "22": ["vertical_and_slash", 1000, 6096, 572237], "23": ["vertical_and_slash", 1000, 6096, 688623], "24": ["vertical_and_slash", 1000, 6096, 294], "25": ["vertical_and_slash", 1000, 6096, 5035], "26": ["vertical_and_slash", 30, 800, 3874], "27": ["vertical_and_slash", 1000, 6096, 618117], "28": ["vertical_and_slash", 30, 800, 545357], "29": ["vertical_and_slash", 30, 800, 1746675], "30": ["vertical_and_slash", 30, 800, 612225], "31": ["vertical_and_slash", 100, 800, 232415]}, {"0": ["vertical_and_slash", 100, 800, 5379826], "1": ["vertical_and_slash", 100, 800, 4399425], "2": ["vertical_and_slash", 100, 800, 5842], "3": ["vertical_and_slash", 30, 800, 178263], "4": ["vertical_and_slash", 30, 800, 356], "5": ["vertical_and_slash", 30, 800, 2387916], "6": ["vertical_and_slash", 1000, 6096, 216595], "7": ["vertical_and_slash", 30, 800, 466], "8": ["vertical_and_slash", 1000, 6096, 832044], "9": ["vertical_and_slash", 1000, 6096, 59709], "10": ["vertical_and_slash", 1000, 6096, 1194089], "11": ["vertical_and_slash", 1000, 6096, 356408], "12": ["vertical_and_slash", 30, 800, 30528], "13": ["vertical_and_slash", 30, 800, 22217], "14": ["vertical_and_slash", 30, 800, 9162], "15": ["vertical_and_slash", 100, 800, 1641325], "16": ["vertical_and_slash", 1000, 6096, 489936], "17": ["vertical_and_slash", 30, 800, 58107], "18": ["vertical_and_slash", 1000, 6096, 8539], "19": ["vertical_and_slash", 1000, 6096, 508038], "20": ["vertical_and_slash", 100, 800, 2632857], "21": ["vertical_and_slash", 1000, 6096, 79517], "22": ["vertical_and_slash", 30, 800, 330362], "23": ["vertical_and_slash", 1000, 6096, 85961], "24": ["vertical_and_slash", 30, 800, 23942], "25": ["vertical_and_slash", 30, 800, 75337], "26": ["vertical_and_slash", 30, 800, 3544417], "27": ["vertical_and_slash", 30, 800, 146427], "28": ["vertical_and_slash", 1000, 6096, 10561], "29": ["vertical_and_slash", 100, 800, 8759352], "30": ["vertical_and_slash", 100, 800, 8425], "31": ["vertical_and_slash", 30, 800, 22]}, {"0": ["vertical_and_slash", 30, 800, 50473], "1": ["vertical_and_slash", 1000, 6096, 277369], "2": ["vertical_and_slash", 30, 800, 59349], "3": ["vertical_and_slash", 30, 800, 27256], "4": ["vertical_and_slash", 30, 800, 112822], "5": ["vertical_and_slash", 1000, 6096, 346887], "6": ["vertical_and_slash", 1000, 6096, 84774], "7": ["vertical_and_slash", 1000, 6096, 954773], "8": ["vertical_and_slash", 1000, 6096, 1210908], "9": ["vertical_and_slash", 1000, 6096, 1679398], "10": ["vertical_and_slash", 1000, 6096, 2474351], "11": ["vertical_and_slash", 1000, 6096, 80495], "12": ["vertical_and_slash", 30, 800, 56761], "13": ["vertical_and_slash", 30, 800, 27757], "14": ["vertical_and_slash", 30, 800, 8811], "15": ["vertical_and_slash", 30, 800, 31547], "16": ["vertical_and_slash", 100, 800, 93167], "17": ["vertical_and_slash", 1000, 6096, 1464896], "18": ["vertical_and_slash", 1000, 6096, 434459], "19": ["vertical_and_slash", 30, 800, 1654521], "20": ["vertical_and_slash", 1000, 6096, 414], "21": ["vertical_and_slash", 1000, 6096, 76207], "22": ["vertical_and_slash", 1000, 6096, 8583], "23": ["vertical_and_slash", 1000, 6096, 1471], "24": ["vertical_and_slash", 1000, 6096, 231656], "25": ["vertical_and_slash", 500, 700, 95889], "26": ["vertical_and_slash", 30, 800, 62035], "27": ["vertical_and_slash", 1000, 6096, 43859], "28": ["vertical_and_slash", 30, 800, 23458], "29": ["vertical_and_slash", 30, 800, 53092], "30": ["vertical_and_slash", 30, 800, 74240], "31": ["vertical_and_slash", 30, 800, 45214]}, {"0": ["vertical_and_slash", 30, 800, 507], "1": ["vertical_and_slash", 100, 800, 8490], "2": ["vertical_and_slash", 100, 800, 3952118], "3": ["vertical_and_slash", 100, 800, 2475164], "4": ["vertical_and_slash", 100, 800, 8038], "5": ["vertical_and_slash", 30, 800, 2620494], "6": ["vertical_and_slash", 1000, 6096, 57306], "7": ["vertical_and_slash", 30, 800, 18889], "8": ["vertical_and_slash", 30, 800, 14900], "9": ["vertical_and_slash", 30, 800, 310453], "10": ["vertical_and_slash", 30, 800, 5494], "11": ["vertical_and_slash", 30, 800, 16096], "12": ["vertical_and_slash", 30, 800, 45897], "13": ["vertical_and_slash", 30, 800, 120295], "14": ["vertical_and_slash", 30, 800, 1446587], "15": ["vertical_and_slash", 30, 800, 133562], "16": ["vertical_and_slash", 30, 800, 81561], "17": ["vertical_and_slash", 100, 800, 1091558], "18": ["vertical_and_slash", 30, 800, 1104027], "19": ["vertical_and_slash", 30, 800, 95228], "20": ["vertical_and_slash", 1000, 6096, 81766], "21": ["vertical_and_slash", 1000, 6096, 1604474], "22": ["vertical_and_slash", 30, 800, 1720847], "23": ["vertical_and_slash", 30, 800, 254367], "24": ["vertical_and_slash", 1000, 6096, 69837], "25": ["vertical_and_slash", 1000, 6096, 1346498], "26": ["vertical_and_slash", 1000, 6096, 251707], "27": ["vertical_and_slash", 1000, 6096, 21055], "28": ["vertical_and_slash", 100, 800, 1310349], "29": ["vertical_and_slash", 1000, 6096, 523], "30": ["vertical_and_slash", 100, 800, 5], "31": ["vertical_and_slash", 1000, 6096, 4114]}, {"0": ["vertical_and_slash", 30, 800, 2076100], "1": ["vertical_and_slash", 30, 800, 742482], "2": ["vertical_and_slash", 30, 800, 84396], "3": ["vertical_and_slash", 100, 800, 6621015], "4": ["vertical_and_slash", 30, 800, 269671], "5": ["vertical_and_slash", 30, 800, 142041], "6": ["vertical_and_slash", 1000, 6096, 2493869], "7": ["vertical_and_slash", 1000, 6096, 2460341], "8": ["vertical_and_slash", 30, 800, 352690], "9": ["vertical_and_slash", 30, 800, 134441], "10": ["vertical_and_slash", 1000, 6096, 112278], "11": ["vertical_and_slash", 30, 800, 62933], "12": ["vertical_and_slash", 30, 800, 150459], "13": ["vertical_and_slash", 1000, 6096, 120036], "14": ["vertical_and_slash", 100, 800, 433238], "15": ["vertical_and_slash", 100, 800, 2723047], "16": ["vertical_and_slash", 1000, 6096, 112925], "17": ["vertical_and_slash", 1000, 6096, 23380], "18": ["vertical_and_slash", 1000, 6096, 92620], "19": ["vertical_and_slash", 1000, 6096, 37993], "20": ["vertical_and_slash", 100, 800, 74928], "21": ["vertical_and_slash", 3500, 100, 14191655], "22": ["vertical_and_slash", 1000, 6096, 514675], "23": ["vertical_and_slash", 100, 800, 9577073], "24": ["vertical_and_slash", 100, 800, 531136], "25": ["vertical_and_slash", 1000, 6096, 30007], "26": ["vertical_and_slash", 1000, 6096, 170687], "27": ["vertical_and_slash", 30, 800, 540287], "28": ["vertical_and_slash", 30, 800, 1435852], "29": ["vertical_and_slash", 30, 800, 948060], "30": ["vertical_and_slash", 1000, 6096, 37219], "31": ["vertical_and_slash", 1000, 6096, 211641]}, {"0": ["vertical_and_slash", 1000, 6096, 582795], "1": ["vertical_and_slash", 1000, 6096, 6289238], "2": ["vertical_and_slash", 1000, 6096, 570805], "3": ["vertical_and_slash", 1000, 6096, 198493], "4": ["vertical_and_slash", 30, 800, 112215], "5": ["vertical_and_slash", 30, 800, 5387246], "6": ["vertical_and_slash", 30, 800, 754350], "7": ["vertical_and_slash", 1000, 6096, 164737], "8": ["vertical_and_slash", 1000, 6096, 8597099], "9": ["vertical_and_slash", 1000, 6096, 13891466], "10": ["vertical_and_slash", 100, 800, 12184646], "11": ["vertical_and_slash", 1000, 6096, 3397834], "12": ["vertical_and_slash", 1000, 6096, 274297], "13": ["vertical_and_slash", 30, 800, 505818], "14": ["vertical_and_slash", 1000, 6096, 382749], "15": ["vertical_and_slash", 1000, 6096, 53485], "16": ["vertical_and_slash", 1000, 6096, 63748], "17": ["vertical_and_slash", 1000, 6096, 743437], "18": ["vertical_and_slash", 1000, 6096, 884226], "19": ["vertical_and_slash", 1000, 6096, 32754], "20": ["vertical_and_slash", 30, 800, 154807], "21": ["vertical_and_slash", 30, 800, 515833], "22": ["vertical_and_slash", 30, 800, 379827], "23": ["vertical_and_slash", 30, 800, 5140670], "24": ["vertical_and_slash", 1000, 6096, 8857], "25": ["vertical_and_slash", 1000, 6096, 9739], "26": ["vertical_and_slash", 1000, 6096, 3362559], "27": ["vertical_and_slash", 1000, 6096, 3602170], "28": ["vertical_and_slash", 1000, 6096, 286758], "29": ["vertical_and_slash", 1000, 6096, 1091568], "30": ["vertical_and_slash", 1000, 6096, 464410], "31": ["vertical_and_slash", 1000, 6096, 9113238]}, {"0": ["vertical_and_slash", 1000, 6096, 4112309], "1": ["vertical_and_slash", 1000, 6096, 6237157], "2": ["vertical_and_slash", 1000, 6096, 12411496], "3": ["vertical_and_slash", 1000, 6096, 3333545], "4": ["vertical_and_slash", 1000, 6096, 1082199], "5": ["vertical_and_slash", 1000, 6096, 3624535], "6": ["vertical_and_slash", 1000, 6096, 85587], "7": ["vertical_and_slash", 1000, 6096, 5060732], "8": ["vertical_and_slash", 30, 800, 981020], "9": ["vertical_and_slash", 30, 800, 647089], "10": ["vertical_and_slash", 30, 800, 1168497], "11": ["vertical_and_slash", 30, 800, 241811], "12": ["vertical_and_slash", 1000, 6096, 14258787], "13": ["vertical_and_slash", 1000, 6096, 13881708], "14": ["vertical_and_slash", 100, 800, 9807781], "15": ["vertical_and_slash", 1000, 6096, 11824390], "16": ["vertical_and_slash", 1000, 6096, 382173], "17": ["vertical_and_slash", 1000, 6096, 682553], "18": ["vertical_and_slash", 1000, 6096, 228115], "19": ["vertical_and_slash", 1000, 6096, 730935], "20": ["vertical_and_slash", 1000, 6096, 10237660], "21": ["vertical_and_slash", 1000, 6096, 210229], "22": ["vertical_and_slash", 1000, 6096, 4883397], "23": ["vertical_and_slash", 1000, 6096, 569329], "24": ["vertical_and_slash", 100, 800, 4152], "25": ["vertical_and_slash", 1000, 6096, 235235], "26": ["vertical_and_slash", 100, 800, 22473], "27": ["vertical_and_slash", 3500, 100, 14276508], "28": ["vertical_and_slash", 1000, 6096, 2277550], "29": ["vertical_and_slash", 1000, 6096, 1821096], "30": ["vertical_and_slash", 30, 800, 1212061], "31": ["vertical_and_slash", 1000, 6096, 13192107]}, {"0": ["vertical_and_slash", 1000, 6096, 812453], "1": ["vertical_and_slash", 1000, 6096, 6634405], "2": ["vertical_and_slash", 1000, 6096, 6896128], "3": ["vertical_and_slash", 1000, 6096, 12539813], "4": ["vertical_and_slash", 1000, 6096, 90867], "5": ["vertical_and_slash", 1000, 6096, 592412], "6": ["vertical_and_slash", 1000, 6096, 1863965], "7": ["vertical_and_slash", 1000, 6096, 1412714], "8": ["vertical_and_slash", 100, 800, 4723238], "9": ["vertical_and_slash", 30, 800, 73268], "10": ["vertical_and_slash", 1000, 6096, 522198], "11": ["vertical_and_slash", 30, 800, 144456], "12": ["vertical_and_slash", 1000, 6096, 218571], "13": ["vertical_and_slash", 1000, 6096, 4766244], "14": ["vertical_and_slash", 1000, 6096, 519409], "15": ["vertical_and_slash", 100, 800, 257427], "16": ["vertical_and_slash", 30, 800, 913307], "17": ["vertical_and_slash", 1000, 6096, 272105], "18": ["vertical_and_slash", 1000, 6096, 10253560], "19": ["vertical_and_slash", 1000, 6096, 103219], "20": ["vertical_and_slash", 1000, 6096, 825917], "21": ["vertical_and_slash", 1000, 6096, 1573906], "22": ["vertical_and_slash", 1000, 6096, 1401963], "23": ["vertical_and_slash", 1000, 6096, 903562], "24": ["vertical_and_slash", 1000, 6096, 116448], "25": ["vertical_and_slash", 500, 700, 10497021], "26": ["vertical_and_slash", 1000, 6096, 1451038], "27": ["vertical_and_slash", 100, 800, 9129837], "28": ["vertical_and_slash", 1000, 6096, 6069558], "29": ["vertical_and_slash", 100, 800, 4906900], "30": ["vertical_and_slash", 100, 800, 1935350], "31": ["vertical_and_slash", 1000, 6096, 13438131]}, {"0": ["vertical_and_slash", 1000, 6096, 200475], "1": ["vertical_and_slash", 1000, 6096, 2525357], "2": ["vertical_and_slash", 1000, 6096, 1581552], "3": ["vertical_and_slash", 1000, 6096, 1585962], "4": ["vertical_and_slash", 100, 800, 2468769], "5": ["vertical_and_slash", 1000, 6096, 2284149], "6": ["vertical_and_slash", 1000, 6096, 3954975], "7": ["vertical_and_slash", 1000, 6096, 12242517], "8": ["vertical_and_slash", 1000, 6096, 407981], "9": ["vertical_and_slash", 1000, 6096, 387918], "10": ["vertical_and_slash", 30, 800, 494970], "11": ["vertical_and_slash", 1000, 6096, 237593], "12": ["vertical_and_slash", 1000, 6096, 13227100], "13": ["vertical_and_slash", 1000, 6096, 7150283], "14": ["vertical_and_slash", 1000, 6096, 1460829], "15": ["vertical_and_slash", 1000, 6096, 5830515], "16": ["vertical_and_slash", 30, 800, 321990], "17": ["vertical_and_slash", 500, 700, 412885], "18": ["vertical_and_slash", 30, 800, 7754087], "19": ["vertical_and_slash", 30, 800, 593222], "20": ["vertical_and_slash", 1000, 6096, 9430066], "21": ["vertical_and_slash", 1000, 6096, 11445545], "22": ["vertical_and_slash", 1000, 6096, 10096832], "23": ["vertical_and_slash", 1000, 6096, 11108827], "24": ["vertical_and_slash", 1000, 6096, 2040566], "25": ["vertical_and_slash", 1000, 6096, 1293645], "26": ["vertical_and_slash", 1000, 6096, 1681146], "27": ["vertical_and_slash", 1000, 6096, 1621078], "28": ["vertical_and_slash", 3500, 100, 14482863], "29": ["vertical_and_slash", 3500, 100, 14306340], "30": ["vertical_and_slash", 3500, 100, 14736032], "31": ["vertical_and_slash", 30, 800, 59474]}, {"0": ["vertical_and_slash", 30, 800, 2015977], "1": ["vertical_and_slash", 1000, 6096, 1851908], "2": ["vertical_and_slash", 500, 700, 3019045], "3": ["vertical_and_slash", 30, 800, 2275137], "4": ["vertical_and_slash", 1000, 6096, 111007], "5": ["vertical_and_slash", 1000, 6096, 74876], "6": ["vertical_and_slash", 1000, 6096, 291657], "7": ["vertical_and_slash", 1000, 6096, 72059], "8": ["vertical_and_slash", 100, 800, 4966732], "9": ["vertical_and_slash", 30, 800, 1227926], "10": ["vertical_and_slash", 1000, 6096, 817635], "11": ["vertical_and_slash", 100, 800, 1996081], "12": ["vertical_and_slash", 30, 800, 320794], "13": ["vertical_and_slash", 30, 800, 641018], "14": ["vertical_and_slash", 1000, 6096, 784584], "15": ["vertical_and_slash", 500, 700, 615730], "16": ["vertical_and_slash", 30, 800, 130637], "17": ["vertical_and_slash", 500, 700, 237719], "18": ["vertical_and_slash", 30, 800, 484009], "19": ["vertical_and_slash", 30, 800, 71667], "20": ["vertical_and_slash", 30, 800, 6034932], "21": ["vertical_and_slash", 30, 800, 279606], "22": ["vertical_and_slash", 30, 800, 273046], "23": ["vertical_and_slash", 500, 700, 5343396], "24": ["vertical_and_slash", 30, 800, 424419], "25": ["vertical_and_slash", 30, 800, 268585], "26": ["vertical_and_slash", 500, 700, 469509], "27": ["vertical_and_slash", 30, 800, 1150183], "28": ["vertical_and_slash", 30, 800, 567665], "29": ["vertical_and_slash", 30, 800, 689969], "30": ["vertical_and_slash", 30, 800, 3124447], "31": ["vertical_and_slash", 500, 700, 1311816]}, {"0": ["vertical_and_slash", 1000, 6096, 13054849], "1": ["vertical_and_slash", 1000, 6096, 11676492], "2": ["vertical_and_slash", 1000, 6096, 13662962], "3": ["vertical_and_slash", 1000, 6096, 13009510], "4": ["vertical_and_slash", 1000, 6096, 13228770], "5": ["vertical_and_slash", 1000, 6096, 13738897], "6": ["vertical_and_slash", 1000, 6096, 4327684], "7": ["vertical_and_slash", 100, 800, 1780647], "8": ["vertical_and_slash", 1000, 6096, 12984525], "9": ["vertical_and_slash", 1000, 6096, 10106452], "10": ["vertical_and_slash", 1000, 6096, 13121645], "11": ["vertical_and_slash", 1000, 6096, 7143877], "12": ["vertical_and_slash", 1000, 6096, 1302273], "13": ["vertical_and_slash", 1000, 6096, 12189960], "14": ["vertical_and_slash", 1000, 6096, 10369892], "15": ["vertical_and_slash", 1000, 6096, 6251432], "16": ["vertical_and_slash", 1000, 6096, 13767358], "17": ["vertical_and_slash", 1000, 6096, 14264179], "18": ["vertical_and_slash", 1000, 6096, 14027354], "19": ["vertical_and_slash", 1000, 6096, 12810299], "20": ["vertical_and_slash", 1000, 6096, 11500719], "21": ["vertical_and_slash", 1000, 6096, 8729013], "22": ["vertical_and_slash", 100, 800, 1386474], "23": ["vertical_and_slash", 1000, 6096, 8809015], "24": ["vertical_and_slash", 30, 800, 1192385], "25": ["vertical_and_slash", 100, 800, 6597145], "26": ["vertical_and_slash", 100, 800, 11801029], "27": ["vertical_and_slash", 1000, 6096, 981847], "28": ["vertical_and_slash", 1000, 6096, 3790181], "29": ["vertical_and_slash", 30, 800, 1641474], "30": ["vertical_and_slash", 1000, 6096, 4214917], "31": ["vertical_and_slash", 1000, 6096, 3423871]}, {"0": ["vertical_and_slash", 1000, 6096, 7281028], "1": ["vertical_and_slash", 1000, 6096, 6327889], "2": ["vertical_and_slash", 1000, 6096, 5161807], "3": ["vertical_and_slash", 1000, 6096, 6944365], "4": ["vertical_and_slash", 1000, 6096, 10798408], "5": ["vertical_and_slash", 1000, 6096, 11848526], "6": ["vertical_and_slash", 1000, 6096, 5023703], "7": ["vertical_and_slash", 1000, 6096, 6869756], "8": ["vertical_and_slash", 30, 800, 2070673], "9": ["vertical_and_slash", 30, 800, 2108039], "10": ["vertical_and_slash", 30, 800, 2478923], "11": ["vertical_and_slash", 30, 800, 1062019], "12": ["vertical_and_slash", 1000, 6096, 10483422], "13": ["vertical_and_slash", 1000, 6096, 13220734], "14": ["vertical_and_slash", 1000, 6096, 10864461], "15": ["vertical_and_slash", 1000, 6096, 10380263], "16": ["vertical_and_slash", 1000, 6096, 12606664], "17": ["vertical_and_slash", 1000, 6096, 12755695], "18": ["vertical_and_slash", 1000, 6096, 14481440], "19": ["vertical_and_slash", 1000, 6096, 12125755], "20": ["vertical_and_slash", 1000, 6096, 13727938], "21": ["vertical_and_slash", 100, 800, 9986525], "22": ["vertical_and_slash", 1000, 6096, 13802294], "23": ["vertical_and_slash", 1000, 6096, 8589854], "24": ["vertical_and_slash", 1000, 6096, 8696624], "25": ["vertical_and_slash", 1000, 6096, 6711141], "26": ["vertical_and_slash", 30, 800, 11407], "27": ["vertical_and_slash", 1000, 6096, 10286733], "28": ["vertical_and_slash", 100, 800, 14346519], "29": ["vertical_and_slash", 3500, 100, 14822370], "30": ["vertical_and_slash", 1000, 6096, 13996996], "31": ["vertical_and_slash", 3500, 100, 13837843]}, {"0": ["vertical_and_slash", 30, 800, 187826], "1": ["vertical_and_slash", 1000, 6096, 319682], "2": ["vertical_and_slash", 1000, 6096, 717971], "3": ["vertical_and_slash", 1000, 6096, 12248225], "4": ["vertical_and_slash", 30, 800, 2311494], "5": ["vertical_and_slash", 1000, 6096, 354949], "6": ["vertical_and_slash", 30, 800, 2723442], "7": ["vertical_and_slash", 30, 800, 217627], "8": ["vertical_and_slash", 500, 700, 1800505], "9": ["vertical_and_slash", 30, 800, 5395314], "10": ["vertical_and_slash", 30, 800, 10715415], "11": ["vertical_and_slash", 100, 800, 13267898], "12": ["vertical_and_slash", 30, 800, 282819], "13": ["vertical_and_slash", 1000, 6096, 8417130], "14": ["vertical_and_slash", 1000, 6096, 5380564], "15": ["vertical_and_slash", 1000, 6096, 9802765], "16": ["vertical_and_slash", 1000, 6096, 385044], "17": ["vertical_and_slash", 1000, 6096, 2048601], "18": ["vertical_and_slash", 1000, 6096, 2798283], "19": ["vertical_and_slash", 100, 800, 11985153], "20": ["vertical_and_slash", 1000, 6096, 9560488], "21": ["vertical_and_slash", 1000, 6096, 8719957], "22": ["vertical_and_slash", 1000, 6096, 10883722], "23": ["vertical_and_slash", 1000, 6096, 11184293], "24": ["vertical_and_slash", 1000, 6096, 5049287], "25": ["vertical_and_slash", 1000, 6096, 6119952], "26": ["vertical_and_slash", 1000, 6096, 11948638], "27": ["vertical_and_slash", 1000, 6096, 4654529], "28": ["vertical_and_slash", 1000, 6096, 269543], "29": ["vertical_and_slash", 1000, 6096, 1183543], "30": ["vertical_and_slash", 1000, 6096, 4018748], "31": ["vertical_and_slash", 30, 800, 208750]}, {"0": ["vertical_and_slash", 3500, 100, 14712977], "1": ["vertical_and_slash", 1000, 6096, 7977346], "2": ["vertical_and_slash", 100, 800, 12022826], "3": ["vertical_and_slash", 100, 800, 7525648], "4": ["vertical_and_slash", 500, 700, 627445], "5": ["vertical_and_slash", 1000, 6096, 1067661], "6": ["vertical_and_slash", 500, 700, 199111], "7": ["vertical_and_slash", 100, 800, 1462908], "8": ["vertical_and_slash", 1000, 6096, 12608289], "9": ["vertical_and_slash", 1000, 6096, 3815760], "10": ["vertical_and_slash", 100, 800, 5050623], "11": ["vertical_and_slash", 3500, 100, 6790875], "12": ["vertical_and_slash", 30, 800, 284918], "13": ["vertical_and_slash", 500, 700, 277887], "14": ["vertical_and_slash", 500, 700, 236664], "15": ["vertical_and_slash", 30, 800, 3582148], "16": ["vertical_and_slash", 100, 800, 13373963], "17": ["vertical_and_slash", 100, 800, 682950], "18": ["vertical_and_slash", 1000, 6096, 7136486], "19": ["vertical_and_slash", 1000, 6096, 13769505], "20": ["vertical_and_slash", 1000, 6096, 9883913], "21": ["vertical_and_slash", 1000, 6096, 10833503], "22": ["vertical_and_slash", 30, 800, 62940], "23": ["vertical_and_slash", 1000, 6096, 4652762], "24": ["vertical_and_slash", 1000, 6096, 5480379], "25": ["vertical_and_slash", 3500, 100, 14131887], "26": ["vertical_and_slash", 100, 800, 9221283], "27": ["vertical_and_slash", 1000, 6096, 4197162], "28": ["vertical_and_slash", 30, 800, 4438611], "29": ["vertical_and_slash", 30, 800, 354648], "30": ["vertical_and_slash", 30, 800, 7285775], "31": ["vertical_and_slash", 30, 800, 4392079]}, {"0": ["vertical_and_slash", 1000, 6096, 2131686], "1": ["vertical_and_slash", 1000, 6096, 3609919], "2": ["vertical_and_slash", 1000, 6096, 899481], "3": ["vertical_and_slash", 100, 800, 3219776], "4": ["vertical_and_slash", 3500, 100, 11460535], "5": ["vertical_and_slash", 1000, 6096, 154336], "6": ["vertical_and_slash", 3500, 100, 14438950], "7": ["vertical_and_slash", 100, 800, 6652113], "8": ["vertical_and_slash", 100, 800, 9133667], "9": ["vertical_and_slash", 100, 800, 8048731], "10": ["vertical_and_slash", 1000, 6096, 528931], "11": ["vertical_and_slash", 30, 800, 2635938], "12": ["vertical_and_slash", 30, 800, 8546455], "13": ["vertical_and_slash", 500, 700, 7229697], "14": ["vertical_and_slash", 1000, 6096, 32195], "15": ["vertical_and_slash", 1000, 6096, 230534], "16": ["vertical_and_slash", 100, 800, 2475909], "17": ["vertical_and_slash", 30, 800, 2484470], "18": ["vertical_and_slash", 100, 800, 8168145], "19": ["vertical_and_slash", 3500, 100, 6348588], "20": ["vertical_and_slash", 500, 700, 290337], "21": ["vertical_and_slash", 3500, 100, 12830116], "22": ["vertical_and_slash", 100, 800, 11406972], "23": ["vertical_and_slash", 1000, 6096, 9663426], "24": ["vertical_and_slash", 3500, 100, 14333500], "25": ["vertical_and_slash", 3500, 100, 14787732], "26": ["vertical_and_slash", 1000, 6096, 13209856], "27": ["vertical_and_slash", 100, 800, 14623240], "28": ["vertical_and_slash", 1000, 6096, 6321698], "29": ["vertical_and_slash", 1000, 6096, 10324255], "30": ["vertical_and_slash", 100, 800, 1338], "31": ["vertical_and_slash", 1000, 6096, 5182275]}, {"0": ["vertical_and_slash", 100, 800, 2653574], "1": ["vertical_and_slash", 1000, 6096, 156404], "2": ["vertical_and_slash", 1000, 6096, 3288754], "3": ["vertical_and_slash", 1000, 6096, 597358], "4": ["vertical_and_slash", 1000, 6096, 13162000], "5": ["vertical_and_slash", 100, 800, 3304599], "6": ["vertical_and_slash", 100, 800, 2334228], "7": ["vertical_and_slash", 30, 800, 151547], "8": ["vertical_and_slash", 1000, 6096, 8084555], "9": ["vertical_and_slash", 1000, 6096, 6986695], "10": ["vertical_and_slash", 30, 800, 1349542], "11": ["vertical_and_slash", 1000, 6096, 62139], "12": ["vertical_and_slash", 500, 700, 586215], "13": ["vertical_and_slash", 30, 800, 3339401], "14": ["vertical_and_slash", 500, 700, 9080591], "15": ["vertical_and_slash", 100, 800, 1860621], "16": ["vertical_and_slash", 1000, 6096, 11577402], "17": ["vertical_and_slash", 1000, 6096, 6483036], "18": ["vertical_and_slash", 1000, 6096, 10223119], "19": ["vertical_and_slash", 1000, 6096, 2516899], "20": ["vertical_and_slash", 100, 800, 14689692], "21": ["vertical_and_slash", 1000, 6096, 9574317], "22": ["vertical_and_slash", 1000, 6096, 14315469], "23": ["vertical_and_slash", 1000, 6096, 11084722], "24": ["vertical_and_slash", 30, 800, 5714332], "25": ["vertical_and_slash", 30, 800, 440501], "26": ["vertical_and_slash", 30, 800, 135011], "27": ["vertical_and_slash", 100, 800, 1143711], "28": ["vertical_and_slash", 1000, 6096, 10833817], "29": ["vertical_and_slash", 100, 800, 9389405], "30": ["vertical_and_slash", 1000, 6096, 7182171], "31": ["vertical_and_slash", 1000, 6096, 3116752]}, {"0": ["vertical_and_slash", 1000, 6096, 2272762], "1": ["vertical_and_slash", 100, 800, 9251901], "2": ["vertical_and_slash", 1000, 6096, 3172792], "3": ["vertical_and_slash", 1000, 6096, 11166637], "4": ["vertical_and_slash", 1000, 6096, 267179], "5": ["vertical_and_slash", 100, 800, 1956945], "6": ["vertical_and_slash", 1000, 6096, 431457], "7": ["vertical_and_slash", 100, 800, 215074], "8": ["vertical_and_slash", 30, 800, 160167], "9": ["vertical_and_slash", 1000, 6096, 13251530], "10": ["vertical_and_slash", 100, 800, 1045212], "11": ["vertical_and_slash", 1000, 6096, 7767754], "12": ["vertical_and_slash", 100, 800, 8430862], "13": ["vertical_and_slash", 100, 800, 12275346], "14": ["vertical_and_slash", 1000, 6096, 12967454], "15": ["vertical_and_slash", 1000, 6096, 776792], "16": ["vertical_and_slash", 30, 800, 4940981], "17": ["vertical_and_slash", 1000, 6096, 4687476], "18": ["vertical_and_slash", 30, 800, 3396568], "19": ["vertical_and_slash", 1000, 6096, 6330177], "20": ["vertical_and_slash", 100, 800, 10772100], "21": ["vertical_and_slash", 1000, 6096, 431927], "22": ["vertical_and_slash", 100, 800, 5368777], "23": ["vertical_and_slash", 100, 800, 11971880], "24": ["vertical_and_slash", 1000, 6096, 3355141], "25": ["vertical_and_slash", 30, 800, 7775685], "26": ["vertical_and_slash", 1000, 6096, 17862], "27": ["vertical_and_slash", 1000, 6096, 2368170], "28": ["vertical_and_slash", 1000, 6096, 887652], "29": ["vertical_and_slash", 1000, 6096, 342019], "30": ["vertical_and_slash", 1000, 6096, 2031], "31": ["vertical_and_slash", 100, 800, 851845]}, {"0": ["vertical_and_slash", 1000, 6096, 9577296], "1": ["vertical_and_slash", 1000, 6096, 6130994], "2": ["vertical_and_slash", 1000, 6096, 932158], "3": ["vertical_and_slash", 1000, 6096, 6193523], "4": ["vertical_and_slash", 30, 800, 4212495], "5": ["vertical_and_slash", 1000, 6096, 82539], "6": ["vertical_and_slash", 1000, 6096, 2033854], "7": ["vertical_and_slash", 100, 800, 973812], "8": ["vertical_and_slash", 1000, 6096, 96691], "9": ["vertical_and_slash", 1000, 6096, 7346123], "10": ["vertical_and_slash", 1000, 6096, 3425225], "11": ["vertical_and_slash", 1000, 6096, 5656378], "12": ["vertical_and_slash", 1000, 6096, 13585373], "13": ["vertical_and_slash", 3500, 100, 12228455], "14": ["vertical_and_slash", 100, 800, 14994473], "15": ["vertical_and_slash", 1000, 6096, 12825284], "16": ["vertical_and_slash", 1000, 6096, 8256], "17": ["vertical_and_slash", 1000, 6096, 287798], "18": ["vertical_and_slash", 1000, 6096, 3485339], "19": ["vertical_and_slash", 1000, 6096, 4049013], "20": ["vertical_and_slash", 1000, 6096, 10172329], "21": ["vertical_and_slash", 100, 800, 70376], "22": ["vertical_and_slash", 500, 700, 624964], "23": ["vertical_and_slash", 1000, 6096, 7478718], "24": ["vertical_and_slash", 1000, 6096, 11234418], "25": ["vertical_and_slash", 100, 800, 12774404], "26": ["vertical_and_slash", 1000, 6096, 10820183], "27": ["vertical_and_slash", 1000, 6096, 8669939], "28": ["vertical_and_slash", 100, 800, 46], "29": ["vertical_and_slash", 30, 800, 2478], "30": ["vertical_and_slash", 1000, 6096, 343890], "31": ["vertical_and_slash", 1000, 6096, 485618]}, {"0": ["vertical_and_slash", 1000, 6096, 2552], "1": ["vertical_and_slash", 1000, 6096, 3940587], "2": ["vertical_and_slash", 1000, 6096, 2070936], "3": ["vertical_and_slash", 1000, 6096, 232875], "4": ["vertical_and_slash", 30, 800, 751140], "5": ["vertical_and_slash", 100, 800, 231769], "6": ["vertical_and_slash", 30, 800, 2274515], "7": ["vertical_and_slash", 30, 800, 989564], "8": ["vertical_and_slash", 3500, 100, 14768346], "9": ["vertical_and_slash", 30, 800, 1208594], "10": ["vertical_and_slash", 30, 800, 1770328], "11": ["vertical_and_slash", 1000, 6096, 8752930], "12": ["vertical_and_slash", 3500, 100, 46312], "13": ["vertical_and_slash", 100, 800, 289542], "14": ["vertical_and_slash", 3500, 100, 306397], "15": ["vertical_and_slash", 3500, 100, 56350], "16": ["vertical_and_slash", 100, 800, 356204], "17": ["vertical_and_slash", 3500, 100, 1500240], "18": ["vertical_and_slash", 1000, 6096, 150152], "19": ["vertical_and_slash", 100, 800, 101799], "20": ["vertical_and_slash", 1000, 6096, 299393], "21": ["vertical_and_slash", 1000, 6096, 8627429], "22": ["vertical_and_slash", 1000, 6096, 3529325], "23": ["vertical_and_slash", 1000, 6096, 1448873], "24": ["vertical_and_slash", 1000, 6096, 1712901], "25": ["vertical_and_slash", 500, 700, 4048433], "26": ["vertical_and_slash", 1000, 6096, 3837844], "27": ["vertical_and_slash", 1000, 6096, 5399791], "28": ["vertical_and_slash", 1000, 6096, 5525857], "29": ["vertical_and_slash", 1000, 6096, 4847570], "30": ["vertical_and_slash", 1000, 6096, 7521944], "31": ["vertical_and_slash", 1000, 6096, 6944849]}, {"0": ["vertical_and_slash", 3500, 100, 12061195], "1": ["vertical_and_slash", 3500, 100, 13821114], "2": ["vertical_and_slash", 1000, 6096, 11831232], "3": ["vertical_and_slash", 1000, 6096, 1990608], "4": ["vertical_and_slash", 1000, 6096, 1126789], "5": ["vertical_and_slash", 1000, 6096, 164058], "6": ["vertical_and_slash", 1000, 6096, 1546250], "7": ["vertical_and_slash", 3500, 100, 3197616], "8": ["vertical_and_slash", 1000, 6096, 4347461], "9": ["vertical_and_slash", 100, 800, 6182587], "10": ["vertical_and_slash", 100, 800, 344594], "11": ["vertical_and_slash", 100, 800, 4476113], "12": ["vertical_and_slash", 1000, 6096, 13461002], "13": ["vertical_and_slash", 1000, 6096, 10764088], "14": ["vertical_and_slash", 1000, 6096, 12256526], "15": ["vertical_and_slash", 1000, 6096, 13680456], "16": ["vertical_and_slash", 30, 800, 247807], "17": ["vertical_and_slash", 30, 800, 283870], "18": ["vertical_and_slash", 30, 800, 8225577], "19": ["vertical_and_slash", 30, 800, 448632], "20": ["vertical_and_slash", 1000, 6096, 4175564], "21": ["vertical_and_slash", 1000, 6096, 2726117], "22": ["vertical_and_slash", 1000, 6096, 310838], "23": ["vertical_and_slash", 1000, 6096, 204919], "24": ["vertical_and_slash", 30, 800, 875524], "25": ["vertical_and_slash", 30, 800, 1182277], "26": ["vertical_and_slash", 30, 800, 4252580], "27": ["vertical_and_slash", 100, 800, 728402], "28": ["vertical_and_slash", 1000, 6096, 12755775], "29": ["vertical_and_slash", 1000, 6096, 13455097], "30": ["vertical_and_slash", 100, 800, 10492805], "31": ["vertical_and_slash", 3500, 100, 11957996]}, {"0": ["vertical_and_slash", 500, 700, 386640], "1": ["vertical_and_slash", 100, 800, 819517], "2": ["vertical_and_slash", 30, 800, 1170984], "3": ["vertical_and_slash", 100, 800, 626489], "4": ["vertical_and_slash", 1000, 6096, 5856605], "5": ["vertical_and_slash", 1000, 6096, 12960788], "6": ["vertical_and_slash", 1000, 6096, 13042017], "7": ["vertical_and_slash", 1000, 6096, 12542120], "8": ["vertical_and_slash", 1000, 6096, 24167], "9": ["vertical_and_slash", 100, 800, 440430], "10": ["vertical_and_slash", 3500, 100, 748759], "11": ["vertical_and_slash", 1000, 6096, 4655], "12": ["vertical_and_slash", 1000, 6096, 10739360], "13": ["vertical_and_slash", 1000, 6096, 9336615], "14": ["vertical_and_slash", 3500, 100, 14305575], "15": ["vertical_and_slash", 3500, 100, 13833292], "16": ["vertical_and_slash", 30, 800, 3412], "17": ["vertical_and_slash", 500, 700, 16614], "18": ["vertical_and_slash", 1000, 6096, 839930], "19": ["vertical_and_slash", 500, 700, 77296], "20": ["vertical_and_slash", 1000, 6096, 11148082], "21": ["vertical_and_slash", 100, 800, 2483383], "22": ["vertical_and_slash", 3500, 100, 11902907], "23": ["vertical_and_slash", 100, 800, 2194], "24": ["vertical_and_slash", 1000, 6096, 4441496], "25": ["vertical_and_slash", 3500, 100, 10827107], "26": ["vertical_and_slash", 100, 800, 105753], "27": ["vertical_and_slash", 1000, 6096, 5261357], "28": ["vertical_and_slash", 30, 800, 61603], "29": ["vertical_and_slash", 30, 800, 108480], "30": ["vertical_and_slash", 30, 800, 30219], "31": ["vertical_and_slash", 30, 800, 31426]}, {"0": ["vertical_and_slash", 1000, 6096, 136760], "1": ["vertical_and_slash", 100, 800, 827733], "2": ["vertical_and_slash", 100, 800, 670059], "3": ["vertical_and_slash", 3500, 100, 502020], "4": ["vertical_and_slash", 100, 800, 469444], "5": ["vertical_and_slash", 100, 800, 162670], "6": ["vertical_and_slash", 1000, 6096, 22310], "7": ["vertical_and_slash", 1000, 6096, 465], "8": ["vertical_and_slash", 30, 800, 951054], "9": ["vertical_and_slash", 30, 800, 799102], "10": ["vertical_and_slash", 30, 800, 936020], "11": ["vertical_and_slash", 30, 800, 2027181], "12": ["vertical_and_slash", 3500, 100, 5986265], "13": ["vertical_and_slash", 500, 700, 3941412], "14": ["vertical_and_slash", 100, 800, 10557303], "15": ["vertical_and_slash", 100, 800, 1533916], "16": ["vertical_and_slash", 3500, 100, 11870953], "17": ["vertical_and_slash", 3500, 100, 12342581], "18": ["vertical_and_slash", 3500, 100, 12699180], "19": ["vertical_and_slash", 1000, 6096, 5138869], "20": ["vertical_and_slash", 1000, 6096, 12477033], "21": ["vertical_and_slash", 1000, 6096, 872144], "22": ["vertical_and_slash", 3500, 100, 13382501], "23": ["vertical_and_slash", 1000, 6096, 11531397], "24": ["vertical_and_slash", 1000, 6096, 13884364], "25": ["vertical_and_slash", 1000, 6096, 13611635], "26": ["vertical_and_slash", 1000, 6096, 13516676], "27": ["vertical_and_slash", 1000, 6096, 12560863], "28": ["vertical_and_slash", 500, 700, 3865996], "29": ["vertical_and_slash", 30, 800, 3343532], "30": ["vertical_and_slash", 30, 800, 179777], "31": ["vertical_and_slash", 3500, 100, 3863085]}, {"0": ["vertical_and_slash", 3500, 100, 6771823], "1": ["vertical_and_slash", 3500, 100, 10770780], "2": ["vertical_and_slash", 1000, 6096, 108476], "3": ["vertical_and_slash", 1000, 6096, 917033], "4": ["vertical_and_slash", 3500, 100, 9994951], "5": ["vertical_and_slash", 3500, 100, 13503132], "6": ["vertical_and_slash", 3500, 100, 11843766], "7": ["vertical_and_slash", 3500, 100, 10714999], "8": ["vertical_and_slash", 100, 800, 650037], "9": ["vertical_and_slash", 30, 800, 321924], "10": ["vertical_and_slash", 100, 800, 306681], "11": ["vertical_and_slash", 100, 800, 76181], "12": ["vertical_and_slash", 3500, 100, 12194592], "13": ["vertical_and_slash", 1000, 6096, 12635491], "14": ["vertical_and_slash", 3500, 100, 11953805], "15": ["vertical_and_slash", 3500, 100, 12355730], "16": ["vertical_and_slash", 100, 800, 614284], "17": ["vertical_and_slash", 100, 800, 512751], "18": ["vertical_and_slash", 3500, 100, 2679940], "19": ["vertical_and_slash", 100, 800, 1749683], "20": ["vertical_and_slash", 30, 800, 563622], "21": ["vertical_and_slash", 30, 800, 9985639], "22": ["vertical_and_slash", 30, 800, 1055029], "23": ["vertical_and_slash", 30, 800, 501782], "24": ["vertical_and_slash", 30, 800, 68229], "25": ["vertical_and_slash", 100, 800, 211743], "26": ["vertical_and_slash", 100, 800, 1690702], "27": ["vertical_and_slash", 30, 800, 2720080], "28": ["vertical_and_slash", 30, 800, 3884686], "29": ["vertical_and_slash", 30, 800, 3303748], "30": ["vertical_and_slash", 30, 800, 3335960], "31": ["vertical_and_slash", 30, 800, 2469116]}, {"0": ["vertical_and_slash", 1000, 6096, 726797], "1": ["vertical_and_slash", 100, 800, 5833160], "2": ["vertical_and_slash", 1000, 6096, 1766748], "3": ["vertical_and_slash", 1000, 6096, 6021028], "4": ["vertical_and_slash", 1000, 6096, 3120126], "5": ["vertical_and_slash", 30, 800, 3103142], "6": ["vertical_and_slash", 1000, 6096, 22974], "7": ["vertical_and_slash", 1000, 6096, 616209], "8": ["vertical_and_slash", 100, 800, 5571258], "9": ["vertical_and_slash", 30, 800, 2259315], "10": ["vertical_and_slash", 1000, 6096, 438342], "11": ["vertical_and_slash", 100, 800, 5557528], "12": ["vertical_and_slash", 3500, 100, 12954645], "13": ["vertical_and_slash", 1000, 6096, 12677660], "14": ["vertical_and_slash", 3500, 100, 13038925], "15": ["vertical_and_slash", 1000, 6096, 11239328], "16": ["vertical_and_slash", 3500, 100, 5247646], "17": ["vertical_and_slash", 500, 700, 384866], "18": ["vertical_and_slash", 1000, 6096, 655131], "19": ["vertical_and_slash", 3500, 100, 8826025], "20": ["vertical_and_slash", 100, 800, 4478606], "21": ["vertical_and_slash", 100, 800, 3881052], "22": ["vertical_and_slash", 100, 800, 6027887], "23": ["vertical_and_slash", 3500, 100, 8475077], "24": ["vertical_and_slash", 1000, 6096, 103633], "25": ["vertical_and_slash", 1000, 6096, 76484], "26": ["vertical_and_slash", 100, 800, 22432], "27": ["vertical_and_slash", 1000, 6096, 1313063], "28": ["vertical_and_slash", 1000, 6096, 6617078], "29": ["vertical_and_slash", 3500, 100, 12355842], "30": ["vertical_and_slash", 100, 800, 1401085], "31": ["vertical_and_slash", 3500, 100, 11350169]}, {"0": ["vertical_and_slash", 100, 800, 142456], "1": ["vertical_and_slash", 500, 700, 290481], "2": ["vertical_and_slash", 30, 800, 195338], "3": ["vertical_and_slash", 30, 800, 235375], "4": ["vertical_and_slash", 3500, 100, 13220328], "5": ["vertical_and_slash", 1000, 6096, 13040738], "6": ["vertical_and_slash", 3500, 100, 14847993], "7": ["vertical_and_slash", 1000, 6096, 12236451], "8": ["vertical_and_slash", 30, 800, 1360565], "9": ["vertical_and_slash", 30, 800, 115757], "10": ["vertical_and_slash", 30, 800, 806615], "11": ["vertical_and_slash", 30, 800, 5655605], "12": ["vertical_and_slash", 1000, 6096, 803465], "13": ["vertical_and_slash", 1000, 6096, 7601845], "14": ["vertical_and_slash", 30, 800, 8869563], "15": ["vertical_and_slash", 100, 800, 9177143], "16": ["vertical_and_slash", 1000, 6096, 612999], "17": ["vertical_and_slash", 100, 800, 2657352], "18": ["vertical_and_slash", 1000, 6096, 297015], "19": ["vertical_and_slash", 1000, 6096, 309571], "20": ["vertical_and_slash", 1000, 6096, 13160644], "21": ["vertical_and_slash", 1000, 6096, 14006964], "22": ["vertical_and_slash", 3500, 100, 14287913], "23": ["vertical_and_slash", 3500, 100, 14586379], "24": ["vertical_and_slash", 1000, 6096, 12023244], "25": ["vertical_and_slash", 30, 800, 12092108], "26": ["vertical_and_slash", 500, 700, 6005169], "27": ["vertical_and_slash", 500, 700, 9574963], "28": ["vertical_and_slash", 1000, 6096, 1696021], "29": ["vertical_and_slash", 30, 800, 1516298], "30": ["vertical_and_slash", 1000, 6096, 2303483], "31": ["vertical_and_slash", 1000, 6096, 903636]}, {"0": ["vertical_and_slash", 3500, 100, 7496361], "1": ["vertical_and_slash", 30, 800, 571560], "2": ["vertical_and_slash", 100, 800, 3025676], "3": ["vertical_and_slash", 30, 800, 5167076], "4": ["vertical_and_slash", 30, 800, 501453], "5": ["vertical_and_slash", 30, 800, 342659], "6": ["vertical_and_slash", 30, 800, 2561588], "7": ["vertical_and_slash", 30, 800, 869660], "8": ["vertical_and_slash", 100, 800, 10740412], "9": ["vertical_and_slash", 30, 800, 87115], "10": ["vertical_and_slash", 3500, 100, 9800623], "11": ["vertical_and_slash", 30, 800, 9191448], "12": ["vertical_and_slash", 1000, 6096, 289817], "13": ["vertical_and_slash", 3500, 100, 9009480], "14": ["vertical_and_slash", 1000, 6096, 1799625], "15": ["vertical_and_slash", 1000, 6096, 4984031], "16": ["vertical_and_slash", 3500, 100, 3381538], "17": ["vertical_and_slash", 100, 800, 11456778], "18": ["vertical_and_slash", 3500, 100, 14316760], "19": ["vertical_and_slash", 100, 800, 5228661], "20": ["vertical_and_slash", 3500, 100, 5831971], "21": ["vertical_and_slash", 500, 700, 10184028], "22": ["vertical_and_slash", 30, 800, 578221], "23": ["vertical_and_slash", 3500, 100, 6213253], "24": ["vertical_and_slash", 1000, 6096, 6146366], "25": ["vertical_and_slash", 1000, 6096, 1477166], "26": ["vertical_and_slash", 30, 800, 318810], "27": ["vertical_and_slash", 1000, 6096, 8654738], "28": ["vertical_and_slash", 500, 700, 3294065], "29": ["vertical_and_slash", 100, 800, 8531992], "30": ["vertical_and_slash", 100, 800, 2564233], "31": ["vertical_and_slash", 100, 800, 113957]}, {"0": ["vertical_and_slash", 100, 800, 530019], "1": ["vertical_and_slash", 100, 800, 647580], "2": ["vertical_and_slash", 30, 800, 4990437], "3": ["vertical_and_slash", 30, 800, 317415], "4": ["vertical_and_slash", 100, 800, 365956], "5": ["vertical_and_slash", 100, 800, 1689094], "6": ["vertical_and_slash", 100, 800, 454281], "7": ["vertical_and_slash", 30, 800, 266331], "8": ["vertical_and_slash", 3500, 100, 3603593], "9": ["vertical_and_slash", 100, 800, 14614370], "10": ["vertical_and_slash", 1000, 6096, 5361097], "11": ["vertical_and_slash", 100, 800, 14371859], "12": ["vertical_and_slash", 30, 800, 1232558], "13": ["vertical_and_slash", 30, 800, 546028], "14": ["vertical_and_slash", 30, 800, 853313], "15": ["vertical_and_slash", 30, 800, 194933], "16": ["vertical_and_slash", 3500, 100, 14304381], "17": ["vertical_and_slash", 1000, 6096, 815541], "18": ["vertical_and_slash", 100, 800, 5138518], "19": ["vertical_and_slash", 3500, 100, 9565094], "20": ["vertical_and_slash", 1000, 6096, 2035169], "21": ["vertical_and_slash", 1000, 6096, 3375423], "22": ["vertical_and_slash", 1000, 6096, 3777615], "23": ["vertical_and_slash", 1000, 6096, 12354929], "24": ["vertical_and_slash", 30, 800, 1763576], "25": ["vertical_and_slash", 30, 800, 3727796], "26": ["vertical_and_slash", 30, 800, 2744406], "27": ["vertical_and_slash", 30, 800, 1997757], "28": ["vertical_and_slash", 1000, 6096, 12257], "29": ["vertical_and_slash", 1000, 6096, 1169443], "30": ["vertical_and_slash", 3500, 100, 5723144], "31": ["vertical_and_slash", 3500, 100, 5420298]}, {"0": ["vertical_and_slash", 1000, 6096, 2447512], "1": ["vertical_and_slash", 3500, 100, 10860908], "2": ["vertical_and_slash", 100, 800, 9108572], "3": ["vertical_and_slash", 3500, 100, 11624453], "4": ["vertical_and_slash", 100, 800, 6925192], "5": ["vertical_and_slash", 100, 800, 9369879], "6": ["vertical_and_slash", 3500, 100, 11865786], "7": ["vertical_and_slash", 30, 800, 9628595], "8": ["vertical_and_slash", 1000, 6096, 6302171], "9": ["vertical_and_slash", 3500, 100, 8455497], "10": ["vertical_and_slash", 30, 800, 6885122], "11": ["vertical_and_slash", 1000, 6096, 5076785], "12": ["vertical_and_slash", 1000, 6096, 12769698], "13": ["vertical_and_slash", 1000, 6096, 13513363], "14": ["vertical_and_slash", 1000, 6096, 14089388], "15": ["vertical_and_slash", 1000, 6096, 14501815], "16": ["vertical_and_slash", 1000, 6096, 1619566], "17": ["vertical_and_slash", 1000, 6096, 5031895], "18": ["vertical_and_slash", 1000, 6096, 3833561], "19": ["vertical_and_slash", 100, 800, 12325460], "20": ["vertical_and_slash", 1000, 6096, 320906], "21": ["vertical_and_slash", 3500, 100, 13924855], "22": ["vertical_and_slash", 100, 800, 10478874], "23": ["vertical_and_slash", 30, 800, 4410655], "24": ["vertical_and_slash", 3500, 100, 14767197], "25": ["vertical_and_slash", 1000, 6096, 4108672], "26": ["vertical_and_slash", 100, 800, 14797906], "27": ["vertical_and_slash", 3500, 100, 14643144], "28": ["vertical_and_slash", 100, 800, 10556268], "29": ["vertical_and_slash", 3500, 100, 14575250], "30": ["vertical_and_slash", 1000, 6096, 14076831], "31": ["vertical_and_slash", 1000, 6096, 10779010]}, {"0": ["vertical_and_slash", 30, 800, 4744885], "1": ["vertical_and_slash", 30, 800, 4794511], "2": ["vertical_and_slash", 30, 800, 9418373], "3": ["vertical_and_slash", 30, 800, 2291979], "4": ["vertical_and_slash", 30, 800, 10009392], "5": ["vertical_and_slash", 30, 800, 981769], "6": ["vertical_and_slash", 30, 800, 3395467], "7": ["vertical_and_slash", 100, 800, 5966942], "8": ["vertical_and_slash", 30, 800, 7092993], "9": ["vertical_and_slash", 30, 800, 2176489], "10": ["vertical_and_slash", 30, 800, 4330010], "11": ["vertical_and_slash", 1000, 6096, 2664159], "12": ["vertical_and_slash", 30, 800, 7282328], "13": ["vertical_and_slash", 30, 800, 14135136], "14": ["vertical_and_slash", 1000, 6096, 791118], "15": ["vertical_and_slash", 30, 800, 9266081], "16": ["vertical_and_slash", 3500, 100, 14422288], "17": ["vertical_and_slash", 3500, 100, 11457529], "18": ["vertical_and_slash", 30, 800, 4503306], "19": ["vertical_and_slash", 100, 800, 11937543], "20": ["vertical_and_slash", 3500, 100, 14538141], "21": ["vertical_and_slash", 3500, 100, 13564714], "22": ["vertical_and_slash", 100, 800, 9671640], "23": ["vertical_and_slash", 30, 800, 2841456], "24": ["vertical_and_slash", 30, 800, 1395156], "25": ["vertical_and_slash", 30, 800, 989026], "26": ["vertical_and_slash", 30, 800, 10617339], "27": ["vertical_and_slash", 30, 800, 8170836], "28": ["vertical_and_slash", 100, 800, 2032096], "29": ["vertical_and_slash", 3500, 100, 13931334], "30": ["vertical_and_slash", 3500, 100, 14790424], "31": ["vertical_and_slash", 1000, 6096, 4133248]}]
|
minference/configs/Phi_3_mini_128k_instruct_kv_out_v32_fit_o_best_pattern.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
[{"0": ["vertical_and_slash", 1000, 6096, 0.33349305391311646], "1": ["vertical_and_slash", 1000, 6096, 0.4378805160522461], "2": ["vertical_and_slash", 1000, 6096, 0.48282963037490845], "3": ["vertical_and_slash", 1000, 6096, 0.37695789337158203], "4": ["vertical_and_slash", 1000, 6096, 0.38924556970596313], "5": ["vertical_and_slash", 1000, 6096, 0.3510749340057373], "6": ["vertical_and_slash", 1000, 6096, 0.39886632561683655], "7": ["vertical_and_slash", 1000, 6096, 0.8939290046691895], "8": ["vertical_and_slash", 1000, 6096, 0.44007450342178345], "9": ["vertical_and_slash", 1000, 6096, 0.3897586464881897], "10": ["vertical_and_slash", 1000, 6096, 0.40355661511421204], "11": ["vertical_and_slash", 1000, 6096, 0.36381030082702637], "12": ["vertical_and_slash", 1000, 6096, 0.4459313154220581], "13": ["vertical_and_slash", 1000, 6096, 0.3341565728187561], "14": ["vertical_and_slash", 1000, 6096, 0.384276419878006], "15": ["vertical_and_slash", 1000, 6096, 0.34818336367607117], "16": ["vertical_and_slash", 1000, 6096, 0.3867861330509186], "17": ["vertical_and_slash", 1000, 6096, 0.3639705777168274], "18": ["vertical_and_slash", 1000, 6096, 0.3512721359729767], "19": ["vertical_and_slash", 1000, 6096, 0.4681489169597626], "20": ["vertical_and_slash", 1000, 6096, 0.4651115834712982], "21": ["vertical_and_slash", 1000, 6096, 0.3882596790790558], "22": ["vertical_and_slash", 1000, 6096, 0.47017091512680054], "23": ["vertical_and_slash", 1000, 6096, 0.8037586808204651], "24": ["vertical_and_slash", 1000, 6096, 0.3913174867630005], "25": ["vertical_and_slash", 1000, 6096, 0.5203016400337219], "26": ["vertical_and_slash", 1000, 6096, 0.47166702151298523], "27": ["vertical_and_slash", 1000, 6096, 0.760438084602356], "28": ["vertical_and_slash", 1000, 6096, 0.943070650100708], "29": ["vertical_and_slash", 1000, 6096, 0.4118039011955261], "30": ["vertical_and_slash", 1000, 6096, 0.6815055012702942], "31": ["vertical_and_slash", 1000, 6096, 0.6300445795059204]}, {"0": ["vertical_and_slash", 1000, 6096, 0.6439709663391113], "1": ["vertical_and_slash", 1000, 6096, 0.5207313895225525], "2": ["vertical_and_slash", 1000, 6096, 0.47401225566864014], "3": ["vertical_and_slash", 1000, 6096, 0.5988013744354248], "4": ["vertical_and_slash", 1000, 6096, 0.6021823287010193], "5": ["vertical_and_slash", 1000, 6096, 0.4162128269672394], "6": ["vertical_and_slash", 1000, 6096, 0.7858797311782837], "7": ["vertical_and_slash", 1000, 6096, 0.6350969672203064], "8": ["vertical_and_slash", 1000, 6096, 0.5817031860351562], "9": ["vertical_and_slash", 1000, 6096, 0.9291586875915527], "10": ["vertical_and_slash", 1000, 6096, 0.6078806519508362], "11": ["vertical_and_slash", 1000, 6096, 0.5813876986503601], "12": ["vertical_and_slash", 1000, 6096, 0.7652914524078369], "13": ["vertical_and_slash", 1000, 6096, 0.4502100944519043], "14": ["vertical_and_slash", 1000, 6096, 0.6180105209350586], "15": ["vertical_and_slash", 1000, 6096, 0.7175759673118591], "16": ["vertical_and_slash", 1000, 6096, 0.6323421597480774], "17": ["vertical_and_slash", 3500, 100, 0.479082852602005], "18": ["vertical_and_slash", 1000, 6096, 0.6011233329772949], "19": ["vertical_and_slash", 1000, 6096, 0.8908118605613708], "20": ["vertical_and_slash", 1000, 6096, 0.9255861639976501], "21": ["vertical_and_slash", 1000, 6096, 0.795491099357605], "22": ["vertical_and_slash", 1000, 6096, 0.5210989117622375], "23": ["vertical_and_slash", 1000, 6096, 0.5200297236442566], "24": ["vertical_and_slash", 1000, 6096, 0.5280771255493164], "25": ["vertical_and_slash", 1000, 6096, 0.7380014657974243], "26": ["vertical_and_slash", 1000, 6096, 0.9885807633399963], "27": ["vertical_and_slash", 30, 800, 0.8718840479850769], "28": ["vertical_and_slash", 1000, 6096, 0.6302862167358398], "29": ["vertical_and_slash", 1000, 6096, 0.5750876069068909], "30": ["vertical_and_slash", 1000, 6096, 0.45260417461395264], "31": ["vertical_and_slash", 1000, 6096, 0.6499432325363159]}, {"0": ["vertical_and_slash", 1000, 6096, 0.7977765798568726], "1": ["vertical_and_slash", 1000, 6096, 0.8083621859550476], "2": ["vertical_and_slash", 1000, 6096, 0.5935484170913696], "3": ["vertical_and_slash", 1000, 6096, 0.5435713529586792], "4": ["vertical_and_slash", 1000, 6096, 0.5687218904495239], "5": ["vertical_and_slash", 1000, 6096, 0.854501485824585], "6": ["vertical_and_slash", 1000, 6096, 0.6359673142433167], "7": ["vertical_and_slash", 1000, 6096, 0.5785433053970337], "8": ["vertical_and_slash", 1000, 6096, 0.8543683290481567], "9": ["vertical_and_slash", 1000, 6096, 0.762371838092804], "10": ["vertical_and_slash", 1000, 6096, 0.6970657706260681], "11": ["vertical_and_slash", 1000, 6096, 0.6844046115875244], "12": ["vertical_and_slash", 1000, 6096, 0.7364732623100281], "13": ["vertical_and_slash", 1000, 6096, 0.8335257172584534], "14": ["vertical_and_slash", 1000, 6096, 0.7734203934669495], "15": ["vertical_and_slash", 1000, 6096, 0.7341973185539246], "16": ["vertical_and_slash", 1000, 6096, 0.7554108500480652], "17": ["vertical_and_slash", 1000, 6096, 0.9054623246192932], "18": ["vertical_and_slash", 1000, 6096, 0.6300320029258728], "19": ["vertical_and_slash", 1000, 6096, 0.70512455701828], "20": ["vertical_and_slash", 1000, 6096, 0.6085258722305298], "21": ["vertical_and_slash", 1000, 6096, 0.6398192644119263], "22": ["vertical_and_slash", 1000, 6096, 0.5992570519447327], "23": ["vertical_and_slash", 1000, 6096, 0.7130728363990784], "24": ["vertical_and_slash", 1000, 6096, 0.8504863977432251], "25": ["vertical_and_slash", 1000, 6096, 0.5748745799064636], "26": ["vertical_and_slash", 1000, 6096, 0.7758736610412598], "27": ["vertical_and_slash", 1000, 6096, 0.5538337230682373], "28": ["vertical_and_slash", 1000, 6096, 0.7384650707244873], "29": ["vertical_and_slash", 1000, 6096, 0.6905707120895386], "30": ["vertical_and_slash", 1000, 6096, 0.6217074990272522], "31": ["vertical_and_slash", 1000, 6096, 0.9545422196388245]}, {"0": ["vertical_and_slash", 500, 700, 0.9924208521842957], "1": ["vertical_and_slash", 100, 750, 0.9987075924873352], "2": ["vertical_and_slash", 500, 700, 0.9915499687194824], "3": ["vertical_and_slash", 100, 750, 0.9940086007118225], "4": ["vertical_and_slash", 100, 750, 0.9947375655174255], "5": ["vertical_and_slash", 100, 750, 0.9920898675918579], "6": ["vertical_and_slash", 100, 750, 0.9960256218910217], "7": ["vertical_and_slash", 100, 750, 0.995691180229187], "8": ["vertical_and_slash", 100, 750, 0.9113738536834717], "9": ["vertical_and_slash", 100, 750, 0.9700976014137268], "10": ["vertical_and_slash", 3500, 100, 0.9520721435546875], "11": ["vertical_and_slash", 100, 750, 0.9561598300933838], "12": ["vertical_and_slash", 100, 750, 0.8256366848945618], "13": ["vertical_and_slash", 100, 750, 0.9905430674552917], "14": ["vertical_and_slash", 500, 700, 0.9822967648506165], "15": ["vertical_and_slash", 100, 750, 0.9880149960517883], "16": ["vertical_and_slash", 100, 750, 0.9570814967155457], "17": ["vertical_and_slash", 100, 750, 0.9678364396095276], "18": ["vertical_and_slash", 3500, 100, 0.9819864630699158], "19": ["vertical_and_slash", 100, 750, 0.9930639266967773], "20": ["vertical_and_slash", 3500, 100, 0.9928342700004578], "21": ["vertical_and_slash", 3500, 100, 0.9522428512573242], "22": ["vertical_and_slash", 100, 750, 0.9961853623390198], "23": ["vertical_and_slash", 100, 750, 0.9895046353340149], "24": ["vertical_and_slash", 100, 750, 0.9106875061988831], "25": ["vertical_and_slash", 100, 750, 0.9944272041320801], "26": ["vertical_and_slash", 100, 750, 0.9603897333145142], "27": ["vertical_and_slash", 100, 750, 0.9967218637466431], "28": ["vertical_and_slash", 100, 750, 0.9922856092453003], "29": ["vertical_and_slash", 100, 750, 0.9425711631774902], "30": ["vertical_and_slash", 1000, 6096, 0.6492345333099365], "31": ["vertical_and_slash", 500, 700, 0.957703709602356]}, {"0": ["vertical_and_slash", 100, 750, 0.9920511841773987], "1": ["vertical_and_slash", 3500, 100, 0.9784621000289917], "2": ["vertical_and_slash", 100, 750, 0.9945407509803772], "3": ["vertical_and_slash", 100, 750, 0.9613493084907532], "4": ["vertical_and_slash", 100, 750, 0.8482271432876587], "5": ["vertical_and_slash", 500, 700, 0.9943300485610962], "6": ["vertical_and_slash", 100, 750, 0.9810841083526611], "7": ["vertical_and_slash", 3500, 100, 0.9297769069671631], "8": ["vertical_and_slash", 100, 750, 0.8839191198348999], "9": ["vertical_and_slash", 100, 750, 0.9955653548240662], "10": ["vertical_and_slash", 100, 750, 0.9484658241271973], "11": ["vertical_and_slash", 100, 750, 0.994473397731781], "12": ["vertical_and_slash", 500, 700, 0.9420907497406006], "13": ["vertical_and_slash", 100, 750, 0.9161052107810974], "14": ["vertical_and_slash", 100, 750, 0.9645522832870483], "15": ["vertical_and_slash", 100, 750, 0.9875764846801758], "16": ["vertical_and_slash", 100, 750, 0.7891636490821838], "17": ["vertical_and_slash", 1000, 6096, 0.7788199186325073], "18": ["vertical_and_slash", 100, 750, 0.9488416910171509], "19": ["vertical_and_slash", 3500, 100, 0.9959850311279297], "20": ["vertical_and_slash", 100, 750, 0.9768155217170715], "21": ["vertical_and_slash", 100, 750, 0.995807945728302], "22": ["vertical_and_slash", 3500, 100, 0.8900895118713379], "23": ["vertical_and_slash", 100, 750, 0.9586788415908813], "24": ["vertical_and_slash", 100, 750, 0.9651024341583252], "25": ["vertical_and_slash", 3500, 100, 0.9384130239486694], "26": ["vertical_and_slash", 100, 750, 0.9855350255966187], "27": ["vertical_and_slash", 100, 750, 0.9657205939292908], "28": ["vertical_and_slash", 3500, 100, 0.9184022545814514], "29": ["vertical_and_slash", 100, 750, 0.866909384727478], "30": ["vertical_and_slash", 1000, 6096, 0.7826077342033386], "31": ["vertical_and_slash", 100, 750, 0.9975974559783936]}, {"0": ["vertical_and_slash", 100, 750, 0.9865456223487854], "1": ["vertical_and_slash", 100, 750, 0.9591361880302429], "2": ["vertical_and_slash", 100, 750, 0.9168012142181396], "3": ["vertical_and_slash", 500, 700, 0.9530511498451233], "4": ["vertical_and_slash", 1000, 6096, 0.8645423650741577], "5": ["vertical_and_slash", 500, 700, 0.9792267084121704], "6": ["vertical_and_slash", 100, 750, 0.9941954612731934], "7": ["vertical_and_slash", 100, 750, 0.960307776927948], "8": ["vertical_and_slash", 3500, 100, 0.9855586886405945], "9": ["vertical_and_slash", 100, 750, 0.9828901886940002], "10": ["vertical_and_slash", 100, 750, 0.8591288328170776], "11": ["vertical_and_slash", 100, 750, 0.917044460773468], "12": ["vertical_and_slash", 100, 750, 0.9849950075149536], "13": ["vertical_and_slash", 100, 750, 0.8859434723854065], "14": ["vertical_and_slash", 100, 750, 0.9971017241477966], "15": ["vertical_and_slash", 500, 700, 0.9620269536972046], "16": ["vertical_and_slash", 500, 700, 0.9597799181938171], "17": ["vertical_and_slash", 500, 700, 0.9934410452842712], "18": ["vertical_and_slash", 3500, 100, 0.9977172017097473], "19": ["vertical_and_slash", 500, 700, 0.9520473480224609], "20": ["vertical_and_slash", 3500, 100, 0.9906032085418701], "21": ["vertical_and_slash", 100, 750, 0.9745447635650635], "22": ["vertical_and_slash", 100, 750, 0.9957244396209717], "23": ["vertical_and_slash", 100, 750, 0.9829675555229187], "24": ["vertical_and_slash", 100, 750, 0.9565562009811401], "25": ["vertical_and_slash", 100, 750, 0.9823064804077148], "26": ["vertical_and_slash", 100, 750, 0.987698495388031], "27": ["vertical_and_slash", 1000, 6096, 0.8219541907310486], "28": ["vertical_and_slash", 1000, 6096, 0.7586351633071899], "29": ["vertical_and_slash", 100, 750, 0.9752539992332458], "30": ["vertical_and_slash", 100, 750, 0.9929803609848022], "31": ["vertical_and_slash", 100, 750, 0.9185792803764343]}, {"0": ["vertical_and_slash", 100, 750, 0.9146243333816528], "1": ["vertical_and_slash", 100, 750, 0.9178520441055298], "2": ["vertical_and_slash", 3500, 100, 0.9930599331855774], "3": ["vertical_and_slash", 100, 750, 0.9993709325790405], "4": ["vertical_and_slash", 500, 700, 0.9853806495666504], "5": ["vertical_and_slash", 100, 750, 0.9141497015953064], "6": ["vertical_and_slash", 100, 750, 0.992788553237915], "7": ["vertical_and_slash", 100, 750, 0.9772038459777832], "8": ["vertical_and_slash", 1000, 6096, 0.6869983673095703], "9": ["vertical_and_slash", 100, 750, 0.9871460795402527], "10": ["vertical_and_slash", 100, 750, 0.9741801619529724], "11": ["vertical_and_slash", 100, 750, 0.9956739544868469], "12": ["vertical_and_slash", 100, 750, 0.9555794596672058], "13": ["vertical_and_slash", 3500, 100, 0.8615856766700745], "14": ["vertical_and_slash", 3500, 100, 0.9012727737426758], "15": ["vertical_and_slash", 100, 750, 0.9786412715911865], "16": ["vertical_and_slash", 3500, 100, 0.7491975426673889], "17": ["vertical_and_slash", 100, 750, 0.9849361181259155], "18": ["vertical_and_slash", 3500, 100, 0.9097980856895447], "19": ["vertical_and_slash", 1000, 6096, 0.8621278405189514], "20": ["vertical_and_slash", 500, 700, 0.9943590760231018], "21": ["vertical_and_slash", 100, 750, 0.8645753264427185], "22": ["vertical_and_slash", 100, 750, 0.9920986294746399], "23": ["vertical_and_slash", 1000, 6096, 0.8657084703445435], "24": ["vertical_and_slash", 3500, 100, 0.9750965237617493], "25": ["vertical_and_slash", 3500, 100, 0.8507974147796631], "26": ["vertical_and_slash", 3500, 100, 0.9118348360061646], "27": ["vertical_and_slash", 3500, 100, 0.9703859090805054], "28": ["vertical_and_slash", 3500, 100, 0.9725451469421387], "29": ["vertical_and_slash", 1000, 6096, 0.7008982300758362], "30": ["vertical_and_slash", 1000, 6096, 0.838621199131012], "31": ["vertical_and_slash", 100, 750, 0.9929103255271912]}, {"0": ["vertical_and_slash", 1000, 6096, 0.7402030825614929], "1": ["vertical_and_slash", 1000, 6096, 0.8565414547920227], "2": ["vertical_and_slash", 100, 750, 0.9612839221954346], "3": ["vertical_and_slash", 1000, 6096, 0.9598837494850159], "4": ["vertical_and_slash", 1000, 6096, 0.7645464539527893], "5": ["vertical_and_slash", 100, 750, 0.9872377514839172], "6": ["vertical_and_slash", 1000, 6096, 0.7918620705604553], "7": ["vertical_and_slash", 500, 700, 0.9622856378555298], "8": ["vertical_and_slash", 100, 750, 0.8891160488128662], "9": ["vertical_and_slash", 500, 700, 0.9844319224357605], "10": ["vertical_and_slash", 500, 700, 0.9876360297203064], "11": ["vertical_and_slash", 500, 700, 0.9688720703125], "12": ["vertical_and_slash", 1000, 6096, 0.5671995878219604], "13": ["vertical_and_slash", 100, 750, 0.9620596170425415], "14": ["vertical_and_slash", 1000, 6096, 0.6478529572486877], "15": ["vertical_and_slash", 100, 750, 0.9807542562484741], "16": ["vertical_and_slash", 3500, 100, 0.9823787212371826], "17": ["vertical_and_slash", 100, 750, 0.8980384469032288], "18": ["vertical_and_slash", 1000, 6096, 0.8713955879211426], "19": ["vertical_and_slash", 100, 750, 0.9611169099807739], "20": ["vertical_and_slash", 100, 750, 0.9941024780273438], "21": ["vertical_and_slash", 100, 750, 0.9876882433891296], "22": ["vertical_and_slash", 3500, 100, 0.9474965333938599], "23": ["vertical_and_slash", 100, 750, 0.9415712952613831], "24": ["vertical_and_slash", 100, 750, 0.9960836172103882], "25": ["vertical_and_slash", 100, 750, 0.9898598194122314], "26": ["vertical_and_slash", 100, 750, 0.9720168113708496], "27": ["vertical_and_slash", 100, 750, 0.985356330871582], "28": ["vertical_and_slash", 3500, 100, 0.9795358180999756], "29": ["vertical_and_slash", 100, 750, 0.970496654510498], "30": ["vertical_and_slash", 3500, 100, 0.999195396900177], "31": ["vertical_and_slash", 100, 750, 0.9589951038360596]}, {"0": ["vertical_and_slash", 1000, 6096, 0.8079184889793396], "1": ["stream_llm", 100, 800, 0.96484375], "2": ["vertical_and_slash", 1000, 6096, 0.6607644557952881], "3": ["vertical_and_slash", 30, 800, 0.9899947047233582], "4": ["vertical_and_slash", 1000, 6096, 0.9565256237983704], "5": ["vertical_and_slash", 1000, 6096, 0.9755614995956421], "6": ["vertical_and_slash", 30, 800, 0.9720635414123535], "7": ["vertical_and_slash", 30, 800, 0.9191414713859558], "8": ["stream_llm", 100, 800, 0.9921875], "9": ["vertical_and_slash", 1000, 6096, 0.6984944939613342], "10": ["stream_llm", 100, 800, 0.97265625], "11": ["vertical_and_slash", 30, 800, 0.955635666847229], "12": ["vertical_and_slash", 1000, 6096, 0.9949175715446472], "13": ["vertical_and_slash", 30, 800, 0.9833577871322632], "14": ["vertical_and_slash", 1000, 6096, 0.612384021282196], "15": ["vertical_and_slash", 1000, 6096, 0.9294421076774597], "16": ["vertical_and_slash", 30, 800, 0.9978874921798706], "17": ["vertical_and_slash", 30, 800, 0.9265275001525879], "18": ["vertical_and_slash", 500, 700, 0.8441793322563171], "19": ["vertical_and_slash", 1000, 6096, 0.9973151087760925], "20": ["vertical_and_slash", 30, 800, 0.8883945941925049], "21": ["vertical_and_slash", 1000, 6096, 0.9890816807746887], "22": ["vertical_and_slash", 30, 800, 0.9924365282058716], "23": ["stream_llm", 100, 800, 0.98828125], "24": ["vertical_and_slash", 1000, 6096, 0.9733841419219971], "25": ["vertical_and_slash", 1000, 6096, 0.8846827149391174], "26": ["vertical_and_slash", 1000, 6096, 0.8909521698951721], "27": ["vertical_and_slash", 30, 800, 0.95379239320755], "28": ["vertical_and_slash", 30, 800, 0.989055871963501], "29": ["vertical_and_slash", 30, 800, 0.9804853796958923], "30": ["vertical_and_slash", 30, 800, 0.9921841621398926], "31": ["vertical_and_slash", 30, 800, 0.9727922677993774]}, {"0": ["stream_llm", 100, 800, 0.984375], "1": ["vertical_and_slash", 30, 800, 0.9801875352859497], "2": ["vertical_and_slash", 3500, 100, 0.9504685997962952], "3": ["vertical_and_slash", 500, 700, 0.5719053745269775], "4": ["vertical_and_slash", 30, 800, 0.9975548386573792], "5": ["vertical_and_slash", 30, 800, 0.9834421873092651], "6": ["vertical_and_slash", 500, 700, 0.876423180103302], "7": ["vertical_and_slash", 1000, 6096, 0.9761123657226562], "8": ["vertical_and_slash", 1000, 6096, 0.6793014407157898], "9": ["vertical_and_slash", 30, 800, 0.8573703765869141], "10": ["vertical_and_slash", 500, 700, 0.9037665128707886], "11": ["stream_llm", 100, 800, 0.94921875], "12": ["stream_llm", 100, 800, 0.59375], "13": ["vertical_and_slash", 30, 800, 0.9938877820968628], "14": ["vertical_and_slash", 30, 800, 0.9964749217033386], "15": ["stream_llm", 100, 800, 0.9765625], "16": ["vertical_and_slash", 500, 700, 0.9928801655769348], "17": ["stream_llm", 100, 800, 0.859375], "18": ["stream_llm", 100, 800, 0.93359375], "19": ["vertical_and_slash", 500, 700, 0.9897311329841614], "20": ["stream_llm", 100, 800, 0.96875], "21": ["stream_llm", 100, 800, 0.9296875], "22": ["vertical_and_slash", 1000, 6096, 0.49674782156944275], "23": ["vertical_and_slash", 1000, 6096, 0.5498730540275574], "24": ["vertical_and_slash", 1000, 6096, 0.6677294373512268], "25": ["vertical_and_slash", 30, 800, 0.8520674109458923], "26": ["vertical_and_slash", 30, 800, 0.9708148241043091], "27": ["vertical_and_slash", 1000, 6096, 0.9498739838600159], "28": ["vertical_and_slash", 30, 800, 0.9852201342582703], "29": ["vertical_and_slash", 30, 800, 0.9892252683639526], "30": ["vertical_and_slash", 30, 800, 0.9976245164871216], "31": ["stream_llm", 100, 800, 0.91796875]}, {"0": ["vertical_and_slash", 30, 800, 0.976232647895813], "1": ["vertical_and_slash", 1000, 6096, 0.850098729133606], "2": ["vertical_and_slash", 30, 800, 0.9943907260894775], "3": ["stream_llm", 100, 800, 0.984375], "4": ["vertical_and_slash", 1000, 6096, 0.9408355355262756], "5": ["stream_llm", 100, 800, 0.62109375], "6": ["vertical_and_slash", 30, 800, 0.9146958589553833], "7": ["stream_llm", 100, 800, 0.578125], "8": ["vertical_and_slash", 1000, 6096, 0.9866257905960083], "9": ["stream_llm", 100, 800, 0.8671875], "10": ["stream_llm", 100, 800, 0.98828125], "11": ["stream_llm", 100, 800, 0.80078125], "12": ["vertical_and_slash", 30, 800, 0.9795709252357483], "13": ["vertical_and_slash", 1000, 6096, 0.9181753396987915], "14": ["vertical_and_slash", 30, 800, 0.9088999032974243], "15": ["stream_llm", 100, 800, 1.0], "16": ["stream_llm", 100, 800, 0.93359375], "17": ["vertical_and_slash", 1000, 6096, 0.7872908115386963], "18": ["stream_llm", 100, 800, 0.96875], "19": ["vertical_and_slash", 30, 800, 0.9915726184844971], "20": ["vertical_and_slash", 30, 800, 0.9914611577987671], "21": ["stream_llm", 100, 800, 0.94921875], "22": ["stream_llm", 100, 800, 0.91796875], "23": ["vertical_and_slash", 3500, 100, 0.4178726077079773], "24": ["vertical_and_slash", 1000, 6096, 0.9209551811218262], "25": ["stream_llm", 100, 800, 0.953125], "26": ["vertical_and_slash", 1000, 6096, 0.8251335024833679], "27": ["vertical_and_slash", 1000, 6096, 0.7916073799133301], "28": ["stream_llm", 100, 800, 0.98046875], "29": ["vertical_and_slash", 30, 800, 0.9805914163589478], "30": ["vertical_and_slash", 30, 800, 0.9889715313911438], "31": ["vertical_and_slash", 30, 800, 0.7096468210220337]}, {"0": ["vertical_and_slash", 3500, 100, 0.9098867774009705], "1": ["vertical_and_slash", 1000, 6096, 0.9131186008453369], "2": ["vertical_and_slash", 1000, 6096, 0.6216369271278381], "3": ["vertical_and_slash", 3500, 100, 0.9781222939491272], "4": ["vertical_and_slash", 1000, 6096, 0.6995159983634949], "5": ["vertical_and_slash", 30, 800, 0.7733919620513916], "6": ["stream_llm", 100, 800, 0.8046875], "7": ["stream_llm", 100, 800, 0.9921875], "8": ["vertical_and_slash", 1000, 6096, 0.9208213686943054], "9": ["vertical_and_slash", 30, 800, 0.9892569780349731], "10": ["stream_llm", 100, 800, 0.65234375], "11": ["vertical_and_slash", 3500, 100, 0.8766616582870483], "12": ["stream_llm", 100, 800, 0.69140625], "13": ["vertical_and_slash", 30, 800, 0.9681114554405212], "14": ["vertical_and_slash", 30, 800, 0.954004168510437], "15": ["vertical_and_slash", 1000, 6096, 0.6683151721954346], "16": ["vertical_and_slash", 1000, 6096, 0.9404566287994385], "17": ["vertical_and_slash", 1000, 6096, 0.629856288433075], "18": ["vertical_and_slash", 500, 700, 0.9569997191429138], "19": ["vertical_and_slash", 1000, 6096, 0.9538705348968506], "20": ["stream_llm", 100, 800, 0.85546875], "21": ["vertical_and_slash", 1000, 6096, 0.8144884705543518], "22": ["vertical_and_slash", 30, 800, 0.95702064037323], "23": ["stream_llm", 100, 800, 0.99609375], "24": ["vertical_and_slash", 1000, 6096, 0.8552843928337097], "25": ["stream_llm", 100, 800, 0.93359375], "26": ["vertical_and_slash", 1000, 6096, 0.8885473012924194], "27": ["vertical_and_slash", 30, 800, 0.9034969210624695], "28": ["vertical_and_slash", 30, 800, 0.8834430575370789], "29": ["stream_llm", 100, 800, 0.59765625], "30": ["stream_llm", 100, 800, 0.98046875], "31": ["vertical_and_slash", 1000, 6096, 0.5801111459732056]}, {"0": ["vertical_and_slash", 1000, 6096, 0.9783773422241211], "1": ["vertical_and_slash", 1000, 6096, 0.9992927312850952], "2": ["vertical_and_slash", 30, 800, 0.9968302845954895], "3": ["vertical_and_slash", 3500, 100, 0.45828360319137573], "4": ["vertical_and_slash", 30, 800, 0.836064875125885], "5": ["vertical_and_slash", 1000, 6096, 0.8009666800498962], "6": ["vertical_and_slash", 3500, 100, 0.6518401503562927], "7": ["vertical_and_slash", 30, 800, 0.9921544790267944], "8": ["vertical_and_slash", 1000, 6096, 0.4855879545211792], "9": ["vertical_and_slash", 1000, 6096, 0.9904646277427673], "10": ["vertical_and_slash", 3500, 100, 0.8973155617713928], "11": ["vertical_and_slash", 1000, 6096, 0.8983845710754395], "12": ["stream_llm", 100, 800, 0.82421875], "13": ["vertical_and_slash", 1000, 6096, 0.8326148390769958], "14": ["vertical_and_slash", 1000, 6096, 0.44982603192329407], "15": ["vertical_and_slash", 30, 800, 0.9292823076248169], "16": ["stream_llm", 100, 800, 0.83203125], "17": ["vertical_and_slash", 500, 700, 0.8943775296211243], "18": ["vertical_and_slash", 3500, 100, 0.8824247121810913], "19": ["vertical_and_slash", 1000, 6096, 0.8916551470756531], "20": ["stream_llm", 100, 800, 0.84765625], "21": ["vertical_and_slash", 1000, 6096, 0.5656689405441284], "22": ["vertical_and_slash", 3500, 100, 0.9858580827713013], "23": ["vertical_and_slash", 3500, 100, 0.6534677743911743], "24": ["vertical_and_slash", 1000, 6096, 0.7796179056167603], "25": ["stream_llm", 100, 800, 0.984375], "26": ["stream_llm", 100, 800, 0.8125], "27": ["vertical_and_slash", 1000, 6096, 0.8051357269287109], "28": ["vertical_and_slash", 1000, 6096, 0.9759415984153748], "29": ["vertical_and_slash", 3500, 100, 0.9613996148109436], "30": ["vertical_and_slash", 30, 800, 0.9861305952072144], "31": ["vertical_and_slash", 1000, 6096, 0.5375377535820007]}, {"0": ["vertical_and_slash", 1000, 6096, 0.9526095390319824], "1": ["vertical_and_slash", 1000, 6096, 0.9219456315040588], "2": ["vertical_and_slash", 1000, 6096, 0.6329025626182556], "3": ["vertical_and_slash", 1000, 6096, 0.9703953862190247], "4": ["vertical_and_slash", 3500, 100, 0.9341285228729248], "5": ["stream_llm", 100, 800, 0.98828125], "6": ["vertical_and_slash", 3500, 100, 0.975139319896698], "7": ["vertical_and_slash", 30, 800, 0.9698626399040222], "8": ["vertical_and_slash", 1000, 6096, 0.8665440082550049], "9": ["vertical_and_slash", 1000, 6096, 0.9887139797210693], "10": ["vertical_and_slash", 1000, 6096, 0.9663894772529602], "11": ["vertical_and_slash", 500, 700, 0.9613908529281616], "12": ["vertical_and_slash", 1000, 6096, 0.9625579118728638], "13": ["vertical_and_slash", 3500, 100, 0.8293338418006897], "14": ["vertical_and_slash", 1000, 6096, 0.9918296933174133], "15": ["vertical_and_slash", 3500, 100, 0.6993081569671631], "16": ["vertical_and_slash", 1000, 6096, 0.7726790904998779], "17": ["vertical_and_slash", 30, 800, 0.9927448034286499], "18": ["vertical_and_slash", 3500, 100, 0.9216746091842651], "19": ["vertical_and_slash", 1000, 6096, 0.9197890758514404], "20": ["vertical_and_slash", 1000, 6096, 0.5418304800987244], "21": ["vertical_and_slash", 3500, 100, 0.7247577905654907], "22": ["vertical_and_slash", 1000, 6096, 0.8909022212028503], "23": ["vertical_and_slash", 3500, 100, 0.6162543892860413], "24": ["vertical_and_slash", 1000, 6096, 0.9798792600631714], "25": ["stream_llm", 100, 800, 0.9921875], "26": ["vertical_and_slash", 1000, 6096, 0.839588463306427], "27": ["stream_llm", 100, 800, 0.921875], "28": ["vertical_and_slash", 1000, 6096, 0.9863616228103638], "29": ["vertical_and_slash", 1000, 6096, 0.9895434975624084], "30": ["vertical_and_slash", 1000, 6096, 0.9338933825492859], "31": ["vertical_and_slash", 1000, 6096, 0.9152888655662537]}, {"0": ["vertical_and_slash", 100, 750, 0.7857484221458435], "1": ["vertical_and_slash", 3500, 100, 0.9863781332969666], "2": ["vertical_and_slash", 3500, 100, 0.9732434153556824], "3": ["vertical_and_slash", 1000, 6096, 0.7411113381385803], "4": ["vertical_and_slash", 1000, 6096, 0.9037321209907532], "5": ["vertical_and_slash", 1000, 6096, 0.7728227376937866], "6": ["vertical_and_slash", 3500, 100, 0.9566982388496399], "7": ["vertical_and_slash", 1000, 6096, 0.8955481648445129], "8": ["vertical_and_slash", 500, 700, 0.8905653357505798], "9": ["vertical_and_slash", 3500, 100, 0.9852890968322754], "10": ["vertical_and_slash", 1000, 6096, 0.5732011795043945], "11": ["vertical_and_slash", 3500, 100, 0.9701256155967712], "12": ["vertical_and_slash", 3500, 100, 0.8983554244041443], "13": ["vertical_and_slash", 100, 750, 0.9726784825325012], "14": ["vertical_and_slash", 3500, 100, 0.6008065938949585], "15": ["vertical_and_slash", 1000, 6096, 0.6582738161087036], "16": ["vertical_and_slash", 3500, 100, 0.9488815665245056], "17": ["vertical_and_slash", 100, 750, 0.9958171844482422], "18": ["vertical_and_slash", 3500, 100, 0.8186895847320557], "19": ["vertical_and_slash", 500, 700, 0.9635193347930908], "20": ["vertical_and_slash", 1000, 6096, 0.9248959422111511], "21": ["vertical_and_slash", 3500, 100, 0.9385164976119995], "22": ["vertical_and_slash", 100, 750, 0.9387568235397339], "23": ["vertical_and_slash", 1000, 6096, 0.8735635876655579], "24": ["vertical_and_slash", 500, 700, 0.890371561050415], "25": ["vertical_and_slash", 100, 750, 0.9905737638473511], "26": ["vertical_and_slash", 3500, 100, 0.946341335773468], "27": ["vertical_and_slash", 3500, 100, 0.942945659160614], "28": ["vertical_and_slash", 100, 750, 0.994683027267456], "29": ["vertical_and_slash", 500, 700, 0.9688966870307922], "30": ["vertical_and_slash", 1000, 6096, 0.9828435778617859], "31": ["vertical_and_slash", 1000, 6096, 0.8722150325775146]}, {"0": ["vertical_and_slash", 500, 700, 0.9728457927703857], "1": ["vertical_and_slash", 100, 750, 0.9586004018783569], "2": ["vertical_and_slash", 3500, 100, 0.9719207882881165], "3": ["vertical_and_slash", 3500, 100, 0.6680086851119995], "4": ["vertical_and_slash", 3500, 100, 0.970458984375], "5": ["vertical_and_slash", 3500, 100, 0.7634486556053162], "6": ["vertical_and_slash", 3500, 100, 0.7259127497673035], "7": ["vertical_and_slash", 100, 750, 0.9781140089035034], "8": ["vertical_and_slash", 3500, 100, 0.9952470064163208], "9": ["vertical_and_slash", 3500, 100, 0.9868772625923157], "10": ["vertical_and_slash", 3500, 100, 0.558458685874939], "11": ["vertical_and_slash", 1000, 6096, 0.7121242880821228], "12": ["vertical_and_slash", 1000, 6096, 0.7061645984649658], "13": ["vertical_and_slash", 3500, 100, 0.923751711845398], "14": ["vertical_and_slash", 1000, 6096, 0.8015576601028442], "15": ["vertical_and_slash", 500, 700, 0.9007270932197571], "16": ["vertical_and_slash", 3500, 100, 0.9591111540794373], "17": ["vertical_and_slash", 500, 700, 0.9750815033912659], "18": ["vertical_and_slash", 100, 750, 0.9805834293365479], "19": ["vertical_and_slash", 3500, 100, 0.8620939254760742], "20": ["vertical_and_slash", 3500, 100, 0.9881291389465332], "21": ["vertical_and_slash", 500, 700, 0.9975225925445557], "22": ["vertical_and_slash", 3500, 100, 0.9125117063522339], "23": ["vertical_and_slash", 3500, 100, 0.8796795010566711], "24": ["vertical_and_slash", 3500, 100, 0.9172841310501099], "25": ["vertical_and_slash", 1000, 6096, 0.8340160846710205], "26": ["vertical_and_slash", 1000, 6096, 0.8479950428009033], "27": ["vertical_and_slash", 3500, 100, 0.9778053164482117], "28": ["vertical_and_slash", 100, 750, 0.9912164211273193], "29": ["vertical_and_slash", 1000, 6096, 0.6634088754653931], "30": ["vertical_and_slash", 3500, 100, 0.9486925601959229], "31": ["vertical_and_slash", 3500, 100, 0.985546350479126]}, {"0": ["vertical_and_slash", 3500, 100, 0.7207826375961304], "1": ["vertical_and_slash", 1000, 6096, 0.7674809098243713], "2": ["vertical_and_slash", 1000, 6096, 0.5480814576148987], "3": ["vertical_and_slash", 3500, 100, 0.974454939365387], "4": ["vertical_and_slash", 100, 750, 0.9901475310325623], "5": ["vertical_and_slash", 3500, 100, 0.9111185073852539], "6": ["vertical_and_slash", 3500, 100, 0.8977652192115784], "7": ["vertical_and_slash", 500, 700, 0.8826637864112854], "8": ["vertical_and_slash", 3500, 100, 0.9674721956253052], "9": ["vertical_and_slash", 500, 700, 0.9511355757713318], "10": ["vertical_and_slash", 3500, 100, 0.9368802309036255], "11": ["vertical_and_slash", 3500, 100, 0.7037530541419983], "12": ["vertical_and_slash", 3500, 100, 0.8404982089996338], "13": ["vertical_and_slash", 3500, 100, 0.9477558732032776], "14": ["vertical_and_slash", 1000, 6096, 0.5408625602722168], "15": ["vertical_and_slash", 1000, 6096, 0.8930901288986206], "16": ["vertical_and_slash", 500, 700, 0.9620649814605713], "17": ["vertical_and_slash", 3500, 100, 0.9665637016296387], "18": ["vertical_and_slash", 3500, 100, 0.9973539710044861], "19": ["vertical_and_slash", 3500, 100, 0.9200847744941711], "20": ["vertical_and_slash", 100, 750, 0.9846996068954468], "21": ["vertical_and_slash", 3500, 100, 0.9522152543067932], "22": ["vertical_and_slash", 3500, 100, 0.9200462102890015], "23": ["vertical_and_slash", 3500, 100, 0.7189115285873413], "24": ["vertical_and_slash", 3500, 100, 0.9400286078453064], "25": ["vertical_and_slash", 3500, 100, 0.9140079617500305], "26": ["vertical_and_slash", 3500, 100, 0.9733141660690308], "27": ["vertical_and_slash", 3500, 100, 0.9182970523834229], "28": ["vertical_and_slash", 500, 700, 0.7845987677574158], "29": ["vertical_and_slash", 500, 700, 0.953305721282959], "30": ["vertical_and_slash", 1000, 6096, 0.9332642555236816], "31": ["vertical_and_slash", 500, 700, 0.8975687026977539]}, {"0": ["vertical_and_slash", 1000, 6096, 0.8796314001083374], "1": ["vertical_and_slash", 3500, 100, 0.9541191458702087], "2": ["vertical_and_slash", 3500, 100, 0.9853596091270447], "3": ["vertical_and_slash", 3500, 100, 0.9959757924079895], "4": ["vertical_and_slash", 500, 700, 0.942274272441864], "5": ["vertical_and_slash", 3500, 100, 0.9958774447441101], "6": ["vertical_and_slash", 3500, 100, 0.762219250202179], "7": ["vertical_and_slash", 3500, 100, 0.9778050780296326], "8": ["vertical_and_slash", 3500, 100, 0.9803900718688965], "9": ["vertical_and_slash", 3500, 100, 0.9493845701217651], "10": ["vertical_and_slash", 100, 750, 0.9833114147186279], "11": ["vertical_and_slash", 3500, 100, 0.9671387076377869], "12": ["vertical_and_slash", 3500, 100, 0.8459083437919617], "13": ["vertical_and_slash", 3500, 100, 0.9625062346458435], "14": ["vertical_and_slash", 3500, 100, 0.9926583766937256], "15": ["vertical_and_slash", 3500, 100, 0.9901418089866638], "16": ["vertical_and_slash", 3500, 100, 0.9975236058235168], "17": ["vertical_and_slash", 3500, 100, 0.8961046934127808], "18": ["vertical_and_slash", 3500, 100, 0.9677743315696716], "19": ["vertical_and_slash", 1000, 6096, 0.7324523329734802], "20": ["vertical_and_slash", 1000, 6096, 0.7565687298774719], "21": ["vertical_and_slash", 3500, 100, 0.9934558272361755], "22": ["vertical_and_slash", 1000, 6096, 0.695542573928833], "23": ["vertical_and_slash", 3500, 100, 0.9594518542289734], "24": ["vertical_and_slash", 3500, 100, 0.9845080375671387], "25": ["vertical_and_slash", 3500, 100, 0.9140312075614929], "26": ["vertical_and_slash", 3500, 100, 0.9816687107086182], "27": ["vertical_and_slash", 3500, 100, 0.9777555465698242], "28": ["vertical_and_slash", 3500, 100, 0.948824405670166], "29": ["vertical_and_slash", 3500, 100, 0.48502659797668457], "30": ["vertical_and_slash", 3500, 100, 0.9340038895606995], "31": ["vertical_and_slash", 3500, 100, 0.9162462949752808]}, {"0": ["vertical_and_slash", 3500, 100, 0.9923238754272461], "1": ["vertical_and_slash", 3500, 100, 0.9678853750228882], "2": ["vertical_and_slash", 100, 750, 0.9968323111534119], "3": ["vertical_and_slash", 500, 700, 0.9936473965644836], "4": ["vertical_and_slash", 3500, 100, 0.9588732123374939], "5": ["vertical_and_slash", 500, 700, 0.9791616797447205], "6": ["vertical_and_slash", 3500, 100, 0.919694721698761], "7": ["vertical_and_slash", 1000, 6096, 0.626932680606842], "8": ["vertical_and_slash", 3500, 100, 0.9546087980270386], "9": ["vertical_and_slash", 500, 700, 0.8930553793907166], "10": ["vertical_and_slash", 100, 750, 0.9767886996269226], "11": ["vertical_and_slash", 1000, 6096, 0.7312592267990112], "12": ["vertical_and_slash", 3500, 100, 0.9913722276687622], "13": ["vertical_and_slash", 3500, 100, 0.9425638914108276], "14": ["vertical_and_slash", 3500, 100, 0.9949523210525513], "15": ["vertical_and_slash", 100, 750, 0.7187187671661377], "16": ["vertical_and_slash", 3500, 100, 0.9734897017478943], "17": ["vertical_and_slash", 3500, 100, 0.9750894904136658], "18": ["vertical_and_slash", 3500, 100, 0.9543801546096802], "19": ["vertical_and_slash", 3500, 100, 0.94287109375], "20": ["vertical_and_slash", 1000, 6096, 0.7409213185310364], "21": ["vertical_and_slash", 3500, 100, 0.9789512753486633], "22": ["vertical_and_slash", 3500, 100, 0.9824472069740295], "23": ["vertical_and_slash", 3500, 100, 0.9614876508712769], "24": ["vertical_and_slash", 500, 700, 0.9097415208816528], "25": ["vertical_and_slash", 3500, 100, 0.7589483857154846], "26": ["vertical_and_slash", 3500, 100, 0.9711624979972839], "27": ["vertical_and_slash", 500, 700, 0.9924762845039368], "28": ["vertical_and_slash", 3500, 100, 0.8917614221572876], "29": ["vertical_and_slash", 500, 700, 0.9802823066711426], "30": ["vertical_and_slash", 3500, 100, 0.9433683156967163], "31": ["vertical_and_slash", 3500, 100, 0.9959222078323364]}, {"0": ["vertical_and_slash", 3500, 100, 0.8028379678726196], "1": ["vertical_and_slash", 3500, 100, 0.9934322237968445], "2": ["vertical_and_slash", 3500, 100, 0.9233330488204956], "3": ["vertical_and_slash", 500, 700, 0.9530222415924072], "4": ["vertical_and_slash", 1000, 6096, 0.7554510831832886], "5": ["vertical_and_slash", 3500, 100, 0.9931245446205139], "6": ["vertical_and_slash", 3500, 100, 0.8175129890441895], "7": ["vertical_and_slash", 500, 700, 0.9769982695579529], "8": ["vertical_and_slash", 3500, 100, 0.7803007364273071], "9": ["vertical_and_slash", 3500, 100, 0.8488234281539917], "10": ["vertical_and_slash", 1000, 6096, 0.7556964159011841], "11": ["vertical_and_slash", 100, 750, 0.9249212145805359], "12": ["vertical_and_slash", 1000, 6096, 0.5030975937843323], "13": ["vertical_and_slash", 3500, 100, 0.7736669778823853], "14": ["vertical_and_slash", 3500, 100, 0.8432313203811646], "15": ["vertical_and_slash", 3500, 100, 0.8078522086143494], "16": ["vertical_and_slash", 1000, 6096, 0.6152622699737549], "17": ["vertical_and_slash", 1000, 6096, 0.4801797866821289], "18": ["vertical_and_slash", 3500, 100, 0.7792356610298157], "19": ["vertical_and_slash", 3500, 100, 0.9260709285736084], "20": ["vertical_and_slash", 3500, 100, 0.9572370052337646], "21": ["vertical_and_slash", 500, 700, 0.9757252335548401], "22": ["vertical_and_slash", 100, 750, 0.9295142889022827], "23": ["vertical_and_slash", 100, 750, 0.8406566381454468], "24": ["vertical_and_slash", 500, 700, 0.9934183955192566], "25": ["vertical_and_slash", 3500, 100, 0.9811476469039917], "26": ["vertical_and_slash", 1000, 6096, 0.43748241662979126], "27": ["vertical_and_slash", 1000, 6096, 0.8173736929893494], "28": ["vertical_and_slash", 1000, 6096, 0.7964892983436584], "29": ["vertical_and_slash", 1000, 6096, 0.5660628080368042], "30": ["vertical_and_slash", 100, 750, 0.8858906626701355], "31": ["vertical_and_slash", 3500, 100, 0.7301779389381409]}, {"0": ["vertical_and_slash", 1000, 6096, 0.8143554925918579], "1": ["vertical_and_slash", 3500, 100, 0.8302785754203796], "2": ["vertical_and_slash", 3500, 100, 0.9859114289283752], "3": ["vertical_and_slash", 3500, 100, 0.6922958493232727], "4": ["vertical_and_slash", 3500, 100, 0.9597254991531372], "5": ["vertical_and_slash", 1000, 6096, 0.8074929714202881], "6": ["vertical_and_slash", 3500, 100, 0.7841739654541016], "7": ["vertical_and_slash", 3500, 100, 0.9443768262863159], "8": ["vertical_and_slash", 3500, 100, 0.9327424764633179], "9": ["vertical_and_slash", 3500, 100, 0.8796824812889099], "10": ["vertical_and_slash", 3500, 100, 0.9468095302581787], "11": ["vertical_and_slash", 3500, 100, 0.9797954559326172], "12": ["vertical_and_slash", 3500, 100, 0.9876496195793152], "13": ["vertical_and_slash", 100, 750, 0.9684455394744873], "14": ["vertical_and_slash", 3500, 100, 0.9720463156700134], "15": ["vertical_and_slash", 3500, 100, 0.9134085774421692], "16": ["vertical_and_slash", 100, 750, 0.9962508678436279], "17": ["vertical_and_slash", 3500, 100, 0.9967661499977112], "18": ["vertical_and_slash", 3500, 100, 0.9218150973320007], "19": ["vertical_and_slash", 3500, 100, 0.9165892601013184], "20": ["vertical_and_slash", 500, 700, 0.9811153411865234], "21": ["vertical_and_slash", 1000, 6096, 0.8401690721511841], "22": ["vertical_and_slash", 100, 750, 0.9827044606208801], "23": ["vertical_and_slash", 500, 700, 0.9265505075454712], "24": ["vertical_and_slash", 3500, 100, 0.8814885020256042], "25": ["vertical_and_slash", 1000, 6096, 0.8774723410606384], "26": ["vertical_and_slash", 1000, 6096, 0.8981026411056519], "27": ["vertical_and_slash", 100, 750, 0.995216429233551], "28": ["vertical_and_slash", 3500, 100, 0.9950628280639648], "29": ["vertical_and_slash", 500, 700, 0.9678530693054199], "30": ["vertical_and_slash", 100, 750, 0.9900303483009338], "31": ["vertical_and_slash", 3500, 100, 0.9148485064506531]}, {"0": ["vertical_and_slash", 3500, 100, 0.7734143137931824], "1": ["vertical_and_slash", 3500, 100, 0.9431662559509277], "2": ["vertical_and_slash", 100, 750, 0.9125087857246399], "3": ["vertical_and_slash", 3500, 100, 0.9382316470146179], "4": ["vertical_and_slash", 1000, 6096, 0.7059416174888611], "5": ["vertical_and_slash", 3500, 100, 0.6978054642677307], "6": ["vertical_and_slash", 3500, 100, 0.9927070140838623], "7": ["vertical_and_slash", 3500, 100, 0.9393529295921326], "8": ["vertical_and_slash", 100, 750, 0.9231113195419312], "9": ["vertical_and_slash", 3500, 100, 0.9985975623130798], "10": ["vertical_and_slash", 500, 700, 0.9555321335792542], "11": ["vertical_and_slash", 3500, 100, 0.9785676002502441], "12": ["vertical_and_slash", 500, 700, 0.9968464374542236], "13": ["vertical_and_slash", 3500, 100, 0.9894333481788635], "14": ["vertical_and_slash", 500, 700, 0.8927757740020752], "15": ["vertical_and_slash", 3500, 100, 0.9463996887207031], "16": ["vertical_and_slash", 3500, 100, 0.9756723642349243], "17": ["vertical_and_slash", 3500, 100, 0.970882773399353], "18": ["vertical_and_slash", 1000, 6096, 0.6809303164482117], "19": ["vertical_and_slash", 3500, 100, 0.9938862919807434], "20": ["vertical_and_slash", 3500, 100, 0.9821802973747253], "21": ["vertical_and_slash", 3500, 100, 0.9383650422096252], "22": ["vertical_and_slash", 3500, 100, 0.8643637299537659], "23": ["vertical_and_slash", 100, 750, 0.9771586656570435], "24": ["vertical_and_slash", 500, 700, 0.976405143737793], "25": ["vertical_and_slash", 3500, 100, 0.9743276238441467], "26": ["vertical_and_slash", 3500, 100, 0.9265220761299133], "27": ["vertical_and_slash", 3500, 100, 0.9841408729553223], "28": ["vertical_and_slash", 500, 700, 0.9391534328460693], "29": ["vertical_and_slash", 3500, 100, 0.9312986135482788], "30": ["vertical_and_slash", 3500, 100, 0.8832992911338806], "31": ["vertical_and_slash", 3500, 100, 0.9811874628067017]}, {"0": ["vertical_and_slash", 3500, 100, 0.9956807494163513], "1": ["vertical_and_slash", 3500, 100, 0.9670407772064209], "2": ["vertical_and_slash", 100, 750, 0.9973832964897156], "3": ["vertical_and_slash", 100, 750, 0.99891597032547], "4": ["vertical_and_slash", 3500, 100, 0.9931758642196655], "5": ["vertical_and_slash", 100, 750, 0.996113121509552], "6": ["vertical_and_slash", 3500, 100, 0.9983065724372864], "7": ["vertical_and_slash", 3500, 100, 0.9833848476409912], "8": ["vertical_and_slash", 3500, 100, 0.9948523640632629], "9": ["vertical_and_slash", 3500, 100, 0.8683006167411804], "10": ["vertical_and_slash", 3500, 100, 0.9931465983390808], "11": ["vertical_and_slash", 100, 750, 0.984261691570282], "12": ["vertical_and_slash", 100, 750, 0.9601353406906128], "13": ["vertical_and_slash", 500, 700, 0.9203216433525085], "14": ["vertical_and_slash", 3500, 100, 0.9650700092315674], "15": ["vertical_and_slash", 100, 750, 0.984341561794281], "16": ["vertical_and_slash", 3500, 100, 0.9989381432533264], "17": ["vertical_and_slash", 1000, 6096, 0.8591818809509277], "18": ["vertical_and_slash", 500, 700, 0.959535539150238], "19": ["vertical_and_slash", 3500, 100, 0.9685975909233093], "20": ["vertical_and_slash", 3500, 100, 0.9992274045944214], "21": ["vertical_and_slash", 3500, 100, 0.9054502248764038], "22": ["vertical_and_slash", 3500, 100, 0.9957486391067505], "23": ["vertical_and_slash", 3500, 100, 0.9970229864120483], "24": ["vertical_and_slash", 3500, 100, 0.933996319770813], "25": ["vertical_and_slash", 3500, 100, 0.9522771239280701], "26": ["vertical_and_slash", 3500, 100, 0.8640444278717041], "27": ["vertical_and_slash", 3500, 100, 0.9864702820777893], "28": ["vertical_and_slash", 1000, 6096, 0.8701584935188293], "29": ["vertical_and_slash", 3500, 100, 0.9872081279754639], "30": ["vertical_and_slash", 3500, 100, 0.9637035727500916], "31": ["vertical_and_slash", 3500, 100, 0.7964584827423096]}, {"0": ["vertical_and_slash", 500, 700, 0.944079577922821], "1": ["vertical_and_slash", 1000, 6096, 0.7686152458190918], "2": ["vertical_and_slash", 3500, 100, 0.9423201680183411], "3": ["vertical_and_slash", 3500, 100, 0.9597930908203125], "4": ["vertical_and_slash", 3500, 100, 0.9981894493103027], "5": ["vertical_and_slash", 100, 750, 0.9951789975166321], "6": ["vertical_and_slash", 3500, 100, 0.9678981304168701], "7": ["vertical_and_slash", 3500, 100, 0.8912110924720764], "8": ["vertical_and_slash", 100, 750, 0.9829361438751221], "9": ["vertical_and_slash", 500, 700, 0.9326693415641785], "10": ["vertical_and_slash", 3500, 100, 0.7954592108726501], "11": ["vertical_and_slash", 3500, 100, 0.9361847639083862], "12": ["vertical_and_slash", 3500, 100, 0.9777213335037231], "13": ["vertical_and_slash", 100, 750, 0.7402770519256592], "14": ["vertical_and_slash", 1000, 6096, 0.8369068503379822], "15": ["vertical_and_slash", 3500, 100, 0.8386251926422119], "16": ["vertical_and_slash", 500, 700, 0.9928125143051147], "17": ["vertical_and_slash", 3500, 100, 0.9980320930480957], "18": ["vertical_and_slash", 100, 750, 0.99200838804245], "19": ["vertical_and_slash", 3500, 100, 0.9937632083892822], "20": ["vertical_and_slash", 1000, 6096, 0.8582853674888611], "21": ["vertical_and_slash", 500, 700, 0.8901017308235168], "22": ["vertical_and_slash", 3500, 100, 0.9825611710548401], "23": ["vertical_and_slash", 3500, 100, 0.9956728219985962], "24": ["vertical_and_slash", 3500, 100, 0.992565929889679], "25": ["vertical_and_slash", 3500, 100, 0.9841880202293396], "26": ["vertical_and_slash", 1000, 6096, 0.8873481750488281], "27": ["vertical_and_slash", 100, 750, 0.9767672419548035], "28": ["vertical_and_slash", 3500, 100, 0.9931612610816956], "29": ["vertical_and_slash", 3500, 100, 0.9209384918212891], "30": ["vertical_and_slash", 100, 750, 0.7578334212303162], "31": ["vertical_and_slash", 3500, 100, 0.9578611850738525]}, {"0": ["vertical_and_slash", 100, 750, 0.9389412999153137], "1": ["vertical_and_slash", 100, 750, 0.9428157210350037], "2": ["vertical_and_slash", 3500, 100, 0.9956400990486145], "3": ["vertical_and_slash", 100, 750, 0.9144065976142883], "4": ["vertical_and_slash", 1000, 6096, 0.8475824594497681], "5": ["vertical_and_slash", 100, 750, 0.996335506439209], "6": ["vertical_and_slash", 3500, 100, 0.9988783597946167], "7": ["vertical_and_slash", 3500, 100, 0.94597989320755], "8": ["vertical_and_slash", 3500, 100, 0.9713111519813538], "9": ["vertical_and_slash", 100, 750, 0.9670871496200562], "10": ["vertical_and_slash", 3500, 100, 0.9996585249900818], "11": ["vertical_and_slash", 3500, 100, 0.9820530414581299], "12": ["vertical_and_slash", 3500, 100, 0.9983968138694763], "13": ["vertical_and_slash", 3500, 100, 0.9315072298049927], "14": ["vertical_and_slash", 3500, 100, 0.9930176138877869], "15": ["vertical_and_slash", 500, 700, 0.9945250749588013], "16": ["vertical_and_slash", 100, 750, 0.9049948453903198], "17": ["vertical_and_slash", 3500, 100, 0.9992651343345642], "18": ["vertical_and_slash", 500, 700, 0.9942126274108887], "19": ["vertical_and_slash", 500, 700, 0.9891477227210999], "20": ["vertical_and_slash", 3500, 100, 0.9028084874153137], "21": ["vertical_and_slash", 100, 750, 0.9475080370903015], "22": ["vertical_and_slash", 500, 700, 0.9690455794334412], "23": ["vertical_and_slash", 3500, 100, 0.9446419477462769], "24": ["vertical_and_slash", 3500, 100, 0.9801247715950012], "25": ["vertical_and_slash", 100, 750, 0.9777910113334656], "26": ["vertical_and_slash", 3500, 100, 0.7017547488212585], "27": ["vertical_and_slash", 3500, 100, 0.9493237137794495], "28": ["vertical_and_slash", 100, 750, 0.9993017315864563], "29": ["vertical_and_slash", 3500, 100, 0.893531858921051], "30": ["vertical_and_slash", 3500, 100, 0.9467594623565674], "31": ["vertical_and_slash", 3500, 100, 0.9743610620498657]}, {"0": ["vertical_and_slash", 3500, 100, 0.985114574432373], "1": ["vertical_and_slash", 500, 700, 0.9950987696647644], "2": ["vertical_and_slash", 3500, 100, 0.7027000784873962], "3": ["vertical_and_slash", 3500, 100, 0.9855831265449524], "4": ["vertical_and_slash", 3500, 100, 0.9874288439750671], "5": ["vertical_and_slash", 1000, 6096, 0.7125917673110962], "6": ["vertical_and_slash", 3500, 100, 0.9454708695411682], "7": ["vertical_and_slash", 3500, 100, 0.9898356199264526], "8": ["vertical_and_slash", 3500, 100, 0.9445544481277466], "9": ["vertical_and_slash", 3500, 100, 0.988140344619751], "10": ["vertical_and_slash", 500, 700, 0.981208860874176], "11": ["vertical_and_slash", 500, 700, 0.9874861836433411], "12": ["vertical_and_slash", 3500, 100, 0.9963038563728333], "13": ["vertical_and_slash", 100, 750, 0.9972052574157715], "14": ["vertical_and_slash", 3500, 100, 0.9943816065788269], "15": ["vertical_and_slash", 100, 750, 0.8364889025688171], "16": ["vertical_and_slash", 100, 750, 0.9870871901512146], "17": ["vertical_and_slash", 100, 750, 0.998099684715271], "18": ["vertical_and_slash", 3500, 100, 0.8674955368041992], "19": ["vertical_and_slash", 500, 700, 0.9969808459281921], "20": ["vertical_and_slash", 3500, 100, 0.8848986625671387], "21": ["vertical_and_slash", 1000, 6096, 0.867315411567688], "22": ["vertical_and_slash", 500, 700, 0.9908551573753357], "23": ["vertical_and_slash", 100, 750, 0.8952099680900574], "24": ["vertical_and_slash", 500, 700, 0.9714990854263306], "25": ["vertical_and_slash", 100, 750, 0.8733819723129272], "26": ["vertical_and_slash", 3500, 100, 0.9205271005630493], "27": ["vertical_and_slash", 3500, 100, 0.9833540916442871], "28": ["vertical_and_slash", 3500, 100, 0.9445760846138], "29": ["vertical_and_slash", 3500, 100, 0.9536135792732239], "30": ["vertical_and_slash", 500, 700, 0.9753504991531372], "31": ["vertical_and_slash", 1000, 6096, 0.8801259398460388]}, {"0": ["vertical_and_slash", 3500, 100, 0.9614631533622742], "1": ["vertical_and_slash", 3500, 100, 0.9763227105140686], "2": ["vertical_and_slash", 100, 750, 0.970956563949585], "3": ["vertical_and_slash", 100, 750, 0.9151788949966431], "4": ["vertical_and_slash", 3500, 100, 0.9920399188995361], "5": ["vertical_and_slash", 3500, 100, 0.9422896504402161], "6": ["vertical_and_slash", 3500, 100, 0.986482560634613], "7": ["vertical_and_slash", 3500, 100, 0.9976206421852112], "8": ["vertical_and_slash", 100, 750, 0.9943424463272095], "9": ["vertical_and_slash", 3500, 100, 0.9936824440956116], "10": ["vertical_and_slash", 3500, 100, 0.9882729649543762], "11": ["vertical_and_slash", 100, 750, 0.9862287640571594], "12": ["vertical_and_slash", 500, 700, 0.9886087775230408], "13": ["vertical_and_slash", 3500, 100, 0.9989089369773865], "14": ["vertical_and_slash", 3500, 100, 0.9651134610176086], "15": ["vertical_and_slash", 3500, 100, 0.9826948046684265], "16": ["vertical_and_slash", 3500, 100, 0.9450136423110962], "17": ["vertical_and_slash", 3500, 100, 0.9979375004768372], "18": ["vertical_and_slash", 3500, 100, 0.9520789384841919], "19": ["vertical_and_slash", 3500, 100, 0.9316532015800476], "20": ["vertical_and_slash", 100, 750, 0.9904720187187195], "21": ["vertical_and_slash", 3500, 100, 0.999125599861145], "22": ["vertical_and_slash", 3500, 100, 0.9995089769363403], "23": ["vertical_and_slash", 100, 750, 0.9886007308959961], "24": ["vertical_and_slash", 3500, 100, 0.9961583018302917], "25": ["vertical_and_slash", 3500, 100, 0.9961526393890381], "26": ["vertical_and_slash", 3500, 100, 0.9557645916938782], "27": ["vertical_and_slash", 3500, 100, 0.8775650262832642], "28": ["vertical_and_slash", 3500, 100, 0.986892580986023], "29": ["vertical_and_slash", 3500, 100, 0.9749740958213806], "30": ["vertical_and_slash", 3500, 100, 0.8765645027160645], "31": ["vertical_and_slash", 3500, 100, 0.9494763016700745]}, {"0": ["vertical_and_slash", 3500, 100, 0.9797922372817993], "1": ["vertical_and_slash", 3500, 100, 0.9958779811859131], "2": ["vertical_and_slash", 3500, 100, 0.9976977705955505], "3": ["vertical_and_slash", 3500, 100, 0.9764806628227234], "4": ["vertical_and_slash", 3500, 100, 0.9868356585502625], "5": ["vertical_and_slash", 1000, 6096, 0.8740545511245728], "6": ["vertical_and_slash", 3500, 100, 0.9939981698989868], "7": ["vertical_and_slash", 1000, 6096, 0.7613811492919922], "8": ["vertical_and_slash", 3500, 100, 0.9811347723007202], "9": ["vertical_and_slash", 3500, 100, 0.9840614795684814], "10": ["vertical_and_slash", 1000, 6096, 0.8657892346382141], "11": ["vertical_and_slash", 3500, 100, 0.9502456188201904], "12": ["vertical_and_slash", 100, 750, 0.9104490280151367], "13": ["vertical_and_slash", 3500, 100, 0.9950721263885498], "14": ["vertical_and_slash", 3500, 100, 0.9724959135055542], "15": ["vertical_and_slash", 1000, 6096, 0.8955191373825073], "16": ["vertical_and_slash", 3500, 100, 0.9936071038246155], "17": ["vertical_and_slash", 3500, 100, 0.9285928606987], "18": ["vertical_and_slash", 3500, 100, 0.756338357925415], "19": ["vertical_and_slash", 3500, 100, 0.9665532112121582], "20": ["vertical_and_slash", 100, 750, 0.9970663785934448], "21": ["vertical_and_slash", 3500, 100, 0.9806201457977295], "22": ["vertical_and_slash", 1000, 6096, 0.8115424513816833], "23": ["vertical_and_slash", 1000, 6096, 0.8631585836410522], "24": ["vertical_and_slash", 3500, 100, 0.9782901406288147], "25": ["vertical_and_slash", 3500, 100, 0.9858242273330688], "26": ["vertical_and_slash", 3500, 100, 0.9617720246315002], "27": ["vertical_and_slash", 3500, 100, 0.997412919998169], "28": ["vertical_and_slash", 3500, 100, 0.8432300090789795], "29": ["vertical_and_slash", 500, 700, 0.9955722093582153], "30": ["vertical_and_slash", 3500, 100, 0.9938695430755615], "31": ["vertical_and_slash", 3500, 100, 0.9511440396308899]}, {"0": ["vertical_and_slash", 3500, 100, 0.988155722618103], "1": ["vertical_and_slash", 3500, 100, 0.9747615456581116], "2": ["vertical_and_slash", 100, 750, 0.9718871712684631], "3": ["vertical_and_slash", 100, 750, 0.9756971597671509], "4": ["vertical_and_slash", 3500, 100, 0.947630763053894], "5": ["vertical_and_slash", 100, 750, 0.99262934923172], "6": ["vertical_and_slash", 3500, 100, 0.9955495595932007], "7": ["vertical_and_slash", 3500, 100, 0.8609271049499512], "8": ["vertical_and_slash", 3500, 100, 0.974815845489502], "9": ["vertical_and_slash", 3500, 100, 0.9884821772575378], "10": ["vertical_and_slash", 3500, 100, 0.9901348352432251], "11": ["vertical_and_slash", 100, 750, 0.9968274831771851], "12": ["vertical_and_slash", 3500, 100, 0.9918603897094727], "13": ["vertical_and_slash", 500, 700, 0.9757610559463501], "14": ["vertical_and_slash", 3500, 100, 0.9900703430175781], "15": ["vertical_and_slash", 500, 700, 0.9938023090362549], "16": ["vertical_and_slash", 1000, 6096, 0.8913345336914062], "17": ["vertical_and_slash", 500, 700, 0.9903258681297302], "18": ["vertical_and_slash", 100, 750, 0.9566823244094849], "19": ["vertical_and_slash", 100, 750, 0.9777167439460754], "20": ["vertical_and_slash", 3500, 100, 0.9674810767173767], "21": ["vertical_and_slash", 100, 750, 0.9178389310836792], "22": ["vertical_and_slash", 100, 750, 0.9882655143737793], "23": ["vertical_and_slash", 100, 750, 0.9989043474197388], "24": ["vertical_and_slash", 1000, 6096, 0.8574219942092896], "25": ["vertical_and_slash", 3500, 100, 0.9944363236427307], "26": ["vertical_and_slash", 3500, 100, 0.9970851540565491], "27": ["vertical_and_slash", 500, 700, 0.9904334545135498], "28": ["vertical_and_slash", 3500, 100, 0.9851230978965759], "29": ["vertical_and_slash", 3500, 100, 0.9900650978088379], "30": ["vertical_and_slash", 3500, 100, 0.9743545055389404], "31": ["vertical_and_slash", 500, 700, 0.9190711975097656]}, {"0": ["vertical_and_slash", 100, 750, 0.9716458320617676], "1": ["vertical_and_slash", 3500, 100, 0.9384027719497681], "2": ["vertical_and_slash", 3500, 100, 0.9696847796440125], "3": ["vertical_and_slash", 3500, 100, 0.9812428951263428], "4": ["vertical_and_slash", 1000, 6096, 0.5853931903839111], "5": ["vertical_and_slash", 3500, 100, 0.7994469404220581], "6": ["vertical_and_slash", 3500, 100, 0.9933062791824341], "7": ["vertical_and_slash", 3500, 100, 0.986369788646698], "8": ["vertical_and_slash", 3500, 100, 0.8895794153213501], "9": ["vertical_and_slash", 1000, 6096, 0.8238524794578552], "10": ["vertical_and_slash", 500, 700, 0.93126380443573], "11": ["vertical_and_slash", 3500, 100, 0.962100088596344], "12": ["vertical_and_slash", 3500, 100, 0.8438158631324768], "13": ["vertical_and_slash", 500, 700, 0.9969620108604431], "14": ["vertical_and_slash", 1000, 6096, 0.8904788494110107], "15": ["vertical_and_slash", 100, 750, 0.9925360679626465], "16": ["vertical_and_slash", 3500, 100, 0.9222993850708008], "17": ["vertical_and_slash", 1000, 6096, 0.6627880334854126], "18": ["vertical_and_slash", 1000, 6096, 0.8668970465660095], "19": ["vertical_and_slash", 3500, 100, 0.9340634346008301], "20": ["vertical_and_slash", 3500, 100, 0.9503065347671509], "21": ["vertical_and_slash", 3500, 100, 0.9436649680137634], "22": ["vertical_and_slash", 3500, 100, 0.9768727421760559], "23": ["vertical_and_slash", 100, 750, 0.988473653793335], "24": ["vertical_and_slash", 3500, 100, 0.8777113556861877], "25": ["vertical_and_slash", 3500, 100, 0.8750200271606445], "26": ["vertical_and_slash", 1000, 6096, 0.4957360625267029], "27": ["vertical_and_slash", 3500, 100, 0.9804278016090393], "28": ["vertical_and_slash", 1000, 6096, 0.8486401438713074], "29": ["vertical_and_slash", 3500, 100, 0.8954175114631653], "30": ["vertical_and_slash", 3500, 100, 0.9651874899864197], "31": ["vertical_and_slash", 3500, 100, 0.9620938301086426]}, {"0": ["vertical_and_slash", 100, 750, 0.920842707157135], "1": ["vertical_and_slash", 3500, 100, 0.7215947508811951], "2": ["vertical_and_slash", 3500, 100, 0.9858340620994568], "3": ["vertical_and_slash", 3500, 100, 0.7861597537994385], "4": ["vertical_and_slash", 3500, 100, 0.7639158964157104], "5": ["vertical_and_slash", 3500, 100, 0.887671947479248], "6": ["vertical_and_slash", 3500, 100, 0.8891316652297974], "7": ["vertical_and_slash", 1000, 6096, 0.8906923532485962], "8": ["vertical_and_slash", 3500, 100, 0.8836961984634399], "9": ["vertical_and_slash", 3500, 100, 0.7728190422058105], "10": ["vertical_and_slash", 3500, 100, 0.9507467746734619], "11": ["vertical_and_slash", 500, 700, 0.7829118967056274], "12": ["vertical_and_slash", 100, 750, 0.8214483857154846], "13": ["vertical_and_slash", 3500, 100, 0.7196475863456726], "14": ["vertical_and_slash", 500, 700, 0.8691932559013367], "15": ["vertical_and_slash", 1000, 6096, 0.6569814085960388], "16": ["vertical_and_slash", 100, 750, 0.9087151288986206], "17": ["vertical_and_slash", 3500, 100, 0.7609643936157227], "18": ["vertical_and_slash", 3500, 100, 0.8670530319213867], "19": ["vertical_and_slash", 1000, 6096, 0.7779831290245056], "20": ["vertical_and_slash", 100, 750, 0.923963725566864], "21": ["vertical_and_slash", 1000, 6096, 0.5714190006256104], "22": ["vertical_and_slash", 500, 700, 0.6351447105407715], "23": ["vertical_and_slash", 100, 750, 0.870464026927948], "24": ["vertical_and_slash", 1000, 6096, 0.6272542476654053], "25": ["vertical_and_slash", 1000, 6096, 0.7302500009536743], "26": ["vertical_and_slash", 3500, 100, 0.9410015940666199], "27": ["vertical_and_slash", 3500, 100, 0.793304979801178], "28": ["vertical_and_slash", 1000, 6096, 0.837500274181366], "29": ["vertical_and_slash", 1000, 6096, 0.766721248626709], "30": ["vertical_and_slash", 1000, 6096, 0.7082650065422058], "31": ["vertical_and_slash", 3500, 100, 0.8947907090187073]}, {"0": ["vertical_and_slash", 100, 750, 0.8983681797981262], "1": ["vertical_and_slash", 1000, 6096, 0.9650430083274841], "2": ["vertical_and_slash", 500, 700, 0.9532706141471863], "3": ["vertical_and_slash", 3500, 100, 0.8198072910308838], "4": ["vertical_and_slash", 1000, 6096, 0.840558648109436], "5": ["vertical_and_slash", 3500, 100, 0.8227204084396362], "6": ["vertical_and_slash", 1000, 6096, 0.5979130268096924], "7": ["vertical_and_slash", 1000, 6096, 0.7691975235939026], "8": ["vertical_and_slash", 1000, 6096, 0.8089779615402222], "9": ["vertical_and_slash", 100, 750, 0.8689324855804443], "10": ["vertical_and_slash", 100, 750, 0.8621079325675964], "11": ["vertical_and_slash", 500, 700, 0.9871177673339844], "12": ["vertical_and_slash", 1000, 6096, 0.9468575716018677], "13": ["vertical_and_slash", 100, 750, 0.9075571894645691], "14": ["vertical_and_slash", 1000, 6096, 0.911694347858429], "15": ["vertical_and_slash", 100, 750, 0.9817390441894531], "16": ["vertical_and_slash", 1000, 6096, 0.7491167783737183], "17": ["vertical_and_slash", 1000, 6096, 0.8255623579025269], "18": ["vertical_and_slash", 1000, 6096, 0.8701649308204651], "19": ["vertical_and_slash", 3500, 100, 0.838506817817688], "20": ["vertical_and_slash", 1000, 6096, 0.8749529123306274], "21": ["vertical_and_slash", 500, 700, 0.8783859610557556], "22": ["vertical_and_slash", 3500, 100, 0.9302544593811035], "23": ["vertical_and_slash", 100, 750, 0.9118035435676575], "24": ["vertical_and_slash", 1000, 6096, 0.7892093658447266], "25": ["vertical_and_slash", 100, 750, 0.904501736164093], "26": ["vertical_and_slash", 3500, 100, 0.947079598903656], "27": ["vertical_and_slash", 1000, 6096, 0.5719630718231201], "28": ["vertical_and_slash", 3500, 100, 0.9740545153617859], "29": ["vertical_and_slash", 100, 750, 0.8365178108215332], "30": ["vertical_and_slash", 3500, 100, 0.8893513083457947], "31": ["vertical_and_slash", 1000, 6096, 0.923209547996521]}]
|
minference/configs/Yi_9B_200k_kv_out_v32_fit_o_best_pattern.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
[{"0": ["vertical_and_slash", 1000, 4096, 12982], "1": ["vertical_and_slash", 1000, 4096, 54], "2": ["vertical_and_slash", 1000, 4096, 0], "3": ["vertical_and_slash", 1000, 4096, 5], "4": ["vertical_and_slash", 1000, 4096, 57], "5": ["vertical_and_slash", 1000, 4096, 93], "6": ["vertical_and_slash", 1000, 4096, 5], "7": ["vertical_and_slash", 1000, 4096, 0], "8": ["vertical_and_slash", 1000, 4096, 4], "9": ["vertical_and_slash", 1000, 4096, 8], "10": ["vertical_and_slash", 1000, 4096, 10020], "11": ["vertical_and_slash", 1000, 4096, 0], "12": ["vertical_and_slash", 1000, 4096, 222290], "13": ["vertical_and_slash", 1000, 4096, 162], "14": ["vertical_and_slash", 1000, 4096, 3], "15": ["vertical_and_slash", 1000, 4096, 11], "16": ["vertical_and_slash", 1000, 4096, 10], "17": ["vertical_and_slash", 1000, 4096, 4], "18": ["vertical_and_slash", 1000, 4096, 26297], "19": ["vertical_and_slash", 1000, 4096, 3], "20": ["vertical_and_slash", 1000, 4096, 0], "21": ["vertical_and_slash", 1000, 4096, 0], "22": ["vertical_and_slash", 1000, 4096, 1627], "23": ["vertical_and_slash", 1000, 4096, 7], "24": ["vertical_and_slash", 1000, 4096, 0], "25": ["vertical_and_slash", 1000, 4096, 859], "26": ["vertical_and_slash", 1000, 4096, 0], "27": ["vertical_and_slash", 1000, 4096, 0], "28": ["vertical_and_slash", 1000, 4096, 484], "29": ["vertical_and_slash", 1000, 4096, 1239], "30": ["vertical_and_slash", 1000, 4096, 0], "31": ["vertical_and_slash", 1000, 4096, 0]}, {"0": ["vertical_and_slash", 1000, 4096, 430388], "1": ["vertical_and_slash", 1000, 4096, 299591], "2": ["vertical_and_slash", 1000, 4096, 5802], "3": ["vertical_and_slash", 1000, 4096, 22390], "4": ["vertical_and_slash", 1000, 4096, 284950], "5": ["vertical_and_slash", 1000, 4096, 237516], "6": ["vertical_and_slash", 1000, 4096, 39541], "7": ["vertical_and_slash", 1000, 4096, 46216], "8": ["vertical_and_slash", 1000, 4096, 782645], "9": ["vertical_and_slash", 1000, 4096, 8], "10": ["vertical_and_slash", 1000, 4096, 18], "11": ["vertical_and_slash", 1000, 4096, 18890], "12": ["vertical_and_slash", 1000, 4096, 141], "13": ["vertical_and_slash", 1000, 4096, 53457], "14": ["vertical_and_slash", 1000, 4096, 34], "15": ["vertical_and_slash", 1000, 4096, 0], "16": ["vertical_and_slash", 1000, 4096, 246481], "17": ["vertical_and_slash", 1000, 4096, 135148], "18": ["vertical_and_slash", 1000, 4096, 48561], "19": ["vertical_and_slash", 1000, 4096, 54785], "20": ["vertical_and_slash", 1000, 4096, 95382], "21": ["vertical_and_slash", 1000, 4096, 387], "22": ["vertical_and_slash", 1000, 4096, 1750], "23": ["vertical_and_slash", 1000, 4096, 201661], "24": ["vertical_and_slash", 1000, 4096, 51272], "25": ["vertical_and_slash", 1000, 4096, 115255], "26": ["vertical_and_slash", 1000, 4096, 6], "27": ["vertical_and_slash", 1000, 4096, 6895], "28": ["vertical_and_slash", 1000, 4096, 2335], "29": ["vertical_and_slash", 1000, 4096, 23041], "30": ["vertical_and_slash", 1000, 4096, 6270087], "31": ["vertical_and_slash", 1000, 4096, 0]}, {"0": ["vertical_and_slash", 100, 800, 11], "1": ["vertical_and_slash", 30, 800, 5], "2": ["vertical_and_slash", 30, 800, 2790], "3": ["vertical_and_slash", 30, 800, 37], "4": ["vertical_and_slash", 30, 800, 2903], "5": ["vertical_and_slash", 30, 800, 1], "6": ["vertical_and_slash", 30, 800, 101], "7": ["vertical_and_slash", 100, 800, 16677], "8": ["vertical_and_slash", 1000, 4096, 99796], "9": ["vertical_and_slash", 30, 800, 8116], "10": ["vertical_and_slash", 30, 800, 1993], "11": ["vertical_and_slash", 1000, 4096, 2561], "12": ["vertical_and_slash", 30, 800, 21], "13": ["vertical_and_slash", 30, 800, 9624], "14": ["vertical_and_slash", 1000, 4096, 3894510], "15": ["vertical_and_slash", 1000, 4096, 66775], "16": ["vertical_and_slash", 30, 800, 1569], "17": ["vertical_and_slash", 1000, 4096, 146958], "18": ["vertical_and_slash", 30, 800, 29976], "19": ["vertical_and_slash", 1000, 4096, 269566], "20": ["vertical_and_slash", 100, 800, 50639], "21": ["vertical_and_slash", 30, 800, 114641], "22": ["vertical_and_slash", 1000, 4096, 238607], "23": ["vertical_and_slash", 100, 800, 302385], "24": ["vertical_and_slash", 1000, 4096, 4893], "25": ["vertical_and_slash", 30, 800, 322], "26": ["vertical_and_slash", 1000, 4096, 3639], "27": ["vertical_and_slash", 100, 800, 131], "28": ["vertical_and_slash", 1000, 4096, 348560], "29": ["vertical_and_slash", 1000, 4096, 14611], "30": ["vertical_and_slash", 30, 800, 86], "31": ["vertical_and_slash", 1000, 4096, 900]}, {"0": ["vertical_and_slash", 100, 800, 64], "1": ["vertical_and_slash", 1000, 4096, 10], "2": ["vertical_and_slash", 500, 700, 77], "3": ["vertical_and_slash", 1000, 4096, 4193], "4": ["vertical_and_slash", 100, 800, 83525], "5": ["vertical_and_slash", 1000, 4096, 6], "6": ["vertical_and_slash", 1000, 4096, 27907], "7": ["vertical_and_slash", 1000, 4096, 42], "8": ["vertical_and_slash", 30, 800, 21349], "9": ["vertical_and_slash", 30, 800, 5018], "10": ["vertical_and_slash", 30, 800, 1663], "11": ["vertical_and_slash", 30, 800, 86902], "12": ["vertical_and_slash", 30, 800, 781], "13": ["vertical_and_slash", 100, 800, 339811], "14": ["vertical_and_slash", 100, 800, 696206], "15": ["vertical_and_slash", 30, 800, 47681], "16": ["vertical_and_slash", 30, 800, 4251], "17": ["vertical_and_slash", 1000, 4096, 6373945], "18": ["vertical_and_slash", 100, 800, 289132], "19": ["vertical_and_slash", 1000, 4096, 10273], "20": ["vertical_and_slash", 1000, 4096, 457078], "21": ["vertical_and_slash", 1000, 4096, 1372461], "22": ["vertical_and_slash", 100, 800, 11108], "23": ["vertical_and_slash", 100, 800, 2979], "24": ["vertical_and_slash", 1000, 4096, 30365], "25": ["vertical_and_slash", 500, 700, 142429], "26": ["vertical_and_slash", 500, 700, 6300], "27": ["vertical_and_slash", 30, 800, 4711], "28": ["vertical_and_slash", 500, 700, 4810], "29": ["vertical_and_slash", 500, 700, 25571], "30": ["vertical_and_slash", 500, 700, 7924], "31": ["vertical_and_slash", 500, 700, 3337]}, {"0": ["vertical_and_slash", 30, 800, 34678], "1": ["vertical_and_slash", 30, 800, 13104], "2": ["vertical_and_slash", 30, 800, 4929], "3": ["vertical_and_slash", 100, 800, 9351380], "4": ["vertical_and_slash", 100, 800, 333814], "5": ["vertical_and_slash", 100, 800, 603408], "6": ["vertical_and_slash", 30, 800, 18975], "7": ["vertical_and_slash", 30, 800, 8848], "8": ["vertical_and_slash", 100, 800, 1690132], "9": ["vertical_and_slash", 30, 800, 59610], "10": ["vertical_and_slash", 500, 700, 1234], "11": ["vertical_and_slash", 1000, 4096, 74422], "12": ["vertical_and_slash", 1000, 4096, 504212], "13": ["vertical_and_slash", 30, 800, 3100], "14": ["vertical_and_slash", 100, 800, 1160], "15": ["vertical_and_slash", 500, 700, 5784], "16": ["vertical_and_slash", 30, 800, 18695], "17": ["vertical_and_slash", 30, 800, 2090], "18": ["vertical_and_slash", 30, 800, 28562], "19": ["vertical_and_slash", 30, 800, 34339], "20": ["vertical_and_slash", 30, 800, 2544], "21": ["vertical_and_slash", 30, 800, 1914], "22": ["vertical_and_slash", 30, 800, 83258], "23": ["vertical_and_slash", 30, 800, 7898], "24": ["vertical_and_slash", 30, 800, 11609], "25": ["vertical_and_slash", 1000, 4096, 64138], "26": ["vertical_and_slash", 1000, 4096, 514471], "27": ["vertical_and_slash", 500, 700, 39930], "28": ["vertical_and_slash", 30, 800, 477456], "29": ["vertical_and_slash", 100, 800, 4526], "30": ["vertical_and_slash", 1000, 4096, 30006], "31": ["vertical_and_slash", 30, 800, 92845]}, {"0": ["vertical_and_slash", 30, 800, 55378], "1": ["vertical_and_slash", 1000, 4096, 17441], "2": ["vertical_and_slash", 100, 800, 1890658], "3": ["vertical_and_slash", 30, 800, 39922], "4": ["vertical_and_slash", 30, 800, 3841], "5": ["vertical_and_slash", 30, 800, 16402], "6": ["vertical_and_slash", 30, 800, 9274], "7": ["vertical_and_slash", 100, 800, 2756], "8": ["vertical_and_slash", 100, 800, 190896], "9": ["vertical_and_slash", 1000, 4096, 30060], "10": ["vertical_and_slash", 1000, 4096, 1123342], "11": ["vertical_and_slash", 1000, 4096, 260812], "12": ["vertical_and_slash", 1000, 4096, 4395769], "13": ["vertical_and_slash", 1000, 4096, 1803359], "14": ["vertical_and_slash", 30, 800, 17625], "15": ["vertical_and_slash", 1000, 4096, 1501177], "16": ["vertical_and_slash", 1000, 4096, 236955], "17": ["vertical_and_slash", 1000, 4096, 27239], "18": ["vertical_and_slash", 1000, 4096, 84045], "19": ["vertical_and_slash", 1000, 4096, 112395], "20": ["vertical_and_slash", 1000, 4096, 289351], "21": ["vertical_and_slash", 1000, 4096, 1200493], "22": ["vertical_and_slash", 100, 800, 5628], "23": ["vertical_and_slash", 1000, 4096, 53], "24": ["vertical_and_slash", 30, 800, 1001179], "25": ["vertical_and_slash", 1000, 4096, 1417294], "26": ["vertical_and_slash", 100, 800, 712290], "27": ["vertical_and_slash", 1000, 4096, 111462], "28": ["vertical_and_slash", 1000, 4096, 2382091], "29": ["vertical_and_slash", 30, 800, 10632], "30": ["vertical_and_slash", 100, 800, 404628], "31": ["vertical_and_slash", 1000, 4096, 36025]}, {"0": ["vertical_and_slash", 1000, 4096, 683931], "1": ["vertical_and_slash", 1000, 4096, 1978224], "2": ["vertical_and_slash", 30, 800, 529064], "3": ["vertical_and_slash", 30, 800, 20483], "4": ["vertical_and_slash", 30, 800, 226587], "5": ["vertical_and_slash", 30, 800, 100650], "6": ["vertical_and_slash", 30, 800, 88814], "7": ["vertical_and_slash", 30, 800, 25415], "8": ["vertical_and_slash", 1000, 4096, 126846], "9": ["vertical_and_slash", 100, 800, 83585], "10": ["vertical_and_slash", 1000, 4096, 53117], "11": ["vertical_and_slash", 1000, 4096, 30196], "12": ["vertical_and_slash", 1000, 4096, 81511], "13": ["vertical_and_slash", 1000, 4096, 25087], "14": ["vertical_and_slash", 1000, 4096, 52332], "15": ["vertical_and_slash", 1000, 4096, 1662596], "16": ["vertical_and_slash", 30, 800, 26199], "17": ["vertical_and_slash", 30, 800, 72420], "18": ["vertical_and_slash", 30, 800, 74770], "19": ["vertical_and_slash", 30, 800, 94064], "20": ["vertical_and_slash", 30, 800, 10369], "21": ["vertical_and_slash", 1000, 4096, 2802268], "22": ["vertical_and_slash", 30, 800, 32077], "23": ["vertical_and_slash", 500, 700, 751949], "24": ["vertical_and_slash", 100, 800, 23111], "25": ["vertical_and_slash", 100, 800, 13161], "26": ["vertical_and_slash", 100, 800, 164196], "27": ["vertical_and_slash", 1000, 4096, 12766], "28": ["vertical_and_slash", 1000, 4096, 37748], "29": ["vertical_and_slash", 1000, 4096, 394580], "30": ["vertical_and_slash", 30, 800, 1161581], "31": ["vertical_and_slash", 1000, 4096, 1070988]}, {"0": ["vertical_and_slash", 100, 800, 4619], "1": ["vertical_and_slash", 1000, 4096, 3223], "2": ["vertical_and_slash", 100, 800, 65675], "3": ["vertical_and_slash", 30, 800, 56], "4": ["vertical_and_slash", 30, 800, 93], "5": ["vertical_and_slash", 30, 800, 72], "6": ["vertical_and_slash", 500, 700, 3523], "7": ["vertical_and_slash", 1000, 4096, 12230], "8": ["vertical_and_slash", 100, 800, 9301307], "9": ["vertical_and_slash", 1000, 4096, 418350], "10": ["vertical_and_slash", 1000, 4096, 994569], "11": ["vertical_and_slash", 100, 800, 399778], "12": ["vertical_and_slash", 1000, 4096, 2677334], "13": ["vertical_and_slash", 1000, 4096, 409432], "14": ["vertical_and_slash", 30, 800, 1233050], "15": ["vertical_and_slash", 1000, 4096, 5697704], "16": ["vertical_and_slash", 100, 800, 294], "17": ["vertical_and_slash", 30, 800, 50017], "18": ["vertical_and_slash", 30, 800, 1566], "19": ["vertical_and_slash", 30, 800, 2368], "20": ["vertical_and_slash", 30, 800, 3051012], "21": ["vertical_and_slash", 1000, 4096, 15983], "22": ["vertical_and_slash", 1000, 4096, 48], "23": ["vertical_and_slash", 1000, 4096, 312543], "24": ["vertical_and_slash", 30, 800, 4820], "25": ["vertical_and_slash", 30, 800, 100931], "26": ["vertical_and_slash", 30, 800, 69743], "27": ["vertical_and_slash", 30, 800, 22187], "28": ["vertical_and_slash", 30, 800, 3936], "29": ["vertical_and_slash", 30, 800, 4611], "30": ["vertical_and_slash", 30, 800, 21928], "31": ["vertical_and_slash", 30, 800, 133206]}, {"0": ["vertical_and_slash", 100, 800, 41811], "1": ["vertical_and_slash", 30, 800, 4226], "2": ["vertical_and_slash", 100, 800, 11930], "3": ["vertical_and_slash", 30, 800, 629146], "4": ["vertical_and_slash", 100, 800, 511736], "5": ["vertical_and_slash", 100, 800, 1408], "6": ["vertical_and_slash", 30, 800, 18012], "7": ["vertical_and_slash", 30, 800, 897], "8": ["vertical_and_slash", 30, 800, 107705], "9": ["vertical_and_slash", 30, 800, 152957], "10": ["vertical_and_slash", 30, 800, 272002], "11": ["vertical_and_slash", 30, 800, 5216722], "12": ["vertical_and_slash", 30, 800, 509504], "13": ["vertical_and_slash", 30, 800, 72091], "14": ["vertical_and_slash", 30, 800, 166293], "15": ["vertical_and_slash", 30, 800, 426344], "16": ["vertical_and_slash", 30, 800, 316624], "17": ["vertical_and_slash", 1000, 4096, 158902], "18": ["vertical_and_slash", 30, 800, 162502], "19": ["vertical_and_slash", 1000, 4096, 2464314], "20": ["vertical_and_slash", 1000, 4096, 5817909], "21": ["vertical_and_slash", 100, 800, 1141235], "22": ["vertical_and_slash", 30, 800, 452577], "23": ["vertical_and_slash", 30, 800, 193960], "24": ["vertical_and_slash", 30, 800, 538157], "25": ["vertical_and_slash", 30, 800, 1355759], "26": ["vertical_and_slash", 100, 800, 141236], "27": ["vertical_and_slash", 30, 800, 87608], "28": ["vertical_and_slash", 30, 800, 102946], "29": ["vertical_and_slash", 30, 800, 81254], "30": ["vertical_and_slash", 30, 800, 6194794], "31": ["vertical_and_slash", 30, 800, 2092660]}, {"0": ["vertical_and_slash", 30, 800, 278589], "1": ["vertical_and_slash", 30, 800, 1071731], "2": ["vertical_and_slash", 30, 800, 1991650], "3": ["vertical_and_slash", 30, 800, 308703], "4": ["vertical_and_slash", 30, 800, 1024242], "5": ["vertical_and_slash", 30, 800, 3107957], "6": ["vertical_and_slash", 30, 800, 926801], "7": ["vertical_and_slash", 30, 800, 2887199], "8": ["vertical_and_slash", 1000, 4096, 4152662], "9": ["vertical_and_slash", 100, 800, 15773492], "10": ["vertical_and_slash", 30, 800, 667496], "11": ["vertical_and_slash", 30, 800, 767325], "12": ["vertical_and_slash", 30, 800, 490706], "13": ["vertical_and_slash", 100, 800, 3083166], "14": ["vertical_and_slash", 100, 800, 14433242], "15": ["vertical_and_slash", 30, 800, 514502], "16": ["vertical_and_slash", 1000, 4096, 4574900], "17": ["vertical_and_slash", 1000, 4096, 1828093], "18": ["vertical_and_slash", 30, 800, 3790483], "19": ["vertical_and_slash", 1000, 4096, 9164424], "20": ["vertical_and_slash", 1000, 4096, 1011346], "21": ["vertical_and_slash", 1000, 4096, 1768867], "22": ["vertical_and_slash", 100, 800, 3253894], "23": ["vertical_and_slash", 1000, 4096, 882663], "24": ["vertical_and_slash", 100, 800, 1974998], "25": ["vertical_and_slash", 500, 700, 1452483], "26": ["vertical_and_slash", 100, 800, 12992816], "27": ["vertical_and_slash", 1000, 4096, 4441511], "28": ["vertical_and_slash", 100, 800, 3146531], "29": ["vertical_and_slash", 1000, 4096, 7002295], "30": ["vertical_and_slash", 100, 800, 7974855], "31": ["vertical_and_slash", 1000, 4096, 2767293]}, {"0": ["vertical_and_slash", 30, 800, 517042], "1": ["vertical_and_slash", 30, 800, 9471250], "2": ["vertical_and_slash", 30, 800, 67128], "3": ["vertical_and_slash", 100, 800, 13225828], "4": ["vertical_and_slash", 1000, 4096, 8138531], "5": ["vertical_and_slash", 30, 800, 169424], "6": ["vertical_and_slash", 30, 800, 165102], "7": ["vertical_and_slash", 1000, 4096, 898000], "8": ["vertical_and_slash", 100, 800, 498306], "9": ["vertical_and_slash", 100, 800, 12016777], "10": ["vertical_and_slash", 100, 800, 13078398], "11": ["vertical_and_slash", 1000, 4096, 569449], "12": ["vertical_and_slash", 1000, 4096, 4419468], "13": ["vertical_and_slash", 100, 800, 2308923], "14": ["vertical_and_slash", 100, 800, 188999], "15": ["vertical_and_slash", 30, 800, 685736], "16": ["vertical_and_slash", 100, 800, 161819], "17": ["vertical_and_slash", 100, 800, 1878966], "18": ["vertical_and_slash", 100, 800, 7840855], "19": ["vertical_and_slash", 30, 800, 207320], "20": ["vertical_and_slash", 100, 800, 2233365], "21": ["vertical_and_slash", 100, 800, 685239], "22": ["vertical_and_slash", 1000, 4096, 1493618], "23": ["vertical_and_slash", 30, 800, 1137958], "24": ["vertical_and_slash", 30, 800, 115113], "25": ["vertical_and_slash", 30, 800, 809754], "26": ["vertical_and_slash", 30, 800, 1328591], "27": ["vertical_and_slash", 30, 800, 697970], "28": ["vertical_and_slash", 1000, 4096, 14409], "29": ["vertical_and_slash", 30, 800, 376399], "30": ["vertical_and_slash", 30, 800, 71599], "31": ["vertical_and_slash", 30, 800, 431162]}, {"0": ["vertical_and_slash", 30, 800, 7073664], "1": ["vertical_and_slash", 100, 800, 4139486], "2": ["vertical_and_slash", 30, 800, 126298], "3": ["vertical_and_slash", 30, 800, 626891], "4": ["vertical_and_slash", 1000, 4096, 244457], "5": ["vertical_and_slash", 30, 800, 338124], "6": ["vertical_and_slash", 100, 800, 4247346], "7": ["vertical_and_slash", 100, 800, 1853876], "8": ["vertical_and_slash", 1000, 4096, 6355420], "9": ["vertical_and_slash", 100, 800, 988264], "10": ["vertical_and_slash", 1000, 4096, 984583], "11": ["vertical_and_slash", 100, 800, 914211], "12": ["vertical_and_slash", 1000, 4096, 570502], "13": ["vertical_and_slash", 1000, 4096, 10187572], "14": ["vertical_and_slash", 1000, 4096, 3408578], "15": ["vertical_and_slash", 1000, 4096, 11375984], "16": ["vertical_and_slash", 100, 800, 5144098], "17": ["vertical_and_slash", 1000, 4096, 350031], "18": ["vertical_and_slash", 1000, 4096, 1299268], "19": ["vertical_and_slash", 1000, 4096, 790117], "20": ["vertical_and_slash", 100, 800, 24094], "21": ["vertical_and_slash", 30, 800, 3856442], "22": ["vertical_and_slash", 100, 800, 383726], "23": ["vertical_and_slash", 500, 700, 832], "24": ["vertical_and_slash", 100, 800, 7717427], "25": ["vertical_and_slash", 1000, 4096, 4545251], "26": ["vertical_and_slash", 30, 800, 7922478], "27": ["vertical_and_slash", 1000, 4096, 2809849], "28": ["vertical_and_slash", 1000, 4096, 4392930], "29": ["vertical_and_slash", 100, 800, 2998060], "30": ["vertical_and_slash", 100, 800, 6173903], "31": ["vertical_and_slash", 1000, 4096, 2536227]}, {"0": ["vertical_and_slash", 30, 800, 1733117], "1": ["vertical_and_slash", 100, 800, 2514524], "2": ["vertical_and_slash", 1000, 4096, 12567570], "3": ["vertical_and_slash", 1000, 4096, 2817534], "4": ["vertical_and_slash", 1000, 4096, 10571712], "5": ["vertical_and_slash", 100, 800, 1311331], "6": ["vertical_and_slash", 30, 800, 4202358], "7": ["vertical_and_slash", 30, 800, 4970102], "8": ["vertical_and_slash", 30, 800, 88687], "9": ["vertical_and_slash", 30, 800, 293880], "10": ["vertical_and_slash", 500, 700, 70693], "11": ["vertical_and_slash", 30, 800, 13849], "12": ["vertical_and_slash", 30, 800, 238706], "13": ["vertical_and_slash", 30, 800, 78435], "14": ["vertical_and_slash", 30, 800, 164251], "15": ["vertical_and_slash", 30, 800, 199789], "16": ["vertical_and_slash", 30, 800, 200684], "17": ["vertical_and_slash", 1000, 4096, 1761919], "18": ["vertical_and_slash", 30, 800, 210071], "19": ["vertical_and_slash", 30, 800, 68554], "20": ["vertical_and_slash", 30, 800, 484345], "21": ["vertical_and_slash", 30, 800, 1489873], "22": ["vertical_and_slash", 30, 800, 301028], "23": ["vertical_and_slash", 30, 800, 1124431], "24": ["vertical_and_slash", 100, 800, 636179], "25": ["vertical_and_slash", 100, 800, 611008], "26": ["vertical_and_slash", 1000, 4096, 1639], "27": ["vertical_and_slash", 1000, 4096, 8255730], "28": ["vertical_and_slash", 1000, 4096, 6678469], "29": ["vertical_and_slash", 1000, 4096, 628985], "30": ["vertical_and_slash", 1000, 4096, 348316], "31": ["vertical_and_slash", 1000, 4096, 2159698]}, {"0": ["vertical_and_slash", 100, 800, 7105558], "1": ["vertical_and_slash", 30, 800, 1085603], "2": ["vertical_and_slash", 1000, 4096, 7896209], "3": ["vertical_and_slash", 30, 800, 193488], "4": ["vertical_and_slash", 100, 800, 1467223], "5": ["vertical_and_slash", 30, 800, 13794329], "6": ["vertical_and_slash", 1000, 4096, 15661583], "7": ["vertical_and_slash", 1000, 4096, 21334871], "8": ["vertical_and_slash", 1000, 4096, 6158120], "9": ["vertical_and_slash", 1000, 4096, 7414022], "10": ["vertical_and_slash", 100, 800, 14091447], "11": ["vertical_and_slash", 1000, 4096, 15589771], "12": ["vertical_and_slash", 1000, 4096, 14632639], "13": ["vertical_and_slash", 100, 800, 1695539], "14": ["vertical_and_slash", 30, 800, 2605978], "15": ["vertical_and_slash", 1000, 4096, 12495330], "16": ["vertical_and_slash", 1000, 4096, 14564586], "17": ["vertical_and_slash", 500, 700, 962969], "18": ["vertical_and_slash", 1000, 4096, 12281016], "19": ["vertical_and_slash", 1000, 4096, 4614742], "20": ["vertical_and_slash", 100, 800, 11940535], "21": ["vertical_and_slash", 100, 800, 2445981], "22": ["vertical_and_slash", 100, 800, 2485005], "23": ["vertical_and_slash", 1000, 4096, 6864324], "24": ["vertical_and_slash", 1000, 4096, 16230551], "25": ["vertical_and_slash", 100, 800, 9358656], "26": ["vertical_and_slash", 100, 800, 14973598], "27": ["vertical_and_slash", 1000, 4096, 14250781], "28": ["vertical_and_slash", 1000, 4096, 18030248], "29": ["vertical_and_slash", 1000, 4096, 20247786], "30": ["vertical_and_slash", 1000, 4096, 12736495], "31": ["vertical_and_slash", 100, 800, 9012943]}, {"0": ["vertical_and_slash", 100, 800, 4792757], "1": ["vertical_and_slash", 100, 800, 5568805], "2": ["vertical_and_slash", 1000, 4096, 12086343], "3": ["vertical_and_slash", 100, 800, 7359182], "4": ["vertical_and_slash", 100, 800, 13719718], "5": ["vertical_and_slash", 1000, 4096, 17051068], "6": ["vertical_and_slash", 100, 800, 15947388], "7": ["vertical_and_slash", 1000, 4096, 9143327], "8": ["vertical_and_slash", 1000, 4096, 21263361], "9": ["vertical_and_slash", 1000, 4096, 17189141], "10": ["vertical_and_slash", 1000, 4096, 7802422], "11": ["vertical_and_slash", 1000, 4096, 18488560], "12": ["vertical_and_slash", 100, 800, 14938800], "13": ["vertical_and_slash", 100, 800, 11012944], "14": ["vertical_and_slash", 1000, 4096, 19104830], "15": ["vertical_and_slash", 3500, 100, 32379], "16": ["vertical_and_slash", 100, 800, 3067742], "17": ["vertical_and_slash", 100, 800, 1977488], "18": ["vertical_and_slash", 1000, 4096, 15351109], "19": ["vertical_and_slash", 30, 800, 1627281], "20": ["vertical_and_slash", 30, 800, 1280991], "21": ["vertical_and_slash", 100, 800, 12133497], "22": ["vertical_and_slash", 1000, 4096, 17870425], "23": ["vertical_and_slash", 30, 800, 4040253], "24": ["vertical_and_slash", 1000, 4096, 6272625], "25": ["vertical_and_slash", 100, 800, 1225145], "26": ["vertical_and_slash", 100, 800, 2746332], "27": ["vertical_and_slash", 100, 800, 4525182], "28": ["vertical_and_slash", 100, 800, 6274770], "29": ["vertical_and_slash", 100, 800, 6919161], "30": ["vertical_and_slash", 100, 800, 3456148], "31": ["vertical_and_slash", 100, 800, 23867]}, {"0": ["vertical_and_slash", 1000, 4096, 7275761], "1": ["vertical_and_slash", 100, 800, 5068315], "2": ["vertical_and_slash", 100, 800, 11162394], "3": ["vertical_and_slash", 100, 800, 3672939], "4": ["vertical_and_slash", 3500, 100, 20894613], "5": ["vertical_and_slash", 1000, 4096, 7938372], "6": ["vertical_and_slash", 100, 800, 12544912], "7": ["vertical_and_slash", 100, 800, 2008695], "8": ["vertical_and_slash", 1000, 4096, 3368310], "9": ["vertical_and_slash", 30, 800, 1508993], "10": ["vertical_and_slash", 1000, 4096, 3495386], "11": ["vertical_and_slash", 3500, 100, 16438193], "12": ["vertical_and_slash", 100, 800, 7069375], "13": ["vertical_and_slash", 100, 800, 10686684], "14": ["vertical_and_slash", 30, 800, 501489], "15": ["vertical_and_slash", 100, 800, 6067001], "16": ["vertical_and_slash", 100, 800, 6935788], "17": ["vertical_and_slash", 1000, 4096, 3300792], "18": ["vertical_and_slash", 100, 800, 7398154], "19": ["vertical_and_slash", 100, 800, 5788636], "20": ["vertical_and_slash", 100, 800, 4456802], "21": ["vertical_and_slash", 100, 800, 2680176], "22": ["vertical_and_slash", 100, 800, 5544567], "23": ["vertical_and_slash", 1000, 4096, 13475356], "24": ["vertical_and_slash", 1000, 4096, 4901727], "25": ["vertical_and_slash", 1000, 4096, 3768996], "26": ["vertical_and_slash", 1000, 4096, 5368869], "27": ["vertical_and_slash", 3500, 100, 14218181], "28": ["vertical_and_slash", 1000, 4096, 13003444], "29": ["vertical_and_slash", 1000, 4096, 5716382], "30": ["vertical_and_slash", 3500, 100, 19916116], "31": ["vertical_and_slash", 1000, 4096, 11776798]}, {"0": ["vertical_and_slash", 100, 800, 13001986], "1": ["vertical_and_slash", 1000, 4096, 7570569], "2": ["vertical_and_slash", 100, 800, 951160], "3": ["vertical_and_slash", 100, 800, 11933179], "4": ["vertical_and_slash", 30, 800, 5365811], "5": ["vertical_and_slash", 100, 800, 10272574], "6": ["vertical_and_slash", 1000, 4096, 6527670], "7": ["vertical_and_slash", 100, 800, 12930014], "8": ["vertical_and_slash", 100, 800, 359537], "9": ["vertical_and_slash", 100, 800, 10654966], "10": ["vertical_and_slash", 100, 800, 1330316], "11": ["vertical_and_slash", 100, 800, 9971156], "12": ["vertical_and_slash", 1000, 4096, 5781478], "13": ["vertical_and_slash", 100, 800, 6032127], "14": ["vertical_and_slash", 100, 800, 1418329], "15": ["vertical_and_slash", 100, 800, 13069922], "16": ["vertical_and_slash", 100, 800, 8547563], "17": ["vertical_and_slash", 100, 800, 970921], "18": ["vertical_and_slash", 1000, 4096, 9256328], "19": ["vertical_and_slash", 1000, 4096, 12447206], "20": ["vertical_and_slash", 100, 800, 153856], "21": ["vertical_and_slash", 100, 800, 8022371], "22": ["vertical_and_slash", 3500, 100, 18626483], "23": ["vertical_and_slash", 100, 800, 3180643], "24": ["vertical_and_slash", 30, 800, 3549186], "25": ["vertical_and_slash", 100, 800, 2600992], "26": ["vertical_and_slash", 3500, 100, 21080570], "27": ["vertical_and_slash", 1000, 4096, 2995096], "28": ["vertical_and_slash", 30, 800, 13324952], "29": ["vertical_and_slash", 100, 800, 7015426], "30": ["vertical_and_slash", 100, 800, 17142326], "31": ["vertical_and_slash", 30, 800, 2059831]}, {"0": ["vertical_and_slash", 100, 800, 336984], "1": ["vertical_and_slash", 1000, 4096, 11908787], "2": ["vertical_and_slash", 1000, 4096, 11465673], "3": ["vertical_and_slash", 1000, 4096, 3870378], "4": ["vertical_and_slash", 1000, 4096, 1000373], "5": ["vertical_and_slash", 1000, 4096, 6450804], "6": ["vertical_and_slash", 1000, 4096, 6602987], "7": ["vertical_and_slash", 1000, 4096, 6552477], "8": ["vertical_and_slash", 30, 800, 8671938], "9": ["vertical_and_slash", 100, 800, 3906764], "10": ["vertical_and_slash", 1000, 4096, 7300294], "11": ["vertical_and_slash", 100, 800, 9068418], "12": ["vertical_and_slash", 100, 800, 5573415], "13": ["vertical_and_slash", 100, 800, 4302354], "14": ["vertical_and_slash", 30, 800, 969401], "15": ["vertical_and_slash", 100, 800, 132492], "16": ["vertical_and_slash", 1000, 4096, 10575265], "17": ["vertical_and_slash", 30, 800, 114557], "18": ["vertical_and_slash", 1000, 4096, 1669778], "19": ["vertical_and_slash", 30, 800, 244697], "20": ["vertical_and_slash", 30, 800, 401989], "21": ["vertical_and_slash", 1000, 4096, 257876], "22": ["vertical_and_slash", 100, 800, 1656276], "23": ["vertical_and_slash", 100, 800, 6627755], "24": ["vertical_and_slash", 100, 800, 17069094], "25": ["vertical_and_slash", 1000, 4096, 17310922], "26": ["vertical_and_slash", 3500, 100, 19238326], "27": ["vertical_and_slash", 100, 800, 10416201], "28": ["vertical_and_slash", 1000, 4096, 9125015], "29": ["vertical_and_slash", 100, 800, 17113558], "30": ["vertical_and_slash", 1000, 4096, 12041930], "31": ["vertical_and_slash", 1000, 4096, 6060396]}, {"0": ["vertical_and_slash", 1000, 4096, 9259982], "1": ["vertical_and_slash", 1000, 4096, 8618567], "2": ["vertical_and_slash", 100, 800, 3876940], "3": ["vertical_and_slash", 1000, 4096, 12767960], "4": ["vertical_and_slash", 1000, 4096, 6112941], "5": ["vertical_and_slash", 1000, 4096, 9851048], "6": ["vertical_and_slash", 1000, 4096, 5763271], "7": ["vertical_and_slash", 1000, 4096, 12744434], "8": ["vertical_and_slash", 100, 800, 12512293], "9": ["vertical_and_slash", 1000, 4096, 2367543], "10": ["vertical_and_slash", 100, 800, 12342103], "11": ["vertical_and_slash", 100, 800, 3126675], "12": ["vertical_and_slash", 1000, 4096, 13617286], "13": ["vertical_and_slash", 1000, 4096, 8094518], "14": ["vertical_and_slash", 1000, 4096, 851614], "15": ["vertical_and_slash", 1000, 4096, 10519480], "16": ["vertical_and_slash", 100, 800, 1706372], "17": ["vertical_and_slash", 100, 800, 248757], "18": ["vertical_and_slash", 100, 800, 4394336], "19": ["vertical_and_slash", 100, 800, 1886529], "20": ["vertical_and_slash", 1000, 4096, 6486541], "21": ["vertical_and_slash", 100, 800, 1175436], "22": ["vertical_and_slash", 100, 800, 7864652], "23": ["vertical_and_slash", 100, 800, 1001917], "24": ["vertical_and_slash", 100, 800, 2494293], "25": ["vertical_and_slash", 1000, 4096, 7698995], "26": ["vertical_and_slash", 100, 800, 2946712], "27": ["vertical_and_slash", 100, 800, 5464103], "28": ["vertical_and_slash", 100, 800, 2608538], "29": ["vertical_and_slash", 100, 800, 1606308], "30": ["vertical_and_slash", 1000, 4096, 5981702], "31": ["vertical_and_slash", 3500, 100, 18590832]}, {"0": ["vertical_and_slash", 100, 800, 4688244], "1": ["vertical_and_slash", 100, 800, 11368272], "2": ["vertical_and_slash", 100, 800, 2558719], "3": ["vertical_and_slash", 1000, 4096, 9536926], "4": ["vertical_and_slash", 1000, 4096, 12315283], "5": ["vertical_and_slash", 1000, 4096, 6272119], "6": ["vertical_and_slash", 1000, 4096, 4450200], "7": ["vertical_and_slash", 100, 800, 5822568], "8": ["vertical_and_slash", 1000, 4096, 13523232], "9": ["vertical_and_slash", 100, 800, 816607], "10": ["vertical_and_slash", 1000, 4096, 15825338], "11": ["vertical_and_slash", 100, 800, 1133867], "12": ["vertical_and_slash", 100, 800, 10722989], "13": ["vertical_and_slash", 100, 800, 2466001], "14": ["vertical_and_slash", 100, 800, 16732584], "15": ["vertical_and_slash", 100, 800, 1052553], "16": ["vertical_and_slash", 100, 800, 8602649], "17": ["vertical_and_slash", 100, 800, 8851217], "18": ["vertical_and_slash", 100, 800, 6104130], "19": ["vertical_and_slash", 1000, 4096, 18459502], "20": ["vertical_and_slash", 100, 800, 8076967], "21": ["vertical_and_slash", 1000, 4096, 4863209], "22": ["vertical_and_slash", 1000, 4096, 8892415], "23": ["vertical_and_slash", 1000, 4096, 9542798], "24": ["vertical_and_slash", 100, 800, 1384183], "25": ["vertical_and_slash", 100, 800, 4035455], "26": ["vertical_and_slash", 100, 800, 536763], "27": ["vertical_and_slash", 1000, 4096, 2058585], "28": ["vertical_and_slash", 100, 800, 4195607], "29": ["vertical_and_slash", 100, 800, 2407136], "30": ["vertical_and_slash", 100, 800, 2106926], "31": ["vertical_and_slash", 100, 800, 3807607]}, {"0": ["vertical_and_slash", 100, 800, 15975096], "1": ["vertical_and_slash", 3500, 100, 20664973], "2": ["vertical_and_slash", 1000, 4096, 943914], "3": ["vertical_and_slash", 100, 800, 14363276], "4": ["vertical_and_slash", 100, 800, 720326], "5": ["vertical_and_slash", 1000, 4096, 7725879], "6": ["vertical_and_slash", 1000, 4096, 11411255], "7": ["vertical_and_slash", 1000, 4096, 9492657], "8": ["vertical_and_slash", 1000, 4096, 16448227], "9": ["vertical_and_slash", 100, 800, 6180918], "10": ["vertical_and_slash", 1000, 4096, 10942342], "11": ["vertical_and_slash", 1000, 4096, 12047657], "12": ["vertical_and_slash", 100, 800, 2376658], "13": ["vertical_and_slash", 1000, 4096, 17780083], "14": ["vertical_and_slash", 1000, 4096, 8548356], "15": ["vertical_and_slash", 100, 800, 4545880], "16": ["vertical_and_slash", 30, 800, 2020350], "17": ["vertical_and_slash", 100, 800, 15875867], "18": ["vertical_and_slash", 30, 800, 661201], "19": ["vertical_and_slash", 1000, 4096, 14915782], "20": ["vertical_and_slash", 100, 800, 4106388], "21": ["vertical_and_slash", 30, 800, 14163451], "22": ["vertical_and_slash", 100, 800, 1759639], "23": ["vertical_and_slash", 1000, 4096, 2391070], "24": ["vertical_and_slash", 100, 800, 10749758], "25": ["vertical_and_slash", 100, 800, 8022438], "26": ["vertical_and_slash", 100, 800, 1013941], "27": ["vertical_and_slash", 100, 800, 3537516], "28": ["vertical_and_slash", 100, 800, 1252545], "29": ["vertical_and_slash", 100, 800, 1155740], "30": ["vertical_and_slash", 1000, 4096, 2590667], "31": ["vertical_and_slash", 100, 800, 3320946]}, {"0": ["vertical_and_slash", 1000, 4096, 8025205], "1": ["vertical_and_slash", 500, 700, 2286667], "2": ["vertical_and_slash", 1000, 4096, 2104863], "3": ["vertical_and_slash", 1000, 4096, 2160060], "4": ["vertical_and_slash", 1000, 4096, 4209178], "5": ["vertical_and_slash", 1000, 4096, 5703899], "6": ["vertical_and_slash", 100, 800, 15566139], "7": ["vertical_and_slash", 500, 700, 464012], "8": ["vertical_and_slash", 1000, 4096, 632556], "9": ["vertical_and_slash", 1000, 4096, 10933130], "10": ["vertical_and_slash", 3500, 100, 6376023], "11": ["vertical_and_slash", 30, 800, 53293], "12": ["vertical_and_slash", 3500, 100, 9195722], "13": ["vertical_and_slash", 100, 800, 130891], "14": ["vertical_and_slash", 100, 800, 1266310], "15": ["vertical_and_slash", 100, 800, 12042893], "16": ["vertical_and_slash", 100, 800, 1440252], "17": ["vertical_and_slash", 100, 800, 5003178], "18": ["vertical_and_slash", 100, 800, 9451180], "19": ["vertical_and_slash", 100, 800, 16518635], "20": ["vertical_and_slash", 1000, 4096, 16574448], "21": ["vertical_and_slash", 100, 800, 10001073], "22": ["vertical_and_slash", 100, 800, 6194150], "23": ["vertical_and_slash", 100, 800, 1990080], "24": ["vertical_and_slash", 100, 800, 14105574], "25": ["vertical_and_slash", 3500, 100, 49578], "26": ["vertical_and_slash", 100, 800, 1368613], "27": ["vertical_and_slash", 100, 800, 882483], "28": ["vertical_and_slash", 100, 800, 200592], "29": ["vertical_and_slash", 100, 800, 4144857], "30": ["vertical_and_slash", 30, 800, 2059620], "31": ["vertical_and_slash", 1000, 4096, 7650136]}, {"0": ["vertical_and_slash", 3500, 100, 20200147], "1": ["vertical_and_slash", 100, 800, 18033672], "2": ["vertical_and_slash", 100, 800, 19227421], "3": ["vertical_and_slash", 1000, 4096, 7658465], "4": ["vertical_and_slash", 100, 800, 4862174], "5": ["vertical_and_slash", 100, 800, 6197824], "6": ["vertical_and_slash", 100, 800, 5687873], "7": ["vertical_and_slash", 100, 800, 13005015], "8": ["vertical_and_slash", 1000, 4096, 6677727], "9": ["vertical_and_slash", 500, 700, 1282697], "10": ["vertical_and_slash", 30, 800, 3148411], "11": ["vertical_and_slash", 500, 700, 8985965], "12": ["vertical_and_slash", 100, 800, 11107850], "13": ["vertical_and_slash", 30, 800, 2077544], "14": ["vertical_and_slash", 1000, 4096, 10030857], "15": ["vertical_and_slash", 100, 800, 1625067], "16": ["vertical_and_slash", 100, 800, 332660], "17": ["vertical_and_slash", 3500, 100, 17539067], "18": ["vertical_and_slash", 500, 700, 97483], "19": ["vertical_and_slash", 30, 800, 10910089], "20": ["vertical_and_slash", 500, 700, 49927], "21": ["vertical_and_slash", 1000, 4096, 2959963], "22": ["vertical_and_slash", 1000, 4096, 1232910], "23": ["vertical_and_slash", 100, 800, 482216], "24": ["vertical_and_slash", 3500, 100, 2789809], "25": ["vertical_and_slash", 3500, 100, 1787013], "26": ["vertical_and_slash", 100, 800, 6121965], "27": ["vertical_and_slash", 100, 800, 10417031], "28": ["vertical_and_slash", 100, 800, 476098], "29": ["vertical_and_slash", 3500, 100, 13019985], "30": ["vertical_and_slash", 100, 800, 15057321], "31": ["vertical_and_slash", 100, 800, 7206530]}, {"0": ["vertical_and_slash", 30, 800, 3863946], "1": ["vertical_and_slash", 3500, 100, 373838], "2": ["vertical_and_slash", 30, 800, 2498107], "3": ["vertical_and_slash", 30, 800, 1774834], "4": ["vertical_and_slash", 30, 800, 13518574], "5": ["vertical_and_slash", 30, 800, 17864279], "6": ["vertical_and_slash", 30, 800, 4971247], "7": ["vertical_and_slash", 30, 800, 15064092], "8": ["vertical_and_slash", 1000, 4096, 173702], "9": ["vertical_and_slash", 100, 800, 2079528], "10": ["vertical_and_slash", 1000, 4096, 1395995], "11": ["vertical_and_slash", 100, 800, 16807189], "12": ["vertical_and_slash", 1000, 4096, 3387818], "13": ["vertical_and_slash", 1000, 4096, 215373], "14": ["vertical_and_slash", 1000, 4096, 7656048], "15": ["vertical_and_slash", 1000, 4096, 3284167], "16": ["vertical_and_slash", 100, 800, 208560], "17": ["vertical_and_slash", 100, 800, 12910224], "18": ["vertical_and_slash", 100, 800, 2482406], "19": ["vertical_and_slash", 100, 800, 591300], "20": ["vertical_and_slash", 500, 700, 2512230], "21": ["vertical_and_slash", 100, 800, 650819], "22": ["vertical_and_slash", 100, 800, 750172], "23": ["vertical_and_slash", 100, 800, 98380], "24": ["vertical_and_slash", 1000, 4096, 12591674], "25": ["vertical_and_slash", 100, 800, 7520129], "26": ["vertical_and_slash", 3500, 100, 19780031], "27": ["vertical_and_slash", 1000, 4096, 11324806], "28": ["vertical_and_slash", 100, 800, 2339301], "29": ["vertical_and_slash", 3500, 100, 20537162], "30": ["vertical_and_slash", 100, 800, 1802458], "31": ["vertical_and_slash", 1000, 4096, 4121953]}, {"0": ["vertical_and_slash", 100, 800, 1406058], "1": ["vertical_and_slash", 30, 800, 20495], "2": ["vertical_and_slash", 100, 800, 265247], "3": ["vertical_and_slash", 30, 800, 6044172], "4": ["vertical_and_slash", 100, 800, 15417162], "5": ["vertical_and_slash", 100, 800, 20101], "6": ["vertical_and_slash", 30, 800, 12443], "7": ["vertical_and_slash", 100, 800, 1029], "8": ["vertical_and_slash", 30, 800, 49334], "9": ["vertical_and_slash", 30, 800, 30976], "10": ["vertical_and_slash", 30, 800, 127540], "11": ["vertical_and_slash", 30, 800, 3597689], "12": ["vertical_and_slash", 30, 800, 32317], "13": ["vertical_and_slash", 30, 800, 202557], "14": ["vertical_and_slash", 30, 800, 531805], "15": ["vertical_and_slash", 30, 800, 606518], "16": ["vertical_and_slash", 30, 800, 1152706], "17": ["vertical_and_slash", 1000, 4096, 5604379], "18": ["vertical_and_slash", 30, 800, 663403], "19": ["vertical_and_slash", 1000, 4096, 11655952], "20": ["vertical_and_slash", 100, 800, 15102172], "21": ["vertical_and_slash", 100, 800, 4674143], "22": ["vertical_and_slash", 500, 700, 1539328], "23": ["vertical_and_slash", 100, 800, 3051857], "24": ["vertical_and_slash", 30, 800, 123576], "25": ["vertical_and_slash", 100, 800, 964667], "26": ["vertical_and_slash", 30, 800, 41505], "27": ["vertical_and_slash", 30, 800, 59560], "28": ["vertical_and_slash", 100, 800, 17208], "29": ["vertical_and_slash", 30, 800, 82626], "30": ["vertical_and_slash", 30, 800, 1815531], "31": ["vertical_and_slash", 100, 800, 2897668]}, {"0": ["vertical_and_slash", 30, 800, 48323], "1": ["vertical_and_slash", 30, 800, 689675], "2": ["vertical_and_slash", 30, 800, 542041], "3": ["vertical_and_slash", 30, 800, 8544], "4": ["vertical_and_slash", 30, 800, 102588], "5": ["vertical_and_slash", 100, 800, 2064154], "6": ["vertical_and_slash", 30, 800, 845227], "7": ["vertical_and_slash", 30, 800, 2922720], "8": ["vertical_and_slash", 1000, 4096, 2932415], "9": ["vertical_and_slash", 1000, 4096, 3062180], "10": ["vertical_and_slash", 100, 800, 485119], "11": ["vertical_and_slash", 30, 800, 215049], "12": ["vertical_and_slash", 100, 800, 387511], "13": ["vertical_and_slash", 100, 800, 1447813], "14": ["vertical_and_slash", 1000, 4096, 3878389], "15": ["vertical_and_slash", 100, 800, 376333], "16": ["vertical_and_slash", 3500, 100, 13506969], "17": ["vertical_and_slash", 100, 800, 12850708], "18": ["vertical_and_slash", 30, 800, 372529], "19": ["vertical_and_slash", 1000, 4096, 3746168], "20": ["vertical_and_slash", 100, 800, 170359], "21": ["vertical_and_slash", 100, 800, 1130785], "22": ["vertical_and_slash", 100, 800, 116224], "23": ["vertical_and_slash", 100, 800, 1001182], "24": ["vertical_and_slash", 100, 800, 335681], "25": ["vertical_and_slash", 100, 800, 3392285], "26": ["vertical_and_slash", 1000, 4096, 4420760], "27": ["vertical_and_slash", 3500, 100, 12258981], "28": ["vertical_and_slash", 500, 700, 1941188], "29": ["vertical_and_slash", 1000, 4096, 7639240], "30": ["vertical_and_slash", 500, 700, 8277346], "31": ["vertical_and_slash", 3500, 100, 3442659]}, {"0": ["vertical_and_slash", 30, 800, 945264], "1": ["vertical_and_slash", 1000, 4096, 3474994], "2": ["vertical_and_slash", 500, 700, 218918], "3": ["vertical_and_slash", 3500, 100, 20221076], "4": ["vertical_and_slash", 3500, 100, 21680113], "5": ["vertical_and_slash", 30, 800, 94866], "6": ["vertical_and_slash", 30, 800, 190907], "7": ["vertical_and_slash", 1000, 4096, 1708889], "8": ["vertical_and_slash", 100, 800, 2832752], "9": ["vertical_and_slash", 1000, 4096, 613061], "10": ["vertical_and_slash", 1000, 4096, 7381575], "11": ["vertical_and_slash", 1000, 4096, 1462120], "12": ["vertical_and_slash", 1000, 4096, 3338671], "13": ["vertical_and_slash", 100, 800, 1664528], "14": ["vertical_and_slash", 500, 700, 143074], "15": ["vertical_and_slash", 30, 800, 433035], "16": ["vertical_and_slash", 500, 700, 210886], "17": ["vertical_and_slash", 100, 800, 8632139], "18": ["vertical_and_slash", 100, 800, 17521811], "19": ["vertical_and_slash", 30, 800, 194306], "20": ["vertical_and_slash", 100, 800, 3156950], "21": ["vertical_and_slash", 100, 800, 2413125], "22": ["vertical_and_slash", 1000, 4096, 10110205], "23": ["vertical_and_slash", 100, 800, 695569], "24": ["vertical_and_slash", 30, 800, 32256], "25": ["vertical_and_slash", 30, 800, 396762], "26": ["vertical_and_slash", 30, 800, 726815], "27": ["vertical_and_slash", 30, 800, 499056], "28": ["vertical_and_slash", 30, 800, 24234], "29": ["vertical_and_slash", 30, 800, 87299], "30": ["vertical_and_slash", 30, 800, 82758], "31": ["vertical_and_slash", 30, 800, 447266]}, {"0": ["vertical_and_slash", 100, 800, 13520320], "1": ["vertical_and_slash", 100, 800, 1746572], "2": ["vertical_and_slash", 100, 800, 81358], "3": ["vertical_and_slash", 100, 800, 53915], "4": ["vertical_and_slash", 100, 800, 16824352], "5": ["vertical_and_slash", 100, 800, 124419], "6": ["vertical_and_slash", 100, 800, 5336412], "7": ["vertical_and_slash", 100, 800, 1005227], "8": ["vertical_and_slash", 1000, 4096, 17919472], "9": ["vertical_and_slash", 100, 800, 5089389], "10": ["vertical_and_slash", 1000, 4096, 2318753], "11": ["vertical_and_slash", 100, 800, 2351529], "12": ["vertical_and_slash", 1000, 4096, 1068220], "13": ["vertical_and_slash", 1000, 4096, 18765314], "14": ["vertical_and_slash", 1000, 4096, 11512280], "15": ["vertical_and_slash", 1000, 4096, 14722530], "16": ["vertical_and_slash", 100, 800, 1542041], "17": ["vertical_and_slash", 3500, 100, 19279869], "18": ["vertical_and_slash", 100, 800, 4711439], "19": ["vertical_and_slash", 3500, 100, 3688560], "20": ["vertical_and_slash", 3500, 100, 224250], "21": ["vertical_and_slash", 100, 800, 10537230], "22": ["vertical_and_slash", 100, 800, 749819], "23": ["vertical_and_slash", 100, 800, 25187], "24": ["vertical_and_slash", 100, 800, 13068183], "25": ["vertical_and_slash", 100, 800, 17508351], "26": ["vertical_and_slash", 100, 800, 12981109], "27": ["vertical_and_slash", 100, 800, 15314279], "28": ["vertical_and_slash", 100, 800, 15558838], "29": ["vertical_and_slash", 100, 800, 3774507], "30": ["vertical_and_slash", 100, 800, 6486179], "31": ["vertical_and_slash", 100, 800, 15420283]}, {"0": ["vertical_and_slash", 100, 800, 1793383], "1": ["vertical_and_slash", 100, 800, 8103093], "2": ["vertical_and_slash", 1000, 4096, 12596743], "3": ["vertical_and_slash", 1000, 4096, 5012316], "4": ["vertical_and_slash", 1000, 4096, 12870742], "5": ["vertical_and_slash", 100, 800, 3459141], "6": ["vertical_and_slash", 30, 800, 10224901], "7": ["vertical_and_slash", 100, 800, 3753981], "8": ["vertical_and_slash", 30, 800, 140040], "9": ["vertical_and_slash", 30, 800, 550671], "10": ["vertical_and_slash", 100, 800, 94454], "11": ["vertical_and_slash", 30, 800, 8909], "12": ["vertical_and_slash", 30, 800, 152077], "13": ["vertical_and_slash", 30, 800, 49171], "14": ["vertical_and_slash", 30, 800, 107813], "15": ["vertical_and_slash", 30, 800, 128764], "16": ["vertical_and_slash", 30, 800, 617322], "17": ["vertical_and_slash", 1000, 4096, 6019612], "18": ["vertical_and_slash", 100, 800, 766582], "19": ["vertical_and_slash", 30, 800, 52503], "20": ["vertical_and_slash", 30, 800, 300294], "21": ["vertical_and_slash", 30, 800, 1577098], "22": ["vertical_and_slash", 100, 800, 838126], "23": ["vertical_and_slash", 100, 800, 1218912], "24": ["vertical_and_slash", 100, 800, 1720664], "25": ["vertical_and_slash", 100, 800, 1377743], "26": ["vertical_and_slash", 1000, 4096, 900287], "27": ["vertical_and_slash", 1000, 4096, 12066126], "28": ["vertical_and_slash", 1000, 4096, 14264762], "29": ["vertical_and_slash", 1000, 4096, 71284], "30": ["vertical_and_slash", 1000, 4096, 3218291], "31": ["vertical_and_slash", 1000, 4096, 13215387]}, {"0": ["vertical_and_slash", 100, 800, 18645971], "1": ["vertical_and_slash", 30, 800, 587932], "2": ["vertical_and_slash", 1000, 4096, 10538505], "3": ["vertical_and_slash", 30, 800, 158559], "4": ["vertical_and_slash", 100, 800, 3376593], "5": ["vertical_and_slash", 100, 800, 18383338], "6": ["vertical_and_slash", 1000, 4096, 10074810], "7": ["vertical_and_slash", 1000, 4096, 19347044], "8": ["vertical_and_slash", 1000, 4096, 6794450], "9": ["vertical_and_slash", 1000, 4096, 3529136], "10": ["vertical_and_slash", 1000, 4096, 6952639], "11": ["vertical_and_slash", 1000, 4096, 9362393], "12": ["vertical_and_slash", 1000, 4096, 5368732], "13": ["vertical_and_slash", 100, 800, 705065], "14": ["vertical_and_slash", 100, 800, 628184], "15": ["vertical_and_slash", 1000, 4096, 7575979], "16": ["vertical_and_slash", 1000, 4096, 14825324], "17": ["vertical_and_slash", 100, 800, 584190], "18": ["vertical_and_slash", 1000, 4096, 14770220], "19": ["vertical_and_slash", 100, 800, 7324628], "20": ["vertical_and_slash", 100, 800, 13439080], "21": ["vertical_and_slash", 100, 800, 2173728], "22": ["vertical_and_slash", 100, 800, 1300676], "23": ["vertical_and_slash", 3500, 100, 20507565], "24": ["vertical_and_slash", 3500, 100, 20826931], "25": ["vertical_and_slash", 100, 800, 16503925], "26": ["vertical_and_slash", 3500, 100, 20607984], "27": ["vertical_and_slash", 1000, 4096, 9100775], "28": ["vertical_and_slash", 3500, 100, 20540180], "29": ["vertical_and_slash", 1000, 4096, 19978707], "30": ["vertical_and_slash", 100, 800, 18084829], "31": ["vertical_and_slash", 100, 800, 15584755]}, {"0": ["vertical_and_slash", 100, 800, 14519032], "1": ["vertical_and_slash", 100, 800, 13637880], "2": ["vertical_and_slash", 3500, 100, 19712241], "3": ["vertical_and_slash", 100, 800, 14417159], "4": ["vertical_and_slash", 100, 800, 18931772], "5": ["vertical_and_slash", 3500, 100, 20278735], "6": ["vertical_and_slash", 100, 800, 21000177], "7": ["vertical_and_slash", 3500, 100, 20181815], "8": ["vertical_and_slash", 1000, 4096, 20667264], "9": ["vertical_and_slash", 1000, 4096, 13546806], "10": ["vertical_and_slash", 1000, 4096, 8056555], "11": ["vertical_and_slash", 1000, 4096, 14544259], "12": ["vertical_and_slash", 3500, 100, 14988539], "13": ["vertical_and_slash", 100, 800, 9925552], "14": ["vertical_and_slash", 1000, 4096, 16502140], "15": ["vertical_and_slash", 3500, 100, 1394], "16": ["vertical_and_slash", 100, 800, 6786191], "17": ["vertical_and_slash", 100, 800, 5142369], "18": ["vertical_and_slash", 1000, 4096, 18139060], "19": ["vertical_and_slash", 100, 800, 1817633], "20": ["vertical_and_slash", 100, 800, 1586931], "21": ["vertical_and_slash", 1000, 4096, 2981991], "22": ["vertical_and_slash", 1000, 4096, 19814245], "23": ["vertical_and_slash", 100, 800, 3823591], "24": ["vertical_and_slash", 1000, 4096, 11968181], "25": ["vertical_and_slash", 100, 800, 4245870], "26": ["vertical_and_slash", 100, 800, 6065658], "27": ["vertical_and_slash", 100, 800, 17099315], "28": ["vertical_and_slash", 100, 800, 14002976], "29": ["vertical_and_slash", 100, 800, 15062395], "30": ["vertical_and_slash", 3500, 100, 9832421], "31": ["vertical_and_slash", 100, 800, 329163]}, {"0": ["vertical_and_slash", 100, 800, 17881284], "1": ["vertical_and_slash", 100, 800, 6096065], "2": ["vertical_and_slash", 100, 800, 19512309], "3": ["vertical_and_slash", 100, 800, 1361094], "4": ["vertical_and_slash", 3500, 100, 21385650], "5": ["vertical_and_slash", 100, 800, 14152330], "6": ["vertical_and_slash", 100, 800, 15379238], "7": ["vertical_and_slash", 100, 800, 936209], "8": ["vertical_and_slash", 3500, 100, 7644919], "9": ["vertical_and_slash", 100, 800, 162434], "10": ["vertical_and_slash", 100, 800, 11548456], "11": ["vertical_and_slash", 100, 800, 11141282], "12": ["vertical_and_slash", 3500, 100, 6011727], "13": ["vertical_and_slash", 100, 800, 16026110], "14": ["vertical_and_slash", 100, 800, 466578], "15": ["vertical_and_slash", 100, 800, 4799040], "16": ["vertical_and_slash", 100, 800, 15252019], "17": ["vertical_and_slash", 1000, 4096, 7350605], "18": ["vertical_and_slash", 100, 800, 16896477], "19": ["vertical_and_slash", 1000, 4096, 5715502], "20": ["vertical_and_slash", 100, 800, 9885275], "21": ["vertical_and_slash", 100, 800, 8062274], "22": ["vertical_and_slash", 100, 800, 11341966], "23": ["vertical_and_slash", 3500, 100, 21639689], "24": ["vertical_and_slash", 1000, 4096, 7313536], "25": ["vertical_and_slash", 1000, 4096, 1858640], "26": ["vertical_and_slash", 100, 800, 17665215], "27": ["vertical_and_slash", 100, 800, 13827567], "28": ["vertical_and_slash", 1000, 4096, 16279088], "29": ["vertical_and_slash", 1000, 4096, 2728376], "30": ["vertical_and_slash", 1000, 4096, 20378804], "31": ["vertical_and_slash", 1000, 4096, 11218561]}, {"0": ["vertical_and_slash", 100, 800, 10702989], "1": ["vertical_and_slash", 100, 800, 13911357], "2": ["vertical_and_slash", 100, 800, 2089505], "3": ["vertical_and_slash", 100, 800, 5795130], "4": ["vertical_and_slash", 100, 800, 6198580], "5": ["vertical_and_slash", 100, 800, 11025874], "6": ["vertical_and_slash", 1000, 4096, 4765707], "7": ["vertical_and_slash", 100, 800, 9275261], "8": ["vertical_and_slash", 100, 800, 356772], "9": ["vertical_and_slash", 100, 800, 6507763], "10": ["vertical_and_slash", 100, 800, 1057022], "11": ["vertical_and_slash", 100, 800, 16390639], "12": ["vertical_and_slash", 1000, 4096, 6504148], "13": ["vertical_and_slash", 100, 800, 5815163], "14": ["vertical_and_slash", 100, 800, 781258], "15": ["vertical_and_slash", 1000, 4096, 5306413], "16": ["vertical_and_slash", 100, 800, 7571947], "17": ["vertical_and_slash", 100, 800, 2246584], "18": ["vertical_and_slash", 1000, 4096, 6370179], "19": ["vertical_and_slash", 1000, 4096, 16329738], "20": ["vertical_and_slash", 100, 800, 810202], "21": ["vertical_and_slash", 100, 800, 9614219], "22": ["vertical_and_slash", 3500, 100, 21023608], "23": ["vertical_and_slash", 100, 800, 3697853], "24": ["vertical_and_slash", 500, 700, 623385], "25": ["vertical_and_slash", 100, 800, 2872545], "26": ["vertical_and_slash", 3500, 100, 21443890], "27": ["vertical_and_slash", 1000, 4096, 964593], "28": ["vertical_and_slash", 1000, 4096, 6046647], "29": ["vertical_and_slash", 1000, 4096, 3390663], "30": ["vertical_and_slash", 3500, 100, 21396110], "31": ["vertical_and_slash", 500, 700, 1185821]}, {"0": ["vertical_and_slash", 100, 800, 929038], "1": ["vertical_and_slash", 1000, 4096, 11917459], "2": ["vertical_and_slash", 1000, 4096, 11189817], "3": ["vertical_and_slash", 1000, 4096, 5290948], "4": ["vertical_and_slash", 100, 800, 2444153], "5": ["vertical_and_slash", 1000, 4096, 7367448], "6": ["vertical_and_slash", 1000, 4096, 3929914], "7": ["vertical_and_slash", 1000, 4096, 2907293], "8": ["vertical_and_slash", 30, 800, 8631190], "9": ["vertical_and_slash", 100, 800, 7657567], "10": ["vertical_and_slash", 1000, 4096, 5754225], "11": ["vertical_and_slash", 100, 800, 16484372], "12": ["vertical_and_slash", 100, 800, 7369987], "13": ["vertical_and_slash", 100, 800, 3365312], "14": ["vertical_and_slash", 30, 800, 461151], "15": ["vertical_and_slash", 500, 700, 315608], "16": ["vertical_and_slash", 1000, 4096, 16240364], "17": ["vertical_and_slash", 100, 800, 253597], "18": ["vertical_and_slash", 1000, 4096, 925109], "19": ["vertical_and_slash", 100, 800, 133339], "20": ["vertical_and_slash", 100, 800, 578256], "21": ["vertical_and_slash", 1000, 4096, 1817521], "22": ["vertical_and_slash", 3500, 100, 4918245], "23": ["vertical_and_slash", 1000, 4096, 114317], "24": ["vertical_and_slash", 3500, 100, 20949654], "25": ["vertical_and_slash", 3500, 100, 21380515], "26": ["vertical_and_slash", 1000, 4096, 20796309], "27": ["vertical_and_slash", 100, 800, 11897642], "28": ["vertical_and_slash", 1000, 4096, 17534343], "29": ["vertical_and_slash", 1000, 4096, 20051889], "30": ["vertical_and_slash", 1000, 4096, 20184777], "31": ["vertical_and_slash", 3500, 100, 20262011]}, {"0": ["vertical_and_slash", 1000, 4096, 8179346], "1": ["vertical_and_slash", 1000, 4096, 2423899], "2": ["vertical_and_slash", 100, 800, 13818895], "3": ["vertical_and_slash", 1000, 4096, 6522601], "4": ["vertical_and_slash", 1000, 4096, 1060263], "5": ["vertical_and_slash", 1000, 4096, 4157137], "6": ["vertical_and_slash", 1000, 4096, 6990380], "7": ["vertical_and_slash", 1000, 4096, 10763715], "8": ["vertical_and_slash", 100, 800, 10123257], "9": ["vertical_and_slash", 1000, 4096, 9156840], "10": ["vertical_and_slash", 1000, 4096, 16029616], "11": ["vertical_and_slash", 100, 800, 1673944], "12": ["vertical_and_slash", 1000, 4096, 15001358], "13": ["vertical_and_slash", 1000, 4096, 11496585], "14": ["vertical_and_slash", 100, 800, 9006039], "15": ["vertical_and_slash", 1000, 4096, 13032008], "16": ["vertical_and_slash", 100, 800, 4813070], "17": ["vertical_and_slash", 100, 800, 1475285], "18": ["vertical_and_slash", 100, 800, 8000337], "19": ["vertical_and_slash", 100, 800, 8837856], "20": ["vertical_and_slash", 1000, 4096, 16977677], "21": ["vertical_and_slash", 100, 800, 4416649], "22": ["vertical_and_slash", 100, 800, 17025902], "23": ["vertical_and_slash", 100, 800, 602195], "24": ["vertical_and_slash", 3500, 100, 5765045], "25": ["vertical_and_slash", 100, 800, 13009069], "26": ["vertical_and_slash", 100, 800, 3523767], "27": ["vertical_and_slash", 100, 800, 6546733], "28": ["vertical_and_slash", 3500, 100, 3452012], "29": ["vertical_and_slash", 100, 800, 1510491], "30": ["vertical_and_slash", 3500, 100, 17227596], "31": ["vertical_and_slash", 3500, 100, 19660969]}, {"0": ["vertical_and_slash", 3500, 100, 6623789], "1": ["vertical_and_slash", 3500, 100, 3902994], "2": ["vertical_and_slash", 3500, 100, 6994928], "3": ["vertical_and_slash", 1000, 4096, 5149770], "4": ["vertical_and_slash", 3500, 100, 14836158], "5": ["vertical_and_slash", 100, 800, 17515427], "6": ["vertical_and_slash", 3500, 100, 7911558], "7": ["vertical_and_slash", 3500, 100, 9338861], "8": ["vertical_and_slash", 3500, 100, 14090410], "9": ["vertical_and_slash", 100, 800, 2492955], "10": ["vertical_and_slash", 3500, 100, 21732500], "11": ["vertical_and_slash", 100, 800, 2898121], "12": ["vertical_and_slash", 3500, 100, 10852444], "13": ["vertical_and_slash", 100, 800, 1940039], "14": ["vertical_and_slash", 3500, 100, 16338195], "15": ["vertical_and_slash", 100, 800, 2006495], "16": ["vertical_and_slash", 3500, 100, 10259390], "17": ["vertical_and_slash", 100, 800, 4065419], "18": ["vertical_and_slash", 100, 800, 12733273], "19": ["vertical_and_slash", 1000, 4096, 11751394], "20": ["vertical_and_slash", 100, 800, 15251186], "21": ["vertical_and_slash", 1000, 4096, 12287035], "22": ["vertical_and_slash", 1000, 4096, 5114508], "23": ["vertical_and_slash", 1000, 4096, 13162100], "24": ["vertical_and_slash", 100, 800, 8000122], "25": ["vertical_and_slash", 100, 800, 9281634], "26": ["vertical_and_slash", 100, 800, 1846488], "27": ["vertical_and_slash", 3500, 100, 8590692], "28": ["vertical_and_slash", 100, 800, 8643063], "29": ["vertical_and_slash", 100, 800, 5758817], "30": ["vertical_and_slash", 100, 800, 5877183], "31": ["vertical_and_slash", 100, 800, 7796906]}, {"0": ["vertical_and_slash", 100, 800, 20597532], "1": ["vertical_and_slash", 3500, 100, 21758452], "2": ["vertical_and_slash", 1000, 4096, 4144141], "3": ["vertical_and_slash", 100, 800, 20261887], "4": ["vertical_and_slash", 1000, 4096, 2512370], "5": ["vertical_and_slash", 3500, 100, 17706009], "6": ["vertical_and_slash", 1000, 4096, 19693735], "7": ["vertical_and_slash", 1000, 4096, 12879585], "8": ["vertical_and_slash", 3500, 100, 18330550], "9": ["vertical_and_slash", 1000, 4096, 395315], "10": ["vertical_and_slash", 100, 800, 12936460], "11": ["vertical_and_slash", 3500, 100, 20489362], "12": ["vertical_and_slash", 100, 800, 2920447], "13": ["vertical_and_slash", 3500, 100, 19704987], "14": ["vertical_and_slash", 3500, 100, 19332279], "15": ["vertical_and_slash", 100, 800, 8771256], "16": ["vertical_and_slash", 100, 800, 5611994], "17": ["vertical_and_slash", 100, 800, 16087138], "18": ["vertical_and_slash", 500, 700, 891236], "19": ["vertical_and_slash", 3500, 100, 21427139], "20": ["vertical_and_slash", 100, 800, 1823410], "21": ["vertical_and_slash", 30, 800, 15408418], "22": ["vertical_and_slash", 500, 700, 9266226], "23": ["vertical_and_slash", 3500, 100, 17195724], "24": ["vertical_and_slash", 1000, 4096, 7809063], "25": ["vertical_and_slash", 100, 800, 14083150], "26": ["vertical_and_slash", 100, 800, 4139113], "27": ["vertical_and_slash", 100, 800, 10706318], "28": ["vertical_and_slash", 1000, 4096, 1105380], "29": ["vertical_and_slash", 100, 800, 3630717], "30": ["vertical_and_slash", 1000, 4096, 10664933], "31": ["vertical_and_slash", 100, 800, 9143007]}, {"0": ["vertical_and_slash", 1000, 4096, 301018], "1": ["vertical_and_slash", 3500, 100, 1784828], "2": ["vertical_and_slash", 3500, 100, 7055406], "3": ["vertical_and_slash", 3500, 100, 2086934], "4": ["vertical_and_slash", 1000, 4096, 4101320], "5": ["vertical_and_slash", 1000, 4096, 1042376], "6": ["vertical_and_slash", 3500, 100, 16976048], "7": ["vertical_and_slash", 500, 700, 1459641], "8": ["vertical_and_slash", 3500, 100, 1180323], "9": ["vertical_and_slash", 3500, 100, 21763195], "10": ["vertical_and_slash", 3500, 100, 5825008], "11": ["vertical_and_slash", 100, 800, 53453], "12": ["vertical_and_slash", 3500, 100, 11794796], "13": ["vertical_and_slash", 3500, 100, 1783957], "14": ["vertical_and_slash", 100, 800, 1440345], "15": ["vertical_and_slash", 100, 800, 16828397], "16": ["vertical_and_slash", 100, 800, 2469338], "17": ["vertical_and_slash", 100, 800, 4665593], "18": ["vertical_and_slash", 3500, 100, 10580848], "19": ["vertical_and_slash", 3500, 100, 19252331], "20": ["vertical_and_slash", 3500, 100, 20024825], "21": ["vertical_and_slash", 100, 800, 14850871], "22": ["vertical_and_slash", 3500, 100, 12678003], "23": ["vertical_and_slash", 100, 800, 1782447], "24": ["vertical_and_slash", 1000, 4096, 13287971], "25": ["vertical_and_slash", 3500, 100, 1097488], "26": ["vertical_and_slash", 1000, 4096, 2633009], "27": ["vertical_and_slash", 3500, 100, 1055757], "28": ["vertical_and_slash", 3500, 100, 742496], "29": ["vertical_and_slash", 1000, 4096, 4194904], "30": ["vertical_and_slash", 3500, 100, 1577446], "31": ["vertical_and_slash", 1000, 4096, 10526781]}, {"0": ["vertical_and_slash", 1000, 4096, 12079479], "1": ["vertical_and_slash", 3500, 100, 19962962], "2": ["vertical_and_slash", 1000, 4096, 12450062], "3": ["vertical_and_slash", 1000, 4096, 10400447], "4": ["vertical_and_slash", 100, 800, 11323650], "5": ["vertical_and_slash", 1000, 4096, 4102038], "6": ["vertical_and_slash", 1000, 4096, 3338557], "7": ["vertical_and_slash", 3500, 100, 9984816], "8": ["vertical_and_slash", 100, 800, 14524592], "9": ["vertical_and_slash", 100, 800, 2065326], "10": ["vertical_and_slash", 30, 800, 4596708], "11": ["vertical_and_slash", 500, 700, 10708236], "12": ["vertical_and_slash", 500, 700, 13397191], "13": ["vertical_and_slash", 500, 700, 1011260], "14": ["vertical_and_slash", 1000, 4096, 13165340], "15": ["vertical_and_slash", 1000, 4096, 825692], "16": ["vertical_and_slash", 3500, 100, 2810461], "17": ["vertical_and_slash", 3500, 100, 19569698], "18": ["vertical_and_slash", 3500, 100, 2251981], "19": ["vertical_and_slash", 500, 700, 5559642], "20": ["vertical_and_slash", 3500, 100, 1522515], "21": ["vertical_and_slash", 1000, 4096, 982286], "22": ["vertical_and_slash", 1000, 4096, 2085881], "23": ["vertical_and_slash", 100, 800, 2055023], "24": ["vertical_and_slash", 1000, 4096, 1242380], "25": ["vertical_and_slash", 3500, 100, 1869920], "26": ["vertical_and_slash", 3500, 100, 12180284], "27": ["vertical_and_slash", 3500, 100, 14622044], "28": ["vertical_and_slash", 1000, 4096, 557560], "29": ["vertical_and_slash", 1000, 4096, 6987039], "30": ["vertical_and_slash", 100, 800, 15769951], "31": ["vertical_and_slash", 100, 800, 7721569]}, {"0": ["vertical_and_slash", 500, 700, 4382254], "1": ["vertical_and_slash", 3500, 100, 84219], "2": ["vertical_and_slash", 500, 700, 4734463], "3": ["vertical_and_slash", 500, 700, 3186548], "4": ["vertical_and_slash", 1000, 4096, 4063246], "5": ["vertical_and_slash", 1000, 4096, 12708225], "6": ["vertical_and_slash", 500, 700, 7742943], "7": ["vertical_and_slash", 100, 800, 15424159], "8": ["vertical_and_slash", 1000, 4096, 6301506], "9": ["vertical_and_slash", 1000, 4096, 2079847], "10": ["vertical_and_slash", 1000, 4096, 4217027], "11": ["vertical_and_slash", 1000, 4096, 6297884], "12": ["vertical_and_slash", 3500, 100, 4824003], "13": ["vertical_and_slash", 1000, 4096, 3960801], "14": ["vertical_and_slash", 1000, 4096, 10405673], "15": ["vertical_and_slash", 1000, 4096, 8272702], "16": ["vertical_and_slash", 3500, 100, 2874719], "17": ["vertical_and_slash", 1000, 4096, 13248253], "18": ["vertical_and_slash", 3500, 100, 16731069], "19": ["vertical_and_slash", 1000, 4096, 3488474], "20": ["vertical_and_slash", 3500, 100, 4911794], "21": ["vertical_and_slash", 3500, 100, 3300649], "22": ["vertical_and_slash", 3500, 100, 2239972], "23": ["vertical_and_slash", 1000, 4096, 847410], "24": ["vertical_and_slash", 1000, 4096, 12556756], "25": ["vertical_and_slash", 3500, 100, 10893823], "26": ["vertical_and_slash", 1000, 4096, 14168165], "27": ["vertical_and_slash", 1000, 4096, 14127548], "28": ["vertical_and_slash", 1000, 4096, 5277617], "29": ["vertical_and_slash", 1000, 4096, 16652651], "30": ["vertical_and_slash", 1000, 4096, 7991739], "31": ["vertical_and_slash", 3500, 100, 16136482]}, {"0": ["vertical_and_slash", 100, 800, 3776409], "1": ["vertical_and_slash", 100, 800, 3972530], "2": ["vertical_and_slash", 100, 800, 10166976], "3": ["vertical_and_slash", 100, 800, 13449519], "4": ["vertical_and_slash", 30, 800, 4621777], "5": ["vertical_and_slash", 30, 800, 17026761], "6": ["vertical_and_slash", 30, 800, 11401344], "7": ["vertical_and_slash", 100, 800, 3178997], "8": ["vertical_and_slash", 1000, 4096, 14919677], "9": ["vertical_and_slash", 100, 800, 13489170], "10": ["vertical_and_slash", 1000, 4096, 12483196], "11": ["vertical_and_slash", 1000, 4096, 18647183], "12": ["vertical_and_slash", 1000, 4096, 18488628], "13": ["vertical_and_slash", 3500, 100, 18285318], "14": ["vertical_and_slash", 3500, 100, 19771087], "15": ["vertical_and_slash", 100, 800, 11952058], "16": ["vertical_and_slash", 1000, 4096, 671303], "17": ["vertical_and_slash", 3500, 100, 20413410], "18": ["vertical_and_slash", 1000, 4096, 693843], "19": ["vertical_and_slash", 3500, 100, 20183012], "20": ["vertical_and_slash", 3500, 100, 4751982], "21": ["vertical_and_slash", 1000, 4096, 1190840], "22": ["vertical_and_slash", 3500, 100, 8189368], "23": ["vertical_and_slash", 3500, 100, 4191516], "24": ["vertical_and_slash", 100, 800, 9072597], "25": ["vertical_and_slash", 3500, 100, 6214053], "26": ["vertical_and_slash", 1000, 4096, 8848124], "27": ["vertical_and_slash", 3500, 100, 1231805], "28": ["vertical_and_slash", 3500, 100, 3468573], "29": ["vertical_and_slash", 3500, 100, 16841594], "30": ["vertical_and_slash", 3500, 100, 12565098], "31": ["vertical_and_slash", 3500, 100, 4308210]}, {"0": ["vertical_and_slash", 100, 800, 405030], "1": ["vertical_and_slash", 3500, 100, 12737242], "2": ["vertical_and_slash", 1000, 4096, 6996254], "3": ["vertical_and_slash", 3500, 100, 4831216], "4": ["vertical_and_slash", 3500, 100, 5890590], "5": ["vertical_and_slash", 1000, 4096, 3008671], "6": ["vertical_and_slash", 1000, 4096, 4998230], "7": ["vertical_and_slash", 1000, 4096, 6509194], "8": ["vertical_and_slash", 3500, 100, 1774041], "9": ["vertical_and_slash", 3500, 100, 1372562], "10": ["vertical_and_slash", 3500, 100, 9111804], "11": ["vertical_and_slash", 1000, 4096, 1109182], "12": ["vertical_and_slash", 100, 800, 371771], "13": ["vertical_and_slash", 3500, 100, 905824], "14": ["vertical_and_slash", 1000, 4096, 4934535], "15": ["vertical_and_slash", 1000, 4096, 2841896], "16": ["vertical_and_slash", 3500, 100, 4614245], "17": ["vertical_and_slash", 3500, 100, 6900617], "18": ["vertical_and_slash", 3500, 100, 2824788], "19": ["vertical_and_slash", 100, 800, 6589423], "20": ["vertical_and_slash", 500, 700, 6357101], "21": ["vertical_and_slash", 30, 800, 5731632], "22": ["vertical_and_slash", 30, 800, 7261064], "23": ["vertical_and_slash", 500, 700, 9172114], "24": ["vertical_and_slash", 1000, 4096, 210349], "25": ["vertical_and_slash", 1000, 4096, 4526369], "26": ["vertical_and_slash", 1000, 4096, 2326769], "27": ["vertical_and_slash", 3500, 100, 5989844], "28": ["vertical_and_slash", 3500, 100, 1393004], "29": ["vertical_and_slash", 3500, 100, 2114704], "30": ["vertical_and_slash", 3500, 100, 776564], "31": ["vertical_and_slash", 3500, 100, 2826514]}, {"0": ["vertical_and_slash", 1000, 4096, 4747927], "1": ["vertical_and_slash", 3500, 100, 14468785], "2": ["vertical_and_slash", 3500, 100, 10124003], "3": ["vertical_and_slash", 3500, 100, 6702061], "4": ["vertical_and_slash", 1000, 4096, 2311190], "5": ["vertical_and_slash", 1000, 4096, 2412642], "6": ["vertical_and_slash", 1000, 4096, 2782532], "7": ["vertical_and_slash", 3500, 100, 6699063], "8": ["vertical_and_slash", 100, 800, 10899273], "9": ["vertical_and_slash", 100, 800, 571205], "10": ["vertical_and_slash", 1000, 4096, 2224039], "11": ["vertical_and_slash", 3500, 100, 5206481], "12": ["vertical_and_slash", 100, 800, 6039530], "13": ["vertical_and_slash", 3500, 100, 6121024], "14": ["vertical_and_slash", 1000, 4096, 915849], "15": ["vertical_and_slash", 3500, 100, 4393793], "16": ["vertical_and_slash", 1000, 4096, 4168491], "17": ["vertical_and_slash", 3500, 100, 5568206], "18": ["vertical_and_slash", 1000, 4096, 1087118], "19": ["vertical_and_slash", 1000, 4096, 2691708], "20": ["vertical_and_slash", 3500, 100, 4351677], "21": ["vertical_and_slash", 3500, 100, 3933999], "22": ["vertical_and_slash", 3500, 100, 3997663], "23": ["vertical_and_slash", 3500, 100, 3522236], "24": ["vertical_and_slash", 3500, 100, 9956224], "25": ["vertical_and_slash", 3500, 100, 4192895], "26": ["vertical_and_slash", 3500, 100, 9150842], "27": ["vertical_and_slash", 3500, 100, 12754903], "28": ["vertical_and_slash", 3500, 100, 7346979], "29": ["vertical_and_slash", 100, 800, 9422285], "30": ["vertical_and_slash", 100, 800, 3140769], "31": ["vertical_and_slash", 500, 700, 2415994]}, {"0": ["vertical_and_slash", 3500, 100, 4352921], "1": ["vertical_and_slash", 1000, 4096, 3398326], "2": ["vertical_and_slash", 3500, 100, 5788760], "3": ["vertical_and_slash", 1000, 4096, 2945608], "4": ["vertical_and_slash", 3500, 100, 1988612], "5": ["vertical_and_slash", 1000, 4096, 3736165], "6": ["vertical_and_slash", 1000, 4096, 9670660], "7": ["vertical_and_slash", 3500, 100, 3803388], "8": ["vertical_and_slash", 3500, 100, 3612542], "9": ["vertical_and_slash", 3500, 100, 4948698], "10": ["vertical_and_slash", 3500, 100, 4880140], "11": ["vertical_and_slash", 3500, 100, 2083345], "12": ["vertical_and_slash", 3500, 100, 4683160], "13": ["vertical_and_slash", 3500, 100, 3650326], "14": ["vertical_and_slash", 3500, 100, 1071456], "15": ["vertical_and_slash", 1000, 4096, 3490570], "16": ["vertical_and_slash", 1000, 4096, 1082160], "17": ["vertical_and_slash", 3500, 100, 6888781], "18": ["vertical_and_slash", 1000, 4096, 2664476], "19": ["vertical_and_slash", 3500, 100, 2759933], "20": ["vertical_and_slash", 100, 800, 653736], "21": ["vertical_and_slash", 3500, 100, 9517662], "22": ["vertical_and_slash", 3500, 100, 3973048], "23": ["vertical_and_slash", 3500, 100, 5761264], "24": ["vertical_and_slash", 3500, 100, 13615692], "25": ["vertical_and_slash", 1000, 4096, 5235320], "26": ["vertical_and_slash", 3500, 100, 10009513], "27": ["vertical_and_slash", 1000, 4096, 2682717], "28": ["vertical_and_slash", 3500, 100, 11382630], "29": ["vertical_and_slash", 3500, 100, 3802301], "30": ["vertical_and_slash", 1000, 4096, 3025864], "31": ["vertical_and_slash", 1000, 4096, 1725752]}, {"0": ["vertical_and_slash", 1000, 4096, 12877084], "1": ["vertical_and_slash", 1000, 4096, 11642564], "2": ["vertical_and_slash", 1000, 4096, 10978654], "3": ["vertical_and_slash", 3500, 100, 14674762], "4": ["vertical_and_slash", 1000, 4096, 8335239], "5": ["vertical_and_slash", 1000, 4096, 11808042], "6": ["vertical_and_slash", 1000, 4096, 10213550], "7": ["vertical_and_slash", 3500, 100, 14957853], "8": ["vertical_and_slash", 500, 700, 19867441], "9": ["vertical_and_slash", 100, 800, 10566603], "10": ["vertical_and_slash", 3500, 100, 19670449], "11": ["vertical_and_slash", 1000, 4096, 12608408], "12": ["vertical_and_slash", 3500, 100, 19432490], "13": ["vertical_and_slash", 3500, 100, 21127812], "14": ["vertical_and_slash", 3500, 100, 16648204], "15": ["vertical_and_slash", 1000, 4096, 10819630], "16": ["vertical_and_slash", 3500, 100, 5741199], "17": ["vertical_and_slash", 3500, 100, 2265976], "18": ["vertical_and_slash", 1000, 4096, 1571848], "19": ["vertical_and_slash", 3500, 100, 12168656], "20": ["vertical_and_slash", 3500, 100, 12687129], "21": ["vertical_and_slash", 1000, 4096, 4052254], "22": ["vertical_and_slash", 3500, 100, 9260206], "23": ["vertical_and_slash", 1000, 4096, 4467273], "24": ["vertical_and_slash", 100, 800, 17813181], "25": ["vertical_and_slash", 3500, 100, 21532596], "26": ["vertical_and_slash", 1000, 4096, 14291589], "27": ["vertical_and_slash", 1000, 4096, 17941032], "28": ["vertical_and_slash", 1000, 4096, 20269858], "29": ["vertical_and_slash", 100, 800, 16481898], "30": ["vertical_and_slash", 100, 800, 14035138], "31": ["vertical_and_slash", 3500, 100, 5218579]}, {"0": ["vertical_and_slash", 1000, 4096, 15472775], "1": ["vertical_and_slash", 500, 700, 16487444], "2": ["vertical_and_slash", 1000, 4096, 13062108], "3": ["vertical_and_slash", 1000, 4096, 17155780], "4": ["vertical_and_slash", 1000, 4096, 9528835], "5": ["vertical_and_slash", 1000, 4096, 18482684], "6": ["vertical_and_slash", 1000, 4096, 17086801], "7": ["vertical_and_slash", 100, 800, 16495168], "8": ["vertical_and_slash", 1000, 4096, 6931295], "9": ["vertical_and_slash", 3500, 100, 21960054], "10": ["vertical_and_slash", 1000, 4096, 13941150], "11": ["vertical_and_slash", 3500, 100, 6249722], "12": ["vertical_and_slash", 1000, 4096, 12292065], "13": ["vertical_and_slash", 3500, 100, 14056066], "14": ["vertical_and_slash", 1000, 4096, 17988711], "15": ["vertical_and_slash", 3500, 100, 13838932], "16": ["vertical_and_slash", 3500, 100, 11542474], "17": ["vertical_and_slash", 1000, 4096, 10272174], "18": ["vertical_and_slash", 3500, 100, 10106952], "19": ["vertical_and_slash", 1000, 4096, 11953729], "20": ["vertical_and_slash", 1000, 4096, 12125335], "21": ["vertical_and_slash", 1000, 4096, 5421557], "22": ["vertical_and_slash", 1000, 4096, 17046156], "23": ["vertical_and_slash", 1000, 4096, 13763363], "24": ["vertical_and_slash", 1000, 4096, 14971340], "25": ["vertical_and_slash", 1000, 4096, 13949429], "26": ["vertical_and_slash", 1000, 4096, 13427580], "27": ["vertical_and_slash", 1000, 4096, 12712355], "28": ["vertical_and_slash", 1000, 4096, 10262417], "29": ["vertical_and_slash", 1000, 4096, 14593517], "30": ["vertical_and_slash", 3500, 100, 19020287], "31": ["vertical_and_slash", 1000, 4096, 16309396]}, {"0": ["vertical_and_slash", 100, 800, 6402139], "1": ["vertical_and_slash", 500, 700, 8580595], "2": ["vertical_and_slash", 3500, 100, 6974040], "3": ["vertical_and_slash", 500, 700, 9230357], "4": ["vertical_and_slash", 500, 700, 1458178], "5": ["vertical_and_slash", 3500, 100, 12626929], "6": ["vertical_and_slash", 500, 700, 7367522], "7": ["vertical_and_slash", 30, 800, 16753754], "8": ["vertical_and_slash", 100, 800, 16185443], "9": ["vertical_and_slash", 30, 800, 13212259], "10": ["vertical_and_slash", 30, 800, 16869582], "11": ["vertical_and_slash", 100, 800, 8982160], "12": ["vertical_and_slash", 3500, 100, 15101824], "13": ["vertical_and_slash", 500, 700, 10028751], "14": ["vertical_and_slash", 30, 800, 18999889], "15": ["vertical_and_slash", 100, 800, 15535188], "16": ["vertical_and_slash", 1000, 4096, 3376934], "17": ["vertical_and_slash", 1000, 4096, 3838435], "18": ["vertical_and_slash", 1000, 4096, 2789787], "19": ["vertical_and_slash", 1000, 4096, 9668519], "20": ["vertical_and_slash", 500, 700, 16137894], "21": ["vertical_and_slash", 1000, 4096, 3380197], "22": ["vertical_and_slash", 500, 700, 6788616], "23": ["vertical_and_slash", 1000, 4096, 4978497], "24": ["vertical_and_slash", 3500, 100, 9896749], "25": ["vertical_and_slash", 500, 700, 20982412], "26": ["vertical_and_slash", 1000, 4096, 5738438], "27": ["vertical_and_slash", 1000, 4096, 14533987], "28": ["vertical_and_slash", 3500, 100, 11385648], "29": ["vertical_and_slash", 30, 800, 11091461], "30": ["vertical_and_slash", 1000, 4096, 7801211], "31": ["vertical_and_slash", 1000, 4096, 12946499]}, {"0": ["vertical_and_slash", 1000, 4096, 8005141], "1": ["vertical_and_slash", 30, 800, 9683398], "2": ["vertical_and_slash", 100, 800, 15684848], "3": ["vertical_and_slash", 30, 800, 10783581], "4": ["vertical_and_slash", 30, 800, 12674711], "5": ["vertical_and_slash", 100, 800, 17627426], "6": ["vertical_and_slash", 500, 700, 6603740], "7": ["vertical_and_slash", 30, 800, 8037793], "8": ["vertical_and_slash", 1000, 4096, 18603355], "9": ["vertical_and_slash", 100, 800, 18175297], "10": ["vertical_and_slash", 1000, 4096, 15415235], "11": ["vertical_and_slash", 100, 800, 8188133], "12": ["vertical_and_slash", 100, 800, 16790430], "13": ["vertical_and_slash", 1000, 4096, 4440951], "14": ["vertical_and_slash", 1000, 4096, 12155674], "15": ["vertical_and_slash", 3500, 100, 18728501], "16": ["vertical_and_slash", 30, 800, 8282869], "17": ["vertical_and_slash", 30, 800, 18611641], "18": ["vertical_and_slash", 30, 800, 7125529], "19": ["vertical_and_slash", 30, 800, 9867525], "20": ["vertical_and_slash", 100, 800, 8121064], "21": ["vertical_and_slash", 100, 800, 8406786], "22": ["vertical_and_slash", 30, 800, 11020990], "23": ["vertical_and_slash", 30, 800, 4944682], "24": ["vertical_and_slash", 30, 800, 16714152], "25": ["vertical_and_slash", 30, 800, 9194588], "26": ["vertical_and_slash", 500, 700, 9003731], "27": ["vertical_and_slash", 1000, 4096, 6939820], "28": ["vertical_and_slash", 500, 700, 10839557], "29": ["vertical_and_slash", 500, 700, 14432584], "30": ["vertical_and_slash", 100, 800, 12363347], "31": ["vertical_and_slash", 30, 800, 14465663]}]
|
minference/configs/model2path.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
4 |
+
MODEL2PATH = {
|
5 |
+
"gradientai/Llama-3-8B-Instruct-262k": os.path.join(
|
6 |
+
BASE_DIR, "Llama_3_8B_Instruct_262k_kv_out_v32_fit_o_best_pattern.json"
|
7 |
+
),
|
8 |
+
"gradientai/Llama-3-8B-Instruct-Gradient-1048k": os.path.join(
|
9 |
+
BASE_DIR, "Llama_3_8B_Instruct_262k_kv_out_v32_fit_o_best_pattern.json"
|
10 |
+
),
|
11 |
+
"01-ai/Yi-9B-200K": os.path.join(
|
12 |
+
BASE_DIR, "Yi_9B_200k_kv_out_v32_fit_o_best_pattern.json"
|
13 |
+
),
|
14 |
+
"microsoft/Phi-3-mini-128k-instruct": os.path.join(
|
15 |
+
BASE_DIR, "Phi_3_mini_128k_instruct_kv_out_v32_fit_o_best_pattern.json"
|
16 |
+
),
|
17 |
+
}
|
minference/minference_configuration.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from .configs.model2path import MODEL2PATH
|
4 |
+
|
5 |
+
|
6 |
+
class MInferenceConfig:
|
7 |
+
ATTENTION_TYPES = [
|
8 |
+
"minference",
|
9 |
+
"minference_with_dense",
|
10 |
+
"static",
|
11 |
+
"dilated1",
|
12 |
+
"dilated2",
|
13 |
+
"streaming",
|
14 |
+
"inf_llm",
|
15 |
+
"vllm",
|
16 |
+
]
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
attn_type: str = "minference",
|
21 |
+
model_name: str = None,
|
22 |
+
config_path: str = None,
|
23 |
+
starting_layer: int = -1,
|
24 |
+
kv_cache_cpu: bool = False,
|
25 |
+
use_snapkv: bool = False,
|
26 |
+
is_search: bool = False,
|
27 |
+
attn_kwargs: dict = {},
|
28 |
+
**kwargs,
|
29 |
+
):
|
30 |
+
super(MInferenceConfig, self).__init__()
|
31 |
+
assert (
|
32 |
+
attn_type in self.ATTENTION_TYPES
|
33 |
+
), f"The attention_type {attn_type} you specified is not supported."
|
34 |
+
self.attn_type = attn_type
|
35 |
+
self.config_path = self.update_config_path(config_path, model_name)
|
36 |
+
self.model_name = model_name
|
37 |
+
self.is_search = is_search
|
38 |
+
self.starting_layer = starting_layer
|
39 |
+
self.kv_cache_cpu = kv_cache_cpu
|
40 |
+
self.use_snapkv = use_snapkv
|
41 |
+
self.attn_kwargs = attn_kwargs
|
42 |
+
|
43 |
+
def update_config_path(self, config_path: str, model_name: str):
|
44 |
+
if config_path is not None:
|
45 |
+
return config_path
|
46 |
+
assert (
|
47 |
+
model_name in MODEL2PATH
|
48 |
+
), f"The model {model_name} you specified is not supported. You are welcome to add it and open a PR :)"
|
49 |
+
return MODEL2PATH[model_name]
|
minference/models_patch.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from .minference_configuration import MInferenceConfig
|
4 |
+
from .patch import minference_patch, minference_patch_vllm, patch_hf
|
5 |
+
|
6 |
+
|
7 |
+
class MInference:
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
attn_type: str = "minference",
|
11 |
+
model_name: str = None,
|
12 |
+
config_path: str = None,
|
13 |
+
starting_layer: int = -1,
|
14 |
+
kv_cache_cpu: bool = False,
|
15 |
+
use_snapkv: bool = False,
|
16 |
+
is_search: bool = False,
|
17 |
+
attn_kwargs: dict = {},
|
18 |
+
**kwargs,
|
19 |
+
):
|
20 |
+
super(MInference, self).__init__()
|
21 |
+
self.config = MInferenceConfig(
|
22 |
+
attn_type=attn_type,
|
23 |
+
model_name=model_name,
|
24 |
+
config_path=config_path,
|
25 |
+
starting_layer=starting_layer,
|
26 |
+
kv_cache_cpu=kv_cache_cpu,
|
27 |
+
use_snapkv=use_snapkv,
|
28 |
+
is_search=is_search,
|
29 |
+
attn_kwargs=attn_kwargs,
|
30 |
+
**kwargs,
|
31 |
+
)
|
32 |
+
|
33 |
+
def __call__(self, model):
|
34 |
+
return self.patch_model(model)
|
35 |
+
|
36 |
+
def patch_model(self, model):
|
37 |
+
if self.config.attn_type != "vllm":
|
38 |
+
model.config.starting_layer = self.config.starting_layer
|
39 |
+
model.config.config_path = self.config.config_path
|
40 |
+
|
41 |
+
if self.config.attn_type == "minference":
|
42 |
+
model.config.is_search = self.config.is_search
|
43 |
+
model = minference_patch(model, self.config)
|
44 |
+
|
45 |
+
elif self.config.attn_type == "minference_with_dense":
|
46 |
+
model.config.dense = True
|
47 |
+
model = minference_patch(model, self.config)
|
48 |
+
|
49 |
+
elif self.config.attn_type == "dilated1":
|
50 |
+
model.config.dilated1 = True
|
51 |
+
model = minference_patch(model, self.config)
|
52 |
+
|
53 |
+
elif self.config.attn_type == "static":
|
54 |
+
model.config.static_pattern = True
|
55 |
+
model = minference_patch(model, self.config)
|
56 |
+
|
57 |
+
elif self.config.attn_type == "dilated2":
|
58 |
+
model.config.dilated2 = True
|
59 |
+
model = minference_patch(model, self.config)
|
60 |
+
|
61 |
+
elif self.config.attn_type == "streaming":
|
62 |
+
model.config.streaming = True
|
63 |
+
model.config.streaming_kwargs = {
|
64 |
+
"n_local": 3968,
|
65 |
+
"n_init": 128,
|
66 |
+
**self.config.attn_kwargs,
|
67 |
+
}
|
68 |
+
model = minference_patch(model, self.config)
|
69 |
+
|
70 |
+
elif self.config.attn_type == "streaming2":
|
71 |
+
model = patch_hf(
|
72 |
+
model,
|
73 |
+
attn_type="streaming",
|
74 |
+
attn_kwargs={"n_local": 3968, "n_init": 128, **self.config.attn_kwargs},
|
75 |
+
)
|
76 |
+
elif self.config.attn_type == "inf_llm":
|
77 |
+
model = patch_hf(
|
78 |
+
model,
|
79 |
+
attn_type="inf_llm",
|
80 |
+
attn_kwargs={
|
81 |
+
"block_size": 128,
|
82 |
+
"n_init": 128,
|
83 |
+
"n_local": 4096,
|
84 |
+
"topk": 16,
|
85 |
+
"repr_topk": 4,
|
86 |
+
"max_cached_block": 32,
|
87 |
+
"exc_block_size": 512,
|
88 |
+
"base": 1000000,
|
89 |
+
"distance_scale": 1.0,
|
90 |
+
"dense_decoding": True,
|
91 |
+
**self.config.attn_kwargs,
|
92 |
+
},
|
93 |
+
)
|
94 |
+
elif self.config.attn_type == "vllm":
|
95 |
+
model = minference_patch_vllm(model, self.config.config_path)
|
96 |
+
else:
|
97 |
+
raise ValueError(
|
98 |
+
f"The attention type {self.config.attn_type} you specified is not supported."
|
99 |
+
)
|
100 |
+
return model
|
minference/modules/inf_llm.py
ADDED
@@ -0,0 +1,1296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from flash_attn import flash_attn_func
|
6 |
+
from transformers.modeling_outputs import CausalLMOutput
|
7 |
+
|
8 |
+
from ..ops.streaming_kernel import TritonMultiStageDotProductionAttention
|
9 |
+
|
10 |
+
|
11 |
+
class CudaCache:
|
12 |
+
def __init__(self, num_units, unit_size, dtype):
|
13 |
+
self.num_units = num_units
|
14 |
+
self.unit_size = unit_size
|
15 |
+
self.dtype = dtype
|
16 |
+
self.data = torch.empty((num_units, unit_size), device="cuda", dtype=dtype)
|
17 |
+
self.idle_set = set(list(range(num_units)))
|
18 |
+
|
19 |
+
def alloc(self):
|
20 |
+
assert len(self.idle_set) > 0
|
21 |
+
|
22 |
+
idx = self.idle_set.pop()
|
23 |
+
return self.data[idx], idx
|
24 |
+
|
25 |
+
def delete(self, idx):
|
26 |
+
assert idx not in self.idle_set
|
27 |
+
self.idle_set.add(idx)
|
28 |
+
|
29 |
+
|
30 |
+
class MemoryUnit:
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
kv: Tuple[torch.Tensor, torch.Tensor],
|
34 |
+
cache: CudaCache,
|
35 |
+
load_to_cache: bool = False,
|
36 |
+
pin_memory: bool = False,
|
37 |
+
):
|
38 |
+
self.cache = cache
|
39 |
+
|
40 |
+
if kv[0].is_cuda:
|
41 |
+
cpu_data = tuple(_t.contiguous().to("cpu", non_blocking=True) for _t in kv)
|
42 |
+
else:
|
43 |
+
cpu_data = tuple(_t.contiguous() for _t in kv)
|
44 |
+
|
45 |
+
if pin_memory:
|
46 |
+
cpu_data = tuple(_t.pin_memory() for _t in cpu_data)
|
47 |
+
|
48 |
+
if load_to_cache:
|
49 |
+
gpu_data, gpu_data_id = cache.alloc()
|
50 |
+
gpu_data = gpu_data.view((2,) + kv[0].shape)
|
51 |
+
gpu_data[0].copy_(kv[0], non_blocking=True)
|
52 |
+
gpu_data[1].copy_(kv[1], non_blocking=True)
|
53 |
+
event = torch.cuda.Event()
|
54 |
+
event.record(torch.cuda.current_stream())
|
55 |
+
else:
|
56 |
+
gpu_data, gpu_data_id = None, None
|
57 |
+
event = None
|
58 |
+
|
59 |
+
self.cpu_data = cpu_data
|
60 |
+
self.gpu_data = gpu_data
|
61 |
+
self.gpu_data_id = gpu_data_id
|
62 |
+
self.event = event
|
63 |
+
|
64 |
+
def load(self, target: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> bool:
|
65 |
+
if self.gpu_data is not None:
|
66 |
+
if target is not None:
|
67 |
+
target[0].copy_(self.gpu_data[0], non_blocking=True)
|
68 |
+
target[1].copy_(self.gpu_data[1], non_blocking=True)
|
69 |
+
target_event = torch.cuda.Event()
|
70 |
+
target_event.record(torch.cuda.current_stream())
|
71 |
+
else:
|
72 |
+
target_event = None
|
73 |
+
|
74 |
+
return False, target_event
|
75 |
+
|
76 |
+
gpu_data, gpu_data_id = self.cache.alloc()
|
77 |
+
gpu_data = gpu_data.view((2,) + self.cpu_data[0].shape)
|
78 |
+
if target is not None:
|
79 |
+
target[0].copy_(self.cpu_data[0], non_blocking=True)
|
80 |
+
target[1].copy_(self.cpu_data[1], non_blocking=True)
|
81 |
+
target_event = torch.cuda.Event()
|
82 |
+
target_event.record(torch.cuda.current_stream())
|
83 |
+
gpu_data[0].copy_(target[0], non_blocking=True)
|
84 |
+
gpu_data[1].copy_(target[1], non_blocking=True)
|
85 |
+
|
86 |
+
else:
|
87 |
+
gpu_data[0].copy_(self.cpu_data[0], non_blocking=True)
|
88 |
+
gpu_data[1].copy_(self.cpu_data[1], non_blocking=True)
|
89 |
+
|
90 |
+
event = torch.cuda.Event()
|
91 |
+
event.record(torch.cuda.current_stream())
|
92 |
+
self.event = event
|
93 |
+
self.gpu_data = gpu_data
|
94 |
+
self.gpu_data_id = gpu_data_id
|
95 |
+
|
96 |
+
return True, target_event
|
97 |
+
|
98 |
+
def get(self):
|
99 |
+
assert self.gpu_data is not None
|
100 |
+
self.event.wait()
|
101 |
+
return self.gpu_data
|
102 |
+
|
103 |
+
def offload(self):
|
104 |
+
assert self.gpu_data is not None
|
105 |
+
self.event.wait()
|
106 |
+
self.gpu_data = None
|
107 |
+
self.cache.delete(self.gpu_data_id)
|
108 |
+
self.gpu_data_id = None
|
109 |
+
|
110 |
+
|
111 |
+
class VectorTensor:
|
112 |
+
def __init__(self, hidden_size, element_dtype):
|
113 |
+
init_cached_size = 16
|
114 |
+
self.data = torch.empty(
|
115 |
+
(init_cached_size, hidden_size), dtype=element_dtype, device="cuda"
|
116 |
+
)
|
117 |
+
self.length = 0
|
118 |
+
self.cache_size = init_cached_size
|
119 |
+
self.hidden_size = hidden_size
|
120 |
+
|
121 |
+
def append_cache(self):
|
122 |
+
new_cache_size = self.cache_size * 2
|
123 |
+
data_shape = self.data.shape
|
124 |
+
new_data = torch.empty(
|
125 |
+
(new_cache_size,) + data_shape[1:], device="cuda", dtype=self.data.dtype
|
126 |
+
)
|
127 |
+
new_data[: self.cache_size, ...].copy_(self.data)
|
128 |
+
self.data = new_data
|
129 |
+
self.cache_size = new_cache_size
|
130 |
+
|
131 |
+
def append(self, tensor: torch.Tensor):
|
132 |
+
assert tensor.dtype == self.data.dtype
|
133 |
+
assert tensor.size(1) == self.hidden_size
|
134 |
+
assert tensor.is_contiguous()
|
135 |
+
|
136 |
+
append_l = tensor.size(0)
|
137 |
+
|
138 |
+
while self.length + append_l > self.cache_size:
|
139 |
+
self.append_cache()
|
140 |
+
|
141 |
+
self.data[self.length : self.length + append_l, ...].copy_(tensor)
|
142 |
+
|
143 |
+
self.length += append_l
|
144 |
+
|
145 |
+
def get_data(self):
|
146 |
+
return self.data[: self.length, ...]
|
147 |
+
|
148 |
+
def get_topk(self, tensor: torch.Tensor, topk): # inner product
|
149 |
+
assert tensor.dim() == 1 and tensor.size(0) == self.hidden_size
|
150 |
+
logits = torch.matmul(self.data[: self.length], tensor[:, None]).squeeze(dim=-1)
|
151 |
+
assert logits.dim() == 1 and logits.size(0) == self.length
|
152 |
+
return logits.topk(topk, dim=0).indices.cpu().tolist()
|
153 |
+
|
154 |
+
def __len__(self):
|
155 |
+
return self.length
|
156 |
+
|
157 |
+
|
158 |
+
class Faiss:
|
159 |
+
def __init__(self, hidden_size, element_dtype):
|
160 |
+
import faiss
|
161 |
+
|
162 |
+
# We use the CPU index here because the GPU index requires a long initialization time
|
163 |
+
self.index = faiss.IndexFlatIP(hidden_size)
|
164 |
+
self.hidden_size = hidden_size
|
165 |
+
|
166 |
+
def append(self, tensor: torch.Tensor):
|
167 |
+
assert tensor.dim() == 2 and tensor.size(1) == self.hidden_size
|
168 |
+
self.index.add(tensor.cpu().float().numpy().astype("float32"))
|
169 |
+
|
170 |
+
def get_data(self):
|
171 |
+
raise ValueError
|
172 |
+
|
173 |
+
def get_topk(self, tensor: torch.Tensor, topk):
|
174 |
+
assert tensor.dim() == 1 and tensor.size(0) == self.hidden_size
|
175 |
+
xq = tensor[None, :].cpu().float().numpy().astype("float32")
|
176 |
+
topk_index = self.index.search(xq, topk)[1][0].tolist()
|
177 |
+
return topk_index
|
178 |
+
|
179 |
+
def __len__(self):
|
180 |
+
return self.index.ntotal
|
181 |
+
|
182 |
+
|
183 |
+
GLOBAL_STREAM = None
|
184 |
+
|
185 |
+
|
186 |
+
class ContextManager:
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
position_embedding,
|
190 |
+
n_init,
|
191 |
+
n_local,
|
192 |
+
block_size,
|
193 |
+
max_cached_block,
|
194 |
+
topk,
|
195 |
+
exc_block_size,
|
196 |
+
score_decay: Optional[float] = None,
|
197 |
+
repr_topk: int = 1,
|
198 |
+
cache_strategy="lru",
|
199 |
+
chunk_topk_calc: Optional[int] = None,
|
200 |
+
async_global_stream: bool = False,
|
201 |
+
pin_memory: bool = False,
|
202 |
+
faiss: bool = False,
|
203 |
+
perhead: bool = False,
|
204 |
+
dense_decoding: bool = False,
|
205 |
+
):
|
206 |
+
self.length = 0
|
207 |
+
self.position_embedding = position_embedding
|
208 |
+
self.n_init = n_init
|
209 |
+
self.n_local = n_local
|
210 |
+
self.block_size = block_size
|
211 |
+
self.max_cached_block = max_cached_block
|
212 |
+
self.exc_block_size = exc_block_size
|
213 |
+
self.score_decay = score_decay
|
214 |
+
assert exc_block_size <= n_local # no global token in input
|
215 |
+
self.topk = topk
|
216 |
+
self.Attn = TritonMultiStageDotProductionAttention
|
217 |
+
self.initialized = False
|
218 |
+
self.repr_topk = repr_topk
|
219 |
+
self.cache_strategy = cache_strategy
|
220 |
+
self.load_count = 0
|
221 |
+
self.chunk_topk_calc = chunk_topk_calc
|
222 |
+
self.async_global_stream = async_global_stream
|
223 |
+
self.pin_memory = pin_memory
|
224 |
+
self.faiss = faiss
|
225 |
+
self.perhead = perhead
|
226 |
+
|
227 |
+
self.dense_decoding = dense_decoding
|
228 |
+
|
229 |
+
global GLOBAL_STREAM
|
230 |
+
if self.async_global_stream and GLOBAL_STREAM is None:
|
231 |
+
GLOBAL_STREAM = torch.cuda.Stream()
|
232 |
+
|
233 |
+
assert cache_strategy in ["lru", "lru-s"]
|
234 |
+
|
235 |
+
if cache_strategy == "lru-s":
|
236 |
+
self.calc_block_score = True
|
237 |
+
else:
|
238 |
+
self.calc_block_score = False
|
239 |
+
|
240 |
+
def remove_lru_blocks(
|
241 |
+
self, u, num_remove: Optional[int] = None, ignore_blocks=None
|
242 |
+
):
|
243 |
+
if num_remove is None:
|
244 |
+
num_remove = len(self.cached_blocks[u]) - self.max_cached_block
|
245 |
+
|
246 |
+
if num_remove <= 0:
|
247 |
+
return
|
248 |
+
|
249 |
+
lst = list(self.cached_blocks[u].items())
|
250 |
+
lst.sort(key=lambda x: x[1])
|
251 |
+
|
252 |
+
removed = 0
|
253 |
+
for i in range(len(lst)):
|
254 |
+
idx = lst[i][0]
|
255 |
+
if ignore_blocks is None or (idx not in ignore_blocks):
|
256 |
+
self.global_blocks[u][idx].offload()
|
257 |
+
self.cached_blocks[u].pop(idx)
|
258 |
+
removed += 1
|
259 |
+
|
260 |
+
if removed >= num_remove:
|
261 |
+
return
|
262 |
+
|
263 |
+
def get_block_k(self, k, score):
|
264 |
+
assert isinstance(score, torch.Tensor)
|
265 |
+
assert k.dim() >= 2
|
266 |
+
k = self.from_group_kv(k)
|
267 |
+
assert k.shape[:-1] == score.shape
|
268 |
+
assert k.shape[-2] == self.block_size
|
269 |
+
score_topk = score.topk(self.repr_topk, dim=-1).indices
|
270 |
+
assert score_topk.shape == (self.num_units, self.unit_size, self.repr_topk)
|
271 |
+
ret = torch.gather(
|
272 |
+
k,
|
273 |
+
-2,
|
274 |
+
score_topk[:, :, :, None].expand(
|
275 |
+
self.num_units, self.unit_size, self.repr_topk, self.dim_head
|
276 |
+
),
|
277 |
+
)
|
278 |
+
return ret
|
279 |
+
|
280 |
+
def from_group_kv(self, tensor):
|
281 |
+
assert tensor.dim() == 4
|
282 |
+
assert tensor.size(1) == self.num_heads_kv
|
283 |
+
if self.num_heads == self.num_heads_kv:
|
284 |
+
return tensor
|
285 |
+
_, _, length, dim_head = tensor.shape
|
286 |
+
num_group = self.num_heads // self.num_heads_kv
|
287 |
+
tensor = tensor.view((self.num_units, self.unit_size_kv, 1, length, dim_head))
|
288 |
+
tensor = tensor.expand(
|
289 |
+
(self.num_units, self.unit_size_kv, num_group, length, dim_head)
|
290 |
+
).reshape((self.num_units, self.num_heads, length, dim_head))
|
291 |
+
return tensor
|
292 |
+
|
293 |
+
def init(self, local_q, local_k, local_v, global_q, global_k, global_v):
|
294 |
+
assert local_q.dim() == 4
|
295 |
+
batch_size, num_heads, len_q, dim_head = local_q.shape
|
296 |
+
num_heads_kv = local_k.size(1)
|
297 |
+
|
298 |
+
for _t in [local_q, local_k, local_v, global_q, global_k, global_v]:
|
299 |
+
assert _t.size(0) == batch_size
|
300 |
+
assert _t.size(1) == num_heads or _t.size(1) == num_heads_kv
|
301 |
+
assert _t.size(2) == len_q
|
302 |
+
assert _t.size(3) == dim_head
|
303 |
+
assert _t.is_cuda
|
304 |
+
|
305 |
+
self.batch_size = batch_size
|
306 |
+
self.num_heads = num_heads
|
307 |
+
self.num_heads_kv = num_heads_kv
|
308 |
+
self.dim_head = dim_head
|
309 |
+
self.num_units = batch_size
|
310 |
+
self.unit_size = num_heads
|
311 |
+
self.unit_size_kv = num_heads_kv
|
312 |
+
|
313 |
+
self.global_blocks = [[] for _ in range(self.num_units)] # [[memory_unit]]
|
314 |
+
self.cached_blocks = [
|
315 |
+
{} for _ in range(self.num_units)
|
316 |
+
] # [[block_id: block_score]
|
317 |
+
self.num_global_block = 0
|
318 |
+
|
319 |
+
if self.faiss:
|
320 |
+
self.block_k = [
|
321 |
+
Faiss(dim_head * self.unit_size, global_k.dtype)
|
322 |
+
for _ in range(self.num_units)
|
323 |
+
]
|
324 |
+
else:
|
325 |
+
self.block_k = [
|
326 |
+
VectorTensor(dim_head * self.unit_size, global_k.dtype)
|
327 |
+
for _ in range(self.num_units)
|
328 |
+
]
|
329 |
+
|
330 |
+
self.local_k = torch.empty(
|
331 |
+
(self.num_units, self.unit_size_kv, 0, dim_head),
|
332 |
+
dtype=local_k.dtype,
|
333 |
+
device=local_k.device,
|
334 |
+
)
|
335 |
+
self.local_v = torch.empty(
|
336 |
+
(self.num_units, self.unit_size_kv, 0, dim_head),
|
337 |
+
dtype=local_v.dtype,
|
338 |
+
device=local_v.device,
|
339 |
+
)
|
340 |
+
|
341 |
+
if self.dense_decoding:
|
342 |
+
self.dense_k = torch.empty(
|
343 |
+
(self.num_units, self.unit_size_kv, 0, dim_head),
|
344 |
+
dtype=local_k.dtype,
|
345 |
+
device=local_k.device,
|
346 |
+
)
|
347 |
+
self.dense_v = torch.empty(
|
348 |
+
(self.num_units, self.unit_size_kv, 0, dim_head),
|
349 |
+
dtype=local_v.dtype,
|
350 |
+
device=local_v.device,
|
351 |
+
)
|
352 |
+
|
353 |
+
self.global_remainder = (
|
354 |
+
torch.empty(
|
355 |
+
(self.num_units, self.unit_size_kv, 0, dim_head),
|
356 |
+
dtype=global_k.dtype,
|
357 |
+
device=global_k.device,
|
358 |
+
),
|
359 |
+
torch.empty(
|
360 |
+
(self.num_units, self.unit_size_kv, 0, dim_head),
|
361 |
+
dtype=global_v.dtype,
|
362 |
+
device=global_v.device,
|
363 |
+
),
|
364 |
+
)
|
365 |
+
|
366 |
+
self.global_remainder_local_score = torch.empty(
|
367 |
+
(self.num_units, self.unit_size, 0),
|
368 |
+
dtype=global_k.dtype,
|
369 |
+
device=global_k.device,
|
370 |
+
)
|
371 |
+
|
372 |
+
self.init_k = torch.empty(
|
373 |
+
(self.num_units, self.unit_size_kv, 0, dim_head),
|
374 |
+
dtype=global_k.dtype,
|
375 |
+
device=global_k.device,
|
376 |
+
)
|
377 |
+
self.init_v = torch.empty(
|
378 |
+
(self.num_units, self.unit_size_kv, 0, dim_head),
|
379 |
+
dtype=global_k.dtype,
|
380 |
+
device=global_k.device,
|
381 |
+
)
|
382 |
+
self.init_exc = False
|
383 |
+
self.dtype = local_q.dtype
|
384 |
+
self.position_embedding._update_cos_sin_tables_len(
|
385 |
+
self.n_local + self.exc_block_size + 1, local_k.device, local_k.dim()
|
386 |
+
)
|
387 |
+
|
388 |
+
buffer_len = (
|
389 |
+
self.topk * self.block_size
|
390 |
+
+ self.exc_block_size
|
391 |
+
+ self.block_size
|
392 |
+
+ self.n_init
|
393 |
+
)
|
394 |
+
self.global_buffer = torch.zeros(
|
395 |
+
(2, self.num_units, self.unit_size_kv, buffer_len, dim_head),
|
396 |
+
dtype=global_k.dtype,
|
397 |
+
device=global_k.device,
|
398 |
+
)
|
399 |
+
self.global_buffer_block_id_list = [
|
400 |
+
[-1] * self.topk for _ in range(self.num_units)
|
401 |
+
]
|
402 |
+
self.global_buffer_init_st = 0
|
403 |
+
self.global_buffer_init_ed = 0
|
404 |
+
self.cuda_cache = CudaCache(
|
405 |
+
self.max_cached_block * self.num_units,
|
406 |
+
self.unit_size_kv * self.block_size * dim_head * 2,
|
407 |
+
local_k.dtype,
|
408 |
+
)
|
409 |
+
|
410 |
+
self.initialized = True
|
411 |
+
|
412 |
+
def calc_block_topk(self, global_h_q):
|
413 |
+
if not self._use_chunk_topk:
|
414 |
+
if self.num_global_block <= self.topk:
|
415 |
+
return [
|
416 |
+
list(range(len(self.global_blocks[0])))
|
417 |
+
for _ in range(self.num_units)
|
418 |
+
]
|
419 |
+
|
420 |
+
global_h_q = global_h_q.mean(dim=2, keepdim=False)
|
421 |
+
assert global_h_q.shape == (self.num_units, self.unit_size, self.dim_head)
|
422 |
+
global_h_q = global_h_q.reshape(
|
423 |
+
self.num_units, self.dim_head * self.unit_size
|
424 |
+
)
|
425 |
+
ret = []
|
426 |
+
for u in range(self.num_units):
|
427 |
+
ret.append(self.block_k[u].get_topk(global_h_q[u], self.topk))
|
428 |
+
|
429 |
+
else:
|
430 |
+
return self._cached_topk[self._topk_cur]
|
431 |
+
|
432 |
+
return ret
|
433 |
+
|
434 |
+
def get_global_hidden_and_mask(self, len_q, block_topk):
|
435 |
+
assert len(block_topk) == self.num_units
|
436 |
+
global_block_map = [[] for _ in range(self.num_units)]
|
437 |
+
global_remainder_len = max(
|
438 |
+
self._global_remainder_ed
|
439 |
+
- self._global_remainder_st
|
440 |
+
+ len_q
|
441 |
+
- self.n_local,
|
442 |
+
0,
|
443 |
+
)
|
444 |
+
init_len = self.init_k.size(-2)
|
445 |
+
sliding_window = None
|
446 |
+
|
447 |
+
global_h_k = self.global_buffer[0]
|
448 |
+
global_h_v = self.global_buffer[1]
|
449 |
+
|
450 |
+
block_num = len(block_topk[0])
|
451 |
+
for u in range(self.num_units):
|
452 |
+
assert len(block_topk[u]) == block_num
|
453 |
+
|
454 |
+
block_topk[u].sort()
|
455 |
+
global_block_map[u] = deepcopy(self.global_buffer_block_id_list[u])
|
456 |
+
for b_idx in block_topk[u]:
|
457 |
+
if b_idx in global_block_map[u]:
|
458 |
+
continue
|
459 |
+
|
460 |
+
st = -1
|
461 |
+
ed = -1
|
462 |
+
for j in range(self.topk):
|
463 |
+
if (
|
464 |
+
global_block_map[u][j] == -1
|
465 |
+
or global_block_map[u][j] not in block_topk[u]
|
466 |
+
):
|
467 |
+
st = j * self.block_size
|
468 |
+
ed = st + self.block_size
|
469 |
+
global_block_map[u][j] = b_idx
|
470 |
+
break
|
471 |
+
|
472 |
+
assert b_idx in self.cached_blocks[u]
|
473 |
+
self.global_blocks[u][b_idx].load(
|
474 |
+
(global_h_k[u, :, st:ed, :], global_h_v[u, :, st:ed, :])
|
475 |
+
)
|
476 |
+
|
477 |
+
init_st = block_num * self.block_size
|
478 |
+
init_ed = init_st + init_len
|
479 |
+
if (
|
480 |
+
self.global_buffer_init_st != init_st
|
481 |
+
or self.global_buffer_init_ed != init_ed
|
482 |
+
):
|
483 |
+
global_h_k[:, :, init_st:init_ed, :].copy_(self.init_k, non_blocking=True)
|
484 |
+
global_h_v[:, :, init_st:init_ed, :].copy_(self.init_v, non_blocking=True)
|
485 |
+
|
486 |
+
ed = init_ed
|
487 |
+
|
488 |
+
rmd_st = init_ed
|
489 |
+
rmd_ed = rmd_st + global_remainder_len
|
490 |
+
ed = rmd_ed
|
491 |
+
global_h_k[:, :, rmd_st:rmd_ed, :].copy_(
|
492 |
+
self.global_remainder[0][
|
493 |
+
:,
|
494 |
+
:,
|
495 |
+
self._global_remainder_st : self._global_remainder_st
|
496 |
+
+ global_remainder_len,
|
497 |
+
:,
|
498 |
+
],
|
499 |
+
non_blocking=True,
|
500 |
+
)
|
501 |
+
global_h_v[:, :, rmd_st:rmd_ed, :].copy_(
|
502 |
+
self.global_remainder[1][
|
503 |
+
:,
|
504 |
+
:,
|
505 |
+
self._global_remainder_st : self._global_remainder_st
|
506 |
+
+ global_remainder_len,
|
507 |
+
:,
|
508 |
+
],
|
509 |
+
non_blocking=True,
|
510 |
+
)
|
511 |
+
|
512 |
+
sliding_window = (self.global_remainder[0].size(-2) + rmd_st, self.n_local)
|
513 |
+
|
514 |
+
self.global_buffer_block_id_list = deepcopy(global_block_map)
|
515 |
+
self.global_buffer_init_st = init_st
|
516 |
+
self.global_buffer_init_ed = init_ed
|
517 |
+
|
518 |
+
for u in range(self.num_units):
|
519 |
+
assert max(global_block_map[u][block_num:] + [-1]) == -1
|
520 |
+
assert min(global_block_map[u][:block_num] + [0]) > -1
|
521 |
+
global_block_map[u] = list(global_block_map[u][:block_num])
|
522 |
+
|
523 |
+
global_h_k = global_h_k[:, :, :ed, :]
|
524 |
+
global_h_v = global_h_v[:, :, :ed, :]
|
525 |
+
return global_h_k, global_h_v, sliding_window, global_block_map, block_num
|
526 |
+
|
527 |
+
def update_block_score(
|
528 |
+
self, global_score: torch.FloatTensor, global_block_map, global_block_num
|
529 |
+
):
|
530 |
+
if global_score is not None:
|
531 |
+
global_score = global_score[:, :, : global_block_num * self.block_size]
|
532 |
+
assert global_score.shape == (
|
533 |
+
self.num_units,
|
534 |
+
self.unit_size,
|
535 |
+
global_block_num * self.block_size,
|
536 |
+
)
|
537 |
+
global_score = global_score.view(
|
538 |
+
self.num_units, self.unit_size, global_block_num, self.block_size
|
539 |
+
)
|
540 |
+
global_score = global_score.sum(dim=-1).sum(dim=1)
|
541 |
+
assert global_score.shape == (self.num_units, global_block_num)
|
542 |
+
global_score = global_score.to(
|
543 |
+
device="cpu", non_blocking=False
|
544 |
+
) # (num_units, global_block_num)
|
545 |
+
for u in range(self.num_units):
|
546 |
+
for k, v in self.cached_blocks[u].items():
|
547 |
+
self.cached_blocks[u][k] = v * self.score_decay
|
548 |
+
score = global_score[u].tolist()
|
549 |
+
assert len(score) >= len(global_block_map[u])
|
550 |
+
for s, i in zip(score, global_block_map[u]):
|
551 |
+
self.cached_blocks[u][i] += s
|
552 |
+
|
553 |
+
def _append(self, local_q, local_k, local_v, global_q):
|
554 |
+
# get local_h_q, local_h_k, local_h_v
|
555 |
+
local_h_q, local_h_k = self.position_embedding(local_q, local_k)
|
556 |
+
local_h_v = local_v
|
557 |
+
|
558 |
+
# calc local result first to overlap host-device communication
|
559 |
+
attn = self.Attn(local_h_q.shape, local_h_q.dtype, local_h_q.device)
|
560 |
+
attn.append(
|
561 |
+
local_h_q, local_h_k, local_h_v, get_score=True, sliding_window=self.n_local
|
562 |
+
)
|
563 |
+
|
564 |
+
# calc topk global repr k and load cache
|
565 |
+
with torch.cuda.stream(GLOBAL_STREAM):
|
566 |
+
block_topk = self.calc_block_topk(global_q)
|
567 |
+
|
568 |
+
for u in range(self.num_units):
|
569 |
+
num_remove = len(self.cached_blocks[u]) - self.max_cached_block
|
570 |
+
for bidx in block_topk[u]:
|
571 |
+
if bidx not in self.cached_blocks[u]:
|
572 |
+
num_remove += 1
|
573 |
+
|
574 |
+
# update cache
|
575 |
+
self.remove_lru_blocks(u, num_remove, block_topk[u])
|
576 |
+
|
577 |
+
if self.cache_strategy == "lru":
|
578 |
+
self.load_count += 1
|
579 |
+
for u in range(self.num_units):
|
580 |
+
for bidx in block_topk[u]:
|
581 |
+
self.cached_blocks[u][bidx] = self.load_count
|
582 |
+
|
583 |
+
elif self.cache_strategy == "lru-s":
|
584 |
+
for u in range(self.num_units):
|
585 |
+
for bidx in block_topk[u]:
|
586 |
+
self.cached_blocks[u][bidx] = 0
|
587 |
+
else:
|
588 |
+
raise ValueError
|
589 |
+
|
590 |
+
# get global_h_k, global_h_v, global_mask
|
591 |
+
# Beacuse exc_block_size <= n_local, no global_k, global_v used in global part
|
592 |
+
global_h_q = global_q
|
593 |
+
(
|
594 |
+
global_h_k,
|
595 |
+
global_h_v,
|
596 |
+
global_sliding_window,
|
597 |
+
global_block_map,
|
598 |
+
global_block_num,
|
599 |
+
) = self.get_global_hidden_and_mask(local_h_q.size(-2), block_topk)
|
600 |
+
|
601 |
+
if self.async_global_stream:
|
602 |
+
torch.cuda.current_stream().wait_stream(GLOBAL_STREAM)
|
603 |
+
|
604 |
+
# calc global result
|
605 |
+
attn.append(
|
606 |
+
global_h_q,
|
607 |
+
global_h_k,
|
608 |
+
global_h_v,
|
609 |
+
end=True,
|
610 |
+
get_score=self.calc_block_score,
|
611 |
+
sliding_window=global_sliding_window,
|
612 |
+
complement_sliding_window=True,
|
613 |
+
)
|
614 |
+
|
615 |
+
o, score_list = attn.get_result()
|
616 |
+
loc_score = score_list[0]
|
617 |
+
glb_score = score_list[1]
|
618 |
+
|
619 |
+
if self.async_global_stream:
|
620 |
+
GLOBAL_STREAM.wait_stream(torch.cuda.current_stream())
|
621 |
+
|
622 |
+
# update global score
|
623 |
+
with torch.cuda.stream(GLOBAL_STREAM):
|
624 |
+
self.update_block_score(glb_score, global_block_map, global_block_num)
|
625 |
+
|
626 |
+
return o.view((self.batch_size, self.num_heads, -1, self.dim_head)), loc_score
|
627 |
+
|
628 |
+
def get_batched_topk(self, global_q):
|
629 |
+
length = global_q.shape[2]
|
630 |
+
exc_num = (length + self.exc_block_size - 1) // self.exc_block_size
|
631 |
+
exc_block_num = length // self.exc_block_size
|
632 |
+
ret = []
|
633 |
+
if self.num_global_block <= self.topk:
|
634 |
+
for _ in range(exc_num):
|
635 |
+
ret.append(
|
636 |
+
[
|
637 |
+
list(range(len(self.global_blocks[0])))
|
638 |
+
for _ in range(self.num_units)
|
639 |
+
]
|
640 |
+
)
|
641 |
+
return ret
|
642 |
+
|
643 |
+
global_h_q = global_q
|
644 |
+
assert global_h_q.dim() == 4
|
645 |
+
assert global_h_q.shape[:2] == (self.num_units, self.unit_size)
|
646 |
+
assert global_h_q.shape[3] == self.dim_head
|
647 |
+
|
648 |
+
block_k = torch.cat(
|
649 |
+
[self.block_k[u].get_data()[None, :, :] for u in range(self.num_units)],
|
650 |
+
dim=0,
|
651 |
+
)
|
652 |
+
assert block_k.shape == (
|
653 |
+
self.num_units,
|
654 |
+
self.num_global_block,
|
655 |
+
self.dim_head * self.unit_size,
|
656 |
+
)
|
657 |
+
block_k = (
|
658 |
+
block_k.reshape(
|
659 |
+
self.num_units, self.num_global_block, self.unit_size, self.dim_head
|
660 |
+
)
|
661 |
+
.permute(0, 2, 1, 3)
|
662 |
+
.contiguous()
|
663 |
+
)
|
664 |
+
|
665 |
+
if exc_block_num > 0:
|
666 |
+
tmp_global_h_q = (
|
667 |
+
global_h_q[:, :, : exc_block_num * self.exc_block_size, :]
|
668 |
+
.reshape(
|
669 |
+
self.num_units,
|
670 |
+
self.unit_size,
|
671 |
+
exc_block_num,
|
672 |
+
self.exc_block_size,
|
673 |
+
self.dim_head,
|
674 |
+
)
|
675 |
+
.mean(dim=-2)
|
676 |
+
)
|
677 |
+
assert tmp_global_h_q.shape == (
|
678 |
+
self.num_units,
|
679 |
+
self.unit_size,
|
680 |
+
exc_block_num,
|
681 |
+
self.dim_head,
|
682 |
+
)
|
683 |
+
block_score = torch.matmul(tmp_global_h_q, block_k.transpose(-1, -2)).mean(
|
684 |
+
dim=1
|
685 |
+
) # (num_units, exc_block_num, num_global_block)
|
686 |
+
assert block_score.shape == (
|
687 |
+
self.num_units,
|
688 |
+
exc_block_num,
|
689 |
+
self.num_global_block,
|
690 |
+
)
|
691 |
+
|
692 |
+
indices = block_score.topk(self.topk, dim=-1).indices.cpu()
|
693 |
+
for b in range(exc_block_num):
|
694 |
+
tmp = []
|
695 |
+
for u in range(self.num_units):
|
696 |
+
tmp.append(indices[u, b].tolist())
|
697 |
+
assert len(tmp[-1]) == self.topk
|
698 |
+
|
699 |
+
ret.append(tmp)
|
700 |
+
|
701 |
+
if exc_block_num != exc_num:
|
702 |
+
tmp_global_h_q = (
|
703 |
+
global_h_q[:, :, exc_block_num * self.exc_block_size :, :]
|
704 |
+
.reshape(
|
705 |
+
self.num_units,
|
706 |
+
self.unit_size,
|
707 |
+
length - exc_block_num * self.exc_block_size,
|
708 |
+
self.dim_head,
|
709 |
+
)
|
710 |
+
.mean(dim=-2, keepdim=True)
|
711 |
+
)
|
712 |
+
assert tmp_global_h_q.shape == (
|
713 |
+
self.num_units,
|
714 |
+
self.unit_size,
|
715 |
+
1,
|
716 |
+
self.dim_head,
|
717 |
+
)
|
718 |
+
block_score = torch.matmul(tmp_global_h_q, block_k.transpose(-1, -2))
|
719 |
+
assert block_score.shape == (
|
720 |
+
self.num_units,
|
721 |
+
self.unit_size,
|
722 |
+
1,
|
723 |
+
self.num_global_block,
|
724 |
+
)
|
725 |
+
block_score = block_score.squeeze(dim=2).mean(dim=1)
|
726 |
+
assert block_score.shape == (self.num_units, self.num_global_block)
|
727 |
+
indices = block_score.topk(self.topk, dim=-1).indices.cpu()
|
728 |
+
tmp = []
|
729 |
+
for u in range(self.num_units):
|
730 |
+
tmp.append(indices[u].tolist())
|
731 |
+
assert len(tmp[-1]) == self.topk
|
732 |
+
|
733 |
+
ret.append(tmp)
|
734 |
+
|
735 |
+
return ret
|
736 |
+
|
737 |
+
def append_global(self, exc_length, kv_length, local_score):
|
738 |
+
global_remainder_ed = self._global_remainder_ed + exc_length
|
739 |
+
global_remainder_st = self._global_remainder_st
|
740 |
+
|
741 |
+
global_remainder_len = global_remainder_ed - global_remainder_st
|
742 |
+
|
743 |
+
assert local_score.shape[:3] == (self.num_units, self.unit_size, kv_length)
|
744 |
+
local_score = local_score[:, :, -exc_length - self.n_local :]
|
745 |
+
self.global_remainder_local_score[
|
746 |
+
:, :, global_remainder_ed - local_score.size(-1) : global_remainder_ed
|
747 |
+
].add_(local_score)
|
748 |
+
|
749 |
+
if not self.init_exc and global_remainder_len > self.n_local:
|
750 |
+
global_k = self.global_remainder[0]
|
751 |
+
global_v = self.global_remainder[1]
|
752 |
+
|
753 |
+
append_init_len = min(
|
754 |
+
self.n_init - self.init_k.size(-2), global_remainder_len - self.n_local
|
755 |
+
)
|
756 |
+
self.init_k = torch.cat(
|
757 |
+
(
|
758 |
+
self.init_k,
|
759 |
+
global_k[
|
760 |
+
:,
|
761 |
+
:,
|
762 |
+
global_remainder_st : global_remainder_st + append_init_len,
|
763 |
+
:,
|
764 |
+
],
|
765 |
+
),
|
766 |
+
dim=-2,
|
767 |
+
)
|
768 |
+
self.init_v = torch.cat(
|
769 |
+
(
|
770 |
+
self.init_v,
|
771 |
+
global_v[
|
772 |
+
:,
|
773 |
+
:,
|
774 |
+
global_remainder_st : global_remainder_st + append_init_len,
|
775 |
+
:,
|
776 |
+
],
|
777 |
+
),
|
778 |
+
dim=-2,
|
779 |
+
)
|
780 |
+
global_remainder_st += append_init_len
|
781 |
+
global_remainder_len -= append_init_len
|
782 |
+
|
783 |
+
if self.init_k.size(-2) == self.n_init:
|
784 |
+
self.init_exc = True
|
785 |
+
|
786 |
+
while global_remainder_len - self.block_size >= self.n_local:
|
787 |
+
global_remainder_len -= self.block_size
|
788 |
+
for u in range(self.num_units):
|
789 |
+
self.global_blocks[u].append(
|
790 |
+
(
|
791 |
+
MemoryUnit(
|
792 |
+
(
|
793 |
+
self.global_remainder[0][
|
794 |
+
u,
|
795 |
+
:,
|
796 |
+
global_remainder_st : global_remainder_st
|
797 |
+
+ self.block_size,
|
798 |
+
:,
|
799 |
+
],
|
800 |
+
self.global_remainder[1][
|
801 |
+
u,
|
802 |
+
:,
|
803 |
+
global_remainder_st : global_remainder_st
|
804 |
+
+ self.block_size,
|
805 |
+
:,
|
806 |
+
],
|
807 |
+
),
|
808 |
+
self.cuda_cache,
|
809 |
+
False,
|
810 |
+
self.pin_memory,
|
811 |
+
)
|
812 |
+
)
|
813 |
+
)
|
814 |
+
|
815 |
+
global_block_k = self.get_block_k(
|
816 |
+
self.global_remainder[0][
|
817 |
+
:, :, global_remainder_st : global_remainder_st + self.block_size, :
|
818 |
+
],
|
819 |
+
self.global_remainder_local_score[
|
820 |
+
:, :, global_remainder_st : global_remainder_st + self.block_size
|
821 |
+
],
|
822 |
+
)
|
823 |
+
assert global_block_k.shape == (
|
824 |
+
self.num_units,
|
825 |
+
self.unit_size,
|
826 |
+
self.repr_topk,
|
827 |
+
self.dim_head,
|
828 |
+
)
|
829 |
+
global_block_k = global_block_k.mean(dim=-2, keepdim=False)
|
830 |
+
global_block_k = global_block_k.reshape(
|
831 |
+
self.num_units, self.unit_size * self.dim_head
|
832 |
+
)
|
833 |
+
global_block_k = global_block_k[:, None, :]
|
834 |
+
|
835 |
+
self.num_global_block += 1
|
836 |
+
for u in range(self.num_units):
|
837 |
+
self.block_k[u].append(global_block_k[u])
|
838 |
+
global_remainder_st += self.block_size
|
839 |
+
|
840 |
+
self._global_remainder_ed = global_remainder_ed
|
841 |
+
self._global_remainder_st = global_remainder_st
|
842 |
+
|
843 |
+
def append(
|
844 |
+
self,
|
845 |
+
local_q,
|
846 |
+
local_k,
|
847 |
+
local_v,
|
848 |
+
global_q,
|
849 |
+
global_k,
|
850 |
+
global_v,
|
851 |
+
):
|
852 |
+
batch_size = local_q.size(0)
|
853 |
+
input_length = local_q.size(-2)
|
854 |
+
|
855 |
+
if self.perhead:
|
856 |
+
num_heads = local_q.size(1)
|
857 |
+
num_heads_kv = local_v.size(1)
|
858 |
+
|
859 |
+
def repeat_kv(t):
|
860 |
+
t = t.view(batch_size, num_heads_kv, 1, input_length, -1)
|
861 |
+
t = t.expand(
|
862 |
+
batch_size,
|
863 |
+
num_heads_kv,
|
864 |
+
num_heads // num_heads_kv,
|
865 |
+
input_length,
|
866 |
+
-1,
|
867 |
+
)
|
868 |
+
t = t.reshape(batch_size * num_heads, 1, input_length, -1)
|
869 |
+
return t
|
870 |
+
|
871 |
+
local_q = local_q.view(batch_size * num_heads, 1, input_length, -1)
|
872 |
+
local_k = repeat_kv(local_k)
|
873 |
+
local_v = repeat_kv(local_v)
|
874 |
+
global_q = global_q.view(batch_size * num_heads, 1, input_length, -1)
|
875 |
+
global_k = repeat_kv(global_k)
|
876 |
+
global_v = repeat_kv(global_v)
|
877 |
+
|
878 |
+
if not self.initialized:
|
879 |
+
self.init(local_q, local_k, local_v, global_q, global_k, global_v)
|
880 |
+
|
881 |
+
input_length = local_q.size(-2)
|
882 |
+
|
883 |
+
if self.async_global_stream:
|
884 |
+
GLOBAL_STREAM.wait_stream(torch.cuda.current_stream())
|
885 |
+
|
886 |
+
# append local and global tensor
|
887 |
+
self.local_k = torch.cat((self.local_k, local_k), dim=-2)
|
888 |
+
self.local_v = torch.cat((self.local_v, local_v), dim=-2)
|
889 |
+
kv_length = self.local_k.size(-2)
|
890 |
+
|
891 |
+
if self.dense_decoding:
|
892 |
+
self.dense_k = torch.cat((self.dense_k, local_k), dim=-2)
|
893 |
+
self.dense_v = torch.cat((self.dense_v, local_v), dim=-2)
|
894 |
+
|
895 |
+
# append global remainder
|
896 |
+
with torch.cuda.stream(GLOBAL_STREAM):
|
897 |
+
self._global_remainder_st = 0
|
898 |
+
self._global_remainder_ed = self.global_remainder[0].size(-2)
|
899 |
+
|
900 |
+
self.global_remainder = (
|
901 |
+
torch.cat((self.global_remainder[0], global_k), dim=-2),
|
902 |
+
torch.cat((self.global_remainder[1], global_v), dim=-2),
|
903 |
+
)
|
904 |
+
|
905 |
+
self.global_remainder_local_score = torch.cat(
|
906 |
+
(
|
907 |
+
self.global_remainder_local_score,
|
908 |
+
torch.zeros(
|
909 |
+
(self.num_units, self.unit_size, global_k.size(-2)),
|
910 |
+
dtype=global_k.dtype,
|
911 |
+
device=global_k.device,
|
912 |
+
),
|
913 |
+
),
|
914 |
+
dim=-1,
|
915 |
+
)
|
916 |
+
|
917 |
+
with torch.cuda.stream(GLOBAL_STREAM):
|
918 |
+
global_q = self.position_embedding.apply_rotary_pos_emb_one_angle(
|
919 |
+
global_q, self.n_local
|
920 |
+
)
|
921 |
+
|
922 |
+
use_chunk_topk = self.chunk_topk_calc is not None and input_length > 1
|
923 |
+
self._use_chunk_topk = use_chunk_topk
|
924 |
+
if use_chunk_topk:
|
925 |
+
exc_block_num = input_length // self.exc_block_size
|
926 |
+
exc_block_per_topk_chunk = self.chunk_topk_calc // self.exc_block_size
|
927 |
+
calc_cur_list = [
|
928 |
+
i * self.exc_block_size
|
929 |
+
for i in range(0, exc_block_num + 1, exc_block_per_topk_chunk)
|
930 |
+
]
|
931 |
+
if calc_cur_list[-1] < input_length:
|
932 |
+
calc_cur_list.append(input_length)
|
933 |
+
self._topk_cur = 0
|
934 |
+
self._topk_calc_cur = -1
|
935 |
+
|
936 |
+
o_list = []
|
937 |
+
|
938 |
+
for st in range(0, input_length, self.exc_block_size):
|
939 |
+
ed = min(st + self.exc_block_size, input_length)
|
940 |
+
if use_chunk_topk and calc_cur_list[self._topk_calc_cur + 1] < ed:
|
941 |
+
# calculate topk and sync with host here
|
942 |
+
assert ed <= calc_cur_list[self._topk_calc_cur + 2]
|
943 |
+
self._topk_calc_cur += 1
|
944 |
+
with torch.cuda.stream(GLOBAL_STREAM):
|
945 |
+
self._cached_topk = self.get_batched_topk(
|
946 |
+
global_q[
|
947 |
+
:,
|
948 |
+
:,
|
949 |
+
calc_cur_list[self._topk_calc_cur] : calc_cur_list[
|
950 |
+
self._topk_calc_cur + 1
|
951 |
+
],
|
952 |
+
:,
|
953 |
+
]
|
954 |
+
)
|
955 |
+
self._topk_cur = 0
|
956 |
+
|
957 |
+
kv_st = max(kv_length + st - input_length - self.n_local, 0)
|
958 |
+
kv_ed = kv_length + ed - input_length
|
959 |
+
chunk_o, local_score = self._append(
|
960 |
+
local_q[:, :, st:ed, :],
|
961 |
+
self.local_k[:, :, kv_st:kv_ed, :],
|
962 |
+
self.local_v[:, :, kv_st:kv_ed, :],
|
963 |
+
global_q[:, :, st:ed, :],
|
964 |
+
)
|
965 |
+
o_list.append(chunk_o)
|
966 |
+
|
967 |
+
# append global
|
968 |
+
with torch.cuda.stream(GLOBAL_STREAM):
|
969 |
+
self.append_global(ed - st, kv_ed - kv_st, local_score)
|
970 |
+
|
971 |
+
if self.async_global_stream:
|
972 |
+
torch.cuda.current_stream().wait_stream(GLOBAL_STREAM)
|
973 |
+
|
974 |
+
if use_chunk_topk:
|
975 |
+
self._topk_cur += 1
|
976 |
+
|
977 |
+
self.length += input_length
|
978 |
+
|
979 |
+
# update local and global tensor
|
980 |
+
if self.local_k.size(-2) >= self.n_local:
|
981 |
+
self.local_k = self.local_k[:, :, -self.n_local :, :]
|
982 |
+
self.local_v = self.local_v[:, :, -self.n_local :, :]
|
983 |
+
|
984 |
+
assert self._global_remainder_ed == self.global_remainder[0].size(-2)
|
985 |
+
with torch.cuda.stream(GLOBAL_STREAM):
|
986 |
+
self.global_remainder = (
|
987 |
+
self.global_remainder[0][:, :, self._global_remainder_st :, :],
|
988 |
+
self.global_remainder[1][:, :, self._global_remainder_st :, :],
|
989 |
+
)
|
990 |
+
self.global_remainder_local_score = self.global_remainder_local_score[
|
991 |
+
:, :, self._global_remainder_st :
|
992 |
+
]
|
993 |
+
|
994 |
+
ret = torch.cat(o_list, dim=-2)
|
995 |
+
|
996 |
+
if self.perhead:
|
997 |
+
ret = ret.view(batch_size, num_heads, input_length, -1)
|
998 |
+
|
999 |
+
return ret
|
1000 |
+
|
1001 |
+
def size(self, *args, **kwargs):
|
1002 |
+
return self.length
|
1003 |
+
|
1004 |
+
|
1005 |
+
def inf_llm_forward(
|
1006 |
+
n_local,
|
1007 |
+
n_init,
|
1008 |
+
topk,
|
1009 |
+
block_size,
|
1010 |
+
max_cached_block,
|
1011 |
+
exc_block_size,
|
1012 |
+
repr_topk: int = 1,
|
1013 |
+
cache_strategy="lru",
|
1014 |
+
score_decay=None,
|
1015 |
+
chunk_topk_calc=None,
|
1016 |
+
async_global_stream=True,
|
1017 |
+
pin_memory=False,
|
1018 |
+
faiss=False,
|
1019 |
+
perhead=False,
|
1020 |
+
dense_decoding=False,
|
1021 |
+
*args,
|
1022 |
+
**kwargs
|
1023 |
+
):
|
1024 |
+
def forward(
|
1025 |
+
self,
|
1026 |
+
query: torch.Tensor,
|
1027 |
+
key_value: torch.Tensor,
|
1028 |
+
position_bias: Optional[torch.Tensor],
|
1029 |
+
use_cache: bool,
|
1030 |
+
past_key_value,
|
1031 |
+
project_q,
|
1032 |
+
project_k,
|
1033 |
+
project_v,
|
1034 |
+
attention_out,
|
1035 |
+
dim_head,
|
1036 |
+
num_heads,
|
1037 |
+
num_heads_kv,
|
1038 |
+
):
|
1039 |
+
batch_size = query.size(0)
|
1040 |
+
len_q = query.size(1)
|
1041 |
+
len_k = key_value.size(1)
|
1042 |
+
|
1043 |
+
# assert use_cache
|
1044 |
+
|
1045 |
+
h_q = project_q(query) # (batch, len_q, num_heads * dim_head)
|
1046 |
+
h_k = project_k(key_value) # (batch, len_k, num_heads * dim_head)
|
1047 |
+
h_v = project_v(key_value) # (batch, len_k, num_heads * dim_head)
|
1048 |
+
|
1049 |
+
h_q = (
|
1050 |
+
h_q.view(batch_size, len_q, num_heads, dim_head)
|
1051 |
+
.permute(0, 2, 1, 3)
|
1052 |
+
.contiguous()
|
1053 |
+
) # (batch, num_heads, len_q, dim_head)
|
1054 |
+
h_k = (
|
1055 |
+
h_k.view(batch_size, len_k, num_heads_kv, dim_head)
|
1056 |
+
.permute(0, 2, 1, 3)
|
1057 |
+
.contiguous()
|
1058 |
+
) # (batch, num_heads_kv, len_k, dim_head)
|
1059 |
+
h_v = (
|
1060 |
+
h_v.view(batch_size, len_k, num_heads_kv, dim_head)
|
1061 |
+
.permute(0, 2, 1, 3)
|
1062 |
+
.contiguous()
|
1063 |
+
) # (batch, num_heads_kv, len_k, dim_head)
|
1064 |
+
|
1065 |
+
if len_q == 1 and dense_decoding:
|
1066 |
+
past_k = past_key_value.dense_k
|
1067 |
+
past_v = past_key_value.dense_v
|
1068 |
+
|
1069 |
+
h_k = torch.cat((past_k, h_k), dim=-2)
|
1070 |
+
h_v = torch.cat((past_v, h_v), dim=-2)
|
1071 |
+
|
1072 |
+
past_key_value.dense_k = h_k
|
1073 |
+
past_key_value.dense_v = h_v
|
1074 |
+
|
1075 |
+
h_q, h_k = position_bias(h_q, h_k)
|
1076 |
+
|
1077 |
+
# (batch_size, seqlen, nheads, headdim)
|
1078 |
+
h_q = h_q.transpose(1, 2)
|
1079 |
+
h_k = h_k.transpose(1, 2)
|
1080 |
+
h_v = h_v.transpose(1, 2)
|
1081 |
+
|
1082 |
+
# (batch_size, seqlen, nheads, headdim)
|
1083 |
+
o = flash_attn_func(h_q, h_k, h_v, causal=True)
|
1084 |
+
|
1085 |
+
o = o.reshape(batch_size, len_q, dim_head * num_heads)
|
1086 |
+
o = attention_out(o)
|
1087 |
+
|
1088 |
+
if use_cache:
|
1089 |
+
return o, past_key_value
|
1090 |
+
else:
|
1091 |
+
return o
|
1092 |
+
|
1093 |
+
if past_key_value is None:
|
1094 |
+
past_key_value = ContextManager(
|
1095 |
+
position_bias,
|
1096 |
+
n_init,
|
1097 |
+
n_local,
|
1098 |
+
block_size,
|
1099 |
+
max_cached_block,
|
1100 |
+
topk,
|
1101 |
+
exc_block_size,
|
1102 |
+
score_decay,
|
1103 |
+
repr_topk,
|
1104 |
+
cache_strategy,
|
1105 |
+
chunk_topk_calc,
|
1106 |
+
async_global_stream,
|
1107 |
+
pin_memory,
|
1108 |
+
faiss,
|
1109 |
+
perhead,
|
1110 |
+
dense_decoding=dense_decoding,
|
1111 |
+
)
|
1112 |
+
|
1113 |
+
local_q, local_k, local_v = h_q, h_k, h_v
|
1114 |
+
global_q, global_k, global_v = h_q, h_k, h_v
|
1115 |
+
|
1116 |
+
o = past_key_value.append(
|
1117 |
+
local_q,
|
1118 |
+
local_k,
|
1119 |
+
local_v,
|
1120 |
+
global_q,
|
1121 |
+
global_k,
|
1122 |
+
global_v,
|
1123 |
+
)
|
1124 |
+
|
1125 |
+
o = o.view(batch_size, num_heads, len_q, dim_head).permute(0, 2, 1, 3)
|
1126 |
+
o = o.reshape(batch_size, len_q, dim_head * num_heads)
|
1127 |
+
o = attention_out(o)
|
1128 |
+
|
1129 |
+
if use_cache:
|
1130 |
+
return o, past_key_value
|
1131 |
+
else:
|
1132 |
+
return o
|
1133 |
+
|
1134 |
+
return forward
|
1135 |
+
|
1136 |
+
|
1137 |
+
class GreedySearch:
|
1138 |
+
def __init__(self, model, tokenizer):
|
1139 |
+
model.eval()
|
1140 |
+
self.device = model.device
|
1141 |
+
self.model = model
|
1142 |
+
self.tokenizer = tokenizer
|
1143 |
+
self.past_kv = None
|
1144 |
+
|
1145 |
+
def clear(self):
|
1146 |
+
self.past_kv = None
|
1147 |
+
|
1148 |
+
def _process_texts(self, input_text):
|
1149 |
+
model_inputs = {}
|
1150 |
+
input_ids = self.tokenizer.encode(input_text)
|
1151 |
+
|
1152 |
+
model_inputs["input_ids"] = input_ids
|
1153 |
+
model_inputs["attention_mask"] = [1] * len(model_inputs["input_ids"])
|
1154 |
+
|
1155 |
+
for key in model_inputs:
|
1156 |
+
model_inputs[key] = (
|
1157 |
+
torch.tensor(model_inputs[key]).int().unsqueeze(0).cuda()
|
1158 |
+
)
|
1159 |
+
|
1160 |
+
return model_inputs
|
1161 |
+
|
1162 |
+
def generate(self, text=None, input_ids=None, **kwargs):
|
1163 |
+
if input_ids is None:
|
1164 |
+
model_inputs = self._process_texts(text)
|
1165 |
+
input_ids = model_inputs["input_ids"]
|
1166 |
+
|
1167 |
+
with torch.inference_mode():
|
1168 |
+
result = self._decode(input_ids, **kwargs)
|
1169 |
+
|
1170 |
+
self.clear()
|
1171 |
+
return result
|
1172 |
+
|
1173 |
+
def _decode(
|
1174 |
+
self,
|
1175 |
+
input_ids,
|
1176 |
+
max_length=100,
|
1177 |
+
extra_end_token_ids=[],
|
1178 |
+
chunk_size: int = 4096,
|
1179 |
+
output=False,
|
1180 |
+
):
|
1181 |
+
if input_ids.dim() == 1:
|
1182 |
+
input_ids = input_ids[None, :]
|
1183 |
+
input_ids = input_ids.cuda()
|
1184 |
+
attention_mask = torch.ones_like(input_ids)
|
1185 |
+
assert input_ids.size(0) == 1
|
1186 |
+
length = input_ids.size(1)
|
1187 |
+
end_token_ids = extra_end_token_ids + [self.tokenizer.eos_token_id]
|
1188 |
+
logits = None
|
1189 |
+
past_key_values = self.past_kv
|
1190 |
+
if output:
|
1191 |
+
output_text = ""
|
1192 |
+
|
1193 |
+
for i in range(max_length + 1):
|
1194 |
+
if i == 0:
|
1195 |
+
if chunk_size is None:
|
1196 |
+
chunk_size = input_ids.size(1)
|
1197 |
+
for st in range(0, input_ids.size(1) - 1, chunk_size):
|
1198 |
+
ed = min(input_ids.size(1) - 1, st + chunk_size)
|
1199 |
+
out = self.model(
|
1200 |
+
input_ids=input_ids[:, st:ed],
|
1201 |
+
attention_mask=attention_mask[:, :ed],
|
1202 |
+
use_cache=True,
|
1203 |
+
return_dict=True,
|
1204 |
+
past_key_values=past_key_values,
|
1205 |
+
)
|
1206 |
+
logits, past_key_values = out.logits, out.past_key_values
|
1207 |
+
|
1208 |
+
out = self.model(
|
1209 |
+
input_ids=input_ids[:, -1:],
|
1210 |
+
attention_mask=attention_mask,
|
1211 |
+
use_cache=True,
|
1212 |
+
return_dict=True,
|
1213 |
+
past_key_values=past_key_values,
|
1214 |
+
)
|
1215 |
+
logits, past_key_values = out.logits, out.past_key_values
|
1216 |
+
else:
|
1217 |
+
out = self.model(
|
1218 |
+
input_ids=input_ids[:, -1:],
|
1219 |
+
attention_mask=attention_mask,
|
1220 |
+
past_key_values=past_key_values,
|
1221 |
+
use_cache=True,
|
1222 |
+
return_dict=True,
|
1223 |
+
)
|
1224 |
+
logits, past_key_values = out.logits, out.past_key_values
|
1225 |
+
|
1226 |
+
logits = logits[:, -1, :]
|
1227 |
+
word = logits.argmax(dim=-1)
|
1228 |
+
if word.item() in end_token_ids or i == max_length:
|
1229 |
+
break
|
1230 |
+
|
1231 |
+
input_ids = torch.cat((input_ids, word.view(1, 1)), dim=-1)
|
1232 |
+
attention_mask = torch.cat(
|
1233 |
+
(
|
1234 |
+
attention_mask,
|
1235 |
+
torch.ones(
|
1236 |
+
(attention_mask.size(0), 1),
|
1237 |
+
dtype=torch.int,
|
1238 |
+
device=attention_mask.device,
|
1239 |
+
),
|
1240 |
+
),
|
1241 |
+
dim=-1,
|
1242 |
+
)
|
1243 |
+
if output:
|
1244 |
+
tmp = self.tokenizer.decode(input_ids.squeeze(0)[length:])
|
1245 |
+
if len(tmp) > len(output_text):
|
1246 |
+
import sys
|
1247 |
+
|
1248 |
+
sys.stdout.write(tmp[len(output_text) :])
|
1249 |
+
sys.stdout.flush()
|
1250 |
+
output_text = tmp
|
1251 |
+
|
1252 |
+
self.past_kv = past_key_values
|
1253 |
+
|
1254 |
+
if output:
|
1255 |
+
sys.stdout.write("\n")
|
1256 |
+
sys.stdout.flush()
|
1257 |
+
|
1258 |
+
# return [self.tokenizer.decode(input_ids.squeeze(0)[length:])]
|
1259 |
+
return input_ids
|
1260 |
+
|
1261 |
+
|
1262 |
+
class InfLLMGenerator(GreedySearch):
|
1263 |
+
def generate(
|
1264 |
+
self,
|
1265 |
+
input_ids=None,
|
1266 |
+
generation_config=None,
|
1267 |
+
pad_token_id=None,
|
1268 |
+
max_new_tokens=None,
|
1269 |
+
):
|
1270 |
+
if max_new_tokens is not None:
|
1271 |
+
max_new_tokens = max_new_tokens
|
1272 |
+
else:
|
1273 |
+
max_new_tokens = generation_config.max_new_tokens
|
1274 |
+
return super().generate(
|
1275 |
+
text=None,
|
1276 |
+
input_ids=input_ids,
|
1277 |
+
max_length=max_new_tokens,
|
1278 |
+
chunk_size=8192,
|
1279 |
+
extra_end_token_ids=[pad_token_id] if pad_token_id is not None else [],
|
1280 |
+
)
|
1281 |
+
|
1282 |
+
@torch.no_grad()
|
1283 |
+
def __call__(self, input_ids=None, *args, **kwargs):
|
1284 |
+
# chunked forward
|
1285 |
+
chunk_size = 8192
|
1286 |
+
all_logits = torch.empty(0, dtype=torch.bfloat16).to(input_ids.device)
|
1287 |
+
for st in range(0, input_ids.size(1), chunk_size):
|
1288 |
+
torch.cuda.empty_cache()
|
1289 |
+
ed = min(input_ids.size(1), st + chunk_size)
|
1290 |
+
out = self.model(
|
1291 |
+
input_ids=input_ids[:, st:ed],
|
1292 |
+
)
|
1293 |
+
logits = out.logits.to(torch.bfloat16)
|
1294 |
+
all_logits = torch.cat((all_logits, logits), dim=1)
|
1295 |
+
|
1296 |
+
return CausalLMOutput(logits=all_logits)
|
minference/modules/minference_forward.py
ADDED
@@ -0,0 +1,855 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from importlib import import_module
|
5 |
+
|
6 |
+
from transformers.models.llama.modeling_llama import *
|
7 |
+
from vllm.attention.backends.flash_attn import *
|
8 |
+
|
9 |
+
from ..ops.block_sparse_flash_attention import block_sparse_attention
|
10 |
+
from ..ops.pit_sparse_flash_attention_v2 import vertical_slash_sparse_attention
|
11 |
+
from ..ops.streaming_kernel import streaming_forward, streaming_forward2
|
12 |
+
from .snap_kv import *
|
13 |
+
|
14 |
+
last_q = 64
|
15 |
+
arange = torch.arange(last_q, device="cuda")
|
16 |
+
LAST_Q_MASK = arange[None, None, :, None] >= arange[None, None, None, :]
|
17 |
+
ROPE_TYPE = None
|
18 |
+
SEARCH_MASK = None
|
19 |
+
|
20 |
+
def init_minference_parameters(self):
|
21 |
+
config = self.config.to_dict()
|
22 |
+
self.starting_layer = config.get("starting_layer", 0)
|
23 |
+
self.is_search = config.get("is_search", False)
|
24 |
+
|
25 |
+
# self.n_init = config.get("n_init", 128)
|
26 |
+
# self.n_local = config.get("n_local", 3968)
|
27 |
+
|
28 |
+
self.ne_inf = None
|
29 |
+
self.config_path = config.get("config_path", "")
|
30 |
+
if os.path.exists(self.config_path) and self.layer_idx < len(json.load(open(self.config_path))):
|
31 |
+
self.best_pattern = {int(ii): jj for ii, jj in json.load(open(self.config_path))[self.layer_idx].items()}
|
32 |
+
else:
|
33 |
+
self.best_pattern = {}
|
34 |
+
self.vertical, self.slash = None, None
|
35 |
+
|
36 |
+
# import apply_rotary_pos_emb
|
37 |
+
if "apply_rotary_pos_emb" not in self.__dict__:
|
38 |
+
global apply_rotary_pos_emb
|
39 |
+
model_path = self.rotary_emb.__class__.__module__
|
40 |
+
apply_rotary_pos_emb = getattr(import_module(model_path), "apply_rotary_pos_emb")
|
41 |
+
self.apply_rotary_pos_emb = True
|
42 |
+
|
43 |
+
def sum_all_diagonal_matrix(mat: torch.tensor):
|
44 |
+
b, h, n, m = mat.shape
|
45 |
+
zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding
|
46 |
+
mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right
|
47 |
+
mat_strided = mat_padded.as_strided((1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides
|
48 |
+
sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns
|
49 |
+
return sum_diags[:,:,1:]
|
50 |
+
|
51 |
+
def gather(t, dim, i):
|
52 |
+
"""A broadcasting version of torch.gather."""
|
53 |
+
dim += (dim < 0) * t.ndim
|
54 |
+
return t.gather(dim, i.expand(*t.shape[:dim], i.shape[dim], *t.shape[dim + 1 :]))
|
55 |
+
|
56 |
+
def gather_qkv(q, k, v, attention_mask):
|
57 |
+
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(q.size(-1)) + attention_mask
|
58 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
59 |
+
attn_output = torch.matmul(attn_weights, v)
|
60 |
+
return attn_output
|
61 |
+
|
62 |
+
def search_pattern(q, k, head):
|
63 |
+
q_len = q.shape[2]
|
64 |
+
head_dim = q.shape[-1]
|
65 |
+
|
66 |
+
def vertical_and_slash(vertical_size, slash_size):
|
67 |
+
last_q = 64
|
68 |
+
q_len = q.shape[2]
|
69 |
+
qk_idxs = [ii + q_len for ii in list(range(-last_q, 0, 1))]
|
70 |
+
qk = torch.matmul(q[:,:,qk_idxs,:], k.transpose(2, 3))/ math.sqrt(head_dim) + attention_mask[:,:,qk_idxs]
|
71 |
+
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
|
72 |
+
vertical = qk.sum(-2, keepdim=True)
|
73 |
+
vertical[...,:30] = 10000
|
74 |
+
vertical_topk = torch.topk(-vertical, q_len - vertical_size, -1).indices
|
75 |
+
|
76 |
+
slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
|
77 |
+
slash[...,-30:] = 10000
|
78 |
+
slash_topk = slash
|
79 |
+
slash = torch.topk(slash, slash_size, -1).indices - (q_len - 1)
|
80 |
+
slash = torch.stack([torch.sparse.spdiags(torch.ones(slash_size, q_len), slash.cpu()[0][_], (q_len, q_len)).to_dense() for _ in range(1)]).to(q.device)
|
81 |
+
|
82 |
+
est_attn = torch.ones_like(attn_weights)
|
83 |
+
dim = 3
|
84 |
+
est_attn = est_attn.scatter(3, vertical_topk.expand(*est_attn.shape[:dim], vertical_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
|
85 |
+
est_attn = est_attn + slash
|
86 |
+
|
87 |
+
est_attn = (est_attn > 0).float()
|
88 |
+
est_attn = torch.tril(est_attn)
|
89 |
+
attn_weights_x = attn_weights * est_attn
|
90 |
+
res3 = attn_weights_x[:,:,2500:].sum(-1).mean(-1).squeeze().float().detach().cpu().numpy()
|
91 |
+
return res3
|
92 |
+
|
93 |
+
def stream_llm(vertical_size, slash_size):
|
94 |
+
q_len = q.shape[2]
|
95 |
+
|
96 |
+
mask = torch.triu(torch.tril(torch.ones(q_len, q_len), 0), -slash_size).to(q)
|
97 |
+
mask[:,:vertical_size] = 1
|
98 |
+
mask = mask.unsqueeze(0).unsqueeze(1)
|
99 |
+
|
100 |
+
est_attn = torch.tril(mask)
|
101 |
+
attn_weights_x = attn_weights * est_attn
|
102 |
+
res3 = attn_weights_x[:,:,2500:].sum(-1).mean(-1).squeeze().float().detach().cpu().numpy()
|
103 |
+
return res3
|
104 |
+
|
105 |
+
def block_sparse(topk_ratio, slash_size=None):
|
106 |
+
block_num = (q_len -1) // 32 + 1
|
107 |
+
block_q = torch.zeros(1,1,block_num * 32,head_dim).to(q)
|
108 |
+
block_q[:,:,:q_len] = q
|
109 |
+
block_q = block_q.reshape(1,1,block_num,32,-1).mean(-2)
|
110 |
+
block_k = torch.zeros(1,1,block_num * 32,head_dim).to(k)
|
111 |
+
block_k[:,:,:q_len] = k
|
112 |
+
block_k = block_k.reshape(1,1,block_num,32,-1).mean(-2)
|
113 |
+
|
114 |
+
qk = torch.matmul(block_q, block_k.transpose(2, 3)) + attention_mask[:,:,:block_num,:block_num]
|
115 |
+
est_attn = torch.ones_like(qk)
|
116 |
+
block_topk = torch.topk(-qk, block_num - block_num//topk_ratio, -1).indices
|
117 |
+
|
118 |
+
dim = 3
|
119 |
+
est_attn = est_attn.scatter(3, block_topk.expand(*est_attn.shape[:dim], block_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
|
120 |
+
est_attn = est_attn.unsqueeze(3).unsqueeze(-1).repeat(1,1,1,32,1,32).reshape(1,1,block_num * 32, block_num * 32)[...,:q_len,:q_len]
|
121 |
+
est_attn = torch.tril(est_attn)
|
122 |
+
|
123 |
+
attn_weights_x = attn_weights * est_attn
|
124 |
+
res2 = attn_weights_x[:,:,2500:].sum(-1).mean(-1).squeeze().float().detach().cpu().numpy()
|
125 |
+
return res2
|
126 |
+
|
127 |
+
global SEARCH_MASK
|
128 |
+
if SEARCH_MASK is None:
|
129 |
+
attention_mask = torch.full((q_len, q_len), torch.finfo(q.dtype).min, device="cuda")
|
130 |
+
mask_cond = torch.arange(attention_mask.size(-1), device="cuda")
|
131 |
+
attention_mask.masked_fill_(mask_cond < (mask_cond + 1).view(attention_mask.size(-1), 1), 0)
|
132 |
+
attention_mask = attention_mask[None, None, :]
|
133 |
+
SEARCH_MASK = attention_mask
|
134 |
+
else:
|
135 |
+
attention_mask = SEARCH_MASK
|
136 |
+
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim) + attention_mask
|
137 |
+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
138 |
+
best_s, best_v, best_score, best_ty = 0, 0, 0, ""
|
139 |
+
all_info = []
|
140 |
+
for ty, fc in [("stream_llm", stream_llm), ("vertical_and_slash", vertical_and_slash), ("block_sparse", block_sparse)]:
|
141 |
+
if ty == "stream_llm":
|
142 |
+
vs_list = [(100, 800)]
|
143 |
+
elif ty == "vertical_and_slash":
|
144 |
+
vs_list = [(30, 800), (100, 750), (500, 700), (3500, 100)]
|
145 |
+
else:
|
146 |
+
vs_list = [(8, 1)]
|
147 |
+
for v_size, s_size in vs_list:
|
148 |
+
score = fc(v_size, s_size)
|
149 |
+
score = score.item()
|
150 |
+
all_info.append([ty, v_size, s_size, score])
|
151 |
+
if score > best_score:
|
152 |
+
best_score = score
|
153 |
+
best_s, best_v = s_size, v_size
|
154 |
+
best_ty = ty
|
155 |
+
if best_ty == "stream_llm":
|
156 |
+
best_ty = "vertical_and_slash"
|
157 |
+
if best_ty == "block_sparse":
|
158 |
+
best_ty, best_v, best_s = "vertical_and_slash", 1000, 6096
|
159 |
+
print(head, best_ty, best_v, best_s, best_score)
|
160 |
+
return (best_ty, best_v, best_s, best_score)
|
161 |
+
|
162 |
+
def search_pattern_v2(q, k, v, head):
|
163 |
+
q_len = q.shape[2]
|
164 |
+
head_dim = q.shape[-1]
|
165 |
+
def vertical_and_slash_kernel(q, k, v, vertical_size, slash_size):
|
166 |
+
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
|
167 |
+
last_q = 64
|
168 |
+
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
|
169 |
+
qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK, qk[:, :, :, -last_q:], -torch.inf)
|
170 |
+
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
|
171 |
+
vertical = qk.sum(-2, keepdim=True)
|
172 |
+
vertical[...,:30] = torch.inf
|
173 |
+
vertical_topk = torch.topk(vertical, vertical_size, -1).indices
|
174 |
+
|
175 |
+
slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
|
176 |
+
slash[...,-30:] = torch.inf
|
177 |
+
slash_topk = slash
|
178 |
+
slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices
|
179 |
+
|
180 |
+
return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
|
181 |
+
def dense(q, k, v, vertical_size=None, slash_size=None):
|
182 |
+
return flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, head_dim)
|
183 |
+
def block_sparse_kernel(q, k, v, vertical_size=None, slash_size=None):
|
184 |
+
topk = 100
|
185 |
+
return block_sparse_attention(q, k, v, topk)
|
186 |
+
|
187 |
+
best_s, best_v, best_score, best_ty = 0, 0, float("inf"), ""
|
188 |
+
bsz = q.shape[0]
|
189 |
+
all_info = []
|
190 |
+
ref = dense(q, k, v)
|
191 |
+
for ty, fc in [("stream_llm", streaming_forward), ("vertical_and_slash", vertical_and_slash_kernel), ("block_sparse", block_sparse_kernel)]:
|
192 |
+
if ty == "stream_llm":
|
193 |
+
vs_list = [(100, 800)]
|
194 |
+
elif ty == "vertical_and_slash":
|
195 |
+
vs_list = [(30, 800), (100, 800), (100, 750), (500, 700), (3500, 100), (1000, 4096)]
|
196 |
+
else:
|
197 |
+
vs_list = [(10, 1)]
|
198 |
+
for v_size, s_size in vs_list:
|
199 |
+
score = fc(q, k, v, v_size, s_size)
|
200 |
+
# delta = (ref - score).abs().sum()
|
201 |
+
delta = ((ref - score).abs() > 5e-3).sum()
|
202 |
+
score = delta.item()
|
203 |
+
all_info.append([ty, v_size, s_size, score])
|
204 |
+
if score < best_score:
|
205 |
+
best_score = score
|
206 |
+
best_s, best_v = s_size, v_size
|
207 |
+
best_ty = ty
|
208 |
+
print(head, best_ty, best_v, best_s, best_score)
|
209 |
+
return all_info
|
210 |
+
|
211 |
+
def shift_matrix(mat):
|
212 |
+
b, h, _, n = mat.shape
|
213 |
+
zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding
|
214 |
+
mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right
|
215 |
+
mat_strided = mat_padded.as_strided((1, 1, n, n + 2 * n), (1, n * (2 * n + n), 2 * n + n - 1, 1)) # Change the strides
|
216 |
+
return mat_strided[...,2 * n-1:-1]
|
217 |
+
|
218 |
+
def repeat(self, q, k, v, attention_mask):
|
219 |
+
q_len = q.shape[2]
|
220 |
+
if q_len == 1:
|
221 |
+
return gather_qkv(q, k, v, attention_mask)
|
222 |
+
qk = torch.matmul(q[:,:,-1:,:], k.transpose(2, 3)) / math.sqrt(self.head_dim)
|
223 |
+
qk = qk.repeat(1,1,q_len, 1)
|
224 |
+
qk = shift_matrix(qk) + attention_mask
|
225 |
+
attn_weights = nn.functional.softmax(qk, dim=-1, dtype=torch.float32).to(q.dtype)
|
226 |
+
attn_output = torch.matmul(attn_weights, v)
|
227 |
+
return attn_output
|
228 |
+
|
229 |
+
def gather_last_q_vertical_slash_topk_v4(self, q, k, v, head_id):
|
230 |
+
kv_seq_len = k.size(2)
|
231 |
+
|
232 |
+
def vertical_and_slash(attn_weights, vertical_size, slash_size):
|
233 |
+
last_q = 64
|
234 |
+
q_len = q.shape[2]
|
235 |
+
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
|
236 |
+
qk_idxs = [ii + q_len for ii in list(range(-last_q, 0, 1))]
|
237 |
+
qk = torch.matmul(q[:,:,qk_idxs,:], k.transpose(2, 3))/ math.sqrt(self.head_dim) + attention_mask[:,:,qk_idxs]
|
238 |
+
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
|
239 |
+
vertical = qk.sum(-2, keepdim=True)
|
240 |
+
vertical[...,:30] = -self.ne_inf
|
241 |
+
vertical_topk = torch.topk(-vertical, q_len - vertical_size, -1).indices
|
242 |
+
|
243 |
+
slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
|
244 |
+
slash[...,-30:] = -self.ne_inf
|
245 |
+
slash_topk = slash
|
246 |
+
slash = torch.topk(slash, slash_size, -1).indices - (q_len - 1)
|
247 |
+
slash = torch.stack([torch.sparse.spdiags(torch.ones(slash_size, q_len), slash.cpu()[0][_], (q_len, q_len)).to_dense() for _ in range(1)]).to(q.device)
|
248 |
+
|
249 |
+
est_attn = torch.ones_like(attn_weights)
|
250 |
+
dim = 3
|
251 |
+
est_attn = est_attn.scatter(3, vertical_topk.expand(*est_attn.shape[:dim], vertical_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
|
252 |
+
est_attn = est_attn + slash
|
253 |
+
|
254 |
+
est_attn = (est_attn > 0).float()
|
255 |
+
est_attn = torch.tril(est_attn)
|
256 |
+
est_attn = (est_attn == 0).int() * self.ne_inf
|
257 |
+
attn_weights = attn_weights + est_attn
|
258 |
+
if self.kv_cache_compressed_v4:
|
259 |
+
self.vertical = torch.topk(vertical, vertical_size * 4, -1).indices
|
260 |
+
self.slash = (torch.topk(slash_topk, slash_size * 4, -1).indices - (q_len - 1)).unsqueeze(2)
|
261 |
+
return attn_weights
|
262 |
+
|
263 |
+
def stream_llm(attn_weights, vertical_size, slash_size):
|
264 |
+
q_len = q.shape[2]
|
265 |
+
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
|
266 |
+
mask = torch.triu(torch.tril(torch.ones(q_len, q_len), 0), -slash_size).to(q)
|
267 |
+
mask[:,:vertical_size] = 1
|
268 |
+
mask = mask.unsqueeze(0).unsqueeze(1)
|
269 |
+
|
270 |
+
est_attn = torch.tril(mask)
|
271 |
+
est_attn = (est_attn == 0).int() * self.ne_inf
|
272 |
+
attn_weights = attn_weights + est_attn
|
273 |
+
if self.kv_cache_compressed_v4:
|
274 |
+
self.vertical = torch.Tensor(list(range(vertical_size * 4))).long().to(q.device).unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
275 |
+
self.slash = torch.Tensor(list(range(-slash_size * 4, 1))).long().to(q.device).unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
276 |
+
return attn_weights
|
277 |
+
|
278 |
+
def block_sparse(attn_weights, topk_ratio, slash_size=None, block_size=8):
|
279 |
+
block_num = (q_len -1) // block_size + 1
|
280 |
+
block_q = torch.zeros(1,1,block_num * block_size,head_dim).to(q)
|
281 |
+
block_q[:,:,:q_len] = q
|
282 |
+
block_q = block_q.reshape(1,1,block_num,block_size,-1).mean(-2)
|
283 |
+
block_k = torch.zeros(1,1,block_num * block_size,head_dim).to(k)
|
284 |
+
block_k[:,:,:q_len] = k
|
285 |
+
block_k = block_k.reshape(1,1,block_num,block_size,-1).mean(-2)
|
286 |
+
|
287 |
+
qk = torch.matmul(block_q, block_k.transpose(2, 3)) + attention_mask[:,:,:block_num,:block_num]
|
288 |
+
est_attn = torch.ones_like(qk)
|
289 |
+
block_topk = torch.topk(-qk, block_num - block_num//topk_ratio, -1).indices
|
290 |
+
|
291 |
+
dim = 3
|
292 |
+
est_attn = est_attn.scatter(3, block_topk.expand(*est_attn.shape[:dim], block_topk.shape[dim], *est_attn.shape[dim + 1 :]), 0)
|
293 |
+
est_attn = est_attn.unsqueeze(3).unsqueeze(-1).repeat(1,1,1,block_size,1,block_size).reshape(1,1,block_num * block_size, block_num * block_size)[...,:q_len,:q_len]
|
294 |
+
est_attn = torch.tril(est_attn)
|
295 |
+
est_attn = (est_attn == 0).int()
|
296 |
+
attn_weights = attn_weights + est_attn
|
297 |
+
return attn_weights
|
298 |
+
|
299 |
+
def dialted(q,k,v, type):
|
300 |
+
q_len = q.shape[2]
|
301 |
+
n_init = min(1024, q_len)
|
302 |
+
vertical_topk = torch.arange(0, n_init, device=q.device)[None, None, None, :]
|
303 |
+
|
304 |
+
slash = torch.arange(0, q_len, device=q.device)
|
305 |
+
if type == 'dilated1':
|
306 |
+
# 8k local with 1 interval
|
307 |
+
slash = slash[-8192::2][None, None, :]
|
308 |
+
elif type == 'dilated2':
|
309 |
+
# 2k dense local + 4k local with 1 interval
|
310 |
+
slash = torch.cat([slash[-2048:], slash[-6144:-2048:2]], 0)[None, None, :]
|
311 |
+
|
312 |
+
slash = (q_len - 1) - slash
|
313 |
+
return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
|
314 |
+
|
315 |
+
def vertical_and_slash_kernel(q, k, v, vertical_size, slash_size):
|
316 |
+
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
|
317 |
+
last_q = min(64, q_len)
|
318 |
+
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
|
319 |
+
qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK[...,-last_q:,-last_q:].to(q.device), qk[:, :, :, -last_q:], -torch.inf)
|
320 |
+
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
|
321 |
+
vertical = qk.sum(-2, keepdim=True)
|
322 |
+
vertical[...,:30] = torch.inf
|
323 |
+
vertical_topk = torch.topk(vertical, vertical_size, -1).indices
|
324 |
+
|
325 |
+
slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
|
326 |
+
slash[...,-100:] = torch.inf
|
327 |
+
slash_topk = slash
|
328 |
+
slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices
|
329 |
+
|
330 |
+
return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
|
331 |
+
|
332 |
+
def vertical_and_slash_kernel_static(q, k, v, vertical_size, slash_size):
|
333 |
+
if "vs" in self.__dict__:
|
334 |
+
vertical_topk, slash = self.vs
|
335 |
+
else:
|
336 |
+
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
|
337 |
+
last_q = 64
|
338 |
+
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
|
339 |
+
qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK, qk[:, :, :, -last_q:], -torch.inf)
|
340 |
+
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
|
341 |
+
vertical = qk.sum(-2, keepdim=True)
|
342 |
+
vertical[...,:30] = torch.inf
|
343 |
+
vertical_topk = torch.topk(vertical, vertical_size, -1).indices
|
344 |
+
|
345 |
+
slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
|
346 |
+
slash[...,-30:] = torch.inf
|
347 |
+
slash_topk = slash
|
348 |
+
slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices
|
349 |
+
self.vs = vertical_topk, slash
|
350 |
+
|
351 |
+
return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
|
352 |
+
def dense(q, k, v, vertical_size=None, slash_size=None):
|
353 |
+
return flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, self.head_dim)
|
354 |
+
def block_sparse_kernel(q, k, v, vertical_size=None, slash_size=None):
|
355 |
+
topk = 100
|
356 |
+
return block_sparse_attention(q, k, v, topk)
|
357 |
+
|
358 |
+
q_len = q.shape[2]
|
359 |
+
bsz = q.shape[0]
|
360 |
+
|
361 |
+
if self.config.to_dict().get("dilated1", False):
|
362 |
+
return dialted(q, k, v, 'dilated1')
|
363 |
+
if self.config.to_dict().get("dilated2", False):
|
364 |
+
return dialted(q, k, v, 'dilated2')
|
365 |
+
if self.config.to_dict().get("dense", False):
|
366 |
+
return dense(q, k, v)
|
367 |
+
if self.config.to_dict().get("streaming", False):
|
368 |
+
return streaming_forward(q, k, v, self.config.streaming_kwargs["n_init"], self.config.streaming_kwargs["n_local"])
|
369 |
+
|
370 |
+
ty, vertical_size, slash_size, _ = self.best_pattern.get(head_id, ("vertical_and_slash", 1000, 6096, 1))
|
371 |
+
|
372 |
+
if self.config.to_dict().get("static_pattern", False):
|
373 |
+
return vertical_and_slash_kernel_static(q, k, v, vertical_size, slash_size)
|
374 |
+
if self.config.to_dict().get("vs_only", False):
|
375 |
+
return vertical_and_slash_kernel(q, k, v, vertical_size, slash_size)
|
376 |
+
|
377 |
+
if q_len == 1:
|
378 |
+
return dense(q, k, v)
|
379 |
+
|
380 |
+
fc = {
|
381 |
+
"stream_llm": streaming_forward,
|
382 |
+
"vertical_and_slash": vertical_and_slash_kernel,
|
383 |
+
"block_sparse": block_sparse_kernel,
|
384 |
+
}[ty]
|
385 |
+
return fc(q, k, v, vertical_size, slash_size)
|
386 |
+
|
387 |
+
def apply_rotary_pos_emb_single(q, cos, sin, position_ids, unsqueeze_dim=1):
|
388 |
+
# cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
389 |
+
# sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
390 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
391 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
392 |
+
return (q * cos) + (rotate_half(q) * sin)
|
393 |
+
|
394 |
+
def minference_forward():
|
395 |
+
def forward(
|
396 |
+
self,
|
397 |
+
hidden_states,
|
398 |
+
attention_mask,
|
399 |
+
position_ids,
|
400 |
+
past_key_value,
|
401 |
+
output_attentions,
|
402 |
+
use_cache,
|
403 |
+
**kwargs,
|
404 |
+
):
|
405 |
+
self.init_minference_parameters()
|
406 |
+
self.ne_inf = torch.finfo(hidden_states.dtype).min
|
407 |
+
|
408 |
+
bsz, q_len, _ = hidden_states.size()
|
409 |
+
|
410 |
+
if "q_proj" in self.__dict__["_modules"]:
|
411 |
+
query_states = self.q_proj(hidden_states)
|
412 |
+
key_states = self.k_proj(hidden_states)
|
413 |
+
value_states = self.v_proj(hidden_states)
|
414 |
+
else:
|
415 |
+
qkv = self.qkv_proj(hidden_states)
|
416 |
+
query_pos = self.num_heads * self.head_dim
|
417 |
+
query_states, key_states, value_states = torch.split(qkv, query_pos, -1)
|
418 |
+
|
419 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
420 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
421 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
422 |
+
|
423 |
+
kv_seq_len = key_states.shape[-2]
|
424 |
+
if past_key_value is not None:
|
425 |
+
if self.layer_idx is None:
|
426 |
+
raise ValueError(
|
427 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
428 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
429 |
+
"with a layer index."
|
430 |
+
)
|
431 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
432 |
+
global ROPE_TYPE
|
433 |
+
if ROPE_TYPE is None:
|
434 |
+
ROPE_TYPE = "seq_len" in inspect.signature(self.rotary_emb.forward).parameters
|
435 |
+
if ROPE_TYPE:
|
436 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
437 |
+
else:
|
438 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
439 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
440 |
+
|
441 |
+
if past_key_value is not None:
|
442 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
443 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
444 |
+
|
445 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
446 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
447 |
+
if self.is_search:
|
448 |
+
if os.path.exists(self.config_path):
|
449 |
+
config_list = json.load(open(self.config_path))
|
450 |
+
if self.layer_idx < len(config_list):
|
451 |
+
assert False
|
452 |
+
else:
|
453 |
+
config_list = []
|
454 |
+
config = {}
|
455 |
+
print("Layer", self.layer_idx)
|
456 |
+
if q_len != 1:
|
457 |
+
output = torch.empty_like(query_states)
|
458 |
+
for head in range(query_states.size(1)):
|
459 |
+
q = query_states[:, head, :, :].unsqueeze(1)
|
460 |
+
k = key_states[:, head, :, :].unsqueeze(1)
|
461 |
+
v = value_states[:, head, :, :].unsqueeze(1)
|
462 |
+
if self.is_search:
|
463 |
+
config[head] = search_pattern(q, k, head)
|
464 |
+
if self.layer_idx >= self.starting_layer and not self.is_search:
|
465 |
+
attn_output = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)
|
466 |
+
elif is_flash_attn_2_available():
|
467 |
+
attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, self.head_dim)
|
468 |
+
else:
|
469 |
+
attn_output = gather_qkv(q, k, v, attention_mask)
|
470 |
+
output[:, head:head + 1] = attn_output
|
471 |
+
if self.is_search:
|
472 |
+
config_list.append(config)
|
473 |
+
with open(self.config_path, 'w') as json_file:
|
474 |
+
json.dump(config_list, json_file)
|
475 |
+
else:
|
476 |
+
output = flash_attn_func(query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, query_states.size(1), q_len, self.head_dim)
|
477 |
+
attn_output = output.transpose(1, 2).contiguous()
|
478 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
479 |
+
attn_output = self.o_proj(attn_output)
|
480 |
+
|
481 |
+
return attn_output, None, past_key_value
|
482 |
+
|
483 |
+
return forward
|
484 |
+
|
485 |
+
def minference_kv_cache_cpu_forward():
|
486 |
+
def forward(
|
487 |
+
self,
|
488 |
+
hidden_states,
|
489 |
+
attention_mask,
|
490 |
+
position_ids,
|
491 |
+
past_key_value,
|
492 |
+
output_attentions,
|
493 |
+
use_cache,
|
494 |
+
**kwargs,
|
495 |
+
):
|
496 |
+
self.init_minference_parameters()
|
497 |
+
self.ne_inf = torch.finfo(hidden_states.dtype).min
|
498 |
+
|
499 |
+
bsz, q_len, hidden_dim = hidden_states.size()
|
500 |
+
kv_seq_len = q_len
|
501 |
+
if use_cache and past_key_value is not None:
|
502 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
503 |
+
|
504 |
+
global ROPE_TYPE
|
505 |
+
if ROPE_TYPE is None:
|
506 |
+
ROPE_TYPE = "seq_len" in inspect.signature(self.rotary_emb.forward).parameters
|
507 |
+
if ROPE_TYPE:
|
508 |
+
cos, sin = self.rotary_emb(hidden_states, seq_len=kv_seq_len)
|
509 |
+
else:
|
510 |
+
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
511 |
+
cache_kwargs = {"sin": sin, "cos": cos}
|
512 |
+
|
513 |
+
attn_out = torch.empty_like(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
514 |
+
act_num_heads = self.num_heads // self.num_key_value_groups
|
515 |
+
if use_cache:
|
516 |
+
k = torch.zeros(bsz, act_num_heads, q_len, self.head_dim).to(hidden_states.dtype).cpu()
|
517 |
+
v = torch.zeros(bsz, act_num_heads, q_len, self.head_dim).to(hidden_states.dtype).cpu()
|
518 |
+
part_k, part_v = None, None
|
519 |
+
for head in range(self.num_heads):
|
520 |
+
if "q_proj" in self.__dict__["_modules"]:
|
521 |
+
part_q = F.linear(hidden_states, self.q_proj.weight.view(self.num_heads, self.head_dim, hidden_dim)[head]).unsqueeze(2)
|
522 |
+
else:
|
523 |
+
part_q = F.linear(hidden_states, self.qkv_proj.weight.view(3, self.num_heads, self.head_dim, hidden_dim)[0][head]).unsqueeze(2)
|
524 |
+
part_q = apply_rotary_pos_emb_single(part_q.transpose(1, 2), cos, sin, position_ids)
|
525 |
+
|
526 |
+
if head % self.num_key_value_groups == 0:
|
527 |
+
if "q_proj" in self.__dict__["_modules"]:
|
528 |
+
part_k = F.linear(hidden_states, self.k_proj.weight.view(act_num_heads, self.head_dim, hidden_dim)[head // self.num_key_value_groups]).unsqueeze(2)
|
529 |
+
part_v = F.linear(hidden_states, self.v_proj.weight.view(act_num_heads, self.head_dim, hidden_dim)[head // self.num_key_value_groups]).unsqueeze(2).transpose(1, 2)
|
530 |
+
else:
|
531 |
+
part_k = F.linear(hidden_states, self.qkv_proj.weight.view(3, act_num_heads, self.head_dim, hidden_dim)[1][head // self.num_key_value_groups]).unsqueeze(2)
|
532 |
+
part_v = F.linear(hidden_states, self.qkv_proj.weight.view(3, act_num_heads, self.head_dim, hidden_dim)[2][head // self.num_key_value_groups]).unsqueeze(2).transpose(1, 2)
|
533 |
+
|
534 |
+
part_k = apply_rotary_pos_emb_single(part_k.transpose(1, 2), cos, sin, position_ids)
|
535 |
+
if use_cache and past_key_value is not None:
|
536 |
+
k[:,head // self.num_key_value_groups] = part_k.cpu()
|
537 |
+
v[:,head // self.num_key_value_groups] = part_v.cpu()
|
538 |
+
part_k, part_v = past_key_value.get(part_k, part_v, self.layer_idx, head // self.num_key_value_groups, cache_kwargs)
|
539 |
+
|
540 |
+
if self.layer_idx >= self.starting_layer:
|
541 |
+
part_o = self.gather_last_q_vertical_slash_topk_v4(part_q, part_k, part_v, head)
|
542 |
+
else:
|
543 |
+
part_o = flash_attn_func(part_q, part_k, part_v.transpose(1, 2), 0.0, softmax_scale=None, causal=True).view(bsz, part_q.shape[1], self.head_dim)
|
544 |
+
attn_out[:, :, head, :] = part_o
|
545 |
+
|
546 |
+
if use_cache and past_key_value is not None:
|
547 |
+
past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
548 |
+
torch.matmul(attn_out.view(bsz, q_len, hidden_dim), self.o_proj.weight.T, out=hidden_states)
|
549 |
+
torch.cuda.empty_cache()
|
550 |
+
return (hidden_states, None, past_key_value)
|
551 |
+
|
552 |
+
return forward
|
553 |
+
|
554 |
+
def minference_with_snapkv_forward():
|
555 |
+
def forward(
|
556 |
+
self,
|
557 |
+
hidden_states,
|
558 |
+
attention_mask,
|
559 |
+
position_ids,
|
560 |
+
past_key_value,
|
561 |
+
output_attentions,
|
562 |
+
use_cache,
|
563 |
+
**kwargs,
|
564 |
+
):
|
565 |
+
self.init_minference_parameters()
|
566 |
+
self.ne_inf = torch.finfo(hidden_states.dtype).min
|
567 |
+
|
568 |
+
init_snapkv(self)
|
569 |
+
|
570 |
+
bsz, q_len, _ = hidden_states.size()
|
571 |
+
|
572 |
+
query_states = self.q_proj(hidden_states)
|
573 |
+
key_states = self.k_proj(hidden_states)
|
574 |
+
value_states = self.v_proj(hidden_states)
|
575 |
+
|
576 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
577 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
578 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
579 |
+
|
580 |
+
kv_seq_len = key_states.shape[-2]
|
581 |
+
if past_key_value is not None:
|
582 |
+
if self.layer_idx is None:
|
583 |
+
raise ValueError(
|
584 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
585 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
586 |
+
"with a layer index."
|
587 |
+
)
|
588 |
+
|
589 |
+
if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len
|
590 |
+
if self.kv_seq_len != 0:
|
591 |
+
kv_seq_len += self.kv_seq_len
|
592 |
+
else:
|
593 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
594 |
+
else:
|
595 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
596 |
+
global ROPE_TYPE
|
597 |
+
if ROPE_TYPE is None:
|
598 |
+
ROPE_TYPE = "seq_len" in inspect.signature(self.rotary_emb.forward).parameters
|
599 |
+
if ROPE_TYPE:
|
600 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
601 |
+
else:
|
602 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
603 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
604 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
605 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
606 |
+
|
607 |
+
if past_key_value is not None:
|
608 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
609 |
+
if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
|
610 |
+
self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len
|
611 |
+
key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
|
612 |
+
past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
|
613 |
+
else:
|
614 |
+
self.kv_seq_len += q_len
|
615 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
616 |
+
|
617 |
+
if self.layer_idx >= self.starting_layer:
|
618 |
+
assert query_states.size(1) == key_states.size(1) == value_states.size(1)
|
619 |
+
output = torch.empty_like(query_states)
|
620 |
+
for head in range(query_states.size(1)):
|
621 |
+
q = query_states[:, head, :, :].unsqueeze(1)
|
622 |
+
k = key_states[:, head, :, :].unsqueeze(1)
|
623 |
+
v = value_states[:, head, :, :].unsqueeze(1)
|
624 |
+
output[:, head:head + 1] = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)
|
625 |
+
|
626 |
+
attn_output = output.transpose(1, 2).contiguous()
|
627 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
628 |
+
attn_output = self.o_proj(attn_output)
|
629 |
+
return attn_output, None, past_key_value
|
630 |
+
|
631 |
+
else:
|
632 |
+
output = torch.empty_like(query_states)
|
633 |
+
for head in range(query_states.size(1)):
|
634 |
+
q = query_states[:, head, :, :].unsqueeze(1)
|
635 |
+
k = key_states[:, head, :, :].unsqueeze(1)
|
636 |
+
v = value_states[:, head, :, :].unsqueeze(1)
|
637 |
+
if is_flash_attn_2_available():
|
638 |
+
attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q.shape[2], self.head_dim)
|
639 |
+
else:
|
640 |
+
attn_output = gather_qkv(q, k, v, attention_mask)
|
641 |
+
output[:, head:head + 1] = attn_output
|
642 |
+
attn_output = output.transpose(1, 2).contiguous()
|
643 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
644 |
+
attn_output = self.o_proj(attn_output)
|
645 |
+
|
646 |
+
return attn_output, None, past_key_value
|
647 |
+
|
648 |
+
return forward
|
649 |
+
|
650 |
+
def gather_last_q_vertical_slash_topk_vllm(self, q, k, v, head_id):
|
651 |
+
kv_seq_len = k.size(2)
|
652 |
+
head_dim = q.size(-1)
|
653 |
+
|
654 |
+
def vertical_and_slash_kernel(q, k, v, vertical_size, slash_size):
|
655 |
+
vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50))
|
656 |
+
last_q = min(64, q_len)
|
657 |
+
qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k)
|
658 |
+
|
659 |
+
qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK[...,-last_q:,-last_q:], qk[:, :, :, -last_q:], -torch.inf)
|
660 |
+
qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
|
661 |
+
vertical = qk.sum(-2, keepdim=True)
|
662 |
+
vertical[...,:30] = torch.inf
|
663 |
+
vertical_topk = torch.topk(vertical, vertical_size, -1).indices
|
664 |
+
|
665 |
+
slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1]
|
666 |
+
slash[...,-100:] = torch.inf
|
667 |
+
slash_topk = slash
|
668 |
+
slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices
|
669 |
+
|
670 |
+
return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
|
671 |
+
|
672 |
+
def block_sparse_kernel(q, k, v, vertical_size=None, slash_size=None):
|
673 |
+
topk = 100
|
674 |
+
return block_sparse_attention(q, k, v, topk)
|
675 |
+
|
676 |
+
def dense(q, k, v, vertical_size=None, slash_size=None):
|
677 |
+
return flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, 1, q_len, head_dim)
|
678 |
+
|
679 |
+
q_len = q.shape[2]
|
680 |
+
bsz = q.shape[0]
|
681 |
+
|
682 |
+
ty, vertical_size, slash_size, _ = self.best_pattern[head_id]
|
683 |
+
|
684 |
+
if q_len == 1:
|
685 |
+
return dense(q, k, v)
|
686 |
+
|
687 |
+
fc = {
|
688 |
+
"stream_llm": streaming_forward,
|
689 |
+
"vertical_and_slash": vertical_and_slash_kernel,
|
690 |
+
"block_sparse": block_sparse_kernel,
|
691 |
+
}[ty]
|
692 |
+
return fc(q, k, v, vertical_size, slash_size)
|
693 |
+
|
694 |
+
def minference_vllm_forward(
|
695 |
+
pattern_config
|
696 |
+
):
|
697 |
+
def forward(
|
698 |
+
self,
|
699 |
+
query: torch.Tensor,
|
700 |
+
key: torch.Tensor,
|
701 |
+
value: torch.Tensor,
|
702 |
+
kv_cache: torch.Tensor,
|
703 |
+
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
|
704 |
+
kv_scale: float,
|
705 |
+
layer_idx: int,
|
706 |
+
) -> torch.Tensor:
|
707 |
+
"""Forward pass with FlashAttention and PagedAttention.
|
708 |
+
|
709 |
+
Args:
|
710 |
+
query: shape = [num_tokens, num_heads * head_size]
|
711 |
+
key: shape = [num_tokens, num_kv_heads * head_size]
|
712 |
+
value: shape = [num_tokens, num_kv_heads * head_size]
|
713 |
+
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
714 |
+
attn_metadata: Metadata for attention.
|
715 |
+
Returns:
|
716 |
+
shape = [num_tokens, num_heads * head_size]
|
717 |
+
"""
|
718 |
+
self.best_pattern = {int(ii): jj for ii, jj in pattern_config[layer_idx].items()}
|
719 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
720 |
+
"""
|
721 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
722 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
723 |
+
"""
|
724 |
+
slen, num_key_value_heads, head_dim = hidden_states.shape
|
725 |
+
if n_rep == 1:
|
726 |
+
return hidden_states
|
727 |
+
hidden_states = hidden_states[:, None, :, :].expand(slen, n_rep, num_key_value_heads, head_dim)
|
728 |
+
return hidden_states.reshape(slen, num_key_value_heads * n_rep, head_dim)
|
729 |
+
|
730 |
+
def minference_prefill_func(
|
731 |
+
q, k, v,
|
732 |
+
|
733 |
+
):
|
734 |
+
# (seq_len, num_heads, head_size)
|
735 |
+
if q.size(-2) != k.size(-2):
|
736 |
+
k = repeat_kv(k, q.size(-2) // k.size(-2))
|
737 |
+
v = repeat_kv(v, q.size(-2) // v.size(-2))
|
738 |
+
|
739 |
+
output = torch.empty_like(q)
|
740 |
+
for head in range(q.size(-2)):
|
741 |
+
q_head = q[:, head, :].unsqueeze(1)
|
742 |
+
k_head = k[:, head, :].unsqueeze(1)
|
743 |
+
v_head = v[:, head, :].unsqueeze(1)
|
744 |
+
|
745 |
+
# (1, seq_len, num_heads, head_size)
|
746 |
+
q_head = q_head[None, ...]
|
747 |
+
k_head = k_head[None, ...]
|
748 |
+
v_head = v_head[None, ...]
|
749 |
+
|
750 |
+
q_head = q_head.transpose(1, 2)
|
751 |
+
k_head = k_head.transpose(1, 2)
|
752 |
+
v_head = v_head.transpose(1, 2)
|
753 |
+
|
754 |
+
out = self.gather_last_q_vertical_slash_topk_vllm(q_head, k_head, v_head, head)
|
755 |
+
|
756 |
+
out = out.transpose(1, 2).squeeze(0).contiguous()
|
757 |
+
output[:, head:head+1, :] = out
|
758 |
+
return output
|
759 |
+
|
760 |
+
num_tokens, hidden_size = query.shape
|
761 |
+
# Reshape the query, key, and value tensors.
|
762 |
+
query = query.view(-1, self.num_heads, self.head_size)
|
763 |
+
key = key.view(-1, self.num_kv_heads, self.head_size)
|
764 |
+
value = value.view(-1, self.num_kv_heads, self.head_size)
|
765 |
+
|
766 |
+
if kv_cache is not None:
|
767 |
+
key_cache, value_cache = PagedAttention.split_kv_cache(
|
768 |
+
kv_cache, self.num_kv_heads, self.head_size)
|
769 |
+
|
770 |
+
# Reshape the input keys and values and store them in the cache.
|
771 |
+
# If kv_cache is not provided, the new key and value tensors are
|
772 |
+
# not cached. This happens during the initial memory profiling run.
|
773 |
+
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
774 |
+
value_cache,
|
775 |
+
attn_metadata.slot_mapping,
|
776 |
+
attn_metadata.kv_cache_dtype,
|
777 |
+
kv_scale)
|
778 |
+
|
779 |
+
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
780 |
+
num_decode_tokens = attn_metadata.num_decode_tokens
|
781 |
+
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
782 |
+
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
783 |
+
|
784 |
+
output = torch.empty_like(query)
|
785 |
+
# Query for decode. KV is not needed because it is already cached.
|
786 |
+
decode_query = query[num_prefill_tokens:]
|
787 |
+
# QKV for prefill.
|
788 |
+
query = query[:num_prefill_tokens]
|
789 |
+
key = key[:num_prefill_tokens]
|
790 |
+
value = value[:num_prefill_tokens]
|
791 |
+
|
792 |
+
assert query.shape[0] == num_prefill_tokens
|
793 |
+
assert decode_query.shape[0] == num_decode_tokens
|
794 |
+
|
795 |
+
if prefill_meta := attn_metadata.prefill_metadata:
|
796 |
+
# Prompt run.
|
797 |
+
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
798 |
+
# normal attention
|
799 |
+
# When block_tables are not filled, it means q and k are the
|
800 |
+
# prompt, and they have the same length.
|
801 |
+
# (seq_len, num_heads, head_size)
|
802 |
+
# out = flash_attn_varlen_func(
|
803 |
+
# q=query,
|
804 |
+
# k=key,
|
805 |
+
# v=value,
|
806 |
+
# cu_seqlens_q=prefill_meta.seq_start_loc,
|
807 |
+
# cu_seqlens_k=prefill_meta.seq_start_loc,
|
808 |
+
# max_seqlen_q=prefill_meta.max_prompt_len,
|
809 |
+
# max_seqlen_k=prefill_meta.max_prompt_len,
|
810 |
+
# softmax_scale=self.scale,
|
811 |
+
# causal=True,
|
812 |
+
# window_size=self.sliding_window,
|
813 |
+
# alibi_slopes=self.alibi_slopes,
|
814 |
+
# )
|
815 |
+
out = minference_prefill_func(query, key, value)
|
816 |
+
assert output[:num_prefill_tokens].shape == out.shape
|
817 |
+
output[:num_prefill_tokens] = out
|
818 |
+
else:
|
819 |
+
# prefix-enabled attention
|
820 |
+
# TODO(Hai) this triton kernel has regression issue (broke) to
|
821 |
+
# deal with different data types between KV and FP8 KV cache,
|
822 |
+
# to be addressed separately.
|
823 |
+
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
824 |
+
query,
|
825 |
+
key,
|
826 |
+
value,
|
827 |
+
key_cache,
|
828 |
+
value_cache,
|
829 |
+
prefill_meta.block_tables,
|
830 |
+
prefill_meta.subquery_start_loc,
|
831 |
+
prefill_meta.prompt_lens_tensor,
|
832 |
+
prefill_meta.context_lens,
|
833 |
+
prefill_meta.max_subquery_len,
|
834 |
+
self.alibi_slopes,
|
835 |
+
)
|
836 |
+
if decode_meta := attn_metadata.decode_metadata:
|
837 |
+
# Decoding run.
|
838 |
+
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
839 |
+
decode_query,
|
840 |
+
key_cache,
|
841 |
+
value_cache,
|
842 |
+
decode_meta.block_tables,
|
843 |
+
decode_meta.context_lens,
|
844 |
+
decode_meta.max_context_len,
|
845 |
+
attn_metadata.kv_cache_dtype,
|
846 |
+
self.num_kv_heads,
|
847 |
+
self.scale,
|
848 |
+
self.alibi_slopes,
|
849 |
+
kv_scale,
|
850 |
+
)
|
851 |
+
|
852 |
+
# Reshape the output tensor.
|
853 |
+
return output.view(num_tokens, hidden_size)
|
854 |
+
|
855 |
+
return forward
|
minference/modules/snap_kv.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import time
|
3 |
+
import warnings
|
4 |
+
from importlib.metadata import version
|
5 |
+
from typing import List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import transformers
|
11 |
+
from transformers.cache_utils import Cache, DynamicCache
|
12 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
13 |
+
from transformers.utils import logging
|
14 |
+
|
15 |
+
logger = logging.get_logger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
# https://github.com/huggingface/transformers/blob/v4.37-release/src/transformers/models/llama/modeling_llama.py
|
19 |
+
def llama_flash_attn2_forward(
|
20 |
+
self,
|
21 |
+
hidden_states: torch.Tensor,
|
22 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
23 |
+
position_ids: Optional[torch.LongTensor] = None,
|
24 |
+
past_key_value: Optional[Cache] = None,
|
25 |
+
output_attentions: bool = False,
|
26 |
+
use_cache: bool = False,
|
27 |
+
**kwargs,
|
28 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
29 |
+
# [SnapKV] register kv_cluster
|
30 |
+
init_snapkv(self)
|
31 |
+
# LlamaFlashAttention2 attention does not support output_attentions
|
32 |
+
if "padding_mask" in kwargs:
|
33 |
+
warnings.warn(
|
34 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
35 |
+
)
|
36 |
+
|
37 |
+
# overwrite attention_mask with padding_mask
|
38 |
+
attention_mask = kwargs.pop("padding_mask")
|
39 |
+
|
40 |
+
output_attentions = False
|
41 |
+
|
42 |
+
bsz, q_len, _ = hidden_states.size()
|
43 |
+
|
44 |
+
query_states = self.q_proj(hidden_states)
|
45 |
+
key_states = self.k_proj(hidden_states)
|
46 |
+
value_states = self.v_proj(hidden_states)
|
47 |
+
|
48 |
+
# Flash attention requires the input to have the shape
|
49 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
50 |
+
# therefore we just need to keep the original shape
|
51 |
+
query_states = query_states.view(
|
52 |
+
bsz, q_len, self.num_heads, self.head_dim
|
53 |
+
).transpose(1, 2)
|
54 |
+
key_states = key_states.view(
|
55 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
56 |
+
).transpose(1, 2)
|
57 |
+
value_states = value_states.view(
|
58 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
59 |
+
).transpose(1, 2)
|
60 |
+
|
61 |
+
kv_seq_len = key_states.shape[-2]
|
62 |
+
# if past_key_value is not None:
|
63 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
64 |
+
if past_key_value is not None:
|
65 |
+
if self.layer_idx is None:
|
66 |
+
raise ValueError(
|
67 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
68 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
69 |
+
"with a layer index."
|
70 |
+
)
|
71 |
+
if hasattr(self, "kv_seq_len"): # [SnapKV] add kv_seq_len
|
72 |
+
if self.kv_seq_len != 0:
|
73 |
+
kv_seq_len += self.kv_seq_len
|
74 |
+
else:
|
75 |
+
kv_seq_len += past_key_value.get_usable_length(
|
76 |
+
kv_seq_len, self.layer_idx
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
80 |
+
|
81 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
82 |
+
query_states, key_states = apply_rotary_pos_emb(
|
83 |
+
query_states, key_states, cos, sin, position_ids
|
84 |
+
)
|
85 |
+
# [SnapKV] move to ahead
|
86 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
87 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
88 |
+
|
89 |
+
if past_key_value is not None:
|
90 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
91 |
+
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
92 |
+
# print('kv_seq_len:', kv_seq_len)
|
93 |
+
# print('key_states.shape:', key_states.shape)
|
94 |
+
if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
|
95 |
+
self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len
|
96 |
+
key_states_compress, value_states_compress = self.kv_cluster.update_kv(
|
97 |
+
key_states,
|
98 |
+
query_states,
|
99 |
+
value_states,
|
100 |
+
attention_mask,
|
101 |
+
self.num_key_value_groups,
|
102 |
+
)
|
103 |
+
past_key_value.update(
|
104 |
+
key_states_compress, value_states_compress, self.layer_idx, cache_kwargs
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
self.kv_seq_len += q_len
|
108 |
+
key_states, value_states = past_key_value.update(
|
109 |
+
key_states, value_states, self.layer_idx, cache_kwargs
|
110 |
+
)
|
111 |
+
|
112 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
113 |
+
# to be able to avoid many of these transpose/reshape/view.
|
114 |
+
query_states = query_states.transpose(1, 2)
|
115 |
+
key_states = key_states.transpose(1, 2)
|
116 |
+
value_states = value_states.transpose(1, 2)
|
117 |
+
|
118 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
119 |
+
|
120 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
121 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
122 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
123 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
124 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
125 |
+
|
126 |
+
input_dtype = query_states.dtype
|
127 |
+
if input_dtype == torch.float32:
|
128 |
+
if torch.is_autocast_enabled():
|
129 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
130 |
+
# Handle the case where the model is quantized
|
131 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
132 |
+
target_dtype = self.config._pre_quantization_dtype
|
133 |
+
else:
|
134 |
+
target_dtype = self.q_proj.weight.dtype
|
135 |
+
|
136 |
+
logger.warning_once(
|
137 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
138 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
139 |
+
f" {target_dtype}."
|
140 |
+
)
|
141 |
+
|
142 |
+
query_states = query_states.to(target_dtype)
|
143 |
+
key_states = key_states.to(target_dtype)
|
144 |
+
value_states = value_states.to(target_dtype)
|
145 |
+
|
146 |
+
attn_output = self._flash_attention_forward(
|
147 |
+
query_states,
|
148 |
+
key_states,
|
149 |
+
value_states,
|
150 |
+
attention_mask,
|
151 |
+
q_len,
|
152 |
+
dropout=dropout_rate,
|
153 |
+
)
|
154 |
+
|
155 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
156 |
+
attn_output = self.o_proj(attn_output)
|
157 |
+
|
158 |
+
if not output_attentions:
|
159 |
+
attn_weights = None
|
160 |
+
|
161 |
+
return attn_output, attn_weights, past_key_value
|
162 |
+
|
163 |
+
|
164 |
+
def prepare_inputs_for_generation_llama(
|
165 |
+
self,
|
166 |
+
input_ids,
|
167 |
+
past_key_values=None,
|
168 |
+
attention_mask=None,
|
169 |
+
inputs_embeds=None,
|
170 |
+
**kwargs,
|
171 |
+
):
|
172 |
+
if past_key_values is None: # [SnapKV]
|
173 |
+
for layer in self.model.layers:
|
174 |
+
layer.self_attn.kv_seq_len = 0
|
175 |
+
if past_key_values is not None:
|
176 |
+
if isinstance(past_key_values, Cache):
|
177 |
+
cache_length = past_key_values.get_seq_length()
|
178 |
+
past_length = past_key_values.seen_tokens
|
179 |
+
max_cache_length = past_key_values.get_max_length()
|
180 |
+
else:
|
181 |
+
# cache_length = past_length = past_key_values[0][0].shape[2]
|
182 |
+
# max_cache_length = None
|
183 |
+
cache_length = past_length = self.model.layers[0].self_attn.kv_seq_len
|
184 |
+
max_cache_length = None
|
185 |
+
# Keep only the unprocessed tokens:
|
186 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
187 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
188 |
+
# input)
|
189 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
190 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
191 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
192 |
+
# input_ids based on the past_length.
|
193 |
+
elif past_length < input_ids.shape[1]:
|
194 |
+
input_ids = input_ids[:, past_length:]
|
195 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
196 |
+
|
197 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
198 |
+
if (
|
199 |
+
max_cache_length is not None
|
200 |
+
and attention_mask is not None
|
201 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
202 |
+
):
|
203 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
204 |
+
|
205 |
+
position_ids = kwargs.get("position_ids", None)
|
206 |
+
if attention_mask is not None and position_ids is None:
|
207 |
+
# create position_ids on the fly for batch generation
|
208 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
209 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
210 |
+
if past_key_values:
|
211 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
212 |
+
|
213 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
214 |
+
if inputs_embeds is not None and past_key_values is None:
|
215 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
216 |
+
else:
|
217 |
+
model_inputs = {"input_ids": input_ids}
|
218 |
+
|
219 |
+
model_inputs.update(
|
220 |
+
{
|
221 |
+
"position_ids": position_ids,
|
222 |
+
"past_key_values": past_key_values,
|
223 |
+
"use_cache": kwargs.get("use_cache"),
|
224 |
+
"attention_mask": attention_mask,
|
225 |
+
}
|
226 |
+
)
|
227 |
+
return model_inputs
|
228 |
+
|
229 |
+
|
230 |
+
llama_flash_attn2_forward_4_37 = llama_flash_attn2_forward
|
231 |
+
prepare_inputs_for_generation_llama_4_37 = prepare_inputs_for_generation_llama
|
232 |
+
|
233 |
+
|
234 |
+
@torch.no_grad()
|
235 |
+
def rope_forward(self, x, seq_len):
|
236 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
237 |
+
position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0)
|
238 |
+
inv_freq_expanded = (
|
239 |
+
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
240 |
+
)
|
241 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
242 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
243 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
244 |
+
device_type = x.device.type
|
245 |
+
device_type = (
|
246 |
+
device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
247 |
+
)
|
248 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
249 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
|
250 |
+
1, 2
|
251 |
+
)
|
252 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
253 |
+
cos = emb.cos()
|
254 |
+
sin = emb.sin()
|
255 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
256 |
+
|
257 |
+
|
258 |
+
##################
|
259 |
+
|
260 |
+
# perform qk calculation and get indices
|
261 |
+
# this version will not update in inference mode
|
262 |
+
|
263 |
+
|
264 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
265 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
266 |
+
"""
|
267 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
268 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
269 |
+
"""
|
270 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
271 |
+
if n_rep == 1:
|
272 |
+
return hidden_states
|
273 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
274 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
275 |
+
)
|
276 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
277 |
+
|
278 |
+
|
279 |
+
class SnapKVCluster:
|
280 |
+
def __init__(
|
281 |
+
self,
|
282 |
+
window_size=64,
|
283 |
+
max_capacity_prompt=256 + 64,
|
284 |
+
kernel_size=5,
|
285 |
+
pooling="avgpool",
|
286 |
+
):
|
287 |
+
self.window_size = window_size
|
288 |
+
self.max_capacity_prompt = max_capacity_prompt
|
289 |
+
assert self.max_capacity_prompt - self.window_size > 0
|
290 |
+
self.kernel_size = kernel_size
|
291 |
+
self.pooling = pooling
|
292 |
+
|
293 |
+
def reset(
|
294 |
+
self,
|
295 |
+
window_size=64,
|
296 |
+
max_capacity_prompt=256 + 64,
|
297 |
+
kernel_size=5,
|
298 |
+
pooling="avgpool",
|
299 |
+
):
|
300 |
+
self.window_size = window_size
|
301 |
+
self.max_capacity_prompt = max_capacity_prompt
|
302 |
+
assert self.max_capacity_prompt - self.window_size > 0
|
303 |
+
self.kernel_size = kernel_size
|
304 |
+
self.pooling = pooling
|
305 |
+
|
306 |
+
def update_kv(
|
307 |
+
self,
|
308 |
+
key_states,
|
309 |
+
query_states,
|
310 |
+
value_states,
|
311 |
+
attention_mask,
|
312 |
+
num_key_value_groups,
|
313 |
+
):
|
314 |
+
# check if prefix phase
|
315 |
+
assert key_states.shape[-2] == query_states.shape[-2]
|
316 |
+
bsz, num_heads, q_len, head_dim = query_states.shape
|
317 |
+
if q_len < self.max_capacity_prompt:
|
318 |
+
return key_states, value_states
|
319 |
+
else:
|
320 |
+
attn_weights = torch.matmul(
|
321 |
+
query_states[..., -self.window_size :, :], key_states.transpose(2, 3)
|
322 |
+
) / math.sqrt(head_dim)
|
323 |
+
mask = torch.full(
|
324 |
+
(self.window_size, self.window_size),
|
325 |
+
torch.finfo(attn_weights.dtype).min,
|
326 |
+
device=attn_weights.device,
|
327 |
+
)
|
328 |
+
mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
|
329 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
330 |
+
mask = mask.to(attn_weights.device)
|
331 |
+
attention_mask = mask[None, None, :, :]
|
332 |
+
|
333 |
+
attn_weights[
|
334 |
+
:, :, -self.window_size :, -self.window_size :
|
335 |
+
] += attention_mask
|
336 |
+
|
337 |
+
attn_weights = nn.functional.softmax(
|
338 |
+
attn_weights, dim=-1, dtype=torch.float32
|
339 |
+
).to(query_states.dtype)
|
340 |
+
attn_weights_sum = attn_weights[
|
341 |
+
:, :, -self.window_size :, : -self.window_size
|
342 |
+
].sum(dim=-2)
|
343 |
+
if self.pooling == "avgpool":
|
344 |
+
attn_cache = F.avg_pool1d(
|
345 |
+
attn_weights_sum,
|
346 |
+
kernel_size=self.kernel_size,
|
347 |
+
padding=self.kernel_size // 2,
|
348 |
+
stride=1,
|
349 |
+
)
|
350 |
+
elif self.pooling == "maxpool":
|
351 |
+
attn_cache = F.max_pool1d(
|
352 |
+
attn_weights_sum,
|
353 |
+
kernel_size=self.kernel_size,
|
354 |
+
padding=self.kernel_size // 2,
|
355 |
+
stride=1,
|
356 |
+
)
|
357 |
+
else:
|
358 |
+
raise ValueError("Pooling method not supported")
|
359 |
+
indices = attn_cache.topk(
|
360 |
+
self.max_capacity_prompt - self.window_size, dim=-1
|
361 |
+
).indices
|
362 |
+
indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
|
363 |
+
k_past_compress = key_states[:, :, : -self.window_size, :].gather(
|
364 |
+
dim=2, index=indices
|
365 |
+
)
|
366 |
+
v_past_compress = value_states[:, :, : -self.window_size, :].gather(
|
367 |
+
dim=2, index=indices
|
368 |
+
)
|
369 |
+
k_cur = key_states[:, :, -self.window_size :, :]
|
370 |
+
v_cur = value_states[:, :, -self.window_size :, :]
|
371 |
+
key_states = torch.cat([k_past_compress, k_cur], dim=2)
|
372 |
+
value_states = torch.cat([v_past_compress, v_cur], dim=2)
|
373 |
+
return key_states, value_states
|
374 |
+
|
375 |
+
|
376 |
+
def init_snapkv(self):
|
377 |
+
if not hasattr(self, "kv_cluster"):
|
378 |
+
if not hasattr(self.config, "window_size"):
|
379 |
+
self.config.window_size = 64
|
380 |
+
if not hasattr(self.config, "max_capacity_prompt"):
|
381 |
+
self.config.max_capacity_prompt = 4096
|
382 |
+
if not hasattr(self.config, "kernel_size"):
|
383 |
+
self.config.kernel_size = 13
|
384 |
+
if not hasattr(self.config, "pooling"):
|
385 |
+
self.config.pooling = "avgpool"
|
386 |
+
self.kv_cluster = SnapKVCluster(
|
387 |
+
window_size=self.config.window_size,
|
388 |
+
max_capacity_prompt=self.config.max_capacity_prompt,
|
389 |
+
kernel_size=self.config.kernel_size,
|
390 |
+
pooling=self.config.pooling,
|
391 |
+
)
|
392 |
+
|
393 |
+
|
394 |
+
############
|
395 |
+
|
396 |
+
|
397 |
+
def check_version():
|
398 |
+
try:
|
399 |
+
transformers_version = version("transformers")
|
400 |
+
except Exception as e:
|
401 |
+
print(f"Transformers not installed: {e}")
|
402 |
+
return transformers_version
|
403 |
+
|
404 |
+
|
405 |
+
def replace_llama():
|
406 |
+
transformers_version = check_version()
|
407 |
+
version_list = ["4.37"]
|
408 |
+
warning_flag = True
|
409 |
+
for version in version_list:
|
410 |
+
if version in transformers_version:
|
411 |
+
warning_flag = False
|
412 |
+
break
|
413 |
+
if warning_flag:
|
414 |
+
warnings.warn(
|
415 |
+
f"Transformers version {transformers_version} might not be compatible with SnapKV. SnapKV is tested with Transformers version {version_list}."
|
416 |
+
)
|
417 |
+
transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = (
|
418 |
+
prepare_inputs_for_generation_llama_4_37
|
419 |
+
)
|
420 |
+
transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = (
|
421 |
+
llama_flash_attn2_forward_4_37
|
422 |
+
)
|
minference/ops/block_sparse_flash_attention.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
import pycuda.autoprimaryctx
|
7 |
+
from pycuda.compiler import SourceModule
|
8 |
+
|
9 |
+
from flash_attn import flash_attn_varlen_func
|
10 |
+
|
11 |
+
|
12 |
+
# @triton.autotune(
|
13 |
+
# configs=[
|
14 |
+
# triton.Config({}, num_stages=1, num_warps=4),
|
15 |
+
# triton.Config({}, num_stages=1, num_warps=8),
|
16 |
+
# triton.Config({}, num_stages=2, num_warps=4),
|
17 |
+
# triton.Config({}, num_stages=2, num_warps=8),
|
18 |
+
# triton.Config({}, num_stages=3, num_warps=4),
|
19 |
+
# triton.Config({}, num_stages=3, num_warps=8),
|
20 |
+
# triton.Config({}, num_stages=4, num_warps=4),
|
21 |
+
# triton.Config({}, num_stages=4, num_warps=8),
|
22 |
+
# triton.Config({}, num_stages=5, num_warps=4),
|
23 |
+
# triton.Config({}, num_stages=5, num_warps=8),
|
24 |
+
# ],
|
25 |
+
# key=['N_CTX'],
|
26 |
+
# )
|
27 |
+
@triton.jit
|
28 |
+
def triton_block_sparse_attn_kernel(
|
29 |
+
Q, K, V, seqlens, sm_scale,
|
30 |
+
block_index,
|
31 |
+
Out,
|
32 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
33 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
34 |
+
stride_vz, stride_vh, stride_vn, stride_vk,
|
35 |
+
stride_oz, stride_oh, stride_om, stride_ok,
|
36 |
+
Z, H, N_CTX,
|
37 |
+
NUM_ROWS, MAX_BLOCKS_PRE_ROW,
|
38 |
+
BLOCK_M: tl.constexpr,
|
39 |
+
BLOCK_N: tl.constexpr,
|
40 |
+
BLOCK_DMODEL: tl.constexpr,
|
41 |
+
dtype: tl.constexpr,
|
42 |
+
):
|
43 |
+
start_m = tl.program_id(0)
|
44 |
+
off_hz = tl.program_id(1)
|
45 |
+
|
46 |
+
seqlen = tl.load(seqlens + off_hz // H)
|
47 |
+
if start_m * BLOCK_M >= seqlen:
|
48 |
+
return
|
49 |
+
|
50 |
+
# initialize offsets
|
51 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
52 |
+
offs_n = tl.arange(0, BLOCK_N)
|
53 |
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
54 |
+
|
55 |
+
qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh
|
56 |
+
kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh
|
57 |
+
|
58 |
+
q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
59 |
+
k_ptrs = K + kv_offset + offs_d[:, None] * stride_kk
|
60 |
+
v_ptrs = V + kv_offset + offs_d[None, :] * stride_vk
|
61 |
+
o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok
|
62 |
+
|
63 |
+
blocks_ptr = block_index + (off_hz * NUM_ROWS + start_m) * MAX_BLOCKS_PRE_ROW
|
64 |
+
|
65 |
+
# initialize pointer to m and l
|
66 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
67 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
68 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
69 |
+
# scale sm_scale by log_2(e) and use
|
70 |
+
# 2^x instead of exp in the loop because CSE and LICM
|
71 |
+
# don't work as expected with `exp` in the loop
|
72 |
+
qk_scale = sm_scale * 1.44269504
|
73 |
+
# load q: it will stay in SRAM throughout
|
74 |
+
q = tl.load(q_ptrs)
|
75 |
+
q = (q * qk_scale).to(dtype)
|
76 |
+
|
77 |
+
# loop over k, v and update accumulator
|
78 |
+
m_mask = offs_m[:, None] < seqlen
|
79 |
+
block_count = tl.minimum((start_m + 1) * BLOCK_M // BLOCK_N, MAX_BLOCKS_PRE_ROW)
|
80 |
+
|
81 |
+
for sparse_block_idx in range(block_count):
|
82 |
+
real_block_idx = tl.load(blocks_ptr + sparse_block_idx)
|
83 |
+
start_n = real_block_idx * BLOCK_N
|
84 |
+
cols = start_n + offs_n
|
85 |
+
# -- load k, v --
|
86 |
+
k = tl.load(k_ptrs + cols[None, :] * stride_kn)
|
87 |
+
v = tl.load(v_ptrs + cols[:, None] * stride_vn)
|
88 |
+
# -- compute qk --
|
89 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
90 |
+
# if start_n + BLOCK_N < seqlen:
|
91 |
+
# qk = tl.where(m_mask, qk, float("-inf"))
|
92 |
+
# else:
|
93 |
+
causal_mask = cols[None, :] <= offs_m[:, None]
|
94 |
+
qk = tl.where(m_mask & causal_mask, qk, float("-inf"))
|
95 |
+
qk += tl.dot(q, k)
|
96 |
+
# -- compute scaling constant --
|
97 |
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
98 |
+
alpha = tl.math.exp2(m_i - m_i_new)
|
99 |
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
100 |
+
# -- scale and update acc --
|
101 |
+
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
102 |
+
acc *= acc_scale[:, None]
|
103 |
+
acc += tl.dot(p.to(dtype), v)
|
104 |
+
# -- update m_i and l_i --
|
105 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
106 |
+
m_i = m_i_new
|
107 |
+
|
108 |
+
# write back O
|
109 |
+
acc /= l_i[:, None]
|
110 |
+
tl.store(o_ptrs, acc.to(dtype), mask=m_mask)
|
111 |
+
|
112 |
+
|
113 |
+
def triton_block_sparse_forward(
|
114 |
+
q, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
115 |
+
k, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
116 |
+
v, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
117 |
+
seqlens, # [BATCH, ]
|
118 |
+
block_index, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), MAX_BLOCKS_PRE_ROW]
|
119 |
+
sm_scale,
|
120 |
+
block_size_M=64,
|
121 |
+
block_size_N=64,
|
122 |
+
) -> torch.Tensor:
|
123 |
+
# shape constraints
|
124 |
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
125 |
+
assert Lq == Lk and Lk == Lv
|
126 |
+
assert Lk in {16, 32, 64, 128}
|
127 |
+
o = torch.zeros_like(q)
|
128 |
+
grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1)
|
129 |
+
dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16
|
130 |
+
triton_block_sparse_attn_kernel[grid](
|
131 |
+
q, k, v, seqlens, sm_scale,
|
132 |
+
block_index,
|
133 |
+
o,
|
134 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
135 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
136 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
137 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
138 |
+
q.shape[0], q.shape[1], q.shape[2],
|
139 |
+
block_index.shape[-2], block_index.shape[-1],
|
140 |
+
BLOCK_M=block_size_M, BLOCK_N=block_size_N,
|
141 |
+
BLOCK_DMODEL=Lk,
|
142 |
+
dtype=dtype,
|
143 |
+
num_warps=4, num_stages=2,
|
144 |
+
)
|
145 |
+
|
146 |
+
return o
|
147 |
+
|
148 |
+
|
149 |
+
def torch_build_index(
|
150 |
+
query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
151 |
+
key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
152 |
+
top_k: int,
|
153 |
+
block_size_M: int = 64,
|
154 |
+
block_size_N: int = 64,
|
155 |
+
):
|
156 |
+
batch_size, num_heads, context_size, head_dim = query.shape
|
157 |
+
query_pool = query.reshape((batch_size, num_heads, -1, block_size_M, head_dim)).mean(dim=-2)
|
158 |
+
key_pool = key.reshape((batch_size, num_heads, -1, block_size_N, head_dim)).mean(dim=-2)
|
159 |
+
arange_M = torch.arange(query_pool.shape[-2], dtype=torch.int32, device=query.device) * block_size_M
|
160 |
+
arange_N = torch.arange(key_pool.shape[-2], dtype=torch.int32, device=key.device) * block_size_N
|
161 |
+
p_pool = torch.einsum(f'bhmk, bhnk -> bhmn', query_pool, key_pool)
|
162 |
+
p_pool = p_pool.where(arange_M[None, None, :, None] >= arange_N[None, None, None, :], -torch.inf)
|
163 |
+
top_k = min(top_k, context_size // block_size_N)
|
164 |
+
return torch.topk(p_pool, top_k, dim=-1).indices.to(torch.int32).sort(dim=-1).values
|
165 |
+
|
166 |
+
|
167 |
+
def make_causal_mask(seqlens, device, context_size):
|
168 |
+
batch_size = seqlens.shape[0]
|
169 |
+
arange = torch.arange(context_size, dtype=torch.int32, device=device)
|
170 |
+
causal_mask = arange[None, None, :, None] >= arange[None, None, None, :]
|
171 |
+
causal_mask = causal_mask.repeat((batch_size, 1, 1, 1))
|
172 |
+
for b, seqlen in enumerate(seqlens):
|
173 |
+
causal_mask[b, :, seqlen:, :] = False
|
174 |
+
causal_mask[b, :, :, seqlen:] = False
|
175 |
+
return causal_mask
|
176 |
+
|
177 |
+
|
178 |
+
def make_block_mask(block_index, causal_mask, device, block_size_M=64, block_size_N=64):
|
179 |
+
batch_size, num_heads, num_rows, max_blocks_per_row = block_index.shape
|
180 |
+
context_size = causal_mask.shape[-1]
|
181 |
+
block_mask = torch.zeros((batch_size, num_heads, context_size, context_size), dtype=torch.bool, device=device)
|
182 |
+
for b in range(batch_size):
|
183 |
+
for h in range(num_heads):
|
184 |
+
for i in range(num_rows):
|
185 |
+
start_m = i * block_size_M
|
186 |
+
end_m = start_m + block_size_M
|
187 |
+
for j in range(max_blocks_per_row):
|
188 |
+
real_j = block_index[b, h, i, j]
|
189 |
+
start_n = real_j * block_size_N
|
190 |
+
end_n = start_n + block_size_N
|
191 |
+
block_mask[b, h, start_m:end_m, start_n:end_n] = True
|
192 |
+
block_mask.logical_and_(causal_mask)
|
193 |
+
return block_mask
|
194 |
+
|
195 |
+
|
196 |
+
def plot_mask(mask, name, batch=0, head=0):
|
197 |
+
import matplotlib.pyplot as plt
|
198 |
+
import seaborn as sns
|
199 |
+
plt.figure(figsize=(16, 12))
|
200 |
+
plt.clf()
|
201 |
+
mask = mask[batch, head].cpu().numpy()
|
202 |
+
sns.heatmap(mask)
|
203 |
+
plt.savefig(name)
|
204 |
+
|
205 |
+
|
206 |
+
@triton.jit
|
207 |
+
def triton_dense_fwd_kernel(
|
208 |
+
Q, K, V, seqlens, sm_scale,
|
209 |
+
Out,
|
210 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
211 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
212 |
+
stride_vz, stride_vh, stride_vn, stride_vk,
|
213 |
+
stride_oz, stride_oh, stride_om, stride_ok,
|
214 |
+
Z, H, N_CTX,
|
215 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
216 |
+
BLOCK_N: tl.constexpr,
|
217 |
+
dtype: tl.constexpr,
|
218 |
+
):
|
219 |
+
start_m = tl.program_id(0)
|
220 |
+
off_hz = tl.program_id(1)
|
221 |
+
|
222 |
+
seqlen = tl.load(seqlens + off_hz // H)
|
223 |
+
if start_m * BLOCK_M >= seqlen:
|
224 |
+
return
|
225 |
+
|
226 |
+
qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh
|
227 |
+
kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh
|
228 |
+
Q_block_ptr = tl.make_block_ptr(
|
229 |
+
base=Q + qo_offset,
|
230 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
231 |
+
strides=(stride_qm, stride_qk),
|
232 |
+
offsets=(start_m * BLOCK_M, 0),
|
233 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
234 |
+
order=(1, 0)
|
235 |
+
)
|
236 |
+
K_block_ptr = tl.make_block_ptr(
|
237 |
+
base=K + kv_offset,
|
238 |
+
shape=(BLOCK_DMODEL, N_CTX),
|
239 |
+
strides=(stride_kk, stride_kn),
|
240 |
+
offsets=(0, 0),
|
241 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
242 |
+
order=(0, 1)
|
243 |
+
)
|
244 |
+
V_block_ptr = tl.make_block_ptr(
|
245 |
+
base=V + kv_offset,
|
246 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
247 |
+
strides=(stride_vn, stride_vk),
|
248 |
+
offsets=(0, 0),
|
249 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
250 |
+
order=(1, 0)
|
251 |
+
)
|
252 |
+
# initialize offsets
|
253 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
254 |
+
offs_n = tl.arange(0, BLOCK_N)
|
255 |
+
# initialize pointer to m and l
|
256 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
257 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
258 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
259 |
+
# scale sm_scale by log_2(e) and use
|
260 |
+
# 2^x instead of exp in the loop because CSE and LICM
|
261 |
+
# don't work as expected with `exp` in the loop
|
262 |
+
qk_scale = sm_scale * 1.44269504
|
263 |
+
# load q: it will stay in SRAM throughout
|
264 |
+
q = tl.load(Q_block_ptr)
|
265 |
+
q = (q * qk_scale).to(dtype)
|
266 |
+
# loop over k, v and update accumulator
|
267 |
+
lo = 0
|
268 |
+
hi = (start_m + 1) * BLOCK_M
|
269 |
+
m_mask = offs_m[:, None] < seqlen
|
270 |
+
|
271 |
+
for start_n in range(lo, hi, BLOCK_N):
|
272 |
+
n_mask = (start_n + offs_n[None, :]) <= offs_m[:, None]
|
273 |
+
# -- load k, v --
|
274 |
+
k = tl.load(K_block_ptr)
|
275 |
+
v = tl.load(V_block_ptr)
|
276 |
+
# -- compute qk --
|
277 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
278 |
+
qk = tl.where(m_mask & n_mask, qk, float("-inf"))
|
279 |
+
qk += tl.dot(q, k)
|
280 |
+
# -- compute scaling constant --
|
281 |
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
282 |
+
alpha = tl.math.exp2(m_i - m_i_new)
|
283 |
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
284 |
+
# -- scale and update acc --
|
285 |
+
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
286 |
+
acc *= acc_scale[:, None]
|
287 |
+
acc += tl.dot(p.to(dtype), v)
|
288 |
+
# -- update m_i and l_i --
|
289 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
290 |
+
m_i = m_i_new
|
291 |
+
# update pointers
|
292 |
+
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
293 |
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
294 |
+
# write back O
|
295 |
+
acc = tl.where(m_mask, acc / l_i[:, None], 0.0)
|
296 |
+
O_block_ptr = tl.make_block_ptr(
|
297 |
+
base=Out + qo_offset,
|
298 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
299 |
+
strides=(stride_om, stride_ok),
|
300 |
+
offsets=(start_m * BLOCK_M, 0),
|
301 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
302 |
+
order=(1, 0)
|
303 |
+
)
|
304 |
+
tl.store(O_block_ptr, acc.to(dtype))
|
305 |
+
|
306 |
+
|
307 |
+
def triton_dense_forward(q, k, v, seqlens, sm_scale, block_size_M=128, block_size_N=64) -> torch.Tensor:
|
308 |
+
# shape constraints
|
309 |
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
310 |
+
assert Lq == Lk and Lk == Lv
|
311 |
+
assert Lk in {16, 32, 64, 128}
|
312 |
+
o = torch.zeros_like(q)
|
313 |
+
grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1)
|
314 |
+
num_warps = 4 if Lk <= 64 else 8 # 4
|
315 |
+
dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16
|
316 |
+
triton_dense_fwd_kernel[grid](
|
317 |
+
q, k, v, seqlens, sm_scale,
|
318 |
+
o,
|
319 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
320 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
321 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
322 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
323 |
+
q.shape[0], q.shape[1], q.shape[2],
|
324 |
+
BLOCK_M=block_size_M, BLOCK_N=block_size_N,
|
325 |
+
BLOCK_DMODEL=Lk,
|
326 |
+
dtype=dtype,
|
327 |
+
num_warps=num_warps, num_stages=4,
|
328 |
+
)
|
329 |
+
|
330 |
+
return o
|
331 |
+
|
332 |
+
|
333 |
+
def flash_attn_forward(q, k, v, seqlens, sm_scale, context_size) -> torch.Tensor:
|
334 |
+
return flash_attn_varlen_func(
|
335 |
+
q,
|
336 |
+
k,
|
337 |
+
v,
|
338 |
+
cu_seqlens_q=seqlens,
|
339 |
+
cu_seqlens_k=seqlens,
|
340 |
+
max_seqlen_q=context_size,
|
341 |
+
max_seqlen_k=context_size,
|
342 |
+
dropout_p=0.0,
|
343 |
+
softmax_scale=sm_scale,
|
344 |
+
causal=True,
|
345 |
+
)
|
346 |
+
|
347 |
+
|
348 |
+
def torch_forward(
|
349 |
+
query: torch.Tensor,
|
350 |
+
key: torch.Tensor,
|
351 |
+
value: torch.Tensor,
|
352 |
+
mask: torch.Tensor,
|
353 |
+
sm_scale: float,
|
354 |
+
) -> torch.Tensor:
|
355 |
+
p = torch.einsum(f'bhmk, bhnk -> bhmn', query, key) * sm_scale
|
356 |
+
p = p.where(mask, -torch.inf)
|
357 |
+
p_max = p.max(-1, keepdim=True).values
|
358 |
+
p_max = torch.where(p_max < 0, 0.0, p_max)
|
359 |
+
p_exp = torch.exp(p - p_max)
|
360 |
+
s = p_exp / (p_exp.sum(-1, keepdim=True) + 1e-6)
|
361 |
+
out = torch.einsum(f'bhmn, bhnk -> bhmk', s, value)
|
362 |
+
return out
|
363 |
+
|
364 |
+
|
365 |
+
def profile(fn, total_flops, tag, warmup=25, rep=100):
|
366 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
367 |
+
gflops = total_flops / ms * 1e-9
|
368 |
+
print(f'{tag}: {ms:.3f} ms | {gflops:.3f} GFLOP/s')
|
369 |
+
|
370 |
+
|
371 |
+
def test_flash_attention(
|
372 |
+
seqlens=None,
|
373 |
+
dtype=torch.float16,
|
374 |
+
device="cuda",
|
375 |
+
torch_test=True,
|
376 |
+
batch_size=4,
|
377 |
+
num_heads=32,
|
378 |
+
context_size=1024,
|
379 |
+
head_dim=128,
|
380 |
+
top_k=5,
|
381 |
+
block_size_M=64,
|
382 |
+
block_size_N=64,
|
383 |
+
):
|
384 |
+
print('========================================')
|
385 |
+
print(f'BATCH={batch_size}, N_CTX={context_size}, N_HEADS={num_heads}, D_HEAD={head_dim}')
|
386 |
+
q = torch.randn((batch_size, num_heads, context_size, head_dim), dtype=dtype, device=device)
|
387 |
+
k = torch.randn((batch_size, num_heads, context_size, head_dim), dtype=dtype, device=device)
|
388 |
+
v = torch.randn((batch_size, num_heads, context_size, head_dim), dtype=dtype, device=device)
|
389 |
+
if seqlens is None:
|
390 |
+
seqlens = torch.randint(context_size // 2, context_size, (batch_size, ), dtype=torch.int32, device=device)
|
391 |
+
else:
|
392 |
+
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=device)
|
393 |
+
dense_mask_nnz = seqlens.to(torch.float32).square().sum().item() * num_heads / 2
|
394 |
+
sm_scale = head_dim ** -0.5
|
395 |
+
|
396 |
+
causal_mask = make_causal_mask(seqlens, device, context_size)
|
397 |
+
if torch_test:
|
398 |
+
ref_o_dense = torch_forward(q, k, v, causal_mask, sm_scale)
|
399 |
+
|
400 |
+
block_index = torch_build_index(q, k, top_k, block_size_M, block_size_N)
|
401 |
+
arange_M = torch.arange(block_index.shape[-2], device=device)
|
402 |
+
block_index_mask = arange_M[None, None, :, None] * block_size_M >= block_index * block_size_N
|
403 |
+
sparse_mask_nnz = block_index_mask.to(torch.float32).sum().item() * block_size_M * block_size_N
|
404 |
+
print(f'block mask sparsity: {1 - sparse_mask_nnz / dense_mask_nnz}')
|
405 |
+
torch_build_index_fn = lambda: torch_build_index(q, k, top_k, block_size_M, block_size_N)
|
406 |
+
profile(torch_build_index_fn, 0., 'torch-index')
|
407 |
+
|
408 |
+
if torch_test:
|
409 |
+
block_mask = make_block_mask(block_index, causal_mask, device, block_size_M, block_size_N)
|
410 |
+
ref_o_sparse = torch_forward(q, k, v, block_mask, sm_scale)
|
411 |
+
|
412 |
+
triton_dense_fn = lambda: triton_dense_forward(q, k, v, seqlens, sm_scale)
|
413 |
+
output = triton_dense_fn()
|
414 |
+
if torch_test:
|
415 |
+
torch.testing.assert_close(output, ref_o_dense, atol=1e-2, rtol=0)
|
416 |
+
profile(triton_dense_fn, 2. * head_dim * dense_mask_nnz, 'triton-dense')
|
417 |
+
|
418 |
+
triton_sparse_fn = lambda: triton_block_sparse_forward(q, k, v, seqlens, block_index, sm_scale, block_size_M, block_size_N)
|
419 |
+
output = triton_sparse_fn()
|
420 |
+
if torch_test:
|
421 |
+
torch.testing.assert_close(output, ref_o_sparse, atol=1e-2, rtol=0)
|
422 |
+
profile(triton_sparse_fn, 2. * head_dim * sparse_mask_nnz, 'triton-sparse')
|
423 |
+
|
424 |
+
q = q.swapaxes(1, 2).contiguous()
|
425 |
+
k = k.swapaxes(1, 2).contiguous()
|
426 |
+
v = v.swapaxes(1, 2).contiguous()
|
427 |
+
q = torch.concatenate([q[i, :seqlen, :, :] for i, seqlen in enumerate(seqlens)])
|
428 |
+
k = torch.concatenate([k[i, :seqlen, :, :] for i, seqlen in enumerate(seqlens)])
|
429 |
+
v = torch.concatenate([v[i, :seqlen, :, :] for i, seqlen in enumerate(seqlens)])
|
430 |
+
seqlens = torch.nn.functional.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
|
431 |
+
|
432 |
+
flash_fn = lambda: flash_attn_forward(q, k, v, seqlens, sm_scale, context_size)
|
433 |
+
output = flash_fn()
|
434 |
+
output = torch.stack([
|
435 |
+
torch.nn.functional.pad(
|
436 |
+
output[seqlens[i]:seqlens[i + 1], :, :],
|
437 |
+
(0, 0, 0, 0, 0, context_size + seqlens[i] - seqlens[i + 1])
|
438 |
+
)
|
439 |
+
for i in range(batch_size)
|
440 |
+
]).swapaxes(1, 2).contiguous()
|
441 |
+
if torch_test:
|
442 |
+
torch.testing.assert_close(output, ref_o_dense, atol=1e-2, rtol=0)
|
443 |
+
profile(flash_fn, 2. * head_dim * dense_mask_nnz, 'flash-dense')
|
444 |
+
print('========================================\n')
|
445 |
+
|
446 |
+
|
447 |
+
def block_sparse_flash_attention_forward(
|
448 |
+
query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
449 |
+
key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
450 |
+
value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
451 |
+
top_k: int,
|
452 |
+
block_size_M: int = 64,
|
453 |
+
block_size_N: int = 64,
|
454 |
+
):
|
455 |
+
batch_size, num_heads, context_size, head_dim = query.shape
|
456 |
+
pad = block_size_M - (query.shape[2] & (block_size_M - 1))
|
457 |
+
query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0])
|
458 |
+
key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0])
|
459 |
+
value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0])
|
460 |
+
seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device)
|
461 |
+
sm_scale = head_dim ** -0.5
|
462 |
+
block_index = torch_build_index(query, key, top_k, block_size_N, block_size_N)
|
463 |
+
out = triton_block_sparse_forward(query, key, value, seqlens, block_index, sm_scale, block_size_M, block_size_N)
|
464 |
+
return out[..., :context_size, :]
|
minference/ops/pit_sparse_flash_attention.py
ADDED
@@ -0,0 +1,740 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pycuda.autoprimaryctx
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
from flash_attn import flash_attn_varlen_func
|
7 |
+
from pycuda.compiler import SourceModule
|
8 |
+
|
9 |
+
|
10 |
+
@triton.autotune(
|
11 |
+
configs=[
|
12 |
+
triton.Config({}, num_stages=1, num_warps=4),
|
13 |
+
triton.Config({}, num_stages=1, num_warps=8),
|
14 |
+
triton.Config({}, num_stages=2, num_warps=4),
|
15 |
+
triton.Config({}, num_stages=2, num_warps=8),
|
16 |
+
triton.Config({}, num_stages=3, num_warps=4),
|
17 |
+
triton.Config({}, num_stages=3, num_warps=8),
|
18 |
+
triton.Config({}, num_stages=4, num_warps=4),
|
19 |
+
triton.Config({}, num_stages=4, num_warps=8),
|
20 |
+
triton.Config({}, num_stages=5, num_warps=4),
|
21 |
+
triton.Config({}, num_stages=5, num_warps=8),
|
22 |
+
],
|
23 |
+
key=['N_CTX'],
|
24 |
+
)
|
25 |
+
@triton.jit
|
26 |
+
def triton_sparse_fwd_kernel(
|
27 |
+
Q, K, V, seqlens, sm_scale,
|
28 |
+
col_count, col_index,
|
29 |
+
Out,
|
30 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
31 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
32 |
+
stride_vz, stride_vh, stride_vn, stride_vk,
|
33 |
+
stride_oz, stride_oh, stride_om, stride_ok,
|
34 |
+
Z, H, N_CTX,
|
35 |
+
NUM_ROWS, MAX_COLS_PRE_ROW,
|
36 |
+
BLOCK_M: tl.constexpr,
|
37 |
+
BLOCK_N: tl.constexpr,
|
38 |
+
BLOCK_DMODEL: tl.constexpr,
|
39 |
+
dtype: tl.constexpr,
|
40 |
+
):
|
41 |
+
start_m = tl.program_id(0)
|
42 |
+
off_hz = tl.program_id(1)
|
43 |
+
|
44 |
+
seqlen = tl.load(seqlens + off_hz // H)
|
45 |
+
if start_m * BLOCK_M >= seqlen:
|
46 |
+
return
|
47 |
+
|
48 |
+
# initialize offsets
|
49 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
50 |
+
offs_n = tl.arange(0, BLOCK_N)
|
51 |
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
52 |
+
|
53 |
+
qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh
|
54 |
+
kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh
|
55 |
+
|
56 |
+
q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
57 |
+
k_ptrs = K + kv_offset + offs_d[:, None] * stride_kk
|
58 |
+
v_ptrs = V + kv_offset + offs_d[None, :] * stride_vk
|
59 |
+
o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok
|
60 |
+
|
61 |
+
num_cols = tl.load(col_count + off_hz * NUM_ROWS + start_m)
|
62 |
+
cols_ptr = col_index + (off_hz * NUM_ROWS + start_m) * MAX_COLS_PRE_ROW
|
63 |
+
|
64 |
+
# initialize pointer to m and l
|
65 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
66 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
67 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
68 |
+
# scale sm_scale by log_2(e) and use
|
69 |
+
# 2^x instead of exp in the loop because CSE and LICM
|
70 |
+
# don't work as expected with `exp` in the loop
|
71 |
+
qk_scale = sm_scale * 1.44269504
|
72 |
+
# load q: it will stay in SRAM throughout
|
73 |
+
q = tl.load(q_ptrs)
|
74 |
+
q = (q * qk_scale).to(dtype)
|
75 |
+
|
76 |
+
# loop over k, v and update accumulator
|
77 |
+
m_mask = offs_m[:, None] < seqlen
|
78 |
+
split = tl.maximum(num_cols - BLOCK_N, 0) & ~(BLOCK_N - 1)
|
79 |
+
|
80 |
+
for start_n in range(0, split, BLOCK_N):
|
81 |
+
cols = tl.load(cols_ptr + start_n + offs_n)
|
82 |
+
# -- load k, v --
|
83 |
+
k = tl.load(k_ptrs + cols[None, :] * stride_kn)
|
84 |
+
v = tl.load(v_ptrs + cols[:, None] * stride_vn)
|
85 |
+
# -- compute qk --
|
86 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
87 |
+
qk = tl.where(m_mask, qk, float("-inf"))
|
88 |
+
qk += tl.dot(q, k)
|
89 |
+
# -- compute scaling constant --
|
90 |
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
91 |
+
alpha = tl.math.exp2(m_i - m_i_new)
|
92 |
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
93 |
+
# -- scale and update acc --
|
94 |
+
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
95 |
+
acc *= acc_scale[:, None]
|
96 |
+
acc += tl.dot(p.to(dtype), v)
|
97 |
+
# -- update m_i and l_i --
|
98 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
99 |
+
m_i = m_i_new
|
100 |
+
|
101 |
+
for start_n in range(split, num_cols, BLOCK_N):
|
102 |
+
n_mask = start_n + offs_n < num_cols
|
103 |
+
cols = tl.load(cols_ptr + start_n + offs_n, mask=n_mask, other=N_CTX - 1)
|
104 |
+
causal_mask = cols[None, :] <= offs_m[:, None]
|
105 |
+
# -- load k, v --
|
106 |
+
k = tl.load(k_ptrs + cols[None, :] * stride_kn)
|
107 |
+
v = tl.load(v_ptrs + cols[:, None] * stride_vn)
|
108 |
+
# -- compute qk --
|
109 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
110 |
+
qk = tl.where(m_mask & causal_mask, qk, float("-inf"))
|
111 |
+
qk += tl.dot(q, k)
|
112 |
+
# -- compute scaling constant --
|
113 |
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
114 |
+
alpha = tl.math.exp2(m_i - m_i_new)
|
115 |
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
116 |
+
# -- scale and update acc --
|
117 |
+
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
118 |
+
acc *= acc_scale[:, None]
|
119 |
+
acc += tl.dot(p.to(dtype), v)
|
120 |
+
# -- update m_i and l_i --
|
121 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
122 |
+
m_i = m_i_new
|
123 |
+
|
124 |
+
# write back O
|
125 |
+
acc = tl.where(m_mask, acc / l_i[:, None], 0.0)
|
126 |
+
tl.store(o_ptrs, acc.to(dtype), mask=m_mask)
|
127 |
+
|
128 |
+
|
129 |
+
def triton_sparse_forward(
|
130 |
+
q, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
131 |
+
k, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
132 |
+
v, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
133 |
+
seqlens, # [BATCH, ]
|
134 |
+
col_count, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
135 |
+
col_index, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), MAX_COLS_PRE_ROW]
|
136 |
+
sm_scale,
|
137 |
+
block_size_M=64,
|
138 |
+
block_size_N=64,
|
139 |
+
) -> torch.Tensor:
|
140 |
+
# shape constraints
|
141 |
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
142 |
+
assert Lq == Lk and Lk == Lv
|
143 |
+
assert Lk in {16, 32, 64, 128}
|
144 |
+
o = torch.zeros_like(q)
|
145 |
+
grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1)
|
146 |
+
num_warps = 4 if (Lk <= 64 or block_size_M <= 64) else 8 # 4
|
147 |
+
dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16
|
148 |
+
triton_sparse_fwd_kernel[grid](
|
149 |
+
q, k, v, seqlens, sm_scale,
|
150 |
+
col_count, col_index,
|
151 |
+
o,
|
152 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
153 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
154 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
155 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
156 |
+
q.shape[0], q.shape[1], q.shape[2],
|
157 |
+
col_index.shape[-2], col_index.shape[-1],
|
158 |
+
BLOCK_M=block_size_M, BLOCK_N=block_size_N,
|
159 |
+
BLOCK_DMODEL=Lk,
|
160 |
+
dtype=dtype,
|
161 |
+
# num_warps=num_warps, num_stages=4,
|
162 |
+
)
|
163 |
+
|
164 |
+
return o
|
165 |
+
|
166 |
+
|
167 |
+
def torch_build_index(seqlens, vertical_indexes, slash_indexes, block_size_M=64):
|
168 |
+
max_cols_per_row = (seqlens.max().item() + 3) & (-4)
|
169 |
+
batch_size, num_heads, NNZ_S = slash_indexes.shape
|
170 |
+
NNZ_V = vertical_indexes.shape[-1]
|
171 |
+
num_rows = triton.cdiv(max_cols_per_row, block_size_M)
|
172 |
+
max_cols_per_row = max_cols_per_row
|
173 |
+
col_count = torch.zeros((batch_size, num_heads, num_rows), dtype=torch.int32)
|
174 |
+
col_index = torch.zeros((batch_size, num_heads, num_rows, max_cols_per_row), dtype=torch.int32)
|
175 |
+
for b in range(batch_size):
|
176 |
+
seqlen = seqlens[b]
|
177 |
+
for h in range(num_heads):
|
178 |
+
for m, start_m in enumerate(range(0, seqlen, block_size_M)):
|
179 |
+
end_m = start_m + block_size_M
|
180 |
+
tmp_col_count = 0
|
181 |
+
cursor, s, v = -1, 0, 0
|
182 |
+
v_idx = vertical_indexes[b, h, v].item()
|
183 |
+
while s < NNZ_S and slash_indexes[b, h, s] >= end_m:
|
184 |
+
s += 1
|
185 |
+
if s < NNZ_S:
|
186 |
+
s_idx = end_m - slash_indexes[b, h, s].item()
|
187 |
+
s_range = min(s_idx, block_size_M)
|
188 |
+
else:
|
189 |
+
s_idx = seqlen
|
190 |
+
s_range = 0
|
191 |
+
while s_idx <= end_m and v_idx < end_m:
|
192 |
+
if v_idx < s_idx:
|
193 |
+
if v_idx < s_idx - s_range:
|
194 |
+
col_index[b, h, m, tmp_col_count] = v_idx
|
195 |
+
tmp_col_count += 1
|
196 |
+
v += 1
|
197 |
+
if v < NNZ_V:
|
198 |
+
v_idx = vertical_indexes[b, h, v].item()
|
199 |
+
else:
|
200 |
+
break
|
201 |
+
else:
|
202 |
+
for idx in range(max(cursor, s_idx - s_range), min(s_idx, seqlen)):
|
203 |
+
col_index[b, h, m, tmp_col_count] = idx
|
204 |
+
tmp_col_count += 1
|
205 |
+
cursor = s_idx
|
206 |
+
s += 1
|
207 |
+
if s < NNZ_S:
|
208 |
+
s_idx = end_m - slash_indexes[b, h, s].item()
|
209 |
+
s_range = min(s_idx, block_size_M)
|
210 |
+
else:
|
211 |
+
break
|
212 |
+
while s_idx <= end_m and s < NNZ_S:
|
213 |
+
for idx in range(max(cursor, s_idx - s_range), min(s_idx, seqlen)):
|
214 |
+
col_index[b, h, m, tmp_col_count] = idx
|
215 |
+
tmp_col_count += 1
|
216 |
+
cursor = s_idx
|
217 |
+
s += 1
|
218 |
+
if s < NNZ_S:
|
219 |
+
s_idx = end_m - slash_indexes[b, h, s].item()
|
220 |
+
s_range = min(s_idx, block_size_M)
|
221 |
+
else:
|
222 |
+
break
|
223 |
+
while v_idx < end_m and v < NNZ_V:
|
224 |
+
if v_idx < s_idx - s_range:
|
225 |
+
col_index[b, h, m, tmp_col_count] = v_idx
|
226 |
+
tmp_col_count += 1
|
227 |
+
v += 1
|
228 |
+
if v < NNZ_V:
|
229 |
+
v_idx = vertical_indexes[b, h, v].item()
|
230 |
+
else:
|
231 |
+
break
|
232 |
+
col_count[b, h, m] = tmp_col_count
|
233 |
+
return col_count.to(seqlens.device), col_index.to(seqlens.device)
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
PYCUDA_BUILD_INDEX_KERNEL_CODE = '''\
|
238 |
+
__device__ int min(int x, int y) {
|
239 |
+
return x < y ? x : y;
|
240 |
+
}
|
241 |
+
|
242 |
+
__device__ int max(int x, int y) {
|
243 |
+
return x > y ? x : y;
|
244 |
+
}
|
245 |
+
|
246 |
+
__device__ void save_list(int* output, int loop_start, int loop_end, int& offset) {
|
247 |
+
if (loop_start + 4 >= loop_end) {
|
248 |
+
for (int idx = loop_start; idx < loop_end; idx++, offset++) {
|
249 |
+
output[offset] = idx;
|
250 |
+
}
|
251 |
+
return;
|
252 |
+
}
|
253 |
+
int4 tmp_int4;
|
254 |
+
int int4_start = ((offset + 3) & (-4)) - offset + loop_start;
|
255 |
+
int int4_end = ((offset + loop_end - loop_start) & (-4)) - offset + loop_start;
|
256 |
+
for (int idx = loop_start; idx < int4_start; idx++, offset++) {
|
257 |
+
output[offset] = idx;
|
258 |
+
}
|
259 |
+
for (int idx = int4_start; idx < int4_end; idx += 4, offset += 4) {
|
260 |
+
tmp_int4.x = idx + 0;
|
261 |
+
tmp_int4.y = idx + 1;
|
262 |
+
tmp_int4.z = idx + 2;
|
263 |
+
tmp_int4.w = idx + 3;
|
264 |
+
(reinterpret_cast<int4*>(&output[offset]))[0] = tmp_int4;
|
265 |
+
}
|
266 |
+
for (int idx = int4_end; idx < loop_end; idx++, offset++) {
|
267 |
+
output[offset] = idx;
|
268 |
+
}
|
269 |
+
}
|
270 |
+
|
271 |
+
__global__ void PYCUDA_BUILD_INDEX_KERNEL(
|
272 |
+
const int* seqlens, // [BATCH, ]
|
273 |
+
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
274 |
+
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
275 |
+
int* col_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
276 |
+
int* col_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), N_CTX]
|
277 |
+
int N_HEADS,
|
278 |
+
int N_CTX,
|
279 |
+
int BLOCK_SIZE_M,
|
280 |
+
int N_ROWS,
|
281 |
+
int NNZ_V,
|
282 |
+
int NNZ_S
|
283 |
+
) {
|
284 |
+
const int batch_idx = blockIdx.y;
|
285 |
+
const int head_idx = blockIdx.x;
|
286 |
+
const int group_idx = blockIdx.z;
|
287 |
+
|
288 |
+
int seqlen = seqlens[batch_idx];
|
289 |
+
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
|
290 |
+
int start_m = block_idx_m * BLOCK_SIZE_M;
|
291 |
+
if (start_m >= seqlen) {
|
292 |
+
return;
|
293 |
+
}
|
294 |
+
int end_m = start_m + BLOCK_SIZE_M;
|
295 |
+
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
|
296 |
+
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
|
297 |
+
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
|
298 |
+
col_count += row_offset;
|
299 |
+
col_index += row_offset * N_CTX;
|
300 |
+
|
301 |
+
int tmp_col_count = 0, cursor = -1, s = 0, v = 0;
|
302 |
+
int v_idx = vertical_indexes[v];
|
303 |
+
/*
|
304 |
+
int left = 0, right = NNZ_S - 1;
|
305 |
+
int tmp_s_idx = 0, target = end_m - 1;
|
306 |
+
s = (left + right) >> 1;
|
307 |
+
while (left + 1 < right) {
|
308 |
+
tmp_s_idx = slash_indexes[s];
|
309 |
+
if (tmp_s_idx > target) {
|
310 |
+
left = s;
|
311 |
+
} else if (tmp_s_idx < target) {
|
312 |
+
right = s;
|
313 |
+
} else {
|
314 |
+
break;
|
315 |
+
}
|
316 |
+
s = (left + right) >> 1;
|
317 |
+
}
|
318 |
+
*/
|
319 |
+
while (s < NNZ_S && slash_indexes[s] >= end_m) s++;
|
320 |
+
|
321 |
+
int s_idx = (s < NNZ_S) ? (end_m - slash_indexes[s]) : seqlen;
|
322 |
+
int s_range = (s < NNZ_S) ? min(s_idx, BLOCK_SIZE_M) : 0;
|
323 |
+
|
324 |
+
while (s_idx <= end_m && v_idx < end_m) {
|
325 |
+
if (v_idx < s_idx) {
|
326 |
+
if (v_idx < s_idx - s_range) {
|
327 |
+
col_index[tmp_col_count] = v_idx;
|
328 |
+
tmp_col_count++;
|
329 |
+
}
|
330 |
+
v++;
|
331 |
+
if (v < NNZ_V) {
|
332 |
+
v_idx = vertical_indexes[v];
|
333 |
+
} else {
|
334 |
+
break;
|
335 |
+
}
|
336 |
+
} else {
|
337 |
+
save_list(col_index, max(cursor, s_idx - s_range), min(s_idx, seqlen), tmp_col_count);
|
338 |
+
cursor = s_idx;
|
339 |
+
s++;
|
340 |
+
if (s < NNZ_S) {
|
341 |
+
s_idx = end_m - slash_indexes[s];
|
342 |
+
s_range = min(s_idx, BLOCK_SIZE_M);
|
343 |
+
} else {
|
344 |
+
break;
|
345 |
+
}
|
346 |
+
}
|
347 |
+
}
|
348 |
+
while (s_idx <= end_m && s < NNZ_S) {
|
349 |
+
save_list(col_index, max(cursor, s_idx - s_range), min(s_idx, seqlen), tmp_col_count);
|
350 |
+
cursor = s_idx;
|
351 |
+
s++;
|
352 |
+
if (s < NNZ_S) {
|
353 |
+
s_idx = end_m - slash_indexes[s];
|
354 |
+
s_range = min(s_idx, BLOCK_SIZE_M);
|
355 |
+
} else {
|
356 |
+
break;
|
357 |
+
}
|
358 |
+
}
|
359 |
+
while (v_idx < end_m && v < NNZ_V) {
|
360 |
+
if (v_idx < s_idx - s_range) {
|
361 |
+
col_index[tmp_col_count] = v_idx;
|
362 |
+
tmp_col_count++;
|
363 |
+
}
|
364 |
+
v++;
|
365 |
+
if (v < NNZ_V) {
|
366 |
+
v_idx = vertical_indexes[v];
|
367 |
+
} else {
|
368 |
+
break;
|
369 |
+
}
|
370 |
+
}
|
371 |
+
col_count[0] = tmp_col_count;
|
372 |
+
}
|
373 |
+
'''
|
374 |
+
PYCUDA_BUILD_INDEX_KERNEL = SourceModule(
|
375 |
+
PYCUDA_BUILD_INDEX_KERNEL_CODE,
|
376 |
+
options=['-std=c++14', '-O3'],
|
377 |
+
).get_function(f'PYCUDA_BUILD_INDEX_KERNEL')
|
378 |
+
|
379 |
+
|
380 |
+
def pycuda_build_index(seqlens, vertical_indexes, slash_indexes, block_size_M=64):
|
381 |
+
max_cols_per_row = (seqlens.max().item() + 3) & (-4)
|
382 |
+
batch_size, num_heads, NNZ_S = slash_indexes.shape
|
383 |
+
NNZ_V = vertical_indexes.shape[-1]
|
384 |
+
num_rows = triton.cdiv(max_cols_per_row, block_size_M)
|
385 |
+
max_cols_per_row = max_cols_per_row
|
386 |
+
col_count = torch.zeros((batch_size, num_heads, num_rows), dtype=torch.int32, device=seqlens.device)
|
387 |
+
col_index = torch.zeros((batch_size, num_heads, num_rows, max_cols_per_row), dtype=torch.int32, device=seqlens.device)
|
388 |
+
num_threads = 64
|
389 |
+
PYCUDA_BUILD_INDEX_KERNEL(
|
390 |
+
seqlens, vertical_indexes, slash_indexes,
|
391 |
+
col_count, col_index,
|
392 |
+
np.int32(num_heads), np.int32(max_cols_per_row), np.int32(block_size_M), np.int32(num_rows),
|
393 |
+
np.int32(NNZ_V), np.int32(NNZ_S),
|
394 |
+
# grid=(triton.cdiv(num_rows, num_threads), N_HEADS, BATCH),
|
395 |
+
grid=(num_heads, batch_size, triton.cdiv(num_rows, num_threads)),
|
396 |
+
block=(num_threads, 1, 1),
|
397 |
+
)
|
398 |
+
return col_count, col_index
|
399 |
+
|
400 |
+
|
401 |
+
def make_causal_mask(seqlens, device, context_size):
|
402 |
+
batch_size = seqlens.shape[0]
|
403 |
+
arange = torch.arange(context_size, dtype=torch.int32, device=device)
|
404 |
+
causal_mask = arange[None, None, :, None] >= arange[None, None, None, :]
|
405 |
+
causal_mask = causal_mask.repeat((batch_size, 1, 1, 1))
|
406 |
+
for b, seqlen in enumerate(seqlens):
|
407 |
+
causal_mask[b, :, seqlen:, :] = False
|
408 |
+
causal_mask[b, :, :, seqlen:] = False
|
409 |
+
return causal_mask
|
410 |
+
|
411 |
+
|
412 |
+
def make_finegrained_mask(vertical_indexes, slash_indexes, causal_mask, device):
|
413 |
+
batch_size, num_heads, _ = vertical_indexes.shape
|
414 |
+
context_size = causal_mask.shape[-1]
|
415 |
+
arange = torch.arange(context_size, dtype=torch.int32, device=device)
|
416 |
+
sparse_mask = torch.zeros((batch_size, num_heads, context_size, context_size), dtype=torch.bool, device=device)
|
417 |
+
for b in range(batch_size):
|
418 |
+
for h in range(num_heads):
|
419 |
+
for vertical_index in vertical_indexes[b, h]:
|
420 |
+
sparse_mask[b, h, :, vertical_index] = True
|
421 |
+
for slash_index in slash_indexes[b, h]:
|
422 |
+
sparse_mask[b, h].logical_or_(arange[:, None] - arange[None, :] == slash_index)
|
423 |
+
sparse_mask.logical_and_(causal_mask)
|
424 |
+
return sparse_mask
|
425 |
+
|
426 |
+
|
427 |
+
def make_block_mask(col_count, col_index, seqlens, causal_mask, device, block_size_M=64):
|
428 |
+
batch_size, num_heads, _ = col_count.shape
|
429 |
+
context_size = causal_mask.shape[-1]
|
430 |
+
block_mask = torch.zeros((batch_size, num_heads, context_size, context_size), dtype=torch.bool, device=device)
|
431 |
+
for b in range(batch_size):
|
432 |
+
for h in range(num_heads):
|
433 |
+
for m, start_m in enumerate(range(0, seqlens[b], block_size_M)):
|
434 |
+
end_m = start_m + block_size_M
|
435 |
+
for c in range(col_count[b, h, m]):
|
436 |
+
block_mask[b, h, start_m:end_m, col_index[b, h, m, c]] = True
|
437 |
+
block_mask.logical_and_(causal_mask)
|
438 |
+
return block_mask
|
439 |
+
|
440 |
+
|
441 |
+
def plot_mask(mask, name, batch=0, head=0):
|
442 |
+
import matplotlib.pyplot as plt
|
443 |
+
import seaborn as sns
|
444 |
+
plt.figure(figsize=(16, 12))
|
445 |
+
plt.clf()
|
446 |
+
mask = mask[batch, head].cpu().numpy()
|
447 |
+
sns.heatmap(mask)
|
448 |
+
plt.savefig(name)
|
449 |
+
|
450 |
+
|
451 |
+
@triton.jit
|
452 |
+
def triton_dense_fwd_kernel(
|
453 |
+
Q, K, V, seqlens, sm_scale,
|
454 |
+
Out,
|
455 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
456 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
457 |
+
stride_vz, stride_vh, stride_vn, stride_vk,
|
458 |
+
stride_oz, stride_oh, stride_om, stride_ok,
|
459 |
+
Z, H, N_CTX,
|
460 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
461 |
+
BLOCK_N: tl.constexpr,
|
462 |
+
dtype: tl.constexpr,
|
463 |
+
):
|
464 |
+
start_m = tl.program_id(0)
|
465 |
+
off_hz = tl.program_id(1)
|
466 |
+
|
467 |
+
seqlen = tl.load(seqlens + off_hz // H)
|
468 |
+
if start_m * BLOCK_M >= seqlen:
|
469 |
+
return
|
470 |
+
|
471 |
+
qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh
|
472 |
+
kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh
|
473 |
+
Q_block_ptr = tl.make_block_ptr(
|
474 |
+
base=Q + qo_offset,
|
475 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
476 |
+
strides=(stride_qm, stride_qk),
|
477 |
+
offsets=(start_m * BLOCK_M, 0),
|
478 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
479 |
+
order=(1, 0)
|
480 |
+
)
|
481 |
+
K_block_ptr = tl.make_block_ptr(
|
482 |
+
base=K + kv_offset,
|
483 |
+
shape=(BLOCK_DMODEL, N_CTX),
|
484 |
+
strides=(stride_kk, stride_kn),
|
485 |
+
offsets=(0, 0),
|
486 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
487 |
+
order=(0, 1)
|
488 |
+
)
|
489 |
+
V_block_ptr = tl.make_block_ptr(
|
490 |
+
base=V + kv_offset,
|
491 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
492 |
+
strides=(stride_vn, stride_vk),
|
493 |
+
offsets=(0, 0),
|
494 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
495 |
+
order=(1, 0)
|
496 |
+
)
|
497 |
+
# initialize offsets
|
498 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
499 |
+
offs_n = tl.arange(0, BLOCK_N)
|
500 |
+
# initialize pointer to m and l
|
501 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
502 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
503 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
504 |
+
# scale sm_scale by log_2(e) and use
|
505 |
+
# 2^x instead of exp in the loop because CSE and LICM
|
506 |
+
# don't work as expected with `exp` in the loop
|
507 |
+
qk_scale = sm_scale * 1.44269504
|
508 |
+
# load q: it will stay in SRAM throughout
|
509 |
+
q = tl.load(Q_block_ptr)
|
510 |
+
q = (q * qk_scale).to(dtype)
|
511 |
+
# loop over k, v and update accumulator
|
512 |
+
lo = 0
|
513 |
+
hi = (start_m + 1) * BLOCK_M
|
514 |
+
m_mask = offs_m[:, None] < seqlen
|
515 |
+
|
516 |
+
for start_n in range(lo, hi, BLOCK_N):
|
517 |
+
n_mask = (start_n + offs_n[None, :]) <= offs_m[:, None]
|
518 |
+
# -- load k, v --
|
519 |
+
k = tl.load(K_block_ptr)
|
520 |
+
v = tl.load(V_block_ptr)
|
521 |
+
# -- compute qk --
|
522 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
523 |
+
qk = tl.where(m_mask & n_mask, qk, float("-inf"))
|
524 |
+
qk += tl.dot(q, k)
|
525 |
+
# -- compute scaling constant --
|
526 |
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
527 |
+
alpha = tl.math.exp2(m_i - m_i_new)
|
528 |
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
529 |
+
# -- scale and update acc --
|
530 |
+
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
531 |
+
acc *= acc_scale[:, None]
|
532 |
+
acc += tl.dot(p.to(dtype), v)
|
533 |
+
# -- update m_i and l_i --
|
534 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
535 |
+
m_i = m_i_new
|
536 |
+
# update pointers
|
537 |
+
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
538 |
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
539 |
+
# write back O
|
540 |
+
acc = tl.where(m_mask, acc / l_i[:, None], 0.0)
|
541 |
+
O_block_ptr = tl.make_block_ptr(
|
542 |
+
base=Out + qo_offset,
|
543 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
544 |
+
strides=(stride_om, stride_ok),
|
545 |
+
offsets=(start_m * BLOCK_M, 0),
|
546 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
547 |
+
order=(1, 0)
|
548 |
+
)
|
549 |
+
tl.store(O_block_ptr, acc.to(dtype), mask=m_mask)
|
550 |
+
|
551 |
+
|
552 |
+
def triton_dense_forward(q, k, v, seqlens, sm_scale, block_size_M=128, block_size_N=64) -> torch.Tensor:
|
553 |
+
# shape constraints
|
554 |
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
555 |
+
assert Lq == Lk and Lk == Lv
|
556 |
+
assert Lk in {16, 32, 64, 128}
|
557 |
+
o = torch.zeros_like(q)
|
558 |
+
grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1)
|
559 |
+
num_warps = 4 if Lk <= 64 else 8 # 4
|
560 |
+
dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16
|
561 |
+
triton_dense_fwd_kernel[grid](
|
562 |
+
q, k, v, seqlens, sm_scale,
|
563 |
+
o,
|
564 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
565 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
566 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
567 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
568 |
+
q.shape[0], q.shape[1], q.shape[2],
|
569 |
+
BLOCK_M=block_size_M, BLOCK_N=block_size_N,
|
570 |
+
BLOCK_DMODEL=Lk,
|
571 |
+
dtype=dtype,
|
572 |
+
num_warps=num_warps, num_stages=4,
|
573 |
+
)
|
574 |
+
|
575 |
+
return o
|
576 |
+
|
577 |
+
|
578 |
+
def flash_attn_forward(q, k, v, seqlens, sm_scale, context_size) -> torch.Tensor:
|
579 |
+
return flash_attn_varlen_func(
|
580 |
+
q,
|
581 |
+
k,
|
582 |
+
v,
|
583 |
+
cu_seqlens_q=seqlens,
|
584 |
+
cu_seqlens_k=seqlens,
|
585 |
+
max_seqlen_q=context_size,
|
586 |
+
max_seqlen_k=context_size,
|
587 |
+
dropout_p=0.0,
|
588 |
+
softmax_scale=sm_scale,
|
589 |
+
causal=True,
|
590 |
+
)
|
591 |
+
|
592 |
+
|
593 |
+
def torch_forward(
|
594 |
+
query: torch.Tensor,
|
595 |
+
key: torch.Tensor,
|
596 |
+
value: torch.Tensor,
|
597 |
+
mask: torch.Tensor,
|
598 |
+
sm_scale: float,
|
599 |
+
) -> torch.Tensor:
|
600 |
+
p = torch.einsum(f'bhmk, bhnk -> bhmn', query, key) * sm_scale
|
601 |
+
p = p.where(mask, -torch.inf)
|
602 |
+
p_max = p.max(-1, keepdim=True).values
|
603 |
+
p_max = torch.where(p_max < 0, 0.0, p_max)
|
604 |
+
p_exp = torch.exp(p - p_max)
|
605 |
+
s = p_exp / (p_exp.sum(-1, keepdim=True) + 1e-6)
|
606 |
+
out = torch.einsum(f'bhmn, bhnk -> bhmk', s, value)
|
607 |
+
return out
|
608 |
+
|
609 |
+
|
610 |
+
def profile(fn, total_flops, tag, warmup=25, rep=100):
|
611 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
612 |
+
gflops = total_flops / ms * 1e-9
|
613 |
+
print(f'{tag}: {ms:.3f} ms | {gflops:.3f} GFLOP/s')
|
614 |
+
|
615 |
+
|
616 |
+
def test_flash_attention(
|
617 |
+
seqlens=None,
|
618 |
+
vertical_indexes=None,
|
619 |
+
slash_indexes=None,
|
620 |
+
dtype=torch.float16,
|
621 |
+
device="cuda",
|
622 |
+
torch_test=True,
|
623 |
+
batch_size=4,
|
624 |
+
num_heads=32,
|
625 |
+
context_size=1024,
|
626 |
+
head_dim=128,
|
627 |
+
sparsity=0.995,
|
628 |
+
block_size_M=64,
|
629 |
+
block_size_N=64,
|
630 |
+
):
|
631 |
+
print('========================================')
|
632 |
+
print(f'BATCH={batch_size}, N_CTX={context_size}, N_HEADS={num_heads}, D_HEAD={head_dim}')
|
633 |
+
q = torch.randn((batch_size, num_heads, context_size, head_dim), dtype=dtype, device=device)
|
634 |
+
k = torch.randn((batch_size, num_heads, context_size, head_dim), dtype=dtype, device=device)
|
635 |
+
v = torch.randn((batch_size, num_heads, context_size, head_dim), dtype=dtype, device=device)
|
636 |
+
if seqlens is None:
|
637 |
+
seqlens = torch.randint(context_size // 2, context_size, (batch_size, ), dtype=torch.int32, device=device)
|
638 |
+
else:
|
639 |
+
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=device)
|
640 |
+
dense_mask_nnz = seqlens.to(torch.float32).square().sum().item() * num_heads / 2
|
641 |
+
sm_scale = head_dim ** -0.5
|
642 |
+
|
643 |
+
causal_mask = make_causal_mask(seqlens, device, context_size)
|
644 |
+
if torch_test:
|
645 |
+
ref_o_dense = torch_forward(q, k, v, causal_mask, sm_scale)
|
646 |
+
|
647 |
+
if vertical_indexes is None or slash_indexes is None:
|
648 |
+
nnz = int((1 - sparsity) * context_size)
|
649 |
+
vertical_indexes = torch.stack([
|
650 |
+
torch.stack([
|
651 |
+
torch.randperm(seqlen, dtype=torch.int32, device=device)[:nnz].sort(descending=False)[0]
|
652 |
+
for _ in range(num_heads)
|
653 |
+
])
|
654 |
+
for seqlen in seqlens
|
655 |
+
])
|
656 |
+
slash_indexes = torch.concatenate([
|
657 |
+
torch.stack([
|
658 |
+
torch.stack([
|
659 |
+
torch.randperm(seqlen - 1, dtype=torch.int32, device=device)[:nnz].sort(descending=True)[0] + 1
|
660 |
+
for _ in range(num_heads)
|
661 |
+
])
|
662 |
+
for seqlen in seqlens
|
663 |
+
]),
|
664 |
+
torch.zeros((batch_size, num_heads, 1), dtype=torch.int32, device=device)
|
665 |
+
], dim=-1)
|
666 |
+
col_count, col_index = pycuda_build_index(seqlens, vertical_indexes, slash_indexes, block_size_M)
|
667 |
+
if torch_test:
|
668 |
+
col_count_ref, col_index_ref = torch_build_index(seqlens, vertical_indexes, slash_indexes, block_size_M)
|
669 |
+
# import ipdb; ipdb.set_trace()
|
670 |
+
torch.testing.assert_close(col_count_ref, col_count)
|
671 |
+
torch.testing.assert_close(col_index_ref, col_index)
|
672 |
+
sparse_mask_nnz = col_count.to(torch.float32).sum().item() * block_size_M
|
673 |
+
print(f'block mask sparsity: {1 - sparse_mask_nnz / dense_mask_nnz}')
|
674 |
+
pycuda_build_index_fn = lambda: pycuda_build_index(seqlens, vertical_indexes, slash_indexes, block_size_M)
|
675 |
+
profile(pycuda_build_index_fn, 0., 'pycuda-index')
|
676 |
+
|
677 |
+
if torch_test:
|
678 |
+
finegrained_mask = make_finegrained_mask(vertical_indexes, slash_indexes, causal_mask, device)
|
679 |
+
block_mask = make_block_mask(col_count, col_index, seqlens, causal_mask, device, block_size_M)
|
680 |
+
# plot_mask(finegrained_mask, 'mask.png', 2, 26)
|
681 |
+
# plot_mask(block_mask, 'mask-1.png', 2, 26)
|
682 |
+
ref_o_sparse = torch_forward(q, k, v, block_mask, sm_scale)
|
683 |
+
|
684 |
+
triton_dense_fn = lambda: triton_dense_forward(q, k, v, seqlens, sm_scale)
|
685 |
+
output = triton_dense_fn()
|
686 |
+
if torch_test:
|
687 |
+
torch.testing.assert_close(output, ref_o_dense, atol=1e-2, rtol=0)
|
688 |
+
profile(triton_dense_fn, 2. * head_dim * dense_mask_nnz, 'triton-dense')
|
689 |
+
|
690 |
+
triton_sparse_fn = lambda: triton_sparse_forward(q, k, v, seqlens, col_count, col_index, sm_scale, block_size_M, block_size_N)
|
691 |
+
output = triton_sparse_fn()
|
692 |
+
if torch_test:
|
693 |
+
torch.testing.assert_close(output, ref_o_sparse, atol=1e-2, rtol=0)
|
694 |
+
profile(triton_sparse_fn, 2. * head_dim * sparse_mask_nnz, 'triton-sparse')
|
695 |
+
|
696 |
+
q = q.swapaxes(1, 2).contiguous()
|
697 |
+
k = k.swapaxes(1, 2).contiguous()
|
698 |
+
v = v.swapaxes(1, 2).contiguous()
|
699 |
+
q = torch.concatenate([q[i, :seqlen, :, :] for i, seqlen in enumerate(seqlens)])
|
700 |
+
k = torch.concatenate([k[i, :seqlen, :, :] for i, seqlen in enumerate(seqlens)])
|
701 |
+
v = torch.concatenate([v[i, :seqlen, :, :] for i, seqlen in enumerate(seqlens)])
|
702 |
+
seqlens = torch.nn.functional.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
|
703 |
+
|
704 |
+
flash_fn = lambda: flash_attn_forward(q, k, v, seqlens, sm_scale, context_size)
|
705 |
+
output = flash_fn()
|
706 |
+
output = torch.stack([
|
707 |
+
torch.nn.functional.pad(
|
708 |
+
output[seqlens[i]:seqlens[i + 1], :, :],
|
709 |
+
(0, 0, 0, 0, 0, context_size + seqlens[i] - seqlens[i + 1])
|
710 |
+
)
|
711 |
+
for i in range(batch_size)
|
712 |
+
]).swapaxes(1, 2).contiguous()
|
713 |
+
if torch_test:
|
714 |
+
torch.testing.assert_close(output, ref_o_dense, atol=1e-2, rtol=0)
|
715 |
+
profile(flash_fn, 2. * head_dim * dense_mask_nnz, 'flash-dense')
|
716 |
+
print('========================================\n')
|
717 |
+
|
718 |
+
|
719 |
+
def pit_sparse_flash_attention_forward(
|
720 |
+
query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
721 |
+
key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
722 |
+
value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
723 |
+
v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
724 |
+
s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
725 |
+
block_size_M: int = 64,
|
726 |
+
block_size_N: int = 64,
|
727 |
+
):
|
728 |
+
q_len = query.shape[2]
|
729 |
+
pad = block_size_M - (query.shape[2] & (block_size_M - 1))
|
730 |
+
query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0])
|
731 |
+
key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0])
|
732 |
+
value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0])
|
733 |
+
batch_size, num_heads, context_size, head_dim = query.shape
|
734 |
+
v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0]
|
735 |
+
s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0]
|
736 |
+
seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device)
|
737 |
+
sm_scale = head_dim ** -0.5
|
738 |
+
col_count, col_index = pycuda_build_index(seqlens, v_idx, s_idx, block_size_M)
|
739 |
+
out = triton_sparse_forward(query, key, value, seqlens, col_count, col_index, sm_scale, block_size_M, block_size_N)[...,:q_len,:]
|
740 |
+
return out
|
minference/ops/pit_sparse_flash_attention_v2.py
ADDED
@@ -0,0 +1,735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
|
5 |
+
import triton
|
6 |
+
import triton.language as tl
|
7 |
+
import pycuda.autoprimaryctx
|
8 |
+
from pycuda.compiler import SourceModule
|
9 |
+
|
10 |
+
from flash_attn import flash_attn_varlen_func
|
11 |
+
|
12 |
+
|
13 |
+
# @triton.autotune(
|
14 |
+
# configs=[
|
15 |
+
# triton.Config({}, num_stages=1, num_warps=4),
|
16 |
+
# triton.Config({}, num_stages=1, num_warps=8),
|
17 |
+
# triton.Config({}, num_stages=2, num_warps=4),
|
18 |
+
# triton.Config({}, num_stages=2, num_warps=8),
|
19 |
+
# triton.Config({}, num_stages=3, num_warps=4),
|
20 |
+
# triton.Config({}, num_stages=3, num_warps=8),
|
21 |
+
# triton.Config({}, num_stages=4, num_warps=4),
|
22 |
+
# triton.Config({}, num_stages=4, num_warps=8),
|
23 |
+
# triton.Config({}, num_stages=5, num_warps=4),
|
24 |
+
# triton.Config({}, num_stages=5, num_warps=8),
|
25 |
+
# ],
|
26 |
+
# key=['N_CTX'],
|
27 |
+
# )
|
28 |
+
@triton.jit
|
29 |
+
def triton_sparse_fwd_kernel(
|
30 |
+
Q, K, V, seqlens, sm_scale,
|
31 |
+
block_count, block_offset, column_count, column_index,
|
32 |
+
Out,
|
33 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
34 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
35 |
+
stride_vz, stride_vh, stride_vn, stride_vk,
|
36 |
+
stride_oz, stride_oh, stride_om, stride_ok,
|
37 |
+
Z, H, N_CTX,
|
38 |
+
NUM_ROWS, NNZ_S, NNZ_V,
|
39 |
+
BLOCK_M: tl.constexpr,
|
40 |
+
BLOCK_N: tl.constexpr,
|
41 |
+
BLOCK_DMODEL: tl.constexpr,
|
42 |
+
dtype: tl.constexpr,
|
43 |
+
):
|
44 |
+
start_m = tl.program_id(0)
|
45 |
+
off_hz = tl.program_id(1)
|
46 |
+
|
47 |
+
seqlen = tl.load(seqlens + off_hz // H)
|
48 |
+
if start_m * BLOCK_M >= seqlen:
|
49 |
+
return
|
50 |
+
|
51 |
+
# initialize offsets
|
52 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
53 |
+
offs_n = tl.arange(0, BLOCK_N)
|
54 |
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
55 |
+
|
56 |
+
qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh
|
57 |
+
kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh
|
58 |
+
|
59 |
+
q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
60 |
+
k_ptrs = K + kv_offset + offs_d[:, None] * stride_kk
|
61 |
+
v_ptrs = V + kv_offset + offs_d[None, :] * stride_vk
|
62 |
+
o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok
|
63 |
+
|
64 |
+
num_blks = tl.load(block_count + off_hz * NUM_ROWS + start_m)
|
65 |
+
blks_ptr = block_offset + (off_hz * NUM_ROWS + start_m) * NNZ_S
|
66 |
+
num_cols = tl.load(column_count + off_hz * NUM_ROWS + start_m)
|
67 |
+
cols_ptr = column_index + (off_hz * NUM_ROWS + start_m) * NNZ_V
|
68 |
+
|
69 |
+
# initialize pointer to m and l
|
70 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
71 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
72 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
73 |
+
# scale sm_scale by log_2(e) and use
|
74 |
+
# 2^x instead of exp in the loop because CSE and LICM
|
75 |
+
# don't work as expected with `exp` in the loop
|
76 |
+
qk_scale = sm_scale * 1.44269504
|
77 |
+
# load q: it will stay in SRAM throughout
|
78 |
+
q = tl.load(q_ptrs)
|
79 |
+
q = (q * qk_scale).to(dtype)
|
80 |
+
|
81 |
+
# loop over k, v and update accumulator
|
82 |
+
m_mask = offs_m[:, None] < seqlen
|
83 |
+
|
84 |
+
for block_index in range(num_blks):
|
85 |
+
start_n = tl.load(blks_ptr + block_index)
|
86 |
+
cols = start_n + offs_n
|
87 |
+
n_mask = cols < seqlen
|
88 |
+
# -- load k, v --
|
89 |
+
k = tl.load(k_ptrs + cols[None, :] * stride_kn, mask=n_mask[None, :], other=0.0)
|
90 |
+
v = tl.load(v_ptrs + cols[:, None] * stride_vn, mask=n_mask[:, None], other=0.0)
|
91 |
+
# -- compute qk --
|
92 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
93 |
+
causal_mask = cols[None, :] <= offs_m[:, None]
|
94 |
+
qk = tl.where(m_mask & causal_mask, qk, float("-inf"))
|
95 |
+
qk += tl.dot(q, k)
|
96 |
+
# -- compute scaling constant --
|
97 |
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
98 |
+
alpha = tl.math.exp2(m_i - m_i_new)
|
99 |
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
100 |
+
# -- scale and update acc --
|
101 |
+
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
102 |
+
acc *= acc_scale[:, None]
|
103 |
+
acc += tl.dot(p.to(dtype), v)
|
104 |
+
# -- update m_i and l_i --
|
105 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
106 |
+
m_i = m_i_new
|
107 |
+
|
108 |
+
for start_n in range(0, num_cols, BLOCK_N):
|
109 |
+
n_mask = start_n + offs_n < num_cols
|
110 |
+
cols = tl.load(cols_ptr + start_n + offs_n, mask=n_mask, other=0)
|
111 |
+
# -- load k, v --
|
112 |
+
k = tl.load(k_ptrs + cols[None, :] * stride_kn, mask=n_mask[None, :], other=0.0)
|
113 |
+
v = tl.load(v_ptrs + cols[:, None] * stride_vn, mask=n_mask[:, None], other=0.0)
|
114 |
+
# -- compute qk --
|
115 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
116 |
+
qk = tl.where(m_mask & n_mask, qk, float("-inf"))
|
117 |
+
qk += tl.dot(q, k)
|
118 |
+
# -- compute scaling constant --
|
119 |
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
120 |
+
alpha = tl.math.exp2(m_i - m_i_new)
|
121 |
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
122 |
+
# -- scale and update acc --
|
123 |
+
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
124 |
+
acc *= acc_scale[:, None]
|
125 |
+
acc += tl.dot(p.to(dtype), v)
|
126 |
+
# -- update m_i and l_i --
|
127 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
128 |
+
m_i = m_i_new
|
129 |
+
|
130 |
+
# write back O
|
131 |
+
acc /= l_i[:, None]
|
132 |
+
# acc = tl.where(m_mask, acc / l_i[:, None], 0.0)
|
133 |
+
tl.store(o_ptrs, acc.to(dtype), mask=m_mask)
|
134 |
+
|
135 |
+
|
136 |
+
def triton_sparse_forward(
|
137 |
+
q: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
138 |
+
k: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
139 |
+
v: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
140 |
+
seqlens: torch.Tensor, # [BATCH, ]
|
141 |
+
block_count: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
142 |
+
block_offset: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
143 |
+
column_count: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
144 |
+
column_index: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
145 |
+
sm_scale: float,
|
146 |
+
block_size_M: int = 64,
|
147 |
+
block_size_N: int = 64,
|
148 |
+
) -> torch.Tensor:
|
149 |
+
# shape constraints
|
150 |
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
151 |
+
assert Lq == Lk and Lk == Lv
|
152 |
+
assert Lk in {16, 32, 64, 128}
|
153 |
+
o = torch.zeros_like(q)
|
154 |
+
grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1)
|
155 |
+
dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16
|
156 |
+
triton_sparse_fwd_kernel[grid](
|
157 |
+
q, k, v, seqlens, sm_scale,
|
158 |
+
block_count, block_offset, column_count, column_index,
|
159 |
+
o,
|
160 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
161 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
162 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
163 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
164 |
+
q.shape[0], q.shape[1], q.shape[2],
|
165 |
+
block_count.shape[-1], block_offset.shape[-1], column_index.shape[-1],
|
166 |
+
BLOCK_M=block_size_M, BLOCK_N=block_size_N,
|
167 |
+
BLOCK_DMODEL=Lk,
|
168 |
+
dtype=dtype,
|
169 |
+
num_warps=4, num_stages=2,
|
170 |
+
)
|
171 |
+
|
172 |
+
return o
|
173 |
+
|
174 |
+
|
175 |
+
def torch_build_index(seqlens, vertical_indexes, slash_indexes, context_size, block_size_M=64, block_size_N=64):
|
176 |
+
device = seqlens.device
|
177 |
+
batch_size, num_heads, NNZ_S = slash_indexes.shape
|
178 |
+
NNZ_V = vertical_indexes.shape[-1]
|
179 |
+
num_rows = triton.cdiv(context_size, block_size_M)
|
180 |
+
block_count = torch.zeros((batch_size, num_heads, num_rows), dtype=torch.int32)
|
181 |
+
block_offset = torch.zeros((batch_size, num_heads, num_rows, NNZ_S), dtype=torch.int32)
|
182 |
+
column_count = torch.zeros((batch_size, num_heads, num_rows), dtype=torch.int32)
|
183 |
+
column_index = torch.zeros((batch_size, num_heads, num_rows, NNZ_V), dtype=torch.int32)
|
184 |
+
|
185 |
+
for b in range(batch_size):
|
186 |
+
seqlen = seqlens[b]
|
187 |
+
for h in range(num_heads):
|
188 |
+
for m, start_m in enumerate(range(0, seqlen, block_size_M)):
|
189 |
+
end_m = start_m + block_size_M
|
190 |
+
s = 0
|
191 |
+
while slash_indexes[b, h, s] >= end_m:
|
192 |
+
s += 1
|
193 |
+
s_idx = max(end_m - slash_indexes[b, h, s], block_size_M)
|
194 |
+
s += 1
|
195 |
+
range_start = s_idx - block_size_M
|
196 |
+
range_end = s_idx
|
197 |
+
tmp_blocks = []
|
198 |
+
while s < NNZ_S:
|
199 |
+
s_idx = max(end_m - slash_indexes[b, h, s], block_size_M)
|
200 |
+
if s_idx > range_end + block_size_M:
|
201 |
+
tmp_blocks += list(range(range_start, range_end, block_size_N))
|
202 |
+
range_start = s_idx - block_size_M
|
203 |
+
range_end = s_idx
|
204 |
+
elif s_idx > range_end:
|
205 |
+
range_end += block_size_M
|
206 |
+
s += 1
|
207 |
+
tmp_blocks += list(range(range_start, range_end, block_size_N))
|
208 |
+
block_count[b, h, m] = len(tmp_blocks)
|
209 |
+
block_offset[b, h, m, :len(tmp_blocks)] = torch.tensor(tmp_blocks, dtype=block_offset.dtype)
|
210 |
+
tmp_columns = vertical_indexes[b, h].cpu().numpy().tolist()
|
211 |
+
tmp_columns = [col for col in tmp_columns if col < range_end]
|
212 |
+
for range_start in tmp_blocks:
|
213 |
+
range_end = range_start + block_size_N
|
214 |
+
tmp_columns = [col for col in tmp_columns if col < range_start or col >= range_end]
|
215 |
+
column_count[b, h, m] = len(tmp_columns)
|
216 |
+
column_index[b, h, m, :len(tmp_columns)] = torch.tensor(tmp_columns, dtype=block_offset.dtype)
|
217 |
+
|
218 |
+
return block_count.to(device), block_offset.to(device), column_count.to(device), column_index.to(device)
|
219 |
+
|
220 |
+
|
221 |
+
PYCUDA_BUILD_INDEX_KERNEL_CODE = '''\
|
222 |
+
__device__ int min(int x, int y) {
|
223 |
+
return x < y ? x : y;
|
224 |
+
}
|
225 |
+
|
226 |
+
__device__ int max(int x, int y) {
|
227 |
+
return x > y ? x : y;
|
228 |
+
}
|
229 |
+
|
230 |
+
__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
|
231 |
+
for (int idx = range_start; idx < range_end; idx += block_size) {
|
232 |
+
block_offset[block_count++] = idx;
|
233 |
+
}
|
234 |
+
}
|
235 |
+
|
236 |
+
__global__ void PYCUDA_BUILD_INDEX_KERNEL(
|
237 |
+
const int* seqlens, // [BATCH, ]
|
238 |
+
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
239 |
+
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
240 |
+
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
241 |
+
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
242 |
+
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
243 |
+
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
244 |
+
int N_HEADS,
|
245 |
+
int N_ROWS,
|
246 |
+
int BLOCK_SIZE_M,
|
247 |
+
int BLOCK_SIZE_N,
|
248 |
+
int NNZ_V,
|
249 |
+
int NNZ_S
|
250 |
+
) {
|
251 |
+
const int batch_idx = blockIdx.y;
|
252 |
+
const int head_idx = blockIdx.x;
|
253 |
+
const int group_idx = blockIdx.z;
|
254 |
+
|
255 |
+
int seqlen = seqlens[batch_idx];
|
256 |
+
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
|
257 |
+
int start_m = block_idx_m * BLOCK_SIZE_M;
|
258 |
+
if (start_m >= seqlen) {
|
259 |
+
return;
|
260 |
+
}
|
261 |
+
int end_m = start_m + BLOCK_SIZE_M;
|
262 |
+
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
|
263 |
+
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
|
264 |
+
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
|
265 |
+
block_count += row_offset;
|
266 |
+
block_offset += row_offset * NNZ_S;
|
267 |
+
column_count += row_offset;
|
268 |
+
column_index += row_offset * NNZ_V;
|
269 |
+
|
270 |
+
int tmp_col_cnt = 0, tmp_blk_cnt = 0;
|
271 |
+
int s = 0, v = 0;
|
272 |
+
int v_idx = vertical_indexes[v++];
|
273 |
+
int s_idx = slash_indexes[s++];
|
274 |
+
while (s_idx >= end_m) {
|
275 |
+
s_idx = slash_indexes[s++];
|
276 |
+
}
|
277 |
+
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
|
278 |
+
int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
|
279 |
+
while (1) {
|
280 |
+
if (v_idx < range_end) {
|
281 |
+
if (v_idx < range_start) {
|
282 |
+
column_index[tmp_col_cnt++] = v_idx;
|
283 |
+
}
|
284 |
+
if (v < NNZ_V) {
|
285 |
+
v_idx = vertical_indexes[v++];
|
286 |
+
} else {
|
287 |
+
v_idx = end_m + BLOCK_SIZE_M;
|
288 |
+
}
|
289 |
+
} else {
|
290 |
+
if (s < NNZ_S) {
|
291 |
+
s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M);
|
292 |
+
} else {
|
293 |
+
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
|
294 |
+
break;
|
295 |
+
}
|
296 |
+
if (s_idx > range_end + BLOCK_SIZE_M) {
|
297 |
+
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
|
298 |
+
range_start = s_idx - BLOCK_SIZE_M;
|
299 |
+
range_end = s_idx;
|
300 |
+
} else if (s_idx > range_end) {
|
301 |
+
range_end += BLOCK_SIZE_M;
|
302 |
+
}
|
303 |
+
}
|
304 |
+
}
|
305 |
+
|
306 |
+
block_count[0] = tmp_blk_cnt;
|
307 |
+
column_count[0] = tmp_col_cnt;
|
308 |
+
}
|
309 |
+
'''
|
310 |
+
PYCUDA_BUILD_INDEX_KERNEL = SourceModule(
|
311 |
+
PYCUDA_BUILD_INDEX_KERNEL_CODE,
|
312 |
+
options=['-std=c++14', '-O3'],
|
313 |
+
).get_function(f'PYCUDA_BUILD_INDEX_KERNEL')
|
314 |
+
|
315 |
+
|
316 |
+
def pycuda_build_index(seqlens, vertical_indexes, slash_indexes, context_size, block_size_M=64, block_size_N=64):
|
317 |
+
batch_size, num_heads, NNZ_S = slash_indexes.shape
|
318 |
+
NNZ_V = vertical_indexes.shape[-1]
|
319 |
+
num_rows = triton.cdiv(context_size, block_size_M)
|
320 |
+
block_count = torch.zeros((batch_size, num_heads, num_rows), dtype=torch.int32, device=seqlens.device)
|
321 |
+
block_offset = torch.zeros((batch_size, num_heads, num_rows, NNZ_S), dtype=torch.int32, device=seqlens.device)
|
322 |
+
column_count = torch.zeros((batch_size, num_heads, num_rows), dtype=torch.int32, device=seqlens.device)
|
323 |
+
column_index = torch.zeros((batch_size, num_heads, num_rows, NNZ_V), dtype=torch.int32, device=seqlens.device)
|
324 |
+
num_threads = 64
|
325 |
+
# import ipdb; ipdb.set_trace()
|
326 |
+
PYCUDA_BUILD_INDEX_KERNEL(
|
327 |
+
seqlens, vertical_indexes, slash_indexes,
|
328 |
+
block_count, block_offset, column_count, column_index,
|
329 |
+
np.int32(num_heads), np.int32(num_rows),
|
330 |
+
np.int32(block_size_M), np.int32(block_size_N),
|
331 |
+
np.int32(NNZ_V), np.int32(NNZ_S),
|
332 |
+
# grid=(triton.cdiv(num_rows, num_threads), N_HEADS, BATCH),
|
333 |
+
grid=(num_heads, batch_size, triton.cdiv(num_rows, num_threads)),
|
334 |
+
block=(num_threads, 1, 1),
|
335 |
+
)
|
336 |
+
return block_count, block_offset, column_count, column_index
|
337 |
+
|
338 |
+
|
339 |
+
def make_causal_mask(seqlens, device, context_size):
|
340 |
+
batch_size = seqlens.shape[0]
|
341 |
+
arange = torch.arange(context_size, dtype=torch.int32, device=device)
|
342 |
+
causal_mask = arange[None, None, :, None] >= arange[None, None, None, :]
|
343 |
+
causal_mask = causal_mask.repeat((batch_size, 1, 1, 1))
|
344 |
+
for b, seqlen in enumerate(seqlens):
|
345 |
+
causal_mask[b, :, seqlen:, :] = False
|
346 |
+
causal_mask[b, :, :, seqlen:] = False
|
347 |
+
return causal_mask
|
348 |
+
|
349 |
+
|
350 |
+
def make_finegrained_mask(vertical_indexes, slash_indexes, causal_mask, device):
|
351 |
+
batch_size, num_heads, _ = vertical_indexes.shape
|
352 |
+
context_size = causal_mask.shape[-1]
|
353 |
+
arange = torch.arange(context_size, dtype=torch.int32, device=device)
|
354 |
+
sparse_mask = torch.zeros((batch_size, num_heads, context_size, context_size), dtype=torch.bool, device=device)
|
355 |
+
for b in range(batch_size):
|
356 |
+
for h in range(num_heads):
|
357 |
+
for vertical_index in vertical_indexes[b, h]:
|
358 |
+
sparse_mask[b, h, :, vertical_index] = True
|
359 |
+
for slash_index in slash_indexes[b, h]:
|
360 |
+
sparse_mask[b, h].logical_or_(arange[:, None] - arange[None, :] == slash_index)
|
361 |
+
sparse_mask.logical_and_(causal_mask)
|
362 |
+
return sparse_mask
|
363 |
+
|
364 |
+
|
365 |
+
def make_block_mask(
|
366 |
+
block_count: torch.Tensor,
|
367 |
+
block_offset: torch.Tensor,
|
368 |
+
column_count: torch.Tensor,
|
369 |
+
column_index: torch.Tensor,
|
370 |
+
seqlens: torch.Tensor,
|
371 |
+
causal_mask: torch.Tensor,
|
372 |
+
device: torch.device,
|
373 |
+
block_size_M: int = 64,
|
374 |
+
block_size_N: int = 64.
|
375 |
+
):
|
376 |
+
batch_size, num_heads, _ = block_count.shape
|
377 |
+
context_size = causal_mask.shape[-1]
|
378 |
+
block_mask = torch.zeros((batch_size, num_heads, context_size, context_size), dtype=torch.bool, device=device)
|
379 |
+
for b in range(batch_size):
|
380 |
+
for h in range(num_heads):
|
381 |
+
for m, start_m in enumerate(range(0, seqlens[b], block_size_M)):
|
382 |
+
end_m = start_m + block_size_M
|
383 |
+
for col_idx in range(column_count[b, h, m]):
|
384 |
+
block_mask[b, h, start_m:end_m, column_index[b, h, m, col_idx]] = True
|
385 |
+
for blk_idx in range(block_count[b, h, m]):
|
386 |
+
blk_start = block_offset[b, h, m, blk_idx].item()
|
387 |
+
blk_end = blk_start + block_size_N
|
388 |
+
block_mask[b, h, start_m:end_m, blk_start:blk_end] = True
|
389 |
+
block_mask.logical_and_(causal_mask)
|
390 |
+
return block_mask
|
391 |
+
|
392 |
+
|
393 |
+
def plot_mask(mask, name, batch=0, head=0):
|
394 |
+
import matplotlib.pyplot as plt
|
395 |
+
import seaborn as sns
|
396 |
+
plt.figure(figsize=(16, 12))
|
397 |
+
plt.clf()
|
398 |
+
mask = mask[batch, head].cpu().numpy()
|
399 |
+
sns.heatmap(mask)
|
400 |
+
plt.savefig(name)
|
401 |
+
|
402 |
+
|
403 |
+
@triton.jit
|
404 |
+
def triton_dense_fwd_kernel(
|
405 |
+
Q, K, V, seqlens, sm_scale,
|
406 |
+
Out,
|
407 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
408 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
409 |
+
stride_vz, stride_vh, stride_vn, stride_vk,
|
410 |
+
stride_oz, stride_oh, stride_om, stride_ok,
|
411 |
+
Z, H, N_CTX,
|
412 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
413 |
+
BLOCK_N: tl.constexpr,
|
414 |
+
dtype: tl.constexpr,
|
415 |
+
):
|
416 |
+
start_m = tl.program_id(0)
|
417 |
+
off_hz = tl.program_id(1)
|
418 |
+
|
419 |
+
seqlen = tl.load(seqlens + off_hz // H)
|
420 |
+
if start_m * BLOCK_M >= seqlen:
|
421 |
+
return
|
422 |
+
|
423 |
+
qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh
|
424 |
+
kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh
|
425 |
+
Q_block_ptr = tl.make_block_ptr(
|
426 |
+
base=Q + qo_offset,
|
427 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
428 |
+
strides=(stride_qm, stride_qk),
|
429 |
+
offsets=(start_m * BLOCK_M, 0),
|
430 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
431 |
+
order=(1, 0)
|
432 |
+
)
|
433 |
+
K_block_ptr = tl.make_block_ptr(
|
434 |
+
base=K + kv_offset,
|
435 |
+
shape=(BLOCK_DMODEL, N_CTX),
|
436 |
+
strides=(stride_kk, stride_kn),
|
437 |
+
offsets=(0, 0),
|
438 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
439 |
+
order=(0, 1)
|
440 |
+
)
|
441 |
+
V_block_ptr = tl.make_block_ptr(
|
442 |
+
base=V + kv_offset,
|
443 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
444 |
+
strides=(stride_vn, stride_vk),
|
445 |
+
offsets=(0, 0),
|
446 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
447 |
+
order=(1, 0)
|
448 |
+
)
|
449 |
+
# initialize offsets
|
450 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
451 |
+
offs_n = tl.arange(0, BLOCK_N)
|
452 |
+
# initialize pointer to m and l
|
453 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
454 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
455 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
456 |
+
# scale sm_scale by log_2(e) and use
|
457 |
+
# 2^x instead of exp in the loop because CSE and LICM
|
458 |
+
# don't work as expected with `exp` in the loop
|
459 |
+
qk_scale = sm_scale * 1.44269504
|
460 |
+
# load q: it will stay in SRAM throughout
|
461 |
+
q = tl.load(Q_block_ptr)
|
462 |
+
q = (q * qk_scale).to(dtype)
|
463 |
+
# loop over k, v and update accumulator
|
464 |
+
lo = 0
|
465 |
+
hi = (start_m + 1) * BLOCK_M
|
466 |
+
m_mask = offs_m[:, None] < seqlen
|
467 |
+
|
468 |
+
for start_n in range(lo, hi, BLOCK_N):
|
469 |
+
n_mask = (start_n + offs_n[None, :]) <= offs_m[:, None]
|
470 |
+
# -- load k, v --
|
471 |
+
k = tl.load(K_block_ptr)
|
472 |
+
v = tl.load(V_block_ptr)
|
473 |
+
# -- compute qk --
|
474 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
475 |
+
qk = tl.where(m_mask & n_mask, qk, float("-inf"))
|
476 |
+
qk += tl.dot(q, k)
|
477 |
+
# -- compute scaling constant --
|
478 |
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
479 |
+
alpha = tl.math.exp2(m_i - m_i_new)
|
480 |
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
481 |
+
# -- scale and update acc --
|
482 |
+
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
483 |
+
acc *= acc_scale[:, None]
|
484 |
+
acc += tl.dot(p.to(dtype), v)
|
485 |
+
# -- update m_i and l_i --
|
486 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
487 |
+
m_i = m_i_new
|
488 |
+
# update pointers
|
489 |
+
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
490 |
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
491 |
+
# write back O
|
492 |
+
acc = tl.where(m_mask, acc / l_i[:, None], 0.0)
|
493 |
+
O_block_ptr = tl.make_block_ptr(
|
494 |
+
base=Out + qo_offset,
|
495 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
496 |
+
strides=(stride_om, stride_ok),
|
497 |
+
offsets=(start_m * BLOCK_M, 0),
|
498 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
499 |
+
order=(1, 0)
|
500 |
+
)
|
501 |
+
tl.store(O_block_ptr, acc.to(dtype))
|
502 |
+
|
503 |
+
|
504 |
+
def triton_dense_forward(q, k, v, seqlens, sm_scale, block_size_M=128, block_size_N=64) -> torch.Tensor:
|
505 |
+
# shape constraints
|
506 |
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
507 |
+
assert Lq == Lk and Lk == Lv
|
508 |
+
assert Lk in {16, 32, 64, 128}
|
509 |
+
o = torch.zeros_like(q)
|
510 |
+
grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1)
|
511 |
+
num_warps = 4 if Lk <= 64 else 8 # 4
|
512 |
+
dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16
|
513 |
+
triton_dense_fwd_kernel[grid](
|
514 |
+
q, k, v, seqlens, sm_scale,
|
515 |
+
o,
|
516 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
517 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
518 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
519 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
520 |
+
q.shape[0], q.shape[1], q.shape[2],
|
521 |
+
BLOCK_M=block_size_M, BLOCK_N=block_size_N,
|
522 |
+
BLOCK_DMODEL=Lk,
|
523 |
+
dtype=dtype,
|
524 |
+
num_warps=num_warps, num_stages=4,
|
525 |
+
)
|
526 |
+
|
527 |
+
return o
|
528 |
+
|
529 |
+
|
530 |
+
def flash_attn_forward(q, k, v, seqlens, sm_scale, context_size) -> torch.Tensor:
|
531 |
+
return flash_attn_varlen_func(
|
532 |
+
q,
|
533 |
+
k,
|
534 |
+
v,
|
535 |
+
cu_seqlens_q=seqlens,
|
536 |
+
cu_seqlens_k=seqlens,
|
537 |
+
max_seqlen_q=context_size,
|
538 |
+
max_seqlen_k=context_size,
|
539 |
+
dropout_p=0.0,
|
540 |
+
softmax_scale=sm_scale,
|
541 |
+
causal=True,
|
542 |
+
)
|
543 |
+
|
544 |
+
|
545 |
+
def torch_forward(
|
546 |
+
query: torch.Tensor,
|
547 |
+
key: torch.Tensor,
|
548 |
+
value: torch.Tensor,
|
549 |
+
mask: torch.Tensor,
|
550 |
+
sm_scale: float,
|
551 |
+
) -> torch.Tensor:
|
552 |
+
p = torch.einsum(f'bhmk, bhnk -> bhmn', query, key) * sm_scale
|
553 |
+
p = p.where(mask, -torch.inf)
|
554 |
+
p_max = p.max(-1, keepdim=True).values
|
555 |
+
p_max = torch.where(p_max < 0, 0.0, p_max)
|
556 |
+
p_exp = torch.exp(p - p_max)
|
557 |
+
s = p_exp / (p_exp.sum(-1, keepdim=True) + 1e-6)
|
558 |
+
out = torch.einsum(f'bhmn, bhnk -> bhmk', s, value)
|
559 |
+
return out
|
560 |
+
|
561 |
+
|
562 |
+
def profile(fn, total_flops, tag, warmup=25, rep=100):
|
563 |
+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
564 |
+
gflops = total_flops / ms * 1e-9
|
565 |
+
print(f'{tag}: {ms:.3f} ms | {gflops:.3f} GFLOP/s')
|
566 |
+
|
567 |
+
|
568 |
+
def test_flash_attention(
|
569 |
+
query=None,
|
570 |
+
key=None,
|
571 |
+
value=None,
|
572 |
+
seqlens=None,
|
573 |
+
vertical_indexes=None,
|
574 |
+
slash_indexes=None,
|
575 |
+
dtype=torch.float16,
|
576 |
+
device="cuda",
|
577 |
+
torch_test=True,
|
578 |
+
batch_size=4,
|
579 |
+
num_heads=32,
|
580 |
+
context_size=2048,
|
581 |
+
head_dim=128,
|
582 |
+
nnz_v=100,
|
583 |
+
nnz_s=10,
|
584 |
+
block_size_M=64,
|
585 |
+
block_size_N=64,
|
586 |
+
):
|
587 |
+
print('========================================')
|
588 |
+
if query is None and key is None and value is None:
|
589 |
+
q = torch.randn((batch_size, num_heads, context_size, head_dim), dtype=dtype, device=device)
|
590 |
+
k = torch.randn((batch_size, num_heads, context_size, head_dim), dtype=dtype, device=device)
|
591 |
+
v = torch.randn((batch_size, num_heads, context_size, head_dim), dtype=dtype, device=device)
|
592 |
+
else:
|
593 |
+
q = torch.tensor(query, dtype=dtype, device=device)
|
594 |
+
k = torch.tensor(key, dtype=dtype, device=device)
|
595 |
+
v = torch.tensor(value, dtype=dtype, device=device)
|
596 |
+
batch_size, num_heads, context_size, head_dim = q.shape
|
597 |
+
print(f'BATCH={batch_size}, N_CTX={context_size}, N_HEADS={num_heads}, D_HEAD={head_dim}')
|
598 |
+
if seqlens is None:
|
599 |
+
seqlens = torch.randint(context_size // 2, context_size, (batch_size, ), dtype=torch.int32, device=device)
|
600 |
+
else:
|
601 |
+
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=device)
|
602 |
+
print(seqlens)
|
603 |
+
dense_mask_nnz = seqlens.to(torch.float32).square().sum().item() * num_heads / 2
|
604 |
+
sm_scale = head_dim ** -0.5
|
605 |
+
|
606 |
+
if torch_test:
|
607 |
+
causal_mask = make_causal_mask(seqlens, device, context_size)
|
608 |
+
ref_o_dense = torch_forward(q, k, v, causal_mask, sm_scale)
|
609 |
+
|
610 |
+
if vertical_indexes is None or slash_indexes is None:
|
611 |
+
vertical_indexes = torch.stack([
|
612 |
+
torch.stack([
|
613 |
+
torch.randperm(seqlen, dtype=torch.int32, device=device)[:nnz_v].sort(descending=False)[0]
|
614 |
+
for _ in range(num_heads)
|
615 |
+
])
|
616 |
+
for seqlen in seqlens
|
617 |
+
])
|
618 |
+
slash_indexes = torch.concatenate([
|
619 |
+
torch.stack([
|
620 |
+
torch.stack([
|
621 |
+
torch.randperm(seqlen - 1, dtype=torch.int32, device=device)[:nnz_s - 1].sort(descending=True)[0] + 1
|
622 |
+
for _ in range(num_heads)
|
623 |
+
])
|
624 |
+
for seqlen in seqlens
|
625 |
+
]),
|
626 |
+
torch.zeros((batch_size, num_heads, 1), dtype=torch.int32, device=device)
|
627 |
+
], dim=-1)
|
628 |
+
pycuda_build_index_fn = lambda: pycuda_build_index(
|
629 |
+
seqlens, vertical_indexes, slash_indexes, context_size, block_size_M, block_size_N
|
630 |
+
)
|
631 |
+
indexes = pycuda_build_index_fn()
|
632 |
+
block_count, block_offset, column_count, column_index = indexes
|
633 |
+
if torch_test:
|
634 |
+
block_count_ref, block_offset_ref, column_count_ref, column_index_ref = torch_build_index(
|
635 |
+
seqlens, vertical_indexes, slash_indexes, context_size, block_size_M, block_size_N
|
636 |
+
)
|
637 |
+
torch.testing.assert_close(block_count_ref, block_count)
|
638 |
+
torch.testing.assert_close(block_offset_ref, block_offset)
|
639 |
+
torch.testing.assert_close(column_count_ref, column_count)
|
640 |
+
torch.testing.assert_close(column_index_ref, column_index)
|
641 |
+
sparse_mask_nnz = column_count.to(torch.float64).sum().item() * block_size_M + \
|
642 |
+
block_count.to(torch.float64).sum().item() * block_size_M * block_size_N
|
643 |
+
print(f'block mask sparsity: {1 - sparse_mask_nnz / dense_mask_nnz}')
|
644 |
+
|
645 |
+
pycuda_build_index_fn = lambda: pycuda_build_index(
|
646 |
+
seqlens, vertical_indexes, slash_indexes, context_size, block_size_M, block_size_N
|
647 |
+
)
|
648 |
+
profile(pycuda_build_index_fn, 0., 'pycuda-index')
|
649 |
+
|
650 |
+
if torch_test:
|
651 |
+
finegrained_mask = make_finegrained_mask(vertical_indexes, slash_indexes, causal_mask, device)
|
652 |
+
block_mask = make_block_mask(*indexes, seqlens, causal_mask, device, block_size_M, block_size_N)
|
653 |
+
plot_mask(finegrained_mask, 'mask.png', 0, 0)
|
654 |
+
plot_mask(block_mask, 'mask-1.png', 0, 0)
|
655 |
+
ref_o_sparse = torch_forward(q, k, v, block_mask, sm_scale)
|
656 |
+
|
657 |
+
triton_dense_fn = lambda: triton_dense_forward(q, k, v, seqlens, sm_scale)
|
658 |
+
output_triton_dense = triton_dense_fn()
|
659 |
+
if torch_test:
|
660 |
+
# Note: not correct for context_size % block_size_M != 0
|
661 |
+
torch.testing.assert_close(output_triton_dense, ref_o_dense, atol=1e-2, rtol=0)
|
662 |
+
profile(triton_dense_fn, 2. * head_dim * dense_mask_nnz, 'triton-dense')
|
663 |
+
|
664 |
+
triton_sparse_fn = lambda: triton_sparse_forward(q, k, v, seqlens, *indexes, sm_scale, block_size_M, block_size_N)
|
665 |
+
output_triton_sparse = triton_sparse_fn()
|
666 |
+
if torch_test:
|
667 |
+
torch.testing.assert_close(output_triton_sparse, ref_o_sparse, atol=1e-2, rtol=0)
|
668 |
+
profile(triton_sparse_fn, 2. * head_dim * sparse_mask_nnz, 'triton-sparse')
|
669 |
+
|
670 |
+
q = q.swapaxes(1, 2).contiguous()
|
671 |
+
k = k.swapaxes(1, 2).contiguous()
|
672 |
+
v = v.swapaxes(1, 2).contiguous()
|
673 |
+
q = torch.concatenate([q[i, :seqlen, :, :] for i, seqlen in enumerate(seqlens)])
|
674 |
+
k = torch.concatenate([k[i, :seqlen, :, :] for i, seqlen in enumerate(seqlens)])
|
675 |
+
v = torch.concatenate([v[i, :seqlen, :, :] for i, seqlen in enumerate(seqlens)])
|
676 |
+
seqlens = torch.nn.functional.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
|
677 |
+
|
678 |
+
flash_fn = lambda: flash_attn_forward(q, k, v, seqlens, sm_scale, context_size)
|
679 |
+
output_flash = flash_fn()
|
680 |
+
output_flash = torch.stack([
|
681 |
+
torch.nn.functional.pad(
|
682 |
+
output_flash[seqlens[i]:seqlens[i + 1], :, :],
|
683 |
+
(0, 0, 0, 0, 0, context_size + seqlens[i] - seqlens[i + 1])
|
684 |
+
)
|
685 |
+
for i in range(batch_size)
|
686 |
+
]).swapaxes(1, 2).contiguous()
|
687 |
+
if torch_test:
|
688 |
+
torch.testing.assert_close(output_flash, ref_o_dense, atol=1e-2, rtol=0)
|
689 |
+
profile(flash_fn, 2. * head_dim * dense_mask_nnz, 'flash-dense')
|
690 |
+
print('========================================\n')
|
691 |
+
|
692 |
+
if torch_test and sparse_mask_nnz >= dense_mask_nnz:
|
693 |
+
torch.testing.assert_close(output_flash, output_triton_sparse, atol=1e-2, rtol=0)
|
694 |
+
|
695 |
+
|
696 |
+
def pit_sparse_flash_attention_forward(
|
697 |
+
query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
698 |
+
key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
699 |
+
value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
700 |
+
v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
701 |
+
s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
702 |
+
block_size_M: int = 64,
|
703 |
+
block_size_N: int = 64,
|
704 |
+
):
|
705 |
+
batch_size, num_heads, context_size, head_dim = query.shape
|
706 |
+
pad = block_size_M - (context_size & (block_size_M - 1))
|
707 |
+
query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0])
|
708 |
+
key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0])
|
709 |
+
value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0])
|
710 |
+
|
711 |
+
if head_dim not in [16, 32, 64, 128, 256, 512]:
|
712 |
+
target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim
|
713 |
+
query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0])
|
714 |
+
key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0])
|
715 |
+
value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0])
|
716 |
+
|
717 |
+
v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0]
|
718 |
+
s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0]
|
719 |
+
seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device)
|
720 |
+
sm_scale = head_dim ** -0.5
|
721 |
+
block_count, block_offset, column_count, column_index = pycuda_build_index(
|
722 |
+
seqlens, v_idx, s_idx, context_size, block_size_M, block_size_N,
|
723 |
+
)
|
724 |
+
# if context_size > 700000:
|
725 |
+
# import ipdb; ipdb.set_trace()
|
726 |
+
# dense_mask_nnz = seqlens.to(torch.float32).square().sum().item() * num_heads / 2
|
727 |
+
# sparse_mask_nnz = column_count.to(torch.float64).sum().item() * block_size_M + \
|
728 |
+
# block_count.to(torch.float64).sum().item() * block_size_M * block_size_N
|
729 |
+
# print(f'block mask sparsity: {1 - sparse_mask_nnz / dense_mask_nnz}')
|
730 |
+
out = triton_sparse_forward(
|
731 |
+
query, key, value, seqlens,
|
732 |
+
block_count, block_offset, column_count, column_index,
|
733 |
+
sm_scale, block_size_M, block_size_N,
|
734 |
+
)
|
735 |
+
return out[..., :context_size, :head_dim]
|
minference/ops/streaming_kernel.py
ADDED
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Fused Attention
|
3 |
+
===============
|
4 |
+
|
5 |
+
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
|
6 |
+
Credits: OpenAI kernel team
|
7 |
+
|
8 |
+
Extra Credits:
|
9 |
+
- Original flash attention paper (https://arxiv.org/abs/2205.14135)
|
10 |
+
- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
|
11 |
+
|
12 |
+
"""
|
13 |
+
|
14 |
+
import math
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import triton
|
18 |
+
import triton.language as tl
|
19 |
+
|
20 |
+
_BLOCK_N=64
|
21 |
+
_BLOCK_M=64
|
22 |
+
|
23 |
+
@triton.jit
|
24 |
+
def _attn_fwd_inner(acc, l_i, m_i, q,
|
25 |
+
K_block_ptr, V_block_ptr,
|
26 |
+
start_m, qk_scale, N_CTX,
|
27 |
+
sliding_window_offset, sliding_window_size,
|
28 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr,
|
29 |
+
IS_EVEN_M: tl.constexpr, IS_EVEN_N: tl.constexpr, COMPLEMENT_SLIDING_WINDOW: tl.constexpr
|
30 |
+
):
|
31 |
+
# range of values handled by this stage
|
32 |
+
if SLIDING_WINDOW and not COMPLEMENT_SLIDING_WINDOW:
|
33 |
+
if COMPLEMENT_SLIDING_WINDOW:
|
34 |
+
lo = 0
|
35 |
+
hi = (((start_m + 1) * BLOCK_M + sliding_window_offset - sliding_window_size + BLOCK_N - 1) // BLOCK_N) * BLOCK_N
|
36 |
+
else:
|
37 |
+
lo = ((start_m * BLOCK_M + sliding_window_offset - sliding_window_size + 1) // BLOCK_N) * BLOCK_N
|
38 |
+
hi = ((((start_m + 1) * BLOCK_M - 1) + sliding_window_offset + BLOCK_N) // BLOCK_N) * BLOCK_N
|
39 |
+
if lo < 0:
|
40 |
+
lo = 0
|
41 |
+
if hi > N_CTX:
|
42 |
+
hi = N_CTX
|
43 |
+
|
44 |
+
# lo = 0
|
45 |
+
# hi = N_CTX
|
46 |
+
lo = tl.multiple_of(lo, BLOCK_N)
|
47 |
+
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
|
48 |
+
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
|
49 |
+
else:
|
50 |
+
lo, hi = 0, N_CTX
|
51 |
+
|
52 |
+
# loop over k, v and update accumulator
|
53 |
+
for start_n in range(lo, hi, BLOCK_N):
|
54 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
55 |
+
# -- compute qk ----
|
56 |
+
if IS_EVEN_N:
|
57 |
+
k = tl.load(K_block_ptr)
|
58 |
+
else:
|
59 |
+
k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option="zero")
|
60 |
+
|
61 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
62 |
+
qk += tl.dot(q, k)
|
63 |
+
qk = qk * qk_scale
|
64 |
+
|
65 |
+
if SLIDING_WINDOW:
|
66 |
+
dist = tl.arange(0, BLOCK_M)[:, None] - tl.arange(0, BLOCK_N)[None, :] \
|
67 |
+
+ start_m * BLOCK_M - start_n + sliding_window_offset
|
68 |
+
|
69 |
+
if COMPLEMENT_SLIDING_WINDOW:
|
70 |
+
mask = (dist >= sliding_window_size)
|
71 |
+
else:
|
72 |
+
mask = (dist >= 0) & (dist < sliding_window_size)
|
73 |
+
|
74 |
+
qk = tl.where(mask, qk, float("-inf"))
|
75 |
+
|
76 |
+
if not IS_EVEN_N:
|
77 |
+
qk = tl.where(((tl.arange(0, BLOCK_N) + start_n) < N_CTX)[None, :], qk, float("-inf"))
|
78 |
+
|
79 |
+
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
80 |
+
qk = qk - m_ij[:, None]
|
81 |
+
p = tl.math.exp2(qk)
|
82 |
+
|
83 |
+
if SLIDING_WINDOW:
|
84 |
+
p = tl.where(mask, p, 0)
|
85 |
+
|
86 |
+
if not IS_EVEN_N:
|
87 |
+
p = tl.where(((tl.arange(0, BLOCK_N) + start_n) < N_CTX)[None, :], p, 0)
|
88 |
+
|
89 |
+
l_ij = tl.sum(p, 1)
|
90 |
+
# -- update m_i and l_i
|
91 |
+
tmp = m_i - m_ij
|
92 |
+
alpha_mask = (tmp != tmp) # check nan
|
93 |
+
alpha = tl.math.exp2(tmp)
|
94 |
+
alpha = tl.where(alpha_mask, 1., alpha)
|
95 |
+
l_i = l_i * alpha + l_ij
|
96 |
+
# -- update output accumulator --
|
97 |
+
acc = acc * alpha[:, None]
|
98 |
+
# update acc
|
99 |
+
if IS_EVEN_N:
|
100 |
+
v = tl.load(V_block_ptr)
|
101 |
+
else:
|
102 |
+
v = tl.load(V_block_ptr, boundary_check=(0, 1), padding_option="zero")
|
103 |
+
|
104 |
+
acc += tl.dot(p.to(v.dtype), v)
|
105 |
+
# update m_i and l_i
|
106 |
+
m_i = m_ij
|
107 |
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
108 |
+
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
109 |
+
|
110 |
+
return acc, l_i, m_i
|
111 |
+
|
112 |
+
|
113 |
+
@triton.heuristics(
|
114 |
+
{
|
115 |
+
"IS_EVEN_M": lambda args: args["N_CTX"] % args["BLOCK_M"] == 0,
|
116 |
+
"IS_EVEN_N": lambda args: args["NKV_CTX"] % args["BLOCK_N"] == 0,
|
117 |
+
}
|
118 |
+
)
|
119 |
+
@triton.jit
|
120 |
+
def _attn_fwd(Q, K, V, sm_scale, M, Out, L,#
|
121 |
+
stride_qz, stride_qh, stride_qm, stride_qk, #
|
122 |
+
stride_kz, stride_kh, stride_kn, stride_kk, #
|
123 |
+
stride_vz, stride_vh, stride_vk, stride_vn, #
|
124 |
+
stride_oz, stride_oh, stride_om, stride_on, #
|
125 |
+
Z, H, H_KV, #
|
126 |
+
N_CTX, #
|
127 |
+
ROUND_CTX,
|
128 |
+
NKV_CTX,
|
129 |
+
sliding_window_offset,
|
130 |
+
sliding_window_size,
|
131 |
+
IS_EVEN_M: tl.constexpr,
|
132 |
+
IS_EVEN_N: tl.constexpr,
|
133 |
+
BLOCK_M: tl.constexpr, #
|
134 |
+
BLOCK_DMODEL: tl.constexpr, #
|
135 |
+
BLOCK_N: tl.constexpr, #
|
136 |
+
END: tl.constexpr,
|
137 |
+
INIT: tl.constexpr,
|
138 |
+
SLIDING_WINDOW: tl.constexpr,
|
139 |
+
COMPLEMENT_SLIDING_WINDOW: tl.constexpr
|
140 |
+
):
|
141 |
+
|
142 |
+
start_m = tl.program_id(0)
|
143 |
+
off_hz = tl.program_id(1)
|
144 |
+
off_z = off_hz // H
|
145 |
+
off_h = off_hz % H
|
146 |
+
off_hkv = off_h // (H//H_KV)
|
147 |
+
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
|
148 |
+
k_offset = off_z.to(tl.int64) * stride_kz + off_hkv.to(tl.int64) * stride_kh
|
149 |
+
v_offset = off_z.to(tl.int64) * stride_vz + off_hkv.to(tl.int64) * stride_vh
|
150 |
+
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh
|
151 |
+
|
152 |
+
# block pointers
|
153 |
+
Q_block_ptr = tl.make_block_ptr(
|
154 |
+
base=Q + q_offset,
|
155 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
156 |
+
strides=(stride_qm, stride_qk),
|
157 |
+
offsets=(start_m * BLOCK_M, 0),
|
158 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
159 |
+
order=(1, 0),
|
160 |
+
)
|
161 |
+
V_block_ptr = tl.make_block_ptr(
|
162 |
+
base=V + v_offset,
|
163 |
+
shape=(NKV_CTX, BLOCK_DMODEL),
|
164 |
+
strides=(stride_vk, stride_vn),
|
165 |
+
offsets=(0, 0),
|
166 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
167 |
+
order=(1, 0),
|
168 |
+
)
|
169 |
+
K_block_ptr = tl.make_block_ptr(
|
170 |
+
base=K + k_offset,
|
171 |
+
shape=(BLOCK_DMODEL, NKV_CTX),
|
172 |
+
strides=(stride_kk, stride_kn),
|
173 |
+
offsets=(0, 0),
|
174 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
175 |
+
order=(0, 1),
|
176 |
+
)
|
177 |
+
O_block_ptr = tl.make_block_ptr(
|
178 |
+
base=Out + o_offset,
|
179 |
+
shape=(ROUND_CTX, BLOCK_DMODEL),
|
180 |
+
strides=(stride_om, stride_on),
|
181 |
+
offsets=(start_m * BLOCK_M, 0),
|
182 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
183 |
+
order=(1, 0),
|
184 |
+
)
|
185 |
+
# initialize offsets
|
186 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
187 |
+
# initialize pointer to m and l
|
188 |
+
m_ptrs = M + off_hz * ROUND_CTX + offs_m
|
189 |
+
l_ptrs = L + off_hz * ROUND_CTX + offs_m
|
190 |
+
if INIT:
|
191 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
192 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
193 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
194 |
+
else:
|
195 |
+
# don't have to check boundary for q len
|
196 |
+
m_i = tl.load(m_ptrs).to(tl.float32)
|
197 |
+
l_i = tl.load(l_ptrs).to(tl.float32)
|
198 |
+
acc = tl.load(O_block_ptr).to(tl.float32)
|
199 |
+
|
200 |
+
qk_scale = sm_scale
|
201 |
+
qk_scale *= 1.4426950408889634 # 1/log(2)
|
202 |
+
# load q: it will stay in SRAM throughout
|
203 |
+
if IS_EVEN_M:
|
204 |
+
q = tl.load(Q_block_ptr)
|
205 |
+
else:
|
206 |
+
q = tl.load(Q_block_ptr, boundary_check=(0, 1), padding_option="zero")
|
207 |
+
|
208 |
+
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
|
209 |
+
start_m, qk_scale, NKV_CTX, #
|
210 |
+
sliding_window_offset, sliding_window_size,
|
211 |
+
BLOCK_M, BLOCK_DMODEL, BLOCK_N, SLIDING_WINDOW, IS_EVEN_M, IS_EVEN_N,
|
212 |
+
COMPLEMENT_SLIDING_WINDOW)
|
213 |
+
# epilogue
|
214 |
+
if (END):
|
215 |
+
m_i += tl.math.log2(l_i)
|
216 |
+
acc = acc / l_i[:, None]
|
217 |
+
else:
|
218 |
+
tl.store(l_ptrs, l_i)
|
219 |
+
|
220 |
+
tl.store(m_ptrs, m_i)
|
221 |
+
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
|
222 |
+
|
223 |
+
|
224 |
+
@triton.heuristics(
|
225 |
+
{
|
226 |
+
"IS_EVEN_M": lambda args: args["N_CTX"] % args["BLOCK_M"] == 0,
|
227 |
+
"IS_EVEN_N": lambda args: args["NKV_CTX"] % args["BLOCK_N"] == 0,
|
228 |
+
}
|
229 |
+
)
|
230 |
+
@triton.jit
|
231 |
+
def _score_kernel(
|
232 |
+
Q, K, M, sm_scale, Out,
|
233 |
+
stride_qz, stride_qh, stride_qm, stride_qk, #
|
234 |
+
stride_kz, stride_kh, stride_kn, stride_kk, #
|
235 |
+
stride_oz, stride_oh, stride_on,
|
236 |
+
Z, H, H_KV, #
|
237 |
+
N_CTX, #
|
238 |
+
ROUND_CTX,
|
239 |
+
NKV_CTX,
|
240 |
+
sliding_window_offset,
|
241 |
+
sliding_window_size,
|
242 |
+
SLIDING_WINDOW: tl.constexpr,
|
243 |
+
COMPLEMENT_SLIDING_WINDOW: tl.constexpr,
|
244 |
+
IS_EVEN_M: tl.constexpr,
|
245 |
+
IS_EVEN_N: tl.constexpr,
|
246 |
+
BLOCK_M: tl.constexpr, #
|
247 |
+
BLOCK_DMODEL: tl.constexpr, #
|
248 |
+
BLOCK_N: tl.constexpr, #
|
249 |
+
):
|
250 |
+
start_n = tl.program_id(0)
|
251 |
+
off_hz = tl.program_id(1)
|
252 |
+
off_z = off_hz // H
|
253 |
+
off_h = off_hz % H
|
254 |
+
off_hkv = off_h // (H//H_KV)
|
255 |
+
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
|
256 |
+
k_offset = off_z.to(tl.int64) * stride_kz + off_hkv.to(tl.int64) * stride_kh
|
257 |
+
m_ptrs = M + off_hz * ROUND_CTX + tl.arange(0, BLOCK_M)
|
258 |
+
o = tl.zeros([BLOCK_M], dtype=tl.float32)
|
259 |
+
|
260 |
+
Q_block_ptr = tl.make_block_ptr(
|
261 |
+
base=Q + q_offset,
|
262 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
263 |
+
strides=(stride_qm, stride_qk),
|
264 |
+
offsets=(0, 0),
|
265 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
266 |
+
order=(1, 0),
|
267 |
+
)
|
268 |
+
K_block_ptr = tl.make_block_ptr(
|
269 |
+
base=K + k_offset,
|
270 |
+
shape=(BLOCK_DMODEL, NKV_CTX),
|
271 |
+
strides=(stride_kk, stride_kn),
|
272 |
+
offsets=(0, start_n * BLOCK_N),
|
273 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
274 |
+
order=(0, 1),
|
275 |
+
)
|
276 |
+
|
277 |
+
if IS_EVEN_N:
|
278 |
+
k = tl.load(K_block_ptr)
|
279 |
+
else:
|
280 |
+
k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option="zero")
|
281 |
+
|
282 |
+
|
283 |
+
lo = 0
|
284 |
+
hi = ROUND_CTX
|
285 |
+
qk_scale = sm_scale
|
286 |
+
qk_scale *= 1.4426950408889634 # 1/log(2)
|
287 |
+
|
288 |
+
for start_m in range(lo, hi, BLOCK_M):
|
289 |
+
start_m = tl.multiple_of(start_m, BLOCK_M)
|
290 |
+
if IS_EVEN_M:
|
291 |
+
q = tl.load(Q_block_ptr)
|
292 |
+
else:
|
293 |
+
q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero")
|
294 |
+
|
295 |
+
m = tl.load(m_ptrs)
|
296 |
+
|
297 |
+
# calc qk
|
298 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
299 |
+
qk += tl.dot(q, k)
|
300 |
+
qk = qk * qk_scale
|
301 |
+
|
302 |
+
if SLIDING_WINDOW:
|
303 |
+
# dist = tl.arange(start_m, start_m + BLOCK_M)[:, None] \
|
304 |
+
# - tl.arange(start_n * BLOCK_N, (start_n + 1) + BLOCK_N)[None, :] + sliding_window_offset
|
305 |
+
dist = tl.arange(0, BLOCK_M)[:, None] - tl.arange(0, BLOCK_N)[None, :] \
|
306 |
+
+ start_m - start_n * BLOCK_N + sliding_window_offset
|
307 |
+
|
308 |
+
if COMPLEMENT_SLIDING_WINDOW:
|
309 |
+
mask = (dist >= sliding_window_size)
|
310 |
+
else:
|
311 |
+
mask = (dist >= 0) & (dist < sliding_window_size)
|
312 |
+
|
313 |
+
qk = qk - m[:, None]
|
314 |
+
p = tl.math.exp2(qk) # (BLOCK_M, BLOCK_N)
|
315 |
+
|
316 |
+
if SLIDING_WINDOW:
|
317 |
+
p = tl.where(mask, p, 0)
|
318 |
+
|
319 |
+
if not IS_EVEN_N:
|
320 |
+
p = tl.where(
|
321 |
+
((tl.arange(0, BLOCK_M) + start_m) < N_CTX)[:, None],
|
322 |
+
p, 0
|
323 |
+
)
|
324 |
+
|
325 |
+
o += tl.sum(p, axis=0)
|
326 |
+
|
327 |
+
|
328 |
+
Q_block_ptr = tl.advance(Q_block_ptr, offsets=(BLOCK_M, 0))
|
329 |
+
m_ptrs = m_ptrs + BLOCK_M
|
330 |
+
|
331 |
+
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh
|
332 |
+
o_range = tl.arange(0, BLOCK_N) + start_n * BLOCK_N # orange
|
333 |
+
o_ptrs = Out + o_offset + o_range
|
334 |
+
tl.store(o_ptrs, o.to(Out.type.element_ty), mask = o_range < NKV_CTX)
|
335 |
+
|
336 |
+
def get_score(q, k, m, sliding_window, complement_sliding_window):
|
337 |
+
assert q.dim() == 4
|
338 |
+
assert k.dim() == 4
|
339 |
+
assert m.dim() == 3
|
340 |
+
assert q.shape[:2] == m.shape[:2]
|
341 |
+
N_CTX = q.size(-2)
|
342 |
+
NKV_CTX = k.size(-2)
|
343 |
+
ROUND_CTX = m.size(-1)
|
344 |
+
ret = torch.zeros(
|
345 |
+
(q.size(0), q.size(1), k.size(2)),
|
346 |
+
dtype=k.dtype, device=k.device
|
347 |
+
)
|
348 |
+
if sliding_window is not None:
|
349 |
+
sliding_window_offset, sliding_window_size = sliding_window
|
350 |
+
else:
|
351 |
+
sliding_window_offset, sliding_window_size = None, None
|
352 |
+
|
353 |
+
|
354 |
+
grid = lambda META: (
|
355 |
+
triton.cdiv(k.shape[2], META["BLOCK_N"]),
|
356 |
+
q.shape[0] * q.shape[1]
|
357 |
+
)
|
358 |
+
sm_scale = 1 / math.sqrt(q.size(-1))
|
359 |
+
|
360 |
+
global _BLOCK_N
|
361 |
+
global _BLOCK_M
|
362 |
+
|
363 |
+
try:
|
364 |
+
_score_kernel[grid](
|
365 |
+
q, k, m, sm_scale, ret,
|
366 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
367 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
368 |
+
ret.stride(0), ret.stride(1), ret.stride(2),
|
369 |
+
q.size(0), q.size(1), k.size(1),
|
370 |
+
N_CTX, ROUND_CTX, NKV_CTX,
|
371 |
+
sliding_window_offset,
|
372 |
+
sliding_window_size,
|
373 |
+
SLIDING_WINDOW=(sliding_window is not None),
|
374 |
+
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,
|
375 |
+
BLOCK_M=_BLOCK_M,
|
376 |
+
BLOCK_N=_BLOCK_N,
|
377 |
+
BLOCK_DMODEL=q.size(-1)
|
378 |
+
)
|
379 |
+
except triton.OutOfResources as E:
|
380 |
+
from warnings import warn
|
381 |
+
_BLOCK_N = _BLOCK_N // 2
|
382 |
+
_BLOCK_M = _BLOCK_M // 2
|
383 |
+
warn(f"Triton Attention Output Resources. {E}\nUse smaller block size {_BLOCK_N}.")
|
384 |
+
_score_kernel[grid](
|
385 |
+
q, k, m, sm_scale, ret,
|
386 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
387 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
388 |
+
ret.stride(0), ret.stride(1), ret.stride(2),
|
389 |
+
q.size(0), q.size(1), k.size(1),
|
390 |
+
N_CTX, ROUND_CTX, NKV_CTX,
|
391 |
+
sliding_window_offset,
|
392 |
+
sliding_window_size,
|
393 |
+
SLIDING_WINDOW=(sliding_window is not None),
|
394 |
+
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,
|
395 |
+
BLOCK_M=_BLOCK_M,
|
396 |
+
BLOCK_N=_BLOCK_N,
|
397 |
+
BLOCK_DMODEL=q.size(-1)
|
398 |
+
)
|
399 |
+
|
400 |
+
return ret
|
401 |
+
|
402 |
+
def _forward(
|
403 |
+
q, k, v, sm_scale,
|
404 |
+
o = None, m = None, l = None, end = False,
|
405 |
+
sliding_window=None, init=False,
|
406 |
+
complement_sliding_window=False
|
407 |
+
):
|
408 |
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
409 |
+
|
410 |
+
assert Lq == Lk and Lk == Lv
|
411 |
+
assert Lk in {16, 32, 64, 128}
|
412 |
+
|
413 |
+
q_round_len = math.ceil(q.shape[2] / 64) * 64
|
414 |
+
|
415 |
+
if sliding_window is not None:
|
416 |
+
sliding_window_offset, sliding_window_size = sliding_window
|
417 |
+
else:
|
418 |
+
sliding_window_offset, sliding_window_size = None, None
|
419 |
+
|
420 |
+
grid = lambda META: (
|
421 |
+
triton.cdiv(q.shape[2], META["BLOCK_M"]),
|
422 |
+
q.shape[0] * q.shape[1],
|
423 |
+
)
|
424 |
+
|
425 |
+
global _BLOCK_N
|
426 |
+
global _BLOCK_M
|
427 |
+
|
428 |
+
try:
|
429 |
+
_attn_fwd[grid](
|
430 |
+
q, k, v, sm_scale, m, o, l, #
|
431 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
432 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
433 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
434 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
435 |
+
q.shape[0], q.shape[1], k.shape[1], #
|
436 |
+
q.shape[2], #
|
437 |
+
q_round_len,
|
438 |
+
k.shape[2],
|
439 |
+
sliding_window_offset,
|
440 |
+
sliding_window_size,
|
441 |
+
BLOCK_DMODEL=Lk, #
|
442 |
+
END=end,
|
443 |
+
INIT=init,
|
444 |
+
BLOCK_M=_BLOCK_M,
|
445 |
+
BLOCK_N=_BLOCK_N,
|
446 |
+
SLIDING_WINDOW=(sliding_window is not None),
|
447 |
+
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,
|
448 |
+
num_warps=4,
|
449 |
+
num_stages=4
|
450 |
+
)
|
451 |
+
except triton.OutOfResources as E:
|
452 |
+
_BLOCK_N = _BLOCK_N // 2
|
453 |
+
_BLOCK_M = _BLOCK_M // 2
|
454 |
+
from warnings import warn
|
455 |
+
warn(f"Triton Attention Output Resources. {E}\nUse smaller block size {_BLOCK_N}.")
|
456 |
+
_attn_fwd[grid](
|
457 |
+
q, k, v, sm_scale, m, o, l, #
|
458 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
459 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
460 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
461 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
462 |
+
q.shape[0], q.shape[1], k.shape[1], #
|
463 |
+
q.shape[2], #
|
464 |
+
q_round_len,
|
465 |
+
k.shape[2],
|
466 |
+
sliding_window_offset,
|
467 |
+
sliding_window_size,
|
468 |
+
BLOCK_DMODEL=Lk, #
|
469 |
+
END=end,
|
470 |
+
INIT=init,
|
471 |
+
BLOCK_M=_BLOCK_M,
|
472 |
+
BLOCK_N=_BLOCK_N,
|
473 |
+
SLIDING_WINDOW=(sliding_window is not None),
|
474 |
+
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,
|
475 |
+
num_warps=4,
|
476 |
+
num_stages=4
|
477 |
+
)
|
478 |
+
|
479 |
+
|
480 |
+
if end:
|
481 |
+
o = o[:, :, :q.shape[2], :].contiguous().to(q.dtype)
|
482 |
+
|
483 |
+
return o, m, l
|
484 |
+
|
485 |
+
class MultiStageDotProductionAttention:
|
486 |
+
def __init__(
|
487 |
+
self,
|
488 |
+
q_shape,
|
489 |
+
dtype,
|
490 |
+
device,
|
491 |
+
):
|
492 |
+
self.q_shape = q_shape
|
493 |
+
self.dtype = dtype
|
494 |
+
self.device = device
|
495 |
+
self.end = False
|
496 |
+
self.ret = torch.zeros(
|
497 |
+
q_shape, dtype=dtype, device=device
|
498 |
+
)
|
499 |
+
self.score_list = []
|
500 |
+
|
501 |
+
def append(
|
502 |
+
self,
|
503 |
+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
504 |
+
sliding_window=None, complement_sliding_window: bool = False,
|
505 |
+
end=False, get_score=False,
|
506 |
+
*args, **kwargs
|
507 |
+
):
|
508 |
+
raise NotImplementedError
|
509 |
+
|
510 |
+
|
511 |
+
def get_result(self):
|
512 |
+
return self.ret, self.score_list
|
513 |
+
|
514 |
+
|
515 |
+
class TritonMultiStageDotProductionAttention(MultiStageDotProductionAttention):
|
516 |
+
def __init__(self, q_shape, dtype, device):
|
517 |
+
self.q_shape = q_shape
|
518 |
+
self.dtype = dtype
|
519 |
+
self.device = device
|
520 |
+
q_round_len = math.ceil(q_shape[2] / 64) * 64
|
521 |
+
o_shape = (q_shape[0], q_shape[1], q_round_len, q_shape[3])
|
522 |
+
m_shape = (q_shape[0], q_shape[1], q_round_len)
|
523 |
+
l_shape = (q_shape[0], q_shape[1], q_round_len)
|
524 |
+
|
525 |
+
self.o = torch.empty(o_shape, device=device, dtype=torch.float32)
|
526 |
+
self.m = torch.empty(m_shape, device=device, dtype=torch.float32)
|
527 |
+
self.l = torch.empty(l_shape, device=device, dtype=torch.float32)
|
528 |
+
self.q_list = []
|
529 |
+
self.k_list = []
|
530 |
+
self.sliding_window_list = []
|
531 |
+
self.complement_sliding_window_list = []
|
532 |
+
self.score_list = []
|
533 |
+
self.end = False
|
534 |
+
self.init = False
|
535 |
+
|
536 |
+
def finalize(self):
|
537 |
+
self.end = True
|
538 |
+
for q, k, sliding_window, comp in zip(self.q_list, self.k_list, self.sliding_window_list, self.complement_sliding_window_list):
|
539 |
+
if q is not None:
|
540 |
+
score = get_score(q, k, self.m, sliding_window, comp)
|
541 |
+
self.score_list.append(score)
|
542 |
+
else:
|
543 |
+
self.score_list.append(None)
|
544 |
+
|
545 |
+
self.ret = self.o
|
546 |
+
|
547 |
+
def append(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, end=False, get_score=False, sliding_window = None, complement_sliding_window: bool = False):
|
548 |
+
assert q.shape == self.q_shape
|
549 |
+
|
550 |
+
if isinstance(sliding_window, int):
|
551 |
+
sliding_window = (
|
552 |
+
k.shape[2] - q.shape[2], sliding_window
|
553 |
+
)
|
554 |
+
|
555 |
+
q = q.contiguous()
|
556 |
+
k = k.contiguous()
|
557 |
+
v = v.contiguous()
|
558 |
+
|
559 |
+
sm_scale = 1 / math.sqrt(q.shape[-1])
|
560 |
+
o, m, l = _forward(
|
561 |
+
q, k, v, sm_scale, self.o, self.m, self.l,
|
562 |
+
sliding_window=sliding_window, end=end, init=not self.init,
|
563 |
+
complement_sliding_window=complement_sliding_window
|
564 |
+
)
|
565 |
+
self.init = True
|
566 |
+
self.o = o
|
567 |
+
self.m = m
|
568 |
+
self.l = l
|
569 |
+
if get_score:
|
570 |
+
self.q_list.append(q)
|
571 |
+
self.k_list.append(k)
|
572 |
+
self.sliding_window_list.append(sliding_window)
|
573 |
+
self.complement_sliding_window_list.append(complement_sliding_window)
|
574 |
+
else:
|
575 |
+
self.q_list.append(None)
|
576 |
+
self.k_list.append(None)
|
577 |
+
self.sliding_window_list.append(None)
|
578 |
+
self.complement_sliding_window_list.append(None)
|
579 |
+
|
580 |
+
if end:
|
581 |
+
assert not self.end
|
582 |
+
self.finalize()
|
583 |
+
|
584 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
585 |
+
"""
|
586 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
587 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
588 |
+
"""
|
589 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
590 |
+
if n_rep == 1:
|
591 |
+
return hidden_states
|
592 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
593 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
594 |
+
|
595 |
+
def streaming_forward(
|
596 |
+
q, k, v,
|
597 |
+
n_init, n_local,
|
598 |
+
):
|
599 |
+
# q,k,v should be tensors already equipped with RoPE
|
600 |
+
# k,v should already repeated to align with q.shape
|
601 |
+
|
602 |
+
assert q.dim() == 4 # (bsz, num_heads, seqlen, head_dim)
|
603 |
+
assert q.shape == k.shape == v.shape
|
604 |
+
|
605 |
+
head_dim = q.shape[-1]
|
606 |
+
if head_dim not in [16, 32, 64, 128, 256, 512]:
|
607 |
+
target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim
|
608 |
+
q = torch.nn.functional.pad(q, [0, target_dim, 0, 0, 0, 0, 0, 0])
|
609 |
+
k = torch.nn.functional.pad(k, [0, target_dim, 0, 0, 0, 0, 0, 0])
|
610 |
+
v = torch.nn.functional.pad(v, [0, target_dim, 0, 0, 0, 0, 0, 0])
|
611 |
+
|
612 |
+
q_len = q.size(2)
|
613 |
+
k_len = k.size(2)
|
614 |
+
|
615 |
+
attn = TritonMultiStageDotProductionAttention(q.shape, q.dtype, q.device)
|
616 |
+
|
617 |
+
if k_len > n_local:
|
618 |
+
init_k = k[:, :, :n_init, :].contiguous()
|
619 |
+
init_v = v[:, :, :n_init, :].contiguous()
|
620 |
+
|
621 |
+
attn.append(q, k, v, sliding_window=n_local)
|
622 |
+
attn.append(
|
623 |
+
q, init_k, init_v, end=True,
|
624 |
+
sliding_window=(k_len - q_len, n_local), complement_sliding_window=True
|
625 |
+
)
|
626 |
+
else:
|
627 |
+
attn.append(q, k, v, sliding_window=n_local, end=True)
|
628 |
+
|
629 |
+
score, _ = attn.get_result()
|
630 |
+
return score[...,:head_dim]
|
631 |
+
|
632 |
+
def streaming_forward2(
|
633 |
+
q, k, v,
|
634 |
+
n_init, n_local,
|
635 |
+
):
|
636 |
+
q_len = q.size(2)
|
637 |
+
k_len = k.size(2)
|
638 |
+
|
639 |
+
attn = TritonMultiStageDotProductionAttention(q.shape, q.dtype, q.device)
|
640 |
+
|
641 |
+
if k_len > n_local:
|
642 |
+
init_k = k[:, :, :n_init, :].contiguous()
|
643 |
+
init_v = v[:, :, :n_init, :].contiguous()
|
644 |
+
|
645 |
+
else:
|
646 |
+
init_k = torch.empty(
|
647 |
+
(k.size(0), k.size(1), 0, k.size(3)),
|
648 |
+
dtype=k.dtype, device=k.device
|
649 |
+
)
|
650 |
+
init_v = torch.empty(
|
651 |
+
(v.size(0), v.size(1), 0, v.size(3)),
|
652 |
+
dtype=v.dtype, device=v.device
|
653 |
+
)
|
654 |
+
|
655 |
+
attn.append(q, k, v, sliding_window=n_local)
|
656 |
+
attn.append(
|
657 |
+
q, init_k, init_v, end=True,
|
658 |
+
sliding_window=(k_len - q_len, n_local), complement_sliding_window=True
|
659 |
+
)
|
660 |
+
|
661 |
+
score, _ = attn.get_result()
|
662 |
+
return score
|
663 |
+
|
664 |
+
def stream_llm_forward(n_local, n_init, *args, **kwargs):
|
665 |
+
Attn = TritonMultiStageDotProductionAttention
|
666 |
+
def forward(self, query : torch.Tensor,
|
667 |
+
key_value : torch.Tensor,
|
668 |
+
position_bias : torch.Tensor,
|
669 |
+
use_cache: bool,
|
670 |
+
past_key_value,
|
671 |
+
project_q, project_k, project_v, attention_out,
|
672 |
+
dim_head, num_heads, num_heads_kv
|
673 |
+
):
|
674 |
+
|
675 |
+
batch_size = query.size(0)
|
676 |
+
len_q = query.size(1)
|
677 |
+
len_k = key_value.size(1)
|
678 |
+
|
679 |
+
h_q = project_q(query) # (batch, len_q, num_heads * dim_head)
|
680 |
+
h_k = project_k(key_value) # (batch, len_k, num_heads * dim_head)
|
681 |
+
h_v = project_v(key_value) # (batch, len_k, num_heads * dim_head)
|
682 |
+
|
683 |
+
h_q = h_q.view(batch_size, len_q, num_heads, dim_head).permute(0, 2, 1, 3) # (batch, num_heads, len_q, dim_head)
|
684 |
+
h_k = h_k.view(batch_size, len_k, num_heads_kv, dim_head).permute(0, 2, 1, 3) # (batch, num_heads_kv, len_k, dim_head)
|
685 |
+
h_v = h_v.view(batch_size, len_k, num_heads_kv, dim_head).permute(0, 2, 1, 3) # (batch, num_heads_kv, len_k, dim_head)
|
686 |
+
|
687 |
+
h_q = h_q.contiguous() # (batch * num_heads, len_q, dim_head)
|
688 |
+
h_k = h_k.contiguous() # (batch * num_heads, len_k, dim_head)
|
689 |
+
h_v = h_v.contiguous() # (batch * num_heads, len_k, dim_head)
|
690 |
+
|
691 |
+
if past_key_value is not None:
|
692 |
+
h_k = torch.cat([past_key_value[0], h_k], dim=-2)
|
693 |
+
h_v = torch.cat([past_key_value[1], h_v], dim=-2)
|
694 |
+
|
695 |
+
len_k += past_key_value[2]
|
696 |
+
|
697 |
+
if use_cache:
|
698 |
+
if len_k <= n_local + n_init:
|
699 |
+
h_k_cache = h_k
|
700 |
+
h_v_cache = h_v
|
701 |
+
else:
|
702 |
+
h_k_cache = torch.cat([h_k[:,:, :n_init, :], h_k[:, :, max(0, h_k.size(-2) - n_local):, :]], dim=2)
|
703 |
+
h_v_cache = torch.cat([h_v[:,:, :n_init, :], h_v[:, :, max(0, h_k.size(-2) - n_local):, :]], dim=2)
|
704 |
+
|
705 |
+
current_key_value = (h_k_cache, h_v_cache, len_k)
|
706 |
+
|
707 |
+
else:
|
708 |
+
current_key_value = None
|
709 |
+
|
710 |
+
h_q_ = h_q
|
711 |
+
h_k_ = h_k
|
712 |
+
h_v_ = h_v
|
713 |
+
|
714 |
+
if len_q + n_local < h_k_.size(-2):
|
715 |
+
h_k_ = h_k_[:, :, h_k_.size(-2) - len_q - n_local:, :].contiguous().clone()
|
716 |
+
h_v_ = h_v_[:, :, h_v_.size(-2) - len_q - n_local:, :].contiguous().clone()
|
717 |
+
|
718 |
+
local_h_q, local_h_k = position_bias(h_q_, h_k_)
|
719 |
+
local_h_v = h_v_
|
720 |
+
|
721 |
+
if len_k > n_local:
|
722 |
+
init_h_q = position_bias.apply_rotary_pos_emb_one_angle(
|
723 |
+
h_q, n_local + n_init
|
724 |
+
)
|
725 |
+
init_h_k = position_bias.apply_rotary_pos_emb(
|
726 |
+
h_k[:, :, :n_init, :].contiguous(),
|
727 |
+
n_init, n_init, position_bias._cos_cached, position_bias._sin_cached
|
728 |
+
)
|
729 |
+
init_h_v = h_v[:, :, :n_init, :].contiguous()
|
730 |
+
|
731 |
+
else:
|
732 |
+
init_h_q = h_q
|
733 |
+
init_h_k = torch.empty(
|
734 |
+
(batch_size, num_heads_kv, 0, dim_head),
|
735 |
+
device=h_k.device,
|
736 |
+
dtype=h_k.dtype
|
737 |
+
)
|
738 |
+
init_h_v = torch.empty(
|
739 |
+
(batch_size, num_heads_kv, 0, dim_head),
|
740 |
+
device=h_v.device,
|
741 |
+
dtype=h_v.dtype
|
742 |
+
)
|
743 |
+
|
744 |
+
attn = Attn(local_h_q.shape, local_h_q.dtype, local_h_q.device)
|
745 |
+
attn.append(local_h_q, local_h_k, local_h_v, sliding_window=n_local)
|
746 |
+
attn.append(
|
747 |
+
init_h_q, init_h_k, init_h_v, end=True,
|
748 |
+
sliding_window=(len_k - len_q, n_local),
|
749 |
+
complement_sliding_window=True
|
750 |
+
)
|
751 |
+
score, _ = attn.get_result()
|
752 |
+
|
753 |
+
score = score.view(batch_size, num_heads, len_q, dim_head).permute(0, 2, 1, 3).contiguous() # (batch, len_q, num_heads, dim_head)
|
754 |
+
score = score.reshape(batch_size, len_q, num_heads * dim_head) # (batch, len_q, num_heads * dim_head)
|
755 |
+
|
756 |
+
score = attention_out(score)
|
757 |
+
|
758 |
+
if use_cache:
|
759 |
+
return score, current_key_value
|
760 |
+
else:
|
761 |
+
return score
|
762 |
+
|
763 |
+
return forward
|
minference/patch.py
ADDED
@@ -0,0 +1,1279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import transformers
|
5 |
+
from transformers.cache_utils import *
|
6 |
+
from transformers.models.llama.modeling_llama import *
|
7 |
+
|
8 |
+
from .modules.inf_llm import InfLLMGenerator, inf_llm_forward
|
9 |
+
from .modules.minference_forward import (
|
10 |
+
gather_last_q_vertical_slash_topk_v4,
|
11 |
+
gather_last_q_vertical_slash_topk_vllm,
|
12 |
+
init_minference_parameters,
|
13 |
+
minference_forward,
|
14 |
+
minference_kv_cache_cpu_forward,
|
15 |
+
minference_vllm_forward,
|
16 |
+
minference_with_snapkv_forward,
|
17 |
+
search_pattern,
|
18 |
+
sum_all_diagonal_matrix,
|
19 |
+
)
|
20 |
+
from .ops.streaming_kernel import stream_llm_forward
|
21 |
+
|
22 |
+
|
23 |
+
class RotaryEmbeddingESM(torch.nn.Module):
|
24 |
+
"""
|
25 |
+
Rotary position embeddings based on those in
|
26 |
+
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
|
27 |
+
matrices which depend on their relative positions.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim: int,
|
33 |
+
base: Union[int, float] = 10000,
|
34 |
+
distance_scale: Union[int, float] = 1,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
self.base = base
|
38 |
+
self.distance_scale = distance_scale
|
39 |
+
|
40 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
41 |
+
inv_freq = 1.0 / (
|
42 |
+
base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim)
|
43 |
+
)
|
44 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
45 |
+
|
46 |
+
self._seq_len_cached = -1
|
47 |
+
self._cos_cached = None
|
48 |
+
self._sin_cached = None
|
49 |
+
|
50 |
+
def rotate_half(self, x):
|
51 |
+
x1, x2 = x.chunk(2, dim=-1)
|
52 |
+
return torch.cat((-x2, x1), dim=-1)
|
53 |
+
|
54 |
+
def apply_rotary_pos_emb(self, x, length, right, cos, sin):
|
55 |
+
dtype = x.dtype
|
56 |
+
if cos.dim() == 2:
|
57 |
+
cos = cos[right - length : right, :]
|
58 |
+
sin = sin[right - length : right, :]
|
59 |
+
elif cos.dim() == 3:
|
60 |
+
cos = cos[:, right - length : right, :]
|
61 |
+
sin = sin[:, right - length : right, :]
|
62 |
+
elif cos.dim() == 4:
|
63 |
+
cos = cos[:, :, right - length : right, :]
|
64 |
+
sin = sin[:, :, right - length : right, :]
|
65 |
+
|
66 |
+
return ((x.float() * cos) + (self.rotate_half(x).float() * sin)).to(dtype)
|
67 |
+
|
68 |
+
def _update_cos_sin_tables(self, x, seq_dim):
|
69 |
+
seq_len = x.size(seq_dim)
|
70 |
+
if seq_len > self._seq_len_cached:
|
71 |
+
self._seq_len_cached = seq_len
|
72 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
73 |
+
freqs = torch.outer(t * self.distance_scale, self.inv_freq)
|
74 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
75 |
+
if x.dim() == 2:
|
76 |
+
self._cos_cached = emb.cos()
|
77 |
+
self._sin_cached = emb.sin()
|
78 |
+
elif x.dim() == 3:
|
79 |
+
self._cos_cached = emb.cos()[None, :, :]
|
80 |
+
self._sin_cached = emb.sin()[None, :, :]
|
81 |
+
elif x.dim() == 4:
|
82 |
+
self._cos_cached = emb.cos()[None, None, :, :]
|
83 |
+
self._sin_cached = emb.sin()[None, None, :, :]
|
84 |
+
return self._cos_cached, self._sin_cached
|
85 |
+
|
86 |
+
def _update_cos_sin_tables_len(self, seq_len, device, dim=None):
|
87 |
+
if seq_len > self._seq_len_cached:
|
88 |
+
if dim is None:
|
89 |
+
assert self._cos_cached is not None
|
90 |
+
dim = self._cos_cached.dim()
|
91 |
+
|
92 |
+
self._seq_len_cached = seq_len
|
93 |
+
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
94 |
+
freqs = torch.outer(t * self.distance_scale, self.inv_freq)
|
95 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
96 |
+
if dim == 2:
|
97 |
+
self._cos_cached = emb.cos()
|
98 |
+
self._sin_cached = emb.sin()
|
99 |
+
elif dim == 3:
|
100 |
+
self._cos_cached = emb.cos()[None, :, :]
|
101 |
+
self._sin_cached = emb.sin()[None, :, :]
|
102 |
+
elif dim == 4:
|
103 |
+
self._cos_cached = emb.cos()[None, None, :, :]
|
104 |
+
self._sin_cached = emb.sin()[None, None, :, :]
|
105 |
+
|
106 |
+
return self._cos_cached, self._sin_cached
|
107 |
+
|
108 |
+
def apply_rotary_pos_emb_one_angle(self, x: torch.Tensor, index):
|
109 |
+
dtype = x.dtype
|
110 |
+
cos, sin = self._update_cos_sin_tables_len(index, x.device)
|
111 |
+
if cos.dim() == 2:
|
112 |
+
cos = cos[index - 1 : index, :]
|
113 |
+
sin = sin[index - 1 : index, :]
|
114 |
+
elif cos.dim() == 3:
|
115 |
+
cos = cos[:, index - 1 : index, :]
|
116 |
+
sin = sin[:, index - 1 : index, :]
|
117 |
+
elif cos.dim() == 4:
|
118 |
+
cos = cos[:, :, index - 1 : index, :]
|
119 |
+
sin = sin[:, :, index - 1 : index, :]
|
120 |
+
|
121 |
+
return ((x.float() * cos) + (self.rotate_half(x).float() * sin)).to(dtype)
|
122 |
+
|
123 |
+
def forward(
|
124 |
+
self, q: torch.Tensor, k: torch.Tensor, seq_dim=-2
|
125 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
126 |
+
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
|
127 |
+
k, seq_dim=seq_dim
|
128 |
+
)
|
129 |
+
return (
|
130 |
+
self.apply_rotary_pos_emb(
|
131 |
+
q, q.size(seq_dim), k.size(seq_dim), self._cos_cached, self._sin_cached
|
132 |
+
),
|
133 |
+
self.apply_rotary_pos_emb(
|
134 |
+
k, k.size(seq_dim), k.size(seq_dim), self._cos_cached, self._sin_cached
|
135 |
+
),
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
ATTN_FORWRAD = {
|
140 |
+
"streaming": stream_llm_forward,
|
141 |
+
"minference": minference_forward,
|
142 |
+
"inf_llm": inf_llm_forward,
|
143 |
+
}
|
144 |
+
|
145 |
+
|
146 |
+
def huggingface_forward(forward):
|
147 |
+
def hf_forward(
|
148 |
+
self,
|
149 |
+
hidden_states: torch.Tensor,
|
150 |
+
attention_mask=None,
|
151 |
+
position_ids=None,
|
152 |
+
past_key_value=None,
|
153 |
+
output_attentions: bool = False,
|
154 |
+
use_cache: bool = False,
|
155 |
+
**kwargs,
|
156 |
+
):
|
157 |
+
assert not output_attentions
|
158 |
+
ret = forward(
|
159 |
+
self,
|
160 |
+
hidden_states,
|
161 |
+
hidden_states,
|
162 |
+
position_ids,
|
163 |
+
use_cache,
|
164 |
+
past_key_value,
|
165 |
+
self.q_proj,
|
166 |
+
self.k_proj,
|
167 |
+
self.v_proj,
|
168 |
+
self.o_proj,
|
169 |
+
self.head_dim,
|
170 |
+
self.num_heads,
|
171 |
+
self.num_key_value_heads,
|
172 |
+
)
|
173 |
+
if use_cache:
|
174 |
+
o, pkv = ret
|
175 |
+
else:
|
176 |
+
o = ret
|
177 |
+
pkv = None
|
178 |
+
|
179 |
+
return o, None, pkv
|
180 |
+
|
181 |
+
return hf_forward
|
182 |
+
|
183 |
+
|
184 |
+
def hf_437_prepare_inputs_for_generation(
|
185 |
+
self,
|
186 |
+
input_ids,
|
187 |
+
past_key_values=None,
|
188 |
+
attention_mask=None,
|
189 |
+
inputs_embeds=None,
|
190 |
+
**kwargs,
|
191 |
+
):
|
192 |
+
if past_key_values is not None:
|
193 |
+
if isinstance(past_key_values, transformers.cache_utils.Cache):
|
194 |
+
cache_length = past_key_values.get_seq_length()
|
195 |
+
past_length = past_key_values.seen_tokens
|
196 |
+
max_cache_length = past_key_values.get_max_length()
|
197 |
+
else:
|
198 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
199 |
+
max_cache_length = None
|
200 |
+
|
201 |
+
# Keep only the unprocessed tokens:
|
202 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
203 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
204 |
+
# input)
|
205 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
206 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
207 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
208 |
+
# input_ids based on the past_length.
|
209 |
+
elif past_length < input_ids.shape[1]:
|
210 |
+
input_ids = input_ids[:, past_length:]
|
211 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
212 |
+
|
213 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
214 |
+
if (
|
215 |
+
max_cache_length is not None
|
216 |
+
and attention_mask is not None
|
217 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
218 |
+
):
|
219 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
220 |
+
|
221 |
+
position_ids = kwargs.get("position_ids", None)
|
222 |
+
if attention_mask is not None and position_ids is None:
|
223 |
+
# create position_ids on the fly for batch generation
|
224 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
225 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
226 |
+
if past_key_values:
|
227 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
228 |
+
|
229 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
230 |
+
if inputs_embeds is not None and past_key_values is None:
|
231 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
232 |
+
else:
|
233 |
+
model_inputs = {"input_ids": input_ids}
|
234 |
+
|
235 |
+
model_inputs.update(
|
236 |
+
{
|
237 |
+
"position_ids": position_ids,
|
238 |
+
"past_key_values": past_key_values,
|
239 |
+
"use_cache": kwargs.get("use_cache"),
|
240 |
+
"attention_mask": attention_mask,
|
241 |
+
}
|
242 |
+
)
|
243 |
+
return model_inputs
|
244 |
+
|
245 |
+
|
246 |
+
def prepare_inputs_for_generation(
|
247 |
+
self,
|
248 |
+
input_ids,
|
249 |
+
past_key_values=None,
|
250 |
+
attention_mask=None,
|
251 |
+
inputs_embeds=None,
|
252 |
+
cache_position=None,
|
253 |
+
**kwargs,
|
254 |
+
):
|
255 |
+
# With static cache, the `past_key_values` is None
|
256 |
+
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
257 |
+
has_static_cache = False
|
258 |
+
if past_key_values is None:
|
259 |
+
past_key_values = getattr(
|
260 |
+
getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None
|
261 |
+
)
|
262 |
+
has_static_cache = past_key_values is not None
|
263 |
+
|
264 |
+
past_length = 0
|
265 |
+
if past_key_values is not None:
|
266 |
+
if isinstance(past_key_values, transformers.cache_utils.Cache):
|
267 |
+
past_length = (
|
268 |
+
cache_position[0]
|
269 |
+
if cache_position is not None
|
270 |
+
else past_key_values.get_seq_length()
|
271 |
+
)
|
272 |
+
max_cache_length = (
|
273 |
+
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
274 |
+
if past_key_values.get_max_length() is not None
|
275 |
+
else None
|
276 |
+
)
|
277 |
+
cache_length = (
|
278 |
+
past_length
|
279 |
+
if max_cache_length is None
|
280 |
+
else torch.min(max_cache_length, past_length)
|
281 |
+
)
|
282 |
+
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
283 |
+
else:
|
284 |
+
# cache_length = past_length = past_key_values[0][0].shape[2]
|
285 |
+
cache_length = past_length = cache_position[0]
|
286 |
+
max_cache_length = None
|
287 |
+
|
288 |
+
# Keep only the unprocessed tokens:
|
289 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
290 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
291 |
+
# input)
|
292 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
293 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
294 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
295 |
+
# input_ids based on the past_length.
|
296 |
+
elif past_length < input_ids.shape[1]:
|
297 |
+
input_ids = input_ids[:, past_length:]
|
298 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
299 |
+
|
300 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
301 |
+
if (
|
302 |
+
max_cache_length is not None
|
303 |
+
and attention_mask is not None
|
304 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
305 |
+
):
|
306 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
307 |
+
|
308 |
+
position_ids = kwargs.get("position_ids", None)
|
309 |
+
if attention_mask is not None and position_ids is None:
|
310 |
+
# create position_ids on the fly for batch generation
|
311 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
312 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
313 |
+
if past_key_values:
|
314 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
315 |
+
|
316 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
317 |
+
if inputs_embeds is not None and past_key_values is None:
|
318 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
319 |
+
else:
|
320 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
321 |
+
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
322 |
+
# TODO: use `next_tokens` directly instead.
|
323 |
+
model_inputs = {"input_ids": input_ids.contiguous()}
|
324 |
+
|
325 |
+
input_length = (
|
326 |
+
position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
327 |
+
)
|
328 |
+
if cache_position is None:
|
329 |
+
cache_position = torch.arange(
|
330 |
+
past_length, past_length + input_length, device=input_ids.device
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
cache_position = cache_position[-input_length:]
|
334 |
+
|
335 |
+
if has_static_cache:
|
336 |
+
past_key_values = None
|
337 |
+
|
338 |
+
model_inputs.update(
|
339 |
+
{
|
340 |
+
"position_ids": position_ids,
|
341 |
+
"cache_position": cache_position,
|
342 |
+
"past_key_values": past_key_values,
|
343 |
+
"use_cache": kwargs.get("use_cache"),
|
344 |
+
"attention_mask": attention_mask,
|
345 |
+
}
|
346 |
+
)
|
347 |
+
return model_inputs
|
348 |
+
|
349 |
+
|
350 |
+
def prepare_inputs_for_generation_snapkv(
|
351 |
+
self,
|
352 |
+
input_ids,
|
353 |
+
past_key_values=None,
|
354 |
+
attention_mask=None,
|
355 |
+
inputs_embeds=None,
|
356 |
+
**kwargs,
|
357 |
+
):
|
358 |
+
if past_key_values is None: # [SnapKV]
|
359 |
+
for layer in self.model.layers:
|
360 |
+
layer.self_attn.kv_seq_len = 0
|
361 |
+
if past_key_values is not None:
|
362 |
+
if isinstance(past_key_values, Cache):
|
363 |
+
cache_length = past_key_values.get_seq_length()
|
364 |
+
past_length = past_key_values.seen_tokens
|
365 |
+
max_cache_length = past_key_values.get_max_length()
|
366 |
+
else:
|
367 |
+
# cache_length = past_length = past_key_values[0][0].shape[2]
|
368 |
+
# max_cache_length = None
|
369 |
+
cache_length = past_length = self.model.layers[0].self_attn.kv_seq_len
|
370 |
+
max_cache_length = None
|
371 |
+
# Keep only the unprocessed tokens:
|
372 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
373 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
374 |
+
# input)
|
375 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
376 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
377 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
378 |
+
# input_ids based on the past_length.
|
379 |
+
elif past_length < input_ids.shape[1]:
|
380 |
+
input_ids = input_ids[:, past_length:]
|
381 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
382 |
+
|
383 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
384 |
+
if (
|
385 |
+
max_cache_length is not None
|
386 |
+
and attention_mask is not None
|
387 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
388 |
+
):
|
389 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
390 |
+
|
391 |
+
position_ids = kwargs.get("position_ids", None)
|
392 |
+
if attention_mask is not None and position_ids is None:
|
393 |
+
# create position_ids on the fly for batch generation
|
394 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
395 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
396 |
+
if past_key_values:
|
397 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
398 |
+
|
399 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
400 |
+
if inputs_embeds is not None and past_key_values is None:
|
401 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
402 |
+
else:
|
403 |
+
model_inputs = {"input_ids": input_ids}
|
404 |
+
|
405 |
+
model_inputs.update(
|
406 |
+
{
|
407 |
+
"position_ids": position_ids,
|
408 |
+
"past_key_values": past_key_values,
|
409 |
+
"use_cache": kwargs.get("use_cache"),
|
410 |
+
"attention_mask": attention_mask,
|
411 |
+
}
|
412 |
+
)
|
413 |
+
return model_inputs
|
414 |
+
|
415 |
+
|
416 |
+
def _prepare_decoder_attention_mask_inference(
|
417 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
418 |
+
):
|
419 |
+
# [bsz, seq_len]
|
420 |
+
if past_key_values_length > 0 and attention_mask is not None:
|
421 |
+
attention_mask = torch.cat(
|
422 |
+
(
|
423 |
+
torch.full(
|
424 |
+
(input_shape[0], past_key_values_length),
|
425 |
+
True,
|
426 |
+
dtype=attention_mask.dtype,
|
427 |
+
device=attention_mask.device,
|
428 |
+
),
|
429 |
+
attention_mask,
|
430 |
+
),
|
431 |
+
dim=-1,
|
432 |
+
)
|
433 |
+
|
434 |
+
if attention_mask is not None and torch.all(attention_mask):
|
435 |
+
return None # This uses the faster call when training with full samples
|
436 |
+
|
437 |
+
return attention_mask
|
438 |
+
|
439 |
+
|
440 |
+
def forward_llama_decoder_layer(
|
441 |
+
self,
|
442 |
+
hidden_states: torch.Tensor,
|
443 |
+
attention_mask: Optional[torch.Tensor] = None,
|
444 |
+
position_ids: Optional[torch.LongTensor] = None,
|
445 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
446 |
+
output_attentions: Optional[bool] = False,
|
447 |
+
use_cache: Optional[bool] = False,
|
448 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
449 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
450 |
+
"""
|
451 |
+
Args:
|
452 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
453 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
454 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
455 |
+
output_attentions (`bool`, *optional*):
|
456 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
457 |
+
returned tensors for more detail.
|
458 |
+
use_cache (`bool`, *optional*):
|
459 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
460 |
+
(see `past_key_values`).
|
461 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
462 |
+
"""
|
463 |
+
|
464 |
+
residual = hidden_states.clone()
|
465 |
+
batch, seq_len, embed_dim = hidden_states.shape
|
466 |
+
|
467 |
+
for start_idx in range(0, seq_len, 32000):
|
468 |
+
end_idx = min(seq_len, start_idx + 32000)
|
469 |
+
hidden_states[:, start_idx:end_idx, :] = self.input_layernorm(
|
470 |
+
hidden_states[:, start_idx:end_idx, :]
|
471 |
+
)
|
472 |
+
|
473 |
+
# Self Attention
|
474 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
475 |
+
hidden_states=hidden_states,
|
476 |
+
attention_mask=attention_mask,
|
477 |
+
position_ids=position_ids,
|
478 |
+
past_key_value=past_key_value,
|
479 |
+
output_attentions=output_attentions,
|
480 |
+
use_cache=use_cache,
|
481 |
+
padding_mask=padding_mask,
|
482 |
+
)
|
483 |
+
hidden_states = residual + hidden_states
|
484 |
+
|
485 |
+
# Fully Connected
|
486 |
+
for start_idx in range(0, seq_len, 32000):
|
487 |
+
end_idx = min(seq_len, start_idx + 32000)
|
488 |
+
part_hidden_states = hidden_states[:, start_idx:end_idx, :].clone()
|
489 |
+
part_hidden_states = self.post_attention_layernorm(part_hidden_states)
|
490 |
+
part_hidden_states = self.mlp(part_hidden_states)
|
491 |
+
hidden_states[:, start_idx:end_idx, :] += part_hidden_states
|
492 |
+
|
493 |
+
outputs = (hidden_states,)
|
494 |
+
|
495 |
+
if output_attentions:
|
496 |
+
outputs += (self_attn_weights,)
|
497 |
+
|
498 |
+
if use_cache:
|
499 |
+
outputs += (present_key_value,)
|
500 |
+
|
501 |
+
return outputs
|
502 |
+
|
503 |
+
|
504 |
+
def forward_llama_model(
|
505 |
+
self,
|
506 |
+
input_ids: torch.LongTensor = None,
|
507 |
+
attention_mask: Optional[torch.Tensor] = None,
|
508 |
+
position_ids: Optional[torch.LongTensor] = None,
|
509 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
510 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
511 |
+
use_cache: Optional[bool] = None,
|
512 |
+
output_attentions: Optional[bool] = None,
|
513 |
+
output_hidden_states: Optional[bool] = None,
|
514 |
+
return_dict: Optional[bool] = None,
|
515 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
516 |
+
output_attentions = (
|
517 |
+
output_attentions
|
518 |
+
if output_attentions is not None
|
519 |
+
else self.config.output_attentions
|
520 |
+
)
|
521 |
+
output_hidden_states = (
|
522 |
+
output_hidden_states
|
523 |
+
if output_hidden_states is not None
|
524 |
+
else self.config.output_hidden_states
|
525 |
+
)
|
526 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
527 |
+
|
528 |
+
return_dict = (
|
529 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
530 |
+
)
|
531 |
+
|
532 |
+
# retrieve input_ids and inputs_embeds
|
533 |
+
if input_ids is not None and inputs_embeds is not None:
|
534 |
+
raise ValueError(
|
535 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
536 |
+
)
|
537 |
+
elif input_ids is not None:
|
538 |
+
batch_size, seq_length = input_ids.shape[:2]
|
539 |
+
elif inputs_embeds is not None:
|
540 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
541 |
+
else:
|
542 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
543 |
+
|
544 |
+
if self.gradient_checkpointing and self.training:
|
545 |
+
if use_cache:
|
546 |
+
logger.warning_once(
|
547 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
548 |
+
)
|
549 |
+
use_cache = False
|
550 |
+
|
551 |
+
seq_length_with_past = seq_length
|
552 |
+
past_key_values_length = 0
|
553 |
+
|
554 |
+
if use_cache:
|
555 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
556 |
+
if use_legacy_cache:
|
557 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
558 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
559 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
560 |
+
|
561 |
+
if position_ids is None:
|
562 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
563 |
+
position_ids = torch.arange(
|
564 |
+
past_key_values_length,
|
565 |
+
seq_length + past_key_values_length,
|
566 |
+
dtype=torch.long,
|
567 |
+
device=device,
|
568 |
+
)
|
569 |
+
position_ids = position_ids.unsqueeze(0)
|
570 |
+
|
571 |
+
if inputs_embeds is None:
|
572 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
573 |
+
|
574 |
+
if attention_mask is None:
|
575 |
+
attention_mask = torch.ones(
|
576 |
+
(batch_size, seq_length_with_past),
|
577 |
+
dtype=torch.bool,
|
578 |
+
device=inputs_embeds.device,
|
579 |
+
)
|
580 |
+
padding_mask = None
|
581 |
+
else:
|
582 |
+
if 0 in attention_mask:
|
583 |
+
padding_mask = attention_mask
|
584 |
+
else:
|
585 |
+
padding_mask = None
|
586 |
+
|
587 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
588 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
589 |
+
)
|
590 |
+
|
591 |
+
# embed positions
|
592 |
+
hidden_states = inputs_embeds
|
593 |
+
|
594 |
+
# decoder layers
|
595 |
+
all_hidden_states = () if output_hidden_states else None
|
596 |
+
all_self_attns = () if output_attentions else None
|
597 |
+
next_decoder_cache = None
|
598 |
+
|
599 |
+
for decoder_layer in self.layers:
|
600 |
+
if output_hidden_states:
|
601 |
+
all_hidden_states += (hidden_states,)
|
602 |
+
|
603 |
+
if self.gradient_checkpointing and self.training:
|
604 |
+
layer_outputs = self._gradient_checkpointing_func(
|
605 |
+
decoder_layer.__call__,
|
606 |
+
hidden_states,
|
607 |
+
attention_mask,
|
608 |
+
position_ids,
|
609 |
+
past_key_values,
|
610 |
+
output_attentions,
|
611 |
+
use_cache,
|
612 |
+
)
|
613 |
+
else:
|
614 |
+
layer_outputs = decoder_layer(
|
615 |
+
hidden_states,
|
616 |
+
attention_mask=attention_mask,
|
617 |
+
position_ids=position_ids,
|
618 |
+
past_key_value=past_key_values,
|
619 |
+
output_attentions=output_attentions,
|
620 |
+
use_cache=use_cache,
|
621 |
+
)
|
622 |
+
|
623 |
+
hidden_states = layer_outputs[0]
|
624 |
+
|
625 |
+
if use_cache:
|
626 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
627 |
+
|
628 |
+
if output_attentions:
|
629 |
+
all_self_attns += (layer_outputs[1],)
|
630 |
+
|
631 |
+
batch, seq_len, embed_dim = hidden_states.shape
|
632 |
+
for start_idx in range(0, seq_len, 32000):
|
633 |
+
end_idx = min(seq_len, start_idx + 32000)
|
634 |
+
hidden_states[:, start_idx:end_idx, :] = self.norm(
|
635 |
+
hidden_states[:, start_idx:end_idx, :]
|
636 |
+
)
|
637 |
+
|
638 |
+
# add hidden states from the last decoder layer
|
639 |
+
if output_hidden_states:
|
640 |
+
all_hidden_states += (hidden_states,)
|
641 |
+
|
642 |
+
next_cache = None
|
643 |
+
if use_cache:
|
644 |
+
next_cache = (
|
645 |
+
next_decoder_cache.to_legacy_cache()
|
646 |
+
if use_legacy_cache
|
647 |
+
else next_decoder_cache
|
648 |
+
)
|
649 |
+
if not return_dict:
|
650 |
+
return tuple(
|
651 |
+
v
|
652 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
653 |
+
if v is not None
|
654 |
+
)
|
655 |
+
return BaseModelOutputWithPast(
|
656 |
+
last_hidden_state=hidden_states,
|
657 |
+
past_key_values=next_cache,
|
658 |
+
hidden_states=all_hidden_states,
|
659 |
+
attentions=all_self_attns,
|
660 |
+
)
|
661 |
+
|
662 |
+
|
663 |
+
def forward_llama_for_causal_lm(
|
664 |
+
self,
|
665 |
+
input_ids: torch.LongTensor = None,
|
666 |
+
attention_mask: Optional[torch.Tensor] = None,
|
667 |
+
position_ids: Optional[torch.LongTensor] = None,
|
668 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
669 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
670 |
+
labels: Optional[torch.LongTensor] = None,
|
671 |
+
use_cache: Optional[bool] = None,
|
672 |
+
output_attentions: Optional[bool] = None,
|
673 |
+
output_hidden_states: Optional[bool] = None,
|
674 |
+
return_dict: Optional[bool] = None,
|
675 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
676 |
+
# assert labels is not None
|
677 |
+
output_attentions = (
|
678 |
+
output_attentions
|
679 |
+
if output_attentions is not None
|
680 |
+
else self.config.output_attentions
|
681 |
+
)
|
682 |
+
output_hidden_states = (
|
683 |
+
output_hidden_states
|
684 |
+
if output_hidden_states is not None
|
685 |
+
else self.config.output_hidden_states
|
686 |
+
)
|
687 |
+
return_dict = (
|
688 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
689 |
+
)
|
690 |
+
|
691 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
692 |
+
outputs = self.model(
|
693 |
+
input_ids=input_ids,
|
694 |
+
attention_mask=attention_mask,
|
695 |
+
position_ids=position_ids,
|
696 |
+
past_key_values=past_key_values,
|
697 |
+
inputs_embeds=inputs_embeds,
|
698 |
+
use_cache=use_cache,
|
699 |
+
output_attentions=output_attentions,
|
700 |
+
output_hidden_states=output_hidden_states,
|
701 |
+
return_dict=return_dict,
|
702 |
+
)
|
703 |
+
torch.cuda.empty_cache()
|
704 |
+
|
705 |
+
hidden_states = outputs[0]
|
706 |
+
if labels is not None:
|
707 |
+
loss_fct = CrossEntropyLoss(reduction="sum")
|
708 |
+
valid_seq_len = input_ids.shape[-1] - 1
|
709 |
+
valid_seq_len_slide_win = torch.sum(labels[:, 1:] >= 0).item()
|
710 |
+
# print("valid_seq_len_slide_win", valid_seq_len)
|
711 |
+
loss = 0.0
|
712 |
+
|
713 |
+
for start_idx in range(0, valid_seq_len, 32000):
|
714 |
+
end_idx = min(start_idx + 32000, valid_seq_len)
|
715 |
+
shift_logits = self.lm_head(
|
716 |
+
hidden_states[..., start_idx:end_idx, :]
|
717 |
+
).float()
|
718 |
+
shift_labels = labels[..., start_idx + 1 : end_idx + 1].contiguous()
|
719 |
+
# Flatten the tokens
|
720 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
721 |
+
shift_labels = shift_labels.view(-1)
|
722 |
+
# Enable model parallelism
|
723 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
724 |
+
loss += loss_fct(shift_logits, shift_labels)
|
725 |
+
|
726 |
+
loss /= valid_seq_len_slide_win
|
727 |
+
logits = None
|
728 |
+
else:
|
729 |
+
if self.config.to_dict().get("is_ppl", False):
|
730 |
+
logits = self.lm_head(hidden_states)
|
731 |
+
else:
|
732 |
+
logits = self.lm_head(hidden_states[:, -1:]).float()
|
733 |
+
loss = None
|
734 |
+
|
735 |
+
return CausalLMOutputWithPast(
|
736 |
+
loss=loss,
|
737 |
+
logits=logits,
|
738 |
+
past_key_values=outputs.past_key_values,
|
739 |
+
)
|
740 |
+
|
741 |
+
|
742 |
+
def minference_patch(model, config):
|
743 |
+
from transformers import LlamaForCausalLM
|
744 |
+
|
745 |
+
if config.kv_cache_cpu:
|
746 |
+
return minference_patch_kv_cache_cpu(model)
|
747 |
+
if config.use_snapkv:
|
748 |
+
return minference_patch_with_snapkv(model)
|
749 |
+
|
750 |
+
Attention = model.model.layers[0].self_attn.__class__
|
751 |
+
Model = model.model.__class__
|
752 |
+
DecoderLayer = model.model.layers[0].__class__
|
753 |
+
|
754 |
+
forward = minference_forward()
|
755 |
+
|
756 |
+
def update_module(m):
|
757 |
+
if isinstance(m, Attention):
|
758 |
+
m.init_minference_parameters = init_minference_parameters.__get__(
|
759 |
+
m, Attention
|
760 |
+
)
|
761 |
+
m.gather_last_q_vertical_slash_topk_v4 = (
|
762 |
+
gather_last_q_vertical_slash_topk_v4.__get__(m, Attention)
|
763 |
+
)
|
764 |
+
m.forward = forward.__get__(m, Attention)
|
765 |
+
if isinstance(m, DecoderLayer):
|
766 |
+
m.forward = forward_llama_decoder_layer.__get__(m, DecoderLayer)
|
767 |
+
|
768 |
+
model.apply(update_module)
|
769 |
+
model.prepare_inputs_for_generation = hf_437_prepare_inputs_for_generation.__get__(
|
770 |
+
model, model.__class__
|
771 |
+
)
|
772 |
+
model.model._use_sdpa = False
|
773 |
+
|
774 |
+
model.model._prepare_decoder_attention_mask = (
|
775 |
+
_prepare_decoder_attention_mask_inference.__get__(
|
776 |
+
model.model, model.model.__class__
|
777 |
+
)
|
778 |
+
)
|
779 |
+
model.model.forward = forward_llama_model.__get__(
|
780 |
+
model.model, model.model.__class__
|
781 |
+
)
|
782 |
+
model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__)
|
783 |
+
|
784 |
+
print("Patched model for minference..")
|
785 |
+
return model
|
786 |
+
|
787 |
+
|
788 |
+
def minference_patch_kv_cache_cpu(model):
|
789 |
+
from transformers import LlamaForCausalLM
|
790 |
+
|
791 |
+
transformers.cache_utils.DynamicCache.update = cpu_cache_update
|
792 |
+
transformers.cache_utils.DynamicCache.get = cpu_cache_get
|
793 |
+
|
794 |
+
Attention = model.model.layers[0].self_attn.__class__
|
795 |
+
Model = model.model.__class__
|
796 |
+
DecoderLayer = model.model.layers[0].__class__
|
797 |
+
|
798 |
+
forward = minference_kv_cache_cpu_forward()
|
799 |
+
|
800 |
+
def update_module(m):
|
801 |
+
if isinstance(m, Attention):
|
802 |
+
m.init_minference_parameters = init_minference_parameters.__get__(
|
803 |
+
m, Attention
|
804 |
+
)
|
805 |
+
m.gather_last_q_vertical_slash_topk_v4 = (
|
806 |
+
gather_last_q_vertical_slash_topk_v4.__get__(m, Attention)
|
807 |
+
)
|
808 |
+
m.forward = forward.__get__(m, Attention)
|
809 |
+
if isinstance(m, DecoderLayer):
|
810 |
+
m.forward = forward_llama_decoder_layer.__get__(m, DecoderLayer)
|
811 |
+
|
812 |
+
model.apply(update_module)
|
813 |
+
model.prepare_inputs_for_generation = hf_437_prepare_inputs_for_generation.__get__(
|
814 |
+
model, model.__class__
|
815 |
+
)
|
816 |
+
model.model._use_sdpa = False
|
817 |
+
|
818 |
+
model.model._prepare_decoder_attention_mask = (
|
819 |
+
_prepare_decoder_attention_mask_inference.__get__(
|
820 |
+
model.model, model.model.__class__
|
821 |
+
)
|
822 |
+
)
|
823 |
+
model.model.forward = forward_llama_model.__get__(
|
824 |
+
model.model, model.model.__class__
|
825 |
+
)
|
826 |
+
model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__)
|
827 |
+
|
828 |
+
print("Patched model for MInference load KV Cache to CPU.")
|
829 |
+
return model
|
830 |
+
|
831 |
+
|
832 |
+
def minference_patch_with_snapkv(model):
|
833 |
+
from transformers import LlamaForCausalLM
|
834 |
+
|
835 |
+
Attention = model.model.layers[0].self_attn.__class__
|
836 |
+
Model = model.model.__class__
|
837 |
+
DecoderLayer = model.model.layers[0].__class__
|
838 |
+
|
839 |
+
forward = minference_with_snapkv_forward()
|
840 |
+
|
841 |
+
def update_module(m):
|
842 |
+
if isinstance(m, Attention):
|
843 |
+
m.init_minference_parameters = init_minference_parameters.__get__(
|
844 |
+
m, Attention
|
845 |
+
)
|
846 |
+
m.gather_last_q_vertical_slash_topk_v4 = (
|
847 |
+
gather_last_q_vertical_slash_topk_v4.__get__(m, Attention)
|
848 |
+
)
|
849 |
+
m.forward = forward.__get__(m, Attention)
|
850 |
+
if isinstance(m, DecoderLayer):
|
851 |
+
m.forward = forward_llama_decoder_layer.__get__(m, DecoderLayer)
|
852 |
+
|
853 |
+
model.apply(update_module)
|
854 |
+
model.prepare_inputs_for_generation = prepare_inputs_for_generation_snapkv.__get__(
|
855 |
+
model, model.__class__
|
856 |
+
)
|
857 |
+
model.model._use_sdpa = False
|
858 |
+
|
859 |
+
model.model._prepare_decoder_attention_mask = (
|
860 |
+
_prepare_decoder_attention_mask_inference.__get__(
|
861 |
+
model.model, model.model.__class__
|
862 |
+
)
|
863 |
+
)
|
864 |
+
model.model.forward = forward_llama_model.__get__(
|
865 |
+
model.model, model.model.__class__
|
866 |
+
)
|
867 |
+
model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__)
|
868 |
+
|
869 |
+
print("Patched model for minference with SanpKV..")
|
870 |
+
return model
|
871 |
+
|
872 |
+
|
873 |
+
def llama_model_forward_vllm(
|
874 |
+
self,
|
875 |
+
input_ids: Optional[torch.Tensor],
|
876 |
+
positions: torch.Tensor,
|
877 |
+
kv_caches: List[torch.Tensor],
|
878 |
+
attn_metadata,
|
879 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
880 |
+
) -> torch.Tensor:
|
881 |
+
if inputs_embeds is not None:
|
882 |
+
hidden_states = inputs_embeds
|
883 |
+
else:
|
884 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
885 |
+
residual = None
|
886 |
+
for i in range(len(self.layers)):
|
887 |
+
layer = self.layers[i]
|
888 |
+
hidden_states, residual = layer(
|
889 |
+
positions,
|
890 |
+
hidden_states,
|
891 |
+
kv_caches[i],
|
892 |
+
attn_metadata,
|
893 |
+
residual,
|
894 |
+
layer_idx=i,
|
895 |
+
)
|
896 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
897 |
+
return hidden_states
|
898 |
+
|
899 |
+
|
900 |
+
def llama_layer_forward_vllm(
|
901 |
+
self,
|
902 |
+
positions: torch.Tensor,
|
903 |
+
hidden_states: torch.Tensor,
|
904 |
+
kv_cache: torch.Tensor,
|
905 |
+
attn_metadata,
|
906 |
+
residual: Optional[torch.Tensor],
|
907 |
+
layer_idx: int,
|
908 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
909 |
+
# Self Attention
|
910 |
+
if residual is None:
|
911 |
+
residual = hidden_states
|
912 |
+
hidden_states = self.input_layernorm(hidden_states)
|
913 |
+
else:
|
914 |
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
915 |
+
hidden_states = self.self_attn(
|
916 |
+
positions=positions,
|
917 |
+
hidden_states=hidden_states,
|
918 |
+
kv_cache=kv_cache,
|
919 |
+
attn_metadata=attn_metadata,
|
920 |
+
layer_idx=layer_idx,
|
921 |
+
)
|
922 |
+
|
923 |
+
# Fully Connected
|
924 |
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
925 |
+
hidden_states = self.mlp(hidden_states)
|
926 |
+
return hidden_states, residual
|
927 |
+
|
928 |
+
|
929 |
+
def llama_attn_forward_vllm(
|
930 |
+
self,
|
931 |
+
positions: torch.Tensor,
|
932 |
+
hidden_states: torch.Tensor,
|
933 |
+
kv_cache: torch.Tensor,
|
934 |
+
attn_metadata,
|
935 |
+
layer_idx: int,
|
936 |
+
) -> torch.Tensor:
|
937 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
938 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
939 |
+
q, k = self.rotary_emb(positions, q, k)
|
940 |
+
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale, layer_idx)
|
941 |
+
output, _ = self.o_proj(attn_output)
|
942 |
+
return output
|
943 |
+
|
944 |
+
|
945 |
+
def vllm_attn_forward(
|
946 |
+
self,
|
947 |
+
query: torch.Tensor,
|
948 |
+
key: torch.Tensor,
|
949 |
+
value: torch.Tensor,
|
950 |
+
kv_cache: Optional[torch.Tensor],
|
951 |
+
attn_metadata,
|
952 |
+
kv_scale: float = 1.0,
|
953 |
+
layer_idx: int = 0,
|
954 |
+
) -> torch.Tensor:
|
955 |
+
return self.impl.forward(
|
956 |
+
query, key, value, kv_cache, attn_metadata, kv_scale, layer_idx
|
957 |
+
)
|
958 |
+
|
959 |
+
|
960 |
+
def minference_patch_vllm(
|
961 |
+
llm,
|
962 |
+
config_file,
|
963 |
+
):
|
964 |
+
from vllm.attention import Attention
|
965 |
+
from vllm.model_executor.models.llama import (
|
966 |
+
LlamaAttention,
|
967 |
+
LlamaDecoderLayer,
|
968 |
+
LlamaForCausalLM,
|
969 |
+
LlamaModel,
|
970 |
+
)
|
971 |
+
|
972 |
+
config = json.load(open(config_file))
|
973 |
+
attn_forward = minference_vllm_forward(config)
|
974 |
+
|
975 |
+
def update_module(m):
|
976 |
+
if isinstance(m, Attention):
|
977 |
+
m.forward = vllm_attn_forward.__get__(m, Attention)
|
978 |
+
|
979 |
+
m = m.impl
|
980 |
+
m_cls = m.__class__
|
981 |
+
m.gather_last_q_vertical_slash_topk_vllm = (
|
982 |
+
gather_last_q_vertical_slash_topk_vllm.__get__(m, m_cls)
|
983 |
+
)
|
984 |
+
m.forward = attn_forward.__get__(m, m_cls)
|
985 |
+
if isinstance(m, LlamaDecoderLayer):
|
986 |
+
m.forward = llama_layer_forward_vllm.__get__(m, LlamaDecoderLayer)
|
987 |
+
if isinstance(m, LlamaModel):
|
988 |
+
m.forward = llama_model_forward_vllm.__get__(m, LlamaModel)
|
989 |
+
if isinstance(m, LlamaAttention):
|
990 |
+
m.forward = llama_attn_forward_vllm.__get__(m, LlamaAttention)
|
991 |
+
|
992 |
+
llm.llm_engine.model_executor.driver_worker.model_runner.model.apply(update_module)
|
993 |
+
|
994 |
+
print("Patched model for minference with VLLM..")
|
995 |
+
return llm
|
996 |
+
|
997 |
+
|
998 |
+
def patch_hf(
|
999 |
+
model,
|
1000 |
+
attn_type: str = "inf_llm",
|
1001 |
+
attn_kwargs: dict = {},
|
1002 |
+
base=None,
|
1003 |
+
distance_scale=None,
|
1004 |
+
**kwargs,
|
1005 |
+
):
|
1006 |
+
attn_kwargs.update(kwargs)
|
1007 |
+
# This approach lacks scalability and will be refactored.
|
1008 |
+
from transformers import LlamaForCausalLM, MistralForCausalLM, Qwen2ForCausalLM
|
1009 |
+
from transformers.models.llama.modeling_llama import (
|
1010 |
+
BaseModelOutputWithPast,
|
1011 |
+
LlamaAttention,
|
1012 |
+
LlamaModel,
|
1013 |
+
)
|
1014 |
+
from transformers.models.mistral.modeling_mistral import (
|
1015 |
+
MistralAttention,
|
1016 |
+
MistralModel,
|
1017 |
+
)
|
1018 |
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2Model
|
1019 |
+
|
1020 |
+
def model_forward(
|
1021 |
+
self,
|
1022 |
+
input_ids: torch.LongTensor = None,
|
1023 |
+
attention_mask=None,
|
1024 |
+
position_ids=None,
|
1025 |
+
past_key_values=None,
|
1026 |
+
inputs_embeds=None,
|
1027 |
+
use_cache=None,
|
1028 |
+
output_attentions=None,
|
1029 |
+
output_hidden_states=None,
|
1030 |
+
return_dict=None,
|
1031 |
+
*args,
|
1032 |
+
**kwargs,
|
1033 |
+
):
|
1034 |
+
output_attentions = (
|
1035 |
+
output_attentions
|
1036 |
+
if output_attentions is not None
|
1037 |
+
else self.config.output_attentions
|
1038 |
+
)
|
1039 |
+
output_hidden_states = (
|
1040 |
+
output_hidden_states
|
1041 |
+
if output_hidden_states is not None
|
1042 |
+
else self.config.output_hidden_states
|
1043 |
+
)
|
1044 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1045 |
+
|
1046 |
+
return_dict = (
|
1047 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
# retrieve input_ids and inputs_embeds
|
1051 |
+
if input_ids is not None and inputs_embeds is not None:
|
1052 |
+
raise ValueError(
|
1053 |
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
1054 |
+
)
|
1055 |
+
elif input_ids is not None:
|
1056 |
+
batch_size, seq_length = input_ids.shape
|
1057 |
+
elif inputs_embeds is not None:
|
1058 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
1059 |
+
else:
|
1060 |
+
raise ValueError(
|
1061 |
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
1062 |
+
)
|
1063 |
+
|
1064 |
+
if inputs_embeds is None:
|
1065 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
1066 |
+
if hasattr(self, "config") and hasattr(self.config, "scale_emb"):
|
1067 |
+
inputs_embeds = inputs_embeds * self.config.scale_emb
|
1068 |
+
|
1069 |
+
if use_cache:
|
1070 |
+
pkv = tuple()
|
1071 |
+
|
1072 |
+
else:
|
1073 |
+
pkv = None
|
1074 |
+
|
1075 |
+
hidden_states = inputs_embeds
|
1076 |
+
|
1077 |
+
# decoder layers
|
1078 |
+
all_hidden_states = () if output_hidden_states else None
|
1079 |
+
all_self_attns = () if output_attentions else None
|
1080 |
+
|
1081 |
+
for i, decoder_layer in enumerate(self.layers):
|
1082 |
+
if output_hidden_states:
|
1083 |
+
all_hidden_states += (hidden_states,)
|
1084 |
+
|
1085 |
+
layer_outputs = decoder_layer(
|
1086 |
+
hidden_states,
|
1087 |
+
attention_mask=attention_mask,
|
1088 |
+
position_ids=self.position_bias,
|
1089 |
+
past_key_value=(
|
1090 |
+
past_key_values[i] if past_key_values is not None else None
|
1091 |
+
),
|
1092 |
+
output_attentions=output_attentions,
|
1093 |
+
use_cache=use_cache,
|
1094 |
+
)
|
1095 |
+
|
1096 |
+
hidden_states = layer_outputs[0]
|
1097 |
+
|
1098 |
+
if use_cache:
|
1099 |
+
_cache = layer_outputs[2 if output_attentions else 1]
|
1100 |
+
pkv = pkv + (_cache,)
|
1101 |
+
|
1102 |
+
if output_attentions:
|
1103 |
+
all_self_attns += (layer_outputs[1],)
|
1104 |
+
|
1105 |
+
# hidden_states = self.norm(hidden_states)
|
1106 |
+
for start_idx in range(0, hidden_states.size(1), 32000):
|
1107 |
+
end_idx = min(hidden_states.size(1), start_idx + 32000)
|
1108 |
+
hidden_states[:, start_idx:end_idx, :] = self.norm(
|
1109 |
+
hidden_states[:, start_idx:end_idx, :]
|
1110 |
+
)
|
1111 |
+
|
1112 |
+
# add hidden states from the last decoder layer
|
1113 |
+
if output_hidden_states:
|
1114 |
+
all_hidden_states += (hidden_states,)
|
1115 |
+
|
1116 |
+
if not return_dict:
|
1117 |
+
return tuple(
|
1118 |
+
v
|
1119 |
+
for v in [hidden_states, pkv, all_hidden_states, all_self_attns]
|
1120 |
+
if v is not None
|
1121 |
+
)
|
1122 |
+
return BaseModelOutputWithPast(
|
1123 |
+
last_hidden_state=hidden_states,
|
1124 |
+
past_key_values=pkv,
|
1125 |
+
hidden_states=all_hidden_states,
|
1126 |
+
attentions=all_self_attns,
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
forward = huggingface_forward(ATTN_FORWRAD[attn_type](**attn_kwargs))
|
1130 |
+
|
1131 |
+
if isinstance(model, LlamaForCausalLM):
|
1132 |
+
Attention = model.model.layers[0].self_attn.__class__
|
1133 |
+
Model = model.model.__class__
|
1134 |
+
elif isinstance(model, MistralForCausalLM):
|
1135 |
+
Attention = model.model.layers[0].self_attn.__class__
|
1136 |
+
Model = model.model.__class__
|
1137 |
+
elif isinstance(model, Qwen2ForCausalLM):
|
1138 |
+
Attention = model.model.layers[0].self_attn.__class__
|
1139 |
+
Model = model.model.__class__
|
1140 |
+
elif model.__class__.__name__ == "MiniCPMForCausalLM":
|
1141 |
+
Attention = model.model.layers[0].self_attn.__class__
|
1142 |
+
Model = model.model.__class__
|
1143 |
+
elif model.__class__.__name__ == "Phi3ForCausalLM":
|
1144 |
+
Attention = model.model.layers[0].self_attn.__class__
|
1145 |
+
Model = model.model.__class__
|
1146 |
+
else:
|
1147 |
+
raise ValueError("Only supports llama, mistral and qwen2 models.")
|
1148 |
+
|
1149 |
+
hf_rope = model.model.layers[0].self_attn.rotary_emb
|
1150 |
+
base = base if base is not None else hf_rope.base
|
1151 |
+
distance_scale = distance_scale if distance_scale is not None else 1.0
|
1152 |
+
rope = RotaryEmbeddingESM(hf_rope.dim, base, distance_scale)
|
1153 |
+
model.model.position_bias = rope
|
1154 |
+
model.model.hf_position_bias = hf_rope
|
1155 |
+
|
1156 |
+
def set_forward(m):
|
1157 |
+
if isinstance(m, Attention):
|
1158 |
+
m._old_forward = m.forward
|
1159 |
+
m.forward = forward.__get__(m, Attention)
|
1160 |
+
|
1161 |
+
model.apply(set_forward)
|
1162 |
+
|
1163 |
+
model._old_prepare_inputs_for_generation = model.prepare_inputs_for_generation
|
1164 |
+
model.prepare_inputs_for_generation = prepare_inputs_for_generation.__get__(
|
1165 |
+
model, model.__class__
|
1166 |
+
)
|
1167 |
+
model.model._old_forward = model.model.forward
|
1168 |
+
model.model.forward = model_forward.__get__(model.model, Model)
|
1169 |
+
|
1170 |
+
if attn_type == "inf_llm":
|
1171 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
1172 |
+
model.config._name_or_path
|
1173 |
+
)
|
1174 |
+
model = InfLLMGenerator(model, tokenizer)
|
1175 |
+
|
1176 |
+
print("Patched model ...")
|
1177 |
+
return model
|
1178 |
+
|
1179 |
+
|
1180 |
+
def fp8_cache_update(
|
1181 |
+
self,
|
1182 |
+
key_states: torch.Tensor,
|
1183 |
+
value_states: torch.Tensor,
|
1184 |
+
layer_idx: int,
|
1185 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
1186 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1187 |
+
"""
|
1188 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
1189 |
+
|
1190 |
+
Parameters:
|
1191 |
+
key_states (`torch.Tensor`):
|
1192 |
+
The new key states to cache.
|
1193 |
+
value_states (`torch.Tensor`):
|
1194 |
+
The new value states to cache.
|
1195 |
+
layer_idx (`int`):
|
1196 |
+
The index of the layer to cache the states for.
|
1197 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
1198 |
+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
1199 |
+
|
1200 |
+
Return:
|
1201 |
+
A tuple containing the updated key and value states.
|
1202 |
+
"""
|
1203 |
+
# Update the number of seen tokens
|
1204 |
+
if layer_idx == 0:
|
1205 |
+
self.seen_tokens += key_states.shape[-2]
|
1206 |
+
|
1207 |
+
# Update the cache
|
1208 |
+
if len(self.key_cache) <= layer_idx:
|
1209 |
+
self.key_cache.append(key_states.to(torch.float8_e5m2))
|
1210 |
+
self.value_cache.append(value_states.to(torch.float8_e5m2))
|
1211 |
+
else:
|
1212 |
+
self.key_cache[layer_idx] = torch.cat(
|
1213 |
+
[self.key_cache[layer_idx], key_states.to(torch.float8_e5m2)], dim=-2
|
1214 |
+
)
|
1215 |
+
self.value_cache[layer_idx] = torch.cat(
|
1216 |
+
[self.value_cache[layer_idx], value_states.to(torch.float8_e5m2)], dim=-2
|
1217 |
+
)
|
1218 |
+
|
1219 |
+
return self.key_cache[layer_idx].to(key_states.dtype), self.value_cache[
|
1220 |
+
layer_idx
|
1221 |
+
].to(key_states.dtype)
|
1222 |
+
|
1223 |
+
|
1224 |
+
def cpu_cache_update(
|
1225 |
+
self,
|
1226 |
+
key_states: torch.Tensor,
|
1227 |
+
value_states: torch.Tensor,
|
1228 |
+
layer_idx: int,
|
1229 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
1230 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1231 |
+
if layer_idx == 0:
|
1232 |
+
if "_seen_tokens" in self.__dict__:
|
1233 |
+
self._seen_tokens += key_states.shape[-2]
|
1234 |
+
else:
|
1235 |
+
self.seen_tokens += key_states.shape[-2]
|
1236 |
+
|
1237 |
+
# Update the cache
|
1238 |
+
if len(self.key_cache) <= layer_idx:
|
1239 |
+
self.key_cache.append(key_states.cpu())
|
1240 |
+
self.value_cache.append(value_states.cpu())
|
1241 |
+
else:
|
1242 |
+
self.key_cache[layer_idx] = torch.cat(
|
1243 |
+
[self.key_cache[layer_idx], key_states.cpu()], dim=-2
|
1244 |
+
)
|
1245 |
+
self.value_cache[layer_idx] = torch.cat(
|
1246 |
+
[self.value_cache[layer_idx], value_states.cpu()], dim=-2
|
1247 |
+
)
|
1248 |
+
|
1249 |
+
|
1250 |
+
def cpu_cache_get(
|
1251 |
+
self,
|
1252 |
+
key_states: torch.Tensor,
|
1253 |
+
value_states: torch.Tensor,
|
1254 |
+
layer_idx: int,
|
1255 |
+
head_idx: int,
|
1256 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
1257 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1258 |
+
if layer_idx == 0:
|
1259 |
+
if "_seen_tokens" in self.__dict__:
|
1260 |
+
self._seen_tokens += key_states.shape[-2]
|
1261 |
+
else:
|
1262 |
+
self.seen_tokens += key_states.shape[-2]
|
1263 |
+
|
1264 |
+
# Update the cache
|
1265 |
+
if len(self.key_cache) <= layer_idx:
|
1266 |
+
return key_states, value_states
|
1267 |
+
else:
|
1268 |
+
key_states = torch.cat(
|
1269 |
+
[self.key_cache[layer_idx][:, head_idx : head_idx + 1].cuda(), key_states],
|
1270 |
+
dim=-2,
|
1271 |
+
)
|
1272 |
+
value_states = torch.cat(
|
1273 |
+
[
|
1274 |
+
self.value_cache[layer_idx][:, head_idx : head_idx + 1].cuda(),
|
1275 |
+
value_states,
|
1276 |
+
],
|
1277 |
+
dim=-2,
|
1278 |
+
)
|
1279 |
+
return key_states, value_states
|
minference/version.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
_MAJOR = "0"
|
5 |
+
_MINOR = "1"
|
6 |
+
# On master and in a nightly release the patch should be one ahead of the last
|
7 |
+
# released build.
|
8 |
+
_PATCH = "0"
|
9 |
+
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
|
10 |
+
# https://semver.org/#is-v123-a-semantic-version for the semantics.
|
11 |
+
_SUFFIX = "alpha.1"
|
12 |
+
|
13 |
+
VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
|
14 |
+
VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flash_attn
|
2 |
+
triton==2.1.0
|
3 |
+
pycuda==2023.1
|
4 |
+
accelerate
|
5 |
+
transformers
|