File size: 2,184 Bytes
7369efb
 
 
88974f6
 
 
7369efb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5ad46a
7369efb
 
 
 
 
c5ad46a
7369efb
 
 
 
 
 
c5ad46a
 
 
 
 
 
 
7369efb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import natsort
from tqdm import tqdm
import torch
from jax import numpy as jnp
from PIL import Image as PilImage


class CustomDataSet(torch.utils.data.Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        all_imgs = os.listdir(main_dir)
        self.total_imgs = natsort.natsorted(all_imgs)

    def __len__(self):
        return len(self.total_imgs)

    def get_image_name(self, idx):
        return self.total_imgs[idx]

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = PilImage.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image


def text_encoder(text, model, tokenizer):
    inputs = tokenizer(
        [text],
        max_length=96,
        truncation=True,
        padding="max_length",
        return_tensors="np",
    )
    embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[
        0
    ]
    embedding /= jnp.linalg.norm(embedding)
    return jnp.expand_dims(embedding, axis=0)


def precompute_image_features(model, loader):
    image_features = []
    for i, (images) in enumerate(tqdm(loader)):
        images = images.permute(0, 2, 3, 1).numpy()
        features = model.get_image_features(images,)
        features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
        image_features.extend(features)
    return jnp.array(image_features)


def find_image(text_query, model, dataset, tokenizer, image_features, n, dataset_name):
    zeroshot_weights = text_encoder(text_query, model, tokenizer)
    zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
    distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
    file_paths = []
    for i in range(1, n + 1):
        idx = jnp.argsort(distances, axis=0)[-i, 0]

        if dataset_name == "Unsplash":
            file_paths.append("photos/" + dataset.get_image_name(idx))
        elif dataset_name == "CC":
            file_paths.append(dataset[idx])
        else:
            raise ValueError(f"{dataset_name} not supported here")
    return file_paths