KennethTM commited on
Commit
7c196be
1 Parent(s): 7dfc3c5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import onnxruntime as ort
5
+
6
+ def resize_and_crop(image, image_size):
7
+
8
+ # Resize the image such that the shortest side is image_size
9
+ original_size = image.size
10
+ ratio = float(image_size) / min(original_size)
11
+ new_size = tuple([int(x * ratio) for x in original_size])
12
+ resized_image = image.resize(new_size, Image.LANCZOS)
13
+
14
+ # Calculate coordinates for center cropping
15
+ left = (resized_image.width - image_size) / 2
16
+ top = (resized_image.height - image_size) / 2
17
+ right = (resized_image.width + image_size) / 2
18
+ bottom = (resized_image.height + image_size) / 2
19
+
20
+ # Crop the image to image_size x image_size
21
+ cropped_image = resized_image.crop((left, top, right, bottom))
22
+
23
+ array = np.array(cropped_image)
24
+ array = np.transpose(array, (2, 0, 1))
25
+ array = np.expand_dims(array, axis=0)
26
+ array = (array/255).astype(np.float32)
27
+
28
+ return array
29
+
30
+ # Read class labels from a text file
31
+ def read_labels(file_path):
32
+ with open(file_path, 'r') as f:
33
+ labels = [line.strip() for line in f.readlines()]
34
+ return labels
35
+
36
+ # Load the class labels
37
+ class_labels = read_labels('vocab_formatted.txt')
38
+
39
+ taxon_included = read_labels('taxon_included.txt')
40
+ string_taxon_included = ', '.join(sorted(taxon_included))
41
+
42
+ taxon_not_included = read_labels('taxon_not_included.txt')
43
+ string_taxon_not_included = ', '.join(sorted(taxon_not_included))
44
+
45
+ # Load the ONNX model
46
+ onnx_model = ort.InferenceSession('convnext_tiny.onnx')
47
+
48
+ input_name = onnx_model.get_inputs()[0].name
49
+ output_name = onnx_model.get_outputs()[0].name
50
+
51
+ # Define the inference function
52
+ def classify_image(image):
53
+ input_array = resize_and_crop(image, 320)
54
+ outputs = onnx_model.run([output_name], {input_name: input_array})[0]
55
+ result = {taxon: prob for taxon, prob in zip(class_labels, outputs[0])}
56
+ return result
57
+
58
+ # Create the Gradio interface
59
+ iface = gr.Interface(
60
+ fn=classify_image,
61
+ inputs=gr.Image(type="pil"),
62
+ outputs=gr.Label(num_top_classes=5),
63
+ title="Image Classification for Freshwater Fish Species of Denmark",
64
+ description=f"**Upload an image to classify it**.\n\nSpecies included (Danish common name):\n*{string_taxon_included}*\n\nSpecies not included (yet!):\n*{string_taxon_not_included}*",
65
+ )
66
+
67
+ # Launch the app
68
+ if __name__ == "__main__":
69
+ iface.launch()