from transformers import Owlv2TextModel, Owlv2Processor, AutoTokenizer import json import torch from torch import nn import tqdm embed_dict = nn.ParameterDict() bsz = 8 with open("id_to_str.json") as f: data = json.load(f) keys = list(data.keys()) bar = tqdm.tqdm(range(len(keys)//bsz)) proc = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") tokenizer = AutoTokenizer.from_pretrained("google/owlv2-base-patch16-ensemble") model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16-ensemble") for i in bar: batch = [data[key].replace("_", " ") for key in keys[i*bsz:(i+1)*bsz]] tokenized = tokenizer(batch) for k in range(bsz): if len(tokenized[k]) > 16: tokenizer.decode(tokenized[k]) batch = proc(text=batch, return_tensors="pt") output = model(**batch) for k, key in enumerate(keys[i*bsz:(i+1)*bsz]): embed_dict[key] = output.pooler_output[k, :] torch.save(embed_dict.state_dict(), "embeds.pt")