|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow import keras |
|
from tensorflow.keras import layers |
|
from matplotlib import pyplot as plt |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
import gradio as gr |
|
from huggingface_hub import from_pretrained_keras |
|
|
|
|
|
model = from_pretrained_keras('keras-io/attention_mil') |
|
|
|
|
|
IMG_SIZE = 28 |
|
|
|
|
|
def plot(input_images=None, predictions=None, attention_weights=None): |
|
bag_class = np.argmax(predictions) |
|
bag_class = 'This set of image does not contain number 8' if bag_class == 0 else 'This set of image contains number 8' |
|
|
|
|
|
prob_str = f"Each image probability: {attention_weights[0]:.2f}, {attention_weights[1]:.2f}, {attention_weights[2]:.2f}" |
|
|
|
if input_images is not None: |
|
figure = plt.figure(figsize=(8, 8)) |
|
for j in range(len(input_images)): |
|
image = input_images[j] |
|
figure.add_subplot(1, len(input_images), j + 1) |
|
plt.grid(False) |
|
if attention_weights is not None: |
|
plt.title(f"prob={attention_weights[j]:.2f}") |
|
plt.imshow(np.squeeze(input_images[j])) |
|
return [bag_class, plt.gcf()] |
|
|
|
return [bag_class, prob_str] |
|
|
|
|
|
def preprocess_image(image): |
|
|
|
image = image / 255.0 |
|
image = np.expand_dims(image, axis = 0) |
|
return image |
|
|
|
def infer(input_images_1, input_images_2, input_images_3): |
|
if (input_images_1 is not None) & (input_images_2 is not None) & (input_images_3 is not None): |
|
|
|
input_images_1 = preprocess_image(input_images_1) |
|
input_images_2 = preprocess_image(input_images_2) |
|
input_images_3 = preprocess_image(input_images_3) |
|
|
|
|
|
prediction = model.predict([input_images_1, input_images_2, input_images_3]) |
|
prediction = np.squeeze(np.swapaxes(prediction, 1, 0)) |
|
intermediate_model = keras.Model(model.input, model.get_layer("alpha").output) |
|
intermediate_predictions = intermediate_model.predict([input_images_1, input_images_2, input_images_3]) |
|
attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0)) |
|
|
|
return plot( |
|
[input_images_1, input_images_2, input_images_3], |
|
predictions=prediction, |
|
attention_weights=attention_weights |
|
) |
|
|
|
|
|
input1 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='First image', show_label=True, visible=True) |
|
input2 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='Second image', show_label=True, visible=True) |
|
input3 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='Third image', show_label=True, visible=True) |
|
|
|
output = [gr.Label(), gr.Plot()] |
|
|
|
|
|
title = 'Bag of Image Classification' |
|
description = 'This is the demo for Keras Implementation of Classification using Attention-based Deep Multiple Instance Learning (MIL). The model will try to predict whether number 8 is within the set of input images. As it was trained on MNIST dataset, please use MNIST image for precise result.' |
|
|
|
article = "Author: <a href=\"https://huggingface.co/geninhu\">Nhu Hoang</a>. Based on the following Keras example <a href=\"https://keras.io/examples/vision/attention_mil_classification\"> Classification using Attention-based Deep Multiple Instance Learning (MIL)</a> by <a href=\"https://www.linkedin.com/in/mohamadjaber1\">Mohamad Jaber.</a> <br> Check out the model <a href=\"https://huggingface.co/keras-io/attention_mil\">here</a>" |
|
|
|
gr_interface = gr.Interface( |
|
infer, inputs=[input1, input2, input3], outputs=output, allow_flagging='never', |
|
analytics_enabled=False, title=title, description=description, article=article, |
|
|
|
examples = [['samples/0.png', 'samples/6.png', 'samples/2.png'], ['samples/1.png','samples/2.png', 'samples/3.png'], |
|
['samples/4.png', 'samples/8.png', 'samples/7.png'], ['samples/8.png', 'samples/0.png', 'samples/9.png'], |
|
['samples/5.png', 'samples/6.png', 'samples/3.png'], ['samples/7.png', 'samples/8.png', 'samples/9.png']] |
|
) |
|
|
|
gr_interface.launch(enable_queue=True, debug=False) |
|
|