AMD Ryzen AI
AMD’s Ryzen™ AI family of laptop processors provide users with an integrated Neural Processing Unit (NPU) which offloads the host CPU and GPU from AI processing tasks. Ryzen™ AI software consists of the Vitis™ AI execution provider (EP) for ONNX Runtime combined with quantization tools and a pre-optimized model zoo. All of this is made possible based on Ryzen™ AI technology built on AMD XDNA™ architecture, purpose-built to run AI workloads efficiently and locally, offering a host of benefits for the developer innovating the next groundbreaking AI app.
Optimum-AMD provides easy interface for loading and inference of Hugging Face models on Ryzen AI accelerator.
Installation
Ryzen AI Environment setup
A Ryzen AI environment needs to be enabled to use this library. Please refer to Ryzen AI’s Installation and Runtime Setup.
Note:
The RyzenAI Model requires a runtime configuration file. A default version of this runtime configuration file can be found in the Ryzen AI VOE package, extracted during installation under the name vaip_config.json
.
For more information refer to runtime-configuration-file
In case no runtime configuration file is provided, the library will use the configuration defined in RyzenAIXXX model class. For available configs see ryzenai/configs/.
Install Optimum-amd
git clone https://github.com/huggingface/optimum-amd.git
cd optimum-amd
pip install -e .[ryzenai]
Inference with pre-optimized models
RyzenAI provides pre-optimized models for various tasks such as image classification, super-resolution, object-detection, etc. Here’s an example to run Resnet for image classification:
>>> from functools import partial
>>> from datasets import load_dataset
>>> from optimum.amd.ryzenai import RyzenAIModelForImageClassification
>>> from transformers import AutoImageProcessor, pipeline
>>> model_id = "amd/resnet50"
>>> model = RyzenAIModelForImageClassification.from_pretrained(model_id)
>>> processor = AutoImageProcessor.from_pretrained(model_id)
>>> # Load image
>>> dataset = load_dataset("imagenet-1k", split="validation", streaming=True, trust_remote_code=True)
>>> data = next(iter(dataset))
>>> image = data["image"]
>>> cls_pipe = pipeline(
... "image-classification", model=model, image_processor=partial(processor, data_format="channels_last")
... )
>>> outputs = cls_pipe(image)
>>> print(outputs)
Minimal working example for 🤗 Timm
Pre-requisites
- Export the model using Optimum Exporters
- Quantize the ONNX model using the RyzenAI quantization tools. For more information on quantization refer to Model Quantization guide.
Load model with Ryzen AI class
>>> import requests
>>> from PIL import Image
>>> from optimum.amd.ryzenai import RyzenAIModelForImageClassification
>>> from transformers import PretrainedConfig, pipeline
>>> import timm
>>> import torch
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # See [quantize.py](https://huggingface.co/mohitsha/timm-resnet18-onnx-quantized-ryzen/blob/main/quantize.py) for more details on quantization.
>>> quantized_model_path = "mohitsha/timm-resnet18-onnx-quantized-ryzen"
>>> model = RyzenAIModelForImageClassification.from_pretrained(quantized_model_path)
>>> config = PretrainedConfig.from_pretrained(quantized_model_path)
>>> # preprocess config
>>> data_config = timm.data.resolve_data_config(pretrained_cfg=config.pretrained_cfg)
>>> transforms = timm.data.create_transform(**data_config, is_training=False)
>>> output = model(transforms(image).unsqueeze(0)).logits # unsqueeze single image into batch of 1
>>> top5_probabilities, top5_class_indices = torch.topk(torch.softmax(output, dim=1) * 100, k=5)
Minimal working example for 🤗 Transformers
Pre-requisites
- Export the model using Optimum Exporters
- Quantize the ONNX model using the RyzenAI quantization tools. For more information on quantization refer to Model Quantization guide.
Load model with Ryzen AI class
To load a transformers model and run inference with RyzenAI, you can just replace your AutoModelForXxx
class with the corresponding RyzenAIModelForXxx
class.
See below example for Image classification.
>>> import requests
>>> from PIL import Image
>>> from optimum.amd.ryzenai import RyzenAIModelForImageClassification
>>> from transformers import AutoFeatureExtractor, pipeline
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>>
>>> # See [quantize.py](https://huggingface.co/mohitsha/transformers-resnet18-onnx-quantized-ryzen/blob/main/quantize.py) for more details on quantization.
>>> quantized_model_path = "mohitsha/transformers-resnet18-onnx-quantized-ryzen"
>>> model = RyzenAIModelForImageClassification.from_pretrained(quantized_model_path)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(quantized_model_path)
>>> cls_pipe = pipeline("image-classification", model=model, feature_extractor=feature_extractor)
>>> outputs = cls_pipe(image)