mwitiderrick commited on
Commit
d9b2258
1 Parent(s): 08d7494

Create onnx_kv_inject.py

Browse files
Files changed (1) hide show
  1. onnx_kv_inject.py +17 -0
onnx_kv_inject.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import os
3
+ import onnx
4
+ from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector
5
+ from sparseml.onnx.utils import ONNXGraph
6
+ @click.command()
7
+ @click.option('--input-file', help='Path to the input ONNX model file')
8
+ @click.option('--output-file', help='Output path for the modified model')
9
+ def modify_model(input_file, output_file):
10
+ model = onnx.load(input_file, load_external_data=False)
11
+ model = KeyValueCacheInjector(model_path=os.path.dirname(input_file)).apply(model)
12
+ graph = ONNXGraph(model)
13
+ graph.delete_orphaned_node_branches()
14
+ onnx.save(model, output_file)
15
+ print(f"Modified model saved to: {output_file}")
16
+ if __name__ == '__main__':
17
+ modify_model()