alanoix commited on
Commit
a903e67
1 Parent(s): 6ea0ce6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +181 -0
  2. requirements.txt +19 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import numpy as np
5
+ from glob import glob
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib
8
+ import tensorflow as tf
9
+ from tensorflow import keras
10
+ from tensorflow.keras import backend as K
11
+ import pandas as pd
12
+ import gc
13
+ import random
14
+ import math
15
+ import glob
16
+ import torch
17
+ import gradio as gr
18
+ from PIL import Image
19
+ import cv2
20
+
21
+
22
+ classes = ['None','building','pervious surface','impervious surface','bare soil','water','coniferous','deciduous','brushwood','vineyard','herbaceous vegetation','agricultural land','plowed land']
23
+ id2label = pd.DataFrame(classes)[0].to_dict()
24
+ print(id2label)
25
+ label2id = {v: k for k, v in id2label.items()}
26
+ num_labels = len(id2label)
27
+
28
+ from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
29
+
30
+ segformer_b0_rgb_model = SegformerForSemanticSegmentation.from_pretrained("alanoix/segformer_b0_flair_one",
31
+ num_labels=len(id2label),
32
+ id2label=id2label,
33
+ label2id=label2id)
34
+
35
+ segformer_rgb_feature_extractor = SegformerFeatureExtractor(ignore_index=0, reduce_labels=False, do_resize=False, do_rescale=False, do_normalize=False)
36
+ segformer_b0_rgb_model= torch.quantization.quantize_dynamic(segformer_b0_rgb_model, {torch.nn.Linear}, dtype=torch.qint8)
37
+
38
+
39
+ import albumentations as aug
40
+ MEAN = np.array([0.44050665, 0.45704361, 0.42254708])
41
+ STD = np.array([0.20264351, 0.1782405 , 0.17575739])
42
+
43
+ test_transform = aug.Compose([
44
+ aug.Normalize(mean=MEAN, std=STD),
45
+ ])
46
+
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ segformer_b0_rgb_model = segformer_b0_rgb_model.to(device)
49
+
50
+ class_colors = [(random.randint(0, 255), random.randint(
51
+ 0, 255), random.randint(0, 255)) for _ in range(5000)]
52
+
53
+
54
+ # Default IMAGE_ORDERING = channels_last
55
+ IMAGE_ORDERING = "channels_last"
56
+
57
+
58
+ def get_colored_segmentation_image(seg_arr, n_classes, colors=class_colors):
59
+ output_height = seg_arr.shape[0]
60
+ output_width = seg_arr.shape[1]
61
+
62
+ seg_img = np.zeros((output_height, output_width, 3))
63
+
64
+ for c in range(n_classes):
65
+ seg_arr_c = seg_arr[:, :] == c
66
+ seg_img[:, :, 0] += ((seg_arr_c)*(colors[c][0])).astype('uint8')
67
+ seg_img[:, :, 1] += ((seg_arr_c)*(colors[c][1])).astype('uint8')
68
+ seg_img[:, :, 2] += ((seg_arr_c)*(colors[c][2])).astype('uint8')
69
+
70
+ return seg_img
71
+
72
+
73
+ def get_legends(class_names, colors=class_colors):
74
+
75
+ n_classes = len(class_names)
76
+ legend = np.zeros(((len(class_names) * 25) + 25, 125, 3),
77
+ dtype="uint8") + 255
78
+
79
+ class_names_colors = enumerate(zip(class_names[:n_classes],
80
+ colors[:n_classes]))
81
+
82
+ for (i, (class_name, color)) in class_names_colors:
83
+ color = [int(c) for c in color]
84
+ cv2.putText(legend, class_name, (5, (i * 25) + 17),
85
+ cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1)
86
+ cv2.rectangle(legend, (100, (i * 25)), (125, (i * 25) + 25),
87
+ tuple(color), -1)
88
+
89
+ return legend
90
+
91
+
92
+ def overlay_seg_image(inp_img, seg_img):
93
+ orininal_h = inp_img.shape[0]
94
+ orininal_w = inp_img.shape[1]
95
+ seg_img = cv2.resize(seg_img, (orininal_w, orininal_h), interpolation=cv2.INTER_NEAREST)
96
+
97
+ fused_img = (inp_img/2 + seg_img/2).astype('uint8')
98
+ return fused_img
99
+
100
+
101
+ def concat_lenends(seg_img, legend_img):
102
+
103
+ new_h = np.maximum(seg_img.shape[0], legend_img.shape[0])
104
+ new_w = seg_img.shape[1] + legend_img.shape[1]
105
+
106
+ out_img = np.zeros((new_h, new_w, 3)).astype('uint8') + legend_img[0, 0, 0]
107
+
108
+ out_img[:legend_img.shape[0], : legend_img.shape[1]] = np.copy(legend_img)
109
+ out_img[:seg_img.shape[0], legend_img.shape[1]:] = np.copy(seg_img)
110
+
111
+ return out_img
112
+
113
+
114
+ def visualize_segmentation(seg_arr, inp_img=None, n_classes=None,
115
+ colors=class_colors, class_names=None,
116
+ overlay_img=False, show_legends=False,
117
+ prediction_width=None, prediction_height=None):
118
+
119
+ if n_classes is None:
120
+ n_classes = np.max(seg_arr)
121
+
122
+ seg_img = get_colored_segmentation_image(seg_arr, n_classes, colors=colors)
123
+
124
+ if inp_img is not None:
125
+ original_h = inp_img.shape[0]
126
+ original_w = inp_img.shape[1]
127
+ seg_img = cv2.resize(seg_img, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
128
+
129
+ if (prediction_height is not None) and (prediction_width is not None):
130
+ seg_img = cv2.resize(seg_img, (prediction_width, prediction_height), interpolation=cv2.INTER_NEAREST)
131
+ if inp_img is not None:
132
+ inp_img = cv2.resize(inp_img,
133
+ (prediction_width, prediction_height))
134
+
135
+ if overlay_img:
136
+ assert inp_img is not None
137
+ seg_img = overlay_seg_image(inp_img, seg_img)
138
+
139
+ if show_legends:
140
+ assert class_names is not None
141
+ legend_img = get_legends(class_names, colors=colors)
142
+
143
+ seg_img = concat_lenends(seg_img, legend_img)
144
+
145
+ return seg_img
146
+
147
+ def query_image(img):
148
+ image_to_pred = test_transform(image=img)['image']
149
+
150
+ pixel_values = segformer_rgb_feature_extractor(image_to_pred, return_tensors="pt").pixel_values.to(device)
151
+
152
+ outputs_segformer_b0_rgb = segformer_b0_rgb_model(pixel_values=pixel_values)
153
+ pred_segformer_b0_rgb = outputs_segformer_b0_rgb.logits.cpu().detach().numpy()
154
+
155
+ pred = np.mean(np.array([K.softmax(pred_segformer_b0_rgb, axis = 1)]), axis = 0)
156
+ pred = tf.image.resize(tf.transpose(pred, perm=[0,2,3,1]), size = [512,512], method="bilinear") # resize to 512*512
157
+ pred = np.argmax(pred, axis = -1)
158
+ pred =np.squeeze(pred)
159
+ result = pred.astype(np.uint8)
160
+
161
+ class_names = [ 'None', 'building', 'pervious surface', 'impervious surface', 'bare soil','water','coniferous','deciduous','brushwood','vineyard', 'herbaceous vegetation', 'agricultural land', 'plowed land']
162
+ seg_img = visualize_segmentation(result, img, n_classes=13,
163
+ colors=class_colors , overlay_img=True,
164
+ show_legends=True,
165
+ class_names=class_names,
166
+ prediction_width=512,
167
+ prediction_height=512)
168
+
169
+ return seg_img
170
+
171
+ demo = gr.Interface(
172
+
173
+ query_image,
174
+ inputs=[gr.Image()],
175
+ outputs="image",
176
+ title="Image Segmentation Demo",
177
+ description = "Please upload an image to see segmentation capabilities of this model",
178
+ examples=["examples/IMG_011942.jpeg","examples/IMG_005339.jpeg","examples/IMG_004753.jpeg","examples/IMG_011617.jpeg","examples/IMG_003022.jpeg"]
179
+ )
180
+
181
+ demo.launch() #debug=True
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations==1.2.1
2
+ evaluate==0.4.0
3
+ numpy==1.22.4
4
+ opencv_python==4.7.0.72
5
+ pandas==1.4.4
6
+ Pillow==9.4.0
7
+ rasterio==1.3.6
8
+ scikit_learn==1.2.2
9
+ torch==1.13.1+cu116
10
+ tqdm==4.65.0
11
+ transformers==4.27.3
12
+ GDAL==3.3.2
13
+ matplotlib==3.7.1
14
+ osgeo==0.0.1
15
+ scikit_image==0.19.3
16
+ scipy==1.10.1
17
+ skimage==0.0
18
+ tensorflow==2.11.0
19
+