akhaliq's picture
akhaliq HF staff
Update click event to expose a Gradio API endpoint (#1)
b3d65ca
import torch
import torch.nn as nn
import numpy as np
import json
import captioning.utils.opts as opts
import captioning.models as models
import captioning.utils.misc as utils
import pytorch_lightning as pl
import gradio as gr
# Checkpoint class
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
def on_keyboard_interrupt(self, trainer, pl_module):
# Save model when keyboard interrupt
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
self._save_model(filepath)
device = 'cpu' #@param ["cuda", "cpu"] {allow-input: true}
reward = 'clips_grammar' #@param ["mle", "cider", "clips", "cider_clips", "clips_grammar"] {allow-input: true}
if reward == 'mle':
cfg = f'./configs/phase1/clipRN50_{reward}.yml'
else:
cfg = f'./configs/phase2/clipRN50_{reward}.yml'
print("Loading cfg from", cfg)
opt = opts.parse_opt(parse=False, cfg=cfg)
import gdown
if reward == "mle":
url = "https://drive.google.com/drive/folders/1hfHWDn5iXsdjB63E5zdZBAoRLWHQC3LD"
elif reward == "cider":
url = "https://drive.google.com/drive/folders/1MnSmCd8HFnBvQq_4K-q4vsVkzEw0OIOs"
elif reward == "clips":
url = "https://drive.google.com/drive/folders/1toceycN-qilHsbYjKalBLtHJck1acQVe"
elif reward == "cider_clips":
url = "https://drive.google.com/drive/folders/1toceycN-qilHsbYjKalBLtHJck1acQVe"
elif reward == "clips_grammar":
url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W"
gdown.download_folder(url, quiet=True, use_cookies=False, output="save/")
url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp"
gdown.download(url, quiet=True, use_cookies=False, output="data/")
dict_json = json.load(open('./data/cocotalk.json'))
print(dict_json.keys())
ix_to_word = dict_json['ix_to_word']
vocab_size = len(ix_to_word)
print('vocab size:', vocab_size)
seq_length = 1
opt.vocab_size = vocab_size
opt.seq_length = seq_length
opt.batch_size = 1
opt.vocab = ix_to_word
# opt.use_grammar = False
model = models.setup(opt)
del opt.vocab
ckpt_path = opt.checkpoint_path + '-last.ckpt'
print("Loading checkpoint from", ckpt_path)
raw_state_dict = torch.load(
ckpt_path,
map_location=device)
strict = True
state_dict = raw_state_dict['state_dict']
if '_vocab' in state_dict:
model.vocab = utils.deserialize(state_dict['_vocab'])
del state_dict['_vocab']
elif strict:
raise KeyError
if '_opt' in state_dict:
saved_model_opt = utils.deserialize(state_dict['_opt'])
del state_dict['_opt']
# Make sure the saved opt is compatible with the curren topt
need_be_same = ["caption_model",
"rnn_type", "rnn_size", "num_layers"]
for checkme in need_be_same:
if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
getattr(opt, checkme) in ['updown', 'topdown']:
continue
assert getattr(saved_model_opt, checkme) == getattr(
opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
elif strict:
raise KeyError
res = model.load_state_dict(state_dict, strict)
print(res)
model = model.to(device)
model.eval();
import clip
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
from timm.models.vision_transformer import resize_pos_embed
clip_model, clip_transform = clip.load("RN50", jit=False, device=device)
preprocess = Compose([
Resize((448, 448), interpolation=Image.BICUBIC),
CenterCrop((448, 448)),
ToTensor()
])
image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1)
image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1)
num_patches = 196 #600 * 1000 // 32 // 32
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),)
pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
clip_model.visual.attnpool.positional_embedding = pos_embed
def inference(img):
with torch.no_grad():
image = preprocess(img)
image = torch.tensor(np.stack([image])).to(device)
image -= image_mean
image /= image_std
tmp_att, tmp_fc = clip_model.encode_image(image)
tmp_att = tmp_att[0].permute(1, 2, 0)
tmp_fc = tmp_fc[0]
att_feat = tmp_att
fc_feat = tmp_fc
# Inference configurations
eval_kwargs = {}
eval_kwargs.update(vars(opt))
verbose = eval_kwargs.get('verbose', True)
verbose_beam = eval_kwargs.get('verbose_beam', 0)
verbose_loss = eval_kwargs.get('verbose_loss', 1)
# dataset = eval_kwargs.get('dataset', 'coco')
beam_size = eval_kwargs.get('beam_size', 1)
sample_n = eval_kwargs.get('sample_n', 1)
remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
with torch.no_grad():
fc_feats = torch.zeros((1,0)).to(device)
att_feats = att_feat.view(1, 196, 2048).float().to(device)
att_masks = None
# forward the model to also get generated samples for each image
# Only leave one feature for each image, in case duplicate sample
tmp_eval_kwargs = eval_kwargs.copy()
tmp_eval_kwargs.update({'sample_n': 1})
seq, seq_logprobs = model(
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
seq = seq.data
sents = utils.decode_sequence(model.vocab, seq)
return sents[0]
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# Gradio Demo for [j-min/CLIP-Caption-Reward](https://github.com/j-min/CLIP-Caption-Reward)
""")
inp = gr.Image(type="pil")
out = gr.Textbox()
image_button = gr.Button("Run")
image_button.click(fn=inference,
inputs=inp,
outputs=out,
api_name="clip_caption")
demo.launch()