|
import streamlit as st |
|
from helper import ( |
|
load_dataset, search, get_file_paths, |
|
get_cordinates, get_images_from_s3_to_display, |
|
get_images_with_bounding_boxes_from_s3, load_dataset_with_limit |
|
) |
|
import os |
|
import time |
|
|
|
|
|
|
|
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") |
|
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") |
|
|
|
|
|
datasets = ["MajorTom-Germany", "MajorTom-Netherlands","WayveScenes"] |
|
folder_path_dict = { |
|
"WayveScenes": "WayveScenes/", |
|
"MajorTom-Germany": "MajorTOM-DE/", |
|
"MajorTom-Netherlands": "MajorTOM-NL/", |
|
"MajorTom-UK" :"" |
|
} |
|
description = { |
|
"WayveScenes": "A large-scale dataset featuring diverse urban driving scenes, captured from vehicles to advance AI perception and navigation in complex environments.", |
|
"MajorTom-Germany": "A geospatial dataset containing satellite imagery from across Germany, designed for tasks like land-use classification, environmental monitoring, and earth observation analytics.", |
|
"MajorTom-Netherlands": "A geospatial dataset containing satellite imagery from across Netherlands, designed for tasks like land-use classification, environmental monitoring, and earth observation analytics.", |
|
"MajorTom-UK" :"A geospatial dataset containing satellite imagery from across the United Kingdom, designed for tasks like land-use classification, environmental monitoring, and earth observation analytics." |
|
} |
|
selection = { |
|
'WayveScenes': [1, 10], |
|
"MajorTom-Germany": [1, 1], |
|
"MajorTom-Netherlands": [1,1], |
|
"MajorTom-UK": [1,1] |
|
} |
|
|
|
example_queries = { |
|
'WayveScenes': "Parking Signs, Pedestrian Crossing, Traffic Light (Red, Green, Orange)", |
|
"MajorTom-Germany": "Airports, Golf Courses, Wind Mills, Solar Panels ", |
|
"MajorTom-Netherlands": "Airports, Golf Courses, Wind Mills, Solar Panels ", |
|
"MajorTom-UK": "Airports, Golf Courses, Wind Mills, Solar Panels " |
|
} |
|
|
|
|
|
|
|
bucket_name = "datasets-quasara-io" |
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
if 'search_in_small_objects' not in st.session_state: |
|
st.session_state.search_in_small_objects = False |
|
|
|
if 'dataset_number' not in st.session_state: |
|
st.session_state.dataset_number = 1 |
|
|
|
if 'df' not in st.session_state: |
|
st.session_state.df = None |
|
|
|
st.title("Semantic Search and Image Display") |
|
|
|
|
|
dataset_name = st.selectbox("Select Dataset", datasets) |
|
st.session_state.df = None |
|
|
|
folder_path = folder_path_dict[dataset_name] |
|
|
|
st.caption(description[dataset_name]) |
|
|
|
if st.checkbox("Enable Small Object Search", value=st.session_state.search_in_small_objects): |
|
st.session_state.search_in_small_objects = True |
|
st.text("Small Object Search Enabled") |
|
st.session_state.dataset_number = st.selectbox("Select Subset of Data", list(range(1, selection[dataset_name][1] + 1))) |
|
st.session_state.df = None |
|
st.text(f"You have selected Split Dataset {st.session_state.dataset_number}") |
|
else: |
|
st.session_state.search_in_small_objects = False |
|
st.text("Small Object Search Disabled") |
|
st.session_state.dataset_number = st.selectbox("Select Subset of Data", list(range(1, selection[dataset_name][0] + 1))) |
|
st.session_state.df = None |
|
st.text(f"You have selected Main Dataset {st.session_state.dataset_number}") |
|
|
|
df, total_rows = load_dataset_with_limit(dataset_name, st.session_state.dataset_number, st.session_state.search_in_small_objects, limit=1) |
|
dataset_limit = st.slider("Size of Dataset to be searched from", min_value=0, max_value=min(total_rows, 80000), value=int(min(total_rows, 80000)/2)) |
|
st.text(f'The smaller the dataset the faster the search will work.') |
|
|
|
|
|
|
|
try: |
|
loading_dataset_text = st.empty() |
|
loading_dataset_text.text("Loading Dataset...") |
|
loading_dataset_bar = st.progress(0) |
|
|
|
|
|
|
|
for i in range(0, 100, 25): |
|
time.sleep(0.2) |
|
loading_dataset_bar.progress(i + 25) |
|
|
|
|
|
df, total_rows = load_dataset_with_limit(dataset_name, st.session_state.dataset_number, st.session_state.search_in_small_objects, limit=dataset_limit) |
|
|
|
|
|
st.session_state.df = df |
|
loading_dataset_bar.progress(100) |
|
loading_dataset_text.text("Dataset loaded successfully!") |
|
st.success(f"Dataset loaded successfully with {len(df)} rows.") |
|
|
|
|
|
except Exception as e: |
|
st.error(f"Failed to load dataset: {e}") |
|
|
|
|
|
|
|
query = st.text_input("Enter your search query") |
|
st.text(f"Example Queries for your Dataset: {example_queries[dataset_name]}") |
|
|
|
limit = st.number_input("Number of results to display", min_value=1, max_value=10, value=10) |
|
|
|
|
|
if st.button("Search"): |
|
|
|
if not query: |
|
st.warning("Please enter a search query.") |
|
else: |
|
try: |
|
|
|
search_loading_text = st.empty() |
|
search_loading_text.text("Searching...") |
|
search_progress_bar = st.progress(0) |
|
|
|
|
|
df = st.session_state.df |
|
if st.session_state.search_in_small_objects: |
|
results = search(query, df, limit) |
|
top_k_paths = get_file_paths(df, results) |
|
top_k_cordinates = get_cordinates(df, results) |
|
search_type = 'Splits' |
|
else: |
|
|
|
results = search(query, df, limit) |
|
top_k_paths = get_file_paths(df, results) |
|
search_type = 'Main' |
|
|
|
|
|
search_progress_bar.progress(100) |
|
search_loading_text.text(f"Search completed among {dataset_limit} rows for {dataset_name} in {search_type} {st.session_state.dataset_number}") |
|
|
|
|
|
if st.session_state.search_in_small_objects and top_k_paths and top_k_cordinates: |
|
get_images_with_bounding_boxes_from_s3(bucket_name, top_k_paths, top_k_cordinates, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_path) |
|
elif not st.session_state.search_in_small_objects and top_k_paths: |
|
st.write(f"Displaying top {len(top_k_paths)} results for query '{query}':") |
|
get_images_from_s3_to_display(bucket_name, top_k_paths, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_path) |
|
|
|
else: |
|
st.write("No results found.") |
|
|
|
|
|
except Exception as e: |
|
st.error(f"Search failed: {e}") |
|
|
|
if __name__ == "__main__": |
|
main() |