import os import numpy as np from matplotlib import rcParams import matplotlib.pyplot as plt from tensorflow.keras.models import load_model, Model from tensorflow.keras.utils import load_img, save_img, img_to_array from tensorflow.keras.applications.vgg19 import preprocess_input from tensorflow.keras.layers import GlobalAveragePooling2D from pymilvus import connections, Collection, utility from requests import get from shutil import rmtree import streamlit as st import zipfile # unzip vegetable images with zipfile.ZipFile("Vegetable Images.zip", 'r') as zip_ref: zip_ref.extractall('.') @st.cache_resource class ImageVectorizer: ''' Get vector representation of an image using VGG19 model fine tuned on vegetable images for classification ''' def __init__(self): self.__model = self.get_model() @staticmethod def get_model(): model = load_model('vegetable_classification_model_vgg.h5') # loading saved VGG model finetuned on vegetable images for classification top = model.get_layer('block5_pool').output top = GlobalAveragePooling2D()(top) model = Model(inputs=model.input, outputs=top) return model def vectorize(self, img_path: str): model = self.__model test_image = load_img(img_path, color_mode="rgb", target_size=(224, 224)) test_image = img_to_array(test_image) test_image = preprocess_input(test_image) test_image = np.array([test_image]) return model(test_image).numpy()[0] @st.cache_resource def get_milvus_collection(): uri = os.environ.get("URI") token = os.environ.get("TOKEN") connections.connect("default", uri=uri, token=token) print(f"Connected to DB") collection_name = os.environ.get("COLLECTION_NAME") collection = Collection(name=collection_name) collection.load() return collection def plot_images(input_image_path: str, similar_img_paths: list): # plotting similar images rows = 5 # rows in subplots cols = 3 # columns in subplots fig, ax = plt.subplots(rows, cols, figsize=(12, 20)) r = 0 c = 0 for i in range(rows*cols): sim_image = load_img(similar_img_paths[i], color_mode="rgb", target_size=(224, 224)) ax[r,c].axis("off") ax[r,c].imshow(sim_image) c += 1 if c == cols: c = 0 r += 1 plt.subplots_adjust(wspace=0.01, hspace=0.01) # display input image rcParams.update({'figure.autolayout': True}) input_image = load_img(input_image_path, color_mode="rgb", target_size=(224, 224)) with placeholder.container(): st.markdown('
Input image
', unsafe_allow_html=True) st.image(input_image) st.write(' \n') # display similar images st.markdown('Similar images
', unsafe_allow_html=True) st.pyplot(fig) def find_similar_images(img_path: str, top_n: int=15): search_params = {"metric_type": "L2"} search_vec = vectorizer.vectorize(img_path) result = collection.search([search_vec], anns_field='image_vector', # annotation field specified in the schema definition param=search_params, limit=top_n, guarantee_timestamp=1, output_fields=['image_path']) # which fields to return in output output_dict = {"input_image_path": img_path, "similar_image_paths": [hit.entity.get('image_path') for hits in result for hit in hits]} plot_images(output_dict['input_image_path'], output_dict['similar_image_paths']) def delete_file(path_: str): if os.path.exists(path_): rmtree(path=path_, ignore_errors=True) def process_input_image(img_url): upload_file_path = os.path.join('.', 'uploads') os.makedirs(upload_file_path, exist_ok=True) upload_filename = "input.jpg" upload_file_path = os.path.join(upload_file_path, upload_filename) headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36'} r = get(img_url, headers=headers) with open(upload_file_path, "wb") as file: file.write(r.content) return upload_file_path vectorizer = ImageVectorizer() collection = get_milvus_collection() try: st.markdown("Allowed Vegetables: Broad Bean, Bitter Gourd, Bottle Gourd, Green Round Brinjal, Broccoli, Cabbage, Capsicum, Carrot, Cauliflower, Cucumber, Raw Papaya, Potato, Green Pumpkin, Radish, Tomato.
Image embeddings are extracted from a fine-tuned VGG model. The model is fine-tuned on images clicked using a mobile phone camera. Embeddings of 20,000 vegetable images are stored in Milvus vector database. Embeddings of the input image are computed and 15 most similar images (based on L2 distance) are displayed.
''' st.markdown(desc, unsafe_allow_html=True) img_url = st.text_input("Paste the image URL of a vegetable and hit Enter:", "") placeholder = st.empty() if img_url: placeholder.empty() img_path = process_input_image(img_url) find_similar_images(img_path, 15) delete_file(os.path.dirname(img_path)) except Exception as e: st.error(f'An unexpected error occured: \n{e}')