Spaces:
Sleeping
Sleeping
import gradio as gr | |
from tensorflow.keras.models import load_model | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
# Load the trained model | |
model = load_model('model2.h5') | |
def predict_and_visualize(img): | |
# Input validation | |
if img is None: | |
raise gr.Error("Please upload an image") | |
try: | |
# Convert numpy array to PIL Image if necessary | |
if isinstance(img, np.ndarray): | |
img = Image.fromarray(img) | |
# Store the original image size | |
original_size = img.size | |
# Convert the input image to the target size expected by the model | |
img_resized = img.resize((224, 224)) | |
img_array = np.array(img_resized) / 255.0 # Normalize the image | |
# Ensure the image has 3 channels (RGB) | |
if len(img_array.shape) == 2: # Grayscale image | |
img_array = np.stack((img_array,)*3, axis=-1) | |
elif img_array.shape[-1] == 4: # RGBA image | |
img_array = img_array[:, :, :3] | |
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension | |
# Make a prediction | |
prediction = model.predict(img_array) | |
# Process the prediction | |
predicted_mask = (prediction[0, :, :, 0] * 255).astype(np.uint8) | |
# Convert the prediction to a PIL image | |
prediction_image = Image.fromarray(predicted_mask, mode='L') | |
# Resize the predicted image back to the original image size | |
prediction_image = prediction_image.resize(original_size, Image.NEAREST) | |
return prediction_image | |
except Exception as e: | |
raise gr.Error(f"Error processing image: {str(e)}") | |
# Create the Gradio interface with examples | |
iface = gr.Interface( | |
fn=predict_and_visualize, | |
inputs=gr.Image(type="pil", label="Input Image"), | |
outputs=gr.Image(type="pil", label="Predicted Mask"), | |
title="MilitarEye: Military Stealth Camouflage Detector", | |
description="Upload an image of a military personnel camouflaged in their surroundings. The model will predict the camouflage mask silhouette.", | |
allow_flagging="never" | |
) | |
# Launch the app | |
iface.launch() |