File size: 3,503 Bytes
9ad81d2
 
 
 
 
 
 
 
 
 
 
 
b14c338
9ad81d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from model import MaskedAutoencoderViT, mae_vit_base_patch16
import numpy as np
from PIL import Image
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoTokenizer
from collections import OrderedDict
from huggingface_hub import hf_hub_download

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', )

ckpt = torch.load(hf_hub_download('tennant/MUG', 'mae_bert_vit_b_cc3m.pth'), map_location='cpu')

new_dict = OrderedDict()
for k, v in ckpt.items():
    k = k[len('image_encoder.model.'):]
    new_dict.update({k: v})

model = mae_vit_base_patch16(uni_dim=768, less_u=True)

model.load_state_dict(new_dict)
if torch.cuda.is_available():
    model.cuda()
model.eval()

@torch.no_grad()
def visual_recon(x, model):
    target = model.patchify(x)
    mean = target.mean(dim=-1, keepdim=True)
    var = target.var(dim=-1, keepdim=True)

    latent, mask, ids_restore, _ = model.forward_encoder(x, mask_ratio=0.75)
    y, _ = model.forward_decoder(latent, ids_restore)
    y = y * (var + 1.e-6)**.5 + mean
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y)
    
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask)
    
    x = torch.einsum('nchw->nhwc', x)
    
    return x * (1 - mask), x * (1 - mask) + y * mask, y, latent

@torch.no_grad()
def caption_next_word(latent, model, tokenizer, prefix='a photo of a'):
    assert latent.shape[0] == 1, 'can only caption one image at a time'
    
    x_l = torch.tensor(tokenizer([prefix, ])['input_ids'])[:, :-1]
    seq = x_l.shape[1]
    if torch.cuda.is_available():
        x_l = x_l.cuda()

    cls_mask = rearrange(x_l != 0, 'b j -> b 1 j')
    attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)

    x_l = model.embed_text(x_l)

    for cross_attn1, cross_attn2 in model.multimodal_layers:
        x_l = cross_attn1(x_l, latent)
        x_l = cross_attn2(x_l, latent)

    pred = model.to_logits(x_l)
    next_word = pred.argmax(dim=-1)[0, -1]
    next_word = tokenizer.decode(next_word)
    
    return next_word

def caption(max_len, latent, model, tokenizer, prefix='a photo of a'):
    words = prefix.split()
    while len(words) < max_len:
        next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words))
        words.append(next_word)
    return ' '.join(words)


def gr_caption(x):
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])
    x = np.array(x) / 255.
    x = x - imagenet_mean
    x = x / imagenet_std

    x = torch.tensor(x).float()
    x = x.unsqueeze(0)
    x = torch.einsum('nhwc->nchw', x)
    if torch.cuda.is_available():
        x = x.cuda()
        
    def unnorm_pix(img):
        img = img.squeeze(0).cpu().detach().numpy()
        img = img * imagenet_std + imagenet_mean
        return np.clip(img, a_min=0., a_max=1.)

    masked, masked_recon, recon, latent = visual_recon(x, model)
    caption_from_model = caption(10, latent, model, tokenizer, )
    
    masked, masked_recon, recon = map(unnorm_pix, (masked, masked_recon, recon))
    
    return masked, masked_recon, recon, caption_from_model

import gradio as gr

demo = gr.Interface(gr_caption, inputs=[gr.Image(shape=(224, 224))], outputs=[gr.Image(shape=(224, 224)), gr.Image(shape=(224, 224)), gr.Image(shape=(224, 224)), 'text'])
demo.launch()