arch_demo / app.py
davanstrien's picture
davanstrien HF staff
Update app.py
4c0e93e
raw
history blame
4.96 kB
import multiprocessing
import random
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from PIL.Image import Image, ANTIALIAS
import gradio as gr
from faiss import METRIC_INNER_PRODUCT
import requests
import pandas as pd
import backoff
from functools import lru_cache
cpu_count = multiprocessing.cpu_count()
model = SentenceTransformer("clip-ViT-B-16")
def resize_image(image: Image, size: int = 224) -> Image:
"""Resizes an image retaining the aspect ratio."""
w, h = image.size
if w == h:
image = image.resize((size, size), ANTIALIAS)
return image
if w > h:
height_percent = size / float(h)
width_size = int(float(w) * float(height_percent))
image = image.resize((width_size, size), ANTIALIAS)
return image
if w < h:
width_percent = size / float(w)
height_size = int(float(w) * float(width_percent))
image = image.resize((size, height_size), ANTIALIAS)
return image
dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
dataset = dataset.filter(lambda x: x["embedding"] is not None)
dataset.add_faiss_index("embedding", metric_type=METRIC_INNER_PRODUCT)
def get_nearest_k_examples(input, k):
query = model.encode(input)
# faiss_index = dataset.get_index("embedding").faiss_index # TODO maybe add range?
# threshold = 0.95
# limits, distances, indices = faiss_index.range_search(x=query, thresh=threshold)
# images = dataset[indices]
_, retrieved_examples = dataset.get_nearest_examples("embedding", query=query, k=k)
images = retrieved_examples["image"][:k]
last_modified = retrieved_examples["last_modified_date"] # [:k]
crawl_date = retrieved_examples["crawl_date"] # [:k]
metadata = [
f"last_modified {modified}, crawl date:{crawl}"
for modified, crawl in zip(last_modified, crawl_date)
]
return list(zip(images, metadata))
def return_random_sample(k=27):
sample = random.sample(range(len(dataset)), k)
images = dataset[sample]["image"]
return [resize_image(image).convert("RGB") for image in images]
def predict_subset(model_id, token):
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
headers = {"Authorization": f"Bearer {token}"}
@backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
def _query(url):
r = requests.post(API_URL, headers=headers, data=url)
print(r)
return r
@lru_cache(maxsize=1000)
def query(url):
response = _query(url)
try:
data = response.json()
argmax = data[0]
return {"score": argmax["score"], "label": argmax["label"]}
except Exception:
return {"score": None, "label": None}
# dataset2 = copy.deepcopy(dataset)
# dataset2.drop_index("embedding")
dataset = load_dataset("davanstrien/ia-loaded-embedded-gpu", split="train")
sample = random.sample(range(len(dataset)), 10)
sample = dataset.select(sample)
print("predicting...")
predictions = []
for row in sample:
url = row["url"]
predictions.append(query(url))
gallery = []
for url, prediction in zip(sample["url"], predictions):
gallery.append((url, f"{prediction['label'], prediction['score']}"))
# sample = sample.map(lambda x: query(x['url']))
labels = [d["label"] for d in predictions]
from toolz import frequencies
df = pd.DataFrame(
{"labels": frequencies(labels).keys(), "freqs": frequencies(labels).values()}
)
return gallery, df
with gr.Blocks() as demo:
with gr.Tab("Random image gallery"):
button = gr.Button("Refresh")
gallery = gr.Gallery().style(grid=9, height="1400")
button.click(return_random_sample, [], [gallery])
with gr.Tab("image search"):
text = gr.Textbox(label="Search for images")
k = gr.Slider(minimum=3, maximum=18, step=1)
button = gr.Button("search")
gallery = gr.Gallery().style(grid=3)
button.click(get_nearest_k_examples, [text, k], [gallery])
# with gr.Tab("Export for label studio"):
# button = gr.Button("Export")
# dataset2 = copy.deepcopy(dataset)
# # dataset2 = dataset2.remove_columns('image')
# # dataset2 = dataset2.rename_column("url", "image")
# csv = dataset2.to_csv("label_studio.csv")
# csv_file = gr.File("label_studio.csv")
# button.click(dataset.save_to_disk, [], [csv_file])
with gr.Tab("predict"):
token = gr.Textbox(label="token", type="password")
model_id = gr.Textbox(label="model_id")
button = gr.Button("predict")
plot = gr.BarPlot(x="labels", y="freqs", width=600, height=400, vertical=False)
gallery = gr.Gallery()
button.click(predict_subset, [model_id, token], [gallery, plot])
demo.launch(enable_queue=True)