{
"metadata": {
"kernelspec": {
"language": "python",
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.7.12",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"colab": {
"provenance": [],
"machine_shape": "hm",
"include_colab_link": true
},
"gpuClass": "standard"
},
"nbformat_minor": 0,
"nbformat": 4,
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"source": [
"## Setting up environment"
],
"metadata": {
"id": "kNdVjbIkwTnV"
}
},
{
"cell_type": "code",
"source": [
"# mount google drive\n",
"from google.colab import drive\n",
"drive.mount('/content/drive', force_remount=True)"
],
"metadata": {
"id": "RQDSDTzNwTRQ",
"outputId": "8a3f90fa-3621-4ef8-a33a-78fbefe4c85b",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!git clone https://github.com/yxmauw/cxr-multilabel-clf.git"
],
"metadata": {
"id": "7LVf0IqtxO28",
"outputId": "b21853a5-e504-40ad-f4d0-5c06c8efe602",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'cxr-multilabel-clf'...\n",
"remote: Enumerating objects: 63, done.\u001b[K\n",
"remote: Counting objects: 100% (63/63), done.\u001b[K\n",
"remote: Compressing objects: 100% (61/61), done.\u001b[K\n",
"remote: Total 63 (delta 28), reused 0 (delta 0), pack-reused 0\u001b[K\n",
"Unpacking objects: 100% (63/63), done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!mkdir ~/.kaggle #Make a directory named “.kaggle”"
],
"metadata": {
"id": "gKoUXV34xws6"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!cp ./cxr-multilabel-clf//kaggle.json ~/.kaggle/ # Copy the “kaggle.json” into this new directory"
],
"metadata": {
"id": "vQ9JVUJqx8RD"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!chmod 600 ~/.kaggle/kaggle.json # Allocate the required permission for this file"
],
"metadata": {
"id": "5Q13Q3d3yEtl"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!kaggle competitions download -c ranzcr-clip-catheter-line-classification # download dataset"
],
"metadata": {
"id": "gIWBx5k-yIYI",
"outputId": "a853e922-85da-4609-d269-ef34880bec92",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading ranzcr-clip-catheter-line-classification.zip to /content\n",
"100% 11.7G/11.7G [06:38<00:00, 38.3MB/s]\n",
"100% 11.7G/11.7G [06:38<00:00, 31.5MB/s]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!unzip ranzcr-clip-catheter-line-classification.zip #unzip folders"
],
"metadata": {
"id": "2oGiv_1SyxFW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Data"
],
"metadata": {
"id": "L3RKSB9ZxE_F"
}
},
{
"cell_type": "code",
"source": [
"import numpy as np \n",
"import pandas as pd "
],
"metadata": {
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"execution": {
"iopub.status.busy": "2022-10-28T02:46:40.023210Z",
"iopub.execute_input": "2022-10-28T02:46:40.024111Z",
"iopub.status.idle": "2022-10-28T02:46:40.047013Z",
"shell.execute_reply.started": "2022-10-28T02:46:40.024018Z",
"shell.execute_reply": "2022-10-28T02:46:40.046119Z"
},
"trusted": true,
"id": "6HJ9QhyevJlw"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_df = pd.read_csv('train.csv')\n",
"display(len(train_df))\n",
"display(train_df.head(3))\n",
"train_annot_df = pd.read_csv('train_annotations.csv')\n",
"display(len(train_annot_df))\n",
"display(train_annot_df.head(3))"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T02:46:44.475857Z",
"iopub.execute_input": "2022-10-28T02:46:44.476564Z",
"iopub.status.idle": "2022-10-28T02:46:44.724851Z",
"shell.execute_reply.started": "2022-10-28T02:46:44.476517Z",
"shell.execute_reply": "2022-10-28T02:46:44.723861Z"
},
"trusted": true,
"id": "q_tcXmTovJly",
"outputId": "e024c922-b5c6-4ddc-e4f1-102f7cb40ee6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 421
}
},
"execution_count": 9,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"30083"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
" StudyInstanceUID ETT - Abnormal \\\n",
"0 1.2.826.0.1.3680043.8.498.26697628953273228189... 0 \n",
"1 1.2.826.0.1.3680043.8.498.46302891597398758759... 0 \n",
"2 1.2.826.0.1.3680043.8.498.23819260719748494858... 0 \n",
"\n",
" ETT - Borderline ETT - Normal NGT - Abnormal NGT - Borderline \\\n",
"0 0 0 0 0 \n",
"1 0 1 0 0 \n",
"2 0 0 0 0 \n",
"\n",
" NGT - Incompletely Imaged NGT - Normal CVC - Abnormal CVC - Borderline \\\n",
"0 0 1 0 0 \n",
"1 1 0 0 0 \n",
"2 0 0 0 1 \n",
"\n",
" CVC - Normal Swan Ganz Catheter Present PatientID \n",
"0 0 0 ec89415d1 \n",
"1 1 0 bf4c6da3c \n",
"2 0 0 3fc1c97e5 "
],
"text/html": [
"\n",
"
\n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" StudyInstanceUID | \n",
" ETT - Abnormal | \n",
" ETT - Borderline | \n",
" ETT - Normal | \n",
" NGT - Abnormal | \n",
" NGT - Borderline | \n",
" NGT - Incompletely Imaged | \n",
" NGT - Normal | \n",
" CVC - Abnormal | \n",
" CVC - Borderline | \n",
" CVC - Normal | \n",
" Swan Ganz Catheter Present | \n",
" PatientID | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1.2.826.0.1.3680043.8.498.26697628953273228189... | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" ec89415d1 | \n",
"
\n",
" \n",
" 1 | \n",
" 1.2.826.0.1.3680043.8.498.46302891597398758759... | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" bf4c6da3c | \n",
"
\n",
" \n",
" 2 | \n",
" 1.2.826.0.1.3680043.8.498.23819260719748494858... | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 3fc1c97e5 | \n",
"
\n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"17999"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
" StudyInstanceUID label \\\n",
"0 1.2.826.0.1.3680043.8.498.12616281126973421762... CVC - Normal \n",
"1 1.2.826.0.1.3680043.8.498.12616281126973421762... CVC - Normal \n",
"2 1.2.826.0.1.3680043.8.498.72921907356394389969... CVC - Borderline \n",
"\n",
" data \n",
"0 [[1487, 1279], [1477, 1168], [1472, 1052], [14... \n",
"1 [[1328, 7], [1347, 101], [1383, 193], [1400, 2... \n",
"2 [[801, 1207], [812, 1112], [823, 1023], [842, ... "
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" StudyInstanceUID | \n",
" label | \n",
" data | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1.2.826.0.1.3680043.8.498.12616281126973421762... | \n",
" CVC - Normal | \n",
" [[1487, 1279], [1477, 1168], [1472, 1052], [14... | \n",
"
\n",
" \n",
" 1 | \n",
" 1.2.826.0.1.3680043.8.498.12616281126973421762... | \n",
" CVC - Normal | \n",
" [[1328, 7], [1347, 101], [1383, 193], [1400, 2... | \n",
"
\n",
" \n",
" 2 | \n",
" 1.2.826.0.1.3680043.8.498.72921907356394389969... | \n",
" CVC - Borderline | \n",
" [[801, 1207], [812, 1112], [823, 1023], [842, ... | \n",
"
\n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"# value counts\n",
"train_df.drop(columns=['StudyInstanceUID','PatientID']).agg(['sum'])\n",
"# unbalanced dataset"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T02:46:59.167135Z",
"iopub.execute_input": "2022-10-28T02:46:59.167596Z",
"iopub.status.idle": "2022-10-28T02:46:59.208167Z",
"shell.execute_reply.started": "2022-10-28T02:46:59.167559Z",
"shell.execute_reply": "2022-10-28T02:46:59.207260Z"
},
"trusted": true,
"id": "5huaW8WQvJl1",
"outputId": "5dcae458-0bfe-4dd5-aaab-bbbebfda5ab8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 197
}
},
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" ETT - Abnormal ETT - Borderline ETT - Normal NGT - Abnormal \\\n",
"sum 79 1138 7240 279 \n",
"\n",
" NGT - Borderline NGT - Incompletely Imaged NGT - Normal \\\n",
"sum 529 2748 4797 \n",
"\n",
" CVC - Abnormal CVC - Borderline CVC - Normal \\\n",
"sum 3195 8460 21324 \n",
"\n",
" Swan Ganz Catheter Present \n",
"sum 830 "
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" ETT - Abnormal | \n",
" ETT - Borderline | \n",
" ETT - Normal | \n",
" NGT - Abnormal | \n",
" NGT - Borderline | \n",
" NGT - Incompletely Imaged | \n",
" NGT - Normal | \n",
" CVC - Abnormal | \n",
" CVC - Borderline | \n",
" CVC - Normal | \n",
" Swan Ganz Catheter Present | \n",
"
\n",
" \n",
" \n",
" \n",
" sum | \n",
" 79 | \n",
" 1138 | \n",
" 7240 | \n",
" 279 | \n",
" 529 | \n",
" 2748 | \n",
" 4797 | \n",
" 3195 | \n",
" 8460 | \n",
" 21324 | \n",
" 830 | \n",
"
\n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T02:47:01.769145Z",
"iopub.execute_input": "2022-10-28T02:47:01.769618Z",
"iopub.status.idle": "2022-10-28T02:47:02.507132Z",
"shell.execute_reply.started": "2022-10-28T02:47:01.769578Z",
"shell.execute_reply": "2022-10-28T02:47:02.506044Z"
},
"trusted": true,
"id": "VBczRJs9vJl2"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# value counts\n",
"train_df.drop(columns=['StudyInstanceUID','PatientID']).agg(['sum']).T.sort_values(by='sum').plot(kind='barh')\n",
"plt.legend(loc='lower right');"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T02:47:04.180989Z",
"iopub.execute_input": "2022-10-28T02:47:04.181680Z",
"iopub.status.idle": "2022-10-28T02:47:04.493106Z",
"shell.execute_reply.started": "2022-10-28T02:47:04.181644Z",
"shell.execute_reply": "2022-10-28T02:47:04.491889Z"
},
"trusted": true,
"id": "ePQF1z29vJl3",
"outputId": "e043571f-f8b6-4084-9dfe-055082040594",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 265
}
},
"execution_count": 12,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"source": [
"len(train_df.drop(columns=['StudyInstanceUID','PatientID']).agg(['sum']).T)\n",
"# num of classes"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T02:47:07.675004Z",
"iopub.execute_input": "2022-10-28T02:47:07.675678Z",
"iopub.status.idle": "2022-10-28T02:47:07.695859Z",
"shell.execute_reply.started": "2022-10-28T02:47:07.675642Z",
"shell.execute_reply": "2022-10-28T02:47:07.694954Z"
},
"trusted": true,
"id": "RA8cTCndvJl4",
"outputId": "ad5fb681-9460-4846-c56e-0adcc336c1b4",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"11"
]
},
"metadata": {},
"execution_count": 12
}
]
},
{
"cell_type": "markdown",
"source": [
"## Create Datasets"
],
"metadata": {
"id": "4ckn3ZfUz4T_"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import cv2\n",
"import numpy as np\n",
"from torchvision import transforms\n",
"from torch.utils.data import Dataset"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T02:47:09.612626Z",
"iopub.execute_input": "2022-10-28T02:47:09.613355Z",
"iopub.status.idle": "2022-10-28T02:47:11.485435Z",
"shell.execute_reply.started": "2022-10-28T02:47:09.613293Z",
"shell.execute_reply": "2022-10-28T02:47:11.484503Z"
},
"trusted": true,
"id": "RAfp_UQivJl4"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class ImageDataset(Dataset):\n",
" def __init__(self, csv, train, test):\n",
" self.csv = csv\n",
" self.train = train\n",
" self.test = test\n",
" self.all_image_names = self.csv[:]['StudyInstanceUID']\n",
" self.all_labels = np.array(self.csv.drop(['StudyInstanceUID', 'PatientID'], axis=1))\n",
" self.train_ratio = int(0.85 * len(self.csv))\n",
" self.valid_ratio = len(self.csv) - self.train_ratio\n",
" # set the training data images and labels\n",
" if self.train == True:\n",
" print(f\"Number of training images: {self.train_ratio}\")\n",
" self.image_names = list(self.all_image_names[:self.train_ratio])\n",
" self.labels = list(self.all_labels[:self.train_ratio])\n",
" # define the training transforms\n",
" self.transform = transforms.Compose([\n",
" transforms.ToPILImage(),\n",
" transforms.Resize((400, 400)),\n",
" transforms.RandomHorizontalFlip(p=0.5),\n",
" transforms.RandomRotation(degrees=45),\n",
" transforms.ToTensor(),\n",
" ])\n",
" # set the validation data images and labels\n",
" elif self.train == False and self.test == False:\n",
" print(f\"Number of validation images: {self.valid_ratio}\")\n",
" self.image_names = list(self.all_image_names[-self.valid_ratio:-10])\n",
" self.labels = list(self.all_labels[-self.valid_ratio:])\n",
" # define the validation transforms\n",
" self.transform = transforms.Compose([\n",
" transforms.ToPILImage(),\n",
" transforms.Resize((400, 400)),\n",
" transforms.ToTensor(),\n",
" ])\n",
" # set the test data images and labels, only last 10 images\n",
" # this, we will use in a separate inference script\n",
" elif self.test == True and self.train == False:\n",
" self.image_names = list(self.all_image_names[-10:])\n",
" self.labels = list(self.all_labels[-10:])\n",
" # define the test transforms\n",
" self.transform = transforms.Compose([\n",
" transforms.ToPILImage(),\n",
" transforms.ToTensor(),\n",
" ])\n",
" def __len__(self):\n",
" return len(self.image_names)\n",
" \n",
" def __getitem__(self, index):\n",
" image = cv2.imread(f\"./train/{self.image_names[index]}.jpg\")\n",
" # convert the image from BGR to RGB color format\n",
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
" # apply image transforms\n",
" image = self.transform(image)\n",
" targets = self.labels[index]\n",
" \n",
" return {\n",
" 'image': torch.tensor(image, dtype=torch.float32),\n",
" 'label': torch.tensor(targets, dtype=torch.float32)\n",
" }"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T02:47:13.898477Z",
"iopub.execute_input": "2022-10-28T02:47:13.899027Z",
"iopub.status.idle": "2022-10-28T02:47:13.912689Z",
"shell.execute_reply.started": "2022-10-28T02:47:13.898986Z",
"shell.execute_reply": "2022-10-28T02:47:13.911765Z"
},
"trusted": true,
"id": "KdLACwrLvJl5"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import torchvision\n",
"torchvision.__version__"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T03:05:44.948332Z",
"iopub.execute_input": "2022-10-28T03:05:44.948713Z",
"iopub.status.idle": "2022-10-28T03:05:44.955478Z",
"shell.execute_reply.started": "2022-10-28T03:05:44.948681Z",
"shell.execute_reply": "2022-10-28T03:05:44.954402Z"
},
"trusted": true,
"id": "v1ofiKiavJl6",
"outputId": "0ecc26fd-3a56-437f-ca65-b18197b3ae8a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'0.13.1+cu113'"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
}
},
"metadata": {},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"source": [
"print(dir(torchvision.models))"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T03:42:01.322455Z",
"iopub.execute_input": "2022-10-28T03:42:01.322825Z",
"iopub.status.idle": "2022-10-28T03:42:01.328721Z",
"shell.execute_reply.started": "2022-10-28T03:42:01.322792Z",
"shell.execute_reply": "2022-10-28T03:42:01.327749Z"
},
"trusted": true,
"id": "utBbFn0_vJl7",
"outputId": "9ee13f10-8c60-40a3-f6e6-2672810eecd2",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['AlexNet', 'AlexNet_Weights', 'ConvNeXt', 'ConvNeXt_Base_Weights', 'ConvNeXt_Large_Weights', 'ConvNeXt_Small_Weights', 'ConvNeXt_Tiny_Weights', 'DenseNet', 'DenseNet121_Weights', 'DenseNet161_Weights', 'DenseNet169_Weights', 'DenseNet201_Weights', 'EfficientNet', 'EfficientNet_B0_Weights', 'EfficientNet_B1_Weights', 'EfficientNet_B2_Weights', 'EfficientNet_B3_Weights', 'EfficientNet_B4_Weights', 'EfficientNet_B5_Weights', 'EfficientNet_B6_Weights', 'EfficientNet_B7_Weights', 'EfficientNet_V2_L_Weights', 'EfficientNet_V2_M_Weights', 'EfficientNet_V2_S_Weights', 'GoogLeNet', 'GoogLeNetOutputs', 'GoogLeNet_Weights', 'Inception3', 'InceptionOutputs', 'Inception_V3_Weights', 'MNASNet', 'MNASNet0_5_Weights', 'MNASNet0_75_Weights', 'MNASNet1_0_Weights', 'MNASNet1_3_Weights', 'MobileNetV2', 'MobileNetV3', 'MobileNet_V2_Weights', 'MobileNet_V3_Large_Weights', 'MobileNet_V3_Small_Weights', 'RegNet', 'RegNet_X_16GF_Weights', 'RegNet_X_1_6GF_Weights', 'RegNet_X_32GF_Weights', 'RegNet_X_3_2GF_Weights', 'RegNet_X_400MF_Weights', 'RegNet_X_800MF_Weights', 'RegNet_X_8GF_Weights', 'RegNet_Y_128GF_Weights', 'RegNet_Y_16GF_Weights', 'RegNet_Y_1_6GF_Weights', 'RegNet_Y_32GF_Weights', 'RegNet_Y_3_2GF_Weights', 'RegNet_Y_400MF_Weights', 'RegNet_Y_800MF_Weights', 'RegNet_Y_8GF_Weights', 'ResNeXt101_32X8D_Weights', 'ResNeXt101_64X4D_Weights', 'ResNeXt50_32X4D_Weights', 'ResNet', 'ResNet101_Weights', 'ResNet152_Weights', 'ResNet18_Weights', 'ResNet34_Weights', 'ResNet50_Weights', 'ShuffleNetV2', 'ShuffleNet_V2_X0_5_Weights', 'ShuffleNet_V2_X1_0_Weights', 'ShuffleNet_V2_X1_5_Weights', 'ShuffleNet_V2_X2_0_Weights', 'SqueezeNet', 'SqueezeNet1_0_Weights', 'SqueezeNet1_1_Weights', 'SwinTransformer', 'Swin_B_Weights', 'Swin_S_Weights', 'Swin_T_Weights', 'VGG', 'VGG11_BN_Weights', 'VGG11_Weights', 'VGG13_BN_Weights', 'VGG13_Weights', 'VGG16_BN_Weights', 'VGG16_Weights', 'VGG19_BN_Weights', 'VGG19_Weights', 'ViT_B_16_Weights', 'ViT_B_32_Weights', 'ViT_H_14_Weights', 'ViT_L_16_Weights', 'ViT_L_32_Weights', 'VisionTransformer', 'Wide_ResNet101_2_Weights', 'Wide_ResNet50_2_Weights', '_GoogLeNetOutputs', '_InceptionOutputs', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_api', '_meta', '_utils', 'alexnet', 'convnext', 'convnext_base', 'convnext_large', 'convnext_small', 'convnext_tiny', 'densenet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'detection', 'efficientnet', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_v2_l', 'efficientnet_v2_m', 'efficientnet_v2_s', 'get_weight', 'googlenet', 'inception', 'inception_v3', 'mnasnet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'mobilenetv2', 'mobilenetv3', 'optical_flow', 'quantization', 'regnet', 'regnet_x_16gf', 'regnet_x_1_6gf', 'regnet_x_32gf', 'regnet_x_3_2gf', 'regnet_x_400mf', 'regnet_x_800mf', 'regnet_x_8gf', 'regnet_y_128gf', 'regnet_y_16gf', 'regnet_y_1_6gf', 'regnet_y_32gf', 'regnet_y_3_2gf', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_8gf', 'resnet', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x8d', 'resnext101_64x4d', 'resnext50_32x4d', 'segmentation', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0', 'shufflenetv2', 'squeezenet', 'squeezenet1_0', 'squeezenet1_1', 'swin_b', 'swin_s', 'swin_t', 'swin_transformer', 'vgg', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'video', 'vision_transformer', 'vit_b_16', 'vit_b_32', 'vit_h_14', 'vit_l_16', 'vit_l_32', 'wide_resnet101_2', 'wide_resnet50_2']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from torchvision import models \n",
"import torch.nn as nn"
],
"metadata": {
"id": "_bjmGalO9UHo"
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def model(pretrained, requires_grad):\n",
" model = models.efficientnet_v2_s(progress=True, pretrained=pretrained)\n",
" # to freeze the hidden layers\n",
" if requires_grad == False:\n",
" for param in model.parameters():\n",
" param.requires_grad = False\n",
" # to train the hidden layers\n",
" elif requires_grad == True:\n",
" for param in model.parameters():\n",
" param.requires_grad = True\n",
" # make the classification layer learnable\n",
" # we have 11 classes in total\n",
" model.classifier[1] = nn.Linear(in_features=1280, out_features=11)\n",
" return model"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T04:24:13.656142Z",
"iopub.execute_input": "2022-10-28T04:24:13.656782Z",
"iopub.status.idle": "2022-10-28T04:24:13.663143Z",
"shell.execute_reply.started": "2022-10-28T04:24:13.656746Z",
"shell.execute_reply": "2022-10-28T04:24:13.662004Z"
},
"trusted": true,
"id": "m_obBk7lvJl8"
},
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip install torchmetrics "
],
"metadata": {
"id": "3okxQXjv0jYw",
"outputId": "10d86872-257a-4706-bdf4-a8d4f60997cf",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting torchmetrics\n",
" Downloading torchmetrics-0.10.2-py3-none-any.whl (529 kB)\n",
"\u001b[K |████████████████████████████████| 529 kB 14.3 MB/s \n",
"\u001b[?25hRequirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from torchmetrics) (1.21.6)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torchmetrics) (4.1.1)\n",
"Requirement already satisfied: torch>=1.3.1 in /usr/local/lib/python3.7/dist-packages (from torchmetrics) (1.12.1+cu113)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from torchmetrics) (21.3)\n",
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->torchmetrics) (3.0.9)\n",
"Installing collected packages: torchmetrics\n",
"Successfully installed torchmetrics-0.10.2\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Create Train and Validate functions"
],
"metadata": {
"id": "U-mCMBzyz_gY"
}
},
{
"cell_type": "code",
"source": [
"from tqdm import tqdm\n",
"from torchmetrics import Accuracy, AUROC, F1Score, Precision, Recall"
],
"metadata": {
"id": "kV8Yj6EZ9XhB"
},
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# training function\n",
"def train(model, dataloader, optimizer, criterion, train_data, device):\n",
" print('Training')\n",
" model.train()\n",
" counter = 0\n",
" train_running_loss = 0.0\n",
" # instantiate metrics\n",
" acc = Accuracy()\n",
" auc = AUROC()\n",
" f1_score = F1Score()\n",
" precision = Precision()\n",
" recall = Recall()\n",
" preds = []\n",
" labels = []\n",
" for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):\n",
" counter += 1\n",
" data, target = data['image'].to(device), data['label'].to(device)\n",
" labels.append(target.cpu().numpy().argmax(axis=1))\n",
" optimizer.zero_grad()\n",
" outputs = model(data)\n",
" # apply sigmoid activation to get all the outputs between 0 and 1\n",
" outputs = torch.sigmoid(outputs)\n",
" loss = criterion(outputs, target)\n",
" train_running_loss += loss.item()\n",
" # backpropagation\n",
" loss.backward()\n",
" # update optimizer parameters\n",
" optimizer.step()\n",
" preds.append(outputs.detach().cpu().numpy().argmax(axis=1))\n",
" \n",
" train_loss = train_running_loss / counter\n",
" preds = torch.tensor(np.concatenate(preds))\n",
" labels = torch.tensor(np.concatenate(labels))\n",
" train_acc = acc(preds, labels).item()\n",
" \n",
" return train_loss, train_acc"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T04:24:16.962617Z",
"iopub.execute_input": "2022-10-28T04:24:16.963764Z",
"iopub.status.idle": "2022-10-28T04:24:16.976195Z",
"shell.execute_reply.started": "2022-10-28T04:24:16.963715Z",
"shell.execute_reply": "2022-10-28T04:24:16.975023Z"
},
"trusted": true,
"id": "5y0w7dPrvJl9"
},
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# validation function\n",
"def validate(model, dataloader, criterion, val_data, device):\n",
" print('Validating')\n",
" model.eval()\n",
" counter = 0\n",
" val_running_loss = 0.0\n",
" # instantiate metrics\n",
" acc = Accuracy()\n",
" auc = AUROC()\n",
" f1_score = F1Score()\n",
" precision = Precision()\n",
" recall = Recall()\n",
" preds = []\n",
" labels = []\n",
" with torch.no_grad():\n",
" for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):\n",
" counter += 1\n",
" data, target = data['image'].to(device), data['label'].to(device)\n",
" labels.append(target.cpu().numpy().argmax(axis=1))\n",
" # make predictions\n",
" outputs = model(data)\n",
" # apply sigmoid activation to get all the outputs between 0 and 1\n",
" outputs = torch.sigmoid(outputs)\n",
" loss = criterion(outputs, target)\n",
" val_running_loss += loss.item()\n",
" preds.append(outputs.detach().cpu().numpy().argmax(axis=1))\n",
" \n",
" val_loss = val_running_loss / counter\n",
" preds = torch.tensor(np.concatenate(preds))\n",
" labels = torch.tensor(np.concatenate(labels))\n",
" val_acc = acc(preds, labels).item()\n",
" return val_loss, val_acc"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T04:24:20.716225Z",
"iopub.execute_input": "2022-10-28T04:24:20.716601Z",
"iopub.status.idle": "2022-10-28T04:24:20.727583Z",
"shell.execute_reply.started": "2022-10-28T04:24:20.716567Z",
"shell.execute_reply": "2022-10-28T04:24:20.726191Z"
},
"trusted": true,
"id": "F43wHu87vJl-"
},
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import torch.optim as optim\n",
"import matplotlib\n",
"from torch.utils.data import DataLoader\n",
"matplotlib.style.use('ggplot')\n",
"# initialize the computation device\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T04:24:24.617804Z",
"iopub.execute_input": "2022-10-28T04:24:24.618166Z",
"iopub.status.idle": "2022-10-28T04:24:24.623867Z",
"shell.execute_reply.started": "2022-10-28T04:24:24.618134Z",
"shell.execute_reply": "2022-10-28T04:24:24.622628Z"
},
"trusted": true,
"id": "TWolPeqivJl-"
},
"execution_count": 22,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Set model parameters"
],
"metadata": {
"id": "2RkrFR2w0EfW"
}
},
{
"cell_type": "code",
"source": [
"#intialize the model\n",
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
"\n",
"ENet_model = model(pretrained=False, requires_grad=True).to(device)\n",
"# learning parameters\n",
"lr = 0.0001\n",
"epochs = 25\n",
"batch_size = 4\n",
"optimizer = optim.Adam(ENet_model.parameters(), lr=lr)\n",
"scheduler = ReduceLROnPlateau(optimizer, 'min')\n",
"criterion = nn.BCELoss()"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T04:24:26.626925Z",
"iopub.execute_input": "2022-10-28T04:24:26.630098Z",
"iopub.status.idle": "2022-10-28T04:24:28.432430Z",
"shell.execute_reply.started": "2022-10-28T04:24:26.630044Z",
"shell.execute_reply": "2022-10-28T04:24:28.431444Z"
},
"trusted": true,
"id": "siSktp4cvJmA",
"outputId": "4559217e-7208-45c5-b3cb-5b2867c663cc",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.\n",
" f\"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, \"\n",
"/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=None`.\n",
" warnings.warn(msg)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"train_data = ImageDataset(\n",
" train_df, train=True, test=False\n",
")\n",
"# validation dataset\n",
"valid_data = ImageDataset(\n",
" train_df, train=False, test=False\n",
")\n",
"# train data loader\n",
"train_loader = DataLoader(\n",
" train_data, \n",
" batch_size=batch_size,\n",
" shuffle=True\n",
")\n",
"# validation data loader\n",
"valid_loader = DataLoader(\n",
" valid_data, \n",
" batch_size=batch_size,\n",
" shuffle=False\n",
")"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T04:24:32.210103Z",
"iopub.execute_input": "2022-10-28T04:24:32.210655Z",
"iopub.status.idle": "2022-10-28T04:24:32.243373Z",
"shell.execute_reply.started": "2022-10-28T04:24:32.210615Z",
"shell.execute_reply": "2022-10-28T04:24:32.242479Z"
},
"trusted": true,
"id": "lUxQij0_vJmA",
"outputId": "7f7cf372-d419-459d-baa5-8f6c5fa7fcac",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of training images: 25570\n",
"Number of validation images: 4513\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Create save model class function"
],
"metadata": {
"id": "89DxeKW50Kez"
}
},
{
"cell_type": "code",
"source": [
"class SaveBestModel:\n",
" \"\"\"\n",
" Class to save the best model while training. If the current epoch's \n",
" validation loss is less than the previous least less, then save the\n",
" model state.\n",
" \"\"\"\n",
" def __init__(\n",
" self, best_valid_loss=float('inf')\n",
" ):\n",
" self.best_valid_loss = best_valid_loss\n",
" \n",
" def __call__(\n",
" self, current_valid_loss, \n",
" epoch, model, optimizer, criterion\n",
" ):\n",
" if current_valid_loss < self.best_valid_loss:\n",
" self.best_valid_loss = current_valid_loss\n",
" print(f\"\\nBest validation loss: {self.best_valid_loss}\")\n",
" print(f\"\\nSaving best model for epoch: {epoch+1}\\n\")\n",
" file_path = f'drive/MyDrive/Colab Notebooks/Enet-ep{epoch+1}-val{current_valid_loss:.3f}.pth'\n",
" torch.save({\n",
" 'epoch': epoch+1,\n",
" 'model_state_dict': model.state_dict(),\n",
" 'optimizer_state_dict': optimizer.state_dict(),\n",
" 'loss': criterion,\n",
" }, file_path)"
],
"metadata": {
"id": "5TSUnoCtJQBS"
},
"execution_count": 26,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Save plots"
],
"metadata": {
"id": "ZyawvU8qDbYx"
}
},
{
"cell_type": "code",
"source": [
"def save_plots(train_acc, valid_acc, train_loss, valid_loss):\n",
" \"\"\"\n",
" Function to save the loss and accuracy plots to disk.\n",
" \"\"\"\n",
" # accuracy plots\n",
" plt.figure(figsize=(10, 7))\n",
" plt.plot(\n",
" train_acc, color='green', linestyle='-', \n",
" label='train accuracy'\n",
" )\n",
" plt.plot(\n",
" valid_acc, color='blue', linestyle='-', \n",
" label='validataion accuracy'\n",
" )\n",
" plt.xlabel('Epochs')\n",
" plt.ylabel('Accuracy')\n",
" plt.legend()\n",
" plt.savefig('drive/MyDrive/Colab Notebooks/Enet-acc.png')\n",
" \n",
" # loss plots\n",
" plt.figure(figsize=(10, 7))\n",
" plt.plot(\n",
" train_loss, color='orange', linestyle='-', \n",
" label='train loss'\n",
" )\n",
" plt.plot(\n",
" valid_loss, color='red', linestyle='-', \n",
" label='validataion loss'\n",
" )\n",
" plt.xlabel('Epochs')\n",
" plt.ylabel('Loss')\n",
" plt.legend()\n",
" plt.savefig('drive/MyDrive/Colab Notebooks/Enet-loss.png')"
],
"metadata": {
"id": "eH_PFKRVDeH1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Create earlystopper class function"
],
"metadata": {
"id": "M-c9Qc4d0UhR"
}
},
{
"cell_type": "code",
"source": [
"class EarlyStopper:\n",
" def __init__(self, patience=1, min_delta=0):\n",
" self.patience = patience\n",
" self.min_delta = min_delta\n",
" self.counter = 0\n",
" self.min_validation_loss = np.inf\n",
"\n",
" def early_stop(self, validation_loss):\n",
" if validation_loss < self.min_validation_loss:\n",
" self.min_validation_loss = validation_loss\n",
" self.counter = 0\n",
" elif validation_loss > (self.min_validation_loss + self.min_delta):\n",
" self.counter += 1\n",
" if self.counter >= self.patience:\n",
" return True\n",
" return False"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T04:24:36.170818Z",
"iopub.execute_input": "2022-10-28T04:24:36.171515Z",
"iopub.status.idle": "2022-10-28T04:24:36.178005Z",
"shell.execute_reply.started": "2022-10-28T04:24:36.171477Z",
"shell.execute_reply": "2022-10-28T04:24:36.176852Z"
},
"trusted": true,
"id": "LDFpjyp_vJmB"
},
"execution_count": 27,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Train model"
],
"metadata": {
"id": "bSIP5Yw40d7l"
}
},
{
"cell_type": "code",
"source": [
"# start the training and validation\n",
"train_loss = []\n",
"valid_loss = []\n",
"train_acc = []\n",
"val_acc = []\n",
"early_stopper = EarlyStopper(patience=5, min_delta=0.001)\n",
"save_best_model = SaveBestModel() # initialize SaveBestModel class\n",
"for epoch in range(epochs):\n",
" print(f\"Epoch {epoch+1} of {epochs}\")\n",
" train_epoch_loss, train_epoch_acc = train(\n",
" ENet_model, train_loader, optimizer, criterion, train_data, device\n",
" )\n",
" valid_epoch_loss, val_epoch_acc = validate(\n",
" ENet_model, valid_loader, criterion, valid_data, device\n",
" )\n",
" if early_stopper.early_stop(valid_epoch_loss): \n",
" break\n",
" train_loss.append(train_epoch_loss)\n",
" valid_loss.append(valid_epoch_loss)\n",
" train_acc.append(train_epoch_acc)\n",
" val_acc.append(val_epoch_acc)\n",
" print(f'Train Loss: {train_epoch_loss:.4f}; Val Loss: {valid_epoch_loss:.4f}; Train accuracy: {train_epoch_acc:.4f}; Val accuracy: {val_epoch_acc:.4f}')\n",
" # save the best model till now if we have the least loss in the current epoch\n",
" save_best_model(\n",
" valid_epoch_loss, epoch, ENet_model, optimizer, criterion\n",
" )\n",
" print('='*50) # gap\n",
"\n",
"save_plots(train_acc, val_acc, train_loss, valid_loss)\n",
"print('PLOTS SAVED')"
],
"metadata": {
"execution": {
"iopub.status.busy": "2022-10-28T04:24:39.710609Z",
"iopub.execute_input": "2022-10-28T04:24:39.710988Z"
},
"trusted": true,
"id": "umWa-LbavJmC",
"outputId": "c81c677e-4b6a-4e4c-d8b0-c7c9018cff6b",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": null,
"outputs": [
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs)\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/6392 [00:00, ?it/s]/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:56: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
"6393it [55:18, 1.93it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:01, 2.67it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2887; Val Loss: 0.2701; Train accuracy: 0.4308; Val accuracy: 0.4373\n",
"\n",
"Best validation loss: 0.27008894395552985\n",
"\n",
"Saving best model for epoch: 1\n",
"\n",
"==================================================\n",
"Epoch 2 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:15, 2.00it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:03, 2.66it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2763; Val Loss: 0.2648; Train accuracy: 0.4333; Val accuracy: 0.4895\n",
"\n",
"Best validation loss: 0.26484752869891864\n",
"\n",
"Saving best model for epoch: 2\n",
"\n",
"==================================================\n",
"Epoch 3 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:31, 1.99it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:03, 2.66it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2719; Val Loss: 0.2589; Train accuracy: 0.4349; Val accuracy: 0.4317\n",
"\n",
"Best validation loss: 0.2588627720268222\n",
"\n",
"Saving best model for epoch: 3\n",
"\n",
"==================================================\n",
"Epoch 4 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:18, 2.00it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:02, 2.67it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2665; Val Loss: 0.2580; Train accuracy: 0.4372; Val accuracy: 0.4666\n",
"\n",
"Best validation loss: 0.257986861266209\n",
"\n",
"Saving best model for epoch: 4\n",
"\n",
"==================================================\n",
"Epoch 5 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:13, 2.00it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:05, 2.65it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2613; Val Loss: 0.2500; Train accuracy: 0.4460; Val accuracy: 0.4610\n",
"\n",
"Best validation loss: 0.25000611321461347\n",
"\n",
"Saving best model for epoch: 5\n",
"\n",
"==================================================\n",
"Epoch 6 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:29, 1.99it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:08, 2.63it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2523; Val Loss: 0.2343; Train accuracy: 0.4856; Val accuracy: 0.6040\n",
"\n",
"Best validation loss: 0.2342516876571873\n",
"\n",
"Saving best model for epoch: 6\n",
"\n",
"==================================================\n",
"Epoch 7 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:23, 2.00it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:05, 2.65it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2431; Val Loss: 0.2315; Train accuracy: 0.5421; Val accuracy: 0.6249\n",
"\n",
"Best validation loss: 0.23147968413907088\n",
"\n",
"Saving best model for epoch: 7\n",
"\n",
"==================================================\n",
"Epoch 8 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [54:42, 1.95it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:03, 2.66it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2374; Val Loss: 0.2245; Train accuracy: 0.5726; Val accuracy: 0.6251\n",
"\n",
"Best validation loss: 0.2244645284842799\n",
"\n",
"Saving best model for epoch: 8\n",
"\n",
"==================================================\n",
"Epoch 9 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:26, 1.99it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:01, 2.67it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2351; Val Loss: 0.2245; Train accuracy: 0.5784; Val accuracy: 0.6018\n",
"==================================================\n",
"Epoch 10 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:15, 2.00it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:00, 2.68it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2323; Val Loss: 0.2268; Train accuracy: 0.5925; Val accuracy: 0.5412\n",
"==================================================\n",
"Epoch 11 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:03, 2.01it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:02, 2.66it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2302; Val Loss: 0.2227; Train accuracy: 0.5969; Val accuracy: 0.6171\n",
"\n",
"Best validation loss: 0.2226692042768425\n",
"\n",
"Saving best model for epoch: 11\n",
"\n",
"==================================================\n",
"Epoch 12 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:38, 1.99it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:11, 2.61it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2288; Val Loss: 0.2232; Train accuracy: 0.6007; Val accuracy: 0.5976\n",
"==================================================\n",
"Epoch 13 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:04, 2.01it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:08, 2.63it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2271; Val Loss: 0.2196; Train accuracy: 0.5988; Val accuracy: 0.6154\n",
"\n",
"Best validation loss: 0.2195812863749361\n",
"\n",
"Saving best model for epoch: 13\n",
"\n",
"==================================================\n",
"Epoch 14 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:41, 1.98it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:08, 2.63it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2255; Val Loss: 0.2187; Train accuracy: 0.6017; Val accuracy: 0.6258\n",
"\n",
"Best validation loss: 0.21869684002803866\n",
"\n",
"Saving best model for epoch: 14\n",
"\n",
"==================================================\n",
"Epoch 15 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:47, 1.98it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:12, 2.60it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2238; Val Loss: 0.2193; Train accuracy: 0.6014; Val accuracy: 0.6023\n",
"==================================================\n",
"Epoch 16 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [53:57, 1.97it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:09, 2.62it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2225; Val Loss: 0.2148; Train accuracy: 0.5998; Val accuracy: 0.5952\n",
"\n",
"Best validation loss: 0.21480766724510578\n",
"\n",
"Saving best model for epoch: 16\n",
"\n",
"==================================================\n",
"Epoch 17 of 30\n",
"Training\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"6393it [54:16, 1.96it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Validating\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"1126it [07:11, 2.61it/s]\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"Train Loss: 0.2210; Val Loss: 0.2136; Train accuracy: 0.6009; Val accuracy: 0.6060\n",
"\n",
"Best validation loss: 0.2136137347509605\n",
"\n",
"Saving best model for epoch: 17\n",
"\n",
"==================================================\n",
"Epoch 18 of 30\n",
"Training\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"6393it [53:49, 1.98it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Validating\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"1126it [07:06, 2.64it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Train Loss: 0.2186; Val Loss: 0.2118; Train accuracy: 0.6012; Val accuracy: 0.6040\n",
"\n",
"Best validation loss: 0.21180420065196115\n",
"\n",
"Saving best model for epoch: 18\n",
"\n",
"==================================================\n",
"Epoch 19 of 30\n",
"Training\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"6393it [54:20, 1.96it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Validating\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"1126it [07:10, 2.61it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Train Loss: 0.2173; Val Loss: 0.2099; Train accuracy: 0.6003; Val accuracy: 0.5958\n",
"\n",
"Best validation loss: 0.20991460148569321\n",
"\n",
"Saving best model for epoch: 19\n",
"\n",
"==================================================\n",
"Epoch 20 of 30\n",
"Training\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
" 57%|█████▋ | 3646/6392 [30:30<23:46, 1.92it/s]"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Evaluate model"
],
"metadata": {
"id": "Uaa5H3Sd4A3d"
}
},
{
"cell_type": "code",
"source": [
"model_path = 'drive/MyDrive/Colab Notebooks/Enet-ep19-val0.210.pth'"
],
"metadata": {
"id": "PEJ3zT6C5Aae"
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# initialize the computation device\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"#intialize the model\n",
"test_model = model(pretrained=False, requires_grad=False).to(device)\n",
"# load the model checkpoint\n",
"checkpoint = torch.load(model_path, map_location=torch.device('cpu'))\n",
"# load model weights state_dict\n",
"test_model.load_state_dict(checkpoint['model_state_dict'])\n",
"test_model.eval()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pLQXP0jCHHkv",
"outputId": "1126c4e0-4deb-4dfb-9dc0-19c703e9004a"
},
"execution_count": 20,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.\n",
" f\"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, \"\n",
"/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=None`.\n",
" warnings.warn(msg)\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"EfficientNet(\n",
" (features): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Sequential(\n",
" (0): FusedMBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.0, mode=row)\n",
" )\n",
" (1): FusedMBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.005, mode=row)\n",
" )\n",
" )\n",
" (2): Sequential(\n",
" (0): FusedMBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(24, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(96, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.01, mode=row)\n",
" )\n",
" (1): FusedMBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.015000000000000003, mode=row)\n",
" )\n",
" (2): FusedMBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.02, mode=row)\n",
" )\n",
" (3): FusedMBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.025, mode=row)\n",
" )\n",
" )\n",
" (3): Sequential(\n",
" (0): FusedMBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(48, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.030000000000000006, mode=row)\n",
" )\n",
" (1): FusedMBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.035, mode=row)\n",
" )\n",
" (2): FusedMBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.04, mode=row)\n",
" )\n",
" (3): FusedMBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.045, mode=row)\n",
" )\n",
" )\n",
" (4): Sequential(\n",
" (0): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(256, 16, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.05, mode=row)\n",
" )\n",
" (1): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
" (1): BatchNorm2d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.05500000000000001, mode=row)\n",
" )\n",
" (2): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
" (1): BatchNorm2d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.06000000000000001, mode=row)\n",
" )\n",
" (3): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
" (1): BatchNorm2d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.065, mode=row)\n",
" )\n",
" (4): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
" (1): BatchNorm2d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.07, mode=row)\n",
" )\n",
" (5): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
" (1): BatchNorm2d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.075, mode=row)\n",
" )\n",
" )\n",
" (5): Sequential(\n",
" (0): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(128, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(768, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
" (1): BatchNorm2d(768, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(768, 32, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(32, 768, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.08, mode=row)\n",
" )\n",
" (1): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.085, mode=row)\n",
" )\n",
" (2): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.09, mode=row)\n",
" )\n",
" (3): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.095, mode=row)\n",
" )\n",
" (4): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.1, mode=row)\n",
" )\n",
" (5): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.10500000000000001, mode=row)\n",
" )\n",
" (6): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.11000000000000001, mode=row)\n",
" )\n",
" (7): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.11500000000000002, mode=row)\n",
" )\n",
" (8): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.12000000000000002, mode=row)\n",
" )\n",
" )\n",
" (6): Sequential(\n",
" (0): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=960, bias=False)\n",
" (1): BatchNorm2d(960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(960, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.125, mode=row)\n",
" )\n",
" (1): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.13, mode=row)\n",
" )\n",
" (2): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.135, mode=row)\n",
" )\n",
" (3): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.14, mode=row)\n",
" )\n",
" (4): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.14500000000000002, mode=row)\n",
" )\n",
" (5): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.15, mode=row)\n",
" )\n",
" (6): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.155, mode=row)\n",
" )\n",
" (7): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.16, mode=row)\n",
" )\n",
" (8): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.165, mode=row)\n",
" )\n",
" (9): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.17, mode=row)\n",
" )\n",
" (10): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.175, mode=row)\n",
" )\n",
" (11): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.18, mode=row)\n",
" )\n",
" (12): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.185, mode=row)\n",
" )\n",
" (13): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.19, mode=row)\n",
" )\n",
" (14): MBConv(\n",
" (block): Sequential(\n",
" (0): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (1): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
" (1): BatchNorm2d(1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" (2): SqueezeExcitation(\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (fc1): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
" (fc2): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
" (activation): SiLU(inplace=True)\n",
" (scale_activation): Sigmoid()\n",
" )\n",
" (3): Conv2dNormActivation(\n",
" (0): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (stochastic_depth): StochasticDepth(p=0.195, mode=row)\n",
" )\n",
" )\n",
" (7): Conv2dNormActivation(\n",
" (0): Conv2d(256, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(1280, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): SiLU(inplace=True)\n",
" )\n",
" )\n",
" (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
" (classifier): Sequential(\n",
" (0): Dropout(p=0.2, inplace=True)\n",
" (1): Linear(in_features=1280, out_features=11, bias=True)\n",
" )\n",
")"
]
},
"metadata": {},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"source": [
"# prepare the test dataset and dataloader\n",
"test_data = ImageDataset(\n",
" train_df, train=False, test=True\n",
")\n",
"test_loader = DataLoader(\n",
" test_data, \n",
" batch_size=1,\n",
" shuffle=False\n",
")"
],
"metadata": {
"id": "bUgrJlhhIfIT"
},
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# get the list of label names from train_df\n",
"tube_statuses = train_df.columns.values[1:12]"
],
"metadata": {
"id": "S6-Hkf5_JVxQ"
},
"execution_count": 24,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Run a Loop to Get the Predictions\n",
"We will iterate over the test data loader and get the predictions for 1 batch (10 images)"
],
"metadata": {
"id": "kR9hnKmuKPWa"
}
},
{
"cell_type": "code",
"source": [
"def preds_preview():\n",
" for counter, data in enumerate(test_loader):\n",
" image, target = data['image'].to(device), data['label']\n",
" # get all the index positions where value == 1\n",
" target_indices = [i for i in range(len(target[0])) if target[0][i] == 1]\n",
" # get the predictions by passing the image through the model\n",
" outputs = test_model(image) # predictions\n",
" outputs = torch.sigmoid(outputs)\n",
" outputs = outputs.detach().cpu()\n",
" prob_indices = torch.flatten(np.argwhere(outputs[0]>0.4)) # adjust probability threshold here, need flatten to convert 2d to 1d \n",
" # print(prob_indices) # 1d list of indices e.g. 6 to correspond to label\n",
" string_predicted = ''\n",
" string_actual = ''\n",
" for i in range(len(prob_indices)):\n",
" string_predicted += f\"{tube_statuses[prob_indices[i]]} \" #string concat\n",
" for i in range(len(target_indices)):\n",
" string_actual += f\"{tube_statuses[target_indices[i]]} \"\n",
" image = image.squeeze(0)\n",
" image = image.detach().cpu().numpy()\n",
" image = np.transpose(image, (1, 2, 0))\n",
" plt.imshow(image)\n",
" plt.axis('off')\n",
" plt.title(f\"PREDICTED: {string_predicted}\\nACTUAL: {string_actual}\")\n",
" # plt.savefig(f\"drive/MyDrive/Colab Notebooks/inferences/inference_{counter}.jpg\")\n",
" plt.show()\n",
" \n",
"preds_preview()"
],
"metadata": {
"outputId": "5af9103a-13aa-42c7-a869-f513ccaceb20",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "xyoD_ahut0gD"
},
"execution_count": 26,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:56: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Should output 10 images for inference using test dataset. \n",
"\n",
"Looks like the model consistently output 3 labels under ETT, NGT and CVC even though not all 3 categories are labelled for every image. "
],
"metadata": {
"id": "JZs7sA4iMPZL"
}
},
{
"cell_type": "markdown",
"source": [
"## Construct Kaggle submission file\n",
"\n",
"__Reference:__\n",
"1. [Kaggle notebook](https://www.kaggle.com/code/ammarali32/resnet200d-inference-single-model-lb-96-5)"
],
"metadata": {
"id": "ugCYbg3YRset"
}
},
{
"cell_type": "code",
"source": [
"test2_models = [test_model.to(device)] # make model become an iterable"
],
"metadata": {
"id": "X3U-Nult1PJ4"
},
"execution_count": 27,
"outputs": []
},
{
"cell_type": "code",
"source": [
"TEST_PATH = 'test' # define for following class function"
],
"metadata": {
"id": "xhp2xzTK0wHJ"
},
"execution_count": 28,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class TestDataset(Dataset):\n",
" def __init__(self, df, transform=None):\n",
" self.df = df\n",
" self.file_names = df['StudyInstanceUID'].values\n",
" self.transform = transform\n",
" \n",
" def __len__(self):\n",
" return len(self.df)\n",
"\n",
" def __getitem__(self, idx):\n",
" file_name = self.file_names[idx]\n",
" file_path = f'{TEST_PATH}/{file_name}.jpg'\n",
" image = cv2.imread(file_path)\n",
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
" if self.transform:\n",
" augmented = self.transform(image=image)\n",
" image = augmented['image']\n",
" return image"
],
"metadata": {
"id": "lrGNuL_7yWsc"
},
"execution_count": 29,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import albumentations\n",
"from albumentations import *\n",
"from albumentations.pytorch import ToTensorV2"
],
"metadata": {
"id": "0wVJEGmA0gt4"
},
"execution_count": 30,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_transforms():\n",
" return Compose([\n",
" Resize(400, 400),\n",
" Normalize(\n",
" ),\n",
" ToTensorV2(),\n",
" ])"
],
"metadata": {
"id": "2-j_tpYlybg_"
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def inference(models, test_loader, device):\n",
" tk0 = tqdm(enumerate(test_loader), total=len(test_loader))\n",
" probs = []\n",
" for i, (images) in tk0:\n",
" images = images.to(device)\n",
" avg_preds = []\n",
" for model in models:\n",
" with torch.no_grad():\n",
" y_preds1 = model(images)\n",
" #y_preds2 = model(images.flip(-1)) # vertical flip \n",
" #y_preds = (y_preds1.sigmoid().to('cpu').numpy() + y_preds2.sigmoid().to('cpu').numpy()) / 2\n",
" y_preds = y_preds1.sigmoid().to('cpu').numpy()\n",
" avg_preds.append(y_preds)\n",
" avg_preds = np.mean(avg_preds, axis=0) # 2d nested prob arrays for each batch of 4 images\n",
" # best to convert preds using 0.4 threshold conversion to binary classes before appending using if-else statment\n",
" new_avg_preds = []\n",
" for prob_arr in avg_preds:\n",
" label_arr = [1 if prob >= 0.4 else 0 for prob in prob_arr] # threshold 0.4 probability\n",
" new_avg_preds.append(label_arr)\n",
" probs.append(new_avg_preds) \n",
" probs = np.concatenate(probs)\n",
" return probs"
],
"metadata": {
"id": "NrOVEkvKym-F"
},
"execution_count": 48,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test = pd.read_csv('sample_submission.csv')"
],
"metadata": {
"id": "KMA72uioy4Ak"
},
"execution_count": 49,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test_dataset = TestDataset(test, transform=get_transforms())\n",
"test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, \n",
" num_workers=4 , pin_memory=True)\n",
"predictions = inference(test2_models, test_loader, device)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tWiWZMFAy96P",
"outputId": "3e92e5e7-c22e-4137-a230-979e9114dcb9"
},
"execution_count": 50,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 896/896 [11:43<00:00, 1.27it/s]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"target_cols = test.iloc[:, 1:12].columns.tolist()\n",
"test[target_cols] = predictions\n",
"test[['StudyInstanceUID'] + target_cols].to_csv('drive/MyDrive/Colab Notebooks/kaggle_submission.csv', index=False)\n",
"test.head()"
],
"metadata": {
"id": "o8qXw_jr_t91",
"outputId": "a6b3025c-ecf7-41a6-a468-97ca5dbae670",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 322
}
},
"execution_count": 52,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" StudyInstanceUID ETT - Abnormal \\\n",
"0 1.2.826.0.1.3680043.8.498.46923145579096002617... 0 \n",
"1 1.2.826.0.1.3680043.8.498.84006870182611080091... 0 \n",
"2 1.2.826.0.1.3680043.8.498.12219033294413119947... 0 \n",
"3 1.2.826.0.1.3680043.8.498.84994474380235968109... 0 \n",
"4 1.2.826.0.1.3680043.8.498.35798987793805669662... 0 \n",
"\n",
" ETT - Borderline ETT - Normal NGT - Abnormal NGT - Borderline \\\n",
"0 0 1 0 0 \n",
"1 0 1 0 0 \n",
"2 0 1 0 0 \n",
"3 0 0 0 0 \n",
"4 0 1 0 0 \n",
"\n",
" NGT - Incompletely Imaged NGT - Normal CVC - Abnormal CVC - Borderline \\\n",
"0 0 0 0 0 \n",
"1 0 0 0 1 \n",
"2 1 0 0 1 \n",
"3 0 0 0 1 \n",
"4 1 0 0 1 \n",
"\n",
" CVC - Normal Swan Ganz Catheter Present \n",
"0 1 0 \n",
"1 0 0 \n",
"2 1 0 \n",
"3 0 0 \n",
"4 0 0 "
],
"text/html": [
"\n",
" \n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" StudyInstanceUID | \n",
" ETT - Abnormal | \n",
" ETT - Borderline | \n",
" ETT - Normal | \n",
" NGT - Abnormal | \n",
" NGT - Borderline | \n",
" NGT - Incompletely Imaged | \n",
" NGT - Normal | \n",
" CVC - Abnormal | \n",
" CVC - Borderline | \n",
" CVC - Normal | \n",
" Swan Ganz Catheter Present | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1.2.826.0.1.3680043.8.498.46923145579096002617... | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 1.2.826.0.1.3680043.8.498.84006870182611080091... | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.2.826.0.1.3680043.8.498.12219033294413119947... | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" 3 | \n",
" 1.2.826.0.1.3680043.8.498.84994474380235968109... | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" 1.2.826.0.1.3680043.8.498.35798987793805669662... | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
\n",
"
\n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
]
},
"metadata": {},
"execution_count": 52
}
]
}
]
}