darshanjani commited on
Commit
58c979f
1 Parent(s): 09036b9

Upload 17 files

Browse files
Store/examples/airplane.png ADDED
Store/examples/bird.webp ADDED
Store/examples/car.jpg ADDED
Store/examples/cat.jpeg ADDED
Store/examples/deer.webp ADDED
Store/examples/dog1.jpg ADDED
Store/examples/frog1.webp ADDED
Store/examples/horse.jpg ADDED
Store/examples/shipp.jpg ADDED
Store/examples/truck1.jpg ADDED
Store/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbd64f23fadf7bffb54d9f55e39771ebb15e40e3d64660d3972cc650def37d51
3
+ size 26333951
Utilities/config.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+
4
+ #SEED
5
+
6
+ SEED = 1
7
+
8
+ #DATASET
9
+
10
+ CLASSES = (
11
+ "Airplane",
12
+ "Automobile",
13
+ "Bird",
14
+ "Cat",
15
+ "Deer",
16
+ "Dog",
17
+ "Frog",
18
+ "Horse",
19
+ "Ship",
20
+ "Truck"
21
+ )
22
+
23
+ SHUFFLE = True
24
+ DATA_DIR = "../data"
25
+ NUM_WORKERS = 4
26
+ PIN_MEMORY = True
27
+
28
+ # TRAINING HP
29
+
30
+ CRITERION = F.cross_entropy
31
+ INPUT_SIZE = (3, 32, 32)
32
+ NUM_CLASSES = 10
33
+ LEARNING_RATE = 0.001
34
+ WEIGHT_DECAY = 1e-4
35
+ BATCH_SIZE = 512
36
+ NUM_EPOCHS = 24
37
+ DROPOUT_PERCENTAGE = 0.05
38
+ LAYER_NORM = "bn"
39
+
40
+ # OPTIMIZER & SCHEDULAR
41
+
42
+ LRFINDER_END_LR = 0.1
43
+ LRFINDER_NUM_ITERATIONS = 50
44
+ LRFINDER_STEP_MODE = "exp"
45
+
46
+ OCLR_DIV_FACTOR = 100
47
+ OCLR_FINAL_DIV_FACTOR = 100
48
+ OCLR_THREE_PHASE = False
49
+ OCLR_ANNEAL_STRATEGY = "linear"
50
+
51
+ # COMPUTE RELATED
52
+
53
+ ACCELERATOR = "cpu"
54
+ PRECISION = 32
55
+
56
+ # STORAGE
57
+
58
+ TRAINING_STAT_STORE = "Store/training_stats.csv"
59
+ MODEL_SAVE_PATH = "Store/model.pth"
60
+ PRED_STORE_PATH = "Store/pred_store.pth"
61
+ EXAMPLE_IMG_PATH = "Store/examples/"
62
+
63
+ # VISULIZATION
64
+
65
+ NORM_CONF_MAT = True
Utilities/model.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import pandas as pd
4
+ import pytorch_lightning as pl
5
+ import seaborn as sns
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.optim as optim
10
+ import torchmetrics
11
+ from torch.optim.lr_scheduler import OneCycleLR
12
+ from torch_lr_finder import LRFinder
13
+
14
+ from . import config # Custom config file
15
+ from .visualize import plot_incorrect_preds
16
+
17
+
18
+ class Net(pl.LightningModule):
19
+ def __init__(
20
+ self,
21
+ num_classes=10,
22
+ dropout_percentage=0,
23
+ norm='bn',
24
+ num_groups=2,
25
+ criterion=F.cross_entropy,
26
+ learning_rate=0.001,
27
+ weight_decay=0.0
28
+ ):
29
+ super(Net, self).__init__()
30
+
31
+ # Define norm
32
+ if norm == 'bn':
33
+ self.norm = nn.BatchNorm2d
34
+ elif norm == 'gn':
35
+ self.norm = lambda in_dim: nn.GroupNorm(
36
+ num_groups=num_groups, num_channels=in_dim
37
+ )
38
+ elif norm == 'ln':
39
+ self.norm = lambda in_dim: nn.GroupNorm(
40
+ num_groups=in_dim, num_channels=in_dim
41
+ )
42
+
43
+ #define loss
44
+ self.criterion = criterion
45
+
46
+ #define metrics
47
+ self.accuracy = torchmetrics.Accuracy(
48
+ task='multiclass', num_classes=num_classes
49
+ )
50
+ self.confusion_matrix = torchmetrics.ConfusionMatrix(
51
+ task='multiclass', num_classes=num_classes
52
+ )
53
+
54
+ #define the optimizer hyperparameters
55
+ self.learning_rate = learning_rate
56
+ self.weight_decay = weight_decay
57
+
58
+ #prediction storage
59
+ self.pred_store = {
60
+ "test_preds": torch.tensor([]),
61
+ "test_labels": torch.tensor([]),
62
+ "test_incorrect": [] #?
63
+ }
64
+ self.log_store = { # not used at all
65
+ "train_loss_epoch": [],
66
+ "train_acc_epoch": [],
67
+ "val_loss_epoch": [],
68
+ "val_acc_epoch": [],
69
+ "test_loss_epoch": [], # not used
70
+ "test_acc_epoch": [], # not used
71
+ }
72
+
73
+ # Define the network architecture
74
+ self.prep_layer = nn.Sequential(
75
+ nn.Conv2d(3, 64, kernel_size=3, padding=1), # 32x32x3 | 1 -> 32x32x64 | 3
76
+ self.norm(64),
77
+ nn.ReLU(),
78
+ nn.Dropout(dropout_percentage),
79
+ )
80
+
81
+ self.l1 = nn.Sequential(
82
+ nn.Conv2d(64, 128, kernel_size=3, padding=1), # 32x32x128 | 5
83
+ nn.MaxPool2d(2, 2), # 16x16x128 | 6
84
+ self.norm(128),
85
+ nn.ReLU(),
86
+ nn.Dropout(dropout_percentage),
87
+ )
88
+ self.l1res = nn.Sequential(
89
+ nn.Conv2d(128, 128, kernel_size=3, padding=1), # 16x16x128 | 10
90
+ self.norm(128),
91
+ nn.ReLU(),
92
+ nn.Dropout(dropout_percentage),
93
+ nn.Conv2d(128, 128, kernel_size=3, padding=1), # 16x16x128 | 14
94
+ self.norm(128),
95
+ nn.ReLU(),
96
+ nn.Dropout(dropout_percentage),
97
+ )
98
+ self.l2 = nn.Sequential(
99
+ nn.Conv2d(128, 256, kernel_size=3, padding=1), # 16x16x256 | 18
100
+ nn.MaxPool2d(2, 2), # 8x8x256 | 19
101
+ self.norm(256),
102
+ nn.ReLU(),
103
+ nn.Dropout(dropout_percentage),
104
+ )
105
+ self.l3 = nn.Sequential(
106
+ nn.Conv2d(256, 512, kernel_size=3, padding=1), # 8x8x512 | 27
107
+ nn.MaxPool2d(2, 2), # 4x4x512 | 28
108
+ self.norm(512),
109
+ nn.ReLU(),
110
+ nn.Dropout(dropout_percentage),
111
+ )
112
+ self.l3res = nn.Sequential(
113
+ nn.Conv2d(512, 512, kernel_size=3, padding=1), # 4x4x512 | 36
114
+ self.norm(512),
115
+ nn.ReLU(),
116
+ nn.Dropout(dropout_percentage),
117
+ nn.Conv2d(512, 512, kernel_size=3, padding=1), # 4x4x512 | 44
118
+ self.norm(512),
119
+ nn.ReLU(),
120
+ nn.Dropout(dropout_percentage),
121
+ )
122
+ self.maxpool = nn.MaxPool2d(4, 4)
123
+
124
+ # Classifier
125
+ self.linear = nn.Linear(512, 10)
126
+
127
+ def forward(self, x):
128
+ x = self.prep_layer(x)
129
+ x = self.l1(x)
130
+ x = x + self.l1res(x)
131
+ x = self.l2(x)
132
+ x = self.l3(x)
133
+ x = x + self.l3res(x)
134
+ x = self.maxpool(x)
135
+ x = x.view(-1, 512)
136
+ x = self.linear(x)
137
+ return F.log_softmax(x, dim=1)
138
+
139
+ def training_step(self, batch, batch_idx):
140
+ data, target = batch
141
+
142
+ #forward pass
143
+ pred = self.forward(data)
144
+
145
+ #calculate loss
146
+ loss = self.criterion(pred, target)
147
+
148
+ #calculate accuracy
149
+ accuracy = self.accuracy(pred, target)
150
+
151
+ #log metrics
152
+ self.log_dict(
153
+ {"train_loss": loss, "train_acc": accuracy},
154
+ on_step=True,
155
+ on_epoch=True,
156
+ prog_bar=True,
157
+ logger=True,
158
+ )
159
+ return loss
160
+
161
+
162
+ def validation_step(self, batch, batch_idx):
163
+ data, target = batch
164
+
165
+ #forward pass
166
+ pred = self.forward(data)
167
+
168
+ #calculate loss
169
+ loss = self.criterion(pred, target)
170
+
171
+ #calculate accuracy
172
+ accuracy = self.accuracy(pred, target)
173
+
174
+ #log metrics
175
+ self.log_dict(
176
+ {"val_loss": loss, "val_acc": accuracy},
177
+ on_step=True,
178
+ on_epoch=True,
179
+ prog_bar=True,
180
+ logger=True,
181
+ )
182
+ return loss
183
+
184
+ def test_step(self, batch, batch_idx):
185
+ data, target = batch
186
+
187
+ #forward pass
188
+ pred = self.forward(data)
189
+ argmax_pred = pred.argmax(dim=1).cpu() # why cpu here when down
190
+
191
+ #calculate loss
192
+ loss = self.criterion(pred, target)
193
+
194
+ #calculate accuracy
195
+ accuracy = self.accuracy(pred, target)
196
+
197
+ #update confusion matrix
198
+ self.confusion_matrix.update(pred, target)
199
+
200
+ #log metrics
201
+ self.log_dict(
202
+ {"test_loss": loss, "test_acc": accuracy},
203
+ on_step=True,
204
+ on_epoch=True,
205
+ prog_bar=True,
206
+ logger=True,
207
+ )
208
+
209
+ #store the predictions. labels and incorrect predictions
210
+
211
+ #converting to cpu
212
+ data, target, pred, argmax_pred = data.cpu(), target.cpu(), pred.cpu(), argmax_pred.cpu()
213
+
214
+ #storing the predictions
215
+ self.pred_store["test_preds"] = torch.cat((self.pred_store["test_preds"], argmax_pred), dim=0)
216
+ self.pred_store["test_labels"] = torch.cat((self.pred_store["test_labels"], target), dim=0)
217
+
218
+ for d, t, p, o in zip(data, target, argmax_pred, pred):
219
+ if p.eq(t.view_as(p)).item() == False:
220
+ self.pred_store["test_incorrect"].append(
221
+ (d.cpu(), t, p, o[p.item()].cpu())
222
+ )
223
+
224
+ return loss
225
+
226
+ def find_bestLR_LRFinder(self, optimizer):
227
+
228
+ lr_finder = LRFinder(self, optimizer, criterian = self.criterion)
229
+ lr_finder.range_test(
230
+ self.trainer.datamodule.train_dataloader(),
231
+ end_lr=config.LRFINDER_END_LR,
232
+ num_iter=config.LRFINDER_NUM_ITERATIONS,
233
+ step_mode=config.LRFINDER_STEP_MODE
234
+ )
235
+ # best_lr = None
236
+ # Extract the loss and learning rate from history
237
+ loss = np.array(lr_finder.history['loss'])
238
+ lr = np.array(lr_finder.history['lr'])
239
+
240
+ # Find the learning rate with steepest negative gradient
241
+ gradient = np.gradient(loss)
242
+ idx = np.argmin(gradient)
243
+ best_lr = lr[idx]
244
+
245
+ try:
246
+ _, y = lr_finder.plot()
247
+ except Exception as e:
248
+ pass
249
+
250
+ print("BEST_LR: ", best_lr)
251
+ lr_finder.reset()
252
+
253
+ return best_lr
254
+
255
+ def configure_optimizers(self):
256
+ optimizer = self.get_only_optimizer()
257
+ best_lr = self.find_bestLR_LRFinder(optimizer)
258
+ scheduler = OneCycleLR(
259
+ optimizer,
260
+ max_lr=best_lr, #used best_lr insted of hard coded values
261
+ steps_per_epoch=len(self.trainer.datamodule.train_dataloader()),
262
+ epochs=config.NUM_EPOCHS,
263
+ pct_start=5 / config.NUM_EPOCHS,
264
+ div_factor=config.OCLR_DIV_FACTOR,
265
+ three_phase=config.OCLR_THREE_PHASE,
266
+ final_div_factor=config.OCLR_FINAL_DIV_FACTOR,
267
+ anneal_strategy=config.OCLR_ANNEAL_STRATEGY
268
+ )
269
+
270
+ return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]
271
+
272
+ def get_only_optimizer(self):
273
+ optimizer = optim.Adam(
274
+ self.parameters(),lr=self.learning_rate, weight_decay=self.weight_decay
275
+ )
276
+ return optimizer
277
+
278
+ def on_test_end(self) -> None:
279
+ super().on_test_end()
280
+
281
+ #Confusion Matrix
282
+ confmat = self.confusion_matrix.cpu().compute().numpy()
283
+ if config.NORM_CONF_MAT:
284
+ df_confmat = pd.DataFrame(
285
+ confmat / np.sum(confmat, axis=1)[:, None],
286
+ index=[i for i in config.CLASSES],
287
+ columns=[i for i in config.CLASSES],
288
+ )
289
+ else:
290
+ df_confmat = pd.DataFrame(
291
+ confmat,
292
+ index=[i for i in config.CLASSES],
293
+ columns=[i for i in config.CLASSES],
294
+ )
295
+ plt.figure(figsize=(7, 5))
296
+ sns.heatmap(df_confmat, annot=True, cmap="Blues", fmt=".3f", linewidths=0.5)
297
+ plt.tight_layout()
298
+ plt.ylabel("True label")
299
+ plt.xlabel("Predicted label")
300
+ plt.show()
301
+
302
+ def plot_incorrect_predictions_helper(self, num_imgs=10):
303
+ plot_incorrect_preds(
304
+ self.pred_store["test_incorrect"], config.CLASSES, num_imgs
305
+ )
Utilities/transforms.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ from albumentations.pytorch import ToTensorV2
3
+
4
+ # Define the transforms (only test)
5
+
6
+ test_transforms = A.Compose([
7
+
8
+ A.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
9
+ ToTensorV2()
10
+
11
+ ])
Utilities/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pytorch_grad_cam import GradCAM
3
+ from pytorch_grad_cam.utils.image import show_cam_on_image
4
+
5
+ from . import config
6
+ from .transforms import test_transforms
7
+
8
+
9
+ def generate_confidences(
10
+ model,
11
+ input_img,
12
+ num_top_preds,
13
+ ):
14
+ input_img = test_transforms(image=input_img)
15
+ input_img = input_img["image"]
16
+
17
+ input_img = input_img.unsqueeze(0)
18
+ model.eval()
19
+ log_probs = model(input_img)[0].detach()
20
+ model.train()
21
+ probs = torch.exp(log_probs)
22
+
23
+ confidences = {
24
+ config.CLASSES[i]: float(probs[i]) for i in range(len(config.CLASSES))
25
+ }
26
+ # Select top 5 confidences based on value
27
+ confidences = {
28
+ k: v
29
+ for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)[
30
+ :num_top_preds
31
+ ]
32
+ }
33
+ return input_img, confidences
34
+
35
+
36
+ def generate_gradcam(
37
+ model,
38
+ org_img,
39
+ input_img,
40
+ show_gradcam,
41
+ gradcam_layer,
42
+ gradcam_opacity,
43
+ ):
44
+ if show_gradcam:
45
+ if gradcam_layer == -1:
46
+ target_layers = [model.l3[-1]]
47
+ elif gradcam_layer == -2:
48
+ target_layers = [model.l2[-1]]
49
+
50
+ cam = GradCAM(
51
+ model=model,
52
+ target_layers=target_layers,
53
+ )
54
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
55
+ grayscale_cam = grayscale_cam[0, :]
56
+
57
+ visualization = show_cam_on_image(
58
+ org_img / 255,
59
+ grayscale_cam,
60
+ use_rgb=True,
61
+ image_weight=(1 - gradcam_opacity),
62
+ )
63
+ else:
64
+ visualization = None
65
+ return visualization
66
+
67
+
68
+ def generate_missclassified_imgs(
69
+ model,
70
+ show_misclassified,
71
+ num_misclassified,
72
+ ):
73
+ if show_misclassified:
74
+ plot = model.plot_incorrect_predictions_helper(num_misclassified)
75
+ else:
76
+ plot = None
77
+ return plot
Utilities/visualize.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from torchvision import transforms
3
+ import random as rand
4
+
5
+ def plot_incorrect_preds(incorrect, classes, num_imgs):
6
+ # num_imgs is a multiple of 5
7
+ assert num_imgs % 5 == 0
8
+ assert len(incorrect) >= num_imgs
9
+
10
+ incorrect_inds = rand.sample(range(len(incorrect)), num_imgs)
11
+
12
+ # incorrect (data, target, pred, output)
13
+ fig = plt.figure(figsize=(10, num_imgs // 2))
14
+ plt.suptitle("Target | Predicted Label")
15
+ for i in range(num_imgs):
16
+ cur_incorrect = incorrect[incorrect_inds[i]]
17
+ plt.subplot(num_imgs // 5, 5, i + 1, aspect="auto")
18
+
19
+ # unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
20
+ unnormalized = transforms.Normalize(
21
+ (-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
22
+ )(cur_incorrect[i][0])
23
+ plt.imshow(transforms.ToPILImage()(unnormalized))
24
+ plt.title(
25
+ f"{classes[cur_incorrect[i][1].item()]}|{classes[cur_incorrect[i][2].item()]}",
26
+ # fontsize=8,
27
+ )
28
+ plt.xticks([])
29
+ plt.yticks([])
30
+ plt.tight_layout()
31
+ return fig
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ from Utilities.model import Net
5
+ from Utilities import config
6
+ from Utilities.utils import generate_confidences, generate_gradcam, generate_missclassified_imgs
7
+
8
+ inputs = [
9
+
10
+ gr.Image(shape=(32, 32), label="Input Image"),
11
+ gr.Slider(minimum=1, maximum=10, step=1, label="Number of Top Prediction to Display"),
12
+ gr.Checkbox(default=False, label="Show GradCAM"),
13
+ gr.Slider(minimum=-2, maximum=-1, step=1, value=-1, label="GradCAM Layer (from the end)"),
14
+ gr.Slider(minimum=0, maximum=1, value=0.5, label="GradCAM Heatmap Opacity"),
15
+ gr.Checkbox(label="Show Incorrect Predictions"),
16
+ gr.Slider(minimum=5, maximum=50, step=5, label="Number of Incorrect Predictions to Display"),
17
+
18
+ ]
19
+
20
+ model = Net(
21
+ num_classes=config.NUM_CLASSES,
22
+ dropout_percentage = config.DROPOUT_PERCENTAGE,
23
+ norm = config.LAYER_NORM,
24
+ criterion = config.CRITERION,
25
+ learning_rate = config.LEARNING_RATE,
26
+ weight_decay = config.WEIGHT_DECAY
27
+ )
28
+
29
+ model.load_state_dict(
30
+ torch.load(
31
+ config.MODEL_PATH,
32
+ map_location=torch.device(config.ACCELERATOR)
33
+ )
34
+ )
35
+
36
+ model.pred_store = torch.load(config.PRED_STORE_PATH, map_location=torch.device(config.ACCELERATOR))
37
+
38
+ def generate_gradio_output(
39
+ input_img,
40
+ num_top_preds,
41
+ show_gradcam,
42
+ gradcam_layer,
43
+ gradcam_opacity,
44
+ show_misclassified,
45
+ num_misclassified,
46
+ ):
47
+ processed_img, confidences = generate_confidences(
48
+ model=model,
49
+ input_img=input_img,
50
+ num_top_preds=num_top_preds
51
+ )
52
+
53
+ visulization = generate_gradcam(
54
+ model=model,
55
+ org_img=input_img,
56
+ input_img=processed_img,
57
+ show_gradcam=show_gradcam,
58
+ gradcam_layer=gradcam_layer,
59
+ gradcam_opacity=gradcam_opacity,
60
+ )
61
+
62
+ plot = generate_missclassified_imgs(
63
+ model=model,
64
+ show_misclassified=show_misclassified,
65
+ num_misclassified=num_misclassified,
66
+ )
67
+
68
+ return confidences, visulization, plot
69
+
70
+ outputs = [
71
+ gr.Label(visible=True, scale=0.5, label="Classification Confidences"),
72
+ gr.Image(shape=(32, 32), label="GradCAM Visualization").style(
73
+ width=256, height=256, visible=True
74
+ ),
75
+ gr.Plot(visible=True, label="Misclassified Images")
76
+ ]
77
+
78
+ examples = [
79
+ [config.EXAMPLE_IMG_PATH + "cat.jpeg", 3, True, -2, 0.68, True, 40],
80
+ [config.EXAMPLE_IMG_PATH + "horse.jpg", 3, True, -2, 0.59, True, 25],
81
+ [config.EXAMPLE_IMG_PATH + "bird.webp", 10, True, -1, 0.55, True, 20],
82
+ [config.EXAMPLE_IMG_PATH + "dog1.jpg", 10, True, -1, 0.33, True, 45],
83
+ [config.EXAMPLE_IMG_PATH + "frog1.webp", 5, True, -1, 0.64, True, 40],
84
+ [config.EXAMPLE_IMG_PATH + "deer.webp", 1, True, -2, 0.45, True, 20],
85
+ [config.EXAMPLE_IMG_PATH + "airplane.png", 3, True, -2, 0.43, True, 40],
86
+ [config.EXAMPLE_IMG_PATH + "shipp.jpg", 7, True, -1, 0.6, True, 30],
87
+ [config.EXAMPLE_IMG_PATH + "car.jpg", 2, True, -1, 0.68, True, 30],
88
+ [config.EXAMPLE_IMG_PATH + "truck1.jpg", 5, True, -2, 0.51, True, 35],
89
+ ]
90
+
91
+ title = "Image Classification (CIFAR10 - 10 Classes) with GradCAM"
92
+ description = """A simple Gradio interface to visualize the output of a CNN trained on CIFAR10 dataset with GradCAM and Misclassified images.
93
+ The architecture is inspired from David Page's (myrtle.ai) DAWNBench winning model archiecture.
94
+ Please input the image and select the number of top predictions to display - you will see the top predictions and their corresponding confidence scores.
95
+ You can also select whether to show GradCAM for the particular image (utilizes the gradients of the classification score with respect to the final convolutional feature map, to identify the parts of an input image that most impact the classification score).
96
+ You need to select the model layer where the gradients need to be plugged from - this affects how much of the image is used to compute the GradCAM.
97
+ You can also select whether to show misclassified images - these are the images that the model misclassified.
98
+ Some examples are provided in the examples tab.
99
+ """
100
+
101
+ gr.Interface(
102
+ fn=generate_gradio_output,
103
+ inputs=inputs,
104
+ outputs=outputs,
105
+ title=title,
106
+ description=description,
107
+ examples=examples
108
+ ).launch()