wybertwang commited on
Commit
c426a27
1 Parent(s): 3cb3d90

Upload 78 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ test_img/img18.jpg filter=lfs diff=lfs merge=lfs -text
36
+ test_img/img22.jpg filter=lfs diff=lfs merge=lfs -text
DejaVuSansCondensed-Bold.ttf ADDED
Binary file (632 kB). View file
 
Image/demo1.svg ADDED
Image/demo2.svg ADDED
Image/title.svg ADDED
LICENSE ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2023, Teng Wang
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
app_old.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import string
3
+ import gradio as gr
4
+ import requests
5
+ from caas import CaptionAnything
6
+ import torch
7
+ import json
8
+ import sys
9
+ import argparse
10
+ from caas import parse_augment
11
+ import os
12
+
13
+ # download sam checkpoint if not downloaded
14
+ def download_checkpoint(url, folder, filename):
15
+ os.makedirs(folder, exist_ok=True)
16
+ filepath = os.path.join(folder, filename)
17
+
18
+ if not os.path.exists(filepath):
19
+ response = requests.get(url, stream=True)
20
+ with open(filepath, "wb") as f:
21
+ for chunk in response.iter_content(chunk_size=8192):
22
+ if chunk:
23
+ f.write(chunk)
24
+
25
+ return filepath
26
+ checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
27
+ folder = "segmenter"
28
+ filename = "sam_vit_h_4b8939.pth"
29
+
30
+ title = """<h1 align="center">Caption-Anything</h1>"""
31
+ description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them.
32
+ <br> <strong>Code</strong>: GitHub repo: <a href='https://github.com/ttengwang/Caption-Anything' target='_blank'></a>
33
+ """
34
+
35
+ examples = [
36
+ ["test_img/img2.jpg", "[[1000, 700, 1]]"]
37
+ ]
38
+
39
+ args = parse_augment()
40
+
41
+ def get_prompt(chat_input, click_state):
42
+ points = click_state[0]
43
+ labels = click_state[1]
44
+ inputs = json.loads(chat_input)
45
+ for input in inputs:
46
+ points.append(input[:2])
47
+ labels.append(input[2])
48
+
49
+ prompt = {
50
+ "prompt_type":["click"],
51
+ "input_point":points,
52
+ "input_label":labels,
53
+ "multimask_output":"True",
54
+ }
55
+ return prompt
56
+
57
+ def inference_seg_cap(image_input, chat_input, language, sentiment, factuality, length, state, click_state):
58
+ controls = {'length': length,
59
+ 'sentiment': sentiment,
60
+ 'factuality': factuality,
61
+ 'language': language}
62
+ prompt = get_prompt(chat_input, click_state)
63
+ print('prompt: ', prompt, 'controls: ', controls)
64
+ out = model.inference(image_input, prompt, controls)
65
+ state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
66
+ for k, v in out['generated_captions'].items():
67
+ state = state + [(f'{k}: {v}', None)]
68
+ click_state[2].append(out['generated_captions']['raw_caption'])
69
+ image_output_mask = out['mask_save_path']
70
+ image_output_crop = out['crop_save_path']
71
+ return state, state, click_state, image_output_mask, image_output_crop
72
+
73
+
74
+ def upload_callback(image_input, state):
75
+ state = state + [('Image size: ' + str(image_input.size), None)]
76
+ return state
77
+
78
+ # get coordinate in format [[x,y,positive/negative]]
79
+ def get_select_coords(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt: gr.SelectData):
80
+ print("point_prompt: ", point_prompt)
81
+ if point_prompt == 'Positive Point':
82
+ coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
83
+ else:
84
+ coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
85
+ return (coordinate,) + inference_seg_cap(image_input, coordinate, language, sentiment, factuality, length, state, click_state)
86
+
87
+ def chat_with_points(chat_input, click_state, state):
88
+ points, labels, captions = click_state
89
+ point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\n. Now begin chatting! Human: {chat_input}\nAI: "
90
+ # "The image is of width {width} and height {height}."
91
+
92
+ prev_visual_context = ""
93
+ pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
94
+ prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
95
+ chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
96
+ response = model.text_refiner.llm(chat_prompt)
97
+ state = state + [(chat_input, response)]
98
+ return state, state
99
+
100
+ def init_openai_api_key(api_key):
101
+ os.environ['OPENAI_API_KEY'] = api_key
102
+ global model
103
+ model = CaptionAnything(args)
104
+
105
+ css='''
106
+ #image_upload{min-height:200px}
107
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 200px}
108
+ '''
109
+
110
+ with gr.Blocks(css=css) as iface:
111
+ state = gr.State([])
112
+ click_state = gr.State([[],[],[]])
113
+ caption_state = gr.State([[]])
114
+ gr.Markdown(title)
115
+ gr.Markdown(description)
116
+
117
+ with gr.Column():
118
+ openai_api_key = gr.Textbox(
119
+ placeholder="Input your openAI API key and press Enter",
120
+ show_label=False,
121
+ lines=1,
122
+ type="password",
123
+ )
124
+ openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
125
+
126
+ with gr.Row():
127
+ with gr.Column(scale=0.7):
128
+ image_input = gr.Image(type="pil", interactive=True, label="Image", elem_id="image_upload").style(height=260,scale=1.0)
129
+
130
+ with gr.Row(scale=0.7):
131
+ point_prompt = gr.Radio(
132
+ choices=["Positive Point", "Negative Point"],
133
+ value="Positive Point",
134
+ label="Points",
135
+ interactive=True,
136
+ )
137
+
138
+ # with gr.Row():
139
+ language = gr.Radio(
140
+ choices=["English", "Chinese", "French", "Spanish", "Arabic", "Portuguese","Cantonese"],
141
+ value="English",
142
+ label="Language",
143
+ interactive=True,
144
+ )
145
+ sentiment = gr.Radio(
146
+ choices=["Positive", "Natural", "Negative"],
147
+ value="Natural",
148
+ label="Sentiment",
149
+ interactive=True,
150
+ )
151
+ factuality = gr.Radio(
152
+ choices=["Factual", "Imagination"],
153
+ value="Factual",
154
+ label="Factuality",
155
+ interactive=True,
156
+ )
157
+ length = gr.Slider(
158
+ minimum=5,
159
+ maximum=100,
160
+ value=10,
161
+ step=1,
162
+ interactive=True,
163
+ label="Length",
164
+ )
165
+
166
+ with gr.Column(scale=1.5):
167
+ with gr.Row():
168
+ image_output_mask= gr.Image(type="pil", interactive=False, label="Mask").style(height=260,scale=1.0)
169
+ image_output_crop= gr.Image(type="pil", interactive=False, label="Cropped Image by Mask", show_progress=False).style(height=260,scale=1.0)
170
+ chatbot = gr.Chatbot(label="Chat Output",).style(height=450,scale=0.5)
171
+
172
+ with gr.Row():
173
+ with gr.Column(scale=0.7):
174
+ prompt_input = gr.Textbox(lines=1, label="Input Prompt (A list of points like : [[100, 200, 1], [200,300,0]])")
175
+ prompt_input.submit(
176
+ inference_seg_cap,
177
+ [
178
+ image_input,
179
+ prompt_input,
180
+ language,
181
+ sentiment,
182
+ factuality,
183
+ length,
184
+ state,
185
+ click_state
186
+ ],
187
+ [chatbot, state, click_state, image_output_mask, image_output_crop],
188
+ show_progress=False
189
+ )
190
+
191
+ image_input.upload(
192
+ upload_callback,
193
+ [image_input, state],
194
+ [chatbot]
195
+ )
196
+
197
+ with gr.Row():
198
+ clear_button = gr.Button(value="Clear Click", interactive=True)
199
+ clear_button.click(
200
+ lambda: ("", [[], [], []], None, None),
201
+ [],
202
+ [prompt_input, click_state, image_output_mask, image_output_crop],
203
+ queue=False,
204
+ show_progress=False
205
+ )
206
+
207
+ clear_button = gr.Button(value="Clear", interactive=True)
208
+ clear_button.click(
209
+ lambda: ("", [], [], [[], [], []], None, None),
210
+ [],
211
+ [prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop],
212
+ queue=False,
213
+ show_progress=False
214
+ )
215
+
216
+ submit_button = gr.Button(
217
+ value="Submit", interactive=True, variant="primary"
218
+ )
219
+ submit_button.click(
220
+ inference_seg_cap,
221
+ [
222
+ image_input,
223
+ prompt_input,
224
+ language,
225
+ sentiment,
226
+ factuality,
227
+ length,
228
+ state,
229
+ click_state
230
+ ],
231
+ [chatbot, state, click_state, image_output_mask, image_output_crop],
232
+ show_progress=False
233
+ )
234
+
235
+ # select coordinate
236
+ image_input.select(
237
+ get_select_coords,
238
+ inputs=[image_input,point_prompt,language,sentiment,factuality,length,state,click_state],
239
+ outputs=[prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop],
240
+ show_progress=False
241
+ )
242
+
243
+ image_input.change(
244
+ lambda: ("", [], [[], [], []]),
245
+ [],
246
+ [chatbot, state, click_state],
247
+ queue=False,
248
+ )
249
+
250
+ with gr.Column(scale=1.5):
251
+ chat_input = gr.Textbox(lines=1, label="Chat Input")
252
+ chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
253
+
254
+
255
+ examples = gr.Examples(
256
+ examples=examples,
257
+ inputs=[image_input, prompt_input],
258
+ )
259
+
260
+ iface.queue(concurrency_count=1, api_open=False, max_size=10)
261
+ iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
caas.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from captioner import build_captioner, BaseCaptioner
2
+ from segmenter import build_segmenter
3
+ from text_refiner import build_text_refiner
4
+ import os
5
+ import argparse
6
+ import pdb
7
+ import time
8
+ from PIL import Image
9
+
10
+ class CaptionAnything():
11
+ def __init__(self, args):
12
+ self.args = args
13
+ self.captioner = build_captioner(args.captioner, args.device, args)
14
+ self.segmenter = build_segmenter(args.segmenter, args.device, args)
15
+ if not args.disable_gpt:
16
+ self.init_refiner()
17
+
18
+
19
+ def init_refiner(self):
20
+ if os.environ.get('OPENAI_API_KEY', None):
21
+ self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args)
22
+
23
+ def inference(self, image, prompt, controls, disable_gpt=False):
24
+ # segment with prompt
25
+ print("CA prompt: ", prompt, "CA controls",controls)
26
+ seg_mask = self.segmenter.inference(image, prompt)[0, ...]
27
+ mask_save_path = f'result/mask_{time.time()}.png'
28
+ if not os.path.exists(os.path.dirname(mask_save_path)):
29
+ os.makedirs(os.path.dirname(mask_save_path))
30
+ new_p = Image.fromarray(seg_mask.astype('int') * 255.)
31
+ if new_p.mode != 'RGB':
32
+ new_p = new_p.convert('RGB')
33
+ new_p.save(mask_save_path)
34
+ print('seg_mask path: ', mask_save_path)
35
+ print("seg_mask.shape: ", seg_mask.shape)
36
+ # captioning with mask
37
+ if self.args.enable_reduce_tokens:
38
+ caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
39
+ else:
40
+ caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
41
+ # refining with TextRefiner
42
+ context_captions = []
43
+ if self.args.context_captions:
44
+ context_captions.append(self.captioner.inference(image))
45
+ if not disable_gpt and hasattr(self, "text_refiner"):
46
+ refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
47
+ else:
48
+ refined_caption = {'raw_caption': caption}
49
+ out = {'generated_captions': refined_caption,
50
+ 'crop_save_path': crop_save_path,
51
+ 'mask_save_path': mask_save_path,
52
+ 'context_captions': context_captions}
53
+ return out
54
+
55
+ def parse_augment():
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument('--captioner', type=str, default="blip")
58
+ parser.add_argument('--segmenter', type=str, default="base")
59
+ parser.add_argument('--text_refiner', type=str, default="base")
60
+ parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
61
+ parser.add_argument('--seg_crop_mode', type=str, default="w_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
62
+ parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
63
+ parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption")
64
+ parser.add_argument('--regular_box', action="store_true", default = False, help="crop image with a regular box")
65
+ parser.add_argument('--device', type=str, default="cuda:0")
66
+ parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
67
+ parser.add_argument('--debug', action="store_true")
68
+ parser.add_argument('--gradio_share', action="store_true")
69
+ parser.add_argument('--disable_gpt', action="store_true")
70
+ parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
71
+ parser.add_argument('--disable_reuse_features', action="store_true", default=False)
72
+ args = parser.parse_args()
73
+
74
+ if args.debug:
75
+ print(args)
76
+ return args
77
+
78
+ if __name__ == "__main__":
79
+ args = parse_augment()
80
+ # image_path = 'test_img/img3.jpg'
81
+ image_path = 'test_img/img13.jpg'
82
+ prompts = [
83
+ {
84
+ "prompt_type":["click"],
85
+ "input_point":[[500, 300], [1000, 500]],
86
+ "input_label":[1, 0],
87
+ "multimask_output":"True",
88
+ },
89
+ {
90
+ "prompt_type":["click"],
91
+ "input_point":[[900, 800]],
92
+ "input_label":[1],
93
+ "multimask_output":"True",
94
+ }
95
+ ]
96
+ controls = {
97
+ "length": "30",
98
+ "sentiment": "positive",
99
+ # "imagination": "True",
100
+ "imagination": "False",
101
+ "language": "English",
102
+ }
103
+
104
+ model = CaptionAnything(args)
105
+ for prompt in prompts:
106
+ print('*'*30)
107
+ print('Image path: ', image_path)
108
+ image = Image.open(image_path)
109
+ print(image)
110
+ print('Visual controls (SAM prompt):\n', prompt)
111
+ print('Language controls:\n', controls)
112
+ out = model.inference(image_path, prompt, controls)
113
+
114
+
captioner/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ To run BLIP/BLIP2, you should install transformers from source!
2
+ ```
3
+ !pip install git+https://github.com/huggingface/transformers.git
4
+ ```
5
+ To run filter module, you should install CLIP repo as a Python package as follow:
6
+ ```
7
+ !pip install ftfy regex tqdm
8
+ !pip install git+https://github.com/openai/CLIP.git
9
+ ```
10
+ To accelerate BLIP2 with int8, you should install accelerate
11
+ ```
12
+ !pip install accelerate bitsandbytes
13
+ ```
captioner/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .blip import BLIPCaptioner
2
+ from .blip2 import BLIP2Captioner
3
+ from .git import GITCaptioner
4
+ from .base_captioner import BaseCaptioner
5
+
6
+
7
+ def build_captioner(type, device, args=None):
8
+ if type == 'blip':
9
+ return BLIPCaptioner(device, enable_filter=args.clip_filter)
10
+ elif type == 'blip2':
11
+ return BLIP2Captioner(device, enable_filter=args.clip_filter)
12
+ elif type == 'git':
13
+ return GITCaptioner(device, enable_filter=args.clip_filter)
14
+ else:
15
+ raise NotImplementedError("")
captioner/base_captioner.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image, ImageDraw, ImageOps
3
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
4
+ import json
5
+ import pdb
6
+ import cv2
7
+ import numpy as np
8
+ from typing import Union
9
+ import time
10
+ import clip
11
+
12
+ def boundary(inputs):
13
+
14
+ col = inputs.shape[1]
15
+ inputs = inputs.reshape(-1)
16
+ lens = len(inputs)
17
+
18
+ for i in range(lens):
19
+ if inputs[i] != False:
20
+ break
21
+ for j in range(lens):
22
+ if inputs[lens - 1 - j] != False:
23
+ break
24
+ start = i
25
+ end = lens - 1 - j
26
+ top = start // col
27
+ bottom = end // col
28
+
29
+ return top, bottom
30
+
31
+ def new_seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
32
+
33
+ if type(seg_mask) == str:
34
+ seg_mask = Image.open(seg_mask)
35
+ elif type(seg_mask) == np.ndarray:
36
+ seg_mask = Image.fromarray(seg_mask)
37
+ seg_mask = np.array(seg_mask) > 0
38
+ size = max(seg_mask.shape[0], seg_mask.shape[1])
39
+ top, bottom = boundary(seg_mask)
40
+ left, right = boundary(seg_mask.T)
41
+ return [left / size, top / size, right / size, bottom / size]
42
+
43
+ def seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
44
+ if type(seg_mask) == str:
45
+ seg_mask = cv2.imread(seg_mask, cv2.IMREAD_GRAYSCALE)
46
+ _, seg_mask = cv2.threshold(seg_mask, 127, 255, 0)
47
+ elif type(seg_mask) == np.ndarray:
48
+ assert seg_mask.ndim == 2 # only support single-channel segmentation mask
49
+ seg_mask = seg_mask.astype('uint8')
50
+ if seg_mask.dtype == 'bool':
51
+ seg_mask = seg_mask * 255
52
+ contours, hierarchy = cv2.findContours(seg_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
53
+ contours = np.concatenate(contours, axis=0)
54
+ rect = cv2.minAreaRect(contours)
55
+ box = cv2.boxPoints(rect)
56
+ if rect[-1] >= 45:
57
+ newstart = box.argmin(axis=0)[1] # leftmost
58
+ else:
59
+ newstart = box.argmax(axis=0)[0] # topmost
60
+ box = np.concatenate([box[newstart:], box[:newstart]], axis=0)
61
+ box = np.int0(box)
62
+ return box
63
+
64
+ def get_w_h(rect_points):
65
+ w = np.linalg.norm(rect_points[0] - rect_points[1], ord=2).astype('int')
66
+ h = np.linalg.norm(rect_points[0] - rect_points[3], ord=2).astype('int')
67
+ return w, h
68
+
69
+ def cut_box(img, rect_points):
70
+ w, h = get_w_h(rect_points)
71
+ dst_pts = np.array([[h, 0], [h, w], [0, w], [0, 0],], dtype="float32")
72
+ transform = cv2.getPerspectiveTransform(rect_points.astype("float32"), dst_pts)
73
+ cropped_img = cv2.warpPerspective(img, transform, (h, w))
74
+ return cropped_img
75
+
76
+ class BaseCaptioner:
77
+ def __init__(self, device, enable_filter=False):
78
+ print(f"Initializing ImageCaptioning to {device}")
79
+ self.device = device
80
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
81
+ self.processor = None
82
+ self.model = None
83
+ self.enable_filter = enable_filter
84
+ if enable_filter:
85
+ self.filter, self.preprocess = clip.load('ViT-B/32', device)
86
+ self.threshold = 0.2
87
+
88
+ @torch.no_grad()
89
+ def filter_caption(self, image: Union[np.ndarray, Image.Image, str], caption: str):
90
+
91
+ if type(image) == str: # input path
92
+ image = Image.open(image)
93
+ elif type(image) == np.ndarray:
94
+ image = Image.fromarray(image)
95
+
96
+ image = self.preprocess(image).unsqueeze(0).to(self.device) # (1, 3, 224, 224)
97
+ text = clip.tokenize(caption).to(self.device) # (1, 77)
98
+ image_features = self.filter.encode_image(image) # (1, 512)
99
+ text_features = self.filter.encode_text(text) # (1, 512)
100
+ image_features /= image_features.norm(dim = -1, keepdim = True)
101
+ text_features /= text_features.norm(dim = -1, keepdim = True)
102
+ similarity = torch.matmul(image_features, text_features.transpose(1, 0)).item()
103
+ if similarity < self.threshold:
104
+ print('There seems to be nothing where you clicked.')
105
+ out = ""
106
+ else:
107
+ out = caption
108
+ print(f'Clip score of the caption is {similarity}')
109
+ return out
110
+
111
+
112
+ def inference(self, image: Union[np.ndarray, Image.Image, str], filter: bool=False):
113
+ raise NotImplementedError()
114
+
115
+ def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, filter: bool=False):
116
+ raise NotImplementedError()
117
+
118
+ def inference_box(self, image: Union[np.ndarray, Image.Image, str], box: Union[list, np.ndarray], filter=False):
119
+ if type(image) == str: # input path
120
+ image = Image.open(image)
121
+ elif type(image) == np.ndarray:
122
+ image = Image.fromarray(image)
123
+
124
+ if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
125
+ size = max(image.width, image.height)
126
+ x1, y1, x2, y2 = box
127
+ image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
128
+ elif np.array(box).size == 8: # four corners of an irregular rectangle
129
+ image_crop = cut_box(np.array(image), box)
130
+
131
+ crop_save_path = f'result/crop_{time.time()}.png'
132
+ Image.fromarray(image_crop).save(crop_save_path)
133
+ print(f'croped image saved in {crop_save_path}')
134
+ caption = self.inference(image_crop, filter)
135
+ return caption, crop_save_path
136
+
137
+
138
+ def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", filter=False, regular_box = False):
139
+ if type(image) == str:
140
+ image = Image.open(image)
141
+ if type(seg_mask) == str:
142
+ seg_mask = Image.open(seg_mask)
143
+ elif type(seg_mask) == np.ndarray:
144
+ seg_mask = Image.fromarray(seg_mask)
145
+ seg_mask = seg_mask.resize(image.size)
146
+ seg_mask = np.array(seg_mask) > 0
147
+
148
+ if crop_mode=="wo_bg":
149
+ image = np.array(image) * seg_mask[:,:,np.newaxis]
150
+ else:
151
+ image = np.array(image)
152
+
153
+ if regular_box:
154
+ min_area_box = new_seg_to_box(seg_mask)
155
+ else:
156
+ min_area_box = seg_to_box(seg_mask)
157
+ return self.inference_box(image, min_area_box, filter)
158
+
159
+
160
+ def generate_seg_cropped_image(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", regular_box = False):
161
+ if type(image) == str:
162
+ image = Image.open(image)
163
+ if type(seg_mask) == str:
164
+ seg_mask = Image.open(seg_mask)
165
+ elif type(seg_mask) == np.ndarray:
166
+ seg_mask = Image.fromarray(seg_mask)
167
+ seg_mask = seg_mask.resize(image.size)
168
+ seg_mask = np.array(seg_mask) > 0
169
+
170
+ if crop_mode=="wo_bg":
171
+ image = np.array(image) * seg_mask[:,:,np.newaxis]
172
+ else:
173
+ image = np.array(image)
174
+
175
+ if regular_box:
176
+ box = new_seg_to_box(seg_mask)
177
+ else:
178
+ box = seg_to_box(seg_mask)
179
+
180
+ if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
181
+ size = max(image.shape[0], image.shape[1])
182
+ x1, y1, x2, y2 = box
183
+ image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
184
+ elif np.array(box).size == 8: # four corners of an irregular rectangle
185
+ image_crop = cut_box(np.array(image), box)
186
+ crop_save_path = f'result/crop_{time.time()}.png'
187
+ Image.fromarray(image_crop).save(crop_save_path)
188
+ print(f'croped image saved in {crop_save_path}')
189
+ return crop_save_path
190
+
191
+
192
+ if __name__ == '__main__':
193
+ model = BaseCaptioner(device='cuda:0')
194
+ image_path = 'test_img/img2.jpg'
195
+ seg_mask = np.zeros((15,15))
196
+ seg_mask[5:10, 5:10] = 1
197
+ seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
198
+ print(model.inference_seg(image_path, seg_mask))
199
+
captioner/blip.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image, ImageDraw, ImageOps
3
+ from transformers import BlipProcessor
4
+ from .modeling_blip import BlipForConditionalGeneration
5
+ import json
6
+ import pdb
7
+ import cv2
8
+ import numpy as np
9
+ from typing import Union
10
+ from .base_captioner import BaseCaptioner
11
+ import torchvision.transforms.functional as F
12
+
13
+
14
+ class BLIPCaptioner(BaseCaptioner):
15
+ def __init__(self, device, enable_filter=False):
16
+ super().__init__(device, enable_filter)
17
+ self.device = device
18
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
19
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
20
+ self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=self.torch_dtype).to(self.device)
21
+
22
+ @torch.no_grad()
23
+ def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
24
+ if type(image) == str: # input path
25
+ image = Image.open(image)
26
+ inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
27
+ out = self.model.generate(**inputs, max_new_tokens=50)
28
+ captions = self.processor.decode(out[0], skip_special_tokens=True)
29
+ if self.enable_filter and filter:
30
+ captions = self.filter_caption(image, captions)
31
+ print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
32
+ return captions
33
+
34
+ @torch.no_grad()
35
+ def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, regular_box = False):
36
+ crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, regular_box=regular_box)
37
+ if type(image) == str: # input path
38
+ image = Image.open(image)
39
+ inputs = self.processor(image, return_tensors="pt")
40
+ pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
41
+ _, _, H, W = pixel_values.shape
42
+ seg_mask = Image.fromarray(seg_mask.astype(float))
43
+ seg_mask = seg_mask.resize((H, W))
44
+ seg_mask = F.pil_to_tensor(seg_mask) > 0.5
45
+ seg_mask = seg_mask.float()
46
+ pixel_masks = seg_mask.unsqueeze(0).to(self.device)
47
+ out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
48
+ captions = self.processor.decode(out[0], skip_special_tokens=True)
49
+ if self.enable_filter and filter:
50
+ captions = self.filter_caption(image, captions)
51
+ print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
52
+ return captions, crop_save_path
53
+
54
+
55
+ if __name__ == '__main__':
56
+ model = BLIPCaptioner(device='cuda:0')
57
+ # image_path = 'test_img/img2.jpg'
58
+ image_path = '/group/30042/wybertwang/project/woa_visgpt/chatARC/image/SAM/img10.jpg'
59
+ seg_mask = np.zeros((15,15))
60
+ seg_mask[5:10, 5:10] = 1
61
+ seg_mask = 'test_img/img10.jpg.raw_mask.png'
62
+ image_path = 'test_img/img2.jpg'
63
+ seg_mask = 'test_img/img2.jpg.raw_mask.png'
64
+ print(f'process image {image_path}')
65
+ print(model.inference_with_reduced_tokens(image_path, seg_mask))
66
+
captioner/blip2.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image, ImageDraw, ImageOps
3
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
4
+ import json
5
+ import pdb
6
+ import cv2
7
+ import numpy as np
8
+ from typing import Union
9
+ from .base_captioner import BaseCaptioner
10
+
11
+ class BLIP2Captioner(BaseCaptioner):
12
+ def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
13
+ super().__init__(device, enable_filter)
14
+ self.device = device
15
+ self.dialogue = dialogue
16
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
17
+ self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
18
+ self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map = 'sequential', load_in_8bit=True)
19
+ @torch.no_grad()
20
+ def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
21
+ if type(image) == str: # input path
22
+ image = Image.open(image)
23
+
24
+ if not self.dialogue:
25
+ inputs = self.processor(image, text = 'Ignore the black background! This is a photo of ', return_tensors="pt").to(self.device, self.torch_dtype)
26
+ out = self.model.generate(**inputs, max_new_tokens=50)
27
+ captions = self.processor.decode(out[0], skip_special_tokens=True)
28
+ if self.enable_filter and filter:
29
+ captions = self.filter_caption(image, captions)
30
+ print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {captions}")
31
+ return captions
32
+ else:
33
+ context = []
34
+ template = "Question: {} Answer: {}."
35
+ while(True):
36
+ input_texts = input()
37
+ if input_texts == 'end':
38
+ break
39
+ prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + input_texts + " Answer:"
40
+ inputs = self.processor(image, text = prompt, return_tensors="pt").to(self.device, self.torch_dtype)
41
+ out = self.model.generate(**inputs, max_new_tokens=50)
42
+ captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
43
+ context.append((input_texts, captions))
44
+
45
+ return captions
46
+
47
+ if __name__ == '__main__':
48
+
49
+ dialogue = False
50
+ model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
51
+ image_path = 'test_img/img2.jpg'
52
+ seg_mask = np.zeros((224,224))
53
+ seg_mask[50:200, 50:200] = 1
54
+ print(f'process image {image_path}')
55
+ print(model.inference_seg(image_path, seg_mask))
captioner/git.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GitProcessor, AutoProcessor
2
+ from .modeling_git import GitForCausalLM
3
+ from PIL import Image
4
+ import torch
5
+ from .base_captioner import BaseCaptioner
6
+ import numpy as np
7
+ from typing import Union
8
+ import torchvision.transforms.functional as F
9
+
10
+
11
+ class GITCaptioner(BaseCaptioner):
12
+ def __init__(self, device, enable_filter=False):
13
+ super().__init__(device, enable_filter)
14
+ self.device = device
15
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
16
+ self.processor = AutoProcessor.from_pretrained("microsoft/git-large")
17
+ self.model = GitForCausalLM.from_pretrained("microsoft/git-large", torch_dtype=self.torch_dtype).to(self.device)
18
+
19
+ @torch.no_grad()
20
+ def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
21
+ if type(image) == str: # input path
22
+ image = Image.open(image)
23
+ pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
24
+ generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
25
+ generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
26
+ if self.enable_filter and filter:
27
+ captions = self.filter_caption(image, captions)
28
+ print(f"\nProcessed ImageCaptioning by GITCaptioner, Output Text: {generated_caption}")
29
+ return generated_caption
30
+
31
+ @torch.no_grad()
32
+ def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, regular_box = False):
33
+ crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, regular_box=regular_box)
34
+ if type(image) == str: # input path
35
+ image = Image.open(image)
36
+ inputs = self.processor(images=image, return_tensors="pt")
37
+ pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
38
+ _, _, H, W = pixel_values.shape
39
+ seg_mask = Image.fromarray(seg_mask.astype(float))
40
+ seg_mask = seg_mask.resize((H, W))
41
+ seg_mask = F.pil_to_tensor(seg_mask) > 0.5
42
+ seg_mask = seg_mask.float()
43
+ pixel_masks = seg_mask.unsqueeze(0).to(self.device)
44
+ out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
45
+ captions = self.processor.decode(out[0], skip_special_tokens=True)
46
+ if self.enable_filter and filter:
47
+ captions = self.filter_caption(image, captions)
48
+ print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
49
+ return captions, crop_save_path
50
+
51
+ if __name__ == '__main__':
52
+ model = GITCaptioner(device='cuda:2', enable_filter=False)
53
+ image_path = 'test_img/img2.jpg'
54
+ seg_mask = np.zeros((224,224))
55
+ seg_mask[50:200, 50:200] = 1
56
+ print(f'process image {image_path}')
57
+ print(model.inference_with_reduced_tokens(image_path, seg_mask))
captioner/modeling_blip.py ADDED
@@ -0,0 +1,1476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Salesforce Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch BLIP model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Any, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn.functional import normalize
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.utils import (
29
+ ModelOutput,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ )
35
+ from transformers.models.blip.configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
36
+ from transformers.models.blip.modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
37
+ from .vit_pixel_masks_utils import ViTPatchMaskGenerator
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ _CHECKPOINT_FOR_DOC = "Salesforce/blip-vqa-base"
42
+
43
+ BLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
44
+ "Salesforce/blip-vqa-base",
45
+ "Salesforce/blip-vqa-capfit-large",
46
+ "Salesforce/blip-image-captioning-base",
47
+ "Salesforce/blip-image-captioning-large",
48
+ "Salesforce/blip-itm-base-coco",
49
+ "Salesforce/blip-itm-large-coco",
50
+ "Salesforce/blip-itm-base-flikr",
51
+ "Salesforce/blip-itm-large-flikr",
52
+ # See all BLIP models at https://huggingface.co/models?filter=blip
53
+ ]
54
+
55
+
56
+ # Copied from transformers.models.clip.modeling_clip.contrastive_loss
57
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
58
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
59
+
60
+
61
+ # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->blip
62
+ def blip_loss(similarity: torch.Tensor) -> torch.Tensor:
63
+ caption_loss = contrastive_loss(similarity)
64
+ image_loss = contrastive_loss(similarity.t())
65
+ return (caption_loss + image_loss) / 2.0
66
+
67
+
68
+ @dataclass
69
+ class BlipForConditionalGenerationModelOutput(ModelOutput):
70
+ """
71
+ Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
72
+ last hidden states. This class also adds the loss term from the text decoder.
73
+
74
+ Args:
75
+ loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
76
+ Languge modeling loss from the text decoder.
77
+ decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
78
+ Prediction scores of the language modeling head of the text decoder model.
79
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*):
80
+ The image embeddings obtained after applying the Vision Transformer model to the input image.
81
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
82
+ Sequence of hidden-states at the output of the last layer of the model.
83
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
84
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
85
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
86
+
87
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
88
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
89
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
90
+ sequence_length)`.
91
+
92
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
93
+ heads.
94
+ """
95
+
96
+ loss: Optional[Tuple[torch.FloatTensor]] = None
97
+ decoder_logits: Optional[Tuple[torch.FloatTensor]] = None
98
+ image_embeds: Optional[torch.FloatTensor] = None
99
+ last_hidden_state: torch.FloatTensor = None
100
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
101
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
102
+
103
+
104
+ @dataclass
105
+ class BlipTextVisionModelOutput(ModelOutput):
106
+ """
107
+ Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
108
+ last hidden states. This class also adds the loss term from the text decoder.
109
+
110
+ Args:
111
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
112
+ Languge modeling loss from the text decoder.
113
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
114
+ The image embeddings obtained by applying the projection layer to the pooler_output.
115
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
116
+ Sequence of hidden-states at the output of the last layer of the model.
117
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
118
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
119
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
120
+
121
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
122
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
123
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
124
+ sequence_length)`.
125
+
126
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
127
+ heads.
128
+ """
129
+
130
+ loss: Optional[torch.FloatTensor] = None
131
+ image_embeds: Optional[torch.FloatTensor] = None
132
+ last_hidden_state: torch.FloatTensor = None
133
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
134
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
135
+
136
+
137
+ @dataclass
138
+ class BlipImageTextMatchingModelOutput(ModelOutput):
139
+ """
140
+ Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
141
+ last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity
142
+ scores.
143
+
144
+ Args:
145
+ itm_score (`torch.FloatTensor`):
146
+ The image-text similarity scores.
147
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
148
+ Languge modeling loss from the text decoder.
149
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
150
+ The image embeddings obtained by applying the projection layer to the pooler_output.
151
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
152
+ Sequence of hidden-states at the output of the last layer of the model.
153
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
154
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
155
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
156
+
157
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
158
+ vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*):
159
+ Last layer hidden-state of the vision of the vision-only branch of the model.
160
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
161
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
162
+ sequence_length)`.
163
+
164
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
165
+ heads.
166
+ question_embeds (`torch.FloatTensor`):
167
+ The question embeddings obtained by the text projection layer.
168
+ """
169
+
170
+ itm_score: Optional[torch.FloatTensor] = None
171
+ loss: Optional[torch.FloatTensor] = None
172
+ image_embeds: Optional[torch.FloatTensor] = None
173
+ last_hidden_state: torch.FloatTensor = None
174
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
175
+ vision_pooler_output: Optional[torch.FloatTensor] = None
176
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
177
+ question_embeds: Optional[Tuple[torch.FloatTensor]] = None
178
+
179
+
180
+ @dataclass
181
+ class BlipOutput(ModelOutput):
182
+ """
183
+ Args:
184
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
185
+ Contrastive loss for image-text similarity.
186
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
187
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
188
+ similarity scores.
189
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
190
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
191
+ similarity scores.
192
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
193
+ The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`].
194
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
195
+ The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`].
196
+ text_model_output(`BaseModelOutputWithPooling`):
197
+ The output of the [`BlipTextModel`].
198
+ vision_model_output(`BaseModelOutputWithPooling`):
199
+ The output of the [`BlipVisionModel`].
200
+ """
201
+
202
+ loss: Optional[torch.FloatTensor] = None
203
+ logits_per_image: torch.FloatTensor = None
204
+ logits_per_text: torch.FloatTensor = None
205
+ text_embeds: torch.FloatTensor = None
206
+ image_embeds: torch.FloatTensor = None
207
+ text_model_output: BaseModelOutputWithPooling = None
208
+ vision_model_output: BaseModelOutputWithPooling = None
209
+
210
+ def to_tuple(self) -> Tuple[Any]:
211
+ return tuple(
212
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
213
+ for k in self.keys()
214
+ )
215
+
216
+
217
+ class BlipVisionEmbeddings(nn.Module):
218
+ def __init__(self, config: BlipVisionConfig):
219
+ super().__init__()
220
+ self.config = config
221
+ self.embed_dim = config.hidden_size
222
+ self.image_size = config.image_size
223
+ self.patch_size = config.patch_size
224
+
225
+ self.class_embedding = nn.Parameter(
226
+ torch.randn(1, 1, self.embed_dim),
227
+ )
228
+
229
+ self.patch_embedding = nn.Conv2d(
230
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
231
+ )
232
+
233
+ self.num_patches = (self.image_size // self.patch_size) ** 2
234
+ self.num_positions = self.num_patches + 1
235
+
236
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
237
+
238
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
239
+ batch_size = pixel_values.shape[0]
240
+ target_dtype = self.patch_embedding.weight.dtype
241
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
242
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
243
+
244
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
245
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
246
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
247
+ return embeddings
248
+
249
+
250
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Blip
251
+ class BlipTextEmbeddings(nn.Module):
252
+ def __init__(self, config: BlipTextConfig):
253
+ super().__init__()
254
+ embed_dim = config.hidden_size
255
+
256
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
257
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
258
+
259
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
260
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
261
+
262
+ def forward(
263
+ self,
264
+ input_ids: Optional[torch.LongTensor] = None,
265
+ position_ids: Optional[torch.LongTensor] = None,
266
+ inputs_embeds: Optional[torch.FloatTensor] = None,
267
+ ) -> torch.Tensor:
268
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
269
+
270
+ if position_ids is None:
271
+ position_ids = self.position_ids[:, :seq_length]
272
+
273
+ if inputs_embeds is None:
274
+ inputs_embeds = self.token_embedding(input_ids)
275
+
276
+ position_embeddings = self.position_embedding(position_ids)
277
+ embeddings = inputs_embeds + position_embeddings
278
+
279
+ return embeddings
280
+
281
+
282
+ class BlipAttention(nn.Module):
283
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
284
+
285
+ def __init__(self, config):
286
+ super().__init__()
287
+ self.config = config
288
+ self.embed_dim = config.hidden_size
289
+ self.num_heads = config.num_attention_heads
290
+ self.head_dim = self.embed_dim // self.num_heads
291
+ if self.head_dim * self.num_heads != self.embed_dim:
292
+ raise ValueError(
293
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
294
+ f" {self.num_heads})."
295
+ )
296
+ self.scale = self.head_dim**-0.5
297
+ self.dropout = nn.Dropout(config.attention_dropout)
298
+
299
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim)
300
+
301
+ self.projection = nn.Linear(self.embed_dim, self.embed_dim)
302
+
303
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
304
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ head_mask: Optional[torch.Tensor] = None,
310
+ output_attentions: Optional[bool] = False,
311
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
312
+ """Input shape: Batch x Time x Channel"""
313
+
314
+ bsz, tgt_len, embed_dim = hidden_states.size()
315
+
316
+ mixed_qkv = self.qkv(hidden_states)
317
+ mixed_qkv = (
318
+ self.qkv(hidden_states)
319
+ .reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
320
+ .permute(2, 0, 3, 1, 4)
321
+ )
322
+ query_states, key_states, value_states = (
323
+ mixed_qkv[0],
324
+ mixed_qkv[1],
325
+ mixed_qkv[2],
326
+ )
327
+
328
+ # Take the dot product between "query" and "key" to get the raw attention scores.
329
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
330
+
331
+ attention_scores = attention_scores * self.scale
332
+
333
+ # Normalize the attention scores to probabilities.
334
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
335
+
336
+ # This is actually dropping out entire tokens to attend to, which might
337
+ # seem a bit unusual, but is taken from the original Transformer paper.
338
+ attention_probs = self.dropout(attention_probs)
339
+
340
+ # Mask heads if we want to
341
+ if head_mask is not None:
342
+ attention_probs = attention_probs * head_mask
343
+
344
+ context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
345
+
346
+ new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
347
+ context_layer = context_layer.reshape(new_context_layer_shape)
348
+
349
+ output = self.projection(context_layer)
350
+
351
+ outputs = (output, attention_probs) if output_attentions else (output, None)
352
+
353
+ return outputs
354
+
355
+
356
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Blip
357
+ class BlipMLP(nn.Module):
358
+ def __init__(self, config):
359
+ super().__init__()
360
+ self.config = config
361
+ self.activation_fn = ACT2FN[config.hidden_act]
362
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
363
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
364
+
365
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
366
+ hidden_states = self.fc1(hidden_states)
367
+ hidden_states = self.activation_fn(hidden_states)
368
+ hidden_states = self.fc2(hidden_states)
369
+ return hidden_states
370
+
371
+
372
+ class BlipEncoderLayer(nn.Module):
373
+ def __init__(self, config: BlipConfig):
374
+ super().__init__()
375
+ self.embed_dim = config.hidden_size
376
+ self.self_attn = BlipAttention(config)
377
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
378
+ self.mlp = BlipMLP(config)
379
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
380
+
381
+ def forward(
382
+ self,
383
+ hidden_states: torch.Tensor,
384
+ attention_mask: torch.Tensor,
385
+ output_attentions: Optional[bool] = False,
386
+ ) -> Tuple[torch.FloatTensor]:
387
+ """
388
+ Args:
389
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
390
+ attention_mask (`torch.FloatTensor`): attention mask of size
391
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
392
+ `(config.encoder_attention_heads,)`.
393
+ output_attentions (`bool`, *optional*):
394
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
395
+ returned tensors for more detail.
396
+ """
397
+ residual = hidden_states
398
+
399
+ hidden_states = self.layer_norm1(hidden_states)
400
+ hidden_states, attn_weights = self.self_attn(
401
+ hidden_states=hidden_states,
402
+ head_mask=attention_mask,
403
+ output_attentions=output_attentions,
404
+ )
405
+ hidden_states = hidden_states + residual
406
+ residual = hidden_states
407
+ hidden_states = self.layer_norm2(hidden_states)
408
+ hidden_states = self.mlp(hidden_states)
409
+
410
+ hidden_states = hidden_states + residual
411
+
412
+ outputs = (hidden_states,)
413
+
414
+ if output_attentions:
415
+ outputs += (attn_weights,)
416
+
417
+ return outputs
418
+
419
+
420
+ class BlipPreTrainedModel(PreTrainedModel):
421
+ """
422
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
423
+ models.
424
+ """
425
+
426
+ config_class = BlipConfig
427
+ base_model_prefix = "blip"
428
+ supports_gradient_checkpointing = True
429
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
430
+
431
+ def _init_weights(self, module):
432
+ """Initialize the weights"""
433
+ factor = self.config.initializer_range
434
+ if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
435
+ module.weight.data.normal_(mean=0.0, std=factor)
436
+ if hasattr(module, "bias") and module.bias is not None:
437
+ module.bias.data.zero_()
438
+
439
+ if isinstance(module, BlipVisionEmbeddings):
440
+ if hasattr(self.config, "vision_config"):
441
+ factor = self.config.vision_config.initializer_range
442
+ nn.init.trunc_normal_(
443
+ module.position_embedding,
444
+ mean=0.0,
445
+ std=factor,
446
+ )
447
+
448
+ nn.init.trunc_normal_(
449
+ module.class_embedding,
450
+ mean=0.0,
451
+ std=factor,
452
+ )
453
+
454
+ elif isinstance(module, nn.LayerNorm):
455
+ module.bias.data.zero_()
456
+ module.weight.data.fill_(1.0)
457
+ elif isinstance(module, nn.Linear) and module.bias is not None:
458
+ module.bias.data.zero_()
459
+
460
+ def _set_gradient_checkpointing(self, module, value=False):
461
+ if isinstance(module, BlipEncoder):
462
+ module.gradient_checkpointing = value
463
+
464
+
465
+ BLIP_START_DOCSTRING = r"""
466
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
467
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
468
+ etc.)
469
+
470
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
471
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
472
+ and behavior.
473
+
474
+ Parameters:
475
+ config ([`BlipConfig`]): Model configuration class with all the parameters of the model.
476
+ Initializing with a config file does not load the weights associated with the model, only the
477
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
478
+ """
479
+
480
+ BLIP_TEXT_INPUTS_DOCSTRING = r"""
481
+ Args:
482
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
483
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
484
+ it.
485
+
486
+ Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details.
487
+
488
+ [What are input IDs?](../glossary#input-ids)
489
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
490
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
491
+
492
+ - 1 for tokens that are **not masked**,
493
+ - 0 for tokens that are **masked**.
494
+
495
+ [What are attention masks?](../glossary#attention-mask)
496
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
497
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
498
+ config.max_position_embeddings - 1]`.
499
+
500
+ [What are position IDs?](../glossary#position-ids)
501
+ output_attentions (`bool`, *optional*):
502
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
503
+ tensors for more detail.
504
+ output_hidden_states (`bool`, *optional*):
505
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
506
+ more detail.
507
+ return_dict (`bool`, *optional*):
508
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
509
+ """
510
+
511
+ BLIP_VISION_INPUTS_DOCSTRING = r"""
512
+ Args:
513
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
514
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
515
+ [`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details.
516
+ output_attentions (`bool`, *optional*):
517
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
518
+ tensors for more detail.
519
+ output_hidden_states (`bool`, *optional*):
520
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
521
+ more detail.
522
+ return_dict (`bool`, *optional*):
523
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
524
+ """
525
+
526
+ BLIP_INPUTS_DOCSTRING = r"""
527
+ Args:
528
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
529
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
530
+ it.
531
+
532
+ Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details.
533
+
534
+ [What are input IDs?](../glossary#input-ids)
535
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
536
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
537
+
538
+ - 1 for tokens that are **not masked**,
539
+ - 0 for tokens that are **masked**.
540
+
541
+ [What are attention masks?](../glossary#attention-mask)
542
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
543
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
544
+ config.max_position_embeddings - 1]`.
545
+
546
+ [What are position IDs?](../glossary#position-ids)
547
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
548
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
549
+ [`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details.
550
+ return_loss (`bool`, *optional*):
551
+ Whether or not to return the contrastive loss.
552
+ output_attentions (`bool`, *optional*):
553
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
554
+ tensors for more detail.
555
+ output_hidden_states (`bool`, *optional*):
556
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
557
+ more detail.
558
+ return_dict (`bool`, *optional*):
559
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
560
+ """
561
+
562
+
563
+ class BlipEncoder(nn.Module):
564
+ """
565
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
566
+ [`BlipEncoderLayer`].
567
+
568
+ Args:
569
+ config (`BlipConfig`):
570
+ The corresponding vision configuration for the `BlipEncoder`.
571
+ """
572
+
573
+ def __init__(self, config: BlipConfig):
574
+ super().__init__()
575
+ self.config = config
576
+ self.layers = nn.ModuleList([BlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
577
+ self.gradient_checkpointing = False
578
+
579
+ def forward(
580
+ self,
581
+ inputs_embeds,
582
+ attention_mask: Optional[torch.LongTensor] = None,
583
+ output_attentions: Optional[bool] = None,
584
+ output_hidden_states: Optional[bool] = None,
585
+ return_dict: Optional[bool] = None,
586
+ ) -> Union[Tuple, BaseModelOutput]:
587
+ r"""
588
+ Args:
589
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
590
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
591
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
592
+ than the model's internal embedding lookup matrix.
593
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
594
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
595
+
596
+ - 1 for tokens that are **not masked**,
597
+ - 0 for tokens that are **masked**.
598
+
599
+ [What are attention masks?](../glossary#attention-mask)
600
+ output_attentions (`bool`, *optional*):
601
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
602
+ returned tensors for more detail.
603
+ output_hidden_states (`bool`, *optional*):
604
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
605
+ for more detail.
606
+ return_dict (`bool`, *optional*):
607
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
608
+ """
609
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
610
+ output_hidden_states = (
611
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
612
+ )
613
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
614
+
615
+ encoder_states = () if output_hidden_states else None
616
+ all_attentions = () if output_attentions else None
617
+
618
+ hidden_states = inputs_embeds
619
+ for idx, encoder_layer in enumerate(self.layers):
620
+ if output_hidden_states:
621
+ encoder_states = encoder_states + (hidden_states,)
622
+ if self.gradient_checkpointing and self.training:
623
+
624
+ def create_custom_forward(module):
625
+ def custom_forward(*inputs):
626
+ return module(*inputs, output_attentions)
627
+
628
+ return custom_forward
629
+
630
+ layer_outputs = torch.utils.checkpoint.checkpoint(
631
+ create_custom_forward(encoder_layer),
632
+ hidden_states,
633
+ attention_mask,
634
+ )
635
+ else:
636
+ layer_outputs = encoder_layer(
637
+ hidden_states,
638
+ attention_mask,
639
+ output_attentions=output_attentions,
640
+ )
641
+
642
+ hidden_states = layer_outputs[0]
643
+
644
+ if output_attentions:
645
+ all_attentions = all_attentions + (layer_outputs[1],)
646
+
647
+ if output_hidden_states:
648
+ encoder_states = encoder_states + (hidden_states,)
649
+
650
+ if not return_dict:
651
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
652
+ return BaseModelOutput(
653
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
654
+ )
655
+
656
+
657
+ class BlipVisionModel(BlipPreTrainedModel):
658
+ main_input_name = "pixel_values"
659
+ config_class = BlipVisionConfig
660
+
661
+ def __init__(self, config: BlipVisionConfig):
662
+ super().__init__(config)
663
+ self.config = config
664
+ embed_dim = config.hidden_size
665
+ self.embeddings = BlipVisionEmbeddings(config)
666
+ self.patch_mask_generator = ViTPatchMaskGenerator(config.patch_size)
667
+ self.encoder = BlipEncoder(config)
668
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
669
+
670
+ self.post_init()
671
+
672
+ @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
673
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=BlipVisionConfig)
674
+ def forward(
675
+ self,
676
+ pixel_values: Optional[torch.FloatTensor] = None,
677
+ pixel_masks: Optional[torch.LongTensor] = None,
678
+ output_attentions: Optional[bool] = None,
679
+ output_hidden_states: Optional[bool] = None,
680
+ return_dict: Optional[bool] = None,
681
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
682
+ r"""
683
+ Returns:
684
+
685
+ """
686
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
687
+ output_hidden_states = (
688
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
689
+ )
690
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
691
+
692
+ if pixel_values is None:
693
+ raise ValueError("You have to specify pixel_values")
694
+
695
+ hidden_states = self.embeddings(pixel_values)
696
+ B, N, D = hidden_states.shape
697
+ # print('Before mask:', hidden_states.shape)
698
+ if pixel_masks is not None:
699
+ assert pixel_masks.shape[0] == 1
700
+ patch_masks = self.patch_mask_generator(pixel_masks)
701
+ # print(patch_masks.shape)
702
+ patch_masks = patch_masks.unsqueeze(-1).expand_as(hidden_states)
703
+ hidden_states = hidden_states.masked_select(patch_masks).view(B, -1, D)
704
+ # print('After mask:', hidden_states.shape)
705
+
706
+ encoder_outputs = self.encoder(
707
+ inputs_embeds=hidden_states,
708
+ output_attentions=output_attentions,
709
+ output_hidden_states=output_hidden_states,
710
+ return_dict=return_dict,
711
+ )
712
+
713
+ last_hidden_state = encoder_outputs[0]
714
+ last_hidden_state = self.post_layernorm(last_hidden_state)
715
+
716
+ pooled_output = last_hidden_state[:, 0, :]
717
+ pooled_output = self.post_layernorm(pooled_output)
718
+
719
+ if not return_dict:
720
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
721
+
722
+ return BaseModelOutputWithPooling(
723
+ last_hidden_state=last_hidden_state,
724
+ pooler_output=pooled_output,
725
+ hidden_states=encoder_outputs.hidden_states,
726
+ attentions=encoder_outputs.attentions,
727
+ )
728
+
729
+ def get_input_embeddings(self):
730
+ return self.embeddings
731
+
732
+
733
+ @add_start_docstrings(BLIP_START_DOCSTRING)
734
+ class BlipModel(BlipPreTrainedModel):
735
+ config_class = BlipConfig
736
+
737
+ def __init__(self, config: BlipConfig):
738
+ super().__init__(config)
739
+
740
+ if not isinstance(config.text_config, BlipTextConfig):
741
+ raise ValueError(
742
+ "config.text_config is expected to be of type BlipTextConfig but is of type"
743
+ f" {type(config.text_config)}."
744
+ )
745
+
746
+ if not isinstance(config.vision_config, BlipVisionConfig):
747
+ raise ValueError(
748
+ "config.vision_config is expected to be of type BlipVisionConfig but is of type"
749
+ f" {type(config.vision_config)}."
750
+ )
751
+
752
+ text_config = config.text_config
753
+ vision_config = config.vision_config
754
+
755
+ self.projection_dim = config.projection_dim
756
+ self.text_embed_dim = text_config.hidden_size
757
+ self.vision_embed_dim = vision_config.hidden_size
758
+
759
+ self.text_model = BlipTextModel(text_config)
760
+ self.vision_model = BlipVisionModel(vision_config)
761
+
762
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
763
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
764
+ self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
765
+
766
+ # Initialize weights and apply final processing
767
+ self.post_init()
768
+
769
+ @add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING)
770
+ def get_text_features(
771
+ self,
772
+ input_ids: Optional[torch.Tensor] = None,
773
+ attention_mask: Optional[torch.Tensor] = None,
774
+ position_ids: Optional[torch.Tensor] = None,
775
+ return_dict: Optional[bool] = None,
776
+ ) -> torch.FloatTensor:
777
+ r"""
778
+ Returns:
779
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
780
+ applying the projection layer to the pooled output of [`BlipTextModel`].
781
+
782
+ Examples:
783
+
784
+ ```python
785
+ >>> from transformers import AutoProcessor, BlipModel
786
+
787
+ >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
788
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
789
+
790
+ >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
791
+ >>> text_features = model.get_text_features(**inputs)
792
+ ```"""
793
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
794
+
795
+ text_outputs = self.text_model(
796
+ input_ids=input_ids,
797
+ attention_mask=attention_mask,
798
+ position_ids=position_ids,
799
+ return_dict=return_dict,
800
+ )
801
+
802
+ pooled_output = text_outputs[1]
803
+ text_features = self.text_projection(pooled_output)
804
+
805
+ return text_features
806
+
807
+ @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
808
+ def get_image_features(
809
+ self,
810
+ pixel_values: Optional[torch.FloatTensor] = None,
811
+ return_dict: Optional[bool] = None,
812
+ ) -> torch.FloatTensor:
813
+ r"""
814
+ Returns:
815
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
816
+ applying the projection layer to the pooled output of [`BlipVisionModel`].
817
+
818
+ Examples:
819
+
820
+ ```python
821
+ >>> from PIL import Image
822
+ >>> import requests
823
+ >>> from transformers import AutoProcessor, BlipModel
824
+
825
+ >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
826
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
827
+
828
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
829
+ >>> image = Image.open(requests.get(url, stream=True).raw)
830
+
831
+ >>> inputs = processor(images=image, return_tensors="pt")
832
+
833
+ >>> image_features = model.get_image_features(**inputs)
834
+ ```"""
835
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
836
+
837
+ vision_outputs = self.vision_model(
838
+ pixel_values=pixel_values,
839
+ return_dict=return_dict,
840
+ )
841
+
842
+ pooled_output = vision_outputs[1] # pooled_output
843
+ image_features = self.visual_projection(pooled_output)
844
+
845
+ return image_features
846
+
847
+ @add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING)
848
+ @replace_return_docstrings(output_type=BlipOutput, config_class=BlipConfig)
849
+ def forward(
850
+ self,
851
+ input_ids: Optional[torch.LongTensor] = None,
852
+ pixel_values: Optional[torch.FloatTensor] = None,
853
+ pixel_masks: Optional[torch.FloatTensor] = None,
854
+ attention_mask: Optional[torch.Tensor] = None,
855
+ position_ids: Optional[torch.LongTensor] = None,
856
+ return_loss: Optional[bool] = None,
857
+ output_attentions: Optional[bool] = None,
858
+ output_hidden_states: Optional[bool] = None,
859
+ return_dict: Optional[bool] = None,
860
+ ) -> Union[Tuple, BlipOutput]:
861
+ r"""
862
+ Returns:
863
+
864
+ Examples:
865
+
866
+ ```python
867
+ >>> from PIL import Image
868
+ >>> import requests
869
+ >>> from transformers import AutoProcessor, BlipModel
870
+
871
+ >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
872
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
873
+
874
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
875
+ >>> image = Image.open(requests.get(url, stream=True).raw)
876
+
877
+ >>> inputs = processor(
878
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
879
+ ... )
880
+
881
+ >>> outputs = model(**inputs)
882
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
883
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
884
+ ```"""
885
+ # Use BLIP model's config for some fields (if specified) instead of those of vision & text components.
886
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
887
+ output_hidden_states = (
888
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
889
+ )
890
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
891
+
892
+ vision_outputs = self.vision_model(
893
+ pixel_values=pixel_values,
894
+ pixel_masks=pixel_masks,
895
+ output_attentions=output_attentions,
896
+ output_hidden_states=output_hidden_states,
897
+ return_dict=return_dict,
898
+ )
899
+
900
+ text_outputs = self.text_model(
901
+ input_ids=input_ids,
902
+ attention_mask=attention_mask,
903
+ position_ids=position_ids,
904
+ output_attentions=output_attentions,
905
+ output_hidden_states=output_hidden_states,
906
+ return_dict=return_dict,
907
+ )
908
+
909
+ image_embeds = vision_outputs[1]
910
+ image_embeds = self.visual_projection(image_embeds)
911
+
912
+ text_embeds = text_outputs[1]
913
+ text_embeds = self.text_projection(text_embeds)
914
+
915
+ # normalized features
916
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
917
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
918
+
919
+ # cosine similarity as logits
920
+ logit_scale = self.logit_scale.exp()
921
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
922
+ logits_per_image = logits_per_text.t()
923
+
924
+ loss = None
925
+ if return_loss:
926
+ loss = blip_loss(logits_per_text)
927
+
928
+ if not return_dict:
929
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
930
+ return ((loss,) + output) if loss is not None else output
931
+
932
+ return BlipOutput(
933
+ loss=loss,
934
+ logits_per_image=logits_per_image,
935
+ logits_per_text=logits_per_text,
936
+ text_embeds=text_embeds,
937
+ image_embeds=image_embeds,
938
+ text_model_output=text_outputs,
939
+ vision_model_output=vision_outputs,
940
+ )
941
+
942
+
943
+ @add_start_docstrings(
944
+ """
945
+ BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass
946
+ `input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise,
947
+ the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption
948
+ from the text input. If no text input is provided, the decoder will start with the [BOS] token only.
949
+ """,
950
+ BLIP_START_DOCSTRING,
951
+ )
952
+ class BlipForConditionalGeneration(BlipPreTrainedModel):
953
+ config_class = BlipConfig
954
+ _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
955
+ main_input_name = "pixel_values"
956
+
957
+ def __init__(self, config: BlipConfig):
958
+ super().__init__(config)
959
+
960
+ self.vision_model = BlipVisionModel(config.vision_config)
961
+
962
+ self.text_decoder = BlipTextLMHeadModel(config.text_config)
963
+
964
+ self.decoder_input_ids = config.text_config.bos_token_id
965
+ self.decoder_pad_token_id = config.text_config.pad_token_id
966
+
967
+ # Initialize weights and apply final processing
968
+ self.post_init()
969
+
970
+ def get_input_embeddings(self) -> nn.Module:
971
+ return self.vision_model.embeddings.patch_embedding
972
+
973
+ @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
974
+ @replace_return_docstrings(output_type=BlipForConditionalGenerationModelOutput, config_class=BlipVisionConfig)
975
+ def forward(
976
+ self,
977
+ pixel_values: torch.FloatTensor,
978
+ input_ids: Optional[torch.LongTensor] = None,
979
+ attention_mask: Optional[torch.LongTensor] = None,
980
+ output_attentions: Optional[bool] = None,
981
+ output_hidden_states: Optional[bool] = None,
982
+ labels: Optional[torch.LongTensor] = None,
983
+ return_dict: Optional[bool] = None,
984
+ ) -> Union[Tuple, BlipForConditionalGenerationModelOutput]:
985
+ r"""
986
+ Returns:
987
+
988
+ Examples:
989
+
990
+ ```python
991
+ >>> from PIL import Image
992
+ >>> import requests
993
+ >>> from transformers import AutoProcessor, BlipForConditionalGeneration
994
+
995
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
996
+ >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
997
+
998
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
999
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1000
+ >>> text = "A picture of"
1001
+
1002
+ >>> inputs = processor(images=image, text=text, return_tensors="pt")
1003
+
1004
+ >>> outputs = model(**inputs)
1005
+ ```"""
1006
+
1007
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1008
+
1009
+ vision_outputs = self.vision_model(
1010
+ pixel_values=pixel_values,
1011
+ output_attentions=output_attentions,
1012
+ output_hidden_states=output_hidden_states,
1013
+ return_dict=return_dict,
1014
+ )
1015
+
1016
+ image_embeds = vision_outputs[0]
1017
+
1018
+ outputs = self.text_decoder(
1019
+ input_ids=input_ids,
1020
+ attention_mask=attention_mask,
1021
+ encoder_hidden_states=image_embeds,
1022
+ labels=labels,
1023
+ return_dict=return_dict,
1024
+ reduction="mean",
1025
+ )
1026
+
1027
+ if not return_dict:
1028
+ outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
1029
+ return tuple(output for output in outputs if output is not None)
1030
+
1031
+ return BlipForConditionalGenerationModelOutput(
1032
+ loss=outputs.loss,
1033
+ decoder_logits=outputs.logits,
1034
+ image_embeds=image_embeds,
1035
+ last_hidden_state=vision_outputs.last_hidden_state,
1036
+ hidden_states=vision_outputs.hidden_states,
1037
+ attentions=vision_outputs.attentions,
1038
+ )
1039
+
1040
+ @torch.no_grad()
1041
+ def generate(
1042
+ self,
1043
+ pixel_values: torch.FloatTensor,
1044
+ pixel_masks: torch.Tensor = None,
1045
+ input_ids: Optional[torch.LongTensor] = None,
1046
+ attention_mask: Optional[torch.LongTensor] = None,
1047
+ **generate_kwargs,
1048
+ ) -> torch.LongTensor:
1049
+ r"""
1050
+ Overrides *generate* function to be able to use the model as a conditional generator
1051
+
1052
+ Parameters:
1053
+ pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
1054
+ Input image to be processed
1055
+ input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
1056
+ The sequence used as a prompt for the generation.
1057
+ attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
1058
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1059
+
1060
+
1061
+ Examples:
1062
+ ```python
1063
+ >>> from PIL import Image
1064
+ >>> import requests
1065
+ >>> from transformers import AutoProcessor, BlipForConditionalGeneration
1066
+
1067
+ >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
1068
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
1069
+
1070
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1071
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1072
+
1073
+ >>> inputs = processor(images=image, return_tensors="pt")
1074
+
1075
+ >>> outputs = model.generate(**inputs)
1076
+ >>> print(processor.decode(outputs[0], skip_special_tokens=True))
1077
+ two cats are laying on a couch
1078
+ ```
1079
+ """
1080
+
1081
+ batch_size = pixel_values.shape[0]
1082
+ vision_outputs = self.vision_model(
1083
+ pixel_values=pixel_values,
1084
+ pixel_masks=pixel_masks,
1085
+ )
1086
+
1087
+ image_embeds = vision_outputs[0]
1088
+
1089
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
1090
+
1091
+ if isinstance(input_ids, list):
1092
+ input_ids = torch.LongTensor(input_ids)
1093
+ elif input_ids is None:
1094
+ input_ids = (
1095
+ torch.LongTensor([[self.decoder_input_ids, self.config.text_config.eos_token_id]])
1096
+ .repeat(batch_size, 1)
1097
+ .to(image_embeds.device)
1098
+ )
1099
+
1100
+ input_ids[:, 0] = self.config.text_config.bos_token_id
1101
+ attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
1102
+
1103
+ outputs = self.text_decoder.generate(
1104
+ input_ids=input_ids[:, :-1],
1105
+ eos_token_id=self.config.text_config.sep_token_id,
1106
+ pad_token_id=self.config.text_config.pad_token_id,
1107
+ attention_mask=attention_mask,
1108
+ encoder_hidden_states=image_embeds,
1109
+ encoder_attention_mask=image_attention_mask,
1110
+ **generate_kwargs,
1111
+ )
1112
+
1113
+ return outputs
1114
+
1115
+
1116
+ @add_start_docstrings(
1117
+ """
1118
+ BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text
1119
+ decoder. The vision encoder will encode the input image, the text encoder will encode the input question together
1120
+ with the encoding of the image, and the text decoder will output the answer to the question.
1121
+ """,
1122
+ BLIP_START_DOCSTRING,
1123
+ )
1124
+ class BlipForQuestionAnswering(BlipPreTrainedModel):
1125
+ config_class = BlipConfig
1126
+ _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
1127
+
1128
+ def __init__(self, config: BlipConfig):
1129
+ super().__init__(config)
1130
+
1131
+ self.vision_model = BlipVisionModel(config.vision_config)
1132
+
1133
+ self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
1134
+
1135
+ self.text_decoder = BlipTextLMHeadModel(config.text_config)
1136
+
1137
+ self.decoder_pad_token_id = config.text_config.pad_token_id
1138
+ self.decoder_start_token_id = config.text_config.bos_token_id
1139
+
1140
+ # Initialize weights and apply final processing
1141
+ self.post_init()
1142
+
1143
+ def get_input_embeddings(self) -> nn.Module:
1144
+ return self.vision_model.embeddings.patch_embedding
1145
+
1146
+ # Adapted from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
1147
+ def _shift_right(self, input_ids):
1148
+ pad_token_id = self.decoder_pad_token_id
1149
+
1150
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1151
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1152
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1153
+
1154
+ # replace possible -100 values in labels by `pad_token_id`
1155
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
1156
+
1157
+ return shifted_input_ids
1158
+
1159
+ @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
1160
+ @replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
1161
+ def forward(
1162
+ self,
1163
+ input_ids: torch.LongTensor,
1164
+ pixel_values: torch.FloatTensor,
1165
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1166
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1167
+ attention_mask: Optional[torch.LongTensor] = None,
1168
+ output_attentions: Optional[bool] = None,
1169
+ output_hidden_states: Optional[bool] = None,
1170
+ labels: Optional[torch.LongTensor] = None,
1171
+ return_dict: Optional[bool] = None,
1172
+ ) -> Union[Tuple, BlipTextVisionModelOutput]:
1173
+ r"""
1174
+ Returns:
1175
+
1176
+ Examples:
1177
+
1178
+ ```python
1179
+ >>> from PIL import Image
1180
+ >>> import requests
1181
+ >>> from transformers import AutoProcessor, BlipForQuestionAnswering
1182
+
1183
+ >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
1184
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
1185
+
1186
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1187
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1188
+
1189
+ >>> # training
1190
+ >>> text = "How many cats are in the picture?"
1191
+ >>> label = "2"
1192
+ >>> inputs = processor(images=image, text=text, return_tensors="pt")
1193
+ >>> labels = processor(text=label, return_tensors="pt").input_ids
1194
+
1195
+ >>> inputs["labels"] = labels
1196
+ >>> outputs = model(**inputs)
1197
+ >>> loss = outputs.loss
1198
+ >>> loss.backward()
1199
+
1200
+ >>> # inference
1201
+ >>> text = "How many cats are in the picture?"
1202
+ >>> inputs = processor(images=image, text=text, return_tensors="pt")
1203
+ >>> outputs = model.generate(**inputs)
1204
+ >>> print(processor.decode(outputs[0], skip_special_tokens=True))
1205
+ 2
1206
+ ```"""
1207
+ if labels is None and decoder_input_ids is None:
1208
+ raise ValueError(
1209
+ "Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
1210
+ " `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
1211
+ " are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`"
1212
+ )
1213
+
1214
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1215
+
1216
+ vision_outputs = self.vision_model(
1217
+ pixel_values=pixel_values,
1218
+ output_attentions=output_attentions,
1219
+ output_hidden_states=output_hidden_states,
1220
+ return_dict=return_dict,
1221
+ )
1222
+
1223
+ image_embeds = vision_outputs[0]
1224
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
1225
+
1226
+ question_embeds = self.text_encoder(
1227
+ input_ids=input_ids,
1228
+ attention_mask=attention_mask,
1229
+ encoder_hidden_states=image_embeds,
1230
+ encoder_attention_mask=image_attention_mask,
1231
+ return_dict=return_dict,
1232
+ )
1233
+
1234
+ question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
1235
+
1236
+ if labels is not None and decoder_input_ids is None:
1237
+ # get decoder inputs from shifting lm labels to the right - this is used in training mode
1238
+ decoder_input_ids = self._shift_right(labels)
1239
+ # replace possible -100 values in labels by `pad_token_id`
1240
+ labels = labels.masked_fill(labels == self.decoder_pad_token_id, -100)
1241
+
1242
+ answer_output = self.text_decoder(
1243
+ input_ids=decoder_input_ids,
1244
+ attention_mask=decoder_attention_mask,
1245
+ encoder_hidden_states=question_embeds,
1246
+ encoder_attention_mask=attention_mask,
1247
+ labels=labels,
1248
+ return_dict=return_dict,
1249
+ reduction="mean",
1250
+ )
1251
+
1252
+ if labels is not None:
1253
+ decoder_loss = answer_output.loss.mean() if return_dict else answer_output[0].mean()
1254
+ else:
1255
+ decoder_loss = None
1256
+
1257
+ if not return_dict:
1258
+ outputs = (decoder_loss, image_embeds, vision_outputs[0]) + vision_outputs[2:]
1259
+ return tuple(output for output in outputs if output is not None)
1260
+
1261
+ return BlipTextVisionModelOutput(
1262
+ loss=decoder_loss,
1263
+ image_embeds=image_embeds,
1264
+ last_hidden_state=vision_outputs.last_hidden_state,
1265
+ hidden_states=vision_outputs.hidden_states,
1266
+ attentions=vision_outputs.attentions,
1267
+ )
1268
+
1269
+ @torch.no_grad()
1270
+ def generate(
1271
+ self,
1272
+ input_ids: torch.LongTensor,
1273
+ pixel_values: torch.FloatTensor,
1274
+ pixel_masks: torch.Tensor = None,
1275
+ attention_mask: Optional[torch.LongTensor] = None,
1276
+ **generate_kwargs,
1277
+ ) -> torch.LongTensor:
1278
+ r"""
1279
+ Overrides *generate* function to be able to use the model as a conditional generator
1280
+
1281
+ Parameters:
1282
+ input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
1283
+ The sequence used as a prompt for the generation.
1284
+ pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
1285
+ Input image to be processed
1286
+ attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
1287
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
1288
+ tokens that are NOT MASKED, `0` for MASKED tokens.
1289
+ **generate_kwargs:
1290
+ Additional arguments passed to the *generate* function of the decoder
1291
+
1292
+
1293
+ Examples:
1294
+ ```python
1295
+ >>> from PIL import Image
1296
+ >>> import requests
1297
+ >>> from transformers import AutoProcessor, BlipForQuestionAnswering
1298
+
1299
+ >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
1300
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
1301
+
1302
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1303
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1304
+ >>> text = "How many cats are in the picture?"
1305
+
1306
+ >>> inputs = processor(images=image, text=text, return_tensors="pt")
1307
+
1308
+ >>> outputs = model.generate(**inputs)
1309
+ >>> print(processor.decode(outputs[0], skip_special_tokens=True))
1310
+ 2
1311
+ ```
1312
+ """
1313
+ vision_outputs = self.vision_model(
1314
+ pixel_values=pixel_values,
1315
+ pixel_masks=pixel_masks
1316
+ )
1317
+
1318
+ image_embeds = vision_outputs[0]
1319
+
1320
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
1321
+
1322
+ if isinstance(input_ids, list):
1323
+ input_ids = torch.LongTensor(input_ids)
1324
+
1325
+ question_outputs = self.text_encoder(
1326
+ input_ids=input_ids,
1327
+ attention_mask=attention_mask,
1328
+ encoder_hidden_states=image_embeds,
1329
+ encoder_attention_mask=image_attention_mask,
1330
+ return_dict=False,
1331
+ )
1332
+
1333
+ question_embeds = question_outputs[0]
1334
+
1335
+ question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long).to(question_embeds.device)
1336
+
1337
+ bos_ids = torch.full(
1338
+ (question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device
1339
+ )
1340
+
1341
+ outputs = self.text_decoder.generate(
1342
+ input_ids=bos_ids,
1343
+ eos_token_id=self.config.text_config.sep_token_id,
1344
+ pad_token_id=self.config.text_config.pad_token_id,
1345
+ encoder_hidden_states=question_embeds,
1346
+ encoder_attention_mask=question_attention_mask,
1347
+ **generate_kwargs,
1348
+ )
1349
+
1350
+ return outputs
1351
+
1352
+
1353
+ @add_start_docstrings(
1354
+ """
1355
+ BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of
1356
+ image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to
1357
+ the image.
1358
+ """,
1359
+ BLIP_START_DOCSTRING,
1360
+ )
1361
+ class BlipForImageTextRetrieval(BlipPreTrainedModel):
1362
+ config_class = BlipConfig
1363
+
1364
+ def __init__(self, config: BlipConfig):
1365
+ super().__init__(config)
1366
+
1367
+ self.vision_model = BlipVisionModel(config.vision_config)
1368
+
1369
+ self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
1370
+
1371
+ # vision projection layer
1372
+ self.vision_proj = nn.Linear(config.vision_config.hidden_size, config.image_text_hidden_size)
1373
+
1374
+ # text projection layer
1375
+ self.text_proj = nn.Linear(config.text_config.hidden_size, config.image_text_hidden_size)
1376
+
1377
+ # image text matching head
1378
+ self.itm_head = nn.Linear(config.text_config.hidden_size, 2)
1379
+
1380
+ self.decoder_pad_token_id = (
1381
+ config.text_config.pad_token_id
1382
+ if not hasattr(config, "decoder_pad_token_id")
1383
+ else config.decoder_pad_token_id
1384
+ )
1385
+ self.decoder_start_token_id = (
1386
+ config.text_config.bos_token_id
1387
+ if not hasattr(config, "decoder_start_token_id")
1388
+ else config.decoder_start_token_id
1389
+ )
1390
+
1391
+ # Initialize weights and apply final processing
1392
+ self.post_init()
1393
+
1394
+ def get_input_embeddings(self) -> nn.Module:
1395
+ return self.vision_model.embeddings.patch_embedding
1396
+
1397
+ @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
1398
+ @replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
1399
+ def forward(
1400
+ self,
1401
+ input_ids: torch.LongTensor,
1402
+ pixel_values: torch.FloatTensor,
1403
+ use_itm_head: Optional[bool] = True,
1404
+ attention_mask: Optional[torch.LongTensor] = None,
1405
+ output_attentions: Optional[bool] = None,
1406
+ output_hidden_states: Optional[bool] = None,
1407
+ return_dict: Optional[bool] = None,
1408
+ ) -> Union[Tuple, BlipTextVisionModelOutput]:
1409
+ r"""
1410
+ Returns:
1411
+
1412
+ Examples:
1413
+
1414
+ ```python
1415
+ >>> from PIL import Image
1416
+ >>> import requests
1417
+ >>> from transformers import AutoProcessor, BlipForImageTextRetrieval
1418
+
1419
+ >>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
1420
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
1421
+
1422
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1423
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1424
+ >>> text = "an image of a cat"
1425
+
1426
+ >>> inputs = processor(images=image, text=text, return_tensors="pt")
1427
+ >>> outputs = model(**inputs)
1428
+ ```
1429
+ """
1430
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1431
+
1432
+ vision_outputs = self.vision_model(
1433
+ pixel_values=pixel_values,
1434
+ output_attentions=output_attentions,
1435
+ output_hidden_states=output_hidden_states,
1436
+ return_dict=return_dict,
1437
+ )
1438
+
1439
+ image_embeds = vision_outputs[0]
1440
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
1441
+
1442
+ if use_itm_head:
1443
+ question_embeds = self.text_encoder(
1444
+ input_ids=input_ids,
1445
+ attention_mask=attention_mask,
1446
+ encoder_hidden_states=image_embeds,
1447
+ encoder_attention_mask=image_atts,
1448
+ return_dict=return_dict,
1449
+ )
1450
+ question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
1451
+
1452
+ output = self.itm_head(question_embeds[:, 0, :])
1453
+ else:
1454
+ question_embeds = self.text_encoder(
1455
+ input_ids=input_ids,
1456
+ attention_mask=attention_mask,
1457
+ return_dict=return_dict,
1458
+ )
1459
+ question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
1460
+
1461
+ image_feat = normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
1462
+ text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1)
1463
+
1464
+ output = image_feat @ text_feat.t()
1465
+
1466
+ if not return_dict:
1467
+ outputs = (output, vision_outputs[0]) + vision_outputs[2:] + (question_embeds,)
1468
+ return tuple(output for output in outputs if output is not None)
1469
+
1470
+ return BlipImageTextMatchingModelOutput(
1471
+ itm_score=output,
1472
+ last_hidden_state=vision_outputs.last_hidden_state,
1473
+ hidden_states=vision_outputs.hidden_states,
1474
+ attentions=vision_outputs.attentions,
1475
+ question_embeds=question_embeds,
1476
+ )
captioner/modeling_git.py ADDED
@@ -0,0 +1,1587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
3
+ # All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch GIT model."""
17
+
18
+
19
+ import math
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.file_utils import ModelOutput
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutput,
32
+ BaseModelOutputWithPast,
33
+ BaseModelOutputWithPooling,
34
+ CausalLMOutputWithPast,
35
+ )
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
38
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
39
+ from transformers.models.git.configuration_git import GitConfig, GitVisionConfig
40
+ from .vit_pixel_masks_utils import ViTPatchMaskGenerator
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ _CHECKPOINT_FOR_DOC = "microsoft/git-base"
46
+ _CONFIG_FOR_DOC = "GitConfig"
47
+
48
+ GIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
49
+ "microsoft/git-base",
50
+ # See all GIT models at https://huggingface.co/models?filter=git
51
+ ]
52
+
53
+
54
+ @dataclass
55
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Git
56
+ class GitVisionModelOutput(ModelOutput):
57
+ """
58
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
59
+
60
+ Args:
61
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
62
+ The image embeddings obtained by applying the projection layer to the pooler_output.
63
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
64
+ Sequence of hidden-states at the output of the last layer of the model.
65
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
66
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
67
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
68
+
69
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
70
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
71
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
72
+ sequence_length)`.
73
+
74
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
75
+ heads.
76
+ """
77
+
78
+ image_embeds: Optional[torch.FloatTensor] = None
79
+ last_hidden_state: torch.FloatTensor = None
80
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
81
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
82
+
83
+
84
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
85
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
86
+ """
87
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
88
+ """
89
+ bsz, src_len = mask.size()
90
+ tgt_len = tgt_len if tgt_len is not None else src_len
91
+
92
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
93
+
94
+ inverted_mask = 1.0 - expanded_mask
95
+
96
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
97
+
98
+
99
+ class GitEmbeddings(nn.Module):
100
+ """Construct the embeddings from word and position embeddings."""
101
+
102
+ def __init__(self, config):
103
+ super().__init__()
104
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
105
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
106
+
107
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
108
+ # any TensorFlow checkpoint file
109
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
110
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
111
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
112
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
113
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
114
+
115
+ def forward(
116
+ self,
117
+ input_ids: Optional[torch.LongTensor] = None,
118
+ position_ids: Optional[torch.LongTensor] = None,
119
+ inputs_embeds: Optional[torch.FloatTensor] = None,
120
+ past_key_values_length: int = 0,
121
+ ) -> torch.Tensor:
122
+ if input_ids is not None:
123
+ input_shape = input_ids.size()
124
+ else:
125
+ input_shape = inputs_embeds.size()[:-1]
126
+
127
+ seq_length = input_shape[1]
128
+
129
+ if position_ids is None:
130
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
131
+
132
+ if inputs_embeds is None:
133
+ embeddings = self.word_embeddings(input_ids)
134
+ else:
135
+ embeddings = inputs_embeds
136
+
137
+ if self.position_embedding_type == "absolute":
138
+ position_embeddings = self.position_embeddings(position_ids)
139
+ embeddings += position_embeddings
140
+ embeddings = self.LayerNorm(embeddings)
141
+ embeddings = self.dropout(embeddings)
142
+ return embeddings
143
+
144
+
145
+ class GitSelfAttention(nn.Module):
146
+ def __init__(self, config, position_embedding_type=None):
147
+ super().__init__()
148
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
149
+ raise ValueError(
150
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
151
+ f"heads ({config.num_attention_heads})"
152
+ )
153
+
154
+ self.num_attention_heads = config.num_attention_heads
155
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
156
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
157
+ self.image_patch_tokens = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
158
+ if config.num_image_with_embedding is not None:
159
+ self.image_patch_tokens *= config.num_image_with_embedding
160
+
161
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
162
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
163
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
164
+
165
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
166
+ self.position_embedding_type = position_embedding_type or getattr(
167
+ config, "position_embedding_type", "absolute"
168
+ )
169
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
170
+ self.max_position_embeddings = config.max_position_embeddings
171
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
172
+
173
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
174
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
175
+ x = x.view(new_x_shape)
176
+ return x.permute(0, 2, 1, 3)
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states: torch.Tensor,
181
+ attention_mask: Optional[torch.FloatTensor] = None,
182
+ head_mask: Optional[torch.FloatTensor] = None,
183
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
184
+ output_attentions: Optional[bool] = False,
185
+ pixel_values_present: Optional[bool] = False,
186
+ image_token_num: Optional[int] = None
187
+ ) -> Tuple[torch.Tensor]:
188
+ mixed_query_layer = self.query(hidden_states)
189
+ if image_token_num is not None:
190
+ cutoff = image_token_num
191
+ else:
192
+ cutoff = self.image_patch_tokens if pixel_values_present else 0
193
+ if past_key_value is not None:
194
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
195
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
196
+ key_layer = torch.cat([key_layer[:, :, :cutoff, :], past_key_value[0], key_layer[:, :, -1:, :]], dim=2)
197
+ value_layer = torch.cat(
198
+ [value_layer[:, :, :cutoff, :], past_key_value[1], value_layer[:, :, -1:, :]], dim=2
199
+ )
200
+ else:
201
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
202
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
203
+
204
+ query_layer = self.transpose_for_scores(mixed_query_layer)
205
+
206
+ use_cache = past_key_value is not None
207
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
208
+ # Further calls to cross_attention layer can then reuse all cross-attention
209
+ # key/value_states (first "if" case)
210
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
211
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
212
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
213
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
214
+ # NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
215
+ past_key_value = (
216
+ key_layer[:, :, cutoff:, :],
217
+ value_layer[:, :, cutoff:, :],
218
+ )
219
+
220
+ # Take the dot product between "query" and "key" to get the raw attention scores.
221
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
222
+
223
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
224
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
225
+ if use_cache:
226
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
227
+ -1, 1
228
+ )
229
+ else:
230
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
231
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
232
+ distance = position_ids_l - position_ids_r
233
+
234
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
235
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
236
+
237
+ if self.position_embedding_type == "relative_key":
238
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
239
+ attention_scores = attention_scores + relative_position_scores
240
+ elif self.position_embedding_type == "relative_key_query":
241
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
242
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
243
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
244
+
245
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
246
+ if attention_mask is not None:
247
+ # Apply the attention mask is (precomputed for all layers in GitModel forward() function)
248
+ attention_scores = attention_scores + attention_mask
249
+
250
+ # Normalize the attention scores to probabilities.
251
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
252
+
253
+ # This is actually dropping out entire tokens to attend to, which might
254
+ # seem a bit unusual, but is taken from the original Transformer paper.
255
+ attention_probs = self.dropout(attention_probs)
256
+
257
+ # Mask heads if we want to
258
+ if head_mask is not None:
259
+ attention_probs = attention_probs * head_mask
260
+
261
+ context_layer = torch.matmul(attention_probs, value_layer)
262
+
263
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
264
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
265
+ context_layer = context_layer.view(new_context_layer_shape)
266
+
267
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
268
+
269
+ outputs = outputs + (past_key_value,)
270
+ return outputs
271
+
272
+
273
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
274
+ class GitSelfOutput(nn.Module):
275
+ def __init__(self, config):
276
+ super().__init__()
277
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
278
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
279
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
280
+
281
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
282
+ hidden_states = self.dense(hidden_states)
283
+ hidden_states = self.dropout(hidden_states)
284
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
285
+ return hidden_states
286
+
287
+
288
+ class GitAttention(nn.Module):
289
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git
290
+ def __init__(self, config, position_embedding_type=None):
291
+ super().__init__()
292
+ self.self = GitSelfAttention(config, position_embedding_type=position_embedding_type)
293
+ self.output = GitSelfOutput(config)
294
+ self.pruned_heads = set()
295
+
296
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
297
+ def prune_heads(self, heads):
298
+ if len(heads) == 0:
299
+ return
300
+ heads, index = find_pruneable_heads_and_indices(
301
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
302
+ )
303
+
304
+ # Prune linear layers
305
+ self.self.query = prune_linear_layer(self.self.query, index)
306
+ self.self.key = prune_linear_layer(self.self.key, index)
307
+ self.self.value = prune_linear_layer(self.self.value, index)
308
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
309
+
310
+ # Update hyper params and store pruned heads
311
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
312
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
313
+ self.pruned_heads = self.pruned_heads.union(heads)
314
+
315
+ def forward(
316
+ self,
317
+ hidden_states: torch.Tensor,
318
+ attention_mask: Optional[torch.FloatTensor] = None,
319
+ head_mask: Optional[torch.FloatTensor] = None,
320
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
321
+ output_attentions: Optional[bool] = False,
322
+ pixel_values_present: Optional[bool] = False,
323
+ image_token_num: Optional[int] = None
324
+ ) -> Tuple[torch.Tensor]:
325
+ self_outputs = self.self(
326
+ hidden_states,
327
+ attention_mask,
328
+ head_mask,
329
+ past_key_value,
330
+ output_attentions,
331
+ pixel_values_present,
332
+ image_token_num
333
+ )
334
+ attention_output = self.output(self_outputs[0], hidden_states)
335
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
336
+ return outputs
337
+
338
+
339
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate
340
+ class GitIntermediate(nn.Module):
341
+ def __init__(self, config):
342
+ super().__init__()
343
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
344
+ if isinstance(config.hidden_act, str):
345
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
346
+ else:
347
+ self.intermediate_act_fn = config.hidden_act
348
+
349
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
350
+ hidden_states = self.dense(hidden_states)
351
+ hidden_states = self.intermediate_act_fn(hidden_states)
352
+ return hidden_states
353
+
354
+
355
+ # Copied from transformers.models.bert.modeling_bert.BertOutput
356
+ class GitOutput(nn.Module):
357
+ def __init__(self, config):
358
+ super().__init__()
359
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
360
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
361
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
362
+
363
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
364
+ hidden_states = self.dense(hidden_states)
365
+ hidden_states = self.dropout(hidden_states)
366
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
367
+ return hidden_states
368
+
369
+
370
+ class GitLayer(nn.Module):
371
+ def __init__(self, config):
372
+ super().__init__()
373
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
374
+ self.seq_len_dim = 1
375
+ self.attention = GitAttention(config)
376
+ self.intermediate = GitIntermediate(config)
377
+ self.output = GitOutput(config)
378
+
379
+ def forward(
380
+ self,
381
+ hidden_states: torch.Tensor,
382
+ attention_mask: Optional[torch.FloatTensor] = None,
383
+ head_mask: Optional[torch.FloatTensor] = None,
384
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
385
+ output_attentions: Optional[bool] = False,
386
+ pixel_values_present: Optional[bool] = False,
387
+ image_token_num: Optional[bool] = None,
388
+ ) -> Tuple[torch.Tensor]:
389
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
390
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
391
+ self_attention_outputs = self.attention(
392
+ hidden_states,
393
+ attention_mask,
394
+ head_mask,
395
+ output_attentions=output_attentions,
396
+ past_key_value=self_attn_past_key_value,
397
+ pixel_values_present=pixel_values_present,
398
+ image_token_num=image_token_num
399
+ )
400
+ attention_output = self_attention_outputs[0]
401
+
402
+ # if decoder, the last output is tuple of self-attn cache
403
+ outputs = self_attention_outputs[1:-1]
404
+ present_key_value = self_attention_outputs[-1]
405
+
406
+ layer_output = apply_chunking_to_forward(
407
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
408
+ )
409
+ outputs = (layer_output,) + outputs
410
+
411
+ # if decoder, return the attn key/values as the last output
412
+ outputs = outputs + (present_key_value,)
413
+
414
+ return outputs
415
+
416
+ def feed_forward_chunk(self, attention_output):
417
+ intermediate_output = self.intermediate(attention_output)
418
+ layer_output = self.output(intermediate_output, attention_output)
419
+ return layer_output
420
+
421
+
422
+ class GitEncoder(nn.Module):
423
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder.__init__ with Bert->Git
424
+ def __init__(self, config):
425
+ super().__init__()
426
+ self.config = config
427
+ self.layer = nn.ModuleList([GitLayer(config) for _ in range(config.num_hidden_layers)])
428
+ self.gradient_checkpointing = False
429
+
430
+ def forward(
431
+ self,
432
+ hidden_states: torch.Tensor,
433
+ attention_mask: Optional[torch.FloatTensor] = None,
434
+ head_mask: Optional[torch.FloatTensor] = None,
435
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
436
+ use_cache: Optional[bool] = None,
437
+ output_attentions: Optional[bool] = False,
438
+ output_hidden_states: Optional[bool] = False,
439
+ pixel_values_present: Optional[bool] = False,
440
+ image_token_num: Optional[int] = None,
441
+ return_dict: Optional[bool] = True,
442
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
443
+ if self.gradient_checkpointing and self.training:
444
+ if use_cache:
445
+ logger.warning_once(
446
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
447
+ )
448
+ use_cache = False
449
+
450
+ all_hidden_states = () if output_hidden_states else None
451
+ all_self_attentions = () if output_attentions else None
452
+
453
+ next_decoder_cache = () if use_cache else None
454
+ for i, layer_module in enumerate(self.layer):
455
+ if output_hidden_states:
456
+ all_hidden_states = all_hidden_states + (hidden_states,)
457
+
458
+ layer_head_mask = head_mask[i] if head_mask is not None else None
459
+ past_key_value = past_key_values[i] if past_key_values is not None else None
460
+
461
+ if self.gradient_checkpointing and self.training:
462
+
463
+ def create_custom_forward(module):
464
+ def custom_forward(*inputs):
465
+ return module(*inputs, past_key_value, output_attentions)
466
+
467
+ return custom_forward
468
+
469
+ layer_outputs = torch.utils.checkpoint.checkpoint(
470
+ create_custom_forward(layer_module),
471
+ hidden_states,
472
+ attention_mask,
473
+ layer_head_mask,
474
+ )
475
+ else:
476
+ layer_outputs = layer_module(
477
+ hidden_states,
478
+ attention_mask,
479
+ layer_head_mask,
480
+ past_key_value,
481
+ output_attentions,
482
+ pixel_values_present,
483
+ image_token_num,
484
+
485
+ )
486
+
487
+ hidden_states = layer_outputs[0]
488
+ if use_cache:
489
+ next_decoder_cache += (layer_outputs[-1],)
490
+ if output_attentions:
491
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
492
+
493
+ if output_hidden_states:
494
+ all_hidden_states = all_hidden_states + (hidden_states,)
495
+
496
+ if not return_dict:
497
+ return tuple(
498
+ v
499
+ for v in [
500
+ hidden_states,
501
+ next_decoder_cache,
502
+ all_hidden_states,
503
+ all_self_attentions,
504
+ ]
505
+ if v is not None
506
+ )
507
+ return BaseModelOutputWithPast(
508
+ last_hidden_state=hidden_states,
509
+ past_key_values=next_decoder_cache,
510
+ hidden_states=all_hidden_states,
511
+ attentions=all_self_attentions,
512
+ )
513
+
514
+
515
+ class GitPreTrainedModel(PreTrainedModel):
516
+ """
517
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
518
+ models.
519
+ """
520
+
521
+ config_class = GitConfig
522
+ base_model_prefix = "git"
523
+ supports_gradient_checkpointing = True
524
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
525
+
526
+ def _init_weights(self, module):
527
+ """Initialize the weights"""
528
+ if isinstance(module, GitVisionEmbeddings):
529
+ nn.init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
530
+ nn.init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
531
+ nn.init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
532
+ if isinstance(module, nn.Linear):
533
+ # Slightly different from the TF version which uses truncated_normal for initialization
534
+ # cf https://github.com/pytorch/pytorch/pull/5617
535
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
536
+ if module.bias is not None:
537
+ module.bias.data.zero_()
538
+ elif isinstance(module, nn.Embedding):
539
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
540
+ if module.padding_idx is not None:
541
+ module.weight.data[module.padding_idx].zero_()
542
+ elif isinstance(module, nn.LayerNorm):
543
+ module.bias.data.zero_()
544
+ module.weight.data.fill_(1.0)
545
+
546
+ def _set_gradient_checkpointing(self, module, value=False):
547
+ if isinstance(module, (GitEncoder, GitVisionEncoder)):
548
+ module.gradient_checkpointing = value
549
+
550
+
551
+ GIT_START_DOCSTRING = r"""
552
+
553
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
554
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
555
+ etc.)
556
+
557
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
558
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
559
+ and behavior.
560
+
561
+ Parameters:
562
+ config ([`GitConfig`]): Model configuration class with all the parameters of the model.
563
+ Initializing with a config file does not load the weights associated with the model, only the
564
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
565
+ """
566
+
567
+ GIT_INPUTS_DOCSTRING = r"""
568
+ Args:
569
+ input_ids (`torch.LongTensor` of shape `({0})`):
570
+ Indices of input sequence tokens in the vocabulary.
571
+
572
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
573
+ [`PreTrainedTokenizer.__call__`] for details.
574
+
575
+ [What are input IDs?](../glossary#input-ids)
576
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
577
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
578
+
579
+ - 1 for tokens that are **not masked**,
580
+ - 0 for tokens that are **masked**.
581
+
582
+ [What are attention masks?](../glossary#attention-mask)
583
+
584
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
585
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
586
+ config.max_position_embeddings - 1]`.
587
+
588
+ [What are position IDs?](../glossary#position-ids)
589
+
590
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
591
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
592
+ [`CLIPImageProcessor.__call__`] for details.
593
+
594
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
595
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
596
+
597
+ - 1 indicates the head is **not masked**,
598
+ - 0 indicates the head is **masked**.
599
+
600
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
601
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
602
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
603
+ model's internal embedding lookup matrix.
604
+ output_attentions (`bool`, *optional*):
605
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
606
+ tensors for more detail.
607
+ output_hidden_states (`bool`, *optional*):
608
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
609
+ more detail.
610
+ return_dict (`bool`, *optional*):
611
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
612
+ """
613
+
614
+
615
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
616
+ class GitVisionEmbeddings(nn.Module):
617
+ def __init__(self, config: GitVisionConfig):
618
+ super().__init__()
619
+ self.config = config
620
+ self.embed_dim = config.hidden_size
621
+ self.image_size = config.image_size
622
+ self.patch_size = config.patch_size
623
+
624
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
625
+
626
+ self.patch_embedding = nn.Conv2d(
627
+ in_channels=config.num_channels,
628
+ out_channels=self.embed_dim,
629
+ kernel_size=self.patch_size,
630
+ stride=self.patch_size,
631
+ bias=False,
632
+ )
633
+
634
+ self.num_patches = (self.image_size // self.patch_size) ** 2
635
+ self.num_positions = self.num_patches + 1
636
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
637
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
638
+
639
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
640
+ batch_size = pixel_values.shape[0]
641
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
642
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
643
+
644
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
645
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
646
+ embeddings = embeddings + self.position_embedding(self.position_ids)
647
+ return embeddings
648
+
649
+
650
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP
651
+ class GitVisionMLP(nn.Module):
652
+ def __init__(self, config):
653
+ super().__init__()
654
+ self.config = config
655
+ self.activation_fn = ACT2FN[config.hidden_act]
656
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
657
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
658
+
659
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
660
+ hidden_states = self.fc1(hidden_states)
661
+ hidden_states = self.activation_fn(hidden_states)
662
+ hidden_states = self.fc2(hidden_states)
663
+ return hidden_states
664
+
665
+
666
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention
667
+ class GitVisionAttention(nn.Module):
668
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
669
+
670
+ def __init__(self, config):
671
+ super().__init__()
672
+ self.config = config
673
+ self.embed_dim = config.hidden_size
674
+ self.num_heads = config.num_attention_heads
675
+ self.head_dim = self.embed_dim // self.num_heads
676
+ if self.head_dim * self.num_heads != self.embed_dim:
677
+ raise ValueError(
678
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
679
+ f" {self.num_heads})."
680
+ )
681
+ self.scale = self.head_dim**-0.5
682
+ self.dropout = config.attention_dropout
683
+
684
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
685
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
686
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
687
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
688
+
689
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
690
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
691
+
692
+ def forward(
693
+ self,
694
+ hidden_states: torch.Tensor,
695
+ attention_mask: Optional[torch.Tensor] = None,
696
+ causal_attention_mask: Optional[torch.Tensor] = None,
697
+ output_attentions: Optional[bool] = False,
698
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
699
+ """Input shape: Batch x Time x Channel"""
700
+
701
+ bsz, tgt_len, embed_dim = hidden_states.size()
702
+
703
+ # get query proj
704
+ query_states = self.q_proj(hidden_states) * self.scale
705
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
706
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
707
+
708
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
709
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
710
+ key_states = key_states.view(*proj_shape)
711
+ value_states = value_states.view(*proj_shape)
712
+
713
+ src_len = key_states.size(1)
714
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
715
+
716
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
717
+ raise ValueError(
718
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
719
+ f" {attn_weights.size()}"
720
+ )
721
+
722
+ # apply the causal_attention_mask first
723
+ if causal_attention_mask is not None:
724
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
725
+ raise ValueError(
726
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
727
+ f" {causal_attention_mask.size()}"
728
+ )
729
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
730
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
731
+
732
+ if attention_mask is not None:
733
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
734
+ raise ValueError(
735
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
736
+ )
737
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
738
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
739
+
740
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
741
+
742
+ if output_attentions:
743
+ # this operation is a bit akward, but it's required to
744
+ # make sure that attn_weights keeps its gradient.
745
+ # In order to do so, attn_weights have to reshaped
746
+ # twice and have to be reused in the following
747
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
748
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
749
+ else:
750
+ attn_weights_reshaped = None
751
+
752
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
753
+
754
+ attn_output = torch.bmm(attn_probs, value_states)
755
+
756
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
757
+ raise ValueError(
758
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
759
+ f" {attn_output.size()}"
760
+ )
761
+
762
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
763
+ attn_output = attn_output.transpose(1, 2)
764
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
765
+
766
+ attn_output = self.out_proj(attn_output)
767
+
768
+ return attn_output, attn_weights_reshaped
769
+
770
+
771
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->GitVision
772
+ class GitVisionEncoderLayer(nn.Module):
773
+ def __init__(self, config: GitVisionConfig):
774
+ super().__init__()
775
+ self.embed_dim = config.hidden_size
776
+ self.self_attn = GitVisionAttention(config)
777
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
778
+ self.mlp = GitVisionMLP(config)
779
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
780
+
781
+ def forward(
782
+ self,
783
+ hidden_states: torch.Tensor,
784
+ attention_mask: torch.Tensor,
785
+ causal_attention_mask: torch.Tensor,
786
+ output_attentions: Optional[bool] = False,
787
+ ) -> Tuple[torch.FloatTensor]:
788
+ """
789
+ Args:
790
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
791
+ attention_mask (`torch.FloatTensor`): attention mask of size
792
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
793
+ `(config.encoder_attention_heads,)`.
794
+ output_attentions (`bool`, *optional*):
795
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
796
+ returned tensors for more detail.
797
+ """
798
+ residual = hidden_states
799
+
800
+ hidden_states = self.layer_norm1(hidden_states)
801
+ hidden_states, attn_weights = self.self_attn(
802
+ hidden_states=hidden_states,
803
+ attention_mask=attention_mask,
804
+ causal_attention_mask=causal_attention_mask,
805
+ output_attentions=output_attentions,
806
+ )
807
+ hidden_states = residual + hidden_states
808
+
809
+ residual = hidden_states
810
+ hidden_states = self.layer_norm2(hidden_states)
811
+ hidden_states = self.mlp(hidden_states)
812
+ hidden_states = residual + hidden_states
813
+
814
+ outputs = (hidden_states,)
815
+
816
+ if output_attentions:
817
+ outputs += (attn_weights,)
818
+
819
+ return outputs
820
+
821
+
822
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->GitVision, CLIPConfig
823
+ class GitVisionEncoder(nn.Module):
824
+ """
825
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
826
+ [`GitVisionEncoderLayer`].
827
+
828
+ Args:
829
+ config: GitVisionConfig
830
+ """
831
+
832
+ def __init__(self, config: GitVisionConfig):
833
+ super().__init__()
834
+ self.config = config
835
+ self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
836
+ self.gradient_checkpointing = False
837
+
838
+ def forward(
839
+ self,
840
+ inputs_embeds,
841
+ attention_mask: Optional[torch.Tensor] = None,
842
+ causal_attention_mask: Optional[torch.Tensor] = None,
843
+ output_attentions: Optional[bool] = None,
844
+ output_hidden_states: Optional[bool] = None,
845
+ return_dict: Optional[bool] = None,
846
+ ) -> Union[Tuple, BaseModelOutput]:
847
+ r"""
848
+ Args:
849
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
850
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
851
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
852
+ than the model's internal embedding lookup matrix.
853
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
854
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
855
+
856
+ - 1 for tokens that are **not masked**,
857
+ - 0 for tokens that are **masked**.
858
+
859
+ [What are attention masks?](../glossary#attention-mask)
860
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
861
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
862
+
863
+ - 1 for tokens that are **not masked**,
864
+ - 0 for tokens that are **masked**.
865
+
866
+ [What are attention masks?](../glossary#attention-mask)
867
+ output_attentions (`bool`, *optional*):
868
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
869
+ returned tensors for more detail.
870
+ output_hidden_states (`bool`, *optional*):
871
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
872
+ for more detail.
873
+ return_dict (`bool`, *optional*):
874
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
875
+ """
876
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
877
+ output_hidden_states = (
878
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
879
+ )
880
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
881
+
882
+ encoder_states = () if output_hidden_states else None
883
+ all_attentions = () if output_attentions else None
884
+
885
+ hidden_states = inputs_embeds
886
+ for idx, encoder_layer in enumerate(self.layers):
887
+ if output_hidden_states:
888
+ encoder_states = encoder_states + (hidden_states,)
889
+ if self.gradient_checkpointing and self.training:
890
+
891
+ def create_custom_forward(module):
892
+ def custom_forward(*inputs):
893
+ return module(*inputs, output_attentions)
894
+
895
+ return custom_forward
896
+
897
+ layer_outputs = torch.utils.checkpoint.checkpoint(
898
+ create_custom_forward(encoder_layer),
899
+ hidden_states,
900
+ attention_mask,
901
+ causal_attention_mask,
902
+ )
903
+ else:
904
+ layer_outputs = encoder_layer(
905
+ hidden_states,
906
+ attention_mask,
907
+ causal_attention_mask,
908
+ output_attentions=output_attentions,
909
+ )
910
+
911
+ hidden_states = layer_outputs[0]
912
+
913
+ if output_attentions:
914
+ all_attentions = all_attentions + (layer_outputs[1],)
915
+
916
+ if output_hidden_states:
917
+ encoder_states = encoder_states + (hidden_states,)
918
+
919
+ if not return_dict:
920
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
921
+ return BaseModelOutput(
922
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
923
+ )
924
+
925
+
926
+ GIT_VISION_INPUTS_DOCSTRING = r"""
927
+ Args:
928
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
929
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
930
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
931
+ output_attentions (`bool`, *optional*):
932
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
933
+ tensors for more detail.
934
+ output_hidden_states (`bool`, *optional*):
935
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
936
+ more detail.
937
+ return_dict (`bool`, *optional*):
938
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
939
+ """
940
+
941
+
942
+ class GitVisionTransformer(nn.Module):
943
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.__init__ with CLIPEncoder->GitVisionEncoder, CLIP->Git
944
+ def __init__(self, config: GitVisionConfig):
945
+ super().__init__()
946
+ self.config = config
947
+ embed_dim = config.hidden_size
948
+
949
+ self.embeddings = GitVisionEmbeddings(config)
950
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
951
+ self.patch_mask_generator = ViTPatchMaskGenerator(config.patch_size)
952
+ self.encoder = GitVisionEncoder(config)
953
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
954
+
955
+ @add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
956
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
957
+ def forward(
958
+ self,
959
+ pixel_values: Optional[torch.FloatTensor] = None,
960
+ pixel_masks: Optional[torch.Tensor] = None,
961
+ output_attentions: Optional[bool] = None,
962
+ output_hidden_states: Optional[bool] = None,
963
+ return_dict: Optional[bool] = None,
964
+ ) -> Union[Tuple, BaseModelOutput]:
965
+ r"""
966
+ Returns:
967
+
968
+ """
969
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
970
+ output_hidden_states = (
971
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
972
+ )
973
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
974
+
975
+ if pixel_values is None:
976
+ raise ValueError("You have to specify pixel_values")
977
+
978
+ hidden_states = self.embeddings(pixel_values)
979
+ B, N, D = hidden_states.shape
980
+ # print('Before mask:', hidden_states.shape)
981
+ if pixel_masks is not None:
982
+ assert pixel_masks.shape[0] == 1
983
+ patch_masks = self.patch_mask_generator(pixel_masks)
984
+ # print(patch_masks.shape)
985
+ patch_masks = patch_masks.unsqueeze(-1).expand_as(hidden_states)
986
+ hidden_states = hidden_states.masked_select(patch_masks).view(B, -1, D)
987
+ # print('After mask:', hidden_states.shape)
988
+ hidden_states = self.pre_layrnorm(hidden_states)
989
+
990
+ encoder_outputs = self.encoder(
991
+ inputs_embeds=hidden_states,
992
+ output_attentions=output_attentions,
993
+ output_hidden_states=output_hidden_states,
994
+ return_dict=return_dict,
995
+ )
996
+
997
+ last_hidden_state = encoder_outputs[0]
998
+
999
+ last_hidden_state = self.post_layernorm(last_hidden_state)
1000
+
1001
+ if not return_dict:
1002
+ return (last_hidden_state,) + encoder_outputs[1:]
1003
+
1004
+ return BaseModelOutput(
1005
+ last_hidden_state=last_hidden_state,
1006
+ hidden_states=encoder_outputs.hidden_states,
1007
+ attentions=encoder_outputs.attentions,
1008
+ )
1009
+
1010
+
1011
+ @add_start_docstrings(
1012
+ """The vision model from CLIP, used in GIT, without any head or projection on top.""",
1013
+ GIT_START_DOCSTRING,
1014
+ )
1015
+ class GitVisionModel(GitPreTrainedModel):
1016
+ config_class = GitVisionConfig
1017
+ main_input_name = "pixel_values"
1018
+
1019
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP->Git
1020
+ def __init__(self, config: GitVisionConfig):
1021
+ super().__init__(config)
1022
+ self.vision_model = GitVisionTransformer(config)
1023
+ # Initialize weights and apply final processing
1024
+ self.post_init()
1025
+
1026
+ def get_input_embeddings(self) -> nn.Module:
1027
+ return self.vision_model.embeddings.patch_embedding
1028
+
1029
+ @add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
1030
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
1031
+ def forward(
1032
+ self,
1033
+ pixel_values: Optional[torch.FloatTensor] = None,
1034
+ pixel_masks: Optional[torch.Tensor] = None,
1035
+ output_attentions: Optional[bool] = None,
1036
+ output_hidden_states: Optional[bool] = None,
1037
+ return_dict: Optional[bool] = None,
1038
+ ) -> Union[Tuple, BaseModelOutput]:
1039
+ r"""
1040
+ Returns:
1041
+
1042
+ Examples:
1043
+
1044
+ ```python
1045
+ >>> from PIL import Image
1046
+ >>> import requests
1047
+ >>> from transformers import AutoProcessor, GitVisionModel
1048
+
1049
+ >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
1050
+ >>> model = GitVisionModel.from_pretrained("microsoft/git-base")
1051
+
1052
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1053
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1054
+
1055
+ >>> inputs = processor(images=image, return_tensors="pt")
1056
+
1057
+ >>> outputs = model(**inputs)
1058
+ >>> last_hidden_state = outputs.last_hidden_state
1059
+ ```"""
1060
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1061
+
1062
+ return self.vision_model(
1063
+ pixel_values=pixel_values,
1064
+ pixel_masks=pixel_masks,
1065
+ output_attentions=output_attentions,
1066
+ output_hidden_states=output_hidden_states,
1067
+ return_dict=return_dict,
1068
+ )
1069
+
1070
+
1071
+ class GitProjection(nn.Module):
1072
+ def __init__(self, config: GitConfig):
1073
+ super().__init__()
1074
+ self.config = config
1075
+ self.visual_projection = nn.Sequential(
1076
+ nn.Linear(config.vision_config.hidden_size, config.hidden_size),
1077
+ nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps),
1078
+ )
1079
+
1080
+ def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
1081
+ return self.visual_projection(embeddings)
1082
+
1083
+
1084
+ @add_start_docstrings(
1085
+ "The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states"
1086
+ " without any specific head on top.",
1087
+ GIT_START_DOCSTRING,
1088
+ )
1089
+ class GitModel(GitPreTrainedModel):
1090
+ def __init__(self, config):
1091
+ super().__init__(config)
1092
+ self.config = config
1093
+
1094
+ self.embeddings = GitEmbeddings(config)
1095
+ self.image_encoder = GitVisionModel(config.vision_config)
1096
+ self.encoder = GitEncoder(config)
1097
+
1098
+ self.visual_projection = GitProjection(config)
1099
+
1100
+ if config.num_image_with_embedding is not None:
1101
+ self.img_temperal_embedding = nn.ParameterList(
1102
+ nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
1103
+ for _ in range(config.num_image_with_embedding)
1104
+ )
1105
+
1106
+ # Initialize weights and apply final processing
1107
+ self.post_init()
1108
+
1109
+ def get_input_embeddings(self):
1110
+ return self.embeddings.word_embeddings
1111
+
1112
+ def set_input_embeddings(self, value):
1113
+ self.embeddings.word_embeddings = value
1114
+
1115
+ def _prune_heads(self, heads_to_prune):
1116
+ """
1117
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1118
+ class PreTrainedModel
1119
+ """
1120
+ for layer, heads in heads_to_prune.items():
1121
+ self.encoder.layer[layer].attention.prune_heads(heads)
1122
+
1123
+ def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
1124
+ # Default mask is for forward direction. Flip for backward direction.
1125
+ mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
1126
+ mask = mask.masked_fill(mask == 1, float("-inf"))
1127
+ return mask
1128
+
1129
+ def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None):
1130
+ num_tgt = tgt.shape[1]
1131
+ num_memory = memory.shape[1]
1132
+ device = tgt.device
1133
+ dtype = tgt.dtype
1134
+ top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
1135
+ top_right = torch.full(
1136
+ (num_memory, num_tgt + past_key_values_length),
1137
+ float("-inf"),
1138
+ device=tgt.device,
1139
+ dtype=dtype,
1140
+ )
1141
+ bottom_left = torch.zeros(
1142
+ (num_tgt, num_memory),
1143
+ dtype=dtype,
1144
+ device=tgt_mask.device,
1145
+ )
1146
+
1147
+ if past_key_values_length > 0:
1148
+ tgt_mask = torch.zeros(
1149
+ (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
1150
+ dtype=dtype,
1151
+ device=tgt_mask.device,
1152
+ )
1153
+
1154
+ left = torch.cat((top_left, bottom_left), dim=0)
1155
+ right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
1156
+
1157
+ full_attention_mask = torch.cat((left, right), dim=1)[None, :]
1158
+
1159
+ if memory_key_padding_mask is None:
1160
+ memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
1161
+ # if it is False, it means valid. That is, it is not a padding
1162
+ if memory_key_padding_mask.dtype != torch.bool:
1163
+ raise ValueError("Memory key padding mask must be a boolean tensor.")
1164
+ zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
1165
+ zero_negative_infinity[memory_key_padding_mask] = float("-inf")
1166
+ full_attention_mask = full_attention_mask.expand(
1167
+ (memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt)
1168
+ )
1169
+ full_attention_mask = full_attention_mask.clone()
1170
+ origin_left = full_attention_mask[:, :, :num_memory]
1171
+ update = zero_negative_infinity[:, None, :]
1172
+ full_attention_mask[:, :, :num_memory] = origin_left + update
1173
+
1174
+ # add axis for multi-head
1175
+ full_attention_mask = full_attention_mask[:, None, :, :]
1176
+
1177
+ return full_attention_mask
1178
+
1179
+ @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1180
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
1181
+ def forward(
1182
+ self,
1183
+ input_ids: Optional[torch.Tensor] = None,
1184
+ attention_mask: Optional[torch.Tensor] = None,
1185
+ position_ids: Optional[torch.Tensor] = None,
1186
+ pixel_values: Optional[torch.Tensor] = None,
1187
+ pixel_masks: Optional[torch.Tensor] = None,
1188
+ head_mask: Optional[torch.Tensor] = None,
1189
+ inputs_embeds: Optional[torch.Tensor] = None,
1190
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1191
+ use_cache: Optional[bool] = None,
1192
+ output_attentions: Optional[bool] = None,
1193
+ output_hidden_states: Optional[bool] = None,
1194
+ return_dict: Optional[bool] = None,
1195
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
1196
+ r"""
1197
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1198
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1199
+
1200
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1201
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1202
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1203
+ use_cache (`bool`, *optional*):
1204
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1205
+ `past_key_values`).
1206
+
1207
+ Returns:
1208
+
1209
+ Examples:
1210
+
1211
+ ```python
1212
+ >>> from transformers import AutoProcessor, AutoModel
1213
+ >>> import requests
1214
+ >>> from PIL import Image
1215
+
1216
+ >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
1217
+ >>> model = AutoModel.from_pretrained("microsoft/git-base")
1218
+
1219
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1220
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1221
+
1222
+ >>> text = "this is an image of two cats"
1223
+
1224
+ >>> inputs = processor(text, images=image, return_tensors="pt")
1225
+
1226
+ >>> outputs = model(**inputs)
1227
+ >>> last_hidden_state = outputs.last_hidden_state
1228
+ ```"""
1229
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1230
+ output_hidden_states = (
1231
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1232
+ )
1233
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1234
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1235
+
1236
+ if input_ids is not None and inputs_embeds is not None:
1237
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1238
+ elif input_ids is not None:
1239
+ input_shape = input_ids.size()
1240
+ elif inputs_embeds is not None:
1241
+ input_shape = inputs_embeds.size()[:-1]
1242
+ else:
1243
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1244
+
1245
+ seq_length = input_shape[1]
1246
+
1247
+ # past_key_values_length
1248
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1249
+
1250
+ # Prepare head mask if needed
1251
+ # 1.0 in head_mask indicate we keep the head
1252
+ # attention_probs has shape bsz x n_heads x N x N
1253
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1254
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1255
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1256
+
1257
+ projected_visual_features = None
1258
+ if pixel_values is not None:
1259
+ if pixel_values.ndim == 4:
1260
+ # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
1261
+ visual_features = self.image_encoder(pixel_values=pixel_values, pixel_masks=pixel_masks).last_hidden_state
1262
+
1263
+ elif pixel_values.ndim == 5:
1264
+ # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
1265
+ visual_features = []
1266
+ for frame_idx in range(pixel_values.shape[1]):
1267
+ visual_features_frame = self.image_encoder(pixel_values[:, frame_idx, :, :]).last_hidden_state
1268
+ visual_features_frame += self.img_temperal_embedding[frame_idx]
1269
+ visual_features.append(visual_features_frame)
1270
+
1271
+ # finally, concatenate all features along sequence dimension
1272
+ visual_features = torch.cat(visual_features, dim=1)
1273
+
1274
+ else:
1275
+ raise ValueError("pixel_values must be of rank 4 or 5")
1276
+
1277
+ projected_visual_features = self.visual_projection(visual_features)
1278
+ image_token_num = projected_visual_features.shape[1]
1279
+ embedding_output = self.embeddings(
1280
+ input_ids=input_ids,
1281
+ position_ids=position_ids,
1282
+ inputs_embeds=inputs_embeds,
1283
+ past_key_values_length=past_key_values_length,
1284
+ )
1285
+
1286
+ if projected_visual_features is None:
1287
+ projected_visual_features = torch.zeros(
1288
+ (embedding_output.shape[0], 0, embedding_output.shape[2]),
1289
+ dtype=embedding_output.dtype,
1290
+ device=embedding_output.device,
1291
+ )
1292
+
1293
+ # Repeat visual features to match embedding batch size.
1294
+ projected_visual_features = projected_visual_features.repeat(
1295
+ embedding_output.size(0) // projected_visual_features.size(0), 1, 1
1296
+ )
1297
+
1298
+ # concatenate patch token and text token embeddings
1299
+ hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
1300
+
1301
+ # By default, an additive causal mask is created
1302
+ # for masking the future (one direction).
1303
+ tgt_mask = self._generate_future_mask(seq_length, embedding_output.dtype, embedding_output.device)
1304
+
1305
+ # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
1306
+ combined_attention_mask = self.create_attention_mask(
1307
+ tgt=embedding_output,
1308
+ memory=projected_visual_features,
1309
+ tgt_mask=tgt_mask,
1310
+ past_key_values_length=past_key_values_length,
1311
+ )
1312
+
1313
+ if attention_mask is not None:
1314
+ # if the user provides an attention mask, we add it to the default one
1315
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1316
+ expanded_attn_mask = _expand_mask(attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]).to(
1317
+ embedding_output.device
1318
+ )
1319
+ if past_key_values_length > 0:
1320
+ expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
1321
+ else:
1322
+ combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask
1323
+
1324
+ encoder_outputs = self.encoder(
1325
+ hidden_states,
1326
+ attention_mask=combined_attention_mask,
1327
+ head_mask=head_mask,
1328
+ past_key_values=past_key_values,
1329
+ use_cache=use_cache,
1330
+ output_attentions=output_attentions,
1331
+ output_hidden_states=output_hidden_states,
1332
+ return_dict=return_dict,
1333
+ pixel_values_present=pixel_values is not None,
1334
+ image_token_num=image_token_num
1335
+ )
1336
+ sequence_output = encoder_outputs[0]
1337
+
1338
+ if not return_dict:
1339
+ return (sequence_output,) + encoder_outputs[1:]
1340
+
1341
+ return BaseModelOutputWithPast(
1342
+ last_hidden_state=sequence_output,
1343
+ past_key_values=encoder_outputs.past_key_values,
1344
+ hidden_states=encoder_outputs.hidden_states,
1345
+ attentions=encoder_outputs.attentions,
1346
+ )
1347
+
1348
+
1349
+ @add_start_docstrings(
1350
+ """GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING
1351
+ )
1352
+ class GitForCausalLM(GitPreTrainedModel):
1353
+ def __init__(self, config):
1354
+ super().__init__(config)
1355
+
1356
+ self.git = GitModel(config)
1357
+ self.output = nn.Linear(config.hidden_size, config.vocab_size)
1358
+
1359
+ # Initialize weights and apply final processing
1360
+ self.post_init()
1361
+
1362
+ def get_output_embeddings(self):
1363
+ return self.output
1364
+
1365
+ def set_output_embeddings(self, new_embeddings):
1366
+ self.output = new_embeddings
1367
+
1368
+ @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1369
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1370
+ def forward(
1371
+ self,
1372
+ input_ids: Optional[torch.Tensor] = None,
1373
+ attention_mask: Optional[torch.Tensor] = None,
1374
+ position_ids: Optional[torch.Tensor] = None,
1375
+ pixel_values: Optional[torch.Tensor] = None,
1376
+ pixel_masks: Optional[torch.Tensor] = None,
1377
+ head_mask: Optional[torch.Tensor] = None,
1378
+ inputs_embeds: Optional[torch.Tensor] = None,
1379
+ labels: Optional[torch.Tensor] = None,
1380
+ past_key_values: Optional[List[torch.Tensor]] = None,
1381
+ use_cache: Optional[bool] = None,
1382
+ output_attentions: Optional[bool] = None,
1383
+ output_hidden_states: Optional[bool] = None,
1384
+ return_dict: Optional[bool] = None,
1385
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
1386
+ r"""
1387
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1388
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1389
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1390
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
1391
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1392
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1393
+
1394
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1395
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1396
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1397
+ use_cache (`bool`, *optional*):
1398
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1399
+ `past_key_values`).
1400
+
1401
+ Returns:
1402
+
1403
+ Examples:
1404
+
1405
+ Image captioning example:
1406
+
1407
+ ```python
1408
+ >>> from transformers import AutoProcessor, AutoModelForCausalLM
1409
+ >>> import requests
1410
+ >>> from PIL import Image
1411
+
1412
+ >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
1413
+ >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
1414
+
1415
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1416
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1417
+
1418
+ >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
1419
+
1420
+ >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
1421
+ >>> generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1422
+ >>> print(generated_caption)
1423
+ two cats sleeping on a pink blanket next to remotes.
1424
+ ```
1425
+
1426
+ Visual question answering (VQA) example:
1427
+
1428
+ ```python
1429
+ >>> from transformers import AutoProcessor, AutoModelForCausalLM
1430
+ >>> from huggingface_hub import hf_hub_download
1431
+ >>> from PIL import Image
1432
+
1433
+ >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
1434
+ >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
1435
+
1436
+ >>> file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
1437
+ >>> image = Image.open(file_path).convert("RGB")
1438
+
1439
+ >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
1440
+
1441
+ >>> question = "what does the front of the bus say at the top?"
1442
+
1443
+ >>> input_ids = processor(text=question, add_special_tokens=False).input_ids
1444
+ >>> input_ids = [processor.tokenizer.cls_token_id] + input_ids
1445
+ >>> input_ids = torch.tensor(input_ids).unsqueeze(0)
1446
+
1447
+ >>> generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
1448
+ >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True))
1449
+ ['what does the front of the bus say at the top? special']
1450
+ ```
1451
+
1452
+ Video captioning example:
1453
+
1454
+ ```python
1455
+ >>> import av
1456
+ >>> import numpy as np
1457
+ >>> from PIL import Image
1458
+ >>> from huggingface_hub import hf_hub_download
1459
+ >>> from transformers import AutoProcessor, AutoModelForCausalLM
1460
+
1461
+ >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex")
1462
+ >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")
1463
+
1464
+ >>> # set seed for reproducability
1465
+ >>> np.random.seed(45)
1466
+
1467
+
1468
+ >>> def read_video_pyav(container, indices):
1469
+ ... '''
1470
+ ... Decode the video with PyAV decoder.
1471
+ ... Args:
1472
+ ... container (`av.container.input.InputContainer`): PyAV container.
1473
+ ... indices (`List[int]`): List of frame indices to decode.
1474
+ ... Returns:
1475
+ ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
1476
+ ... '''
1477
+ ... frames = []
1478
+ ... container.seek(0)
1479
+ ... start_index = indices[0]
1480
+ ... end_index = indices[-1]
1481
+ ... for i, frame in enumerate(container.decode(video=0)):
1482
+ ... if i > end_index:
1483
+ ... break
1484
+ ... if i >= start_index and i in indices:
1485
+ ... frames.append(frame)
1486
+ ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
1487
+
1488
+
1489
+ >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
1490
+ ... converted_len = int(clip_len * frame_sample_rate)
1491
+ ... end_idx = np.random.randint(converted_len, seg_len)
1492
+ ... start_idx = end_idx - converted_len
1493
+ ... indices = np.linspace(start_idx, end_idx, num=clip_len)
1494
+ ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
1495
+ ... return indices
1496
+
1497
+
1498
+ >>> # load video
1499
+ >>> file_path = hf_hub_download(
1500
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
1501
+ ... )
1502
+ >>> container = av.open(file_path)
1503
+
1504
+ >>> # sample frames
1505
+ >>> num_frames = model.config.num_image_with_embedding
1506
+ >>> indices = sample_frame_indices(
1507
+ ... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
1508
+ ... )
1509
+ >>> frames = read_video_pyav(container, indices)
1510
+
1511
+ >>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values
1512
+
1513
+ >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
1514
+
1515
+ >>> print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True))
1516
+ Generated caption: ['a woman is sitting at a table and she is talking about the food she is holding.']
1517
+ ```
1518
+ """
1519
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1520
+ if labels is not None:
1521
+ use_cache = False
1522
+
1523
+ outputs = self.git(
1524
+ input_ids,
1525
+ attention_mask=attention_mask,
1526
+ position_ids=position_ids,
1527
+ pixel_values=pixel_values,
1528
+ pixel_masks=pixel_masks,
1529
+ head_mask=head_mask,
1530
+ inputs_embeds=inputs_embeds,
1531
+ past_key_values=past_key_values,
1532
+ use_cache=use_cache,
1533
+ output_attentions=output_attentions,
1534
+ output_hidden_states=output_hidden_states,
1535
+ return_dict=return_dict,
1536
+ )
1537
+
1538
+ sequence_output = outputs[0]
1539
+ logits = self.output(sequence_output)
1540
+
1541
+ loss = None
1542
+ if labels is not None:
1543
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1544
+ num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
1545
+ shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
1546
+ labels = labels[:, 1:].contiguous()
1547
+ loss_fct = CrossEntropyLoss()
1548
+ loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
1549
+
1550
+ if not return_dict:
1551
+ output = (logits,) + outputs[1:]
1552
+ return ((loss,) + output) if loss is not None else output
1553
+
1554
+ return CausalLMOutputWithPast(
1555
+ loss=loss,
1556
+ logits=logits,
1557
+ past_key_values=outputs.past_key_values,
1558
+ hidden_states=outputs.hidden_states,
1559
+ attentions=outputs.attentions,
1560
+ )
1561
+
1562
+ def prepare_inputs_for_generation(
1563
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
1564
+ ):
1565
+ # cut decoder_input_ids if past_key_values is used
1566
+ if past_key_values is not None:
1567
+ input_ids = input_ids[:, -1:]
1568
+
1569
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1570
+ input_shape = input_ids.shape
1571
+ if attention_mask is None:
1572
+ attention_mask = input_ids.new_ones(input_shape)
1573
+
1574
+ return {
1575
+ "input_ids": input_ids,
1576
+ "attention_mask": attention_mask,
1577
+ "pixel_values": kwargs.get("pixel_values", None),
1578
+ "pixel_masks": kwargs.get("pixel_masks", None),
1579
+ "past_key_values": past_key_values,
1580
+ "use_cache": use_cache,
1581
+ }
1582
+
1583
+ def _reorder_cache(self, past_key_values, beam_idx):
1584
+ reordered_past = ()
1585
+ for layer_past in past_key_values:
1586
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1587
+ return reordered_past
captioner/vit_pixel_masks_utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class ViTPatchMaskGenerator(nn.Module):
7
+ def __init__(self, patch_size) -> None:
8
+ super(ViTPatchMaskGenerator, self).__init__()
9
+ self.patch_size = patch_size
10
+ self.pool = nn.MaxPool2d(kernel_size=patch_size, stride=patch_size)
11
+
12
+ def forward(self, pixel_masks):
13
+ patch_mask = self.pool(pixel_masks)
14
+ patch_mask = patch_mask.bool().flatten(1)
15
+ cls_token_mask = patch_mask.new_ones([patch_mask.shape[0], 1]).bool()
16
+ patch_mask = torch.cat([cls_token_mask, patch_mask], dim=-1)
17
+ return patch_mask
image_editing_utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw, ImageFont
2
+ import copy
3
+ import numpy as np
4
+
5
+ def wrap_text(text, font, max_width):
6
+ lines = []
7
+ words = text.split(' ')
8
+ current_line = ''
9
+
10
+ for word in words:
11
+ if font.getsize(current_line + word)[0] <= max_width:
12
+ current_line += word + ' '
13
+ else:
14
+ lines.append(current_line)
15
+ current_line = word + ' '
16
+
17
+ lines.append(current_line)
18
+ return lines
19
+
20
+ def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.ttf', font_size_ratio=0.033):
21
+ # Load the image
22
+ if type(image) == np.ndarray:
23
+ image = Image.fromarray(image)
24
+
25
+ image = copy.deepcopy(image)
26
+ width, height = image.size
27
+
28
+ # Calculate max_text_width and font_size based on image dimensions and total number of characters
29
+ total_chars = len(text)
30
+ max_text_width = int(0.33 * width)
31
+ font_size = int(height * font_size_ratio)
32
+
33
+ # Load the font
34
+ font = ImageFont.truetype(font_path, font_size)
35
+
36
+ # Wrap the text to fit within the max_text_width
37
+ lines = wrap_text(text, font, max_text_width)
38
+ text_width, text_height = font.getsize(lines[0])
39
+ text_height = text_height * len(lines)
40
+
41
+ # Define bubble frame dimensions
42
+ padding = 10
43
+ bubble_width = text_width + 2 * padding
44
+ bubble_height = text_height + 2 * padding
45
+
46
+ # Create a new image for the bubble frame
47
+ bubble = Image.new('RGBA', (bubble_width, bubble_height), (255, 255, 255, 0))
48
+
49
+ # Draw the bubble frame on the new image
50
+ draw = ImageDraw.Draw(bubble)
51
+ draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
52
+
53
+ # Draw the wrapped text line by line
54
+ y_text = padding
55
+ for line in lines:
56
+ draw.text((padding, y_text), line, font=font, fill=(255, 255, 255, 255))
57
+ y_text += font.getsize(line)[1]
58
+
59
+ # Calculate the bubble frame position
60
+ x, y = point
61
+ if x + bubble_width > width:
62
+ x = width - bubble_width
63
+ if y + bubble_height > height:
64
+ y = height - bubble_height
65
+
66
+ # Paste the bubble frame onto the image
67
+ image.paste(bubble, (x, y), bubble)
68
+ return image
segmenter/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from segmenter.base_segmenter import BaseSegmenter
2
+
3
+
4
+ def build_segmenter(type, device, args=None):
5
+ if type == 'base':
6
+ return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features)
segmenter/base_segmenter.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+ from PIL import Image, ImageDraw, ImageOps
5
+ import numpy as np
6
+ from typing import Union
7
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ import matplotlib.pyplot as plt
9
+ import PIL
10
+
11
+ class BaseSegmenter:
12
+ def __init__(self, device, checkpoint, model_type='vit_h', reuse_feature = True):
13
+ print(f"Initializing BaseSegmenter to {device}")
14
+ self.device = device
15
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
16
+ self.processor = None
17
+ self.model_type = model_type
18
+ self.checkpoint = checkpoint
19
+ self.model = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
20
+ self.model.to(device=self.device)
21
+ self.reuse_feature = reuse_feature
22
+ self.predictor = SamPredictor(self.model)
23
+ self.mask_generator = SamAutomaticMaskGenerator(self.model)
24
+ self.image_embedding = None
25
+ self.image = None
26
+
27
+
28
+ @torch.no_grad()
29
+ def set_image(self, image: Union[np.ndarray, Image.Image, str]):
30
+ if type(image) == str: # input path
31
+ image = Image.open(image)
32
+ image = np.array(image)
33
+ elif type(image) == Image.Image:
34
+ image = np.array(image)
35
+ self.image = image
36
+ if self.reuse_feature:
37
+ self.predictor.set_image(image)
38
+ self.image_embedding = self.predictor.get_image_embedding()
39
+ print(self.image_embedding.shape)
40
+
41
+
42
+ @torch.no_grad()
43
+ def inference(self, image, control):
44
+ if 'everything' in control['prompt_type']:
45
+ masks = self.mask_generator.generate(image)
46
+ new_masks = np.concatenate([mask["segmentation"][np.newaxis,:] for mask in masks])
47
+ return new_masks
48
+ else:
49
+ if not self.reuse_feature:
50
+ self.set_image(image)
51
+ self.predictor.set_image(self.image)
52
+ else:
53
+ assert self.image_embedding is not None
54
+ self.predictor.features = self.image_embedding
55
+
56
+ if 'mutimask_output' in control:
57
+ masks, scores, logits = self.predictor.predict(
58
+ point_coords = np.array(control['input_point']),
59
+ point_labels = np.array(control['input_label']),
60
+ multimask_output = True,
61
+ )
62
+ elif 'input_boxes' in control:
63
+ transformed_boxes = self.predictor.transform.apply_boxes_torch(
64
+ torch.tensor(control["input_boxes"], device=self.predictor.device),
65
+ image.shape[:2]
66
+ )
67
+ masks, _, _ = self.predictor.predict_torch(
68
+ point_coords=None,
69
+ point_labels=None,
70
+ boxes=transformed_boxes,
71
+ multimask_output=False,
72
+ )
73
+ masks = masks.squeeze(1).cpu().numpy()
74
+
75
+ else:
76
+ input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None
77
+ input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None
78
+ input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None
79
+
80
+ masks, scores, logits = self.predictor.predict(
81
+ point_coords = input_point,
82
+ point_labels = input_label,
83
+ box = input_box,
84
+ multimask_output = False,
85
+ )
86
+
87
+ if 0 in control['input_label']:
88
+ mask_input = logits[np.argmax(scores), :, :]
89
+ masks, scores, logits = self.predictor.predict(
90
+ point_coords=input_point,
91
+ point_labels=input_label,
92
+ box = input_box,
93
+ mask_input=mask_input[None, :, :],
94
+ multimask_output=False,
95
+ )
96
+
97
+ return masks
98
+
99
+ if __name__ == "__main__":
100
+ image_path = 'segmenter/images/truck.jpg'
101
+ prompts = [
102
+ # {
103
+ # "prompt_type":["click"],
104
+ # "input_point":[[500, 375]],
105
+ # "input_label":[1],
106
+ # "multimask_output":"True",
107
+ # },
108
+ {
109
+ "prompt_type":["click"],
110
+ "input_point":[[1000, 600], [1325, 625]],
111
+ "input_label":[1, 0],
112
+ },
113
+ # {
114
+ # "prompt_type":["click", "box"],
115
+ # "input_box":[425, 600, 700, 875],
116
+ # "input_point":[[575, 750]],
117
+ # "input_label": [0]
118
+ # },
119
+ # {
120
+ # "prompt_type":["box"],
121
+ # "input_boxes": [
122
+ # [75, 275, 1725, 850],
123
+ # [425, 600, 700, 875],
124
+ # [1375, 550, 1650, 800],
125
+ # [1240, 675, 1400, 750],
126
+ # ]
127
+ # },
128
+ # {
129
+ # "prompt_type":["everything"]
130
+ # },
131
+ ]
132
+
133
+ init_time = time.time()
134
+ segmenter = BaseSegmenter(
135
+ device='cuda',
136
+ # checkpoint='sam_vit_h_4b8939.pth',
137
+ checkpoint='segmenter/sam_vit_h_4b8939.pth',
138
+ model_type='vit_h',
139
+ reuse_feature=True
140
+ )
141
+ print(f'init time: {time.time() - init_time}')
142
+
143
+ image_path = 'test_img/img2.jpg'
144
+ infer_time = time.time()
145
+ for i, prompt in enumerate(prompts):
146
+ print(f'{prompt["prompt_type"]} mode')
147
+ image = Image.open(image_path)
148
+ segmenter.set_image(np.array(image))
149
+ masks = segmenter.inference(np.array(image), prompt)
150
+ Image.fromarray(masks[0]).save('seg.png')
151
+ print(masks.shape)
152
+
153
+ print(f'infer time: {time.time() - infer_time}')
segmenter/images/truck.jpg ADDED
segmenter/readme.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Prepare SAM
2
+ ```
3
+ pip install git+https://github.com/facebookresearch/segment-anything.git
4
+ ```
5
+ or
6
+ ```
7
+ git clone [email protected]:facebookresearch/segment-anything.git
8
+ cd segment-anything; pip install -e .
9
+ ```
10
+
11
+ ```
12
+ pip install opencv-python pycocotools matplotlib onnxruntime onnx
13
+ ```
14
+ ### Download the checkpoint:
15
+
16
+ https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
17
+
18
+ ### Inference
19
+
20
+ The prompts are in json format:
21
+
22
+ ```
23
+ prompts = [
24
+ {
25
+ "prompt_type":["click"],
26
+ "input_point":[[500, 375]],
27
+ "input_label":[1],
28
+ "multimask_output":"True",
29
+ },
30
+ {
31
+ "prompt_type":["click"],
32
+ "input_point":[[500, 375], [1125, 625]],
33
+ "input_label":[1, 0],
34
+ },
35
+ {
36
+ "prompt_type":["click", "box"],
37
+ "input_box":[425, 600, 700, 875],
38
+ "input_point":[[575, 750]],
39
+ "input_label": [0]
40
+ },
41
+ {
42
+ "prompt_type":["box"],
43
+ "input_boxes": [
44
+ [75, 275, 1725, 850],
45
+ [425, 600, 700, 875],
46
+ [1375, 550, 1650, 800],
47
+ [1240, 675, 1400, 750],
48
+ ]
49
+ },
50
+ {
51
+ "prompt_type":["everything"]
52
+ },
53
+ ]
54
+ ```
55
+
56
+ In `base_segmenter.py`:
57
+ ```
58
+ segmenter = BaseSegmenter(
59
+ device='cuda',
60
+ checkpoint='sam_vit_h_4b8939.pth',
61
+ model_type='vit_h'
62
+ )
63
+
64
+ for i, prompt in enumerate(prompts):
65
+ masks = segmenter.inference(image_path, prompt)
66
+ ```
67
+
68
+ Outputs are masks (True and False numpy Matrix), shape: (num of masks, height, weight)
test_img/img1.jpg ADDED
test_img/img1.jpg.raw_mask.png ADDED
test_img/img10.jpg ADDED
test_img/img10.jpg.raw_mask.png ADDED
test_img/img11.jpg ADDED
test_img/img12.jpg ADDED
test_img/img12.jpg.raw_mask.png ADDED
test_img/img13.jpg ADDED
test_img/img13.jpg.raw_mask.png ADDED
test_img/img14.jpg ADDED
test_img/img14.jpg.raw_mask.png ADDED
test_img/img15.jpg ADDED
test_img/img15.jpg.raw_mask.png ADDED
test_img/img16.jpg ADDED
test_img/img16.jpg.raw_mask.png ADDED
test_img/img17.jpg ADDED
test_img/img18.jpg ADDED

Git LFS Details

  • SHA256: e02c393a23aadd1304497e3a9b41144df166d1cfda33ea3e00eed94e27da3aa4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.37 MB
test_img/img19.jpg ADDED
test_img/img2.jpg ADDED
test_img/img2.jpg.raw_mask.png ADDED
test_img/img20.jpg ADDED
test_img/img21.jpg ADDED
test_img/img22.jpg ADDED

Git LFS Details

  • SHA256: 5c5159bf7114d08967f95475176670043115b157bf700efa34190260cd917662
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
test_img/img23.jpg ADDED
test_img/img24.jpg ADDED
test_img/img25.jpg ADDED
test_img/img27.jpg ADDED
test_img/img28.jpg ADDED