Transformers documentation

Clasificación de imágenes

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Clasificación de imágenes

La clasificación de imágenes asigna una etiqueta o clase a una imagen. A diferencia de la clasificación de texto o audio, las entradas son los valores de los píxeles que representan una imagen. La clasificación de imágenes tiene muchos usos, como la detección de daños tras una catástrofe, el control de la salud de los cultivos o la búsqueda de signos de enfermedad en imágenes médicas.

Esta guía te mostrará como hacer fine-tune al ViT en el dataset Food-101 para clasificar un alimento en una imagen.

Consulta la página de la tarea de clasificación de imágenes para obtener más información sobre sus modelos, datasets y métricas asociadas.

Carga el dataset Food-101

Carga solo las primeras 5000 imágenes del dataset Food-101 de la biblioteca 🤗 de Datasets ya que es bastante grande:

>>> from datasets import load_dataset

>>> food = load_dataset("food101", split="train[:5000]")

Divide el dataset en un train y un test set:

>>> food = food.train_test_split(test_size=0.2)

A continuación, observa un ejemplo:

>>> food["train"][0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x7F52AFC8AC50>,
 'label': 79}

El campo image contiene una imagen PIL, y cada label es un número entero que representa una clase. Crea un diccionario que asigne un nombre de label a un entero y viceversa. El mapeo ayudará al modelo a recuperar el nombre de label a partir del número de la misma:

>>> labels = food["train"].features["label"].names
>>> label2id, id2label = dict(), dict()
>>> for i, label in enumerate(labels):
...     label2id[label] = str(i)
...     id2label[str(i)] = label

Ahora puedes convertir el número de label en un nombre de label para obtener más información:

>>> id2label[str(79)]
'prime_rib'

Cada clase de alimento - o label - corresponde a un número; 79 indica una costilla de primera en el ejemplo anterior.

Preprocesa

Carga el image processor de ViT para procesar la imagen en un tensor:

>>> from transformers import AutoImageProcessor

>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

Aplica varias transformaciones de imagen al dataset para hacer el modelo más robusto contra el overfitting. En este caso se utilizará el módulo transforms de torchvision. Recorta una parte aleatoria de la imagen, cambia su tamaño y normalízala con la media y la desviación estándar de la imagen:

>>> from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

>>> normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
>>> _transforms = Compose([RandomResizedCrop(image_processor.size["height"]), ToTensor(), normalize])

Crea una función de preprocesamiento que aplique las transformaciones y devuelva los pixel_values - los inputs al modelo - de la imagen:

>>> def transforms(examples):
...     examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
...     del examples["image"]
...     return examples

Utiliza el método with_transform de 🤗 Dataset para aplicar las transformaciones sobre todo el dataset. Las transformaciones se aplican sobre la marcha cuando se carga un elemento del dataset:

>>> food = food.with_transform(transforms)

Utiliza DefaultDataCollator para crear un batch de ejemplos. A diferencia de otros data collators en 🤗 Transformers, el DefaultDataCollator no aplica un preprocesamiento adicional como el padding.

>>> from transformers import DefaultDataCollator

>>> data_collator = DefaultDataCollator()

Entrena

Carga ViT con AutoModelForImageClassification. Especifica el número de labels, y pasa al modelo el mapping entre el número de label y la clase de label:

>>> from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

>>> model = AutoModelForImageClassification.from_pretrained(
...     "google/vit-base-patch16-224-in21k",
...     num_labels=len(labels),
...     id2label=id2label,
...     label2id=label2id,
... )

Si no estás familiarizado con el fine-tuning de un modelo con el Trainer, echa un vistazo al tutorial básico aquí!

Al llegar a este punto, solo quedan tres pasos:

  1. Define tus hiperparámetros de entrenamiento en TrainingArguments. Es importante que no elimines las columnas que no se utilicen, ya que esto hará que desaparezca la columna image. Sin la columna image no puedes crear pixel_values. Establece remove_unused_columns=False para evitar este comportamiento.
  2. Pasa los training arguments al Trainer junto con el modelo, los datasets, tokenizer y data collator.
  3. Llama train() para hacer fine-tune de tu modelo.
>>> training_args = TrainingArguments(
...     output_dir="./results",
...     per_device_train_batch_size=16,
...     eval_strategy="steps",
...     num_train_epochs=4,
...     fp16=True,
...     save_steps=100,
...     eval_steps=100,
...     logging_steps=10,
...     learning_rate=2e-4,
...     save_total_limit=2,
...     remove_unused_columns=False,
... )

>>> trainer = Trainer(
...     model=model,
...     args=training_args,
...     data_collator=data_collator,
...     train_dataset=food["train"],
...     eval_dataset=food["test"],
...     processing_class=image_processor,
... )

>>> trainer.train()

Para ver un ejemplo más a profundidad de cómo hacer fine-tune a un modelo para clasificación de imágenes, echa un vistazo al correspondiente PyTorch notebook.

< > Update on GitHub