Spaces:
Build error
Build error
import matplotlib.pyplot as plt | |
import nmslib | |
import numpy as np | |
import os | |
import streamlit as st | |
from PIL import Image | |
from transformers import CLIPProcessor, FlaxCLIPModel | |
BASELINE_MODEL = "openai/clip-vit-base-patch32" | |
# MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1" | |
MODEL_PATH = "flax-community/clip-rsicd-v2" | |
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv" | |
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv" | |
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv" | |
# IMAGES_DIR = "/home/shared/data/rsicd_images" | |
IMAGES_DIR = "./images" | |
def load_index(): | |
filenames, image_vecs = [], [] | |
fvec = open(IMAGE_VECTOR_FILE, "r") | |
for line in fvec: | |
cols = line.strip().split('\t') | |
filename = cols[0] | |
image_vec = np.array([float(x) for x in cols[1].split(',')]) | |
filenames.append(filename) | |
image_vecs.append(image_vec) | |
V = np.array(image_vecs) | |
index = nmslib.init(method='hnsw', space='cosinesimil') | |
index.addDataPointBatch(V) | |
index.createIndex({'post': 2}, print_progress=True) | |
return filenames, index | |
def load_model(): | |
model = FlaxCLIPModel.from_pretrained(MODEL_PATH) | |
processor = CLIPProcessor.from_pretrained(BASELINE_MODEL) | |
return model, processor | |
def load_example_images(): | |
example_images = {} | |
image_names = os.listdir(IMAGES_DIR) | |
for image_name in image_names: | |
if image_name.find("_") < 0: | |
continue | |
image_class = image_name.split("_")[0] | |
if image_class in example_images.keys(): | |
example_images[image_class].append(image_name) | |
else: | |
example_images[image_class] = [image_name] | |
return example_images | |
def app(): | |
filenames, index = load_index() | |
model, processor = load_model() | |
example_images = load_example_images() | |
example_image_list = sorted([v[np.random.randint(0, len(v))] | |
for k, v in example_images.items()][0:10]) | |
st.title("Image to Image Retrieval") | |
st.markdown(""" | |
The CLIP model from OpenAI is trained in a self-supervised manner using | |
contrastive learning to project images and caption text onto a common | |
embedding space. We have fine-tuned the model using the RSICD dataset | |
(10k images and ~50k captions from the remote sensing domain). | |
This demo shows the image to image retrieval capabilities of this model, i.e., | |
given an image file name as a query, we use our fine-tuned CLIP model | |
to project the query image to the image/caption embedding space and search | |
for nearby images (by cosine similarity) in this space. | |
Our fine-tuned CLIP model was previously used to generate image vectors for | |
our demo, and NMSLib was used for fast vector access. | |
Here are some randomly generated image files from our corpus. You can | |
copy paste one of these below or use one from the results of a text to | |
image search -- {:s} | |
""".format(", ".join("`{:s}`".format(example) for example in example_image_list))) | |
image_name = st.text_input("Provide an Image File Name") | |
submit_button = st.button("Find Similar") | |
if submit_button: | |
image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_name))) | |
inputs = processor(images=image, return_tensors="jax", padding=True) | |
query_vec = model.get_image_features(**inputs) | |
query_vec = np.asarray(query_vec) | |
ids, distances = index.knnQuery(query_vec, k=11) | |
result_filenames = [filenames[id] for id in ids] | |
images, captions = [], [] | |
for result_filename, score in zip(result_filenames, distances): | |
if result_filename == image_name: | |
continue | |
images.append( | |
plt.imread(os.path.join(IMAGES_DIR, result_filename))) | |
captions.append("{:s} (score: {:.3f})".format(result_filename, 1.0 - score)) | |
images = images[0:10] | |
captions = captions[0:10] | |
st.image(images[0:3], caption=captions[0:3]) | |
st.image(images[3:6], caption=captions[3:6]) | |
st.image(images[6:9], caption=captions[6:9]) | |
st.image(images[9:], caption=captions[9:]) | |