Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gradio as gr
|
3 |
+
from PIL import Image
|
4 |
+
import onnxruntime as ort
|
5 |
+
|
6 |
+
def resize_and_crop(image, image_size):
|
7 |
+
|
8 |
+
# Resize the image such that the shortest side is image_size
|
9 |
+
original_size = image.size
|
10 |
+
ratio = float(image_size) / min(original_size)
|
11 |
+
new_size = tuple([int(x * ratio) for x in original_size])
|
12 |
+
resized_image = image.resize(new_size, Image.LANCZOS)
|
13 |
+
|
14 |
+
# Calculate coordinates for center cropping
|
15 |
+
left = (resized_image.width - image_size) / 2
|
16 |
+
top = (resized_image.height - image_size) / 2
|
17 |
+
right = (resized_image.width + image_size) / 2
|
18 |
+
bottom = (resized_image.height + image_size) / 2
|
19 |
+
|
20 |
+
# Crop the image to image_size x image_size
|
21 |
+
cropped_image = resized_image.crop((left, top, right, bottom))
|
22 |
+
|
23 |
+
array = np.array(cropped_image)
|
24 |
+
array = np.transpose(array, (2, 0, 1))
|
25 |
+
array = np.expand_dims(array, axis=0)
|
26 |
+
array = (array/255).astype(np.float32)
|
27 |
+
|
28 |
+
return array
|
29 |
+
|
30 |
+
# Read class labels from a text file
|
31 |
+
def read_labels(file_path):
|
32 |
+
with open(file_path, 'r') as f:
|
33 |
+
labels = [line.strip() for line in f.readlines()]
|
34 |
+
return labels
|
35 |
+
|
36 |
+
# Load the class labels
|
37 |
+
class_labels = read_labels('vocab_formatted.txt')
|
38 |
+
|
39 |
+
taxon_included = read_labels('taxon_included.txt')
|
40 |
+
string_taxon_included = ', '.join(sorted(taxon_included))
|
41 |
+
|
42 |
+
taxon_not_included = read_labels('taxon_not_included.txt')
|
43 |
+
string_taxon_not_included = ', '.join(sorted(taxon_not_included))
|
44 |
+
|
45 |
+
# Load the ONNX model
|
46 |
+
onnx_model = ort.InferenceSession('convnext_tiny.onnx')
|
47 |
+
|
48 |
+
input_name = onnx_model.get_inputs()[0].name
|
49 |
+
output_name = onnx_model.get_outputs()[0].name
|
50 |
+
|
51 |
+
# Define the inference function
|
52 |
+
def classify_image(image):
|
53 |
+
input_array = resize_and_crop(image, 320)
|
54 |
+
outputs = onnx_model.run([output_name], {input_name: input_array})[0]
|
55 |
+
result = {taxon: prob for taxon, prob in zip(class_labels, outputs[0])}
|
56 |
+
return result
|
57 |
+
|
58 |
+
# Create the Gradio interface
|
59 |
+
iface = gr.Interface(
|
60 |
+
fn=classify_image,
|
61 |
+
inputs=gr.Image(type="pil"),
|
62 |
+
outputs=gr.Label(num_top_classes=5),
|
63 |
+
title="Image Classification for Freshwater Fish Species of Denmark",
|
64 |
+
description=f"**Upload an image to classify it**.\n\nSpecies included (Danish common name):\n*{string_taxon_included}*\n\nSpecies not included (yet!):\n*{string_taxon_not_included}*",
|
65 |
+
)
|
66 |
+
|
67 |
+
# Launch the app
|
68 |
+
if __name__ == "__main__":
|
69 |
+
iface.launch()
|