Spaces:
Runtime error
Runtime error
File size: 5,227 Bytes
3579efb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from typing import List, Union
import datasets
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from tqdm.auto import tqdm
from transformers import AutoFeatureExtractor, AutoModel
seed = 42
hash_size = 8
hidden_dim = 768 # ViT-base
np.random.seed(seed)
# Device.
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model for computing embeddings..
model_ckpt = "nateraw/vit-base-beans"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
# Data transformation chain.
transformation_chain = T.Compose(
[
# We first resize the input image to 256x256 and then we take center crop.
T.Resize(int((256 / 224) * extractor.size["height"])),
T.CenterCrop(extractor.size["height"]),
T.ToTensor(),
T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
]
)
# Define random vectors to project with.
random_vectors = np.random.randn(hash_size, hidden_dim).T
def hash_func(embedding, random_vectors=random_vectors):
"""Randomly projects the embeddings and then computes bit-wise hashes."""
if not isinstance(embedding, np.ndarray):
embedding = np.array(embedding)
if len(embedding.shape) < 2:
embedding = np.expand_dims(embedding, 0)
# Random projection.
bools = np.dot(embedding, random_vectors) > 0
return [bool2int(bool_vec) for bool_vec in bools]
def bool2int(x):
y = 0
for i, j in enumerate(x):
if j:
y += 1 << i
return y
def compute_hash(model: Union[torch.nn.Module, str]):
"""Computes hash on a given dataset."""
device = model.device
def pp(example_batch):
# Prepare the input images for the model.
image_batch = example_batch["image"]
image_batch_transformed = torch.stack(
[transformation_chain(image) for image in image_batch]
)
new_batch = {"pixel_values": image_batch_transformed.to(device)}
# Compute embeddings and pool them i.e., take the representations from the [CLS]
# token.
with torch.no_grad():
embeddings = model(**new_batch).last_hidden_state[:, 0].cpu().numpy()
# Compute hashes for the batch of images.
hashes = [hash_func(embeddings[i]) for i in range(len(embeddings))]
example_batch["hashes"] = hashes
return example_batch
return pp
class Table:
def __init__(self, hash_size: int):
self.table = {}
self.hash_size = hash_size
def add(self, id: int, hashes: List[int], label: int):
# Create a unique indentifier.
entry = {"id_label": str(id) + "_" + str(label)}
# Add the hash values to the current table.
for h in hashes:
if h in self.table:
self.table[h].append(entry)
else:
self.table[h] = [entry]
def query(self, hashes: List[int]):
results = []
# Loop over the query hashes and determine if they exist in
# the current table.
for h in hashes:
if h in self.table:
results.extend(self.table[h])
return results
class LSH:
def __init__(self, hash_size, num_tables):
self.num_tables = num_tables
self.tables = []
for i in range(self.num_tables):
self.tables.append(Table(hash_size))
def add(self, id: int, hash: List[int], label: int):
for table in self.tables:
table.add(id, hash, label)
def query(self, hashes: List[int]):
results = []
for table in self.tables:
results.extend(table.query(hashes))
return results
class BuildLSHTable:
def __init__(
self,
model: Union[torch.nn.Module, None],
batch_size: int = 48,
hash_size: int = hash_size,
dim: int = hidden_dim,
num_tables: int = 10,
):
self.hash_size = hash_size
self.dim = dim
self.num_tables = num_tables
self.lsh = LSH(self.hash_size, self.num_tables)
self.batch_size = batch_size
self.hash_fn = compute_hash(model.to(device))
def build(self, ds: datasets.DatasetDict):
dataset_hashed = ds.map(self.hash_fn, batched=True, batch_size=self.batch_size)
for id in tqdm(range(len(dataset_hashed))):
hash, label = dataset_hashed[id]["hashes"], dataset_hashed[id]["labels"]
self.lsh.add(id, hash, label)
def query(self, image, verbose=True):
if isinstance(image, str):
image = Image.open(image).convert("RGB")
# Compute the hashes of the query image and fetch the results.
example_batch = dict(image=[image])
hashes = self.hash_fn(example_batch)["hashes"][0]
results = self.lsh.query(hashes)
if verbose:
print("Matches:", len(results))
# Calculate Jaccard index to quantify the similarity.
counts = {}
for r in results:
if r["id_label"] in counts:
counts[r["id_label"]] += 1
else:
counts[r["id_label"]] = 1
for k in counts:
counts[k] = float(counts[k]) / self.dim
return counts
|