import numpy as np from sentence_transformers import SentenceTransformer, util from open_clip import create_model_from_pretrained, get_tokenizer import torch from datasets import load_dataset from sklearn.metrics.pairwise import cosine_similarity import torch.nn as nn import boto3 import streamlit as st from PIL import Image from io import BytesIO from typing import List, Union # Initialize the model globally to avoid reloading each time model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384') tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP-384') #what model do we use? def encode_query(query: Union[str, Image.Image]) -> torch.Tensor: """ Encode the query using the OpenCLIP model. Parameters ---------- query : Union[str, Image.Image] The query, which can be a text string or an Image object. Returns ------- torch.Tensor The encoded query vector. """ if isinstance(query, Image.Image): query = preprocess(query).unsqueeze(0) # Preprocess the image and add batch dimension with torch.no_grad(): query_embedding = model.encode_image(query) # Get image embedding elif isinstance(query, str): text = tokenizer(query, context_length=model.context_length) with torch.no_grad(): query_embedding = model.encode_text(text) # Get text embedding else: raise ValueError("Query must be either a string or an Image.") return query_embedding def load_hf_datasets(dataset_name): """ Load Datasets from Hugging Face as DF --------------------------------------- dataset_name: str - name of dataset on Hugging Face --------------------------------------- RETURNS: dataset as pandas dataframe """ dataset = load_dataset(f"quasara-io/{dataset_name}") # Access only the 'Main' split main_dataset = dataset['Main'] # Convert to Pandas DataFrame df = main_dataset.to_pandas() return df def get_image_vectors(df): # Get the image vectors from the dataframe image_vectors = np.vstack(df['Vector'].to_numpy()) return torch.tensor(image_vectors, dtype=torch.float32) def search(query, df, limit, offset, scoring_func, search_in_images, search_in_small_objects): if search_in_images: # Encode the image query query_vector = encode_query(query) # Get the image vectors from the dataframe image_vectors = get_image_vectors(df) # Calculate the cosine similarity between the query vector and each image vector query_vector = query_vector[0, :].detach().numpy() # Detach and convert to a NumPy array image_vectors = image_vectors.detach().numpy() # Convert the image vectors to a NumPy array cosine_similarities = cosine_similarity([query_vector], image_vectors) # Get the top K indices of the most similar image vectors top_k_indices = np.argsort(-cosine_similarities[0])[:limit] # Return the top K indices return top_k_indices def get_file_paths(df, top_k_indices, column_name = 'File_Path'): """ Retrieve the file paths (or any specific column) from the DataFrame using the top K indices. Parameters: - df: pandas DataFrame containing the data - top_k_indices: numpy array of the top K indices - column_name: str, the name of the column to fetch (e.g., 'ImagePath') Returns: - top_k_paths: list of file paths or values from the specified column """ # Fetch the specific column corresponding to the top K indices top_k_paths = df.iloc[top_k_indices][column_name].tolist() return top_k_paths def get_images_from_s3_to_display(bucket_name, file_paths, AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY, folder_name= None): """ Retrieve and display images from AWS S3 in a Streamlit app. Parameters: - bucket_name: str, the name of the S3 bucket - file_paths: list, a list of file paths to retrieve from S3 Returns: - None (directly displays images in the Streamlit app) """ # Initialize S3 client s3 = boto3.client( 's3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY ) # Iterate over file paths and display each image for file_path in file_paths: # Retrieve the image from S3 s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}") img_data = s3_object['Body'].read() # Open the image using PIL and display it using Streamlit img = Image.open(BytesIO(img_data)) st.image(img, caption=file_path, use_column_width=True) def main(): dataset_name = "StopSign_test" query = "black car" limit = 10 offset = 0 scoring_func = "cosine" search_in_images = True search_in_small_objects = False df = load_hf_datasets(dataset_name) results = search(query, df, limit, offset, scoring_func, search_in_images, search_in_small_objects) top_k_paths = get_file_paths(df,results) return top_k_paths if __name__ == "__main__": main()