File size: 1,297 Bytes
5331d9c
503f165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import gradio as gr
from deepforest import main
import matplotlib.pyplot as plt

# Initialize the deepforest model and use the released version
model = main.deepforest()
model.use_release()

def predict_and_visualize(image):
    """
    Function to predict and visualize the image using deepforest model.
    
    Args:
    - image: An image array.
    
    Returns:
    - An image with predictions visualized.
    """
    # Predict image and return plot. Since Gradio passes image as array, save it temporarily.
    temp_path = "/tmp/uploaded_image.png"
    plt.imsave(temp_path, image)
    img = model.predict_image(path=temp_path, return_plot=True)
    
    # Since the output is BGR and matplotlib (and hence Gradio) needs RGB, we convert the color scheme
    img_rgb = img[:, :, ::-1]
    
    # Return the RGB image
    return img_rgb

# Define the Gradio interface
iface = gr.Interface(fn=predict_and_visualize,
                     inputs=gr.Image(type="numpy", label="Upload Image"),
                     outputs=gr.Image(label="Predicted Image"),
                     title="DeepForest Tree Detection",
                     examples=["./example.jpg"],
                     description="Upload an image to detect trees using the DeepForest model.")

# Launch the Gradio app
iface.launch()