{ "cells": [ { "cell_type": "markdown", "id": "4ef57047", "metadata": {}, "source": [ "# Using PEFT with timm" ] }, { "cell_type": "markdown", "id": "80561acc", "metadata": {}, "source": [ "`peft` allows us to train any model with LoRA as long as the layer type is supported. Since `Conv2D` is one of the supported layer types, it makes sense to test it on image models.\n", "\n", "In this short notebook, we will demonstrate this with an image classification task using [`timm`](https://huggingface.co/docs/timm/index)." ] }, { "cell_type": "markdown", "id": "aa26c285", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "markdown", "id": "552b9040", "metadata": {}, "source": [ "Make sure that you have the latest version of `peft` installed. To ensure that, run this in your Python environment:\n", " \n", " python -m pip install --upgrade peft\n", " \n", "Also, ensure that `timm` is installed:\n", "\n", " python -m pip install --upgrade timm" ] }, { "cell_type": "code", "execution_count": 1, "id": "e600b7d5", "metadata": {}, "outputs": [], "source": [ "import timm\n", "import torch\n", "from PIL import Image\n", "from timm.data import resolve_data_config\n", "from timm.data.transforms_factory import create_transform" ] }, { "cell_type": "code", "execution_count": 2, "id": "73a2ae54", "metadata": {}, "outputs": [], "source": [ "import peft\n", "from datasets import load_dataset" ] }, { "cell_type": "code", "execution_count": 3, "id": "82c628fd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.manual_seed(0)" ] }, { "cell_type": "markdown", "id": "701ab69c", "metadata": {}, "source": [ "## Loading the pre-trained base model" ] }, { "cell_type": "markdown", "id": "20bff51a", "metadata": {}, "source": [ "We use a small pretrained `timm` model, `PoolFormer`. Find more info on its [model card](https://huggingface.co/timm/poolformer_m36.sail_in1k)." ] }, { "cell_type": "code", "execution_count": 4, "id": "495cb3d6", "metadata": {}, "outputs": [], "source": [ "model_id_timm = \"timm/poolformer_m36.sail_in1k\"" ] }, { "cell_type": "markdown", "id": "2dc06f9b", "metadata": {}, "source": [ "We tell `timm` that we deal with 3 classes, to ensure that the classification layer has the correct size." ] }, { "cell_type": "code", "execution_count": 5, "id": "090564bc", "metadata": {}, "outputs": [], "source": [ "model = timm.create_model(model_id_timm, pretrained=True, num_classes=3)" ] }, { "cell_type": "markdown", "id": "beca5794", "metadata": {}, "source": [ "These are the transformations steps necessary to process the image." ] }, { "cell_type": "code", "execution_count": 6, "id": "9df2e113", "metadata": {}, "outputs": [], "source": [ "transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))" ] }, { "cell_type": "markdown", "id": "3f809dfa", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "markdown", "id": "a398fe22", "metadata": {}, "source": [ "For this exercise, we use the \"beans\" dataset. More details on the dataset can be found on [its datasets page](https://huggingface.co/datasets/beans). For our purposes, what's important is that we have image inputs and the target we're trying to predict is one of three classes for each image." ] }, { "cell_type": "code", "execution_count": 7, "id": "0fddc704", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset beans (/home/vinh/.cache/huggingface/datasets/beans/default/0.0.0/90c755fb6db1c0ccdad02e897a37969dbf070bed3755d4391e269ff70642d791)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "05592574da474b81ab736d6babb5e19d", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds_train[0][\"image\"]" ] }, { "cell_type": "markdown", "id": "880ea6c4", "metadata": {}, "source": [ "We define a small processing function which is responsible for loading and transforming the images, as well as extracting the labels." ] }, { "cell_type": "code", "execution_count": 10, "id": "142df842", "metadata": {}, "outputs": [], "source": [ "def process(batch):\n", " x = torch.cat([transform(img).unsqueeze(0) for img in batch[\"image\"]])\n", " y = torch.tensor(batch[\"labels\"])\n", " return {\"x\": x, \"y\": y}" ] }, { "cell_type": "code", "execution_count": 11, "id": "9744257b", "metadata": {}, "outputs": [], "source": [ "ds_train.set_transform(process)\n", "ds_valid.set_transform(process)" ] }, { "cell_type": "code", "execution_count": 12, "id": "282374be", "metadata": {}, "outputs": [], "source": [ "train_loader = torch.utils.data.DataLoader(ds_train, batch_size=32)\n", "valid_loader = torch.utils.data.DataLoader(ds_valid, batch_size=32)" ] }, { "cell_type": "markdown", "id": "5dcd3329", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "markdown", "id": "969bc374", "metadata": {}, "source": [ "This is just a function that performs the train loop, nothing fancy happening." ] }, { "cell_type": "code", "execution_count": 13, "id": "b9fc9588", "metadata": {}, "outputs": [], "source": [ "def train(model, optimizer, criterion, train_dataloader, valid_dataloader, epochs):\n", " for epoch in range(epochs):\n", " model.train()\n", " train_loss = 0\n", " for batch in train_dataloader:\n", " xb, yb = batch[\"x\"], batch[\"y\"]\n", " xb, yb = xb.to(device), yb.to(device)\n", " outputs = model(xb)\n", " lsm = torch.nn.functional.log_softmax(outputs, dim=-1)\n", " loss = criterion(lsm, yb)\n", " train_loss += loss.detach().float()\n", " loss.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " model.eval()\n", " valid_loss = 0\n", " correct = 0\n", " n_total = 0\n", " for batch in valid_dataloader:\n", " xb, yb = batch[\"x\"], batch[\"y\"]\n", " xb, yb = xb.to(device), yb.to(device)\n", " with torch.no_grad():\n", " outputs = model(xb)\n", " lsm = torch.nn.functional.log_softmax(outputs, dim=-1)\n", " loss = criterion(lsm, yb)\n", " valid_loss += loss.detach().float()\n", " correct += (outputs.argmax(-1) == yb).sum().item()\n", " n_total += len(yb)\n", "\n", " train_loss_total = (train_loss / len(train_dataloader)).item()\n", " valid_loss_total = (valid_loss / len(valid_dataloader)).item()\n", " valid_acc_total = correct / n_total\n", " print(f\"{epoch=:<2} {train_loss_total=:.4f} {valid_loss_total=:.4f} {valid_acc_total=:.4f}\")" ] }, { "cell_type": "markdown", "id": "3fd58357", "metadata": {}, "source": [ "### Selecting which layers to fine-tune with LoRA" ] }, { "cell_type": "markdown", "id": "7987321c", "metadata": {}, "source": [ "Let's take a look at the layers of our model. We only print the first 30, since there are quite a few:" ] }, { "cell_type": "code", "execution_count": 14, "id": "55a7be4d", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "[('', timm.models.metaformer.MetaFormer),\n", " ('stem', timm.models.metaformer.Stem),\n", " ('stem.conv', torch.nn.modules.conv.Conv2d),\n", " ('stem.norm', torch.nn.modules.linear.Identity),\n", " ('stages', torch.nn.modules.container.Sequential),\n", " ('stages.0', timm.models.metaformer.MetaFormerStage),\n", " ('stages.0.downsample', torch.nn.modules.linear.Identity),\n", " ('stages.0.blocks', torch.nn.modules.container.Sequential),\n", " ('stages.0.blocks.0', timm.models.metaformer.MetaFormerBlock),\n", " ('stages.0.blocks.0.norm1', timm.layers.norm.GroupNorm1),\n", " ('stages.0.blocks.0.token_mixer', timm.models.metaformer.Pooling),\n", " ('stages.0.blocks.0.token_mixer.pool', torch.nn.modules.pooling.AvgPool2d),\n", " ('stages.0.blocks.0.drop_path1', torch.nn.modules.linear.Identity),\n", " ('stages.0.blocks.0.layer_scale1', timm.models.metaformer.Scale),\n", " ('stages.0.blocks.0.res_scale1', torch.nn.modules.linear.Identity),\n", " ('stages.0.blocks.0.norm2', timm.layers.norm.GroupNorm1),\n", " ('stages.0.blocks.0.mlp', timm.layers.mlp.Mlp),\n", " ('stages.0.blocks.0.mlp.fc1', torch.nn.modules.conv.Conv2d),\n", " ('stages.0.blocks.0.mlp.act', torch.nn.modules.activation.GELU),\n", " ('stages.0.blocks.0.mlp.drop1', torch.nn.modules.dropout.Dropout),\n", " ('stages.0.blocks.0.mlp.norm', torch.nn.modules.linear.Identity),\n", " ('stages.0.blocks.0.mlp.fc2', torch.nn.modules.conv.Conv2d),\n", " ('stages.0.blocks.0.mlp.drop2', torch.nn.modules.dropout.Dropout),\n", " ('stages.0.blocks.0.drop_path2', torch.nn.modules.linear.Identity),\n", " ('stages.0.blocks.0.layer_scale2', timm.models.metaformer.Scale),\n", " ('stages.0.blocks.0.res_scale2', torch.nn.modules.linear.Identity),\n", " ('stages.0.blocks.1', timm.models.metaformer.MetaFormerBlock),\n", " ('stages.0.blocks.1.norm1', timm.layers.norm.GroupNorm1),\n", " ('stages.0.blocks.1.token_mixer', timm.models.metaformer.Pooling),\n", " ('stages.0.blocks.1.token_mixer.pool', torch.nn.modules.pooling.AvgPool2d)]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[(n, type(m)) for n, m in model.named_modules()][:30]" ] }, { "cell_type": "markdown", "id": "09af9349", "metadata": {}, "source": [ "Most of these layers are not good targets for LoRA, but we see a couple that should interest us. Their names are `'stages.0.blocks.0.mlp.fc1'`, etc. With a bit of regex, we can match them easily.\n", "\n", "Also, we should inspect the name of the classification layer, since we want to train that one too!" ] }, { "cell_type": "code", "execution_count": 15, "id": "8b98d9ef", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "[('head.global_pool.flatten', torch.nn.modules.linear.Identity),\n", " ('head.norm', timm.layers.norm.LayerNorm2d),\n", " ('head.flatten', torch.nn.modules.flatten.Flatten),\n", " ('head.drop', torch.nn.modules.linear.Identity),\n", " ('head.fc', torch.nn.modules.linear.Linear)]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[(n, type(m)) for n, m in model.named_modules()][-5:]" ] }, { "cell_type": "markdown", "id": "00e75b78", "metadata": {}, "source": [ " config = peft.LoraConfig(\n", " r=8,\n", " target_modules=r\".*\\.mlp\\.fc\\d|head\\.fc\",\n", " )" ] }, { "cell_type": "markdown", "id": "23814d70", "metadata": {}, "source": [ "Okay, this gives us all the information we need to fine-tune this model. With a bit of regex, we match the convolutional layers that should be targeted for LoRA. We also want to train the classification layer `'head.fc'` (without LoRA), so we add it to the `modules_to_save`." ] }, { "cell_type": "code", "execution_count": 16, "id": "81029587", "metadata": {}, "outputs": [], "source": [ "config = peft.LoraConfig(r=8, target_modules=r\".*\\.mlp\\.fc\\d\", modules_to_save=[\"head.fc\"])" ] }, { "cell_type": "markdown", "id": "e05876bc", "metadata": {}, "source": [ "Finally, let's create the `peft` model, the optimizer and criterion, and we can get started. As shown below, less than 2% of the model's total parameters are updated thanks to `peft`." ] }, { "cell_type": "code", "execution_count": 17, "id": "8cc5c5db", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 1,064,454 || all params: 56,467,974 || trainable%: 1.88505789139876\n" ] } ], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "peft_model = peft.get_peft_model(model, config).to(device)\n", "optimizer = torch.optim.Adam(peft_model.parameters(), lr=2e-4)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "peft_model.print_trainable_parameters()" ] }, { "cell_type": "code", "execution_count": 18, "id": "9e557e42", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch=0 train_loss_total=1.2999 valid_loss_total=1.0624 valid_acc_total=0.4436\n", "epoch=1 train_loss_total=1.0200 valid_loss_total=0.8906 valid_acc_total=0.7594\n", "epoch=2 train_loss_total=0.8874 valid_loss_total=0.6894 valid_acc_total=0.8045\n", "epoch=3 train_loss_total=0.7440 valid_loss_total=0.4797 valid_acc_total=0.8045\n", "epoch=4 train_loss_total=0.6025 valid_loss_total=0.3419 valid_acc_total=0.8120\n", "epoch=5 train_loss_total=0.4820 valid_loss_total=0.2589 valid_acc_total=0.8421\n", "epoch=6 train_loss_total=0.3567 valid_loss_total=0.2101 valid_acc_total=0.8722\n", "epoch=7 train_loss_total=0.2835 valid_loss_total=0.1385 valid_acc_total=0.9098\n", "epoch=8 train_loss_total=0.1815 valid_loss_total=0.1108 valid_acc_total=0.9474\n", "epoch=9 train_loss_total=0.1341 valid_loss_total=0.0785 valid_acc_total=0.9699\n", "CPU times: user 4min 3s, sys: 36.3 s, total: 4min 40s\n", "Wall time: 3min 32s\n" ] } ], "source": [ "%time train(peft_model, optimizer, criterion, train_loader, valid_dataloader=valid_loader, epochs=10)" ] }, { "cell_type": "markdown", "id": "94162859", "metadata": {}, "source": [ "We get an accuracy of ~0.97, despite only training a tiny amount of parameters. That's a really nice result." ] }, { "cell_type": "markdown", "id": "9c16bad8", "metadata": {}, "source": [ "## Sharing the model through Hugging Face Hub" ] }, { "cell_type": "markdown", "id": "2e1e16c7", "metadata": {}, "source": [ "### Pushing the model to Hugging Face Hub" ] }, { "cell_type": "markdown", "id": "ec596b3b", "metadata": {}, "source": [ "If we want to share the fine-tuned weights with the world, we can upload them to Hugging Face Hub like this:" ] }, { "cell_type": "code", "execution_count": 19, "id": "b583579d", "metadata": {}, "outputs": [], "source": [ "user = \"BenjaminB\" # put your user name here\n", "model_name = \"peft-lora-with-timm-model\"\n", "model_id = f\"{user}/{model_name}\"" ] }, { "cell_type": "code", "execution_count": 20, "id": "f1db67e4", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aed1f9c3fa334be1b5f208efe5ba27e6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Upload 1 LFS files: 0%| | 0/1 [00:00