mwitiderrick's picture
Create onnx_kv_inject.py
d9b2258
raw
history blame
704 Bytes
import click
import os
import onnx
from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector
from sparseml.onnx.utils import ONNXGraph
@click.command()
@click.option('--input-file', help='Path to the input ONNX model file')
@click.option('--output-file', help='Output path for the modified model')
def modify_model(input_file, output_file):
model = onnx.load(input_file, load_external_data=False)
model = KeyValueCacheInjector(model_path=os.path.dirname(input_file)).apply(model)
graph = ONNXGraph(model)
graph.delete_orphaned_node_branches()
onnx.save(model, output_file)
print(f"Modified model saved to: {output_file}")
if __name__ == '__main__':
modify_model()