g8a9 commited on
Commit
0f27d7b
1 Parent(s): 7369efb

Add code to precompute embeddings.

Browse files
Files changed (2) hide show
  1. app.py +3 -0
  2. precompute_features.py +50 -0
app.py CHANGED
@@ -47,6 +47,9 @@ def get_image_features():
47
 
48
 
49
  """
 
 
 
50
  # CLIP Italian Demo (Flax Community Week)
51
  """
52
 
 
47
 
48
 
49
  """
50
+
51
+ # 👋 Ciao!
52
+
53
  # CLIP Italian Demo (Flax Community Week)
54
  """
55
 
precompute_features.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from argparse import ArgumentParser
3
+ from jax import numpy as jnp
4
+ from torchvision import datasets, transforms
5
+ from torchvision.transforms import CenterCrop, Normalize, Resize, ToTensor
6
+ from torchvision.transforms.functional import InterpolationMode
7
+ from transformers import AutoTokenizer
8
+ from modeling_hybrid_clip import FlaxHybridCLIP
9
+ import utils
10
+ import torch
11
+
12
+
13
+ if __name__ == "__main__":
14
+ parser = ArgumentParser()
15
+ parser.add_argument("in_dir")
16
+ parser.add_argument("out_file")
17
+ args = parser.parse_args()
18
+
19
+ model = FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(
22
+ "dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True
23
+ )
24
+
25
+ image_size = model.config.vision_config.image_size
26
+
27
+ val_preprocess = transforms.Compose(
28
+ [
29
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
30
+ CenterCrop(image_size),
31
+ ToTensor(),
32
+ Normalize(
33
+ (0.48145466, 0.4578275, 0.40821073),
34
+ (0.26862954, 0.26130258, 0.27577711),
35
+ ),
36
+ ]
37
+ )
38
+
39
+ dataset = utils.CustomDataSet(args.in_dir, transform=val_preprocess)
40
+
41
+ loader = torch.utils.data.DataLoader(
42
+ dataset,
43
+ batch_size=256,
44
+ shuffle=False,
45
+ num_workers=16,
46
+ drop_last=False,
47
+ )
48
+
49
+ image_features = utils.precompute_image_features(model, loader)
50
+ jnp.save(f"static/features/{args.out_file}", image_features)