Spaces:
Runtime error
Runtime error
File size: 5,951 Bytes
4b1ee17 d927e44 4b1ee17 1c0510f d927e44 1c0510f 4b1ee17 7301ba2 4b1ee17 d927e44 4b1ee17 1c0510f a5fc389 4b1ee17 f6bc454 0d33000 f6bc454 0d33000 f6bc454 0d33000 dc4cbf0 4b1ee17 1c0510f 4b1ee17 a43c327 1c0510f 4b1ee17 a43c327 4b1ee17 1c0510f 4b1ee17 1610241 e02b592 1610241 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import numpy as np
from matplotlib import rcParams
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.utils import load_img, save_img, img_to_array
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.layers import GlobalAveragePooling2D
from pymilvus import connections, Collection, utility
from requests import get
import streamlit as st
import zipfile
# unzip vegetable images
def unzip_images():
with zipfile.ZipFile("Vegetable Images.zip", 'r') as zip_ref:
zip_ref.extractall('.')
print('unzipped images')
if not os.path.exists('Vegetable Images/'):
unzip_images()
class ImageVectorizer:
'''
Get vector representation of an image using VGG19 model fine tuned on vegetable images for classification
'''
def __init__(self):
self.__model = self.get_model()
@staticmethod
@st.cache_resource
def get_model():
model = load_model('vegetable_classification_model_vgg.h5') # loading saved VGG model finetuned on vegetable images for classification
top = model.get_layer('block5_pool').output
top = GlobalAveragePooling2D()(top)
model = Model(inputs=model.input, outputs=top)
print('loaded model')
return model
def vectorize(self, img_path: str):
model = self.__model
test_image = load_img(img_path, color_mode="rgb", target_size=(224, 224))
test_image = img_to_array(test_image)
test_image = preprocess_input(test_image)
test_image = np.array([test_image])
return model(test_image).numpy()[0]
@st.cache_resource
def get_milvus_collection():
uri = os.environ.get("URI")
token = os.environ.get("TOKEN")
connections.connect("default", uri=uri, token=token)
print(f"Connected to DB")
collection_name = os.environ.get("COLLECTION_NAME")
collection = Collection(name=collection_name)
collection.load()
return collection
def plot_images(input_image_path: str, similar_img_paths: list):
# plotting similar images
rows = 5 # rows in subplots
cols = 3 # columns in subplots
fig, ax = plt.subplots(rows, cols, figsize=(12, 20))
r = 0
c = 0
for i in range(rows*cols):
sim_image = load_img(similar_img_paths[i], color_mode="rgb", target_size=(224, 224))
ax[r,c].axis("off")
ax[r,c].imshow(sim_image)
c += 1
if c == cols:
c = 0
r += 1
plt.subplots_adjust(wspace=0.01, hspace=0.01)
# display input image
rcParams.update({'figure.autolayout': True})
input_image = load_img(input_image_path, color_mode="rgb", target_size=(224, 224))
with placeholder.container():
st.markdown('<p style="font-size: 20px; font-weight: bold">Input image</p>', unsafe_allow_html=True)
st.image(input_image)
st.write(' \n')
# display similar images
st.markdown('<p style="font-size: 20px; font-weight: bold">Similar images</p>', unsafe_allow_html=True)
st.pyplot(fig)
def find_similar_images(img_path: str, top_n: int=15):
search_params = {"metric_type": "L2"}
search_vec = vectorizer.vectorize(img_path)
result = collection.search([search_vec],
anns_field='image_vector', # annotation field specified in the schema definition
param=search_params,
limit=top_n,
guarantee_timestamp=1,
output_fields=['image_path']) # which fields to return in output
output_dict = {"input_image_path": img_path, "similar_image_paths": [hit.entity.get('image_path') for hits in result for hit in hits]}
plot_images(output_dict['input_image_path'], output_dict['similar_image_paths'])
def delete_file(path_: str):
if os.path.exists(path_):
os.remove(path_)
@st.cache_resource
def get_upload_path():
upload_file_path = os.path.join('.', 'uploads')
if not os.path.exists(upload_file_path):
os.makedirs(upload_file_path)
upload_filename = "input.jpg"
upload_file_path = os.path.join(upload_file_path, upload_filename)
return upload_file_path
def process_input_image(img_url):
upload_file_path = get_upload_path()
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36'}
r = get(img_url, headers=headers)
with open(upload_file_path, "wb") as file:
file.write(r.content)
return upload_file_path
vectorizer = ImageVectorizer()
collection = get_milvus_collection()
try:
st.markdown("<h3>Find Similar Vegetable Images</h3>", unsafe_allow_html=True)
desc = '''<p style="font-size: 15px;">Allowed Vegetables: Broad Bean, Bitter Gourd, Bottle Gourd,
Green Round Brinjal, Broccoli, Cabbage, Capsicum, Carrot, Cauliflower, Cucumber,
Raw Papaya, Potato, Green Pumpkin, Radish, Tomato.
</p>
<p style="font-size: 13px;">Image embeddings are extracted from a fine-tuned VGG model. The model is fine-tuned on <a href="https://www.kaggle.com/datasets/misrakahmed/vegetable-image-dataset" target="_blank">images</a> clicked using a mobile phone camera.
Embeddings of 20,000 vegetable images are stored in Milvus vector database. Embeddings of the input image are computed and 15 most similar images (based on L2 distance) are displayed.</p>
'''
st.markdown(desc, unsafe_allow_html=True)
img_url = st.text_input("Paste the image URL of a vegetable and hit Enter:", "")
placeholder = st.empty()
if img_url:
placeholder.empty()
img_path = process_input_image(img_url)
find_similar_images(img_path, 15)
delete_file(img_path)
except Exception as e:
st.error(f'An unexpected error occured: \n{e}')
|