Clasificación de imágenes
La clasificación de imágenes asigna una etiqueta o clase a una imagen. A diferencia de la clasificación de texto o audio, las entradas son los valores de los píxeles que representan una imagen. La clasificación de imágenes tiene muchos usos, como la detección de daños tras una catástrofe, el control de la salud de los cultivos o la búsqueda de signos de enfermedad en imágenes médicas.
Esta guía te mostrará como hacer fine-tune al ViT en el dataset Food-101 para clasificar un alimento en una imagen.
Consulta la página de la tarea de clasificación de imágenes para obtener más información sobre sus modelos, datasets y métricas asociadas.
Carga el dataset Food-101
Carga solo las primeras 5000 imágenes del dataset Food-101 de la biblioteca 🤗 de Datasets ya que es bastante grande:
>>> from datasets import load_dataset
>>> food = load_dataset("food101", split="train[:5000]")
Divide el dataset en un train y un test set:
>>> food = food.train_test_split(test_size=0.2)
A continuación, observa un ejemplo:
>>> food["train"][0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x7F52AFC8AC50>,
'label': 79}
El campo image
contiene una imagen PIL, y cada label
es un número entero que representa una clase. Crea un diccionario que asigne un nombre de label a un entero y viceversa. El mapeo ayudará al modelo a recuperar el nombre de label a partir del número de la misma:
>>> labels = food["train"].features["label"].names
>>> label2id, id2label = dict(), dict()
>>> for i, label in enumerate(labels):
... label2id[label] = str(i)
... id2label[str(i)] = label
Ahora puedes convertir el número de label en un nombre de label para obtener más información:
>>> id2label[str(79)]
'prime_rib'
Cada clase de alimento - o label - corresponde a un número; 79
indica una costilla de primera en el ejemplo anterior.
Preprocesa
Carga el image processor de ViT para procesar la imagen en un tensor:
>>> from transformers import AutoImageProcessor
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
Aplica varias transformaciones de imagen al dataset para hacer el modelo más robusto contra el overfitting. En este caso se utilizará el módulo transforms
de torchvision. Recorta una parte aleatoria de la imagen, cambia su tamaño y normalízala con la media y la desviación estándar de la imagen:
>>> from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
>>> normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
>>> _transforms = Compose([RandomResizedCrop(image_processor.size["height"]), ToTensor(), normalize])
Crea una función de preprocesamiento que aplique las transformaciones y devuelva los pixel_values
- los inputs al modelo - de la imagen:
>>> def transforms(examples):
... examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
... del examples["image"]
... return examples
Utiliza el método with_transform
de 🤗 Dataset para aplicar las transformaciones sobre todo el dataset. Las transformaciones se aplican sobre la marcha cuando se carga un elemento del dataset:
>>> food = food.with_transform(transforms)
Utiliza DefaultDataCollator
para crear un batch de ejemplos. A diferencia de otros data collators en 🤗 Transformers, el DefaultDataCollator no aplica un preprocesamiento adicional como el padding.
>>> from transformers import DefaultDataCollator
>>> data_collator = DefaultDataCollator()
Entrena
Carga ViT con AutoModelForImageClassification
. Especifica el número de labels, y pasa al modelo el mapping entre el número de label y la clase de label:
>>> from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
>>> model = AutoModelForImageClassification.from_pretrained(
... "google/vit-base-patch16-224-in21k",
... num_labels=len(labels),
... id2label=id2label,
... label2id=label2id,
... )
Si no estás familiarizado con el fine-tuning de un modelo con el Trainer
, echa un vistazo al tutorial básico aquí!
Al llegar a este punto, solo quedan tres pasos:
- Define tus hiperparámetros de entrenamiento en
TrainingArguments
. Es importante que no elimines las columnas que no se utilicen, ya que esto hará que desaparezca la columnaimage
. Sin la columnaimage
no puedes crearpixel_values
. Estableceremove_unused_columns=False
para evitar este comportamiento. - Pasa los training arguments al
Trainer
junto con el modelo, los datasets, tokenizer y data collator. - Llama
train()
para hacer fine-tune de tu modelo.
>>> training_args = TrainingArguments(
... output_dir="./results",
... per_device_train_batch_size=16,
... eval_strategy="steps",
... num_train_epochs=4,
... fp16=True,
... save_steps=100,
... eval_steps=100,
... logging_steps=10,
... learning_rate=2e-4,
... save_total_limit=2,
... remove_unused_columns=False,
... )
>>> trainer = Trainer(
... model=model,
... args=training_args,
... data_collator=data_collator,
... train_dataset=food["train"],
... eval_dataset=food["test"],
... processing_class=image_processor,
... )
>>> trainer.train()
Para ver un ejemplo más a profundidad de cómo hacer fine-tune a un modelo para clasificación de imágenes, echa un vistazo al correspondiente PyTorch notebook.