Cloth_matching / app.py
vishalkatheriya's picture
Update app.py
6ff5afe verified
import streamlit as st
import os
import io
from PIL import Image
import numpy as np
import pickle
import tensorflow
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.layers import GlobalMaxPooling2D
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
# Load data only once and store in session state
if 'feature_list' not in st.session_state:
st.session_state.feature_list = np.array(pickle.load(open('embeddings.pkl', 'rb')))
if 'filenames' not in st.session_state:
st.session_state.filenames = pickle.load(open('filenames.pkl', 'rb'))
# Load model only once and store in session state
if 'model' not in st.session_state:
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = False
model = tensorflow.keras.Sequential([
base_model,
GlobalMaxPooling2D()
])
st.session_state.model = model
st.title('Fashion Recommender System')
def save_uploaded_file(uploaded_file):
# Ensure the 'uploads' directory exists
if not os.path.exists('uploads'):
os.makedirs('uploads')
try:
file_path = os.path.join('uploads', uploaded_file.name)
with open(file_path, 'wb') as f:
f.write(uploaded_file.getbuffer())
return file_path
except Exception as e:
st.error(f"Error: {e}")
return None
def load_and_process_image(uploaded_image):
image = Image.open(uploaded_image)
if image.mode != 'RGB':
image = image.convert('RGB')
image = image.resize((256, 256))
# Save the image to a BytesIO object with a specific format
image_bytes = io.BytesIO()
image_format = image.format if image.format else 'PNG' # Default to 'PNG' if format is None
image.save(image_bytes, format=image_format)
image_bytes.seek(0)
return image_bytes
def feature_extraction(image_bytes, model):
img = keras_image.load_img(image_bytes, target_size=(224, 224))
img_array = keras_image.img_to_array(img)
expanded_img_array = np.expand_dims(img_array, axis=0)
preprocessed_img = preprocess_input(expanded_img_array)
result = model.predict(preprocessed_img).flatten()
normalized_result = result / norm(result)
return normalized_result
def recommend(features, feature_list):
neighbors = NearestNeighbors(n_neighbors=6, algorithm='brute', metric='euclidean')
neighbors.fit(feature_list)
distances, indices = neighbors.kneighbors([features])
return indices
# Steps
uploaded_file = st.file_uploader("Choose an image")
if uploaded_file is not None:
file_path = save_uploaded_file(uploaded_file)
if file_path:
# Load and process the uploaded file
image_bytes = load_and_process_image(file_path)
# Display the processed image
display_image = Image.open(image_bytes)
st.image(display_image)
# Perform feature extraction only if it's a new file
if 'features' not in st.session_state or st.session_state.uploaded_file_name != uploaded_file.name:
st.session_state.features = feature_extraction(image_bytes, st.session_state.model)
st.session_state.uploaded_file_name = uploaded_file.name
# Recommendation
indices = recommend(st.session_state.features, st.session_state.feature_list)
# Show recommended images
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.write(st.session_state.filenames[indices[0][0]])
with col2:
st.write(st.session_state.filenames[indices[0][1]])
with col3:
st.write(st.session_state.filenames[indices[0][2]])
with col4:
st.write(st.session_state.filenames[indices[0][3]])
with col5:
st.write(st.session_state.filenames[indices[0][4]])
else:
st.error("Some error occurred in file upload")