|
import numpy as np |
|
from sentence_transformers import SentenceTransformer, util |
|
from open_clip import create_model_from_pretrained, get_tokenizer |
|
import torch |
|
from datasets import load_dataset |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import torch.nn as nn |
|
import boto3 |
|
import streamlit as st |
|
from PIL import Image |
|
from PIL import ImageDraw |
|
from io import BytesIO |
|
import pandas as pd |
|
from typing import List, Union |
|
import concurrent.futures |
|
|
|
|
|
|
|
model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384') |
|
tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP-384') |
|
|
|
|
|
|
|
def encode_query(query: Union[str, Image.Image]) -> torch.Tensor: |
|
""" |
|
Encode the query using the OpenCLIP model. |
|
Parameters |
|
---------- |
|
query : Union[str, Image.Image] |
|
The query, which can be a text string or an Image object. |
|
Returns |
|
------- |
|
torch.Tensor |
|
The encoded query vector. |
|
""" |
|
if isinstance(query, Image.Image): |
|
query = preprocess(query).unsqueeze(0) |
|
with torch.no_grad(): |
|
query_embedding = model.encode_image(query) |
|
elif isinstance(query, str): |
|
text = tokenizer(query, context_length=model.context_length) |
|
with torch.no_grad(): |
|
query_embedding = model.encode_text(text) |
|
else: |
|
raise ValueError("Query must be either a string or an Image.") |
|
|
|
return query_embedding |
|
|
|
def load_dataset_with_limit(dataset_name, dataset_subset, search_in_small_objects,limit=1000): |
|
""" |
|
Load a dataset from Hugging Face and limit the number of rows. |
|
""" |
|
if search_in_small_objects: |
|
split = f'Splits_{dataset_subset}' |
|
else: |
|
split = f'Main_{dataset_subset}' |
|
dataset_name = f"quasara-io/{dataset_name}" |
|
dataset = load_dataset(dataset_name, split=split) |
|
total_rows = dataset.num_rows |
|
|
|
if limit is not None: |
|
df = dataset.to_pandas().sample(n=limit, random_state=42) |
|
else: |
|
df = dataset.to_pandas() |
|
|
|
return df,total_rows |
|
|
|
def get_image_vectors(df): |
|
|
|
image_vectors = np.vstack(df['Vector'].to_numpy()) |
|
return torch.tensor(image_vectors, dtype=torch.float32) |
|
|
|
|
|
def search(query, df, limit, search_in_images = True): |
|
if search_in_images: |
|
|
|
query_vector = encode_query(query) |
|
|
|
|
|
|
|
image_vectors = get_image_vectors(df) |
|
|
|
|
|
|
|
query_vector = query_vector[0, :].detach().numpy() |
|
image_vectors = image_vectors.detach().numpy() |
|
cosine_similarities = cosine_similarity([query_vector], image_vectors) |
|
|
|
|
|
top_k_indices = np.argsort(-cosine_similarities[0])[:limit] |
|
|
|
|
|
return top_k_indices |
|
|
|
|
|
def batch_search(query, df, batch_size=100000, limit=10): |
|
top_k_indices = [] |
|
|
|
|
|
vectors = get_image_vectors(df).numpy() |
|
|
|
|
|
query_vector = encode_query(query)[0].detach().numpy() |
|
|
|
|
|
for i in range(0, len(vectors), batch_size): |
|
batch_vectors = vectors[i:i + batch_size] |
|
|
|
|
|
batch_similarities = cosine_similarity([query_vector], batch_vectors) |
|
|
|
|
|
top_k_indices.extend(np.argsort(-batch_similarities[0])[:limit]) |
|
|
|
return top_k_indices |
|
|
|
|
|
def get_file_paths(df, top_k_indices, column_name = 'File_Path'): |
|
""" |
|
Retrieve the file paths (or any specific column) from the DataFrame using the top K indices. |
|
|
|
Parameters: |
|
- df: pandas DataFrame containing the data |
|
- top_k_indices: numpy array of the top K indices |
|
- column_name: str, the name of the column to fetch (e.g., 'ImagePath') |
|
|
|
Returns: |
|
- top_k_paths: list of file paths or values from the specified column |
|
""" |
|
|
|
top_k_paths = df.iloc[top_k_indices][column_name].tolist() |
|
return top_k_paths |
|
def get_cordinates(df, top_k_indices, column_name = 'Coordinate'): |
|
""" |
|
Retrieve the file paths (or any specific column) from the DataFrame using the top K indices. |
|
|
|
Parameters: |
|
- df: pandas DataFrame containing the data |
|
- top_k_indices: numpy array of the top K indices |
|
- column_name: str, the name of the column to fetch (e.g., 'ImagePath') |
|
|
|
Returns: |
|
- top_k_paths: list of file paths or values from the specified column |
|
""" |
|
|
|
top_k_paths = df.iloc[top_k_indices][column_name].tolist() |
|
return top_k_paths |
|
|
|
def get_images_from_s3_to_display(bucket_name, file_paths, AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY, folder_name): |
|
""" |
|
Retrieve and display images from AWS S3 in a Streamlit app. |
|
|
|
Parameters: |
|
- bucket_name: str, the name of the S3 bucket |
|
- file_paths: list, a list of file paths to retrieve from S3 |
|
|
|
Returns: |
|
- None (directly displays images in the Streamlit app) |
|
""" |
|
|
|
s3 = boto3.client( |
|
's3', |
|
aws_access_key_id=AWS_ACCESS_KEY_ID, |
|
aws_secret_access_key=AWS_SECRET_ACCESS_KEY |
|
) |
|
|
|
|
|
for file_path in file_paths: |
|
|
|
s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}") |
|
img_data = s3_object['Body'].read() |
|
|
|
|
|
img = Image.open(BytesIO(img_data)) |
|
st.image(img, caption=file_path, use_column_width=True) |
|
|
|
|
|
|
|
def get_images_with_bounding_boxes_from_s3(bucket_name, file_paths, bounding_boxes, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_name): |
|
""" |
|
Retrieve and display images from AWS S3 with corresponding bounding boxes in a Streamlit app. |
|
|
|
Parameters: |
|
- bucket_name: str, the name of the S3 bucket |
|
- file_paths: list, a list of file paths to retrieve from S3 |
|
- bounding_boxes: list of numpy arrays or lists, each containing coordinates of bounding boxes (in the form [x_min, y_min, x_max, y_max]) |
|
- AWS_ACCESS_KEY_ID: str, AWS access key ID for authentication |
|
- AWS_SECRET_ACCESS_KEY: str, AWS secret access key for authentication |
|
- folder_name: str, the folder prefix in S3 bucket where the images are stored |
|
|
|
Returns: |
|
- None (directly displays images in the Streamlit app with bounding boxes) |
|
""" |
|
|
|
s3 = boto3.client( |
|
's3', |
|
aws_access_key_id=AWS_ACCESS_KEY_ID, |
|
aws_secret_access_key=AWS_SECRET_ACCESS_KEY |
|
) |
|
|
|
|
|
for file_path, box_coords in zip(file_paths, bounding_boxes): |
|
|
|
s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}") |
|
img_data = s3_object['Body'].read() |
|
|
|
|
|
img = Image.open(BytesIO(img_data)) |
|
|
|
|
|
draw = ImageDraw.Draw(img) |
|
|
|
|
|
if isinstance(box_coords, (np.ndarray, list)): |
|
|
|
if len(box_coords) > 0 and isinstance(box_coords[0], (np.ndarray, list)): |
|
|
|
for box in box_coords: |
|
x_min, y_min, x_max, y_max = map(int, box) |
|
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) |
|
else: |
|
|
|
x_min, y_min, x_max, y_max = map(int, box_coords) |
|
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) |
|
else: |
|
raise ValueError(f"Bounding box data for {file_path} is not in an iterable format.") |
|
|
|
|
|
st.image(img, caption=file_path, use_column_width=True) |
|
|
|
|
|
|