AliSaria commited on
Commit
0cadd71
1 Parent(s): 6b255fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -26
app.py CHANGED
@@ -6,40 +6,59 @@ import matplotlib.pyplot as plt
6
  from io import BytesIO
7
 
8
  # Load the trained model
9
- model = load_model('model2.h5') # Make sure 'model1.h5' is the correct path to your model
10
 
11
- # Prediction function for the Gradio app
12
  def predict_and_visualize(img):
13
- # Store the original image size
14
- original_size = img.size
15
-
16
- # Convert the input image to the target size expected by the model
17
- img_resized = img.resize((224,224))
18
- img_array = np.array(img_resized) / 255.0 # Normalize the image
19
- img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Make a prediction
22
- prediction = model.predict(img_array)
23
-
24
- # Assuming the model outputs a single-channel image, normalize to 0-255 range for display
25
- predicted_mask = (prediction[0, :, :, 0] * 255).astype(np.uint8)
26
-
27
- # Convert the prediction to a PIL image
28
- prediction_image = Image.fromarray(predicted_mask, mode='L') # 'L' mode is for grayscale
29
-
30
- # Resize the predicted image back to the original image size
31
- prediction_image = prediction_image.resize(original_size, Image.NEAREST)
32
 
33
- return prediction_image
 
 
 
34
 
35
- # Create the Gradio interface
36
  iface = gr.Interface(
37
  fn=predict_and_visualize,
38
- inputs=gr.Image(type="pil"), # We expect a PIL Image
39
- outputs=gr.Image(type="pil"), # We will return a PIL Image
40
  title="MilitarEye: Military Stealth Camouflage Detector",
41
- description="Please upload an image of a military personnel camouflaged in their surroundings. On the right, the model will attempt to predict the camouflage mask silhouette."
 
42
  )
43
 
44
- # Launch the Gradio app
45
  iface.launch()
 
6
  from io import BytesIO
7
 
8
  # Load the trained model
9
+ model = load_model('model2.h5')
10
 
 
11
  def predict_and_visualize(img):
12
+ # Input validation
13
+ if img is None:
14
+ raise gr.Error("Please upload an image")
15
+
16
+ try:
17
+ # Convert numpy array to PIL Image if necessary
18
+ if isinstance(img, np.ndarray):
19
+ img = Image.fromarray(img)
20
+
21
+ # Store the original image size
22
+ original_size = img.size
23
+
24
+ # Convert the input image to the target size expected by the model
25
+ img_resized = img.resize((224, 224))
26
+ img_array = np.array(img_resized) / 255.0 # Normalize the image
27
+
28
+ # Ensure the image has 3 channels (RGB)
29
+ if len(img_array.shape) == 2: # Grayscale image
30
+ img_array = np.stack((img_array,)*3, axis=-1)
31
+ elif img_array.shape[-1] == 4: # RGBA image
32
+ img_array = img_array[:, :, :3]
33
+
34
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
35
 
36
+ # Make a prediction
37
+ prediction = model.predict(img_array)
38
+
39
+ # Process the prediction
40
+ predicted_mask = (prediction[0, :, :, 0] * 255).astype(np.uint8)
41
+
42
+ # Convert the prediction to a PIL image
43
+ prediction_image = Image.fromarray(predicted_mask, mode='L')
44
+
45
+ # Resize the predicted image back to the original image size
46
+ prediction_image = prediction_image.resize(original_size, Image.NEAREST)
47
 
48
+ return prediction_image
49
+
50
+ except Exception as e:
51
+ raise gr.Error(f"Error processing image: {str(e)}")
52
 
53
+ # Create the Gradio interface with examples
54
  iface = gr.Interface(
55
  fn=predict_and_visualize,
56
+ inputs=gr.Image(type="pil", label="Input Image"),
57
+ outputs=gr.Image(type="pil", label="Predicted Mask"),
58
  title="MilitarEye: Military Stealth Camouflage Detector",
59
+ description="Upload an image of a military personnel camouflaged in their surroundings. The model will predict the camouflage mask silhouette.",
60
+ allow_flagging="never"
61
  )
62
 
63
+ # Launch the app
64
  iface.launch()