Optimum documentation

Export a model to ONNX with optimum.exporters.onnx

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Export a model to ONNX with optimum.exporters.onnx

Summary

Exporting a model to ONNX is as simple as

optimum-cli export onnx --model gpt2 gpt2_onnx/

Check out the help for more options:

optimum-cli export onnx --help

Why use ONNX?

If you need to deploy 🤗 Transformers or 🤗 Diffusers models in production environments, we recommend exporting them to a serialized format that can be loaded and executed on specialized runtimes and hardware. In this guide, we’ll show you how to export these models to ONNX (Open Neural Network eXchange).

ONNX is an open standard that defines a common set of operators and a common file format to represent deep learning models in a wide variety of frameworks, including PyTorch and TensorFlow. When a model is exported to the ONNX format, these operators are used to construct a computational graph (often called an intermediate representation) which represents the flow of data through the neural network.

By exposing a graph with standardized operators and data types, ONNX makes it easy to switch between frameworks. For example, a model trained in PyTorch can be exported to ONNX format and then imported in TensorRT or OpenVINO.

Once exported, a model can be optimized for inference via techniques such as graph optimization and quantization. Check the optimum.onnxruntime subpackage to optimize and run ONNX models!

🤗 Optimum provides support for the ONNX export by leveraging configuration objects. These configuration objects come ready made for a number of model architectures, and are designed to be easily extendable to other architectures.

To check the supported architectures, go to the configuration reference page.

Exporting a model to ONNX using the CLI

To export a 🤗 Transformers or 🤗 Diffusers model to ONNX, you’ll first need to install some extra dependencies:

pip install optimum[exporters]

The Optimum ONNX export can be used through Optimum command-line:

optimum-cli export onnx --help

usage: optimum-cli <command> [<args>] export onnx [-h] -m MODEL [--task TASK] [--monolith] [--device DEVICE] [--opset OPSET] [--atol ATOL]
                                                  [--framework {pt,tf}] [--pad_token_id PAD_TOKEN_ID] [--cache_dir CACHE_DIR] [--trust-remote-code]
                                                  [--no-post-process] [--optimize {O1,O2,O3,O4}] [--batch_size BATCH_SIZE]
                                                  [--sequence_length SEQUENCE_LENGTH] [--num_choices NUM_CHOICES] [--width WIDTH] [--height HEIGHT]
                                                  [--num_channels NUM_CHANNELS] [--feature_size FEATURE_SIZE] [--nb_max_frames NB_MAX_FRAMES]
                                                  [--audio_sequence_length AUDIO_SEQUENCE_LENGTH]
                                                  output

optional arguments:
  -h, --help            show this help message and exit

Required arguments:
  -m MODEL, --model MODEL
                        Model ID on huggingface.co or path on disk to load model from.
  output                Path indicating the directory where to store generated ONNX model.

Optional arguments:
  --task TASK           The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among: ['default', 'fill-mask', 'text-generation', 'text2text-generation', 'text-classification', 'token-classification', 'multiple-choice', 'object-detection', 'question-answering', 'image-classification', 'image-segmentation', 'masked-im', 'semantic-segmentation', 'automatic-speech-recognition', 'audio-classification', 'audio-frame-classification', 'automatic-speech-recognition', 'audio-xvector', 'image-to-text', 'zero-shot-object-detection', 'image-to-image', 'inpainting', 'text-to-image']. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder.
  --monolith            Force to export the model as a single ONNX file. By default, the ONNX exporter may break the model in several ONNX files, for example for encoder-decoder models where the encoder should be run only once while the decoder is looped over.
  --device DEVICE       The device to use to do the export. Defaults to "cpu".
  --opset OPSET         If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used.
  --atol ATOL           If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used.
  --framework {pt,tf}   The framework to use for the ONNX export. If not provided, will attempt to use the local checkpoint's original framework or what is available in the environment.
  --pad_token_id PAD_TOKEN_ID
                        This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it.
  --cache_dir CACHE_DIR
                        Path indicating where to store cache.
  --trust-remote-code   Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the model repository.
  --no-post-process     Allows to disable any post-processing done by default on the exported ONNX models. For example, the merging of decoder and decoder-with-past models into a single ONNX model file to reduce memory usage.
  --optimize {O1,O2,O3,O4}
                        Allows to run ONNX Runtime optimizations directly during the export. Some of these optimizations are specific to ONNX Runtime, and the resulting ONNX will not be usable with other runtime as OpenVINO or TensorRT. Possible options:
                            - O1: Basic general optimizations
                            - O2: Basic and extended general optimizations, transformers-specific fusions
                            - O3: Same as O2 with GELU approximation
                            - O4: Same as O3 with mixed precision (fp16, GPU-only, requires `--device cuda`)

Exporting a checkpoint can be done as follows:

optimum-cli export onnx --model distilbert-base-uncased-distilled-squad distilbert_base_uncased_squad_onnx/

You should see the following logs (along with potential logs from PyTorch / TensorFlow that were hidden here for clarity):

Automatic task detection to question-answering.
Framework not specified. Using pt to export the model.
Using framework PyTorch: 1.12.1

Validating ONNX model...
        -[✓] ONNX model output names match reference model (start_logits, end_logits)
        - Validating ONNX Model output "start_logits":
                -[✓] (2, 16) matches (2, 16)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "end_logits":
                -[✓] (2, 16) matches (2, 16)
                -[✓] all values close (atol: 0.0001)
All good, model saved at: distilbert_base_uncased_squad_onnx/model.onnx

This exports an ONNX graph of the checkpoint defined by the --model argument. As you can see, the task was automatically detected. This was possible because the model was on the Hub.

For local models, providing the --task argument is needed or it will default to the model architecture without any task specific head:

optimum-cli export onnx --model local_path --task question-answering distilbert_base_uncased_squad_onnx/

Note that providing the --task argument for a model on the Hub will disable the automatic task detection.

The resulting model.onnx file can then be run on one of the many accelerators that support the ONNX standard. For example, we can load and run the model with ONNX Runtime using the optimum.onnxruntime package as follows:

>>> from transformers import AutoTokenizer
>>> from optimum.onnxruntime import ORTModelForQuestionAnswering

>>> tokenizer = AutoTokenizer.from_pretrained("distilbert_base_uncased_squad_onnx")
>>> model = ORTModelForQuestionAnswering.from_pretrained("distilbert_base_uncased_squad_onnx")
>>> inputs = tokenizer("What am I using?", "Using DistilBERT with ONNX Runtime!", return_tensors="pt")
>>> outputs = model(**inputs)

Printing the outputs would give that:

QuestionAnsweringModelOutput(loss=None, start_logits=tensor([[-4.7652, -1.0452, -7.0409, -4.6864, -4.0277, -6.2021, -4.9473,  2.6287,
          7.6111, -1.2488, -2.0551, -0.9350,  4.9758, -0.7707,  2.1493, -2.0703,
         -4.3232, -4.9472]]), end_logits=tensor([[ 0.4382, -1.6502, -6.3654, -6.0661, -4.1482, -3.5779, -0.0774, -3.6168,
         -1.8750, -2.8910,  6.2582,  0.5425, -3.7699,  3.8232, -1.5073,  6.2311,
          3.3604, -0.0772]]), hidden_states=None, attentions=None)

As you can see, converting a model to ONNX does not mean leaving the Hugging Face ecosystem. You end up with a similar API as regular 🤗 Transformers models!

It is also possible to export the model to ONNX directly from the ORTModelForQuestionAnswering class by doing the following:

>>> model = ORTModelForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad", export=True)

For more information, check the optimum.onnxruntime documentation page on this topic.

The process is identical for TensorFlow checkpoints on the Hub. For example, we can export a pure TensorFlow checkpoint from the Keras organization as follows:

optimum-cli export onnx --model keras-io/transformers-qa distilbert_base_cased_squad_onnx/

Exporting a model to be used with Optimum’s ORTModel

Models exported through optimum-cli export onnx can be used directly in ORTModel. This is especially useful for encoder-decoder models, where in this case the export will split the encoder and decoder into two .onnx files, as the encoder is usually only run once while the decoder may be run several times in autogenerative tasks.

Exporting a model using past keys/values in the decoder

When exporting a decoder model used for generation, it can be useful to encapsulate in the exported ONNX the reuse of past keys and values. This allows to avoid recomputing the same intermediate activations during the generation.

In the ONNX export, the past keys/values are reused by default. This behavior corresponds to --task text2text-generation-with-past, --task text-generation-with-past, or --task automatic-speech-recognition-with-past. If for any purpose you would like to disable the export with past keys/values reuse, passing explicitly to optimum-cli export onnx the task text2text-generation, text-generation or automatic-speech-recognition is required.

A model exported using past key/values can be reused directly into Optimum’s ORTModel:

optimum-cli export onnx --model gpt2 gpt2_onnx/

and

>>> from transformers import AutoTokenizer
>>> from optimum.onnxruntime import ORTModelForCausalLM

>>> tokenizer = AutoTokenizer.from_pretrained("./gpt2_onnx/")
>>> model = ORTModelForCausalLM.from_pretrained("./gpt2_onnx/")
>>> inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt")
>>> gen_tokens = model.generate(**inputs)
>>> print(tokenizer.batch_decode(gen_tokens))
# prints ['My name is Arthur and I live in the United States of America. I am a member of the']

Selecting a task

Specifying a --task should not be necessary in most cases when exporting from a model on the Hugging Face Hub.

However, in case you need to check for a given a model architecture what tasks the ONNX export supports, we got you covered. First, you can check the list of supported tasks for both PyTorch and TensorFlow here.

For each model architecture, you can find the list of supported tasks via the TasksManager. For example, for DistilBERT, for the ONNX export, we have:

>>> from optimum.exporters.tasks import TasksManager

>>> distilbert_tasks = list(TasksManager.get_supported_tasks_for_model_type("distilbert", "onnx").keys())
>>> print(distilbert_tasks)
['default', 'fill-mask', 'text-classification', 'multiple-choice', 'token-classification', 'question-answering']

You can then pass one of these tasks to the --task argument in the optimum-cli export onnx command, as mentioned above.

Custom export of Transformers models

Customize the export of official Transformers models

Optimum allows for advanced users a finer-grained control over the configuration for the ONNX export. This is especially useful if you would like to export models with different keyword arguments, for example using output_attentions=True or output_hidden_states=True.

To support these use cases, ~exporters.main_export supports two arguments: model_kwargs and custom_onnx_configs, which are used in the following fashion:

  • model_kwargs allows to override some of the default arguments to the models forward, in practice as model(**reference_model_inputs, **model_kwargs).
  • custom_onnx_configs should be a Dict[str, OnnxConfig], mapping from the submodel name (usually model, encoder_model, decoder_model, or decoder_model_with_past - reference) to a custom ONNX configuration for the given submodel.

A complete example is given below, allowing to export models with output_attentions=True.

from optimum.exporters.onnx import main_export
from optimum.exporters.onnx.model_configs import WhisperOnnxConfig
from transformers import AutoConfig

from optimum.exporters.onnx.base import ConfigBehavior
from typing import Dict

class CustomWhisperOnnxConfig(WhisperOnnxConfig):
    @property
    def outputs(self) -> Dict[str, Dict[int, str]]:
        common_outputs = super().outputs

        if self._behavior is ConfigBehavior.ENCODER:
            for i in range(self._config.encoder_layers):
                common_outputs[f"encoder_attentions.{i}"] = {0: "batch_size"}
        elif self._behavior is ConfigBehavior.DECODER:
            for i in range(self._config.decoder_layers):
                common_outputs[f"decoder_attentions.{i}"] = {
                    0: "batch_size",
                    2: "decoder_sequence_length",
                    3: "past_decoder_sequence_length + 1"
                }
            for i in range(self._config.decoder_layers):
                common_outputs[f"cross_attentions.{i}"] = {
                    0: "batch_size",
                    2: "decoder_sequence_length",
                    3: "encoder_sequence_length_out"
                }

        return common_outputs

    @property
    def torch_to_onnx_output_map(self):
        if self._behavior is ConfigBehavior.ENCODER:
            # The encoder export uses WhisperEncoder that returns the key "attentions"
            return {"attentions": "encoder_attentions"}
        else:
            return {}

model_id = "openai/whisper-tiny.en"
config = AutoConfig.from_pretrained(model_id)

custom_whisper_onnx_config = CustomWhisperOnnxConfig(
        config=config,
        task="automatic-speech-recognition",
)

encoder_config = custom_whisper_onnx_config.with_behavior("encoder")
decoder_config = custom_whisper_onnx_config.with_behavior("decoder", use_past=False)
decoder_with_past_config = custom_whisper_onnx_config.with_behavior("decoder", use_past=True)

custom_onnx_configs={
    "encoder_model": encoder_config,
    "decoder_model": decoder_config,
    "decoder_with_past_model": decoder_with_past_config,
}

main_export(
    model_id,
    output="custom_whisper_onnx",
    no_post_process=True,
    model_kwargs={"output_attentions": True},
    custom_onnx_configs=custom_onnx_configs
)

For tasks that require only a single ONNX file (e.g. encoder-only), an exported model with custom inputs/outputs can then be used with the class optimum.onnxruntime.ORTModelForCustomTasks for inference with ONNX Runtime on CPU or GPU.

Customize the export of Transformers models with custom modeling

Optimum supports the export of Transformers models with custom modeling that use trust_remote_code=True, not officially supported in the Transormers library but usable with its functionality as pipelines and generation.

Examples of such models are THUDM/chatglm2-6b and mosaicml/mpt-30b.

To export custom models, a dictionary custom_onnx_configs needs to be passed to main_export(), with the ONNX config definition for all the subparts of the model to export (for example, encoder and decoder subparts). The example below allows to export mosaicml/mpt-7b model:

from optimum.exporters.onnx import main_export

from transformers import AutoConfig

from optimum.exporters.onnx.config import TextDecoderOnnxConfig
from optimum.utils import NormalizedTextConfig, DummyPastKeyValuesGenerator
from typing import Dict


class MPTDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
    """
    MPT swaps the two last dimensions for the key cache compared to usual transformers
    decoder models, thus the redefinition here.
    """
    def generate(self, input_name: str, framework: str = "pt"):
        past_key_shape = (
            self.batch_size,
            self.num_attention_heads,
            self.hidden_size // self.num_attention_heads,
            self.sequence_length,
        )
        past_value_shape = (
            self.batch_size,
            self.num_attention_heads,
            self.sequence_length,
            self.hidden_size // self.num_attention_heads,
        )
        return [
            (
                self.random_float_tensor(past_key_shape, framework=framework),
                self.random_float_tensor(past_value_shape, framework=framework),
            )
            for _ in range(self.num_layers)
        ]

class CustomMPTOnnxConfig(TextDecoderOnnxConfig):
    DUMMY_INPUT_GENERATOR_CLASSES = (MPTDummyPastKeyValuesGenerator,) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
    DUMMY_PKV_GENERATOR_CLASS = MPTDummyPastKeyValuesGenerator

    DEFAULT_ONNX_OPSET = 14  # aten::tril operator requires opset>=14
    NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
        hidden_size="d_model",
        num_layers="n_layers",
        num_attention_heads="n_heads"
    )

    def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
        """
        Adapted from https://github.com/huggingface/optimum/blob/v1.9.0/optimum/exporters/onnx/base.py#L625
        """
        if direction not in ["inputs", "outputs"]:
            raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

        if direction == "inputs":
            decoder_sequence_name = "past_sequence_length"
            name = "past_key_values"
        else:
            decoder_sequence_name = "past_sequence_length + 1"
            name = "present"

        for i in range(self._normalized_config.num_layers):
            inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 3: decoder_sequence_name}
            inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name}


model_id = "/home/fxmarty/hf_internship/optimum/tiny-mpt-random-remote-code"
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)

onnx_config = CustomMPTOnnxConfig(
    config=config,
    task="text-generation",
    use_past_in_inputs=False,
    use_present_in_outputs=True,
)
onnx_config_with_past = CustomMPTOnnxConfig(config, task="text-generation", use_past=True)

custom_onnx_configs = {
    "decoder_model": onnx_config,
    "decoder_with_past_model": onnx_config_with_past,
}

main_export(
    model_id,
    output="mpt_onnx",
    task="text-generation-with-past",
    trust_remote_code=True,
    custom_onnx_configs=custom_onnx_configs,
    no_post_process=True,
)

Moreover, the advanced argument fn_get_submodels to main_export allows to customize how the submodels are extracted in case the model needs to be exported in several submodels. Examples of such functions can be [consulted here](link to utils.py relevant code once merged).

< > Update on GitHub