Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Written by Xueyan Zou ([email protected]) | |
# -------------------------------------------------------- | |
import torch | |
import numpy as np | |
from PIL import Image | |
from torchvision import transforms | |
from detectron2.data import MetadataCatalog | |
from xdecoder.language.loss import vl_similarity | |
t = [] | |
t.append(transforms.Resize(224, interpolation=Image.BICUBIC)) | |
transform_ret = transforms.Compose(t) | |
t = [] | |
t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) | |
transform_grd = transforms.Compose(t) | |
metedata = MetadataCatalog.get('coco_2017_train_panoptic') | |
def text_retrieval(model, image, texts, inpainting_text, *args, **kwargs): | |
out_str = '' | |
with torch.no_grad(): | |
image = transform_ret(image) | |
image = np.asarray(image) | |
images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() | |
batch_inputs = [{'image': images, 'image_id': 0}] | |
outputs = model.model.evaluate(batch_inputs) | |
v_emb = torch.cat([x['captions'][-1:] for x in outputs]) | |
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) | |
texts = [x.strip() for x in texts.split(',')] | |
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, is_eval=False, name='caption', prompt=False) | |
t_emb = getattr(model.model.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('caption')) | |
temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale | |
logits = vl_similarity(v_emb, t_emb, temperature) | |
topk_prob, topk_idx = logits.softmax(-1)[0].topk(min(5, len(texts))) | |
for prob, idx in zip(topk_prob, topk_idx): | |
out_str += "{}:{:.2f}; ".format(texts[idx.item()], prob.item()) | |
torch.cuda.empty_cache() | |
return None, out_str, None |