import numpy as np import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers from matplotlib import pyplot as plt import numpy as np import tensorflow as tf import gradio as gr from huggingface_hub import from_pretrained_keras model = from_pretrained_keras('keras-io/attention_mil') # functions for inference IMG_SIZE = 28 # resize the image and it to a float between 0,1 def plot(input_images=None, predictions=None, attention_weights=None): bag_class = np.argmax(predictions) bag_class = 'This set of image does not contain number 8' if bag_class == 0 else 'This set of image contains number 8' # attention_weights = [round(i, 2) for i in attention_weights] prob_str = f"Each image probability: {attention_weights[0]:.2f}, {attention_weights[1]:.2f}, {attention_weights[2]:.2f}" if input_images is not None: figure = plt.figure(figsize=(8, 8)) for j in range(len(input_images)): image = input_images[j] figure.add_subplot(1, len(input_images), j + 1) plt.grid(False) if attention_weights is not None: plt.title(f"prob={attention_weights[j]:.2f}") plt.imshow(np.squeeze(input_images[j])) return [bag_class, plt.gcf()] return [bag_class, prob_str] def preprocess_image(image): # image = image[:, :, 0] image = image / 255.0 image = np.expand_dims(image, axis = 0) return image def infer(input_images_1, input_images_2, input_images_3): if (input_images_1 is not None) & (input_images_2 is not None) & (input_images_3 is not None): # Normalize input data input_images_1 = preprocess_image(input_images_1) input_images_2 = preprocess_image(input_images_2) input_images_3 = preprocess_image(input_images_3) # Collect info per model. prediction = model.predict([input_images_1, input_images_2, input_images_3]) prediction = np.squeeze(np.swapaxes(prediction, 1, 0)) intermediate_model = keras.Model(model.input, model.get_layer("alpha").output) intermediate_predictions = intermediate_model.predict([input_images_1, input_images_2, input_images_3]) attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0)) return plot( [input_images_1, input_images_2, input_images_3], predictions=prediction, attention_weights=attention_weights ) # get the inputs input1 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='First image', show_label=True, visible=True) input2 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='Second image', show_label=True, visible=True) input3 = gr.Image(shape=(28, 28), type='numpy', image_mode='L', label='Third image', show_label=True, visible=True) # the app outputs two segmented images output = [gr.Label(), gr.Plot()] # output = [gr.Plot()] # it's good practice to pass examples, description and a title to guide users title = 'Bag of Image Classification' description = 'This is the demo for Keras Implementation of Classification using Attention-based Deep Multiple Instance Learning (MIL). The model will try to predict whether number 8 is within the set of input images. As it was trained on MNIST dataset, please use MNIST image for precise result.' gr_interface = gr.Interface( infer, inputs=[input1, input2, input3], outputs=output, allow_flagging='never', analytics_enabled=False, title=title, description=description, # examples = [[f'{i}.png' for i in range(0,3)], [f'{i}.png' for i in range(3,6)], [f'{i}.png' for i in range(6,9)], '9.png'] examples = [['samples/0.png', 'samples/6.png', 'samples/2.png'], ['samples/1.png','samples/2.png', 'samples/3.png'], ['samples/4.png', 'samples/8.png', 'samples/7.png'], ['samples/8.png', 'samples/0.png', 'samples/9.png'], ['samples/5.png', 'samples/6.png', 'samples/3.png'], ['samples/7.png', 'samples/8.png', 'samples/9.png']] ) gr_interface.launch(enable_queue=True, debug=False)