hiera-tiny-224-in1k / README.md
merve's picture
merve HF staff
Update README.md
b72e7d4
metadata
license: cc-by-nc-4.0
pipeline_tag: feature-extraction

Hiera (Tiny)

Hiera is a hierarchical transformer that is a much more efficient alternative to previous series of hierarchical transformers (ConvNeXT and Swin). Vanilla transformer architectures (Dosovitskiy et al. 2020) are very popular yet simple and scalable architectures that enable pretraining strategies such as MAE (He et al., 2022). However, they use the same spatial resolution and number of channels throughout the network, ViTs make inefficient use of their parameters. This is in contrast to prior “hierarchical” or “multi-scale” models (e.g., Krizhevsky et al. (2012); He et al. (2016)), which use fewer channels but higher spatial resolution in early stages with simpler features, and more channels but lower spatial resolution later in the model with more complex features. These models are way too complex though which add overhead operations to achieve state-of-the-art accuracy in ImageNet-1k, making the model slower. Hiera attempts to address this issue by teaching the model spatial biases by training MAE. image/png

How to Use

Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes: Clone the repository.

git lfs install
git clone https://huggingface.co/merve/hiera-tiny-ft-224-in1k
pip install timm
cd hiera-tiny-ft-224-in1k
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
import hiera
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import requests
import sys
sys.path.append("..")

model = hiera.hiera_small_224(pretrained=True, checkpoint="mae_in1k_ft_in1k") 
input_size = 224
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

# preprocess the image
transform_list = [
    transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(input_size)
]
transform_vis  = transforms.Compose(transform_list)
transform_norm = transforms.Compose(transform_list + [
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
img_vis = transform_vis(image)
img_norm = transform_norm(image)

# Get imagenet class as output
out = model(img_norm[None, ...])
# tabby cat
out.argmax(dim=-1).item()

You can try the fine-tuned model here.