File size: 13,392 Bytes
fba8607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248a504
fba8607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b145c8
 
 
 
 
 
 
 
 
e839607
4b145c8
 
 
 
 
 
 
fba8607
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import os
import torch
import gradio as gr
import time
import clip
import requests
import csv
import json
import wget

url_dict = {'clip_ViTL14_openimage_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_openimage_classifier_weights.pt',
            'clip_ViTL14_place365_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_place365_classifier_weights.pt',
            'clip_ViTL14_tencentml_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_tencentml_classifier_weights.pt'}

os.makedirs('./prompts', exist_ok=True)
for k, v in url_dict.items():
        wget.download(v, out='./prompts')

os.environ['CUDA_VISIBLE_DEVICES'] = ''

API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom"
HF_TOKEN = os.environ["HF_TOKEN"]

def load_openimage_classnames(csv_path):
    csv_data = open(csv_path)
    csv_reader = csv.reader(csv_data)
    classnames = {idx: row[-1] for idx, row in enumerate(csv_reader)}
    return classnames


def load_tencentml_classnames(txt_path):
    txt_data = open(txt_path)
    lines = txt_data.readlines()
    classnames = {idx: line.strip() for idx, line in enumerate(lines)}
    return classnames


def build_simple_classifier(clip_model, text_list, template, device):
    with torch.no_grad():
        texts = [template(text) for text in text_list]
        text_inputs = clip.tokenize(texts).to(device)
        text_features = clip_model.encode_text(text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    return text_features, {idx: text for idx, text in enumerate(text_list)}


def load_models():
    # build model and tokenizer
    model_dict = {}

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print('\tLoading CLIP ViT-L/14')
    clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
    print('\tLoading precomputed zeroshot classifier')
    openimage_classifier_weights = torch.load('./prompts/clip_ViTL14_openimage_classifier_weights.pt', map_location=device).type(torch.FloatTensor)
    openimage_classnames = load_openimage_classnames('./prompts/openimage-classnames.csv')
    tencentml_classifier_weights = torch.load('./prompts/clip_ViTL14_tencentml_classifier_weights.pt', map_location=device).type(torch.FloatTensor)
    tencentml_classnames = load_tencentml_classnames('./prompts/tencent-ml-classnames.txt')
    place365_classifier_weights = torch.load('./prompts/clip_ViTL14_place365_classifier_weights.pt', map_location=device).type(torch.FloatTensor)
    place365_classnames = load_tencentml_classnames('./prompts/place365-classnames.txt')

    print('\tBuilding simple zeroshot classifier')
    img_types = ['photo', 'cartoon', 'sketch', 'painting']
    ppl_texts = ['no people', 'people']
    ifppl_texts = ['is one person', 'are two people', 'are three people', 'are several people', 'are many people']
    imgtype_classifier_weights, imgtype_classnames = build_simple_classifier(clip_model, img_types, lambda c: f'This is a {c}.', device)
    ppl_classifier_weights, ppl_classnames = build_simple_classifier(clip_model, ppl_texts, lambda c: f'There are {c} in this photo.', device)
    ifppl_classifier_weights, ifppl_classnames = build_simple_classifier(clip_model, ifppl_texts, lambda c: f'There {c} in this photo.', device)

    model_dict['clip_model'] = clip_model
    model_dict['clip_preprocess'] = clip_preprocess
    model_dict['openimage_classifier_weights'] = openimage_classifier_weights
    model_dict['openimage_classnames'] = openimage_classnames
    model_dict['tencentml_classifier_weights'] = tencentml_classifier_weights
    model_dict['tencentml_classnames'] = tencentml_classnames
    model_dict['place365_classifier_weights'] = place365_classifier_weights
    model_dict['place365_classnames'] = place365_classnames
    model_dict['imgtype_classifier_weights'] = imgtype_classifier_weights
    model_dict['imgtype_classnames'] = imgtype_classnames
    model_dict['ppl_classifier_weights'] = ppl_classifier_weights
    model_dict['ppl_classnames'] = ppl_classnames
    model_dict['ifppl_classifier_weights'] = ifppl_classifier_weights
    model_dict['ifppl_classnames'] = ifppl_classnames
    model_dict['device'] = device

    return model_dict


def drop_gpu(tensor):
    if torch.cuda.is_available():
        return tensor.cpu().numpy()
    else:
        return tensor.numpy()


def zeroshot_classifier(image):
    image_input = model_dict['clip_preprocess'](image).unsqueeze(0).to(model_dict['device'])
    with torch.no_grad():
        image_features = model_dict['clip_model'].encode_image(image_input)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        sim = (100.0 * image_features @ model_dict['openimage_classifier_weights'].T).softmax(dim=-1)
        openimage_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)]
        openimage_classes = [model_dict['openimage_classnames'][idx] for idx in indices]

        sim = (100.0 * image_features @ model_dict['tencentml_classifier_weights'].T).softmax(dim=-1)
        tencentml_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)]
        tencentml_classes = [model_dict['tencentml_classnames'][idx] for idx in indices]

        sim = (100.0 * image_features @ model_dict['place365_classifier_weights'].T).softmax(dim=-1)
        place365_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)]
        place365_classes = [model_dict['place365_classnames'][idx] for idx in indices]

        sim = (100.0 * image_features @ model_dict['imgtype_classifier_weights'].T).softmax(dim=-1)
        imgtype_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['imgtype_classnames']))]
        imgtype_classes = [model_dict['imgtype_classnames'][idx] for idx in indices]

        sim = (100.0 * image_features @ model_dict['ppl_classifier_weights'].T).softmax(dim=-1)
        ppl_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['ppl_classnames']))]
        ppl_classes = [model_dict['ppl_classnames'][idx] for idx in indices]

        sim = (100.0 * image_features @ model_dict['ifppl_classifier_weights'].T).softmax(dim=-1)
        ifppl_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['ifppl_classnames']))]
        ifppl_classes = [model_dict['ifppl_classnames'][idx] for idx in indices]

    return image_features, openimage_scores, openimage_classes, tencentml_scores, tencentml_classes,\
           place365_scores, place365_classes, imgtype_scores, imgtype_classes,\
           ppl_scores, ppl_classes, ifppl_scores, ifppl_classes


def generate_prompt(openimage_classes, tencentml_classes, place365_classes, imgtype_classes, ppl_classes, ifppl_classes):
    img_type = imgtype_classes[0]
    ppl_result = ppl_classes[0]
    if ppl_result == 'people':
        ppl_result = ifppl_classes[0]
    else:
        ppl_result = 'are %s' % ppl_result

    sorted_places = place365_classes

    object_list = ''
    for cls in tencentml_classes:
        object_list += f'{cls}, '
    for cls in openimage_classes[:2]:
        object_list += f'{cls}, '
    object_list = object_list[:-2]

    prompt_caption = f'''I am an intelligent image captioning bot.
    This image is a {img_type}. There {ppl_result}.
    I think this photo was taken at a {sorted_places[0]}, {sorted_places[1]}, or {sorted_places[2]}.
    I think there might be a {object_list} in this {img_type}.
    A creative short caption I can generate to describe this image is:'''

    #prompt_search = f'''Let's list keywords that include the following description.
    #This image is a {img_type}. There {ppl_result}.
    #I think this photo was taken at a {sorted_places[0]}, {sorted_places[1]}, or {sorted_places[2]}.
    #I think there might be a {object_list} in this {img_type}.
    #Relevant keywords which we can list and are seperated with comma are:'''

    return prompt_caption


def generate_captions(prompt, num_captions=3):
    headers = {"Authorization": f"Bearer {HF_TOKEN}"}

    max_length = 16
    seed = 42
    sample_or_greedy = 'Greedy'
    input_sentence = prompt
    if sample_or_greedy == "Sample":
        parameters = {
            "max_new_tokens": max_length,
            "top_p": 0.7,
            "do_sample": True,
            "seed": seed,
            "early_stopping": False,
            "length_penalty": 0.0,
            "eos_token_id": None,
        }
    else:
        parameters = {
            "max_new_tokens": max_length,
            "do_sample": False,
            "seed": seed,
            "early_stopping": False,
            "length_penalty": 0.0,
            "eos_token_id": None,
        }

    payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False}}

    bloom_results = []
    for _ in range(num_captions):
        response = requests.post(API_URL, headers=headers, json=payload)
        output = response.json()
        generated_text = output[0]['generated_text'].replace(prompt, '').split('.')[0] + '.'
        bloom_results.append(generated_text)
    return bloom_results


def sorting_texts(image_features, captions):
    with torch.no_grad():
        text_inputs = clip.tokenize(captions).to(model_dict['device'])
        text_features = model_dict['clip_model'].encode_text(text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        sim = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(captions))]
        sorted_captions = [captions[idx] for idx in indices]

    return scores, sorted_captions


def postprocess_results(scores, classes):
    scores = [float('%.4f' % float(val)) for val in scores]
    outputs = []
    for score, cls in zip(scores, classes):
        outputs.append({'score': score, 'output': cls})
    return outputs


def image_captioning(image):
    start_time = time.time()
    image_features, openimage_scores, openimage_classes, tencentml_scores, tencentml_classes, place365_scores, place365_classes, imgtype_scores, imgtype_classes, ppl_scores, ppl_classes, ifppl_scores, ifppl_classes = zeroshot_classifier(image)
    end_zeroshot = time.time()
    prompt_caption = generate_prompt(openimage_classes, tencentml_classes, place365_classes, imgtype_classes, ppl_classes, ifppl_classes)
    generated_captions = generate_captions(prompt_caption, num_captions=1)
    end_bloom = time.time()
    caption_scores, sorted_captions = sorting_texts(image_features, generated_captions)

    output_dict = {}
    output_dict['inference_time'] = {'CLIP inference': end_zeroshot - start_time,
                                     'BLOOM request': end_bloom - end_zeroshot}

    output_dict['generated_captions'] = postprocess_results(caption_scores, sorted_captions)
    output_dict['reasoning'] = {'openimage_results': postprocess_results(openimage_scores, openimage_classes),
                                'tencentml_results': postprocess_results(tencentml_scores, tencentml_classes),
                                'place365_results': postprocess_results(place365_scores, place365_classes),
                                'imgtype_results': postprocess_results(imgtype_scores, imgtype_classes),
                                'ppl_results': postprocess_results(ppl_scores, ppl_classes),
                                'ifppl_results': postprocess_results(ifppl_scores, ifppl_classes)}
    return output_dict


if __name__ == '__main__':
    print('\tinit models')

    global model_dict

    model_dict = load_models()
    
    # define gradio demo
    inputs = [gr.inputs.Image(type="pil", label="Image")
              ]

    outputs = gr.outputs.JSON()

    title = "Socratic models for image captioning with BLOOM"

    description = """
    ## Details
    **Without any fine-tuning**, we can do image captioning using Visual-Language models (e.g., CLIP, SLIP, ...) and Large language models (e.g., GPT, BLOOM, ...).
    In this demo, I choose BLOOM as the language model and CLIP ViT-L/14 as the visual-language model.
    The order of generating image caption is as follow:
    1. Classify whether there are people, where the location is, and what objects are in the input image using the visual-language model.
    2. Then, build a prompt using classified results.
    3. Request BLOOM API with the prompt.

    This demo is slightly different with the original method proposed in the socratic model paper.
    I used not only tencent ml class names, but also OpenImage class names and I adopt BLOOM for the large language model

    If you want the demo using GPT3 from OpenAI, check https://github.com/geonm/socratic-models-demo.

    Demo is running on CPU.
    """

    article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.00598'>Socratic Models: Composing Zero-Shot Multimodal Reasoning with Language</a></p>"
    examples = ['k21-1.jpg']

    gr.Interface(image_captioning,
                 inputs,
                 outputs,
                 title=title,
                 description=description,
                 article=article,
                 examples=examples,
                 #examples_per_page=50,
                 ).launch()